···12 jax-cuda12-pjrt,
13}:
14let
15+ inherit (jaxlib) version;
16 inherit (cudaPackages) cudaVersion;
17+ inherit (jax-cuda12-pjrt) cudaLibPath;
1819 getSrcFromPypi =
20 {
···95 wheelUnpackHook
96 ];
9798+ # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel.
99+ # Linking into $out is the least bad solution. See
100+ # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
101+ # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211
102+ # * https://github.com/NixOS/nixpkgs/pull/375186
103+ # for more info.
104+ postInstall = ''
105+ mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
106+ ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
107+ ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
108+ '';
109+110+ # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen
111+ # and these implicit dependencies are not recognized by ldd or
112+ # autoPatchelfHook. That means we need to sneak them into rpath. This step
113+ # must be done after autoPatchelfHook and the automatic stripping of
114+ # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
115+ # patchPhase.
116+ preInstallCheck = ''
117+ patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so
118+ '';
119+120 dependencies = [ jax-cuda12-pjrt ];
121122 pythonImportsCheck = [ "jax_cuda12_plugin" ];
123124+ # FIXME: there are no tests, but we need to run preInstallCheck above
125+ doCheck = true;
126127 meta = {
128 description = "JAX Plugin for CUDA12";
-4
pkgs/development/python-modules/jax/default.nix
···198199 meta = {
200 description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code";
201- longDescription = ''
202- This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations,
203- e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`.
204- '';
205 homepage = "https://github.com/google/jax";
206 license = lib.licenses.asl20;
207 maintainers = with lib.maintainers; [ samuela ];