Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 fetchPypi, 6 autoPatchelfHook, 7 pypaInstallHook, 8 wheelUnpackHook, 9 cudaPackages, 10 python, 11 jaxlib, 12 jax-cuda12-pjrt, 13}: 14let 15 inherit (jaxlib) version; 16 inherit (cudaPackages) cudaAtLeast; 17 inherit (jax-cuda12-pjrt) cudaLibPath; 18 19 getSrcFromPypi = 20 { 21 platform, 22 dist, 23 hash, 24 }: 25 fetchPypi { 26 inherit 27 version 28 platform 29 dist 30 hash 31 ; 32 pname = "jax_cuda12_plugin"; 33 format = "wheel"; 34 python = dist; 35 abi = dist; 36 }; 37 38 # upstream does not distribute jax-cuda12-plugin 0.4.38 binaries for aarch64-linux 39 srcs = { 40 "3.10-x86_64-linux" = getSrcFromPypi { 41 platform = "manylinux2014_x86_64"; 42 dist = "cp310"; 43 hash = "sha256-F1H4iYkmmzzbDf5PewcqZEIUmBjJvJjDo5XIrK+RCnk="; 44 }; 45 "3.10-aarch64-linux" = getSrcFromPypi { 46 platform = "manylinux2014_aarch64"; 47 dist = "cp310"; 48 hash = "sha256-vFw6ddBVGbTTJuRmnQ960P4PCs+HX5MT2RN0jMylqeo="; 49 }; 50 "3.11-x86_64-linux" = getSrcFromPypi { 51 platform = "manylinux2014_x86_64"; 52 dist = "cp311"; 53 hash = "sha256-CJbLswjZUpHiBc2J0lQCne46HfQ9ZumDEzGpr9LSeHA="; 54 }; 55 "3.11-aarch64-linux" = getSrcFromPypi { 56 platform = "manylinux2014_aarch64"; 57 dist = "cp311"; 58 hash = "sha256-LNjieaWaOLoMl4qDHhOt627p5Fcvujh8eXW6OtU13Tg="; 59 }; 60 "3.12-x86_64-linux" = getSrcFromPypi { 61 platform = "manylinux2014_x86_64"; 62 dist = "cp312"; 63 hash = "sha256-/r0Jn5cNNQ64+losmi+0sOp7PWqJ3xSWZj7fp6/lkOU="; 64 }; 65 "3.12-aarch64-linux" = getSrcFromPypi { 66 platform = "manylinux2014_aarch64"; 67 dist = "cp312"; 68 hash = "sha256-bJsALROx/LlANxPu3Th2oietH/vfs4EbH5+Jr0wlpfc="; 69 }; 70 "3.13-x86_64-linux" = getSrcFromPypi { 71 platform = "manylinux2014_x86_64"; 72 dist = "cp313"; 73 hash = "sha256-20xhA8kS2M0a35TDTTE7tHYMp/AciXynzWLmXyeZQZk="; 74 }; 75 "3.13-aarch64-linux" = getSrcFromPypi { 76 platform = "manylinux2014_aarch64"; 77 dist = "cp313"; 78 hash = "sha256-dz76i1WoN0BsVh8O8CFE3akBkYEZN2DsVBnuyd0rmqw="; 79 }; 80 }; 81in 82buildPythonPackage { 83 pname = "jax-cuda12-plugin"; 84 inherit version; 85 pyproject = false; 86 87 src = ( 88 srcs."${python.pythonVersion}-${stdenv.hostPlatform.system}" 89 or (throw "python${python.pythonVersion}Packages.jax-cuda12-plugin is not supported on ${stdenv.hostPlatform.system}") 90 ); 91 92 nativeBuildInputs = [ 93 autoPatchelfHook 94 pypaInstallHook 95 wheelUnpackHook 96 ]; 97 98 # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel. 99 # Linking into $out is the least bad solution. See 100 # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 101 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211 102 # * https://github.com/NixOS/nixpkgs/pull/375186 103 # for more info. 104 postInstall = '' 105 mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 106 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 107 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 108 ''; 109 110 # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen 111 # and these implicit dependencies are not recognized by ldd or 112 # autoPatchelfHook. That means we need to sneak them into rpath. This step 113 # must be done after autoPatchelfHook and the automatic stripping of 114 # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the 115 # patchPhase. 116 preInstallCheck = '' 117 patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so 118 ''; 119 120 dependencies = [ jax-cuda12-pjrt ]; 121 122 pythonImportsCheck = [ "jax_cuda12_plugin" ]; 123 124 # FIXME: there are no tests, but we need to run preInstallCheck above 125 doCheck = true; 126 127 meta = { 128 description = "JAX Plugin for CUDA12"; 129 homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; 130 sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; 131 license = lib.licenses.asl20; 132 maintainers = with lib.maintainers; [ natsukium ]; 133 platforms = lib.platforms.linux; 134 # see CUDA compatibility matrix 135 # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder 136 broken = !(cudaAtLeast "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1"); 137 }; 138}