1{
2 lib,
3 buildPythonPackage,
4 pythonOlder,
5 fetchFromGitHub,
6 hatchling,
7 jax,
8 jaxlib,
9 jaxtyping,
10 typing-extensions,
11 beartype,
12 optax,
13 pytest-xdist,
14 pytestCheckHook,
15}:
16
17buildPythonPackage rec {
18 pname = "equinox";
19 version = "0.11.4";
20 pyproject = true;
21
22 disabled = pythonOlder "3.9";
23
24 src = fetchFromGitHub {
25 owner = "patrick-kidger";
26 repo = "equinox";
27 rev = "refs/tags/v${version}";
28 hash = "sha256-3OwHND1YEdg/SppqiB7pCdp6v+lYwTbtX07tmyEMWDo=";
29 };
30
31 nativeBuildInputs = [ hatchling ];
32
33 propagatedBuildInputs = [
34 jax
35 jaxlib
36 jaxtyping
37 typing-extensions
38 ];
39
40 nativeCheckInputs = [
41 beartype
42 optax
43 pytest-xdist
44 pytestCheckHook
45 ];
46
47 pythonImportsCheck = [ "equinox" ];
48
49 disabledTests = [
50 # For simplicity, JAX has removed its internal frames from the traceback of the following exception.
51 # https://github.com/patrick-kidger/equinox/issues/716
52 "test_abstract"
53 "test_complicated"
54 "test_grad"
55 "test_jvp"
56 "test_mlp"
57 "test_num_traces"
58 "test_pytree_in"
59 "test_simple"
60 "test_vmap"
61
62 # AssertionError: assert 'foo:\n pri...pe=float32)\n' == 'foo:\n pri...pe=float32)\n'
63 # Also reported in patrick-kidger/equinox#716
64 "test_backward_nan"
65 ];
66
67 meta = with lib; {
68 description = "A JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees";
69 changelog = "https://github.com/patrick-kidger/equinox/releases/tag/v${version}";
70 homepage = "https://github.com/patrick-kidger/equinox";
71 license = licenses.asl20;
72 maintainers = with maintainers; [ GaetanLepage ];
73 };
74}