1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, hatchling
5, numpy
6, typeguard
7, typing-extensions
8, cloudpickle
9, equinox
10, jax
11, jaxlib
12, torch
13, pytestCheckHook
14}:
15
16let
17 self = buildPythonPackage rec {
18 pname = "jaxtyping";
19 version = "0.2.23";
20 pyproject = true;
21
22 src = fetchFromGitHub {
23 owner = "google";
24 repo = "jaxtyping";
25 rev = "refs/tags/v${version}";
26 hash = "sha256-22dIuIjFgqRmV9AQok02skVt7fm17/WpzBm3FrJ6/zs=";
27 };
28
29 nativeBuildInputs = [
30 hatchling
31 ];
32
33 propagatedBuildInputs = [
34 numpy
35 typeguard
36 typing-extensions
37 ];
38
39 nativeCheckInputs = [
40 cloudpickle
41 equinox
42 jax
43 jaxlib
44 pytestCheckHook
45 torch
46 ];
47
48 doCheck = false;
49
50 # Enable tests via passthru to avoid cyclic dependency with equinox.
51 passthru.tests = {
52 check = self.overridePythonAttrs { doCheck = true; };
53 };
54
55 pythonImportsCheck = [ "jaxtyping" ];
56
57 meta = with lib; {
58 description = "Type annotations and runtime checking for JAX arrays and PyTrees";
59 homepage = "https://github.com/google/jaxtyping";
60 license = licenses.mit;
61 maintainers = with maintainers; [ GaetanLepage ];
62 };
63 };
64 in self