at 23.11-beta 1.4 kB view raw
1{ lib 2, buildPythonPackage 3, fetchFromGitHub 4, fetchpatch 5, jaxlib 6, jax 7, numpy 8, parameterized 9, pillow 10, scipy 11, tensorboard 12, keras 13, pytestCheckHook 14, tensorflow 15}: 16 17buildPythonPackage rec { 18 pname = "objax"; 19 version = "1.8.0"; 20 21 src = fetchFromGitHub { 22 owner = "google"; 23 repo = "objax"; 24 rev = "refs/tags/v${version}"; 25 hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk="; 26 }; 27 28 # Avoid propagating the dependency on `jaxlib`, see 29 # https://github.com/NixOS/nixpkgs/issues/156767 30 buildInputs = [ 31 jaxlib 32 ]; 33 34 propagatedBuildInputs = [ 35 jax 36 numpy 37 parameterized 38 pillow 39 scipy 40 tensorboard 41 ]; 42 43 pythonImportsCheck = [ 44 "objax" 45 ]; 46 47 # This is necessay to ignore the presence of two protobufs version (tensorflow is bringing an 48 # older version). 49 catchConflicts = false; 50 51 nativeCheckInputs = [ 52 keras 53 pytestCheckHook 54 tensorflow 55 ]; 56 57 pytestFlagsArray = [ 58 "tests/*.py" 59 ]; 60 61 disabledTests = [ 62 # Test requires internet access for prefetching some weights 63 "test_pretrained_keras_weight_0_ResNet50V2" 64 ]; 65 66 meta = with lib; { 67 description = "Objax is a machine learning framework that provides an Object Oriented layer for JAX."; 68 homepage = "https://github.com/google/objax"; 69 license = licenses.asl20; 70 maintainers = with maintainers; [ ndl ]; 71 }; 72}