1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, hatchling
5, jax
6, jaxlib
7, jaxtyping
8, typing-extensions
9, beartype
10, optax
11, pytestCheckHook
12}:
13
14buildPythonPackage rec {
15 pname = "equinox";
16 version = "0.11.2";
17 pyproject = true;
18
19 src = fetchFromGitHub {
20 owner = "patrick-kidger";
21 repo = "equinox";
22 rev = "refs/tags/v${version}";
23 hash = "sha256-qFTKiY/t2LCCWJBOSfaX0hYQInrpXgfhTc+J4iuyVbM=";
24 };
25
26 nativeBuildInputs = [
27 hatchling
28 ];
29
30 propagatedBuildInputs = [
31 jax
32 jaxlib
33 jaxtyping
34 typing-extensions
35 ];
36
37 nativeCheckInputs = [
38 beartype
39 optax
40 pytestCheckHook
41 ];
42
43 pythonImportsCheck = [ "equinox" ];
44
45 meta = with lib; {
46 description = "A JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees";
47 homepage = "https://github.com/patrick-kidger/equinox";
48 license = licenses.asl20;
49 maintainers = with maintainers; [ GaetanLepage ];
50 };
51}