1{
2 stdenv,
3 lib,
4 buildPythonPackage,
5 fetchFromGitHub,
6 isPy27,
7 config,
8
9 # build-system
10 setuptools,
11
12 # dependencies
13 numpy,
14 scikit-learn,
15 torch,
16 tqdm,
17
18 # optional-dependencies
19 faiss,
20 tensorboard,
21
22 # tests
23 cudaSupport ? config.cudaSupport,
24 pytestCheckHook,
25 torchvision,
26}:
27
28buildPythonPackage rec {
29 pname = "pytorch-metric-learning";
30 version = "2.8.1";
31 pyproject = true;
32
33 disabled = isPy27;
34
35 src = fetchFromGitHub {
36 owner = "KevinMusgrave";
37 repo = pname;
38 tag = "v${version}";
39 hash = "sha256-WO/gv8rKkxY3pR627WrEPVyvZnvUZIKMzOierIW8bJA=";
40 };
41
42 build-system = [
43 setuptools
44 ];
45
46 dependencies = [
47 numpy
48 torch
49 scikit-learn
50 tqdm
51 ];
52
53 optional-dependencies = {
54 with-hooks = [
55 # TODO: record-keeper
56 faiss
57 tensorboard
58 ];
59 with-hooks-cpu = [
60 # TODO: record-keeper
61 faiss
62 tensorboard
63 ];
64 };
65
66 preCheck = ''
67 export HOME=$TMP
68 export TEST_DEVICE=cpu
69 export TEST_DTYPES=float32,float64 # half-precision tests fail on CPU
70 '';
71
72 # package only requires `unittest`, but use `pytest` to exclude tests
73 nativeCheckInputs = [
74 pytestCheckHook
75 torchvision
76 ] ++ lib.flatten (lib.attrValues optional-dependencies);
77
78 disabledTests =
79 [
80 # network access
81 "test_tuplestoweights_sampler"
82 "test_metric_loss_only"
83 "test_add_to_indexer"
84 "test_get_nearest_neighbors"
85 "test_list_of_text"
86 "test_untrained_indexer"
87 ]
88 ++ lib.optionals cudaSupport [
89 # crashes with SIGBART
90 "test_accuracy_calculator_and_faiss_with_torch_and_numpy"
91 "test_accuracy_calculator_large_k"
92 "test_custom_knn"
93 "test_global_embedding_space_tester"
94 "test_global_two_stream_embedding_space_tester"
95 "test_index_type"
96 "test_k_warning"
97 "test_many_tied_distances"
98 "test_query_within_reference"
99 "test_tied_distances"
100 "test_with_same_parent_label_tester"
101 ];
102
103 meta = {
104 description = "Metric learning library for PyTorch";
105 homepage = "https://github.com/KevinMusgrave/pytorch-metric-learning";
106 changelog = "https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v${version}";
107 license = lib.licenses.mit;
108 maintainers = with lib.maintainers; [ bcdarwin ];
109 };
110}