1{
2 stdenv,
3 lib,
4 buildPythonPackage,
5 fetchFromGitHub,
6 isPy27,
7 numpy,
8 scikit-learn,
9 pytestCheckHook,
10 torch,
11 torchvision,
12 tqdm,
13 faiss,
14}:
15
16buildPythonPackage rec {
17 pname = "pytorch-metric-learning";
18 version = "2.5.0";
19 format = "setuptools";
20
21 disabled = isPy27;
22
23 src = fetchFromGitHub {
24 owner = "KevinMusgrave";
25 repo = pname;
26 rev = "refs/tags/v${version}";
27 hash = "sha256-1y7VCnzgwFOMeMloVdYyszNhf/zZlBJUjuF4qgA5c0A=";
28 };
29
30 propagatedBuildInputs = [
31 numpy
32 torch
33 scikit-learn
34 torchvision
35 tqdm
36 ];
37
38 preCheck = ''
39 export HOME=$TMP
40 export TEST_DEVICE=cpu
41 export TEST_DTYPES=float32,float64 # half-precision tests fail on CPU
42 '';
43
44 # package only requires `unittest`, but use `pytest` to exclude tests
45 nativeCheckInputs = [
46 faiss
47 pytestCheckHook
48 ];
49
50 disabledTests =
51 [
52 # TypeError: setup() missing 1 required positional argument: 'world_size'
53 "TestDistributedLossWrapper"
54 # require network access:
55 "TestInference"
56 "test_get_nearest_neighbors"
57 "test_tuplestoweights_sampler"
58 "test_untrained_indexer"
59 "test_metric_loss_only"
60 "test_pca"
61 # flaky
62 "test_distributed_classifier_loss_and_miner"
63 ]
64 ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
65 # RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly
66 "test_global_embedding_space_tester"
67 "test_with_same_parent_label_tester"
68 ];
69
70 meta = {
71 description = "Metric learning library for PyTorch";
72 homepage = "https://github.com/KevinMusgrave/pytorch-metric-learning";
73 changelog = "https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v${version}";
74 license = lib.licenses.mit;
75 maintainers = with lib.maintainers; [ bcdarwin ];
76 };
77}