Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
at flake-libs 106 lines 3.6 kB view raw
1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 fetchPypi, 6 addDriverRunpath, 7 autoPatchelfHook, 8 pypaInstallHook, 9 wheelUnpackHook, 10 cudaPackages, 11 python, 12 jaxlib, 13}: 14let 15 inherit (jaxlib) version; 16 inherit (cudaPackages) cudaAtLeast; 17 18 cudaLibPath = lib.makeLibraryPath ( 19 with cudaPackages; 20 [ 21 (lib.getLib libcublas) # libcublas.so 22 (lib.getLib cuda_cupti) # libcupti.so 23 (lib.getLib cuda_cudart) # libcudart.so 24 (lib.getLib cudnn) # libcudnn.so 25 (lib.getLib libcufft) # libcufft.so 26 (lib.getLib libcusolver) # libcusolver.so 27 (lib.getLib libcusparse) # libcusparse.so 28 (lib.getLib nccl) # libnccl.so 29 (lib.getLib libnvjitlink) # libnvJitLink.so 30 (lib.getLib addDriverRunpath.driverLink) # libcuda.so 31 ] 32 ); 33 34in 35buildPythonPackage rec { 36 pname = "jax-cuda12-pjrt"; 37 inherit version; 38 pyproject = false; 39 40 src = fetchPypi { 41 pname = "jax_cuda12_pjrt"; 42 inherit version; 43 format = "wheel"; 44 python = "py3"; 45 dist = "py3"; 46 platform = 47 { 48 x86_64-linux = "manylinux2014_x86_64"; 49 aarch64-linux = "manylinux2014_aarch64"; 50 } 51 .${stdenv.hostPlatform.system}; 52 hash = 53 { 54 x86_64-linux = "sha256-TJfRClqawJ+gAVaMrDtxUBTo27ws2GdjdT9Y5acwwzM="; 55 aarch64-linux = "sha256-lnB2z7by4zlZ5zdmY1maoMEcwO3o8vUaIG2godQixrs="; 56 } 57 .${stdenv.hostPlatform.system}; 58 }; 59 60 nativeBuildInputs = [ 61 autoPatchelfHook 62 pypaInstallHook 63 wheelUnpackHook 64 ]; 65 66 # jax-cuda12-pjrt looks for ptxas, nvlink and nvvm at runtime, eg when running `jax.random.PRNGKey(0)`. 67 # Linking into $out is the least bad solution. See 68 # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 69 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211 70 # for more info. 71 postInstall = '' 72 mkdir -p $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/bin 73 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/bin/ptxas 74 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/bin/nvlink 75 ln -s ${cudaPackages.cuda_nvcc}/nvvm $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/nvvm 76 ''; 77 78 # jax-cuda12-pjrt contains shared libraries that open other shared libraries via dlopen 79 # and these implicit dependencies are not recognized by ldd or 80 # autoPatchelfHook. That means we need to sneak them into rpath. This step 81 # must be done after autoPatchelfHook and the automatic stripping of 82 # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the 83 # patchPhase. 84 preInstallCheck = '' 85 patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_plugins/xla_cuda12/xla_cuda_plugin.so 86 ''; 87 88 # FIXME: there are no tests, but we need to run preInstallCheck above 89 doCheck = true; 90 91 pythonImportsCheck = [ "jax_plugins" ]; 92 93 inherit cudaLibPath; 94 95 meta = { 96 description = "JAX XLA PJRT Plugin for NVIDIA GPUs"; 97 homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; 98 sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; 99 license = lib.licenses.asl20; 100 maintainers = with lib.maintainers; [ natsukium ]; 101 platforms = lib.platforms.linux; 102 # see CUDA compatibility matrix 103 # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder 104 broken = !(cudaAtLeast "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1"); 105 }; 106}