1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchPypi,
6 autoPatchelfHook,
7 pypaInstallHook,
8 wheelUnpackHook,
9 cudaPackages,
10 python,
11 jaxlib,
12 jax-cuda12-pjrt,
13}:
14let
15 inherit (jaxlib) version;
16 inherit (cudaPackages) cudaAtLeast;
17 inherit (jax-cuda12-pjrt) cudaLibPath;
18
19 getSrcFromPypi =
20 {
21 platform,
22 dist,
23 hash,
24 }:
25 fetchPypi {
26 inherit
27 version
28 platform
29 dist
30 hash
31 ;
32 pname = "jax_cuda12_plugin";
33 format = "wheel";
34 python = dist;
35 abi = dist;
36 };
37
38 # upstream does not distribute jax-cuda12-plugin 0.4.38 binaries for aarch64-linux
39 srcs = {
40 "3.10-x86_64-linux" = getSrcFromPypi {
41 platform = "manylinux2014_x86_64";
42 dist = "cp310";
43 hash = "sha256-pwDhcYI84lUQIALkDJR4j6ho8hYle30/BWjQn+dcEHs=";
44 };
45 "3.10-aarch64-linux" = getSrcFromPypi {
46 platform = "manylinux2014_aarch64";
47 dist = "cp310";
48 hash = "sha256-UwrYUcpGKZHOgtsmrUfwKwjOvkg8nI0MADfp4np7Up8=";
49 };
50 "3.11-x86_64-linux" = getSrcFromPypi {
51 platform = "manylinux2014_x86_64";
52 dist = "cp311";
53 hash = "sha256-DZ7O3mbEAlhwKkImHoaM21ahA1UafDyISzX1Mcms1I4=";
54 };
55 "3.11-aarch64-linux" = getSrcFromPypi {
56 platform = "manylinux2014_aarch64";
57 dist = "cp311";
58 hash = "sha256-fNG0iKVKMInolYjMr2dwiZUsglKefQQD4LBQGZ5SVBg=";
59 };
60 "3.12-x86_64-linux" = getSrcFromPypi {
61 platform = "manylinux2014_x86_64";
62 dist = "cp312";
63 hash = "sha256-5w608IRpbD474StekJ7xIFyfVu/j3OzyYhvZtatZVNU=";
64 };
65 "3.12-aarch64-linux" = getSrcFromPypi {
66 platform = "manylinux2014_aarch64";
67 dist = "cp312";
68 hash = "sha256-oqOvX5iIDYb40kartGpVLlou9J12e/xKdMjDV3UgB8Y=";
69 };
70 "3.13-x86_64-linux" = getSrcFromPypi {
71 platform = "manylinux2014_x86_64";
72 dist = "cp313";
73 hash = "sha256-6W891KlCUWroeMn2l+au/teOFI8JAYynPuKLI0JqfYo=";
74 };
75 "3.13-aarch64-linux" = getSrcFromPypi {
76 platform = "manylinux2014_aarch64";
77 dist = "cp313";
78 hash = "sha256-o0LyznxLH1nUA/Zlo1qGuGUCU7sl3jRkf7IlxFzrCgQ=";
79 };
80 };
81in
82buildPythonPackage {
83 pname = "jax-cuda12-plugin";
84 inherit version;
85 pyproject = false;
86
87 src = (
88 srcs."${python.pythonVersion}-${stdenv.hostPlatform.system}"
89 or (throw "python${python.pythonVersion}Packages.jax-cuda12-plugin is not supported on ${stdenv.hostPlatform.system}")
90 );
91
92 nativeBuildInputs = [
93 autoPatchelfHook
94 pypaInstallHook
95 wheelUnpackHook
96 ];
97
98 # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel.
99 # Linking into $out is the least bad solution. See
100 # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
101 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211
102 # * https://github.com/NixOS/nixpkgs/pull/375186
103 # for more info.
104 postInstall = ''
105 mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
106 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
107 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
108 '';
109
110 # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen
111 # and these implicit dependencies are not recognized by ldd or
112 # autoPatchelfHook. That means we need to sneak them into rpath. This step
113 # must be done after autoPatchelfHook and the automatic stripping of
114 # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
115 # patchPhase.
116 preInstallCheck = ''
117 patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so
118 '';
119
120 dependencies = [ jax-cuda12-pjrt ];
121
122 pythonImportsCheck = [ "jax_cuda12_plugin" ];
123
124 # FIXME: there are no tests, but we need to run preInstallCheck above
125 doCheck = true;
126
127 meta = {
128 description = "JAX Plugin for CUDA12";
129 homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda";
130 sourceProvenance = [ lib.sourceTypes.binaryNativeCode ];
131 license = lib.licenses.asl20;
132 maintainers = with lib.maintainers; [ natsukium ];
133 platforms = lib.platforms.linux;
134 # see CUDA compatibility matrix
135 # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder
136 broken = !(cudaAtLeast "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1");
137 };
138}