Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
at litex 2.4 kB view raw
1{ buildPythonPackage 2, fetchFromGitHub 3, lib 4, libjpeg_turbo 5, libpng 6, ninja 7, numpy 8, pillow 9, pytest 10, scipy 11, symlinkJoin 12, torch 13, which 14}: 15 16let 17 inherit (torch) cudaCapabilities cudaPackages cudaSupport; 18 inherit (cudaPackages) backendStdenv cudaVersion; 19 20 # NOTE: torchvision doesn't use cudnn; torch does! 21 # For this reason it is not included. 22 cuda-common-redist = with cudaPackages; [ 23 cuda_cccl # <thrust/*> 24 libcublas # cublas_v2.h 25 libcusolver # cusolverDn.h 26 libcusparse # cusparse.h 27 ]; 28 29 cuda-native-redist = symlinkJoin { 30 name = "cuda-native-redist-${cudaVersion}"; 31 paths = with cudaPackages; [ 32 cuda_cudart # cuda_runtime.h 33 cuda_nvcc 34 ] ++ cuda-common-redist; 35 }; 36 37 cuda-redist = symlinkJoin { 38 name = "cuda-redist-${cudaVersion}"; 39 paths = cuda-common-redist; 40 }; 41 42 pname = "torchvision"; 43 version = "0.15.2"; 44in 45buildPythonPackage { 46 inherit pname version; 47 48 src = fetchFromGitHub { 49 owner = "pytorch"; 50 repo = "vision"; 51 rev = "refs/tags/v${version}"; 52 hash = "sha256-KNbOgd6PCINZqZ24c/Ev+ODux3ik5iUlzem9uUfQArM="; 53 }; 54 55 nativeBuildInputs = [ libpng ninja which ] ++ lib.optionals cudaSupport [ cuda-native-redist ]; 56 57 buildInputs = [ libjpeg_turbo libpng ] ++ lib.optionals cudaSupport [ cuda-redist ]; 58 59 propagatedBuildInputs = [ numpy pillow torch scipy ]; 60 61 preConfigure = '' 62 export TORCHVISION_INCLUDE="${libjpeg_turbo.dev}/include/" 63 export TORCHVISION_LIBRARY="${libjpeg_turbo}/lib/" 64 '' 65 # NOTE: We essentially override the compilers provided by stdenv because we don't have a hook 66 # for cudaPackages to swap in compilers supported by NVCC. 67 + lib.optionalString cudaSupport '' 68 export CC=${backendStdenv.cc}/bin/cc 69 export CXX=${backendStdenv.cc}/bin/c++ 70 export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}" 71 export FORCE_CUDA=1 72 ''; 73 74 # tries to download many datasets for tests 75 doCheck = false; 76 77 pythonImportsCheck = [ "torchvision" ]; 78 checkPhase = '' 79 HOME=$TMPDIR py.test test --ignore=test/test_datasets_download.py 80 ''; 81 82 nativeCheckInputs = [ pytest ]; 83 84 meta = with lib; { 85 description = "PyTorch vision library"; 86 homepage = "https://pytorch.org/"; 87 license = licenses.bsd3; 88 platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin; 89 maintainers = with maintainers; [ ericsagnes ]; 90 }; 91}