nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at r-updates 139 lines 3.1 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 setuptools, 8 setuptools-scm, 9 10 # dependencies 11 jax, 12 msgpack, 13 numpy, 14 optax, 15 orbax-checkpoint, 16 pyyaml, 17 rich, 18 tensorstore, 19 typing-extensions, 20 21 # optional-dependencies 22 matplotlib, 23 24 # tests 25 cloudpickle, 26 keras, 27 einops, 28 flaxlib, 29 pytestCheckHook, 30 pytest-xdist, 31 sphinx, 32 tensorflow, 33 treescope, 34 35 writeScript, 36 tomlq, 37}: 38 39buildPythonPackage rec { 40 pname = "flax"; 41 version = "0.12.2"; 42 pyproject = true; 43 44 src = fetchFromGitHub { 45 owner = "google"; 46 repo = "flax"; 47 tag = "v${version}"; 48 hash = "sha256-Wdfc35/iah98C5WNYZWiAd2FJUJlyGLJ8xELpuYD3GU="; 49 }; 50 51 build-system = [ 52 setuptools 53 setuptools-scm 54 ]; 55 56 dependencies = [ 57 flaxlib 58 jax 59 msgpack 60 numpy 61 optax 62 orbax-checkpoint 63 pyyaml 64 rich 65 tensorstore 66 treescope 67 typing-extensions 68 ]; 69 70 optional-dependencies = { 71 all = [ matplotlib ]; 72 }; 73 74 pythonImportsCheck = [ "flax" ]; 75 76 nativeCheckInputs = [ 77 cloudpickle 78 keras 79 einops 80 pytestCheckHook 81 pytest-xdist 82 sphinx 83 tensorflow 84 ]; 85 86 pytestFlags = [ 87 # FutureWarning: In the future `np.object` will be defined as the corresponding NumPy scalar. 88 "-Wignore::FutureWarning" 89 ]; 90 91 disabledTestPaths = [ 92 # Docs test, needs extra deps + we're not interested in it. 93 "docs/_ext/codediff_test.py" 94 95 # The tests in `examples` are not designed to be executed from a single test 96 # session and thus either have the modules that conflict with each other or 97 # wrong import paths, depending on how they're invoked. Many tests also have 98 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`, 99 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them 100 # would be limited anyway. 101 "examples/*" 102 ]; 103 104 disabledTests = [ 105 # AssertionError: [Chex] Function 'add' is traced > 1 times! 106 "PadShardUnpadTest" 107 108 # AssertionError: nnx_model.kernel.value.sharding = NamedSharding(... 109 "test_linen_to_nnx_metadata" 110 111 # AssertionError: 'Linear_0' not found in State({}) 112 "test_compact_basic" 113 # KeyError: 'intermediates' 114 "test_linen_submodule" 115 "test_pure_nnx_submodule" 116 # KeyError: 'counts 117 "test_mutable_state" 118 # AttributeError: 'Top' object has no attribute '_pytree__state'. Did you mean: '_pytree__flatten'? 119 "test_shared_modules" 120 # AttributeError: 'MLP' object has no attribute 'scope 121 "test_transforms" 122 ]; 123 124 passthru = { 125 updateScript = writeScript "update.sh" '' 126 nix-update flax # does not --build by default 127 nix-build . -A flax.src # src is essentially a passthru 128 nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit 129 ''; 130 }; 131 132 meta = { 133 description = "Neural network library for JAX"; 134 homepage = "https://github.com/google/flax"; 135 changelog = "https://github.com/google/flax/releases/tag/v${version}"; 136 license = lib.licenses.asl20; 137 maintainers = with lib.maintainers; [ ndl ]; 138 }; 139}