Merge pull request #291705 from GaetanLepage/jax

python311Packages.{jax,jaxlib,jaxlib-bin}: 0.4.24 -> 0.4.28

authored by Samuel Ainsworth and committed by GitHub 3a993d32 8fe1aa68

+128 -100
+6 -2
pkgs/development/python-modules/blackjax/default.nix
··· 16 16 17 17 buildPythonPackage rec { 18 18 pname = "blackjax"; 19 - version = "1.2.0"; 19 + version = "1.2.1"; 20 20 pyproject = true; 21 21 22 22 disabled = pythonOlder "3.9"; ··· 25 25 owner = "blackjax-devs"; 26 26 repo = "blackjax"; 27 27 rev = "refs/tags/${version}"; 28 - hash = "sha256-vXyxK3xALKG61YGK7fmoqQNGfOiagHFrvnU02WKZThw="; 28 + hash = "sha256-VoWBCjFMyE5LVJyf7du/pKlnvDHj22lguiP6ZUzH9ak="; 29 29 }; 30 30 31 31 build-system = [ ··· 56 56 disabledTests = [ 57 57 # too slow 58 58 "test_adaptive_tempered_smc" 59 + ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [ 60 + # Numerical test (AssertionError) 61 + # https://github.com/blackjax-devs/blackjax/issues/668 62 + "test_chees_adaptation" 59 63 ]; 60 64 61 65 pythonImportsCheck = [
+15 -2
pkgs/development/python-modules/equinox/default.nix
··· 48 48 pythonImportsCheck = [ "equinox" ]; 49 49 50 50 disabledTests = [ 51 - # Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted. 52 - "test_tracetime" 51 + # For simplicity, JAX has removed its internal frames from the traceback of the following exception. 52 + # https://github.com/patrick-kidger/equinox/issues/716 53 + "test_abstract" 54 + "test_complicated" 55 + "test_grad" 56 + "test_jvp" 57 + "test_mlp" 58 + "test_num_traces" 59 + "test_pytree_in" 60 + "test_simple" 61 + "test_vmap" 62 + 63 + # AssertionError: assert 'foo:\n pri...pe=float32)\n' == 'foo:\n pri...pe=float32)\n' 64 + # Also reported in patrick-kidger/equinox#716 65 + "test_backward_nan" 53 66 ]; 54 67 55 68 meta = with lib; {
+4 -4
pkgs/development/python-modules/flax/default.nix
··· 25 25 26 26 buildPythonPackage rec { 27 27 pname = "flax"; 28 - version = "0.8.2"; 28 + version = "0.8.3"; 29 29 pyproject = true; 30 30 31 31 disabled = pythonOlder "3.9"; ··· 34 34 owner = "google"; 35 35 repo = "flax"; 36 36 rev = "refs/tags/v${version}"; 37 - hash = "sha256-UABgJGe1grUSkwOJpjeIoFqhXsqG//HlC1YyYPxXV+g="; 37 + hash = "sha256-uDGTyksUZTTL6FiTJP+qteFLOjr75dcTj9yRJ6Jm8xU="; 38 38 }; 39 39 40 - nativeBuildInputs = [ 40 + build-system = [ 41 41 jaxlib 42 42 pythonRelaxDepsHook 43 43 setuptools-scm 44 44 ]; 45 45 46 - propagatedBuildInputs = [ 46 + dependencies = [ 47 47 jax 48 48 msgpack 49 49 numpy
+10 -2
pkgs/development/python-modules/jax/default.nix
··· 29 29 in 30 30 buildPythonPackage rec { 31 31 pname = "jax"; 32 - version = "0.4.25"; 32 + version = "0.4.28"; 33 33 pyproject = true; 34 34 35 35 disabled = pythonOlder "3.9"; ··· 39 39 repo = "jax"; 40 40 # google/jax contains tags for jax and jaxlib. Only use jax tags! 41 41 rev = "refs/tags/jax-v${version}"; 42 - hash = "sha256-poQQo2ZgEhPYzK3aCs+BjaHTNZbezJAECd+HOdY1Yok="; 42 + hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; 43 43 }; 44 44 45 45 nativeBuildInputs = [ ··· 80 80 "-W ignore::DeprecationWarning" 81 81 "tests/" 82 82 ]; 83 + 84 + # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with 85 + # PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py' 86 + # See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241 87 + # NOTE: this doesn't seem to be an issue on linux 88 + preCheck = lib.optionalString stdenv.isDarwin '' 89 + export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d) 90 + ''; 83 91 84 92 disabledTests = [ 85 93 # Exceeds tolerance when the machine is busy
+22 -38
pkgs/development/python-modules/jaxlib/bin.nix
··· 20 20 , stdenv 21 21 # Options: 22 22 , cudaSupport ? config.cudaSupport 23 - , cudaPackagesGoogle 23 + , cudaPackages 24 24 }: 25 25 26 26 let 27 - inherit (cudaPackagesGoogle) cudaVersion; 27 + inherit (cudaPackages) cudaVersion; 28 28 29 - version = "0.4.24"; 29 + version = "0.4.28"; 30 30 31 31 inherit (python) pythonVersion; 32 32 33 - cudaLibPath = lib.makeLibraryPath (with cudaPackagesGoogle; [ 33 + cudaLibPath = lib.makeLibraryPath (with cudaPackages; [ 34 34 cuda_cudart.lib # libcudart.so 35 35 cuda_cupti.lib # libcupti.so 36 36 cudnn.lib # libcudnn.so ··· 56 56 "3.9-x86_64-linux" = getSrcFromPypi { 57 57 platform = "manylinux2014_x86_64"; 58 58 dist = "cp39"; 59 - hash = "sha256-6P5ArMoLZiUkHUoQ/mJccbNj5/7el/op+Qo6cGQ33xE="; 59 + hash = "sha256-Slbr8FtKTBeRaZ2HTgcvP4CPCYa0AQsU+1SaackMqdw="; 60 60 }; 61 61 "3.9-aarch64-darwin" = getSrcFromPypi { 62 62 platform = "macosx_11_0_arm64"; 63 63 dist = "cp39"; 64 - hash = "sha256-23JQZRwMLtt7sK/JlCBqqRyfTVIAVJFN2sL+nAkQgvU="; 64 + hash = "sha256-sBVi7IrXVxm30DiXUkiel+trTctMjBE75JFjTVKCrTw="; 65 65 }; 66 66 "3.9-x86_64-darwin" = getSrcFromPypi { 67 67 platform = "macosx_10_14_x86_64"; 68 68 dist = "cp39"; 69 - hash = "sha256-OgMedn9GHGs5THZf3pkP3Aw/jJ0vL5qK1b+Lzf634Ik="; 69 + hash = "sha256-T5jMg3srbG3P4Kt/+esQkxSSCUYRmqOvn6oTlxj/J4c="; 70 70 }; 71 71 72 72 "3.10-x86_64-linux" = getSrcFromPypi { 73 73 platform = "manylinux2014_x86_64"; 74 74 dist = "cp310"; 75 - hash = "sha256-/VwUIIa7mTs/wLz0ArsEfNrz2pGriVVT5GX9XRFRxfY="; 75 + hash = "sha256-47zcb45g+FVPQVwU2TATTmAuPKM8OOVGJ0/VRfh1dps="; 76 76 }; 77 77 "3.10-aarch64-darwin" = getSrcFromPypi { 78 78 platform = "macosx_11_0_arm64"; 79 79 dist = "cp310"; 80 - hash = "sha256-LgICOyDGts840SQQJh+yOMobMASb62llvJjpGvhzrSw="; 80 + hash = "sha256-8Djmi9ENGjVUcisLvjbmpEg4RDenWqnSg/aW8O2fjAk="; 81 81 }; 82 82 "3.10-x86_64-darwin" = getSrcFromPypi { 83 83 platform = "macosx_10_14_x86_64"; 84 84 dist = "cp310"; 85 - hash = "sha256-vhyULw+zBpz1UEi2tqgBMQEzY9a6YBgEIg6A4PPh3bQ="; 85 + hash = "sha256-pCHSN/jCXShQFm0zRgPGc925tsJvUrxJZwS4eCKXvWY="; 86 86 }; 87 87 88 88 "3.11-x86_64-linux" = getSrcFromPypi { 89 89 platform = "manylinux2014_x86_64"; 90 90 dist = "cp311"; 91 - hash = "sha256-VJO/VVwBFkOEtq4y/sLVgAV8Cung01JULiuT6W96E/8="; 91 + hash = "sha256-Rc4PPIQM/4I2z/JsN/Jsn/B4aV+T4MFiwyDCgfUEEnU="; 92 92 }; 93 93 "3.11-aarch64-darwin" = getSrcFromPypi { 94 94 platform = "macosx_11_0_arm64"; 95 95 dist = "cp311"; 96 - hash = "sha256-VtuwXxurpSp1KI8ty1bizs5cdy8GEBN2MgS227sOCmE="; 96 + hash = "sha256-eThX+vN/Nxyv51L+pfyBH0NeQ7j7S1AgWERKf17M+Ck="; 97 97 }; 98 98 "3.11-x86_64-darwin" = getSrcFromPypi { 99 99 platform = "macosx_10_14_x86_64"; 100 100 dist = "cp311"; 101 - hash = "sha256-4Dj5dEGKb9hpg3HlVogNO1Gc9UibJhy1eym2mjivxAQ="; 101 + hash = "sha256-L/gpDtx7ksfq5SUX9lSSYz4mey6QZ7rT5MMj0hPnfPU="; 102 102 }; 103 103 104 104 "3.12-x86_64-linux" = getSrcFromPypi { 105 105 platform = "manylinux2014_x86_64"; 106 106 dist = "cp312"; 107 - hash = "sha256-TlrGVtb3NTLmhnILWPLJR+jISCZ5SUV4wxNFpSfkCBo="; 107 + hash = "sha256-RqGqhX9P7uikP8upXA4Kti1AwmzJcwtsaWVZCLo1n40="; 108 108 }; 109 109 "3.12-aarch64-darwin" = getSrcFromPypi { 110 110 platform = "macosx_11_0_arm64"; 111 111 dist = "cp312"; 112 - hash = "sha256-FIwK5CGykQjteuWzLZnbtAggIxLQeGV96bXlZGEytN0="; 112 + hash = "sha256-jdi//jhTcC9jzZJNoO4lc0pNGc1ckmvgM9dyun0cF10="; 113 113 }; 114 114 "3.12-x86_64-darwin" = getSrcFromPypi { 115 115 platform = "macosx_10_14_x86_64"; 116 116 dist = "cp312"; 117 - hash = "sha256-9/jw/wr6oUD9pOadVAaMRL086iVMUXwVgnUMcG1UNvE="; 117 + hash = "sha256-1sCaVFMpciRhrwVuc1FG0sjHTCKsdCaoRetp8ya096A="; 118 118 }; 119 119 }; 120 120 ··· 130 130 gpuSrcs = { 131 131 "cuda12.2-3.9" = fetchurl { 132 132 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl"; 133 - hash = "sha256-xdJKLPtx+CIza2CrWKM3M0cZJzyNFVTTTsvlgh38bfM="; 133 + hash = "sha256-d8LIl22gIvmWfoyKfXKElZJXicPQIZxdS4HumhwQGCw="; 134 134 }; 135 135 "cuda12.2-3.10" = fetchurl { 136 136 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"; 137 - hash = "sha256-QCjrOczD2mp+CDwVXBc0/4rJnAizeV62AK0Dpx9X6TE="; 137 + hash = "sha256-PXtWv+UEcMWF8LhWe6Z1UGkf14PG3dkJ0Iop0LiimnQ="; 138 138 }; 139 139 "cuda12.2-3.11" = fetchurl { 140 140 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl"; 141 - hash = "sha256-Ipy3vk1yUplpNzECAFt63aOIhgEWgXG7hkoeTIk9bQQ="; 141 + hash = "sha256-QO2WSOzmJ48VaCha596mELiOfPsAGLpGctmdzcCHE/o="; 142 142 }; 143 143 "cuda12.2-3.12" = fetchurl { 144 144 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl"; 145 - hash = "sha256-LSnZHaUga/8Z65iKXWBnZDk4yUpNykFTu3vukCchO6Q="; 146 - }; 147 - "cuda11.8-3.9" = fetchurl { 148 - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl"; 149 - hash = "sha256-UmyugL0VjlXkiD7fuDPWgW8XUpr/QaP5ggp6swoZTzU="; 150 - }; 151 - "cuda11.8-3.10" = fetchurl { 152 - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; 153 - hash = "sha256-luKULEiV1t/sO6eckDxddJTiOFa0dtJeDlrvp+WYmHk="; 154 - }; 155 - "cuda11.8-3.11" = fetchurl { 156 - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl"; 157 - hash = "sha256-4+uJ8Ij6mFGEmjFEgi3fLnSLZs+v18BRoOt7mZuqydw="; 158 - }; 159 - "cuda11.8-3.12" = fetchurl { 160 - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp312-cp312-manylinux2014_x86_64.whl"; 161 - hash = "sha256-bUDFb94Ar/65SzzR9RLIs/SL/HdjaPT1Su5whmjkS00="; 145 + hash = "sha256-ixWMaIChy4Ammsn23/3cCoala0lFibuUxyUr3tjfFKU="; 162 146 }; 163 147 }; 164 148 ··· 213 197 # for more info. 214 198 postInstall = lib.optional cudaSupport '' 215 199 mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin 216 - ln -s ${lib.getExe' cudaPackagesGoogle.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas 200 + ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas 217 201 ''; 218 202 219 203 inherit (jaxlib-build) pythonImportsCheck; ··· 227 211 platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ]; 228 212 broken = 229 213 !(cudaSupport -> lib.versionAtLeast cudaVersion "11.1") 230 - || !(cudaSupport -> lib.versionAtLeast cudaPackagesGoogle.cudnn.version "8.2") 214 + || !(cudaSupport -> lib.versionAtLeast cudaPackages.cudnn.version "8.2") 231 215 || !(cudaSupport -> stdenv.isLinux) 232 216 || !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}")) 233 217 # Fails at pythonImportsCheckPhase:
+15 -25
pkgs/development/python-modules/jaxlib/default.nix
··· 13 13 , curl 14 14 , cython 15 15 , fetchFromGitHub 16 - , fetchpatch 17 16 , git 18 17 , IOKit 19 18 , jsoncpp ··· 45 44 , config 46 45 # CUDA flags: 47 46 , cudaSupport ? config.cudaSupport 48 - , cudaPackagesGoogle 47 + , cudaPackages 49 48 50 49 # MKL: 51 50 , mklSupport ? true 52 51 }@inputs: 53 52 54 53 let 55 - inherit (cudaPackagesGoogle) cudaFlags cudaVersion cudnn nccl; 54 + inherit (cudaPackages) cudaFlags cudaVersion cudnn nccl; 56 55 57 56 pname = "jaxlib"; 58 - version = "0.4.24"; 57 + version = "0.4.28"; 59 58 60 59 # It's necessary to consistently use backendStdenv when building with CUDA 61 60 # support, otherwise we get libstdc++ errors downstream 62 61 stdenv = throw "Use effectiveStdenv instead"; 63 - effectiveStdenv = if cudaSupport then cudaPackagesGoogle.backendStdenv else inputs.stdenv; 62 + effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv; 64 63 65 64 meta = with lib; { 66 65 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; ··· 78 77 # These are necessary at build time and run time. 79 78 cuda_libs_joined = symlinkJoin { 80 79 name = "cuda-joined"; 81 - paths = with cudaPackagesGoogle; [ 80 + paths = with cudaPackages; [ 82 81 cuda_cudart.lib # libcudart.so 83 82 cuda_cudart.static # libcudart_static.a 84 83 cuda_cupti.lib # libcupti.so ··· 92 91 # These are only necessary at build time. 93 92 cuda_build_deps_joined = symlinkJoin { 94 93 name = "cuda-build-deps-joined"; 95 - paths = with cudaPackagesGoogle; [ 94 + paths = with cudaPackages; [ 96 95 cuda_libs_joined 97 96 98 97 # Binaries 99 - cudaPackagesGoogle.cuda_nvcc.bin # nvcc 98 + cudaPackages.cuda_nvcc.bin # nvcc 100 99 101 100 # Headers 102 101 cuda_cccl.dev # block_load.cuh ··· 181 180 owner = "openxla"; 182 181 repo = "xla"; 183 182 # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl. 184 - rev = "12eee889e1f2ad41e27d7b0e970cb92d282d3ec5"; 185 - hash = "sha256-68kjjgwYjRlcT0TVJo9BN6s+WTkdu5UMJqQcfHpBT90="; 183 + rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4"; 184 + hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E="; 186 185 }; 187 186 188 - patches = [ 189 - # Resolves "could not convert ‘result’ from ‘SmallVector<[...],6>’ to 190 - # ‘SmallVector<[...],4>’" compilation error. See https://github.com/google/jax/issues/19814#issuecomment-1945141259. 191 - (fetchpatch { 192 - url = "https://github.com/openxla/xla/commit/7a614cd346594fc7ea2fe75570c9c53a4a444f60.patch"; 193 - hash = "sha256-RtuQTH8wzNiJcOtISLhf+gMlH1gg8hekvxEB+4wX6BM="; 194 - }) 195 - ]; 196 - 197 187 dontBuild = true; 198 188 199 189 # This is necessary for patchShebangs to know the right path to use. ··· 220 210 repo = "jax"; 221 211 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! 222 212 rev = "refs/tags/${pname}-v${version}"; 223 - hash = "sha256-hmx7eo3pephc6BQfoJ3U0QwWBWmhkAc+7S4QmW32qQs="; 213 + hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; 224 214 }; 225 215 226 216 nativeBuildInputs = [ ··· 364 354 ]; 365 355 366 356 sha256 = (if cudaSupport then { 367 - x86_64-linux = "sha256-8JilAoTbqOjOOJa/Zc/n/quaEDcpdcLXCNb34mfB+OM="; 357 + x86_64-linux = "sha256-VGNMf5/DgXbgsu1w5J1Pmrukw+7UO31BNU+crKVsX5k="; 368 358 } else { 369 - x86_64-linux = "sha256-iqS+I1FQLNWXNMsA20cJp7YkyGUeshee5b2QfRBNZtk="; 370 - aarch64-linux = "sha256-qmJ0Fm/VGMTmko4PhKs1P8/GLEJmVxb8xg+ss/HsakY=="; 359 + x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ="; 360 + aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA="; 371 361 }).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}"); 372 362 }; 373 363 ··· 414 404 # for more info. 415 405 postInstall = lib.optionalString cudaSupport '' 416 406 mkdir -p $out/bin 417 - ln -s ${cudaPackagesGoogle.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas 407 + ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas 418 408 419 409 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do 420 410 patchelf --add-rpath "${lib.makeLibraryPath [cuda_libs_joined cudnn nccl]}" "$lib" ··· 423 413 424 414 nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ]; 425 415 426 - propagatedBuildInputs = [ 416 + dependencies = [ 427 417 absl-py 428 418 curl 429 419 double-conversion
+17 -3
pkgs/development/python-modules/jaxopt/default.nix
··· 6 6 , fetchpatch 7 7 , pytest-xdist 8 8 , pytestCheckHook 9 + , setuptools 9 10 , absl-py 10 11 , cvxpy 11 12 , jax ··· 20 21 buildPythonPackage rec { 21 22 pname = "jaxopt"; 22 23 version = "0.8.3"; 23 - format = "setuptools"; 24 + pyproject = true; 24 25 25 26 disabled = pythonOlder "3.8"; 26 27 ··· 41 42 }) 42 43 ]; 43 44 44 - propagatedBuildInputs = [ 45 + build-system = [ 46 + setuptools 47 + ]; 48 + 49 + dependencies = [ 45 50 absl-py 46 51 jax 47 52 jaxlib ··· 66 71 "jaxopt.tree_util" 67 72 ]; 68 73 69 - disabledTests = lib.optionals (stdenv.isLinux && stdenv.isAarch64) [ 74 + disabledTests = [ 75 + # https://github.com/google/jaxopt/issues/592 76 + "test_solve_sparse" 77 + ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [ 70 78 # https://github.com/google/jaxopt/issues/577 71 79 "test_binary_logit_log_likelihood" 72 80 "test_solve_sparse" 73 81 "test_logreg_with_intercept_manual_loop3" 82 + 83 + # https://github.com/google/jaxopt/issues/593 84 + # Makes the test suite crash 85 + "test_dtype_consistency" 86 + # AssertionError: Array(0.01411963, dtype=float32) not less than or equal to 0.01 87 + "test_multiclass_logreg6" 74 88 ]; 75 89 76 90 meta = with lib; {
+4 -2
pkgs/development/python-modules/nanobind/default.nix
··· 51 51 scipy 52 52 torch 53 53 tensorflow 54 - jax 55 - jaxlib 54 + # Uncomment at next release (1.9.3) 55 + # See https://github.com/wjakob/nanobind/issues/578 56 + # jax 57 + # jaxlib 56 58 ]; 57 59 58 60 meta = with lib; {
+7 -3
pkgs/development/python-modules/objax/default.nix
··· 1 1 { lib 2 2 , buildPythonPackage 3 3 , fetchFromGitHub 4 - , fetchpatch 5 4 , jax 6 5 , jaxlib 7 6 , keras ··· 30 29 hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk="; 31 30 }; 32 31 33 - nativeBuildInputs = [ 32 + patches = [ 33 + # Issue reported upstream: https://github.com/google/objax/issues/270 34 + ./replace-deprecated-device_buffers.patch 35 + ]; 36 + 37 + build-system = [ 34 38 setuptools 35 39 ]; 36 40 ··· 40 44 jaxlib 41 45 ]; 42 46 43 - propagatedBuildInputs = [ 47 + dependencies = [ 44 48 jax 45 49 numpy 46 50 parameterized
+14
pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch
··· 1 + diff --git a/objax/util/util.py b/objax/util/util.py 2 + index c31a356..344cf9a 100644 3 + --- a/objax/util/util.py 4 + +++ b/objax/util/util.py 5 + @@ -117,7 +117,8 @@ def get_local_devices(): 6 + if _local_devices is None: 7 + x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32) 8 + sharded_x = map_to_device(x) 9 + - _local_devices = [b.device() for b in sharded_x.device_buffers] 10 + + device_buffers = [buf.data for buf in sharded_x.addressable_shards] 11 + + _local_devices = [list(b.devices())[0] for b in device_buffers] 12 + return _local_devices 13 + 14 +
+2 -6
pkgs/development/python-modules/tensorflow/bin.nix
··· 22 22 , tensorboard 23 23 , config 24 24 , cudaSupport ? config.cudaSupport 25 - , cudaPackagesGoogle 25 + , cudaPackages 26 26 , zlib 27 27 , python 28 28 , keras-applications ··· 43 43 44 44 let 45 45 packages = import ./binary-hashes.nix; 46 - inherit (cudaPackagesGoogle) cudatoolkit cudnn; 46 + inherit (cudaPackages) cudatoolkit cudnn; 47 47 in buildPythonPackage { 48 48 pname = "tensorflow" + lib.optionalString cudaSupport "-gpu"; 49 49 inherit (packages) version; ··· 198 198 "tensorflow.python" 199 199 "tensorflow.python.framework" 200 200 ]; 201 - 202 - passthru = { 203 - cudaPackages = cudaPackagesGoogle; 204 - }; 205 201 206 202 meta = with lib; { 207 203 description = "Computation using data flow graphs for scalable machine learning";
+7 -8
pkgs/development/python-modules/tensorflow/default.nix
··· 19 19 # https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0 20 20 , config 21 21 , cudaSupport ? config.cudaSupport 22 - , cudaPackagesGoogle 23 - , cudaCapabilities ? cudaPackagesGoogle.cudaFlags.cudaCapabilities 22 + , cudaPackages 23 + , cudaCapabilities ? cudaPackages.cudaFlags.cudaCapabilities 24 24 , mklSupport ? false, mkl 25 25 , tensorboardSupport ? true 26 26 # XLA without CUDA is broken ··· 50 50 # __ZN4llvm11SmallPtrSetIPKNS_10AllocaInstELj8EED1Ev in any of the 51 51 # translation units, so the build fails at link time 52 52 stdenv = 53 - if cudaSupport then cudaPackagesGoogle.backendStdenv 53 + if cudaSupport then cudaPackages.backendStdenv 54 54 else if originalStdenv.isDarwin then llvmPackages.stdenv 55 55 else originalStdenv; 56 - inherit (cudaPackagesGoogle) cudatoolkit nccl; 56 + inherit (cudaPackages) cudatoolkit nccl; 57 57 # use compatible cuDNN (https://www.tensorflow.org/install/source#gpu) 58 58 # cudaPackages.cudnn led to this: 59 59 # https://github.com/tensorflow/tensorflow/issues/60398 60 60 cudnnAttribute = "cudnn_8_6"; 61 - cudnn = cudaPackagesGoogle.${cudnnAttribute}; 61 + cudnn = cudaPackages.${cudnnAttribute}; 62 62 gentoo-patches = fetchzip { 63 63 url = "https://dev.gentoo.org/~perfinion/patches/tensorflow-patches-2.12.0.tar.bz2"; 64 64 hash = "sha256-SCRX/5/zML7LmKEPJkcM5Tebez9vv/gmE4xhT/jyqWs="; ··· 490 490 broken = 491 491 stdenv.isDarwin 492 492 || !(xlaSupport -> cudaSupport) 493 - || !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackagesGoogle) 494 - || !(cudaSupport -> cudaPackagesGoogle ? cudatoolkit); 493 + || !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackages) 494 + || !(cudaSupport -> cudaPackages ? cudatoolkit); 495 495 } // lib.optionalAttrs stdenv.isDarwin { 496 496 timeout = 86400; # 24 hours 497 497 maxSilent = 14400; # 4h, double the default of 7200s ··· 594 594 # Regression test for #77626 removed because not more `tensorflow.contrib`. 595 595 596 596 passthru = { 597 - cudaPackages = cudaPackagesGoogle; 598 597 deps = bazel-build.deps; 599 598 libtensorflow = bazel-build.out; 600 599 };
-1
pkgs/test/cuda/default.nix
··· 3 3 recurseIntoAttrs, 4 4 5 5 cudaPackages, 6 - cudaPackagesGoogle, 7 6 8 7 cudaPackages_10_0, 9 8 cudaPackages_10_1,
-4
pkgs/top-level/all-packages.nix
··· 7125 7125 cudaPackages_12_3 = callPackage ./cuda-packages.nix { cudaVersion = "12.3"; }; 7126 7126 cudaPackages_12 = cudaPackages_12_2; # Latest supported by cudnn 7127 7127 7128 - # Use the older cudaPackages for tensorflow and jax, as determined by cudnn 7129 - # compatibility: https://www.tensorflow.org/install/source#gpu 7130 - cudaPackagesGoogle = cudaPackages_11; 7131 - 7132 7128 cudaPackages = recurseIntoAttrs cudaPackages_12; 7133 7129 7134 7130 # TODO: move to alias
+5
pkgs/top-level/python-packages.nix
··· 14885 14885 14886 14886 tensorflow-bin = callPackage ../development/python-modules/tensorflow/bin.nix { 14887 14887 inherit (pkgs.config) cudaSupport; 14888 + # https://www.tensorflow.org/install/source#gpu 14889 + cudaPackages = pkgs.cudaPackages_11; 14888 14890 }; 14889 14891 14890 14892 tensorflow-build = let ··· 14892 14894 protobufTF = pkgs.protobuf_21.override { 14893 14895 abseil-cpp = pkgs.abseil-cpp_202301; 14894 14896 }; 14897 + # https://www.tensorflow.org/install/source#gpu 14898 + cudaPackagesTF = pkgs.cudaPackages_11; 14895 14899 grpcTF = (pkgs.grpc.overrideAttrs ( 14896 14900 oldAttrs: rec { 14897 14901 # nvcc fails on recent grpc versions, so we use the latest patch level ··· 14937 14941 inherit (pkgs.darwin.apple_sdk.frameworks) Foundation Security; 14938 14942 flatbuffers-core = pkgs.flatbuffers; 14939 14943 flatbuffers-python = self.flatbuffers; 14944 + cudaPackages = compat.cudaPackagesTF; 14940 14945 protobuf-core = compat.protobufTF; 14941 14946 protobuf-python = compat.protobuf-pythonTF; 14942 14947 grpc = compat.grpcTF;