at 25.11-pre 590 B view raw
1{ 2 jax, 3 pkgs, 4}: 5 6pkgs.writers.writePython3Bin "jax-test-cuda" 7 { 8 libraries = [ 9 jax 10 ] ++ jax.optional-dependencies.cuda; 11 } 12 '' 13 import jax 14 import jax.numpy as jnp 15 from jax import random 16 from jax.experimental import sparse 17 18 assert jax.devices()[0].platform == "gpu" # libcuda.so 19 20 rng = random.key(0) # libcudart.so, libcudnn.so 21 x = random.normal(rng, (100, 100)) 22 x @ x # libcublas.so 23 jnp.fft.fft(x) # libcufft.so 24 jnp.linalg.inv(x) # libcusolver.so 25 sparse.CSR.fromdense(x) @ x # libcusparse.so 26 27 print("success!") 28 ''