1{ 2 buildPythonPackage, 3 fetchFromGitHub, 4 lib, 5 libjpeg_turbo, 6 libpng, 7 ninja, 8 numpy, 9 pillow, 10 pytest, 11 scipy, 12 torch, 13 which, 14}: 15 16let 17 inherit (torch) cudaCapabilities cudaPackages cudaSupport; 18 inherit (cudaPackages) backendStdenv; 19 20 pname = "torchvision"; 21 version = "0.18.0"; 22in 23buildPythonPackage { 24 inherit pname version; 25 26 src = fetchFromGitHub { 27 owner = "pytorch"; 28 repo = "vision"; 29 rev = "refs/tags/v${version}"; 30 hash = "sha256-VWbalbLSV5a+t9eAO7QzQ/e11KkhGg6MHgd5vXcAUXc="; 31 }; 32 33 nativeBuildInputs = [ 34 libpng 35 ninja 36 which 37 ] ++ lib.optionals cudaSupport [ cudaPackages.cuda_nvcc ]; 38 39 buildInputs = [ 40 libjpeg_turbo 41 libpng 42 torch.cxxdev 43 ]; 44 45 propagatedBuildInputs = [ 46 numpy 47 pillow 48 torch 49 scipy 50 ]; 51 52 preConfigure = 53 '' 54 export TORCHVISION_INCLUDE="${libjpeg_turbo.dev}/include/" 55 export TORCHVISION_LIBRARY="${libjpeg_turbo}/lib/" 56 '' 57 # NOTE: We essentially override the compilers provided by stdenv because we don't have a hook 58 # for cudaPackages to swap in compilers supported by NVCC. 59 + lib.optionalString cudaSupport '' 60 export CC=${backendStdenv.cc}/bin/cc 61 export CXX=${backendStdenv.cc}/bin/c++ 62 export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}" 63 export FORCE_CUDA=1 64 ''; 65 66 # tries to download many datasets for tests 67 doCheck = false; 68 69 pythonImportsCheck = [ "torchvision" ]; 70 checkPhase = '' 71 HOME=$TMPDIR py.test test --ignore=test/test_datasets_download.py 72 ''; 73 74 nativeCheckInputs = [ pytest ]; 75 76 meta = with lib; { 77 description = "PyTorch vision library"; 78 homepage = "https://pytorch.org/"; 79 license = licenses.bsd3; 80 platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin; 81 maintainers = with maintainers; [ ericsagnes ]; 82 }; 83}