1{ lib 2, symlinkJoin 3, buildPythonPackage 4, fetchFromGitHub 5, ninja 6, which 7, libjpeg_turbo 8, libpng 9, numpy 10, scipy 11, pillow 12, torch 13, pytest 14, cudaSupport ? torch.cudaSupport or false # by default uses the value from torch 15}: 16 17let 18 inherit (torch.cudaPackages) cudatoolkit cudnn; 19 20 cudatoolkit_joined = symlinkJoin { 21 name = "${cudatoolkit.name}-unsplit"; 22 paths = [ cudatoolkit.out cudatoolkit.lib ]; 23 }; 24 cudaArchStr = lib.optionalString cudaSupport lib.strings.concatStringsSep ";" torch.cudaArchList; 25in buildPythonPackage rec { 26 pname = "torchvision"; 27 version = "0.13.1"; 28 29 src = fetchFromGitHub { 30 owner = "pytorch"; 31 repo = "vision"; 32 rev = "refs/tags/v${version}"; 33 hash = "sha256-QlUAFAG6zEDCDSXR5n2CznspU3fT0kbqySzofGLPgK4="; 34 }; 35 36 nativeBuildInputs = [ libpng ninja which ] 37 ++ lib.optionals cudaSupport [ cudatoolkit_joined ]; 38 39 TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/"; 40 TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/"; 41 42 buildInputs = [ libjpeg_turbo libpng ] 43 ++ lib.optionals cudaSupport [ cudnn ]; 44 45 propagatedBuildInputs = [ numpy pillow torch scipy ]; 46 47 preBuild = lib.optionalString cudaSupport '' 48 export TORCH_CUDA_ARCH_LIST="${cudaArchStr}" 49 export FORCE_CUDA=1 50 ''; 51 52 # tries to download many datasets for tests 53 doCheck = false; 54 55 checkPhase = '' 56 HOME=$TMPDIR py.test test --ignore=test/test_datasets_download.py 57 ''; 58 59 checkInputs = [ pytest ]; 60 61 meta = with lib; { 62 description = "PyTorch vision library"; 63 homepage = "https://pytorch.org/"; 64 license = licenses.bsd3; 65 platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin; 66 maintainers = with maintainers; [ ericsagnes ]; 67 }; 68}