1{ 2 stdenv, 3 lib, 4 buildPythonPackage, 5 fetchFromGitHub, 6 isPy27, 7 config, 8 9 # build-system 10 setuptools, 11 12 # dependencies 13 numpy, 14 scikit-learn, 15 torch, 16 tqdm, 17 18 # optional-dependencies 19 faiss, 20 tensorboard, 21 22 # tests 23 cudaSupport ? config.cudaSupport, 24 pytestCheckHook, 25 torchvision, 26}: 27 28buildPythonPackage rec { 29 pname = "pytorch-metric-learning"; 30 version = "2.8.1"; 31 pyproject = true; 32 33 disabled = isPy27; 34 35 src = fetchFromGitHub { 36 owner = "KevinMusgrave"; 37 repo = pname; 38 tag = "v${version}"; 39 hash = "sha256-WO/gv8rKkxY3pR627WrEPVyvZnvUZIKMzOierIW8bJA="; 40 }; 41 42 build-system = [ 43 setuptools 44 ]; 45 46 dependencies = [ 47 numpy 48 torch 49 scikit-learn 50 tqdm 51 ]; 52 53 optional-dependencies = { 54 with-hooks = [ 55 # TODO: record-keeper 56 faiss 57 tensorboard 58 ]; 59 with-hooks-cpu = [ 60 # TODO: record-keeper 61 faiss 62 tensorboard 63 ]; 64 }; 65 66 preCheck = '' 67 export HOME=$TMP 68 export TEST_DEVICE=cpu 69 export TEST_DTYPES=float32,float64 # half-precision tests fail on CPU 70 ''; 71 72 # package only requires `unittest`, but use `pytest` to exclude tests 73 nativeCheckInputs = [ 74 pytestCheckHook 75 torchvision 76 ] ++ lib.flatten (lib.attrValues optional-dependencies); 77 78 disabledTests = 79 [ 80 # network access 81 "test_tuplestoweights_sampler" 82 "test_metric_loss_only" 83 "test_add_to_indexer" 84 "test_get_nearest_neighbors" 85 "test_list_of_text" 86 "test_untrained_indexer" 87 ] 88 ++ lib.optionals cudaSupport [ 89 # crashes with SIGBART 90 "test_accuracy_calculator_and_faiss_with_torch_and_numpy" 91 "test_accuracy_calculator_large_k" 92 "test_custom_knn" 93 "test_global_embedding_space_tester" 94 "test_global_two_stream_embedding_space_tester" 95 "test_index_type" 96 "test_k_warning" 97 "test_many_tied_distances" 98 "test_query_within_reference" 99 "test_tied_distances" 100 "test_with_same_parent_label_tester" 101 ]; 102 103 meta = { 104 description = "Metric learning library for PyTorch"; 105 homepage = "https://github.com/KevinMusgrave/pytorch-metric-learning"; 106 changelog = "https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v${version}"; 107 license = lib.licenses.mit; 108 maintainers = with lib.maintainers; [ bcdarwin ]; 109 }; 110}