1{
2 buildPythonPackage,
3 fetchFromGitHub,
4 lib,
5 libjpeg_turbo,
6 libpng,
7 ninja,
8 numpy,
9 pillow,
10 pytest,
11 scipy,
12 torch,
13 which,
14}:
15
16let
17 inherit (torch) cudaCapabilities cudaPackages cudaSupport;
18 inherit (cudaPackages) backendStdenv;
19
20 pname = "torchvision";
21 version = "0.18.0";
22in
23buildPythonPackage {
24 inherit pname version;
25
26 src = fetchFromGitHub {
27 owner = "pytorch";
28 repo = "vision";
29 rev = "refs/tags/v${version}";
30 hash = "sha256-VWbalbLSV5a+t9eAO7QzQ/e11KkhGg6MHgd5vXcAUXc=";
31 };
32
33 nativeBuildInputs = [
34 libpng
35 ninja
36 which
37 ] ++ lib.optionals cudaSupport [ cudaPackages.cuda_nvcc ];
38
39 buildInputs = [
40 libjpeg_turbo
41 libpng
42 torch.cxxdev
43 ];
44
45 propagatedBuildInputs = [
46 numpy
47 pillow
48 torch
49 scipy
50 ];
51
52 preConfigure =
53 ''
54 export TORCHVISION_INCLUDE="${libjpeg_turbo.dev}/include/"
55 export TORCHVISION_LIBRARY="${libjpeg_turbo}/lib/"
56 ''
57 # NOTE: We essentially override the compilers provided by stdenv because we don't have a hook
58 # for cudaPackages to swap in compilers supported by NVCC.
59 + lib.optionalString cudaSupport ''
60 export CC=${backendStdenv.cc}/bin/cc
61 export CXX=${backendStdenv.cc}/bin/c++
62 export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
63 export FORCE_CUDA=1
64 '';
65
66 # tries to download many datasets for tests
67 doCheck = false;
68
69 pythonImportsCheck = [ "torchvision" ];
70 checkPhase = ''
71 HOME=$TMPDIR py.test test --ignore=test/test_datasets_download.py
72 '';
73
74 nativeCheckInputs = [ pytest ];
75
76 meta = with lib; {
77 description = "PyTorch vision library";
78 homepage = "https://pytorch.org/";
79 license = licenses.bsd3;
80 platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin;
81 maintainers = with maintainers; [ ericsagnes ];
82 };
83}