···1212 jax-cuda12-pjrt,
1313}:
1414let
1515+ inherit (jaxlib) version;
1516 inherit (cudaPackages) cudaVersion;
1616- inherit (jaxlib) version;
1717+ inherit (jax-cuda12-pjrt) cudaLibPath;
17181819 getSrcFromPypi =
1920 {
···9495 wheelUnpackHook
9596 ];
96979898+ # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel.
9999+ # Linking into $out is the least bad solution. See
100100+ # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
101101+ # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211
102102+ # * https://github.com/NixOS/nixpkgs/pull/375186
103103+ # for more info.
104104+ postInstall = ''
105105+ mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
106106+ ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
107107+ ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
108108+ '';
109109+110110+ # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen
111111+ # and these implicit dependencies are not recognized by ldd or
112112+ # autoPatchelfHook. That means we need to sneak them into rpath. This step
113113+ # must be done after autoPatchelfHook and the automatic stripping of
114114+ # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
115115+ # patchPhase.
116116+ preInstallCheck = ''
117117+ patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so
118118+ '';
119119+97120 dependencies = [ jax-cuda12-pjrt ];
9812199122 pythonImportsCheck = [ "jax_cuda12_plugin" ];
100123101101- # no tests
102102- doCheck = false;
124124+ # FIXME: there are no tests, but we need to run preInstallCheck above
125125+ doCheck = true;
103126104127 meta = {
105128 description = "JAX Plugin for CUDA12";
-4
pkgs/development/python-modules/jax/default.nix
···198198199199 meta = {
200200 description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code";
201201- longDescription = ''
202202- This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations,
203203- e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`.
204204- '';
205201 homepage = "https://github.com/google/jax";
206202 license = lib.licenses.asl20;
207203 maintainers = with lib.maintainers; [ samuela ];
+8-3
pkgs/development/python-modules/jax/test-cuda.nix
···1111 }
1212 ''
1313 import jax
1414+ import jax.numpy as jnp
1415 from jax import random
1616+ from jax.experimental import sparse
15171616- assert jax.devices()[0].platform == "gpu"
1818+ assert jax.devices()[0].platform == "gpu" # libcuda.so
17191818- rng = random.PRNGKey(0)
2020+ rng = random.key(0) # libcudart.so, libcudnn.so
1921 x = random.normal(rng, (100, 100))
2020- x @ x
2222+ x @ x # libcublas.so
2323+ jnp.fft.fft(x) # libcufft.so
2424+ jnp.linalg.inv(x) # libcusolver.so
2525+ sparse.CSR.fromdense(x) @ x # libcusparse.so
21262227 print("success!")
2328 ''