nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchPypi,
6 autoPatchelfHook,
7 pypaInstallHook,
8 wheelUnpackHook,
9 cudaPackages,
10 python,
11 jaxlib,
12 jax-cuda12-pjrt,
13}:
14let
15 inherit (jaxlib) version;
16 inherit (jax-cuda12-pjrt) cudaLibPath;
17
18 getSrcFromPypi =
19 {
20 platform,
21 dist,
22 hash,
23 }:
24 fetchPypi {
25 inherit
26 version
27 platform
28 dist
29 hash
30 ;
31 pname = "jax_cuda12_plugin";
32 format = "wheel";
33 python = dist;
34 abi = dist;
35 };
36
37 # upstream does not distribute jax-cuda12-plugin 0.4.38 binaries for aarch64-linux
38 srcs = {
39 "3.11-x86_64-linux" = getSrcFromPypi {
40 platform = "manylinux_2_27_x86_64";
41 dist = "cp311";
42 hash = "sha256-CwozBM5+SUrNjZxZNJDBEqMs22AQ/hr8WE2eQf2GMWc=";
43 };
44 "3.11-aarch64-linux" = getSrcFromPypi {
45 platform = "manylinux_2_27_aarch64";
46 dist = "cp311";
47 hash = "sha256-cNMyIkhK1cN1uPg1e3wjysuET27Pw5Vn+N1H/eboeFg=";
48 };
49 "3.12-x86_64-linux" = getSrcFromPypi {
50 platform = "manylinux_2_27_x86_64";
51 dist = "cp312";
52 hash = "sha256-IBZYYbPT5m67LA9jpUfR1e4X6kSsO+cVPHkIycqMiPM=";
53 };
54 "3.12-aarch64-linux" = getSrcFromPypi {
55 platform = "manylinux_2_27_aarch64";
56 dist = "cp312";
57 hash = "sha256-QD1eB3MbXNrDvZ+z9Ei9hIAGLLLAq2HqKtI/zQplR5o=";
58 };
59 "3.13-x86_64-linux" = getSrcFromPypi {
60 platform = "manylinux_2_27_x86_64";
61 dist = "cp313";
62 hash = "sha256-gsZ5i+Zr+MdzOGkY5MjlzYEZdT87+zyku8RoGCg3UMY=";
63 };
64 "3.13-aarch64-linux" = getSrcFromPypi {
65 platform = "manylinux_2_27_aarch64";
66 dist = "cp313";
67 hash = "sha256-Y3OH3DQIzSBFYmaFAvnpX3bG7d4KbS5I8FUWLcKuvw0=";
68 };
69 "3.14-x86_64-linux" = getSrcFromPypi {
70 platform = "manylinux_2_27_x86_64";
71 dist = "cp314";
72 hash = "sha256-pYmLrB2KtgILVFRkQCVkCfLGa8u7OhCZykc8hIQ63a0=";
73 };
74 "3.14-aarch64-linux" = getSrcFromPypi {
75 platform = "manylinux_2_27_aarch64";
76 dist = "cp314";
77 hash = "sha256-WMUUc/xiLgMTgDWYX3QYM1ZNcKS9WiF49htizaoy/5Q=";
78 };
79 };
80in
81buildPythonPackage {
82 pname = "jax-cuda12-plugin";
83 inherit version;
84 pyproject = false;
85
86 src = (
87 srcs."${python.pythonVersion}-${stdenv.hostPlatform.system}"
88 or (throw "python${python.pythonVersion}Packages.jax-cuda12-plugin is not supported on ${stdenv.hostPlatform.system}")
89 );
90
91 nativeBuildInputs = [
92 autoPatchelfHook
93 pypaInstallHook
94 wheelUnpackHook
95 ];
96
97 # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel.
98 # Linking into $out is the least bad solution. See
99 # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
100 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211
101 # * https://github.com/NixOS/nixpkgs/pull/375186
102 # for more info.
103 postInstall = ''
104 mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
105 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
106 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
107 '';
108
109 # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen
110 # and these implicit dependencies are not recognized by ldd or
111 # autoPatchelfHook. That means we need to sneak them into rpath. This step
112 # must be done after autoPatchelfHook and the automatic stripping of
113 # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
114 # patchPhase.
115 preInstallCheck = ''
116 patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so
117 '';
118
119 dependencies = [ jax-cuda12-pjrt ];
120
121 pythonImportsCheck = [ "jax_cuda12_plugin" ];
122
123 # FIXME: there are no tests, but we need to run preInstallCheck above
124 doCheck = true;
125
126 meta = {
127 description = "JAX Plugin for CUDA12";
128 homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda";
129 sourceProvenance = [ lib.sourceTypes.binaryNativeCode ];
130 license = lib.licenses.asl20;
131 maintainers = with lib.maintainers; [ natsukium ];
132 platforms = lib.platforms.linux;
133 # see CUDA compatibility matrix
134 # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder
135 broken = !(lib.versionAtLeast cudaPackages.cudnn.version "9.1");
136 };
137}