1{ lib
2, fetchPypi
3, buildPythonPackage
4, numpy
5, scipy
6, cython
7, matplotlib
8, scikit-learn
9, cupy
10, pymanopt
11, autograd
12, pytestCheckHook
13, enableDimensionalityReduction ? false
14, enableGPU ? false
15}:
16
17buildPythonPackage rec {
18 pname = "pot";
19 version = "0.7.0";
20
21 src = fetchPypi {
22 pname = "POT";
23 inherit version;
24 sha256 = "01mdsiv8rlgqzvm3bds9aj49khnn33i523c2cqqrl10zg742pb6l";
25 };
26
27 postPatch = ''
28 substituteInPlace setup.cfg \
29 --replace "--cov-report= --cov=ot" ""
30 '';
31
32 nativeBuildInputs = [ numpy cython ];
33 propagatedBuildInputs = [ numpy scipy ]
34 ++ lib.optionals enableGPU [ cupy ]
35 ++ lib.optionals enableDimensionalityReduction [ pymanopt autograd ];
36 checkInputs = [ matplotlib scikit-learn pytestCheckHook ];
37
38 # To prevent importing of an incomplete package from the build directory
39 # instead of nix store (`ot` is the top-level package name).
40 preCheck = ''
41 rm -r ot
42 '';
43
44 # GPU tests are always skipped because of sandboxing
45 disabledTests = [ "warnings" ];
46
47 pythonImportsCheck = [ "ot" "ot.lp" ];
48
49 meta = {
50 description = "Python Optimal Transport Library";
51 homepage = "https://pythonot.github.io/";
52 license = lib.licenses.mit;
53 maintainers = with lib.maintainers; [ yl3dy ];
54 };
55}