1{ lib
2, fetchFromGitHub
3, buildPythonPackage
4, jax
5, jaxlib
6, numpy
7, parameterized
8, pillow
9, scipy
10, tensorboard
11}:
12
13buildPythonPackage rec {
14 pname = "objax";
15 version = "1.6.0";
16
17 src = fetchFromGitHub {
18 owner = "google";
19 repo = "objax";
20 rev = "v${version}";
21 sha256 = "sha256-/6tZxVDe/3C53Re14odU9VA3mKvSj9X3/xt6bHFLHwQ=";
22 };
23
24 # Avoid propagating the dependency on `jaxlib`, see
25 # https://github.com/NixOS/nixpkgs/issues/156767
26 buildInputs = [
27 jaxlib
28 ];
29
30 propagatedBuildInputs = [
31 jax
32 numpy
33 parameterized
34 pillow
35 scipy
36 tensorboard
37 ];
38
39 pythonImportsCheck = [
40 "objax"
41 ];
42
43 meta = with lib; {
44 description = "Objax is a machine learning framework that provides an Object Oriented layer for JAX.";
45 homepage = "https://github.com/google/objax";
46 license = licenses.asl20;
47 maintainers = with maintainers; [ ndl ];
48 };
49}