at 24.11-pre 332 B view raw
1{ 2 jax, 3 jaxlib, 4 pkgs, 5}: 6 7pkgs.writers.writePython3Bin "jax-test-cuda" 8 { 9 libraries = [ 10 jax 11 jaxlib 12 ]; 13 } 14 '' 15 import jax 16 from jax import random 17 18 assert jax.devices()[0].platform == "gpu" 19 20 rng = random.PRNGKey(0) 21 x = random.normal(rng, (100, 100)) 22 x @ x 23 24 print("success!") 25 ''