1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4
5# build-system
6, setuptools
7
8# dependencies
9, boltons
10, numpy
11, scipy
12, torch
13, trampoline
14
15# tests
16, pytestCheckHook
17}:
18
19buildPythonPackage rec {
20 pname = "torchsde";
21 version = "0.2.4";
22 format = "pyproject";
23
24 src = fetchFromGitHub {
25 owner = "google-research";
26 repo = "torchsde";
27 rev = "v${version}";
28 hash = "sha256-qQ7oswm0qTdq1xpQElt5cd3K0zskH+H/lgyEnxbCqsI=";
29 };
30
31 postPatch = ''
32 substituteInPlace setup.py \
33 --replace "numpy==1.19.*" "numpy" \
34 --replace "scipy==1.5.*" "scipy"
35 '';
36
37 nativeBuildInputs = [
38 setuptools
39 ];
40
41 propagatedBuildInputs = [
42 boltons
43 numpy
44 scipy
45 torch
46 trampoline
47 ];
48
49 pythonImportsCheck = [ "torchsde" ];
50
51 nativeCheckInputs = [
52 pytestCheckHook
53 ];
54
55 disabledTests = [
56 # RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.
57 "test_adjoint"
58 ];
59
60 meta = with lib; {
61 changelog = "https://github.com/google-research/torchsde/releases/tag/v${version}";
62 description = "Differentiable SDE solvers with GPU support and efficient sensitivity analysis";
63 homepage = "https://github.com/google-research/torchsde";
64 license = licenses.asl20;
65 maintainers = teams.tts.members;
66 };
67}