1{ stdenv
2, lib
3, buildPythonPackage
4, fetchFromGitHub
5, isPy27
6, numpy
7, scikit-learn
8, pytestCheckHook
9, torch
10, torchvision
11, tqdm
12, faiss
13}:
14
15buildPythonPackage rec {
16 pname = "pytorch-metric-learning";
17 version = "2.1.2";
18
19 disabled = isPy27;
20
21 src = fetchFromGitHub {
22 owner = "KevinMusgrave";
23 repo = pname;
24 rev = "refs/tags/v${version}";
25 hash = "sha256-B2gDPOQSJGg29xjJbZsWUrSalBnn+S8h+2j8NQ4tfTM=";
26 };
27
28 propagatedBuildInputs = [
29 numpy
30 torch
31 scikit-learn
32 torchvision
33 tqdm
34 ];
35
36 preCheck = ''
37 export HOME=$TMP
38 export TEST_DEVICE=cpu
39 export TEST_DTYPES=float32,float64 # half-precision tests fail on CPU
40 '';
41
42 # package only requires `unittest`, but use `pytest` to exclude tests
43 nativeCheckInputs = [
44 faiss
45 pytestCheckHook
46 ];
47
48 disabledTests = [
49 # TypeError: setup() missing 1 required positional argument: 'world_size'
50 "TestDistributedLossWrapper"
51 # require network access:
52 "TestInference"
53 "test_get_nearest_neighbors"
54 "test_tuplestoweights_sampler"
55 "test_untrained_indexer"
56 "test_metric_loss_only"
57 "test_pca"
58 # flaky
59 "test_distributed_classifier_loss_and_miner"
60 ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
61 # RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly
62 "test_global_embedding_space_tester"
63 "test_with_same_parent_label_tester"
64 ];
65
66 meta = {
67 description = "Metric learning library for PyTorch";
68 homepage = "https://github.com/KevinMusgrave/pytorch-metric-learning";
69 changelog = "https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v${version}";
70 license = lib.licenses.mit;
71 maintainers = with lib.maintainers; [ bcdarwin ];
72 };
73}