1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, cloudpickle
5, scikit-learn
6, scikitimage
7, packaging
8, psutil
9, py-deprecate
10, torch
11, pytestCheckHook
12, torchmetrics
13, pytorch-lightning
14}:
15
16let
17 pname = "torchmetrics";
18 version = "0.9.3";
19in
20buildPythonPackage {
21 inherit pname version;
22
23 src = fetchFromGitHub {
24 owner = "PyTorchLightning";
25 repo = "metrics";
26 rev = "refs/tags/v${version}";
27 hash = "sha256-L2p8UftRkuBuRJX4V5+OYkJeJ5pCK3MvfA1OvSfgglY=";
28 };
29
30 propagatedBuildInputs = [
31 packaging
32 py-deprecate
33 ];
34
35 # Let the user bring their own instance
36 buildInputs = [
37 torch
38 ];
39
40 checkInputs = [
41 pytorch-lightning
42 scikit-learn
43 scikitimage
44 cloudpickle
45 psutil
46 pytestCheckHook
47 ];
48
49 # A cyclic dependency in: integrations/test_lightning.py
50 doCheck = false;
51 passthru.tests.check = torchmetrics.overridePythonAttrs (_: {
52 doCheck = true;
53 });
54
55 disabledTestPaths = [
56 # These require too many "leftpad-level" dependencies
57 "tests/text"
58 "tests/audio"
59 "tests/image"
60
61 # A few non-deterministic things like test_check_compute_groups_is_faster
62 "tests/bases/test_collections.py"
63 ];
64
65 pythonImportsCheck = [
66 "torchmetrics"
67 ];
68
69 meta = with lib; {
70 description = "Machine learning metrics for distributed, scalable PyTorch applications (used in pytorch-lightning)";
71 homepage = "https://torchmetrics.readthedocs.io";
72 license = licenses.asl20;
73 maintainers = with maintainers; [
74 SomeoneSerge
75 ];
76 };
77}
78