1{ lib
2, symlinkJoin
3, buildPythonPackage
4, fetchFromGitHub
5, ninja
6, which
7, libjpeg_turbo
8, libpng
9, numpy
10, scipy
11, pillow
12, torch
13, pytest
14, cudaSupport ? torch.cudaSupport or false # by default uses the value from torch
15}:
16
17let
18 inherit (torch.cudaPackages) cudatoolkit cudnn;
19
20 cudatoolkit_joined = symlinkJoin {
21 name = "${cudatoolkit.name}-unsplit";
22 paths = [ cudatoolkit.out cudatoolkit.lib ];
23 };
24 cudaArchStr = lib.optionalString cudaSupport lib.strings.concatStringsSep ";" torch.cudaArchList;
25in buildPythonPackage rec {
26 pname = "torchvision";
27 version = "0.13.1";
28
29 src = fetchFromGitHub {
30 owner = "pytorch";
31 repo = "vision";
32 rev = "refs/tags/v${version}";
33 hash = "sha256-QlUAFAG6zEDCDSXR5n2CznspU3fT0kbqySzofGLPgK4=";
34 };
35
36 nativeBuildInputs = [ libpng ninja which ]
37 ++ lib.optionals cudaSupport [ cudatoolkit_joined ];
38
39 TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/";
40 TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/";
41
42 buildInputs = [ libjpeg_turbo libpng ]
43 ++ lib.optionals cudaSupport [ cudnn ];
44
45 propagatedBuildInputs = [ numpy pillow torch scipy ];
46
47 preBuild = lib.optionalString cudaSupport ''
48 export TORCH_CUDA_ARCH_LIST="${cudaArchStr}"
49 export FORCE_CUDA=1
50 '';
51
52 # tries to download many datasets for tests
53 doCheck = false;
54
55 checkPhase = ''
56 HOME=$TMPDIR py.test test --ignore=test/test_datasets_download.py
57 '';
58
59 checkInputs = [ pytest ];
60
61 meta = with lib; {
62 description = "PyTorch vision library";
63 homepage = "https://pytorch.org/";
64 license = licenses.bsd3;
65 platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin;
66 maintainers = with maintainers; [ ericsagnes ];
67 };
68}