1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 hatchling,
9
10 # dependencies
11 jax,
12 jaxtyping,
13 typing-extensions,
14 wadler-lindig,
15
16 # tests
17 beartype,
18 optax,
19 pytest-xdist,
20 pytestCheckHook,
21}:
22
23buildPythonPackage rec {
24 pname = "equinox";
25 version = "0.13.0";
26 pyproject = true;
27
28 src = fetchFromGitHub {
29 owner = "patrick-kidger";
30 repo = "equinox";
31 tag = "v${version}";
32 hash = "sha256-zXgAuFGWKHShKodi9swnWIry4VU9s4pBhBRoK5KzaL0=";
33 };
34
35 # Relax speed constraints on tests that can fail on busy builders
36 postPatch = ''
37 substituteInPlace tests/test_while_loop.py \
38 --replace-fail "speed < 0.1" "speed < 0.5" \
39 --replace-fail "speed < 0.5" "speed < 1" \
40 --replace-fail "speed < 1" "speed < 4" \
41 '';
42
43 build-system = [ hatchling ];
44
45 dependencies = [
46 jax
47 jaxtyping
48 typing-extensions
49 wadler-lindig
50 ];
51
52 nativeCheckInputs = [
53 beartype
54 optax
55 pytest-xdist
56 pytestCheckHook
57 ];
58
59 disabledTests = lib.optionals stdenv.hostPlatform.isDarwin [
60 # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated!
61 "test_filter"
62 ];
63
64 pythonImportsCheck = [ "equinox" ];
65
66 meta = {
67 description = "JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees";
68 changelog = "https://github.com/patrick-kidger/equinox/releases/tag/v${version}";
69 homepage = "https://github.com/patrick-kidger/equinox";
70 license = lib.licenses.asl20;
71 maintainers = with lib.maintainers; [ GaetanLepage ];
72 };
73}