1{
2 jax,
3 pkgs,
4}:
5
6pkgs.writers.writePython3Bin "jax-test-cuda"
7 {
8 libraries = [
9 jax
10 ] ++ jax.optional-dependencies.cuda;
11 }
12 ''
13 import jax
14 import jax.numpy as jnp
15 from jax import random
16 from jax.experimental import sparse
17
18 assert jax.devices()[0].platform == "gpu" # libcuda.so
19
20 rng = random.key(0) # libcudart.so, libcudnn.so
21 x = random.normal(rng, (100, 100))
22 x @ x # libcublas.so
23 jnp.fft.fft(x) # libcufft.so
24 jnp.linalg.inv(x) # libcusolver.so
25 sparse.CSR.fromdense(x) @ x # libcusparse.so
26
27 print("success!")
28 ''