1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 stdenv,
6
7 # build-system
8 setuptools,
9
10 # dependencies
11 absl-py,
12 dm-env,
13 etils,
14 flask,
15 flask-cors,
16 flax,
17 grpcio,
18 gym,
19 jax,
20 jaxlib,
21 jaxopt,
22 jinja2,
23 ml-collections,
24 mujoco,
25 mujoco-mjx,
26 numpy,
27 optax,
28 orbax-checkpoint,
29 pillow,
30 pytinyrenderer,
31 scipy,
32 tensorboardx,
33 trimesh,
34
35 # tests
36 pytestCheckHook,
37 pytest-xdist,
38 transforms3d,
39}:
40
41buildPythonPackage rec {
42 pname = "brax";
43 version = "0.12.1";
44 pyproject = true;
45
46 src = fetchFromGitHub {
47 owner = "google";
48 repo = "brax";
49 tag = "v${version}";
50 hash = "sha256-whkkqTTy5CY6soyS5D7hWtBZuVHc6si1ArqwLgzHDkw=";
51 };
52
53 build-system = [
54 setuptools
55 ];
56
57 dependencies = [
58 absl-py
59 # TODO: remove dm_env after dropping legacy v1 code
60 dm-env
61 etils
62 flask
63 flask-cors
64 flax
65 # TODO: remove grpcio and gym after dropping legacy v1 code
66 grpcio
67 gym
68 jax
69 jaxlib
70 jaxopt
71 jinja2
72 ml-collections
73 mujoco
74 mujoco-mjx
75 numpy
76 optax
77 orbax-checkpoint
78 pillow
79 # TODO: remove pytinyrenderer after dropping legacy v1 code
80 pytinyrenderer
81 scipy
82 tensorboardx
83 trimesh
84 ];
85
86 nativeCheckInputs = [
87 pytestCheckHook
88 pytest-xdist
89 transforms3d
90 ];
91
92 disabledTests = lib.optionals stdenv.hostPlatform.isAarch64 [
93 # Flaky:
94 # AssertionError: Array(-0.00135638, dtype=float32) != 0.0 within 0.001 delta (Array(0.00135638, dtype=float32) difference)
95 "test_pendulum_period2"
96 ];
97
98 disabledTestPaths = [
99 # ValueError: matmul: Input operand 1 has a mismatch in its core dimension
100 "brax/generalized/constraint_test.py"
101 ];
102
103 pythonImportsCheck = [
104 "brax"
105 ];
106
107 meta = {
108 description = "Massively parallel rigidbody physics simulation on accelerator hardware";
109 homepage = "https://github.com/google/brax";
110 license = lib.licenses.asl20;
111 maintainers = with lib.maintainers; [ nim65s ];
112 };
113}