1{
2 lib,
3 buildPythonPackage,
4 pythonOlder,
5 fetchFromGitHub,
6 hatchling,
7 pythonRelaxDepsHook,
8 numpy,
9 typeguard,
10 typing-extensions,
11 cloudpickle,
12 equinox,
13 ipython,
14 jax,
15 jaxlib,
16 pytestCheckHook,
17 tensorflow,
18 torch,
19}:
20
21let
22 self = buildPythonPackage rec {
23 pname = "jaxtyping";
24 version = "0.2.28";
25 pyproject = true;
26
27 disabled = pythonOlder "3.9";
28
29 src = fetchFromGitHub {
30 owner = "google";
31 repo = "jaxtyping";
32 rev = "refs/tags/v${version}";
33 hash = "sha256-xDFrgPecUIfCACg/xkMQ8G1+6hNiUUDg9eCZKNpNfzs=";
34 };
35
36 nativeBuildInputs = [
37 hatchling
38 pythonRelaxDepsHook
39 ];
40
41 propagatedBuildInputs = [
42 numpy
43 typeguard
44 typing-extensions
45 ];
46
47 pythonRelaxDeps = [ "typeguard" ];
48
49 nativeCheckInputs = [
50 cloudpickle
51 equinox
52 ipython
53 jax
54 jaxlib
55 pytestCheckHook
56 tensorflow
57 torch
58 ];
59
60 doCheck = false;
61
62 # Enable tests via passthru to avoid cyclic dependency with equinox.
63 passthru.tests = {
64 check = self.overridePythonAttrs {
65 # We disable tests because they complain about the version of typeguard being too new.
66 doCheck = false;
67 catchConflicts = false;
68 };
69 };
70
71 pythonImportsCheck = [ "jaxtyping" ];
72
73 meta = with lib; {
74 description = "Type annotations and runtime checking for JAX arrays and PyTrees";
75 homepage = "https://github.com/google/jaxtyping";
76 license = licenses.mit;
77 maintainers = with maintainers; [ GaetanLepage ];
78 };
79 };
80in
81self