1{
2 lib,
3 buildPythonPackage,
4 pythonOlder,
5 fetchFromGitHub,
6 hatchling,
7 typeguard,
8 cloudpickle,
9 equinox,
10 ipython,
11 jax,
12 jaxlib,
13 pytestCheckHook,
14 tensorflow,
15 torch,
16}:
17
18let
19 self = buildPythonPackage rec {
20 pname = "jaxtyping";
21 version = "0.2.33";
22 pyproject = true;
23
24 disabled = pythonOlder "3.9";
25
26 src = fetchFromGitHub {
27 owner = "google";
28 repo = "jaxtyping";
29 rev = "refs/tags/v${version}";
30 hash = "sha256-CL1EONbjjT3SCAn2o1x+1cgfuYWMEgQwbX9j34t+HGs=";
31 };
32
33 build-system = [
34 hatchling
35 ];
36
37 dependencies = [
38 typeguard
39 ];
40
41 pythonRelaxDeps = [ "typeguard" ];
42
43 nativeCheckInputs = [
44 cloudpickle
45 equinox
46 ipython
47 jax
48 jaxlib
49 pytestCheckHook
50 tensorflow
51 torch
52 ];
53
54 doCheck = false;
55
56 # Enable tests via passthru to avoid cyclic dependency with equinox.
57 passthru.tests = {
58 check = self.overridePythonAttrs {
59 # We disable tests because they complain about the version of typeguard being too new.
60 doCheck = false;
61 catchConflicts = false;
62 };
63 };
64
65 pythonImportsCheck = [ "jaxtyping" ];
66
67 meta = {
68 description = "Type annotations and runtime checking for JAX arrays and PyTrees";
69 homepage = "https://github.com/google/jaxtyping";
70 changelog = "https://github.com/patrick-kidger/jaxtyping/releases/tag/v${version}";
71 license = lib.licenses.mit;
72 maintainers = with lib.maintainers; [ GaetanLepage ];
73 };
74 };
75in
76self