1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, isPy27
5, numpy
6, scikit-learn
7, pytestCheckHook
8, pytorch
9, torchvision
10, tqdm
11}:
12
13buildPythonPackage rec {
14 pname = "pytorch-metric-learning";
15 version = "0.9.99";
16
17 disabled = isPy27;
18
19 src = fetchFromGitHub {
20 owner = "KevinMusgrave";
21 repo = pname;
22 rev = "v${version}";
23 sha256 = "1ahs2b7q3hxi6yv4g2fjy7mvl899h56dvlpc2r9301440qsgkdzr";
24 };
25
26 propagatedBuildInputs = [
27 numpy
28 pytorch
29 scikit-learn
30 torchvision
31 tqdm
32 ];
33
34 preCheck = ''
35 export HOME=$TMP
36 export TEST_DEVICE=cpu
37 export TEST_DTYPES=float32,float64 # half-precision tests fail on CPU
38 '';
39 # package only requires `unittest`, but use `pytest` to exclude tests
40 checkInputs = [ pytestCheckHook ];
41 disabledTests = [
42 # requires FAISS (not in Nixpkgs)
43 "test_accuracy_calculator_and_faiss"
44 "test_global_embedding_space_tester"
45 "test_with_same_parent_label_tester"
46 # require network access:
47 "test_get_nearest_neighbors"
48 "test_tuplestoweights_sampler"
49 "test_untrained_indexer"
50 "test_metric_loss_only"
51 "test_pca"
52 # flaky
53 "test_distributed_classifier_loss_and_miner"
54 ];
55
56 meta = {
57 description = "Metric learning library for PyTorch";
58 homepage = "https://github.com/KevinMusgrave/pytorch-metric-learning";
59 changelog = "https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v${version}";
60 license = lib.licenses.mit;
61 maintainers = with lib.maintainers; [ bcdarwin ];
62 };
63}