1{
2 absl-py,
3 buildPythonPackage,
4 chex,
5 dm-env,
6 dm-tree,
7 fetchFromGitHub,
8 flax,
9 immutabledict,
10 jax,
11 lib,
12 matplotlib,
13 mediapy,
14 numpy,
15 pillow,
16 pytestCheckHook,
17 setuptools,
18 tensorflow,
19 tqdm,
20}:
21
22buildPythonPackage {
23 pname = "waymax";
24 version = "0-unstable-2024-03-23";
25 pyproject = true;
26
27 src = fetchFromGitHub {
28 owner = "waymo-research";
29 repo = "waymax";
30 rev = "720f9214a9bf79b3da7926497f0cd0468ca3e630";
31 hash = "sha256-B1Rp5MATbEelp6G6K2wwV83QpINhOHgvAxb3mBN52Eg=";
32 };
33
34 # AttributeError: jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
35 # https://github.com/waymo-research/waymax/pull/77
36 postPatch = ''
37 substituteInPlace \
38 waymax/agents/expert.py \
39 waymax/agents/waypoint_following_agent.py \
40 waymax/agents/waypoint_following_agent_test.py \
41 waymax/dynamics/abstract_dynamics_test.py \
42 waymax/dynamics/state_dynamics_test.py \
43 waymax/env/base_environment_test.py \
44 waymax/env/rollout_test.py \
45 waymax/env/wrappers/brax_wrapper_test.py \
46 --replace-fail "jax.tree_map" "jax.tree_util.tree_map"
47 '';
48
49 build-system = [ setuptools ];
50
51 dependencies = [
52 absl-py
53 chex
54 dm-env
55 dm-tree
56 flax
57 immutabledict
58 jax
59 matplotlib
60 mediapy
61 numpy
62 pillow
63 tensorflow
64 tqdm
65 ];
66
67 nativeCheckInputs = [
68 pytestCheckHook
69 ];
70
71 pythonImportsCheck = [ "waymax" ];
72
73 disabledTestPaths = [
74 # Disable visualization tests that require a GUI
75 # waymax/visualization/viz_test.py Fatal Python error: Aborted
76 "waymax/visualization/viz_test.py"
77 ];
78
79 meta = {
80 description = "A JAX-based simulator for autonomous driving research";
81 homepage = "https://github.com/waymo-research/waymax";
82 changelog = "https://github.com/waymo-research/waymax/blob/main/CHANGELOG.md";
83 maintainers = with lib.maintainers; [ samuela ];
84 };
85}