1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, fetchpatch
5, jaxlib
6, jax
7, numpy
8, parameterized
9, pillow
10, scipy
11, tensorboard
12, keras
13, pytestCheckHook
14, tensorflow
15}:
16
17buildPythonPackage rec {
18 pname = "objax";
19 version = "1.8.0";
20
21 src = fetchFromGitHub {
22 owner = "google";
23 repo = "objax";
24 rev = "refs/tags/v${version}";
25 hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk=";
26 };
27
28 # Avoid propagating the dependency on `jaxlib`, see
29 # https://github.com/NixOS/nixpkgs/issues/156767
30 buildInputs = [
31 jaxlib
32 ];
33
34 propagatedBuildInputs = [
35 jax
36 numpy
37 parameterized
38 pillow
39 scipy
40 tensorboard
41 ];
42
43 pythonImportsCheck = [
44 "objax"
45 ];
46
47 # This is necessay to ignore the presence of two protobufs version (tensorflow is bringing an
48 # older version).
49 catchConflicts = false;
50
51 nativeCheckInputs = [
52 keras
53 pytestCheckHook
54 tensorflow
55 ];
56
57 pytestFlagsArray = [
58 "tests/*.py"
59 ];
60
61 disabledTests = [
62 # Test requires internet access for prefetching some weights
63 "test_pretrained_keras_weight_0_ResNet50V2"
64 ];
65
66 meta = with lib; {
67 description = "Objax is a machine learning framework that provides an Object Oriented layer for JAX.";
68 homepage = "https://github.com/google/objax";
69 license = licenses.asl20;
70 maintainers = with maintainers; [ ndl ];
71 };
72}