nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 setuptools,
9 setuptools-scm,
10
11 # dependencies
12 jaxtyping,
13 linear-operator,
14 mpmath,
15 scikit-learn,
16 scipy,
17 torch,
18
19 # tests
20 pytestCheckHook,
21}:
22
23buildPythonPackage (finalAttrs: {
24 pname = "gpytorch";
25 version = "1.15.1";
26 pyproject = true;
27
28 src = fetchFromGitHub {
29 owner = "cornellius-gp";
30 repo = "gpytorch";
31 tag = "v${finalAttrs.version}";
32 hash = "sha256-ftiAY02K0EwVQZufk8xR+/21A+2ONWchuWPF3a5lRW0=";
33 };
34
35 # AttributeError: module 'numpy' has no attribute 'trapz'
36 postPatch = ''
37 substituteInPlace gpytorch/kernels/spectral_mixture_kernel.py \
38 --replace-fail \
39 "np.trapz(emp_spect, freq)" \
40 "np.trapezoid(emp_spect, freq)"
41 '';
42
43 build-system = [
44 setuptools
45 setuptools-scm
46 ];
47
48 dependencies = [
49 jaxtyping
50 linear-operator
51 mpmath
52 scikit-learn
53 scipy
54 torch
55 ];
56
57 nativeCheckInputs = [ pytestCheckHook ];
58
59 pythonImportsCheck = [ "gpytorch" ];
60
61 disabledTests = [
62 # AssertionError on number of warnings emitted
63 "test_deprecated_methods"
64 # flaky numerical tests
65 "test_classification_error"
66 "test_matmul_matrix_broadcast"
67 "test_optimization_optimal_error"
68 # https://github.com/cornellius-gp/gpytorch/issues/2396
69 "test_t_matmul_matrix"
70 ];
71
72 disabledTestPaths = lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isx86_64) [
73 # Hang forever
74 "test/examples/test_spectral_mixture_gp_regression.py"
75 "test/kernels/test_spectral_mixture_kernel.py"
76 "test/utils/test_nearest_neighbors.py"
77 "test/variational/test_nearest_neighbor_variational_strategy.py"
78 ];
79
80 meta = {
81 description = "Highly efficient and modular implementation of Gaussian Processes, with GPU acceleration";
82 homepage = "https://gpytorch.ai";
83 downloadPage = "https://github.com/cornellius-gp/gpytorch";
84 changelog = "https://github.com/cornellius-gp/gpytorch/releases/tag/${finalAttrs.src.tag}";
85 license = lib.licenses.mit;
86 maintainers = with lib.maintainers; [ veprbl ];
87 };
88})