python3Packages.jax: add missing cuda libraries (#375186)

authored by

Samuel Ainsworth and committed by
GitHub
59f337a6 3db65802

+44 -12
+10 -2
pkgs/development/python-modules/jax-cuda12-pjrt/default.nix
··· 18 18 cudaLibPath = lib.makeLibraryPath ( 19 19 with cudaPackages; 20 20 [ 21 + (lib.getLib libcublas) # libcublas.so 22 + (lib.getLib cuda_cupti) # libcupti.so 21 23 (lib.getLib cuda_cudart) # libcudart.so 22 24 (lib.getLib cudnn) # libcudnn.so 23 - (lib.getLib libcublas) # libcublas.so 24 - addDriverRunpath.driverLink # libcuda.so 25 + (lib.getLib libcufft) # libcufft.so 26 + (lib.getLib libcusolver) # libcusolver.so 27 + (lib.getLib libcusparse) # libcusparse.so 28 + (lib.getLib nccl) # libnccl.so 29 + (lib.getLib libnvjitlink) # libnvJitLink.so 30 + (lib.getLib addDriverRunpath.driverLink) # libcuda.so 25 31 ] 26 32 ); 27 33 ··· 82 88 doCheck = true; 83 89 84 90 pythonImportsCheck = [ "jax_plugins" ]; 91 + 92 + inherit cudaLibPath; 85 93 86 94 meta = { 87 95 description = "JAX XLA PJRT Plugin for NVIDIA GPUs";
+26 -3
pkgs/development/python-modules/jax-cuda12-plugin/default.nix
··· 12 12 jax-cuda12-pjrt, 13 13 }: 14 14 let 15 + inherit (jaxlib) version; 15 16 inherit (cudaPackages) cudaVersion; 16 - inherit (jaxlib) version; 17 + inherit (jax-cuda12-pjrt) cudaLibPath; 17 18 18 19 getSrcFromPypi = 19 20 { ··· 94 95 wheelUnpackHook 95 96 ]; 96 97 98 + # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel. 99 + # Linking into $out is the least bad solution. See 100 + # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 101 + # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211 102 + # * https://github.com/NixOS/nixpkgs/pull/375186 103 + # for more info. 104 + postInstall = '' 105 + mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 106 + ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 107 + ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 108 + ''; 109 + 110 + # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen 111 + # and these implicit dependencies are not recognized by ldd or 112 + # autoPatchelfHook. That means we need to sneak them into rpath. This step 113 + # must be done after autoPatchelfHook and the automatic stripping of 114 + # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the 115 + # patchPhase. 116 + preInstallCheck = '' 117 + patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so 118 + ''; 119 + 97 120 dependencies = [ jax-cuda12-pjrt ]; 98 121 99 122 pythonImportsCheck = [ "jax_cuda12_plugin" ]; 100 123 101 - # no tests 102 - doCheck = false; 124 + # FIXME: there are no tests, but we need to run preInstallCheck above 125 + doCheck = true; 103 126 104 127 meta = { 105 128 description = "JAX Plugin for CUDA12";
-4
pkgs/development/python-modules/jax/default.nix
··· 198 198 199 199 meta = { 200 200 description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; 201 - longDescription = '' 202 - This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations, 203 - e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`. 204 - ''; 205 201 homepage = "https://github.com/google/jax"; 206 202 license = lib.licenses.asl20; 207 203 maintainers = with lib.maintainers; [ samuela ];
+8 -3
pkgs/development/python-modules/jax/test-cuda.nix
··· 11 11 } 12 12 '' 13 13 import jax 14 + import jax.numpy as jnp 14 15 from jax import random 16 + from jax.experimental import sparse 15 17 16 - assert jax.devices()[0].platform == "gpu" 18 + assert jax.devices()[0].platform == "gpu" # libcuda.so 17 19 18 - rng = random.PRNGKey(0) 20 + rng = random.key(0) # libcudart.so, libcudnn.so 19 21 x = random.normal(rng, (100, 100)) 20 - x @ x 22 + x @ x # libcublas.so 23 + jnp.fft.fft(x) # libcufft.so 24 + jnp.linalg.inv(x) # libcusolver.so 25 + sparse.CSR.fromdense(x) @ x # libcusparse.so 21 26 22 27 print("success!") 23 28 ''