1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 hatchling,
8
9 # tests
10 cloudpickle,
11 equinox,
12 ipython,
13 jax,
14 jaxlib,
15 pytestCheckHook,
16 tensorflow,
17 torch,
18}:
19
20let
21 self = buildPythonPackage rec {
22 pname = "jaxtyping";
23 version = "0.2.36";
24 pyproject = true;
25
26 src = fetchFromGitHub {
27 owner = "google";
28 repo = "jaxtyping";
29 rev = "refs/tags/v${version}";
30 hash = "sha256-TXhHh6Nka9TOnfFPaNyHmLdTkhzyFEY0mLSfoDf9KQc=";
31 };
32
33 build-system = [ hatchling ];
34
35 pythonImportsCheck = [ "jaxtyping" ];
36
37 nativeCheckInputs = [
38 cloudpickle
39 equinox
40 ipython
41 jax
42 jaxlib
43 pytestCheckHook
44 tensorflow
45 torch
46 ];
47
48 doCheck = false;
49
50 # Enable tests via passthru to avoid cyclic dependency with equinox.
51 passthru.tests = {
52 check = self.overridePythonAttrs {
53 # We disable tests because they complain about the version of typeguard being too new.
54 doCheck = false;
55 catchConflicts = false;
56 };
57 };
58
59 meta = {
60 description = "Type annotations and runtime checking for JAX arrays and PyTrees";
61 homepage = "https://github.com/google/jaxtyping";
62 changelog = "https://github.com/patrick-kidger/jaxtyping/releases/tag/v${version}";
63 license = lib.licenses.mit;
64 maintainers = with lib.maintainers; [ GaetanLepage ];
65 };
66 };
67in
68self