1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, hatchling
5, jax
6, jaxlib
7, jaxtyping
8, typing-extensions
9, beartype
10, pytestCheckHook
11}:
12
13buildPythonPackage rec {
14 pname = "equinox";
15 version = "0.10.11";
16 format = "pyproject";
17
18 src = fetchFromGitHub {
19 owner = "patrick-kidger";
20 repo = pname;
21 rev = "refs/tags/v${version}";
22 hash = "sha256-JffuPplIROPog29FBsWH9cQHSkrFKuXjaTjjEwIqW/0=";
23 };
24
25 nativeBuildInputs = [
26 hatchling
27 ];
28
29 propagatedBuildInputs = [
30 jax
31 jaxlib
32 jaxtyping
33 typing-extensions
34 ];
35
36 nativeCheckInputs = [
37 beartype
38 pytestCheckHook
39 ];
40
41 pythonImportsCheck = [ "equinox" ];
42
43 meta = with lib; {
44 description = "A JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees";
45 homepage = "https://github.com/patrick-kidger/equinox";
46 license = licenses.asl20;
47 maintainers = with maintainers; [ GaetanLepage ];
48 };
49}