1{
2 buildPythonPackage,
3 cloudpickle,
4 dm-haiku,
5 einops,
6 fetchFromGitHub,
7 flax,
8 hypothesis,
9 jaxlib,
10 keras,
11 lib,
12 poetry-core,
13 pytestCheckHook,
14 pyyaml,
15 rich,
16 tensorflow,
17 treeo,
18 torchmetrics,
19 pythonRelaxDepsHook,
20 torch,
21}:
22
23buildPythonPackage rec {
24 pname = "treex";
25 version = "0.6.11";
26 format = "pyproject";
27
28 src = fetchFromGitHub {
29 owner = "cgarciae";
30 repo = pname;
31 rev = "refs/tags/${version}";
32 hash = "sha256-ObOnbtAT4SlrwOms1jtn7/XKZorGISGY6VuhQlC3DaQ=";
33 };
34
35 # At the time of writing (2022-03-29), rich is currently at version 11.0.0.
36 # The treeo dependency is compatible with a patch, but not marked as such in
37 # treex. See https://github.com/cgarciae/treex/issues/68.
38 pythonRelaxDeps = [
39 "certifi"
40 "flax"
41 "rich"
42 "treeo"
43 ];
44
45 nativeBuildInputs = [
46 poetry-core
47 pythonRelaxDepsHook
48 ];
49
50 buildInputs = [ jaxlib ];
51
52 propagatedBuildInputs = [
53 einops
54 flax
55 pyyaml
56 rich
57 treeo
58 torch
59 ];
60
61 nativeCheckInputs = [
62 cloudpickle
63 dm-haiku
64 hypothesis
65 keras
66 pytestCheckHook
67 tensorflow
68 torchmetrics
69 ];
70
71 pythonImportsCheck = [ "treex" ];
72
73 meta = with lib; {
74 description = "Pytree Module system for Deep Learning in JAX";
75 homepage = "https://github.com/cgarciae/treex";
76 license = licenses.mit;
77 maintainers = with maintainers; [ ndl ];
78 };
79}