1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 fetchPypi, 6 autoPatchelfHook, 7 pypaInstallHook, 8 wheelUnpackHook, 9 cudaPackages, 10 python, 11 jaxlib, 12 jax-cuda12-pjrt, 13}: 14let 15 inherit (jaxlib) version; 16 inherit (cudaPackages) cudaAtLeast; 17 inherit (jax-cuda12-pjrt) cudaLibPath; 18 19 getSrcFromPypi = 20 { 21 platform, 22 dist, 23 hash, 24 }: 25 fetchPypi { 26 inherit 27 version 28 platform 29 dist 30 hash 31 ; 32 pname = "jax_cuda12_plugin"; 33 format = "wheel"; 34 python = dist; 35 abi = dist; 36 }; 37 38 # upstream does not distribute jax-cuda12-plugin 0.4.38 binaries for aarch64-linux 39 srcs = { 40 "3.10-x86_64-linux" = getSrcFromPypi { 41 platform = "manylinux2014_x86_64"; 42 dist = "cp310"; 43 hash = "sha256-pwDhcYI84lUQIALkDJR4j6ho8hYle30/BWjQn+dcEHs="; 44 }; 45 "3.10-aarch64-linux" = getSrcFromPypi { 46 platform = "manylinux2014_aarch64"; 47 dist = "cp310"; 48 hash = "sha256-UwrYUcpGKZHOgtsmrUfwKwjOvkg8nI0MADfp4np7Up8="; 49 }; 50 "3.11-x86_64-linux" = getSrcFromPypi { 51 platform = "manylinux2014_x86_64"; 52 dist = "cp311"; 53 hash = "sha256-DZ7O3mbEAlhwKkImHoaM21ahA1UafDyISzX1Mcms1I4="; 54 }; 55 "3.11-aarch64-linux" = getSrcFromPypi { 56 platform = "manylinux2014_aarch64"; 57 dist = "cp311"; 58 hash = "sha256-fNG0iKVKMInolYjMr2dwiZUsglKefQQD4LBQGZ5SVBg="; 59 }; 60 "3.12-x86_64-linux" = getSrcFromPypi { 61 platform = "manylinux2014_x86_64"; 62 dist = "cp312"; 63 hash = "sha256-5w608IRpbD474StekJ7xIFyfVu/j3OzyYhvZtatZVNU="; 64 }; 65 "3.12-aarch64-linux" = getSrcFromPypi { 66 platform = "manylinux2014_aarch64"; 67 dist = "cp312"; 68 hash = "sha256-oqOvX5iIDYb40kartGpVLlou9J12e/xKdMjDV3UgB8Y="; 69 }; 70 "3.13-x86_64-linux" = getSrcFromPypi { 71 platform = "manylinux2014_x86_64"; 72 dist = "cp313"; 73 hash = "sha256-6W891KlCUWroeMn2l+au/teOFI8JAYynPuKLI0JqfYo="; 74 }; 75 "3.13-aarch64-linux" = getSrcFromPypi { 76 platform = "manylinux2014_aarch64"; 77 dist = "cp313"; 78 hash = "sha256-o0LyznxLH1nUA/Zlo1qGuGUCU7sl3jRkf7IlxFzrCgQ="; 79 }; 80 }; 81in 82buildPythonPackage { 83 pname = "jax-cuda12-plugin"; 84 inherit version; 85 pyproject = false; 86 87 src = ( 88 srcs."${python.pythonVersion}-${stdenv.hostPlatform.system}" 89 or (throw "python${python.pythonVersion}Packages.jax-cuda12-plugin is not supported on ${stdenv.hostPlatform.system}") 90 ); 91 92 nativeBuildInputs = [ 93 autoPatchelfHook 94 pypaInstallHook 95 wheelUnpackHook 96 ]; 97 98 # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel. 99 # Linking into $out is the least bad solution. See 100 # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 101 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211 102 # * https://github.com/NixOS/nixpkgs/pull/375186 103 # for more info. 104 postInstall = '' 105 mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 106 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 107 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 108 ''; 109 110 # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen 111 # and these implicit dependencies are not recognized by ldd or 112 # autoPatchelfHook. That means we need to sneak them into rpath. This step 113 # must be done after autoPatchelfHook and the automatic stripping of 114 # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the 115 # patchPhase. 116 preInstallCheck = '' 117 patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so 118 ''; 119 120 dependencies = [ jax-cuda12-pjrt ]; 121 122 pythonImportsCheck = [ "jax_cuda12_plugin" ]; 123 124 # FIXME: there are no tests, but we need to run preInstallCheck above 125 doCheck = true; 126 127 meta = { 128 description = "JAX Plugin for CUDA12"; 129 homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; 130 sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; 131 license = lib.licenses.asl20; 132 maintainers = with lib.maintainers; [ natsukium ]; 133 platforms = lib.platforms.linux; 134 # see CUDA compatibility matrix 135 # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder 136 broken = !(cudaAtLeast "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1"); 137 }; 138}