at 24.11-pre 1.7 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 pythonOlder, 5 fetchFromGitHub, 6 hatchling, 7 jax, 8 jaxlib, 9 jaxtyping, 10 typing-extensions, 11 beartype, 12 optax, 13 pytest-xdist, 14 pytestCheckHook, 15}: 16 17buildPythonPackage rec { 18 pname = "equinox"; 19 version = "0.11.4"; 20 pyproject = true; 21 22 disabled = pythonOlder "3.9"; 23 24 src = fetchFromGitHub { 25 owner = "patrick-kidger"; 26 repo = "equinox"; 27 rev = "refs/tags/v${version}"; 28 hash = "sha256-3OwHND1YEdg/SppqiB7pCdp6v+lYwTbtX07tmyEMWDo="; 29 }; 30 31 nativeBuildInputs = [ hatchling ]; 32 33 propagatedBuildInputs = [ 34 jax 35 jaxlib 36 jaxtyping 37 typing-extensions 38 ]; 39 40 nativeCheckInputs = [ 41 beartype 42 optax 43 pytest-xdist 44 pytestCheckHook 45 ]; 46 47 pythonImportsCheck = [ "equinox" ]; 48 49 disabledTests = [ 50 # For simplicity, JAX has removed its internal frames from the traceback of the following exception. 51 # https://github.com/patrick-kidger/equinox/issues/716 52 "test_abstract" 53 "test_complicated" 54 "test_grad" 55 "test_jvp" 56 "test_mlp" 57 "test_num_traces" 58 "test_pytree_in" 59 "test_simple" 60 "test_vmap" 61 62 # AssertionError: assert 'foo:\n pri...pe=float32)\n' == 'foo:\n pri...pe=float32)\n' 63 # Also reported in patrick-kidger/equinox#716 64 "test_backward_nan" 65 ]; 66 67 meta = with lib; { 68 description = "A JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees"; 69 changelog = "https://github.com/patrick-kidger/equinox/releases/tag/v${version}"; 70 homepage = "https://github.com/patrick-kidger/equinox"; 71 license = licenses.asl20; 72 maintainers = with maintainers; [ GaetanLepage ]; 73 }; 74}