1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 hatchling,
9
10 # dependencies
11 jax,
12 jaxtyping,
13 typing-extensions,
14 wadler-lindig,
15
16 # tests
17 beartype,
18 optax,
19 pytest-xdist,
20 pytestCheckHook,
21}:
22
23buildPythonPackage rec {
24 pname = "equinox";
25 version = "0.12.1";
26 pyproject = true;
27
28 src = fetchFromGitHub {
29 owner = "patrick-kidger";
30 repo = "equinox";
31 tag = "v${version}";
32 hash = "sha256-mw2fk+527b6Rx6FGe6QJf3ZbxZ3rjYFXKleX2g6AryU=";
33 };
34
35 # Relax speed constraints on tests that can fail on busy builders
36 postPatch = ''
37 substituteInPlace tests/test_while_loop.py \
38 --replace-fail "speed < 0.1" "speed < 0.5" \
39 --replace-fail "speed < 0.5" "speed < 1" \
40 --replace-fail "speed < 1" "speed < 4" \
41 '';
42
43 build-system = [ hatchling ];
44
45 dependencies = [
46 jax
47 jaxtyping
48 typing-extensions
49 wadler-lindig
50 ];
51
52 nativeCheckInputs = [
53 beartype
54 optax
55 pytest-xdist
56 pytestCheckHook
57 ];
58
59 disabledTests =
60 [
61 # AssertionError: assert '<function te...n.<locals>.f>' == '<function f>'
62 # https://github.com/patrick-kidger/equinox/issues/1008
63 "test_function"
64 ]
65 ++ lib.optionals stdenv.hostPlatform.isDarwin [
66 # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated!
67 "test_filter"
68 ];
69
70 pythonImportsCheck = [ "equinox" ];
71
72 meta = {
73 description = "JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees";
74 changelog = "https://github.com/patrick-kidger/equinox/releases/tag/v${version}";
75 homepage = "https://github.com/patrick-kidger/equinox";
76 license = lib.licenses.asl20;
77 maintainers = with lib.maintainers; [ GaetanLepage ];
78 };
79}