1{ 2 lib, 3 stdenv, 4 torch, 5 apple-sdk_13, 6 buildPythonPackage, 7 darwinMinVersionHook, 8 fetchFromGitHub, 9 10 # nativeBuildInputs 11 libpng, 12 ninja, 13 which, 14 15 # buildInputs 16 libjpeg_turbo, 17 18 # dependencies 19 numpy, 20 pillow, 21 scipy, 22 23 # tests 24 pytest, 25 writableTmpDirAsHomeHook, 26}: 27 28let 29 inherit (torch) cudaCapabilities cudaPackages cudaSupport; 30 31 pname = "torchvision"; 32 version = "0.21.0"; 33in 34buildPythonPackage { 35 inherit pname version; 36 37 stdenv = torch.stdenv; 38 39 src = fetchFromGitHub { 40 owner = "pytorch"; 41 repo = "vision"; 42 tag = "v${version}"; 43 hash = "sha256-eDWw1Lt/sUc2Xt6cqOM5xaOfmsm+NEL5lZO+cIJKMtU="; 44 }; 45 46 nativeBuildInputs = [ 47 libpng 48 ninja 49 which 50 ] ++ lib.optionals cudaSupport [ cudaPackages.cuda_nvcc ]; 51 52 buildInputs = 53 [ 54 libjpeg_turbo 55 libpng 56 torch.cxxdev 57 ] 58 ++ lib.optionals stdenv.hostPlatform.isDarwin [ 59 # This should match the SDK used by `torch` above 60 apple-sdk_13 61 62 # error: unknown type name 'MPSGraphCompilationDescriptor'; did you mean 'MPSGraphExecutionDescriptor'? 63 # https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcompilationdescriptor/ 64 (darwinMinVersionHook "12.0") 65 ]; 66 67 dependencies = [ 68 numpy 69 pillow 70 torch 71 scipy 72 ]; 73 74 preConfigure = 75 '' 76 export TORCHVISION_INCLUDE="${libjpeg_turbo.dev}/include/" 77 export TORCHVISION_LIBRARY="${libjpeg_turbo}/lib/" 78 '' 79 + lib.optionalString cudaSupport '' 80 export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}" 81 export FORCE_CUDA=1 82 ''; 83 84 # tests download big datasets, models, require internet connection, etc. 85 doCheck = false; 86 87 pythonImportsCheck = [ "torchvision" ]; 88 89 nativeCheckInputs = [ 90 pytest 91 writableTmpDirAsHomeHook 92 ]; 93 94 checkPhase = '' 95 py.test test --ignore=test/test_datasets_download.py 96 ''; 97 98 meta = { 99 description = "PyTorch vision library"; 100 homepage = "https://pytorch.org/"; 101 changelog = "https://github.com/pytorch/vision/releases/tag/v${version}"; 102 license = lib.licenses.bsd3; 103 platforms = with lib.platforms; linux ++ lib.optionals (!cudaSupport) darwin; 104 maintainers = with lib.maintainers; [ ]; 105 }; 106}