1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 hatchling,
9
10 # dependencies
11 jax,
12 jaxtyping,
13 typing-extensions,
14 wadler-lindig,
15
16 # tests
17 beartype,
18 optax,
19 pytest-xdist,
20 pytestCheckHook,
21}:
22
23buildPythonPackage rec {
24 pname = "equinox";
25 version = "0.13.1";
26 pyproject = true;
27
28 src = fetchFromGitHub {
29 owner = "patrick-kidger";
30 repo = "equinox";
31 tag = "v${version}";
32 hash = "sha256-txgL5a+kKT28gAS8HianBgnnR+J25R2wrpRr8HEWCXA=";
33 };
34
35 # Relax speed constraints on tests that can fail on busy builders
36 postPatch = ''
37 substituteInPlace tests/test_while_loop.py \
38 --replace-fail "speed < 0.1" "speed < 0.5" \
39 --replace-fail "speed < 0.5" "speed < 1" \
40 --replace-fail "speed < 1" "speed < 20" \
41 --replace-fail "speed < 2" "speed < 20"
42 '';
43
44 build-system = [ hatchling ];
45
46 dependencies = [
47 jax
48 jaxtyping
49 typing-extensions
50 wadler-lindig
51 ];
52
53 nativeCheckInputs = [
54 beartype
55 optax
56 pytest-xdist
57 pytestCheckHook
58 ];
59
60 disabledTests = lib.optionals stdenv.hostPlatform.isDarwin [
61 # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated!
62 "test_filter"
63 ];
64
65 pythonImportsCheck = [ "equinox" ];
66
67 meta = {
68 description = "JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees";
69 changelog = "https://github.com/patrick-kidger/equinox/releases/tag/v${version}";
70 homepage = "https://github.com/patrick-kidger/equinox";
71 license = lib.licenses.asl20;
72 maintainers = with lib.maintainers; [ GaetanLepage ];
73 };
74}