1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 stdenv,
6
7 # build-system
8 hatchling,
9
10 # dependencies
11 absl-py,
12 etils,
13 flask,
14 flask-cors,
15 flax,
16 jax,
17 jaxlib,
18 jaxopt,
19 jinja2,
20 ml-collections,
21 mujoco,
22 mujoco-mjx,
23 numpy,
24 optax,
25 orbax-checkpoint,
26 pillow,
27 scipy,
28 tensorboardx,
29 typing-extensions,
30
31 # tests
32 dm-env,
33 gym,
34 pytestCheckHook,
35 pytest-xdist,
36 transforms3d,
37}:
38
39buildPythonPackage rec {
40 pname = "brax";
41 version = "0.12.4";
42 pyproject = true;
43
44 src = fetchFromGitHub {
45 owner = "google";
46 repo = "brax";
47 tag = "v${version}";
48 hash = "sha256-/eb0WjMzHwD1tjTyZ2fb2dzvGrWnyOLcVLOx4BeKvqk=";
49 };
50
51 build-system = [
52 hatchling
53 ];
54
55 dependencies = [
56 absl-py
57 etils
58 flask
59 flask-cors
60 flax
61 jax
62 jaxlib
63 jaxopt
64 jinja2
65 ml-collections
66 mujoco
67 mujoco-mjx
68 numpy
69 optax
70 orbax-checkpoint
71 pillow
72 scipy
73 tensorboardx
74 typing-extensions
75 ];
76
77 nativeCheckInputs = [
78 dm-env
79 gym
80 pytestCheckHook
81 pytest-xdist
82 transforms3d
83 ];
84
85 disabledTests = lib.optionals stdenv.hostPlatform.isAarch64 [
86 # Flaky:
87 # AssertionError: Array(-0.00135638, dtype=float32) != 0.0 within 0.001 delta (Array(0.00135638, dtype=float32) difference)
88 "test_pendulum_period2"
89 ];
90
91 disabledTestPaths = [
92 # ValueError: matmul: Input operand 1 has a mismatch in its core dimension
93 "brax/generalized/constraint_test.py"
94 ];
95
96 pythonImportsCheck = [
97 "brax"
98 ];
99
100 meta = {
101 description = "Massively parallel rigidbody physics simulation on accelerator hardware";
102 homepage = "https://github.com/google/brax";
103 license = lib.licenses.asl20;
104 maintainers = with lib.maintainers; [ nim65s ];
105 };
106}