1# For the moment we only support the CPU and GPU backends of jaxlib. The TPU
2# backend will require some additional work. Those wheels are located here:
3# https://storage.googleapis.com/jax-releases/libtpu_releases.html.
4
5# See `python3Packages.jax.passthru` for CUDA tests.
6
7{
8 absl-py,
9 autoPatchelfHook,
10 buildPythonPackage,
11 fetchPypi,
12 flatbuffers,
13 lib,
14 ml-dtypes,
15 python,
16 scipy,
17 stdenv,
18}:
19
20let
21 version = "0.6.0";
22 inherit (python) pythonVersion;
23
24 # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the
25 # official instructions recommend installing CPU-only versions via PyPI.
26 srcs =
27 let
28 getSrcFromPypi =
29 {
30 platform,
31 dist,
32 hash,
33 }:
34 fetchPypi {
35 inherit
36 version
37 platform
38 dist
39 hash
40 ;
41 pname = "jaxlib";
42 format = "wheel";
43 # See the `disabled` attr comment below.
44 python = dist;
45 abi = dist;
46 };
47 in
48 {
49 "3.10-x86_64-linux" = getSrcFromPypi {
50 platform = "manylinux2014_x86_64";
51 dist = "cp310";
52 hash = "sha256-pNQlTHEziIh6MhN508WxogITqNzckD+vFRObqB4+zWE=";
53 };
54 "3.10-aarch64-linux" = getSrcFromPypi {
55 platform = "manylinux2014_aarch64";
56 dist = "cp310";
57 hash = "sha256-VBpBi5iyjfW9Oh6TxistP2TUSwxwt7YI9/4rSqRSsq8=";
58 };
59 "3.10-aarch64-darwin" = getSrcFromPypi {
60 platform = "macosx_11_0_arm64";
61 dist = "cp310";
62 hash = "sha256-ZKgvjrQP23uh1G75BzANQuT5jL2pYCou2OcNsamsSmA=";
63 };
64
65 "3.11-x86_64-linux" = getSrcFromPypi {
66 platform = "manylinux2014_x86_64";
67 dist = "cp311";
68 hash = "sha256-vtRVJeO7XsCGML/SB8Ca+dYun/E/XwfC7iz9jthBG6E=";
69 };
70 "3.11-aarch64-linux" = getSrcFromPypi {
71 platform = "manylinux2014_aarch64";
72 dist = "cp311";
73 hash = "sha256-wK6VmJmALhMpzI7ForTUvpoHa1vrIFLrSbo3UU5iPrw=";
74 };
75 "3.11-aarch64-darwin" = getSrcFromPypi {
76 platform = "macosx_11_0_arm64";
77 dist = "cp311";
78 hash = "sha256-7xY88H3gC8VpAWnpf6+q3DePHDgfAofopHPnirW6sbU=";
79 };
80
81 "3.12-x86_64-linux" = getSrcFromPypi {
82 platform = "manylinux2014_x86_64";
83 dist = "cp312";
84 hash = "sha256-tthbjR/XkkiwRQNRcgHnL8vNOYDPeR036BRwnqUKPII=";
85 };
86 "3.12-aarch64-linux" = getSrcFromPypi {
87 platform = "manylinux2014_aarch64";
88 dist = "cp312";
89 hash = "sha256-JTb6k+wUjVAW2osgd7pmMlsNhqriKJphwSaHfwQrPRw=";
90 };
91 "3.12-aarch64-darwin" = getSrcFromPypi {
92 platform = "macosx_11_0_arm64";
93 dist = "cp312";
94 hash = "sha256-fjzi7w7cm0izbicEw2GB8eznoSrBFN91PbQobqLG6Lg=";
95 };
96
97 "3.13-x86_64-linux" = getSrcFromPypi {
98 platform = "manylinux2014_x86_64";
99 dist = "cp313";
100 hash = "sha256-0PsSLceDDKKlyjyHSghzY6AFMrZEUJwhnDv9HVRRXo0=";
101 };
102 "3.13-aarch64-linux" = getSrcFromPypi {
103 platform = "manylinux2014_aarch64";
104 dist = "cp313";
105 hash = "sha256-GJcpY5diBQwXgLBQ6Y/2IEgLHqMr8WdTPgAKXPTFc44=";
106 };
107 "3.13-aarch64-darwin" = getSrcFromPypi {
108 platform = "macosx_11_0_arm64";
109 dist = "cp313";
110 hash = "sha256-xOl5NMuvUXI0OqWujvDFhGLOJhVN/adUICswNBYMrHs=";
111 };
112 };
113in
114buildPythonPackage {
115 pname = "jaxlib";
116 inherit version;
117 format = "wheel";
118
119 # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
120 src = (
121 srcs."${pythonVersion}-${stdenv.hostPlatform.system}"
122 or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}")
123 );
124
125 # Prebuilt wheels are dynamically linked against things that nix can't find.
126 # Run `autoPatchelfHook` to automagically fix them.
127 nativeBuildInputs = lib.optionals stdenv.hostPlatform.isLinux [ autoPatchelfHook ];
128 # Dynamic link dependencies
129 buildInputs = [ (lib.getLib stdenv.cc.cc) ];
130
131 dependencies = [
132 absl-py
133 flatbuffers
134 ml-dtypes
135 scipy
136 ];
137
138 pythonImportsCheck = [ "jaxlib" ];
139
140 meta = {
141 description = "Prebuilt jaxlib backend from PyPi";
142 homepage = "https://github.com/google/jax";
143 sourceProvenance = with lib.sourceTypes; [ binaryNativeCode ];
144 license = lib.licenses.asl20;
145 maintainers = with lib.maintainers; [ samuela ];
146 badPlatforms = [
147 # Fails at pythonImportsCheckPhase:
148 # ...-python-imports-check-hook.sh/nix-support/setup-hook: line 10: 28017 Illegal instruction: 4
149 # /nix/store/5qpssbvkzfh73xih07xgmpkj5r565975-python3-3.11.9/bin/python3.11 -c
150 # 'import os; import importlib; list(map(lambda mod: importlib.import_module(mod), os.environ["pythonImportsCheck"].split()))'
151 "x86_64-darwin"
152 ];
153 };
154}