Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
at lanzaboote 104 lines 2.5 kB view raw
1{ buildPythonPackage 2, fetchFromGitHub 3, jaxlib 4, jax 5, keras 6, lib 7, matplotlib 8, msgpack 9, numpy 10, optax 11, pytest-xdist 12, pytestCheckHook 13, pythonRelaxDepsHook 14, tensorflow 15, tensorstore 16, fetchpatch 17, rich 18}: 19 20buildPythonPackage rec { 21 pname = "flax"; 22 version = "0.7.4"; 23 24 src = fetchFromGitHub { 25 owner = "google"; 26 repo = pname; 27 rev = "refs/tags/v${version}"; 28 hash = "sha256-i48omag/1Si3mCCGfsUD9qeejyeCLWzvvwKJqH8vm8k="; 29 }; 30 31 nativeBuildInputs = [ jaxlib pythonRelaxDepsHook ]; 32 33 propagatedBuildInputs = [ 34 jax 35 matplotlib 36 msgpack 37 numpy 38 optax 39 rich 40 tensorstore 41 ]; 42 43 # See https://github.com/google/flax/pull/2882. 44 pythonRemoveDeps = [ "orbax" ]; 45 46 pythonImportsCheck = [ 47 "flax" 48 ]; 49 50 nativeCheckInputs = [ 51 keras 52 pytest-xdist 53 pytestCheckHook 54 tensorflow 55 ]; 56 57 pytestFlagsArray = [ 58 "-W ignore::FutureWarning" 59 "-W ignore::DeprecationWarning" 60 ]; 61 62 disabledTestPaths = [ 63 # Docs test, needs extra deps + we're not interested in it. 64 "docs/_ext/codediff_test.py" 65 66 # The tests in `examples` are not designed to be executed from a single test 67 # session and thus either have the modules that conflict with each other or 68 # wrong import paths, depending on how they're invoked. Many tests also have 69 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`, 70 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them 71 # would be limited anyway. 72 "examples/*" 73 74 # See https://github.com/google/flax/issues/3232. 75 "tests/jax_utils_test.py" 76 77 # Requires orbax which is not packaged as of 2023-07-27. 78 "tests/checkpoints_test.py" 79 ]; 80 81 disabledTests = [ 82 # See https://github.com/google/flax/issues/2554. 83 "test_async_save_checkpoints" 84 "test_jax_array0" 85 "test_jax_array1" 86 "test_keep0" 87 "test_keep1" 88 "test_optimized_lstm_cell_matches_regular" 89 "test_overwrite_checkpoints" 90 "test_save_restore_checkpoints_target_empty" 91 "test_save_restore_checkpoints_target_none" 92 "test_save_restore_checkpoints_target_singular" 93 "test_save_restore_checkpoints_w_float_steps" 94 "test_save_restore_checkpoints" 95 ]; 96 97 meta = with lib; { 98 description = "Neural network library for JAX"; 99 homepage = "https://github.com/google/flax"; 100 changelog = "https://github.com/google/flax/releases/tag/v${version}"; 101 license = licenses.asl20; 102 maintainers = with maintainers; [ ndl ]; 103 }; 104}