python3Packages.jax: add missing cuda libraries (#375186)

authored by

Samuel Ainsworth and committed by
GitHub
59f337a6 3db65802

+44 -12
+10 -2
pkgs/development/python-modules/jax-cuda12-pjrt/default.nix
··· 18 cudaLibPath = lib.makeLibraryPath ( 19 with cudaPackages; 20 [ 21 (lib.getLib cuda_cudart) # libcudart.so 22 (lib.getLib cudnn) # libcudnn.so 23 - (lib.getLib libcublas) # libcublas.so 24 - addDriverRunpath.driverLink # libcuda.so 25 ] 26 ); 27 ··· 82 doCheck = true; 83 84 pythonImportsCheck = [ "jax_plugins" ]; 85 86 meta = { 87 description = "JAX XLA PJRT Plugin for NVIDIA GPUs";
··· 18 cudaLibPath = lib.makeLibraryPath ( 19 with cudaPackages; 20 [ 21 + (lib.getLib libcublas) # libcublas.so 22 + (lib.getLib cuda_cupti) # libcupti.so 23 (lib.getLib cuda_cudart) # libcudart.so 24 (lib.getLib cudnn) # libcudnn.so 25 + (lib.getLib libcufft) # libcufft.so 26 + (lib.getLib libcusolver) # libcusolver.so 27 + (lib.getLib libcusparse) # libcusparse.so 28 + (lib.getLib nccl) # libnccl.so 29 + (lib.getLib libnvjitlink) # libnvJitLink.so 30 + (lib.getLib addDriverRunpath.driverLink) # libcuda.so 31 ] 32 ); 33 ··· 88 doCheck = true; 89 90 pythonImportsCheck = [ "jax_plugins" ]; 91 + 92 + inherit cudaLibPath; 93 94 meta = { 95 description = "JAX XLA PJRT Plugin for NVIDIA GPUs";
+26 -3
pkgs/development/python-modules/jax-cuda12-plugin/default.nix
··· 12 jax-cuda12-pjrt, 13 }: 14 let 15 inherit (cudaPackages) cudaVersion; 16 - inherit (jaxlib) version; 17 18 getSrcFromPypi = 19 { ··· 94 wheelUnpackHook 95 ]; 96 97 dependencies = [ jax-cuda12-pjrt ]; 98 99 pythonImportsCheck = [ "jax_cuda12_plugin" ]; 100 101 - # no tests 102 - doCheck = false; 103 104 meta = { 105 description = "JAX Plugin for CUDA12";
··· 12 jax-cuda12-pjrt, 13 }: 14 let 15 + inherit (jaxlib) version; 16 inherit (cudaPackages) cudaVersion; 17 + inherit (jax-cuda12-pjrt) cudaLibPath; 18 19 getSrcFromPypi = 20 { ··· 95 wheelUnpackHook 96 ]; 97 98 + # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel. 99 + # Linking into $out is the least bad solution. See 100 + # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 101 + # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211 102 + # * https://github.com/NixOS/nixpkgs/pull/375186 103 + # for more info. 104 + postInstall = '' 105 + mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 106 + ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 107 + ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 108 + ''; 109 + 110 + # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen 111 + # and these implicit dependencies are not recognized by ldd or 112 + # autoPatchelfHook. That means we need to sneak them into rpath. This step 113 + # must be done after autoPatchelfHook and the automatic stripping of 114 + # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the 115 + # patchPhase. 116 + preInstallCheck = '' 117 + patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so 118 + ''; 119 + 120 dependencies = [ jax-cuda12-pjrt ]; 121 122 pythonImportsCheck = [ "jax_cuda12_plugin" ]; 123 124 + # FIXME: there are no tests, but we need to run preInstallCheck above 125 + doCheck = true; 126 127 meta = { 128 description = "JAX Plugin for CUDA12";
-4
pkgs/development/python-modules/jax/default.nix
··· 198 199 meta = { 200 description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; 201 - longDescription = '' 202 - This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations, 203 - e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`. 204 - ''; 205 homepage = "https://github.com/google/jax"; 206 license = lib.licenses.asl20; 207 maintainers = with lib.maintainers; [ samuela ];
··· 198 199 meta = { 200 description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; 201 homepage = "https://github.com/google/jax"; 202 license = lib.licenses.asl20; 203 maintainers = with lib.maintainers; [ samuela ];
+8 -3
pkgs/development/python-modules/jax/test-cuda.nix
··· 11 } 12 '' 13 import jax 14 from jax import random 15 16 - assert jax.devices()[0].platform == "gpu" 17 18 - rng = random.PRNGKey(0) 19 x = random.normal(rng, (100, 100)) 20 - x @ x 21 22 print("success!") 23 ''
··· 11 } 12 '' 13 import jax 14 + import jax.numpy as jnp 15 from jax import random 16 + from jax.experimental import sparse 17 18 + assert jax.devices()[0].platform == "gpu" # libcuda.so 19 20 + rng = random.key(0) # libcudart.so, libcudnn.so 21 x = random.normal(rng, (100, 100)) 22 + x @ x # libcublas.so 23 + jnp.fft.fft(x) # libcufft.so 24 + jnp.linalg.inv(x) # libcusolver.so 25 + sparse.CSR.fromdense(x) @ x # libcusparse.so 26 27 print("success!") 28 ''