1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, isPy27
5, numpy
6, scikit-learn
7, pytestCheckHook
8, torch
9, torchvision
10, tqdm
11, faiss
12}:
13
14buildPythonPackage rec {
15 pname = "pytorch-metric-learning";
16 version = "1.6.2";
17
18 disabled = isPy27;
19
20 src = fetchFromGitHub {
21 owner = "KevinMusgrave";
22 repo = pname;
23 rev = "refs/tags/v${version}";
24 sha256 = "sha256-y/KqMqxSzTGsjwtbhHbFK+S4CX6yHC6tR6jdPWUzeGg=";
25 };
26
27 propagatedBuildInputs = [
28 numpy
29 torch
30 scikit-learn
31 torchvision
32 tqdm
33 ];
34
35 preCheck = ''
36 export HOME=$TMP
37 export TEST_DEVICE=cpu
38 export TEST_DTYPES=float32,float64 # half-precision tests fail on CPU
39 '';
40
41 # package only requires `unittest`, but use `pytest` to exclude tests
42 checkInputs = [
43 faiss
44 pytestCheckHook
45 ];
46
47 disabledTests = [
48 # TypeError: setup() missing 1 required positional argument: 'world_size'
49 "TestDistributedLossWrapper"
50 # require network access:
51 "TestInference"
52 "test_get_nearest_neighbors"
53 "test_tuplestoweights_sampler"
54 "test_untrained_indexer"
55 "test_metric_loss_only"
56 "test_pca"
57 # flaky
58 "test_distributed_classifier_loss_and_miner"
59 ];
60
61 meta = {
62 description = "Metric learning library for PyTorch";
63 homepage = "https://github.com/KevinMusgrave/pytorch-metric-learning";
64 changelog = "https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v${version}";
65 license = lib.licenses.mit;
66 maintainers = with lib.maintainers; [ bcdarwin ];
67 };
68}