at 23.05-pre 902 B view raw
1{ lib 2, fetchFromGitHub 3, buildPythonPackage 4, jax 5, jaxlib 6, numpy 7, parameterized 8, pillow 9, scipy 10, tensorboard 11}: 12 13buildPythonPackage rec { 14 pname = "objax"; 15 version = "1.6.0"; 16 17 src = fetchFromGitHub { 18 owner = "google"; 19 repo = "objax"; 20 rev = "v${version}"; 21 sha256 = "sha256-/6tZxVDe/3C53Re14odU9VA3mKvSj9X3/xt6bHFLHwQ="; 22 }; 23 24 # Avoid propagating the dependency on `jaxlib`, see 25 # https://github.com/NixOS/nixpkgs/issues/156767 26 buildInputs = [ 27 jaxlib 28 ]; 29 30 propagatedBuildInputs = [ 31 jax 32 numpy 33 parameterized 34 pillow 35 scipy 36 tensorboard 37 ]; 38 39 pythonImportsCheck = [ 40 "objax" 41 ]; 42 43 meta = with lib; { 44 description = "Objax is a machine learning framework that provides an Object Oriented layer for JAX."; 45 homepage = "https://github.com/google/objax"; 46 license = licenses.asl20; 47 maintainers = with maintainers; [ ndl ]; 48 }; 49}