1{
2 lib,
3 autograd,
4 buildPythonPackage,
5 fetchFromGitHub,
6 cvxopt,
7 cython,
8 jax,
9 jaxlib,
10 matplotlib,
11 numpy,
12 pymanopt,
13 pytestCheckHook,
14 pythonOlder,
15 scikit-learn,
16 scipy,
17 setuptools,
18 tensorflow,
19 torch,
20}:
21
22buildPythonPackage rec {
23 pname = "pot";
24 version = "0.9.4";
25 pyproject = true;
26
27 disabled = pythonOlder "3.6";
28
29 src = fetchFromGitHub {
30 owner = "PythonOT";
31 repo = "POT";
32 rev = "refs/tags/${version}";
33 hash = "sha256-Yx9hjniXebn7ZZeqou0JEsn2Yf9hyJSu/acDlM4kCCI=";
34 };
35
36 build-system = [
37 setuptools
38 cython
39 numpy
40 ];
41
42 dependencies = [
43 numpy
44 scipy
45 ];
46
47 optional-dependencies = {
48 backend-numpy = [ ];
49 backend-jax = [
50 jax
51 jaxlib
52 ];
53 backend-cupy = [ ];
54 backend-tf = [ tensorflow ];
55 backend-torch = [ torch ];
56 cvxopt = [ cvxopt ];
57 dr = [
58 scikit-learn
59 pymanopt
60 autograd
61 ];
62 gnn = [
63 torch
64 # torch-geometric
65 ];
66 plot = [ matplotlib ];
67 all =
68 with optional-dependencies;
69 (
70 backend-numpy
71 ++ backend-jax
72 ++ backend-cupy
73 ++ backend-tf
74 ++ backend-torch
75 ++ optional-dependencies.cvxopt
76 ++ dr
77 ++ gnn
78 ++ plot
79 );
80 };
81
82 nativeCheckInputs = [ pytestCheckHook ];
83
84 postPatch = ''
85 substituteInPlace setup.cfg \
86 --replace " --cov-report= --cov=ot" "" \
87 --replace " --durations=20" "" \
88 --replace " --junit-xml=junit-results.xml" ""
89
90 substituteInPlace pyproject.toml \
91 --replace-fail "numpy>=2.0.0" "numpy"
92
93 # we don't need setup.py to find the macos sdk for us
94 sed -i '/sdk_path/d' setup.py
95 '';
96
97 # need to run the tests with the built package next to the test directory
98 preCheck = ''
99 pushd build/lib.*
100 ln -s -t . "$OLDPWD/test"
101 '';
102
103 postCheck = ''
104 popd
105 '';
106
107 disabledTests = [
108 # GPU tests are always skipped because of sandboxing
109 "warnings"
110 # Fixture is not available
111 "test_conditional_gradient"
112 "test_convert_between_backends"
113 "test_emd_backends"
114 "test_emd_emd2_types_devices"
115 "test_emd1d_type_devices"
116 "test_emd2_backends"
117 "test_factored_ot_backends"
118 "test_free_support_barycenter_backends"
119 "test_func_backends"
120 "test_generalized_conditional_gradient"
121 "test_line_search_armijo"
122 "test_loss_dual"
123 "test_max_sliced_backend"
124 "test_plan_dual"
125 "test_random_backends"
126 "test_sliced_backend"
127 "test_to_numpy"
128 "test_wasserstein_1d_type_devices"
129 "test_wasserstein"
130 "test_weak_ot_bakends"
131 # TypeError: Only integers, slices...
132 "test_emd1d_device_tf"
133 ];
134
135 pythonImportsCheck = [
136 "ot"
137 "ot.lp"
138 ];
139
140 meta = with lib; {
141 description = "Python Optimal Transport Library";
142 homepage = "https://pythonot.github.io/";
143 license = licenses.mit;
144 maintainers = with maintainers; [ yl3dy ];
145 };
146}