1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 hatchling,
8
9 # dependencies
10 jax,
11 jaxlib,
12 jaxtyping,
13 typing-extensions,
14
15 # checks
16 beartype,
17 optax,
18 pytest-xdist,
19 pytestCheckHook,
20}:
21
22buildPythonPackage rec {
23 pname = "equinox";
24 version = "0.11.10";
25 pyproject = true;
26
27 src = fetchFromGitHub {
28 owner = "patrick-kidger";
29 repo = "equinox";
30 tag = "v${version}";
31 hash = "sha256-QoqwLdtWCDrXyqiI6Xw7jq2sxiRCmLaxk3/ZGHrvqL4=";
32 };
33
34 build-system = [ hatchling ];
35
36 dependencies = [
37 jax
38 jaxlib
39 jaxtyping
40 typing-extensions
41 ];
42
43 nativeCheckInputs = [
44 beartype
45 optax
46 pytest-xdist
47 pytestCheckHook
48 ];
49
50 pythonImportsCheck = [ "equinox" ];
51
52 meta = {
53 description = "JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees";
54 changelog = "https://github.com/patrick-kidger/equinox/releases/tag/v${version}";
55 homepage = "https://github.com/patrick-kidger/equinox";
56 license = lib.licenses.asl20;
57 maintainers = with lib.maintainers; [ GaetanLepage ];
58 };
59}