Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
at fix-function-merge 25 lines 332 B view raw
1{ 2 jax, 3 jaxlib, 4 pkgs, 5}: 6 7pkgs.writers.writePython3Bin "jax-test-cuda" 8 { 9 libraries = [ 10 jax 11 jaxlib 12 ]; 13 } 14 '' 15 import jax 16 from jax import random 17 18 assert jax.devices()[0].platform == "gpu" 19 20 rng = random.PRNGKey(0) 21 x = random.normal(rng, (100, 100)) 22 x @ x 23 24 print("success!") 25 ''