nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at python-updates 137 lines 4.4 kB view raw
1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 fetchPypi, 6 autoPatchelfHook, 7 pypaInstallHook, 8 wheelUnpackHook, 9 cudaPackages, 10 python, 11 jaxlib, 12 jax-cuda12-pjrt, 13}: 14let 15 inherit (jaxlib) version; 16 inherit (jax-cuda12-pjrt) cudaLibPath; 17 18 getSrcFromPypi = 19 { 20 platform, 21 dist, 22 hash, 23 }: 24 fetchPypi { 25 inherit 26 version 27 platform 28 dist 29 hash 30 ; 31 pname = "jax_cuda12_plugin"; 32 format = "wheel"; 33 python = dist; 34 abi = dist; 35 }; 36 37 # upstream does not distribute jax-cuda12-plugin 0.4.38 binaries for aarch64-linux 38 srcs = { 39 "3.11-x86_64-linux" = getSrcFromPypi { 40 platform = "manylinux_2_27_x86_64"; 41 dist = "cp311"; 42 hash = "sha256-CwozBM5+SUrNjZxZNJDBEqMs22AQ/hr8WE2eQf2GMWc="; 43 }; 44 "3.11-aarch64-linux" = getSrcFromPypi { 45 platform = "manylinux_2_27_aarch64"; 46 dist = "cp311"; 47 hash = "sha256-cNMyIkhK1cN1uPg1e3wjysuET27Pw5Vn+N1H/eboeFg="; 48 }; 49 "3.12-x86_64-linux" = getSrcFromPypi { 50 platform = "manylinux_2_27_x86_64"; 51 dist = "cp312"; 52 hash = "sha256-IBZYYbPT5m67LA9jpUfR1e4X6kSsO+cVPHkIycqMiPM="; 53 }; 54 "3.12-aarch64-linux" = getSrcFromPypi { 55 platform = "manylinux_2_27_aarch64"; 56 dist = "cp312"; 57 hash = "sha256-QD1eB3MbXNrDvZ+z9Ei9hIAGLLLAq2HqKtI/zQplR5o="; 58 }; 59 "3.13-x86_64-linux" = getSrcFromPypi { 60 platform = "manylinux_2_27_x86_64"; 61 dist = "cp313"; 62 hash = "sha256-gsZ5i+Zr+MdzOGkY5MjlzYEZdT87+zyku8RoGCg3UMY="; 63 }; 64 "3.13-aarch64-linux" = getSrcFromPypi { 65 platform = "manylinux_2_27_aarch64"; 66 dist = "cp313"; 67 hash = "sha256-Y3OH3DQIzSBFYmaFAvnpX3bG7d4KbS5I8FUWLcKuvw0="; 68 }; 69 "3.14-x86_64-linux" = getSrcFromPypi { 70 platform = "manylinux_2_27_x86_64"; 71 dist = "cp314"; 72 hash = "sha256-pYmLrB2KtgILVFRkQCVkCfLGa8u7OhCZykc8hIQ63a0="; 73 }; 74 "3.14-aarch64-linux" = getSrcFromPypi { 75 platform = "manylinux_2_27_aarch64"; 76 dist = "cp314"; 77 hash = "sha256-WMUUc/xiLgMTgDWYX3QYM1ZNcKS9WiF49htizaoy/5Q="; 78 }; 79 }; 80in 81buildPythonPackage { 82 pname = "jax-cuda12-plugin"; 83 inherit version; 84 pyproject = false; 85 86 src = ( 87 srcs."${python.pythonVersion}-${stdenv.hostPlatform.system}" 88 or (throw "python${python.pythonVersion}Packages.jax-cuda12-plugin is not supported on ${stdenv.hostPlatform.system}") 89 ); 90 91 nativeBuildInputs = [ 92 autoPatchelfHook 93 pypaInstallHook 94 wheelUnpackHook 95 ]; 96 97 # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel. 98 # Linking into $out is the least bad solution. See 99 # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 100 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211 101 # * https://github.com/NixOS/nixpkgs/pull/375186 102 # for more info. 103 postInstall = '' 104 mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 105 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 106 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 107 ''; 108 109 # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen 110 # and these implicit dependencies are not recognized by ldd or 111 # autoPatchelfHook. That means we need to sneak them into rpath. This step 112 # must be done after autoPatchelfHook and the automatic stripping of 113 # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the 114 # patchPhase. 115 preInstallCheck = '' 116 patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so 117 ''; 118 119 dependencies = [ jax-cuda12-pjrt ]; 120 121 pythonImportsCheck = [ "jax_cuda12_plugin" ]; 122 123 # FIXME: there are no tests, but we need to run preInstallCheck above 124 doCheck = true; 125 126 meta = { 127 description = "JAX Plugin for CUDA12"; 128 homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; 129 sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; 130 license = lib.licenses.asl20; 131 maintainers = with lib.maintainers; [ natsukium ]; 132 platforms = lib.platforms.linux; 133 # see CUDA compatibility matrix 134 # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder 135 broken = !(lib.versionAtLeast cudaPackages.cudnn.version "9.1"); 136 }; 137}