lol

python312Packages.jax[lib]: 0.5.0 -> 0.5.1 (#384879)

authored by

Nick Cao and committed by
GitHub
2b0c931f 2f0c9948

+69 -76
+3
pkgs/development/python-modules/distrax/default.nix
··· 42 42 # Flaky: AssertionError: 1 not less than 0.7000000000000001 43 43 "test_von_mises_sample_uniform_ks_test" 44 44 45 + # Flaky: AssertionError: Not equal to tolerance 46 + "test_composite_methods_are_consistent__with_jit" 47 + 45 48 # NotImplementedError: Primitive 'square' does not have a registered inverse. 46 49 "test_against_tfp_bijectors_square" 47 50 "test_log_dets_square__with_device"
+8
pkgs/development/python-modules/equinox/default.nix
··· 32 32 hash = "sha256-hor2qw+aTL7yhV53E/y5DUwyDEYJA8RPRS39xxa8xcw="; 33 33 }; 34 34 35 + # Relax speed constraints on tests that can fail on busy builders 36 + postPatch = '' 37 + substituteInPlace tests/test_while_loop.py \ 38 + --replace-fail "speed < 0.1" "speed < 0.5" \ 39 + --replace-fail "speed < 0.5" "speed < 1" \ 40 + --replace-fail "speed < 1" "speed < 4" \ 41 + ''; 42 + 35 43 build-system = [ hatchling ]; 36 44 37 45 dependencies = [
+6
pkgs/development/python-modules/flax/default.nix
··· 84 84 tensorflow 85 85 ]; 86 86 87 + pytestFlagsArray = [ 88 + # DeprecationWarning: linear_util.wrap_init is missing a DebugInfo object. 89 + "-W" 90 + "ignore::DeprecationWarning" 91 + ]; 92 + 87 93 disabledTestPaths = [ 88 94 # Docs test, needs extra deps + we're not interested in it. 89 95 "docs/_ext/codediff_test.py"
+1 -1
pkgs/development/python-modules/jax-cuda12-pjrt/default.nix
··· 39 39 srcs = { 40 40 "x86_64-linux" = fetchurl { 41 41 url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_x86_64.whl"; 42 - hash = "sha256-0jgzwbiF2WwnZAAOlQUvK1gnx31JLqaPZ+kDoTJlbbs="; 42 + hash = "sha256-05Xe87NP1oSOEVlu8pdaiV0fUG31EuQbH8XS3lIMjlE="; 43 43 }; 44 44 # "aarch64-linux" = fetchurl { 45 45 # url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_aarch64.whl";
+8 -8
pkgs/development/python-modules/jax-cuda12-plugin/default.nix
··· 40 40 "3.10-x86_64-linux" = getSrcFromPypi { 41 41 platform = "manylinux2014_x86_64"; 42 42 dist = "cp310"; 43 - hash = "sha256-D0Q6azcpjt+weW/NvR+GzoWksIS2vT8fUKT7/Wfe2Gs="; 43 + hash = "sha256-ymCGSgWlzqRK51dthHtHeeTeYYUKmhgjg0H8Q6dY1Vs="; 44 44 }; 45 45 "3.10-aarch64-linux" = getSrcFromPypi { 46 46 platform = "manylinux2014_aarch64"; 47 47 dist = "cp310"; 48 - hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; 48 + hash = "sha256-488emSaNinBBPw+sM1krh2nBPckdch+RxYeEa+nYhAM="; 49 49 }; 50 50 "3.11-x86_64-linux" = getSrcFromPypi { 51 51 platform = "manylinux2014_x86_64"; 52 52 dist = "cp311"; 53 - hash = "sha256-qYE1oCIwZLj1xoU+It3BpOOGIVLTf7aF8Nve/+DIASI="; 53 + hash = "sha256-NGVbjq/H2b/sMbB3rBPULgjY7YZV0kFHxa38AVFSaU8="; 54 54 }; 55 55 "3.11-aarch64-linux" = getSrcFromPypi { 56 56 platform = "manylinux2014_aarch64"; 57 57 dist = "cp311"; 58 - hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; 58 + hash = "sha256-ik8Mje8QnXb1uqwbKuRpCURFzS1/vYxZf83WI+sC/1Q="; 59 59 }; 60 60 "3.12-x86_64-linux" = getSrcFromPypi { 61 61 platform = "manylinux2014_x86_64"; 62 62 dist = "cp312"; 63 - hash = "sha256-QwWN/FZdjJ2mn0fNTkuVxJXxaG8onvRYTCtygD5vFgc="; 63 + hash = "sha256-Fp22Rbr+whOO2YOvjxTk0RqElyivpXIC55qRBNmJLxY="; 64 64 }; 65 65 "3.12-aarch64-linux" = getSrcFromPypi { 66 66 platform = "manylinux2014_aarch64"; 67 67 dist = "cp312"; 68 - hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; 68 + hash = "sha256-wqhtVj3AR5rLTHmuz/vuwYn6vY6XAP4/cxvwOV4dbBE="; 69 69 }; 70 70 "3.13-x86_64-linux" = getSrcFromPypi { 71 71 platform = "manylinux2014_x86_64"; 72 72 dist = "cp313"; 73 - hash = "sha256-3zbEsXbi01qCqfOM13zDadJx5gBR43GgqO9FFD+PWLY="; 73 + hash = "sha256-J0kTjLGHMZBoa7FPMxBIskwXjwXkdIo9L/fSA2c1rT0="; 74 74 }; 75 75 "3.13-aarch64-linux" = getSrcFromPypi { 76 76 platform = "manylinux2014_aarch64"; 77 77 dist = "cp313"; 78 - hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; 78 + hash = "sha256-C7aXl3+NbGOS9WLDQF6D68xvXLfCCBQFCE9tXPoJ4yo="; 79 79 }; 80 80 }; 81 81 in
+2 -34
pkgs/development/python-modules/jax/default.nix
··· 40 40 in 41 41 buildPythonPackage rec { 42 42 pname = "jax"; 43 - version = "0.5.0"; 43 + version = "0.5.1"; 44 44 pyproject = true; 45 45 46 46 src = fetchFromGitHub { ··· 48 48 repo = "jax"; 49 49 # google/jax contains tags for jax and jaxlib. Only use jax tags! 50 50 tag = "jax-v${version}"; 51 - hash = "sha256-D6n9Z34nrCbBd9IS8YW6uio5Yi9GLCo9PViO3YYbkQ8="; 51 + hash = "sha256-WXtoLiRGcD8PqCMV+VYGeFr/qFEavuoVM5JSapO1QXc="; 52 52 }; 53 53 54 54 build-system = [ setuptools ]; ··· 113 113 [ 114 114 # Exceeds tolerance when the machine is busy 115 115 "test_custom_linear_solve_aux" 116 - # UserWarning: Explicitly requested dtype <class 'numpy.float64'> 117 - # requested in astype is not available, and will be truncated to 118 - # dtype float32. (With numpy 1.24) 119 - "testKde3" 120 - "testKde5" 121 - "testKde6" 122 - # Invokes python manually in a subprocess, which does not have the correct dependencies 123 - # ImportError: This version of jax requires jaxlib version >= 0.4.19. 124 - "test_no_log_spam" 125 116 ] 126 117 ++ lib.optionals usingMKL [ 127 118 # See ··· 133 124 "testEigvalsGrad_shape" 134 125 ] 135 126 ++ lib.optionals stdenv.hostPlatform.isAarch64 [ 136 - # See https://github.com/google/jax/issues/14793. 137 - "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop" 138 - "testQdwhWithRandomMatrix3" 139 - "testScanGrad_jit_scan" 140 - 141 - # See https://github.com/google/jax/issues/17867. 142 - "test_array" 143 - "test_async" 144 - "test_copy0" 145 - "test_device_put" 146 - "test_make_array_from_callback" 147 - "test_make_array_from_single_device_arrays" 148 - 149 127 # Fails on some hardware due to some numerical error 150 128 # See https://github.com/google/jax/issues/18535 151 129 "testQdwhWithOnRankDeficientInput5" ··· 161 139 "test_vjp_rule_inconsistent_pytree_structures_error" 162 140 "test_vmap_in_axes_tree_prefix_error" 163 141 "test_vmap_mismatched_axis_sizes_error_message_issue_705" 164 - ]; 165 - 166 - disabledTestPaths = 167 - [ 168 - # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba 169 - "tests/linalg_test.py" 170 - ] 171 - ++ lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64) [ 172 - # RuntimeWarning: invalid value encountered in cast 173 - "tests/lax_test.py" 174 142 ]; 175 143 176 144 pythonImportsCheck = [ "jax" ];
+13 -13
pkgs/development/python-modules/jaxlib/bin.nix
··· 18 18 }: 19 19 20 20 let 21 - version = "0.5.0"; 21 + version = "0.5.1"; 22 22 inherit (python) pythonVersion; 23 23 24 24 # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the ··· 49 49 "3.10-x86_64-linux" = getSrcFromPypi { 50 50 platform = "manylinux2014_x86_64"; 51 51 dist = "cp310"; 52 - hash = "sha256-dEQLYyEHM2QA1Pl6Fkgddn8T6pFMU7oU5UTG/aVIGbM="; 52 + hash = "sha256-ZbxJAKBJHfxvubamLoEA0SFCnVjHQplF7CtCTBqCvyc="; 53 53 }; 54 54 "3.10-aarch64-linux" = getSrcFromPypi { 55 55 platform = "manylinux2014_aarch64"; 56 56 dist = "cp310"; 57 - hash = "sha256-Wy7+Pf6/GKhMRR04A6yITuJCAhwRE7J5wT9LvDeMPcA="; 57 + hash = "sha256-CQ/n1LyOGXaHcdvIMZ3pAkdGu1Z/XZujebUQLxdljEE="; 58 58 }; 59 59 "3.10-aarch64-darwin" = getSrcFromPypi { 60 60 platform = "macosx_11_0_arm64"; 61 61 dist = "cp310"; 62 - hash = "sha256-G4psQ0XxN/OHZQ3i28SIwgJRt0ErVd1kjhpPE7z1B/s="; 62 + hash = "sha256-LavMsIZHaBj3N80OYi/YjcT1u08N9Melxow0s2bohH8="; 63 63 }; 64 64 65 65 "3.11-x86_64-linux" = getSrcFromPypi { 66 66 platform = "manylinux2014_x86_64"; 67 67 dist = "cp311"; 68 - hash = "sha256-CRE+8Vgro018vEQP7bMY9IVbWbd2cRqKuiRzyXJ9MCU="; 68 + hash = "sha256-gMDtVEZkSzg8qj5hdUCAO7DvNuVgnNd1bU0pE6jeUS4="; 69 69 }; 70 70 "3.11-aarch64-linux" = getSrcFromPypi { 71 71 platform = "manylinux2014_aarch64"; 72 72 dist = "cp311"; 73 - hash = "sha256-YwiNv6qFu1bNUhqSWjRy/XMosY7JPC2P+oWvMxCVyZU="; 73 + hash = "sha256-afS54HrQdNRBuZIbeoOu5PT/09VCAz/exQvhRW0GEcY="; 74 74 }; 75 75 "3.11-aarch64-darwin" = getSrcFromPypi { 76 76 platform = "macosx_11_0_arm64"; 77 77 dist = "cp311"; 78 - hash = "sha256-bNdi7RYjEySZ+nAcQgNEYQLgqcgsojGUuHKI90bRKik="; 78 + hash = "sha256-M0xJrUEfOaUFXCPxOVUq4yvZr+aWq7scrfnETr72B/c="; 79 79 }; 80 80 81 81 "3.12-x86_64-linux" = getSrcFromPypi { 82 82 platform = "manylinux2014_x86_64"; 83 83 dist = "cp312"; 84 - hash = "sha256-+YDHM+mMmYqNqHyajMYbZybQvmZ6WL1mTB1xe0tOrnU="; 84 + hash = "sha256-Uee1n8QLsnBEDFBJs8gvn3/a2uMZnxaBhiDP24C5Z68="; 85 85 }; 86 86 "3.12-aarch64-linux" = getSrcFromPypi { 87 87 platform = "manylinux2014_aarch64"; 88 88 dist = "cp312"; 89 - hash = "sha256-S0sBr7Dd7JbAg1a/8rtoXdvpf9/+Ttbi2DSzCrqXLyI="; 89 + hash = "sha256-W0ulqj9ZtfLjfVJc7dav4P7suIQW5fQ+uacJz97YslA="; 90 90 }; 91 91 "3.12-aarch64-darwin" = getSrcFromPypi { 92 92 platform = "macosx_11_0_arm64"; 93 93 dist = "cp312"; 94 - hash = "sha256-c+M1cVdgxW5jUQnWFCZDWl1/RvM2OhFdrqCUJ9XNDv0="; 94 + hash = "sha256-rj3ii/m4Z4HDCjLIi3y9HTIiqMIpqpbL4hBV3NCeuIk="; 95 95 }; 96 96 97 97 "3.13-x86_64-linux" = getSrcFromPypi { 98 98 platform = "manylinux2014_x86_64"; 99 99 dist = "cp313"; 100 - hash = "sha256-Ee7wHTfA8cUwYmW3byB/EALRNIDe0uMf1j7HaRLJPKI="; 100 + hash = "sha256-3BCf+mhzZAImw2DaeTqKe+r4xIdRrD/bKfprM3kB4YY="; 101 101 }; 102 102 "3.13-aarch64-linux" = getSrcFromPypi { 103 103 platform = "manylinux2014_aarch64"; 104 104 dist = "cp313"; 105 - hash = "sha256-fZsXp+oZNV1F7Nsv8NtdcHqG8MWoYtlLibRWjWxFMRo="; 105 + hash = "sha256-3zcE8TXP+H/Z1BkwJIkl8vFjvu1u/qqrzNl0AVgNz9g="; 106 106 }; 107 107 "3.13-aarch64-darwin" = getSrcFromPypi { 108 108 platform = "macosx_11_0_arm64"; 109 109 dist = "cp313"; 110 - hash = "sha256-7RjqcWHQOqj9TRtVSUiC8hQg79/qaOXymMSuvPKsPzQ="; 110 + hash = "sha256-jFf7vnmqPOOsLsZXp/F4Z7mzwsrYhbLIOQVnzJc47ug="; 111 111 }; 112 112 }; 113 113 in
+1 -2
pkgs/development/python-modules/jaxlib/prefetch.sh
··· 1 1 #! /usr/bin/env nix-shell 2 2 #! nix-shell -i sh -p jq 3 3 4 - prefetch () { 4 + prefetch() { 5 5 expr="(import <nixpkgs> { system = \"$2\"; config.cudaSupport = true; }).python$1.pkgs.$3.src.url" 6 6 url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r) 7 7 echo "$url" ··· 14 14 prefetch "$py" "x86_64-linux" "jaxlib-bin" 15 15 prefetch "$py" "aarch64-linux" "jaxlib-bin" 16 16 prefetch "$py" "aarch64-darwin" "jaxlib-bin" 17 - prefetch "$py" "x86_64-darwin" "jaxlib-bin" 18 17 prefetch "$py" "x86_64-linux" "jax-cuda12-plugin" 19 18 prefetch "$py" "aarch64-linux" "jax-cuda12-plugin" 20 19 done
+24 -18
pkgs/development/python-modules/numpyro/default.nix
··· 1 1 { 2 2 lib, 3 + stdenv, 3 4 buildPythonPackage, 4 5 fetchFromGitHub, 5 6 ··· 75 76 "ignore::UserWarning" 76 77 ]; 77 78 78 - disabledTests = [ 79 - # AssertionError due to tolerance issues 80 - "test_bijective_transforms" 81 - "test_cpu" 82 - "test_entropy_categorical" 83 - "test_gaussian_model" 79 + disabledTests = 80 + [ 81 + # AssertionError due to tolerance issues 82 + "test_bijective_transforms" 83 + "test_cpu" 84 + "test_entropy_categorical" 85 + "test_gaussian_model" 84 86 85 - # > with pytest.warns(UserWarning, match="Hessian of log posterior"): 86 - # E Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted. 87 - # E Emitted warnings: []. 88 - "test_laplace_approximation_warning" 87 + # > with pytest.warns(UserWarning, match="Hessian of log posterior"): 88 + # E Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted. 89 + # E Emitted warnings: []. 90 + "test_laplace_approximation_warning" 89 91 90 - # Tests want to download data 91 - "data_load" 92 - "test_jsb_chorales" 92 + # Tests want to download data 93 + "data_load" 94 + "test_jsb_chorales" 93 95 94 - # ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available (num_replicas=2) 95 - "test_chain" 96 + # ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available (num_replicas=2) 97 + "test_chain" 96 98 97 - # test_biject_to[CorrMatrix()-(15,)] - assert Array(False, dtype=bool) 98 - "test_biject_to" 99 - ]; 99 + # test_biject_to[CorrMatrix()-(15,)] - assert Array(False, dtype=bool) 100 + "test_biject_to" 101 + ] 102 + ++ lib.optionals stdenv.hostPlatform.isDarwin [ 103 + # AssertionError: Not equal to tolerance rtol=0.06, atol=0 104 + "test_functional_map" 105 + ]; 100 106 101 107 meta = { 102 108 description = "Library for probabilistic programming with NumPy";
+3
pkgs/development/python-modules/oryx/default.nix
··· 114 114 changelog = "https://github.com/jax-ml/oryx/releases/tag/v${version}"; 115 115 license = lib.licenses.asl20; 116 116 maintainers = with lib.maintainers; [ GaetanLepage ]; 117 + # oryx seems to be incompatible with jax 0.5.1 118 + # 237 additional test failures are resulting from the jax bump. 119 + broken = true; 117 120 }; 118 121 }