Merge pull request #291705 from GaetanLepage/jax

python311Packages.{jax,jaxlib,jaxlib-bin}: 0.4.24 -> 0.4.28

authored by Samuel Ainsworth and committed by GitHub 3a993d32 8fe1aa68

+128 -100
+6 -2
pkgs/development/python-modules/blackjax/default.nix
··· 16 17 buildPythonPackage rec { 18 pname = "blackjax"; 19 - version = "1.2.0"; 20 pyproject = true; 21 22 disabled = pythonOlder "3.9"; ··· 25 owner = "blackjax-devs"; 26 repo = "blackjax"; 27 rev = "refs/tags/${version}"; 28 - hash = "sha256-vXyxK3xALKG61YGK7fmoqQNGfOiagHFrvnU02WKZThw="; 29 }; 30 31 build-system = [ ··· 56 disabledTests = [ 57 # too slow 58 "test_adaptive_tempered_smc" 59 ]; 60 61 pythonImportsCheck = [
··· 16 17 buildPythonPackage rec { 18 pname = "blackjax"; 19 + version = "1.2.1"; 20 pyproject = true; 21 22 disabled = pythonOlder "3.9"; ··· 25 owner = "blackjax-devs"; 26 repo = "blackjax"; 27 rev = "refs/tags/${version}"; 28 + hash = "sha256-VoWBCjFMyE5LVJyf7du/pKlnvDHj22lguiP6ZUzH9ak="; 29 }; 30 31 build-system = [ ··· 56 disabledTests = [ 57 # too slow 58 "test_adaptive_tempered_smc" 59 + ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [ 60 + # Numerical test (AssertionError) 61 + # https://github.com/blackjax-devs/blackjax/issues/668 62 + "test_chees_adaptation" 63 ]; 64 65 pythonImportsCheck = [
+15 -2
pkgs/development/python-modules/equinox/default.nix
··· 48 pythonImportsCheck = [ "equinox" ]; 49 50 disabledTests = [ 51 - # Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted. 52 - "test_tracetime" 53 ]; 54 55 meta = with lib; {
··· 48 pythonImportsCheck = [ "equinox" ]; 49 50 disabledTests = [ 51 + # For simplicity, JAX has removed its internal frames from the traceback of the following exception. 52 + # https://github.com/patrick-kidger/equinox/issues/716 53 + "test_abstract" 54 + "test_complicated" 55 + "test_grad" 56 + "test_jvp" 57 + "test_mlp" 58 + "test_num_traces" 59 + "test_pytree_in" 60 + "test_simple" 61 + "test_vmap" 62 + 63 + # AssertionError: assert 'foo:\n pri...pe=float32)\n' == 'foo:\n pri...pe=float32)\n' 64 + # Also reported in patrick-kidger/equinox#716 65 + "test_backward_nan" 66 ]; 67 68 meta = with lib; {
+4 -4
pkgs/development/python-modules/flax/default.nix
··· 25 26 buildPythonPackage rec { 27 pname = "flax"; 28 - version = "0.8.2"; 29 pyproject = true; 30 31 disabled = pythonOlder "3.9"; ··· 34 owner = "google"; 35 repo = "flax"; 36 rev = "refs/tags/v${version}"; 37 - hash = "sha256-UABgJGe1grUSkwOJpjeIoFqhXsqG//HlC1YyYPxXV+g="; 38 }; 39 40 - nativeBuildInputs = [ 41 jaxlib 42 pythonRelaxDepsHook 43 setuptools-scm 44 ]; 45 46 - propagatedBuildInputs = [ 47 jax 48 msgpack 49 numpy
··· 25 26 buildPythonPackage rec { 27 pname = "flax"; 28 + version = "0.8.3"; 29 pyproject = true; 30 31 disabled = pythonOlder "3.9"; ··· 34 owner = "google"; 35 repo = "flax"; 36 rev = "refs/tags/v${version}"; 37 + hash = "sha256-uDGTyksUZTTL6FiTJP+qteFLOjr75dcTj9yRJ6Jm8xU="; 38 }; 39 40 + build-system = [ 41 jaxlib 42 pythonRelaxDepsHook 43 setuptools-scm 44 ]; 45 46 + dependencies = [ 47 jax 48 msgpack 49 numpy
+10 -2
pkgs/development/python-modules/jax/default.nix
··· 29 in 30 buildPythonPackage rec { 31 pname = "jax"; 32 - version = "0.4.25"; 33 pyproject = true; 34 35 disabled = pythonOlder "3.9"; ··· 39 repo = "jax"; 40 # google/jax contains tags for jax and jaxlib. Only use jax tags! 41 rev = "refs/tags/jax-v${version}"; 42 - hash = "sha256-poQQo2ZgEhPYzK3aCs+BjaHTNZbezJAECd+HOdY1Yok="; 43 }; 44 45 nativeBuildInputs = [ ··· 80 "-W ignore::DeprecationWarning" 81 "tests/" 82 ]; 83 84 disabledTests = [ 85 # Exceeds tolerance when the machine is busy
··· 29 in 30 buildPythonPackage rec { 31 pname = "jax"; 32 + version = "0.4.28"; 33 pyproject = true; 34 35 disabled = pythonOlder "3.9"; ··· 39 repo = "jax"; 40 # google/jax contains tags for jax and jaxlib. Only use jax tags! 41 rev = "refs/tags/jax-v${version}"; 42 + hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; 43 }; 44 45 nativeBuildInputs = [ ··· 80 "-W ignore::DeprecationWarning" 81 "tests/" 82 ]; 83 + 84 + # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with 85 + # PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py' 86 + # See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241 87 + # NOTE: this doesn't seem to be an issue on linux 88 + preCheck = lib.optionalString stdenv.isDarwin '' 89 + export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d) 90 + ''; 91 92 disabledTests = [ 93 # Exceeds tolerance when the machine is busy
+22 -38
pkgs/development/python-modules/jaxlib/bin.nix
··· 20 , stdenv 21 # Options: 22 , cudaSupport ? config.cudaSupport 23 - , cudaPackagesGoogle 24 }: 25 26 let 27 - inherit (cudaPackagesGoogle) cudaVersion; 28 29 - version = "0.4.24"; 30 31 inherit (python) pythonVersion; 32 33 - cudaLibPath = lib.makeLibraryPath (with cudaPackagesGoogle; [ 34 cuda_cudart.lib # libcudart.so 35 cuda_cupti.lib # libcupti.so 36 cudnn.lib # libcudnn.so ··· 56 "3.9-x86_64-linux" = getSrcFromPypi { 57 platform = "manylinux2014_x86_64"; 58 dist = "cp39"; 59 - hash = "sha256-6P5ArMoLZiUkHUoQ/mJccbNj5/7el/op+Qo6cGQ33xE="; 60 }; 61 "3.9-aarch64-darwin" = getSrcFromPypi { 62 platform = "macosx_11_0_arm64"; 63 dist = "cp39"; 64 - hash = "sha256-23JQZRwMLtt7sK/JlCBqqRyfTVIAVJFN2sL+nAkQgvU="; 65 }; 66 "3.9-x86_64-darwin" = getSrcFromPypi { 67 platform = "macosx_10_14_x86_64"; 68 dist = "cp39"; 69 - hash = "sha256-OgMedn9GHGs5THZf3pkP3Aw/jJ0vL5qK1b+Lzf634Ik="; 70 }; 71 72 "3.10-x86_64-linux" = getSrcFromPypi { 73 platform = "manylinux2014_x86_64"; 74 dist = "cp310"; 75 - hash = "sha256-/VwUIIa7mTs/wLz0ArsEfNrz2pGriVVT5GX9XRFRxfY="; 76 }; 77 "3.10-aarch64-darwin" = getSrcFromPypi { 78 platform = "macosx_11_0_arm64"; 79 dist = "cp310"; 80 - hash = "sha256-LgICOyDGts840SQQJh+yOMobMASb62llvJjpGvhzrSw="; 81 }; 82 "3.10-x86_64-darwin" = getSrcFromPypi { 83 platform = "macosx_10_14_x86_64"; 84 dist = "cp310"; 85 - hash = "sha256-vhyULw+zBpz1UEi2tqgBMQEzY9a6YBgEIg6A4PPh3bQ="; 86 }; 87 88 "3.11-x86_64-linux" = getSrcFromPypi { 89 platform = "manylinux2014_x86_64"; 90 dist = "cp311"; 91 - hash = "sha256-VJO/VVwBFkOEtq4y/sLVgAV8Cung01JULiuT6W96E/8="; 92 }; 93 "3.11-aarch64-darwin" = getSrcFromPypi { 94 platform = "macosx_11_0_arm64"; 95 dist = "cp311"; 96 - hash = "sha256-VtuwXxurpSp1KI8ty1bizs5cdy8GEBN2MgS227sOCmE="; 97 }; 98 "3.11-x86_64-darwin" = getSrcFromPypi { 99 platform = "macosx_10_14_x86_64"; 100 dist = "cp311"; 101 - hash = "sha256-4Dj5dEGKb9hpg3HlVogNO1Gc9UibJhy1eym2mjivxAQ="; 102 }; 103 104 "3.12-x86_64-linux" = getSrcFromPypi { 105 platform = "manylinux2014_x86_64"; 106 dist = "cp312"; 107 - hash = "sha256-TlrGVtb3NTLmhnILWPLJR+jISCZ5SUV4wxNFpSfkCBo="; 108 }; 109 "3.12-aarch64-darwin" = getSrcFromPypi { 110 platform = "macosx_11_0_arm64"; 111 dist = "cp312"; 112 - hash = "sha256-FIwK5CGykQjteuWzLZnbtAggIxLQeGV96bXlZGEytN0="; 113 }; 114 "3.12-x86_64-darwin" = getSrcFromPypi { 115 platform = "macosx_10_14_x86_64"; 116 dist = "cp312"; 117 - hash = "sha256-9/jw/wr6oUD9pOadVAaMRL086iVMUXwVgnUMcG1UNvE="; 118 }; 119 }; 120 ··· 130 gpuSrcs = { 131 "cuda12.2-3.9" = fetchurl { 132 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl"; 133 - hash = "sha256-xdJKLPtx+CIza2CrWKM3M0cZJzyNFVTTTsvlgh38bfM="; 134 }; 135 "cuda12.2-3.10" = fetchurl { 136 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"; 137 - hash = "sha256-QCjrOczD2mp+CDwVXBc0/4rJnAizeV62AK0Dpx9X6TE="; 138 }; 139 "cuda12.2-3.11" = fetchurl { 140 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl"; 141 - hash = "sha256-Ipy3vk1yUplpNzECAFt63aOIhgEWgXG7hkoeTIk9bQQ="; 142 }; 143 "cuda12.2-3.12" = fetchurl { 144 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl"; 145 - hash = "sha256-LSnZHaUga/8Z65iKXWBnZDk4yUpNykFTu3vukCchO6Q="; 146 - }; 147 - "cuda11.8-3.9" = fetchurl { 148 - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl"; 149 - hash = "sha256-UmyugL0VjlXkiD7fuDPWgW8XUpr/QaP5ggp6swoZTzU="; 150 - }; 151 - "cuda11.8-3.10" = fetchurl { 152 - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; 153 - hash = "sha256-luKULEiV1t/sO6eckDxddJTiOFa0dtJeDlrvp+WYmHk="; 154 - }; 155 - "cuda11.8-3.11" = fetchurl { 156 - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl"; 157 - hash = "sha256-4+uJ8Ij6mFGEmjFEgi3fLnSLZs+v18BRoOt7mZuqydw="; 158 - }; 159 - "cuda11.8-3.12" = fetchurl { 160 - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp312-cp312-manylinux2014_x86_64.whl"; 161 - hash = "sha256-bUDFb94Ar/65SzzR9RLIs/SL/HdjaPT1Su5whmjkS00="; 162 }; 163 }; 164 ··· 213 # for more info. 214 postInstall = lib.optional cudaSupport '' 215 mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin 216 - ln -s ${lib.getExe' cudaPackagesGoogle.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas 217 ''; 218 219 inherit (jaxlib-build) pythonImportsCheck; ··· 227 platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ]; 228 broken = 229 !(cudaSupport -> lib.versionAtLeast cudaVersion "11.1") 230 - || !(cudaSupport -> lib.versionAtLeast cudaPackagesGoogle.cudnn.version "8.2") 231 || !(cudaSupport -> stdenv.isLinux) 232 || !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}")) 233 # Fails at pythonImportsCheckPhase:
··· 20 , stdenv 21 # Options: 22 , cudaSupport ? config.cudaSupport 23 + , cudaPackages 24 }: 25 26 let 27 + inherit (cudaPackages) cudaVersion; 28 29 + version = "0.4.28"; 30 31 inherit (python) pythonVersion; 32 33 + cudaLibPath = lib.makeLibraryPath (with cudaPackages; [ 34 cuda_cudart.lib # libcudart.so 35 cuda_cupti.lib # libcupti.so 36 cudnn.lib # libcudnn.so ··· 56 "3.9-x86_64-linux" = getSrcFromPypi { 57 platform = "manylinux2014_x86_64"; 58 dist = "cp39"; 59 + hash = "sha256-Slbr8FtKTBeRaZ2HTgcvP4CPCYa0AQsU+1SaackMqdw="; 60 }; 61 "3.9-aarch64-darwin" = getSrcFromPypi { 62 platform = "macosx_11_0_arm64"; 63 dist = "cp39"; 64 + hash = "sha256-sBVi7IrXVxm30DiXUkiel+trTctMjBE75JFjTVKCrTw="; 65 }; 66 "3.9-x86_64-darwin" = getSrcFromPypi { 67 platform = "macosx_10_14_x86_64"; 68 dist = "cp39"; 69 + hash = "sha256-T5jMg3srbG3P4Kt/+esQkxSSCUYRmqOvn6oTlxj/J4c="; 70 }; 71 72 "3.10-x86_64-linux" = getSrcFromPypi { 73 platform = "manylinux2014_x86_64"; 74 dist = "cp310"; 75 + hash = "sha256-47zcb45g+FVPQVwU2TATTmAuPKM8OOVGJ0/VRfh1dps="; 76 }; 77 "3.10-aarch64-darwin" = getSrcFromPypi { 78 platform = "macosx_11_0_arm64"; 79 dist = "cp310"; 80 + hash = "sha256-8Djmi9ENGjVUcisLvjbmpEg4RDenWqnSg/aW8O2fjAk="; 81 }; 82 "3.10-x86_64-darwin" = getSrcFromPypi { 83 platform = "macosx_10_14_x86_64"; 84 dist = "cp310"; 85 + hash = "sha256-pCHSN/jCXShQFm0zRgPGc925tsJvUrxJZwS4eCKXvWY="; 86 }; 87 88 "3.11-x86_64-linux" = getSrcFromPypi { 89 platform = "manylinux2014_x86_64"; 90 dist = "cp311"; 91 + hash = "sha256-Rc4PPIQM/4I2z/JsN/Jsn/B4aV+T4MFiwyDCgfUEEnU="; 92 }; 93 "3.11-aarch64-darwin" = getSrcFromPypi { 94 platform = "macosx_11_0_arm64"; 95 dist = "cp311"; 96 + hash = "sha256-eThX+vN/Nxyv51L+pfyBH0NeQ7j7S1AgWERKf17M+Ck="; 97 }; 98 "3.11-x86_64-darwin" = getSrcFromPypi { 99 platform = "macosx_10_14_x86_64"; 100 dist = "cp311"; 101 + hash = "sha256-L/gpDtx7ksfq5SUX9lSSYz4mey6QZ7rT5MMj0hPnfPU="; 102 }; 103 104 "3.12-x86_64-linux" = getSrcFromPypi { 105 platform = "manylinux2014_x86_64"; 106 dist = "cp312"; 107 + hash = "sha256-RqGqhX9P7uikP8upXA4Kti1AwmzJcwtsaWVZCLo1n40="; 108 }; 109 "3.12-aarch64-darwin" = getSrcFromPypi { 110 platform = "macosx_11_0_arm64"; 111 dist = "cp312"; 112 + hash = "sha256-jdi//jhTcC9jzZJNoO4lc0pNGc1ckmvgM9dyun0cF10="; 113 }; 114 "3.12-x86_64-darwin" = getSrcFromPypi { 115 platform = "macosx_10_14_x86_64"; 116 dist = "cp312"; 117 + hash = "sha256-1sCaVFMpciRhrwVuc1FG0sjHTCKsdCaoRetp8ya096A="; 118 }; 119 }; 120 ··· 130 gpuSrcs = { 131 "cuda12.2-3.9" = fetchurl { 132 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl"; 133 + hash = "sha256-d8LIl22gIvmWfoyKfXKElZJXicPQIZxdS4HumhwQGCw="; 134 }; 135 "cuda12.2-3.10" = fetchurl { 136 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"; 137 + hash = "sha256-PXtWv+UEcMWF8LhWe6Z1UGkf14PG3dkJ0Iop0LiimnQ="; 138 }; 139 "cuda12.2-3.11" = fetchurl { 140 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl"; 141 + hash = "sha256-QO2WSOzmJ48VaCha596mELiOfPsAGLpGctmdzcCHE/o="; 142 }; 143 "cuda12.2-3.12" = fetchurl { 144 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl"; 145 + hash = "sha256-ixWMaIChy4Ammsn23/3cCoala0lFibuUxyUr3tjfFKU="; 146 }; 147 }; 148 ··· 197 # for more info. 198 postInstall = lib.optional cudaSupport '' 199 mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin 200 + ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas 201 ''; 202 203 inherit (jaxlib-build) pythonImportsCheck; ··· 211 platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ]; 212 broken = 213 !(cudaSupport -> lib.versionAtLeast cudaVersion "11.1") 214 + || !(cudaSupport -> lib.versionAtLeast cudaPackages.cudnn.version "8.2") 215 || !(cudaSupport -> stdenv.isLinux) 216 || !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}")) 217 # Fails at pythonImportsCheckPhase:
+15 -25
pkgs/development/python-modules/jaxlib/default.nix
··· 13 , curl 14 , cython 15 , fetchFromGitHub 16 - , fetchpatch 17 , git 18 , IOKit 19 , jsoncpp ··· 45 , config 46 # CUDA flags: 47 , cudaSupport ? config.cudaSupport 48 - , cudaPackagesGoogle 49 50 # MKL: 51 , mklSupport ? true 52 }@inputs: 53 54 let 55 - inherit (cudaPackagesGoogle) cudaFlags cudaVersion cudnn nccl; 56 57 pname = "jaxlib"; 58 - version = "0.4.24"; 59 60 # It's necessary to consistently use backendStdenv when building with CUDA 61 # support, otherwise we get libstdc++ errors downstream 62 stdenv = throw "Use effectiveStdenv instead"; 63 - effectiveStdenv = if cudaSupport then cudaPackagesGoogle.backendStdenv else inputs.stdenv; 64 65 meta = with lib; { 66 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; ··· 78 # These are necessary at build time and run time. 79 cuda_libs_joined = symlinkJoin { 80 name = "cuda-joined"; 81 - paths = with cudaPackagesGoogle; [ 82 cuda_cudart.lib # libcudart.so 83 cuda_cudart.static # libcudart_static.a 84 cuda_cupti.lib # libcupti.so ··· 92 # These are only necessary at build time. 93 cuda_build_deps_joined = symlinkJoin { 94 name = "cuda-build-deps-joined"; 95 - paths = with cudaPackagesGoogle; [ 96 cuda_libs_joined 97 98 # Binaries 99 - cudaPackagesGoogle.cuda_nvcc.bin # nvcc 100 101 # Headers 102 cuda_cccl.dev # block_load.cuh ··· 181 owner = "openxla"; 182 repo = "xla"; 183 # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl. 184 - rev = "12eee889e1f2ad41e27d7b0e970cb92d282d3ec5"; 185 - hash = "sha256-68kjjgwYjRlcT0TVJo9BN6s+WTkdu5UMJqQcfHpBT90="; 186 }; 187 188 - patches = [ 189 - # Resolves "could not convert ‘result’ from ‘SmallVector<[...],6>’ to 190 - # ‘SmallVector<[...],4>’" compilation error. See https://github.com/google/jax/issues/19814#issuecomment-1945141259. 191 - (fetchpatch { 192 - url = "https://github.com/openxla/xla/commit/7a614cd346594fc7ea2fe75570c9c53a4a444f60.patch"; 193 - hash = "sha256-RtuQTH8wzNiJcOtISLhf+gMlH1gg8hekvxEB+4wX6BM="; 194 - }) 195 - ]; 196 - 197 dontBuild = true; 198 199 # This is necessary for patchShebangs to know the right path to use. ··· 220 repo = "jax"; 221 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! 222 rev = "refs/tags/${pname}-v${version}"; 223 - hash = "sha256-hmx7eo3pephc6BQfoJ3U0QwWBWmhkAc+7S4QmW32qQs="; 224 }; 225 226 nativeBuildInputs = [ ··· 364 ]; 365 366 sha256 = (if cudaSupport then { 367 - x86_64-linux = "sha256-8JilAoTbqOjOOJa/Zc/n/quaEDcpdcLXCNb34mfB+OM="; 368 } else { 369 - x86_64-linux = "sha256-iqS+I1FQLNWXNMsA20cJp7YkyGUeshee5b2QfRBNZtk="; 370 - aarch64-linux = "sha256-qmJ0Fm/VGMTmko4PhKs1P8/GLEJmVxb8xg+ss/HsakY=="; 371 }).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}"); 372 }; 373 ··· 414 # for more info. 415 postInstall = lib.optionalString cudaSupport '' 416 mkdir -p $out/bin 417 - ln -s ${cudaPackagesGoogle.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas 418 419 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do 420 patchelf --add-rpath "${lib.makeLibraryPath [cuda_libs_joined cudnn nccl]}" "$lib" ··· 423 424 nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ]; 425 426 - propagatedBuildInputs = [ 427 absl-py 428 curl 429 double-conversion
··· 13 , curl 14 , cython 15 , fetchFromGitHub 16 , git 17 , IOKit 18 , jsoncpp ··· 44 , config 45 # CUDA flags: 46 , cudaSupport ? config.cudaSupport 47 + , cudaPackages 48 49 # MKL: 50 , mklSupport ? true 51 }@inputs: 52 53 let 54 + inherit (cudaPackages) cudaFlags cudaVersion cudnn nccl; 55 56 pname = "jaxlib"; 57 + version = "0.4.28"; 58 59 # It's necessary to consistently use backendStdenv when building with CUDA 60 # support, otherwise we get libstdc++ errors downstream 61 stdenv = throw "Use effectiveStdenv instead"; 62 + effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv; 63 64 meta = with lib; { 65 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; ··· 77 # These are necessary at build time and run time. 78 cuda_libs_joined = symlinkJoin { 79 name = "cuda-joined"; 80 + paths = with cudaPackages; [ 81 cuda_cudart.lib # libcudart.so 82 cuda_cudart.static # libcudart_static.a 83 cuda_cupti.lib # libcupti.so ··· 91 # These are only necessary at build time. 92 cuda_build_deps_joined = symlinkJoin { 93 name = "cuda-build-deps-joined"; 94 + paths = with cudaPackages; [ 95 cuda_libs_joined 96 97 # Binaries 98 + cudaPackages.cuda_nvcc.bin # nvcc 99 100 # Headers 101 cuda_cccl.dev # block_load.cuh ··· 180 owner = "openxla"; 181 repo = "xla"; 182 # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl. 183 + rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4"; 184 + hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E="; 185 }; 186 187 dontBuild = true; 188 189 # This is necessary for patchShebangs to know the right path to use. ··· 210 repo = "jax"; 211 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! 212 rev = "refs/tags/${pname}-v${version}"; 213 + hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; 214 }; 215 216 nativeBuildInputs = [ ··· 354 ]; 355 356 sha256 = (if cudaSupport then { 357 + x86_64-linux = "sha256-VGNMf5/DgXbgsu1w5J1Pmrukw+7UO31BNU+crKVsX5k="; 358 } else { 359 + x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ="; 360 + aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA="; 361 }).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}"); 362 }; 363 ··· 404 # for more info. 405 postInstall = lib.optionalString cudaSupport '' 406 mkdir -p $out/bin 407 + ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas 408 409 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do 410 patchelf --add-rpath "${lib.makeLibraryPath [cuda_libs_joined cudnn nccl]}" "$lib" ··· 413 414 nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ]; 415 416 + dependencies = [ 417 absl-py 418 curl 419 double-conversion
+17 -3
pkgs/development/python-modules/jaxopt/default.nix
··· 6 , fetchpatch 7 , pytest-xdist 8 , pytestCheckHook 9 , absl-py 10 , cvxpy 11 , jax ··· 20 buildPythonPackage rec { 21 pname = "jaxopt"; 22 version = "0.8.3"; 23 - format = "setuptools"; 24 25 disabled = pythonOlder "3.8"; 26 ··· 41 }) 42 ]; 43 44 - propagatedBuildInputs = [ 45 absl-py 46 jax 47 jaxlib ··· 66 "jaxopt.tree_util" 67 ]; 68 69 - disabledTests = lib.optionals (stdenv.isLinux && stdenv.isAarch64) [ 70 # https://github.com/google/jaxopt/issues/577 71 "test_binary_logit_log_likelihood" 72 "test_solve_sparse" 73 "test_logreg_with_intercept_manual_loop3" 74 ]; 75 76 meta = with lib; {
··· 6 , fetchpatch 7 , pytest-xdist 8 , pytestCheckHook 9 + , setuptools 10 , absl-py 11 , cvxpy 12 , jax ··· 21 buildPythonPackage rec { 22 pname = "jaxopt"; 23 version = "0.8.3"; 24 + pyproject = true; 25 26 disabled = pythonOlder "3.8"; 27 ··· 42 }) 43 ]; 44 45 + build-system = [ 46 + setuptools 47 + ]; 48 + 49 + dependencies = [ 50 absl-py 51 jax 52 jaxlib ··· 71 "jaxopt.tree_util" 72 ]; 73 74 + disabledTests = [ 75 + # https://github.com/google/jaxopt/issues/592 76 + "test_solve_sparse" 77 + ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [ 78 # https://github.com/google/jaxopt/issues/577 79 "test_binary_logit_log_likelihood" 80 "test_solve_sparse" 81 "test_logreg_with_intercept_manual_loop3" 82 + 83 + # https://github.com/google/jaxopt/issues/593 84 + # Makes the test suite crash 85 + "test_dtype_consistency" 86 + # AssertionError: Array(0.01411963, dtype=float32) not less than or equal to 0.01 87 + "test_multiclass_logreg6" 88 ]; 89 90 meta = with lib; {
+4 -2
pkgs/development/python-modules/nanobind/default.nix
··· 51 scipy 52 torch 53 tensorflow 54 - jax 55 - jaxlib 56 ]; 57 58 meta = with lib; {
··· 51 scipy 52 torch 53 tensorflow 54 + # Uncomment at next release (1.9.3) 55 + # See https://github.com/wjakob/nanobind/issues/578 56 + # jax 57 + # jaxlib 58 ]; 59 60 meta = with lib; {
+7 -3
pkgs/development/python-modules/objax/default.nix
··· 1 { lib 2 , buildPythonPackage 3 , fetchFromGitHub 4 - , fetchpatch 5 , jax 6 , jaxlib 7 , keras ··· 30 hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk="; 31 }; 32 33 - nativeBuildInputs = [ 34 setuptools 35 ]; 36 ··· 40 jaxlib 41 ]; 42 43 - propagatedBuildInputs = [ 44 jax 45 numpy 46 parameterized
··· 1 { lib 2 , buildPythonPackage 3 , fetchFromGitHub 4 , jax 5 , jaxlib 6 , keras ··· 29 hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk="; 30 }; 31 32 + patches = [ 33 + # Issue reported upstream: https://github.com/google/objax/issues/270 34 + ./replace-deprecated-device_buffers.patch 35 + ]; 36 + 37 + build-system = [ 38 setuptools 39 ]; 40 ··· 44 jaxlib 45 ]; 46 47 + dependencies = [ 48 jax 49 numpy 50 parameterized
+14
pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch
···
··· 1 + diff --git a/objax/util/util.py b/objax/util/util.py 2 + index c31a356..344cf9a 100644 3 + --- a/objax/util/util.py 4 + +++ b/objax/util/util.py 5 + @@ -117,7 +117,8 @@ def get_local_devices(): 6 + if _local_devices is None: 7 + x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32) 8 + sharded_x = map_to_device(x) 9 + - _local_devices = [b.device() for b in sharded_x.device_buffers] 10 + + device_buffers = [buf.data for buf in sharded_x.addressable_shards] 11 + + _local_devices = [list(b.devices())[0] for b in device_buffers] 12 + return _local_devices 13 + 14 +
+2 -6
pkgs/development/python-modules/tensorflow/bin.nix
··· 22 , tensorboard 23 , config 24 , cudaSupport ? config.cudaSupport 25 - , cudaPackagesGoogle 26 , zlib 27 , python 28 , keras-applications ··· 43 44 let 45 packages = import ./binary-hashes.nix; 46 - inherit (cudaPackagesGoogle) cudatoolkit cudnn; 47 in buildPythonPackage { 48 pname = "tensorflow" + lib.optionalString cudaSupport "-gpu"; 49 inherit (packages) version; ··· 198 "tensorflow.python" 199 "tensorflow.python.framework" 200 ]; 201 - 202 - passthru = { 203 - cudaPackages = cudaPackagesGoogle; 204 - }; 205 206 meta = with lib; { 207 description = "Computation using data flow graphs for scalable machine learning";
··· 22 , tensorboard 23 , config 24 , cudaSupport ? config.cudaSupport 25 + , cudaPackages 26 , zlib 27 , python 28 , keras-applications ··· 43 44 let 45 packages = import ./binary-hashes.nix; 46 + inherit (cudaPackages) cudatoolkit cudnn; 47 in buildPythonPackage { 48 pname = "tensorflow" + lib.optionalString cudaSupport "-gpu"; 49 inherit (packages) version; ··· 198 "tensorflow.python" 199 "tensorflow.python.framework" 200 ]; 201 202 meta = with lib; { 203 description = "Computation using data flow graphs for scalable machine learning";
+7 -8
pkgs/development/python-modules/tensorflow/default.nix
··· 19 # https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0 20 , config 21 , cudaSupport ? config.cudaSupport 22 - , cudaPackagesGoogle 23 - , cudaCapabilities ? cudaPackagesGoogle.cudaFlags.cudaCapabilities 24 , mklSupport ? false, mkl 25 , tensorboardSupport ? true 26 # XLA without CUDA is broken ··· 50 # __ZN4llvm11SmallPtrSetIPKNS_10AllocaInstELj8EED1Ev in any of the 51 # translation units, so the build fails at link time 52 stdenv = 53 - if cudaSupport then cudaPackagesGoogle.backendStdenv 54 else if originalStdenv.isDarwin then llvmPackages.stdenv 55 else originalStdenv; 56 - inherit (cudaPackagesGoogle) cudatoolkit nccl; 57 # use compatible cuDNN (https://www.tensorflow.org/install/source#gpu) 58 # cudaPackages.cudnn led to this: 59 # https://github.com/tensorflow/tensorflow/issues/60398 60 cudnnAttribute = "cudnn_8_6"; 61 - cudnn = cudaPackagesGoogle.${cudnnAttribute}; 62 gentoo-patches = fetchzip { 63 url = "https://dev.gentoo.org/~perfinion/patches/tensorflow-patches-2.12.0.tar.bz2"; 64 hash = "sha256-SCRX/5/zML7LmKEPJkcM5Tebez9vv/gmE4xhT/jyqWs="; ··· 490 broken = 491 stdenv.isDarwin 492 || !(xlaSupport -> cudaSupport) 493 - || !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackagesGoogle) 494 - || !(cudaSupport -> cudaPackagesGoogle ? cudatoolkit); 495 } // lib.optionalAttrs stdenv.isDarwin { 496 timeout = 86400; # 24 hours 497 maxSilent = 14400; # 4h, double the default of 7200s ··· 594 # Regression test for #77626 removed because not more `tensorflow.contrib`. 595 596 passthru = { 597 - cudaPackages = cudaPackagesGoogle; 598 deps = bazel-build.deps; 599 libtensorflow = bazel-build.out; 600 };
··· 19 # https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0 20 , config 21 , cudaSupport ? config.cudaSupport 22 + , cudaPackages 23 + , cudaCapabilities ? cudaPackages.cudaFlags.cudaCapabilities 24 , mklSupport ? false, mkl 25 , tensorboardSupport ? true 26 # XLA without CUDA is broken ··· 50 # __ZN4llvm11SmallPtrSetIPKNS_10AllocaInstELj8EED1Ev in any of the 51 # translation units, so the build fails at link time 52 stdenv = 53 + if cudaSupport then cudaPackages.backendStdenv 54 else if originalStdenv.isDarwin then llvmPackages.stdenv 55 else originalStdenv; 56 + inherit (cudaPackages) cudatoolkit nccl; 57 # use compatible cuDNN (https://www.tensorflow.org/install/source#gpu) 58 # cudaPackages.cudnn led to this: 59 # https://github.com/tensorflow/tensorflow/issues/60398 60 cudnnAttribute = "cudnn_8_6"; 61 + cudnn = cudaPackages.${cudnnAttribute}; 62 gentoo-patches = fetchzip { 63 url = "https://dev.gentoo.org/~perfinion/patches/tensorflow-patches-2.12.0.tar.bz2"; 64 hash = "sha256-SCRX/5/zML7LmKEPJkcM5Tebez9vv/gmE4xhT/jyqWs="; ··· 490 broken = 491 stdenv.isDarwin 492 || !(xlaSupport -> cudaSupport) 493 + || !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackages) 494 + || !(cudaSupport -> cudaPackages ? cudatoolkit); 495 } // lib.optionalAttrs stdenv.isDarwin { 496 timeout = 86400; # 24 hours 497 maxSilent = 14400; # 4h, double the default of 7200s ··· 594 # Regression test for #77626 removed because not more `tensorflow.contrib`. 595 596 passthru = { 597 deps = bazel-build.deps; 598 libtensorflow = bazel-build.out; 599 };
-1
pkgs/test/cuda/default.nix
··· 3 recurseIntoAttrs, 4 5 cudaPackages, 6 - cudaPackagesGoogle, 7 8 cudaPackages_10_0, 9 cudaPackages_10_1,
··· 3 recurseIntoAttrs, 4 5 cudaPackages, 6 7 cudaPackages_10_0, 8 cudaPackages_10_1,
-4
pkgs/top-level/all-packages.nix
··· 7125 cudaPackages_12_3 = callPackage ./cuda-packages.nix { cudaVersion = "12.3"; }; 7126 cudaPackages_12 = cudaPackages_12_2; # Latest supported by cudnn 7127 7128 - # Use the older cudaPackages for tensorflow and jax, as determined by cudnn 7129 - # compatibility: https://www.tensorflow.org/install/source#gpu 7130 - cudaPackagesGoogle = cudaPackages_11; 7131 - 7132 cudaPackages = recurseIntoAttrs cudaPackages_12; 7133 7134 # TODO: move to alias
··· 7125 cudaPackages_12_3 = callPackage ./cuda-packages.nix { cudaVersion = "12.3"; }; 7126 cudaPackages_12 = cudaPackages_12_2; # Latest supported by cudnn 7127 7128 cudaPackages = recurseIntoAttrs cudaPackages_12; 7129 7130 # TODO: move to alias
+5
pkgs/top-level/python-packages.nix
··· 14885 14886 tensorflow-bin = callPackage ../development/python-modules/tensorflow/bin.nix { 14887 inherit (pkgs.config) cudaSupport; 14888 }; 14889 14890 tensorflow-build = let ··· 14892 protobufTF = pkgs.protobuf_21.override { 14893 abseil-cpp = pkgs.abseil-cpp_202301; 14894 }; 14895 grpcTF = (pkgs.grpc.overrideAttrs ( 14896 oldAttrs: rec { 14897 # nvcc fails on recent grpc versions, so we use the latest patch level ··· 14937 inherit (pkgs.darwin.apple_sdk.frameworks) Foundation Security; 14938 flatbuffers-core = pkgs.flatbuffers; 14939 flatbuffers-python = self.flatbuffers; 14940 protobuf-core = compat.protobufTF; 14941 protobuf-python = compat.protobuf-pythonTF; 14942 grpc = compat.grpcTF;
··· 14885 14886 tensorflow-bin = callPackage ../development/python-modules/tensorflow/bin.nix { 14887 inherit (pkgs.config) cudaSupport; 14888 + # https://www.tensorflow.org/install/source#gpu 14889 + cudaPackages = pkgs.cudaPackages_11; 14890 }; 14891 14892 tensorflow-build = let ··· 14894 protobufTF = pkgs.protobuf_21.override { 14895 abseil-cpp = pkgs.abseil-cpp_202301; 14896 }; 14897 + # https://www.tensorflow.org/install/source#gpu 14898 + cudaPackagesTF = pkgs.cudaPackages_11; 14899 grpcTF = (pkgs.grpc.overrideAttrs ( 14900 oldAttrs: rec { 14901 # nvcc fails on recent grpc versions, so we use the latest patch level ··· 14941 inherit (pkgs.darwin.apple_sdk.frameworks) Foundation Security; 14942 flatbuffers-core = pkgs.flatbuffers; 14943 flatbuffers-python = self.flatbuffers; 14944 + cudaPackages = compat.cudaPackagesTF; 14945 protobuf-core = compat.protobufTF; 14946 protobuf-python = compat.protobuf-pythonTF; 14947 grpc = compat.grpcTF;