1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 pythonOlder,
6 fetchFromGitHub,
7
8 # build-system
9 poetry-core,
10
11 # dependencies
12 deepdiff,
13 gymnasium,
14 h5py,
15 matplotlib,
16 numba,
17 numpy,
18 overrides,
19 packaging,
20 pandas,
21 pettingzoo,
22 sensai-utils,
23 tensorboard,
24 torch,
25 tqdm,
26
27 # optional-dependencies
28 docstring-parser,
29 jsonargparse,
30 ale-py,
31 opencv,
32 shimmy,
33 pybox2d,
34 pygame,
35 swig,
36 mujoco,
37 imageio,
38 cython,
39 pybullet,
40 joblib,
41 scipy,
42
43 # tests
44 pymunk,
45 pytestCheckHook,
46}:
47
48buildPythonPackage rec {
49 pname = "tianshou";
50 version = "1.2.0";
51 pyproject = true;
52
53 disabled = pythonOlder "3.11";
54
55 src = fetchFromGitHub {
56 owner = "thu-ml";
57 repo = "tianshou";
58 tag = "v${version}";
59 hash = "sha256-lJAxjE+GMwssov1r4jOCOTf5Aonu+q6FSz5oWvZpuQQ=";
60 };
61
62 pythonRelaxDeps = [
63 "deepdiff"
64 "gymnasium"
65 "numpy"
66 ];
67
68 pythonRemoveDeps = [ "virtualenv" ];
69
70 postPatch = ''
71 # silence matplotlib warning
72 export MPLCONFIGDIR=$(mktemp -d)
73 '';
74
75 build-system = [ poetry-core ];
76
77 dependencies = [
78 deepdiff
79 gymnasium
80 h5py
81 matplotlib
82 numba
83 numpy
84 overrides
85 packaging
86 pandas
87 pettingzoo
88 sensai-utils
89 tensorboard
90 torch
91 tqdm
92 ];
93
94 optional-dependencies = {
95 all = lib.flatten (lib.attrValues (lib.filterAttrs (n: v: n != "all") optional-dependencies));
96
97 argparse = [
98 docstring-parser
99 jsonargparse
100 ];
101
102 atari = [
103 ale-py
104 # autorom
105 opencv
106 shimmy
107 ];
108
109 box2d = [
110 # instead of box2d-py
111 pybox2d
112 pygame
113 swig
114 ];
115
116 classic_control = [
117 pygame
118 ];
119
120 mujoco = [
121 mujoco
122 imageio
123 cython
124 ];
125
126 pybullet = [
127 pybullet
128 ];
129
130 # envpool = [
131 # envpool
132 # ];
133
134 # robotics = [
135 # gymnasium-robotics
136 # ];
137
138 # vizdoom = [
139 # vizdoom
140 # ];
141
142 eval = [
143 docstring-parser
144 joblib
145 jsonargparse
146 # rliable
147 scipy
148 ];
149 };
150
151 pythonImportsCheck = [ "tianshou" ];
152
153 nativeCheckInputs = [
154 pygame
155 pymunk
156 pytestCheckHook
157 ];
158
159 disabledTestPaths = [
160 # remove tests that require lot of compute (ai model training tests)
161 "test/continuous"
162 "test/discrete"
163 "test/highlevel"
164 "test/modelbased"
165 "test/offline"
166 ];
167
168 disabledTests = [
169 # AttributeError: 'TimeLimit' object has no attribute 'test_attribute'
170 "test_attr_unwrapped"
171 # Failed: DID NOT RAISE <class 'TypeError'>
172 "test_batch"
173 # Failed: Raised AssertionError
174 "test_vecenv"
175 ]
176 ++ lib.optionals stdenv.hostPlatform.isDarwin [
177 # Fatal Python error: Aborted
178 # pettingzoo/classic/tictactoe/tictactoe.py", line 254 in reset
179 "test_tic_tac_toe"
180 ];
181
182 meta = {
183 description = "Elegant PyTorch deep reinforcement learning library";
184 homepage = "https://github.com/thu-ml/tianshou";
185 changelog = "https://github.com/thu-ml/tianshou/releases/tag/${src.tag}";
186 license = lib.licenses.mit;
187 maintainers = with lib.maintainers; [ derdennisop ];
188 };
189}