1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 setuptools,
9
10 # dependencies
11 apricot-select,
12 networkx,
13 numpy,
14 scikit-learn,
15 scipy,
16 torch,
17
18 # tests
19 pytestCheckHook,
20}:
21
22buildPythonPackage rec {
23 pname = "pomegranate";
24 version = "1.1.2";
25 pyproject = true;
26
27 src = fetchFromGitHub {
28 repo = "pomegranate";
29 owner = "jmschrei";
30 # tag = "v${version}";
31 # No tag for 1.1.2
32 rev = "e9162731f4f109b7b17ecffde768734cacdb839b";
33 hash = "sha256-vVoAoZ+mph11ZfINT+yxRyk9rXv6FBDgxBz56P2K95Y=";
34 };
35
36 # _pickle.UnpicklingError: Weights only load failed.
37 # https://pytorch.org/docs/stable/generated/torch.load.html
38 postPatch = ''
39 substituteInPlace \
40 tests/distributions/test_bernoulli.py \
41 tests/distributions/test_categorical.py \
42 tests/distributions/test_exponential.py \
43 tests/distributions/test_gamma.py \
44 tests/distributions/test_independent_component.py \
45 tests/distributions/test_normal_diagonal.py \
46 tests/distributions/test_normal_full.py \
47 tests/distributions/test_poisson.py \
48 tests/distributions/test_student_t.py \
49 tests/distributions/test_uniform.py \
50 tests/test_bayes_classifier.py \
51 tests/test_gmm.py \
52 tests/test_kmeans.py \
53 --replace-fail \
54 'torch.load(".pytest.torch")' \
55 'torch.load(".pytest.torch", weights_only=False)'
56 '';
57
58 build-system = [ setuptools ];
59
60 dependencies = [
61 apricot-select
62 networkx
63 numpy
64 scikit-learn
65 scipy
66 torch
67 ];
68
69 pythonImportsCheck = [ "pomegranate" ];
70
71 nativeCheckInputs = [
72 pytestCheckHook
73 ];
74
75 disabledTestPaths = lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isx86_64) [
76 # AssertionError: Arrays are not almost equal to 6 decimals
77 "=tests/distributions/test_normal_full.py::test_fit"
78 "=tests/distributions/test_normal_full.py::test_from_summaries"
79 "=tests/distributions/test_normal_full.py::test_serialization"
80 ];
81
82 disabledTests = [
83 # AssertionError: Arrays are not almost equal to 6 decimals
84 "test_sample"
85 ];
86
87 meta = {
88 description = "Probabilistic and graphical models for Python, implemented in cython for speed";
89 homepage = "https://github.com/jmschrei/pomegranate";
90 changelog = "https://github.com/jmschrei/pomegranate/releases/tag/v${version}";
91 license = lib.licenses.mit;
92 maintainers = with lib.maintainers; [ rybern ];
93 };
94}