Merge pull request #258951 from GaetanLepage/jax

python310Packages.{jax,jaxlib,jaxlib-bin}: 0.4.16 -> 0.4.17

authored by Nick Cao and committed by GitHub cababf47 aeb18cc5

+13 -13
+3 -3
pkgs/development/python-modules/jax/default.nix
··· 27 in 28 buildPythonPackage rec { 29 pname = "jax"; 30 - version = "0.4.16"; 31 - format = "pyproject"; 32 33 disabled = pythonOlder "3.9"; 34 ··· 37 repo = pname; 38 # google/jax contains tags for jax and jaxlib. Only use jax tags! 39 rev = "refs/tags/${pname}-v${version}"; 40 - hash = "sha256-q+8CXGxK8JX0bUMK4KJB3qV/EaLHg68D1B5UrtRz0Eg="; 41 }; 42 43 nativeBuildInputs = [
··· 27 in 28 buildPythonPackage rec { 29 pname = "jax"; 30 + version = "0.4.17"; 31 + pyproject = true; 32 33 disabled = pythonOlder "3.9"; 34 ··· 37 repo = pname; 38 # google/jax contains tags for jax and jaxlib. Only use jax tags! 39 rev = "refs/tags/${pname}-v${version}"; 40 + hash = "sha256-Lxi/lBBq7VlsT6CgnXPFcwbRU+T8630rBdm693E2jok="; 41 }; 42 43 nativeBuildInputs = [
+5 -5
pkgs/development/python-modules/jaxlib/bin.nix
··· 39 assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux; 40 41 let 42 - version = "0.4.16"; 43 44 inherit (python) pythonVersion; 45 ··· 60 { 61 "x86_64-linux" = getSrcFromPypi { 62 platform = "manylinux2014_x86_64"; 63 - hash = "sha256-4XyaDnKEMhAbfPEvN3RCDEjXTWbOL6tWrTlyYeiboVs="; 64 }; 65 "aarch64-darwin" = getSrcFromPypi { 66 platform = "macosx_11_0_arm64"; 67 - hash = "sha256-IG2pCui/Yj+LDMbQwBVlu7yl2llqnaxMzz/MtBvBr6U="; 68 }; 69 "x86_64-darwin" = getSrcFromPypi { 70 platform = "macosx_10_14_x86_64"; 71 - hash = "sha256-x5DqsmHqEb7Dl7dnxT5N0l30GKt5OPZpq3HGX9MFKmo="; 72 }; 73 }; 74 ··· 78 # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. 79 gpuSrc = fetchurl { 80 url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; 81 - hash = "sha256-eLOprP2kv6roodwRKZXVZFQCD1wC26TSTEDJBjMu/Uo="; 82 }; 83 84 in
··· 39 assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux; 40 41 let 42 + version = "0.4.17"; 43 44 inherit (python) pythonVersion; 45 ··· 60 { 61 "x86_64-linux" = getSrcFromPypi { 62 platform = "manylinux2014_x86_64"; 63 + hash = "sha256-Fg/OaLgqeabFImUujdmhCqycANFZnLfhZmca2QmqE54="; 64 }; 65 "aarch64-darwin" = getSrcFromPypi { 66 platform = "macosx_11_0_arm64"; 67 + hash = "sha256-OSx3n5AsQ+Ggr0kVna/++bWvlSq6ABRj+Yz5WlnvF/8="; 68 }; 69 "x86_64-darwin" = getSrcFromPypi { 70 platform = "macosx_10_14_x86_64"; 71 + hash = "sha256-1L4axL8b4a4c2PX02kFKbQ3o3jbPLv/bV1jU1neJYHg="; 72 }; 73 }; 74 ··· 78 # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. 79 gpuSrc = fetchurl { 80 url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; 81 + hash = "sha256-Ctdlr8mvlMcTnBSiyjEEvle5AGr+o1v6OI7XIqcTENM="; 82 }; 83 84 in
+5 -5
pkgs/development/python-modules/jaxlib/default.nix
··· 54 inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl; 55 56 pname = "jaxlib"; 57 - version = "0.4.16"; 58 59 meta = with lib; { 60 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; ··· 151 repo = "jax"; 152 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! 153 rev = "refs/tags/${pname}-v${version}"; 154 - hash = "sha256-q+8CXGxK8JX0bUMK4KJB3qV/EaLHg68D1B5UrtRz0Eg="; 155 }; 156 157 nativeBuildInputs = [ ··· 264 ]; 265 266 sha256 = (if cudaSupport then { 267 - x86_64-linux = "sha256-6HkrEWAPjGPj4zRxahl0FLiV7WZO/6zsdCX8STfV5EE="; 268 } else { 269 - x86_64-linux = "sha256-MDnuJwJ/xKnC72Qub0ETYj5uQB2r8/AgGm10oqmzzcc="; 270 - aarch64-linux = "sha256-aVUm612VNEsjZLDrtiOPTqSk1t+AhmOx+pOG3bZdOAw="; 271 }).${stdenv.system} or (throw "jaxlib: unsupported system: ${stdenv.system}"); 272 }; 273
··· 54 inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl; 55 56 pname = "jaxlib"; 57 + version = "0.4.17"; 58 59 meta = with lib; { 60 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; ··· 151 repo = "jax"; 152 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! 153 rev = "refs/tags/${pname}-v${version}"; 154 + hash = "sha256-Lxi/lBBq7VlsT6CgnXPFcwbRU+T8630rBdm693E2jok="; 155 }; 156 157 nativeBuildInputs = [ ··· 264 ]; 265 266 sha256 = (if cudaSupport then { 267 + x86_64-linux = "sha256-nRvvFAuP/9D8BWWVPjuZijVtk+F9IrBBHsNc5Daluy4="; 268 } else { 269 + x86_64-linux = "sha256-pPIJOELN62GqUuaKpcpaqHu7wbJHiZgtb2PVUPRr1Ek="; 270 + aarch64-linux = "sha256-Q0PYZkOkUYUHVtSHZDlWitslDZbjNq6yRZv/ZkhTmyc="; 271 }).${stdenv.system} or (throw "jaxlib: unsupported system: ${stdenv.system}"); 272 }; 273