1{
2 jax,
3 jaxlib,
4 pkgs,
5}:
6
7pkgs.writers.writePython3Bin "jax-test-cuda"
8 {
9 libraries = [
10 jax
11 jaxlib
12 ];
13 }
14 ''
15 import jax
16 from jax import random
17
18 assert jax.devices()[0].platform == "gpu"
19
20 rng = random.PRNGKey(0)
21 x = random.normal(rng, (100, 100))
22 x @ x
23
24 print("success!")
25 ''