python3Packages.torchvision: added cudaSupport option (#132917)

Co-authored-by: Sandro <sandro.jaeckel@gmail.com>

authored by

Alexander Kiselyov
Sandro
and committed by
GitHub
717538e9 0d078fcd

+26 -3
+5
pkgs/development/python-modules/pytorch/default.nix
··· 301 301 # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder. 302 302 requiredSystemFeatures = [ "big-parallel" ]; 303 303 304 + passthru = { 305 + inherit cudaSupport; 306 + cudaArchList = final_cudaArchList; 307 + }; 308 + 304 309 meta = with lib; { 305 310 description = "Open source, prototype-to-production deep learning platform"; 306 311 homepage = "https://pytorch.org/";
+21 -3
pkgs/development/python-modules/torchvision/default.nix
··· 1 1 { lib 2 + , symlinkJoin 2 3 , buildPythonPackage 3 4 , fetchFromGitHub 4 5 , ninja ··· 10 11 , pillow 11 12 , pytorch 12 13 , pytest 14 + , cudatoolkit 15 + , cudnn 16 + , cudaSupport ? pytorch.cudaSupport or false # by default uses the value from pytorch 13 17 }: 14 18 15 - buildPythonPackage rec { 19 + let 20 + cudatoolkit_joined = symlinkJoin { 21 + name = "${cudatoolkit.name}-unsplit"; 22 + paths = [ cudatoolkit.out cudatoolkit.lib ]; 23 + }; 24 + cudaArchStr = lib.optionalString cudaSupport lib.strings.concatStringsSep ";" pytorch.cudaArchList; 25 + in buildPythonPackage rec { 16 26 pname = "torchvision"; 17 27 version = "0.10.0"; 18 28 ··· 23 33 sha256 = "13j04ij0jmi58nhav1p69xrm8dg7jisg23268i3n6lnms37n02kc"; 24 34 }; 25 35 26 - nativeBuildInputs = [ libpng ninja which ]; 36 + nativeBuildInputs = [ libpng ninja which ] 37 + ++ lib.optionals cudaSupport [ cudatoolkit_joined ]; 27 38 28 39 TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/"; 29 40 TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/"; 30 41 31 - buildInputs = [ libjpeg_turbo libpng ]; 42 + buildInputs = [ libjpeg_turbo libpng ] 43 + ++ lib.optionals cudaSupport [ cudnn ]; 32 44 33 45 propagatedBuildInputs = [ numpy pillow pytorch scipy ]; 34 46 47 + preBuild = lib.optionalString cudaSupport '' 48 + export TORCH_CUDA_ARCH_LIST="${cudaArchStr}" 49 + export FORCE_CUDA=1 50 + ''; 51 + 35 52 # tries to download many datasets for tests 36 53 doCheck = false; 37 54 ··· 45 62 description = "PyTorch vision library"; 46 63 homepage = "https://pytorch.org/"; 47 64 license = licenses.bsd3; 65 + platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin; 48 66 maintainers = with maintainers; [ ericsagnes ]; 49 67 }; 50 68 }