1{ buildPythonPackage
2, fetchFromGitHub
3, lib
4, libjpeg_turbo
5, libpng
6, ninja
7, numpy
8, pillow
9, pytest
10, scipy
11, symlinkJoin
12, torch
13, which
14}:
15
16let
17 inherit (torch) cudaCapabilities cudaPackages cudaSupport;
18 inherit (cudaPackages) backendStdenv cudaVersion;
19
20 # NOTE: torchvision doesn't use cudnn; torch does!
21 # For this reason it is not included.
22 cuda-common-redist = with cudaPackages; [
23 cuda_cccl # <thrust/*>
24 libcublas # cublas_v2.h
25 libcusolver # cusolverDn.h
26 libcusparse # cusparse.h
27 ];
28
29 cuda-native-redist = symlinkJoin {
30 name = "cuda-native-redist-${cudaVersion}";
31 paths = with cudaPackages; [
32 cuda_cudart # cuda_runtime.h
33 cuda_nvcc
34 ] ++ cuda-common-redist;
35 };
36
37 cuda-redist = symlinkJoin {
38 name = "cuda-redist-${cudaVersion}";
39 paths = cuda-common-redist;
40 };
41
42 pname = "torchvision";
43 version = "0.15.2";
44in
45buildPythonPackage {
46 inherit pname version;
47
48 src = fetchFromGitHub {
49 owner = "pytorch";
50 repo = "vision";
51 rev = "refs/tags/v${version}";
52 hash = "sha256-KNbOgd6PCINZqZ24c/Ev+ODux3ik5iUlzem9uUfQArM=";
53 };
54
55 nativeBuildInputs = [ libpng ninja which ] ++ lib.optionals cudaSupport [ cuda-native-redist ];
56
57 buildInputs = [ libjpeg_turbo libpng ] ++ lib.optionals cudaSupport [ cuda-redist ];
58
59 propagatedBuildInputs = [ numpy pillow torch scipy ];
60
61 preConfigure = ''
62 export TORCHVISION_INCLUDE="${libjpeg_turbo.dev}/include/"
63 export TORCHVISION_LIBRARY="${libjpeg_turbo}/lib/"
64 ''
65 # NOTE: We essentially override the compilers provided by stdenv because we don't have a hook
66 # for cudaPackages to swap in compilers supported by NVCC.
67 + lib.optionalString cudaSupport ''
68 export CC=${backendStdenv.cc}/bin/cc
69 export CXX=${backendStdenv.cc}/bin/c++
70 export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
71 export FORCE_CUDA=1
72 '';
73
74 # tries to download many datasets for tests
75 doCheck = false;
76
77 pythonImportsCheck = [ "torchvision" ];
78 checkPhase = ''
79 HOME=$TMPDIR py.test test --ignore=test/test_datasets_download.py
80 '';
81
82 nativeCheckInputs = [ pytest ];
83
84 meta = with lib; {
85 description = "PyTorch vision library";
86 homepage = "https://pytorch.org/";
87 license = licenses.bsd3;
88 platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin;
89 maintainers = with maintainers; [ ericsagnes ];
90 };
91}