1{ lib 2, symlinkJoin 3, buildPythonPackage 4, fetchFromGitHub 5, ninja 6, which 7, libjpeg_turbo 8, libpng 9, numpy 10, scipy 11, pillow 12, pytorch 13, pytest 14, cudatoolkit 15, cudnn 16, cudaSupport ? pytorch.cudaSupport or false # by default uses the value from pytorch 17}: 18 19let 20 cudatoolkit_joined = symlinkJoin { 21 name = "${cudatoolkit.name}-unsplit"; 22 paths = [ cudatoolkit.out cudatoolkit.lib ]; 23 }; 24 cudaArchStr = lib.optionalString cudaSupport lib.strings.concatStringsSep ";" pytorch.cudaArchList; 25in buildPythonPackage rec { 26 pname = "torchvision"; 27 version = "0.11.1"; 28 29 src = fetchFromGitHub { 30 owner = "pytorch"; 31 repo = "vision"; 32 rev = "v${version}"; 33 sha256 = "05dg835mmpzf7k2jn101l7x7cnra1kldwbgf19zblym5lfn21zhf"; 34 }; 35 36 nativeBuildInputs = [ libpng ninja which ] 37 ++ lib.optionals cudaSupport [ cudatoolkit_joined ]; 38 39 TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/"; 40 TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/"; 41 42 buildInputs = [ libjpeg_turbo libpng ] 43 ++ lib.optionals cudaSupport [ cudnn ]; 44 45 propagatedBuildInputs = [ numpy pillow pytorch scipy ]; 46 47 preBuild = lib.optionalString cudaSupport '' 48 export TORCH_CUDA_ARCH_LIST="${cudaArchStr}" 49 export FORCE_CUDA=1 50 ''; 51 52 # tries to download many datasets for tests 53 doCheck = false; 54 55 checkPhase = '' 56 HOME=$TMPDIR py.test test --ignore=test/test_datasets_download.py 57 ''; 58 59 checkInputs = [ pytest ]; 60 61 meta = with lib; { 62 description = "PyTorch vision library"; 63 homepage = "https://pytorch.org/"; 64 license = licenses.bsd3; 65 platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin; 66 maintainers = with maintainers; [ ericsagnes ]; 67 }; 68}