1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, linear_operator
5, scikit-learn
6, torch
7, pytestCheckHook
8}:
9
10buildPythonPackage rec {
11 pname = "gpytorch";
12 version = "1.10";
13 format = "pyproject";
14
15 src = fetchFromGitHub {
16 owner = "cornellius-gp";
17 repo = pname;
18 rev = "v${version}";
19 hash = "sha256-KY3ItkVjBfIYMkZAmD56EBGR9YN/MRN7b2K3zrK6Qmk=";
20 };
21
22 postPatch = ''
23 substituteInPlace setup.py \
24 --replace 'find_version("gpytorch", "version.py")' \"$version\"
25 '';
26
27 propagatedBuildInputs = [
28 linear_operator
29 scikit-learn
30 torch
31 ];
32
33 checkInputs = [
34 pytestCheckHook
35 ];
36 pythonImportsCheck = [ "gpytorch" ];
37 disabledTests = [
38 # AssertionError on number of warnings emitted
39 "test_deprecated_methods"
40 # flaky numerical tests
41 "test_classification_error"
42 "test_matmul_matrix_broadcast"
43 ];
44
45 meta = with lib; {
46 description = "A highly efficient and modular implementation of Gaussian Processes, with GPU acceleration";
47 homepage = "https://gpytorch.ai";
48 license = licenses.mit;
49 maintainers = with maintainers; [ veprbl ];
50 };
51}