1# For the moment we only support the CPU and GPU backends of jaxlib. The TPU 2# backend will require some additional work. Those wheels are located here: 3# https://storage.googleapis.com/jax-releases/libtpu_releases.html. 4 5# See `python3Packages.jax.passthru` for CUDA tests. 6 7{ 8 absl-py, 9 autoPatchelfHook, 10 buildPythonPackage, 11 fetchPypi, 12 flatbuffers, 13 lib, 14 ml-dtypes, 15 python, 16 scipy, 17 stdenv, 18}: 19 20let 21 version = "0.6.0"; 22 inherit (python) pythonVersion; 23 24 # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the 25 # official instructions recommend installing CPU-only versions via PyPI. 26 srcs = 27 let 28 getSrcFromPypi = 29 { 30 platform, 31 dist, 32 hash, 33 }: 34 fetchPypi { 35 inherit 36 version 37 platform 38 dist 39 hash 40 ; 41 pname = "jaxlib"; 42 format = "wheel"; 43 # See the `disabled` attr comment below. 44 python = dist; 45 abi = dist; 46 }; 47 in 48 { 49 "3.10-x86_64-linux" = getSrcFromPypi { 50 platform = "manylinux2014_x86_64"; 51 dist = "cp310"; 52 hash = "sha256-pNQlTHEziIh6MhN508WxogITqNzckD+vFRObqB4+zWE="; 53 }; 54 "3.10-aarch64-linux" = getSrcFromPypi { 55 platform = "manylinux2014_aarch64"; 56 dist = "cp310"; 57 hash = "sha256-VBpBi5iyjfW9Oh6TxistP2TUSwxwt7YI9/4rSqRSsq8="; 58 }; 59 "3.10-aarch64-darwin" = getSrcFromPypi { 60 platform = "macosx_11_0_arm64"; 61 dist = "cp310"; 62 hash = "sha256-ZKgvjrQP23uh1G75BzANQuT5jL2pYCou2OcNsamsSmA="; 63 }; 64 65 "3.11-x86_64-linux" = getSrcFromPypi { 66 platform = "manylinux2014_x86_64"; 67 dist = "cp311"; 68 hash = "sha256-vtRVJeO7XsCGML/SB8Ca+dYun/E/XwfC7iz9jthBG6E="; 69 }; 70 "3.11-aarch64-linux" = getSrcFromPypi { 71 platform = "manylinux2014_aarch64"; 72 dist = "cp311"; 73 hash = "sha256-wK6VmJmALhMpzI7ForTUvpoHa1vrIFLrSbo3UU5iPrw="; 74 }; 75 "3.11-aarch64-darwin" = getSrcFromPypi { 76 platform = "macosx_11_0_arm64"; 77 dist = "cp311"; 78 hash = "sha256-7xY88H3gC8VpAWnpf6+q3DePHDgfAofopHPnirW6sbU="; 79 }; 80 81 "3.12-x86_64-linux" = getSrcFromPypi { 82 platform = "manylinux2014_x86_64"; 83 dist = "cp312"; 84 hash = "sha256-tthbjR/XkkiwRQNRcgHnL8vNOYDPeR036BRwnqUKPII="; 85 }; 86 "3.12-aarch64-linux" = getSrcFromPypi { 87 platform = "manylinux2014_aarch64"; 88 dist = "cp312"; 89 hash = "sha256-JTb6k+wUjVAW2osgd7pmMlsNhqriKJphwSaHfwQrPRw="; 90 }; 91 "3.12-aarch64-darwin" = getSrcFromPypi { 92 platform = "macosx_11_0_arm64"; 93 dist = "cp312"; 94 hash = "sha256-fjzi7w7cm0izbicEw2GB8eznoSrBFN91PbQobqLG6Lg="; 95 }; 96 97 "3.13-x86_64-linux" = getSrcFromPypi { 98 platform = "manylinux2014_x86_64"; 99 dist = "cp313"; 100 hash = "sha256-0PsSLceDDKKlyjyHSghzY6AFMrZEUJwhnDv9HVRRXo0="; 101 }; 102 "3.13-aarch64-linux" = getSrcFromPypi { 103 platform = "manylinux2014_aarch64"; 104 dist = "cp313"; 105 hash = "sha256-GJcpY5diBQwXgLBQ6Y/2IEgLHqMr8WdTPgAKXPTFc44="; 106 }; 107 "3.13-aarch64-darwin" = getSrcFromPypi { 108 platform = "macosx_11_0_arm64"; 109 dist = "cp313"; 110 hash = "sha256-xOl5NMuvUXI0OqWujvDFhGLOJhVN/adUICswNBYMrHs="; 111 }; 112 }; 113in 114buildPythonPackage { 115 pname = "jaxlib"; 116 inherit version; 117 format = "wheel"; 118 119 # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. 120 src = ( 121 srcs."${pythonVersion}-${stdenv.hostPlatform.system}" 122 or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}") 123 ); 124 125 # Prebuilt wheels are dynamically linked against things that nix can't find. 126 # Run `autoPatchelfHook` to automagically fix them. 127 nativeBuildInputs = lib.optionals stdenv.hostPlatform.isLinux [ autoPatchelfHook ]; 128 # Dynamic link dependencies 129 buildInputs = [ (lib.getLib stdenv.cc.cc) ]; 130 131 dependencies = [ 132 absl-py 133 flatbuffers 134 ml-dtypes 135 scipy 136 ]; 137 138 pythonImportsCheck = [ "jaxlib" ]; 139 140 meta = { 141 description = "Prebuilt jaxlib backend from PyPi"; 142 homepage = "https://github.com/google/jax"; 143 sourceProvenance = with lib.sourceTypes; [ binaryNativeCode ]; 144 license = lib.licenses.asl20; 145 maintainers = with lib.maintainers; [ samuela ]; 146 badPlatforms = [ 147 # Fails at pythonImportsCheckPhase: 148 # ...-python-imports-check-hook.sh/nix-support/setup-hook: line 10: 28017 Illegal instruction: 4 149 # /nix/store/5qpssbvkzfh73xih07xgmpkj5r565975-python3-3.11.9/bin/python3.11 -c 150 # 'import os; import importlib; list(map(lambda mod: importlib.import_module(mod), os.environ["pythonImportsCheck"].split()))' 151 "x86_64-darwin" 152 ]; 153 }; 154}