at 25.11-pre 4.2 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 fetchpatch, 6 chex, 7 jaxlib, 8 numpy, 9 tensorflow-probability, 10 dm-haiku, 11 pytest-xdist, 12 pytestCheckHook, 13}: 14 15buildPythonPackage rec { 16 pname = "distrax"; 17 version = "0.1.5"; 18 pyproject = true; 19 20 src = fetchFromGitHub { 21 owner = "google-deepmind"; 22 repo = "distrax"; 23 tag = "v${version}"; 24 hash = "sha256-A1aCL/I89Blg9sNmIWQru4QJteUTN6+bhgrEJPmCrM0="; 25 }; 26 27 patches = [ 28 # TODO: remove at the next release (already on master) 29 (fetchpatch { 30 name = "fix-jax-0.6.0-compat"; 31 url = "https://github.com/google-deepmind/distrax/commit/c02708ac46518fac00ab2945311e0f2ee32c672c.patch"; 32 hash = "sha256-hFNXKoA1b5I6dzhwTRXp/SnkHv89GI6tYwlnBBHwG78="; 33 }) 34 ]; 35 36 dependencies = [ 37 chex 38 jaxlib 39 numpy 40 tensorflow-probability 41 ]; 42 43 nativeCheckInputs = [ 44 dm-haiku 45 pytest-xdist 46 pytestCheckHook 47 ]; 48 49 pythonImportsCheck = [ "distrax" ]; 50 51 disabledTests = [ 52 # Flaky: AssertionError: 1 not less than 0.7000000000000001 53 "test_von_mises_sample_uniform_ks_test" 54 55 # Flaky: AssertionError: Not equal to tolerance 56 "test_composite_methods_are_consistent__with_jit" 57 58 # NotImplementedError: Primitive 'square' does not have a registered inverse. 59 "test_against_tfp_bijectors_square" 60 "test_log_dets_square__with_device" 61 "test_log_dets_square__without_device" 62 "test_log_dets_square__without_jit" 63 64 # AssertionError on numerical values 65 # Reported upstream in https://github.com/google-deepmind/distrax/issues/267 66 "test_method_with_input_unnormalized_probs__with_device" 67 "test_method_with_input_unnormalized_probs__with_jit" 68 "test_method_with_input_unnormalized_probs__without_device" 69 "test_method_with_input_unnormalized_probs__without_jit" 70 "test_method_with_value_1d" 71 "test_nested_distributions__with_device" 72 "test_nested_distributions__without_device" 73 "test_nested_distributions__with_jit" 74 "test_nested_distributions__without_jit" 75 "test_stability__with_device" 76 "test_stability__with_jit" 77 "test_stability__without_device" 78 "test_stability__without_jit" 79 "test_von_mises_sample_gradient" 80 "test_von_mises_sample_moments" 81 ]; 82 83 disabledTestPaths = [ 84 # Since jax 0.6.0: 85 # TypeError: <lambda>() got an unexpected keyword argument 'accuracy' 86 "distrax/_src/bijectors/lambda_bijector_test.py" 87 88 # TypeErrors 89 "distrax/_src/bijectors/tfp_compatible_bijector_test.py" 90 "distrax/_src/distributions/distribution_from_tfp_test.py" 91 "distrax/_src/distributions/laplace_test.py" 92 "distrax/_src/distributions/multinomial_test.py" 93 "distrax/_src/distributions/mvn_diag_plus_low_rank_test.py" 94 "distrax/_src/distributions/mvn_kl_test.py" 95 "distrax/_src/distributions/straight_through_test.py" 96 "distrax/_src/distributions/tfp_compatible_distribution_test.py" 97 "distrax/_src/distributions/transformed_test.py" 98 "distrax/_src/distributions/uniform_test.py" 99 "distrax/_src/utils/transformations_test.py" 100 # https://github.com/google-deepmind/distrax/pull/270 101 "distrax/_src/distributions/deterministic_test.py" 102 "distrax/_src/distributions/epsilon_greedy_test.py" 103 "distrax/_src/distributions/gamma_test.py" 104 "distrax/_src/distributions/greedy_test.py" 105 "distrax/_src/distributions/gumbel_test.py" 106 "distrax/_src/distributions/logistic_test.py" 107 "distrax/_src/distributions/log_stddev_normal_test.py" 108 "distrax/_src/distributions/mvn_diag_test.py" 109 "distrax/_src/distributions/mvn_full_covariance_test.py" 110 "distrax/_src/distributions/mvn_tri_test.py" 111 "distrax/_src/distributions/one_hot_categorical_test.py" 112 "distrax/_src/distributions/softmax_test.py" 113 "distrax/_src/utils/hmm_test.py" 114 ]; 115 116 meta = { 117 description = "Probability distributions in JAX"; 118 homepage = "https://github.com/deepmind/distrax"; 119 changelog = "https://github.com/google-deepmind/distrax/releases/tag/v${version}"; 120 license = lib.licenses.asl20; 121 maintainers = with lib.maintainers; [ onny ]; 122 badPlatforms = [ 123 # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! 124 lib.systems.inspect.patterns.isDarwin 125 ]; 126 }; 127}