1{
2 lib,
3 stdenv,
4 torch,
5 apple-sdk_13,
6 buildPythonPackage,
7 darwinMinVersionHook,
8 fetchFromGitHub,
9
10 # nativeBuildInputs
11 libpng,
12 ninja,
13 which,
14
15 # buildInputs
16 libjpeg_turbo,
17
18 # dependencies
19 numpy,
20 pillow,
21 scipy,
22
23 # tests
24 pytest,
25 writableTmpDirAsHomeHook,
26}:
27
28let
29 inherit (torch) cudaCapabilities cudaPackages cudaSupport;
30
31 pname = "torchvision";
32 version = "0.21.0";
33in
34buildPythonPackage {
35 inherit pname version;
36
37 stdenv = torch.stdenv;
38
39 src = fetchFromGitHub {
40 owner = "pytorch";
41 repo = "vision";
42 tag = "v${version}";
43 hash = "sha256-eDWw1Lt/sUc2Xt6cqOM5xaOfmsm+NEL5lZO+cIJKMtU=";
44 };
45
46 nativeBuildInputs = [
47 libpng
48 ninja
49 which
50 ] ++ lib.optionals cudaSupport [ cudaPackages.cuda_nvcc ];
51
52 buildInputs =
53 [
54 libjpeg_turbo
55 libpng
56 torch.cxxdev
57 ]
58 ++ lib.optionals stdenv.hostPlatform.isDarwin [
59 # This should match the SDK used by `torch` above
60 apple-sdk_13
61
62 # error: unknown type name 'MPSGraphCompilationDescriptor'; did you mean 'MPSGraphExecutionDescriptor'?
63 # https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcompilationdescriptor/
64 (darwinMinVersionHook "12.0")
65 ];
66
67 dependencies = [
68 numpy
69 pillow
70 torch
71 scipy
72 ];
73
74 preConfigure =
75 ''
76 export TORCHVISION_INCLUDE="${libjpeg_turbo.dev}/include/"
77 export TORCHVISION_LIBRARY="${libjpeg_turbo}/lib/"
78 ''
79 + lib.optionalString cudaSupport ''
80 export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
81 export FORCE_CUDA=1
82 '';
83
84 # tests download big datasets, models, require internet connection, etc.
85 doCheck = false;
86
87 pythonImportsCheck = [ "torchvision" ];
88
89 nativeCheckInputs = [
90 pytest
91 writableTmpDirAsHomeHook
92 ];
93
94 checkPhase = ''
95 py.test test --ignore=test/test_datasets_download.py
96 '';
97
98 meta = {
99 description = "PyTorch vision library";
100 homepage = "https://pytorch.org/";
101 changelog = "https://github.com/pytorch/vision/releases/tag/v${version}";
102 license = lib.licenses.bsd3;
103 platforms = with lib.platforms; linux ++ lib.optionals (!cudaSupport) darwin;
104 maintainers = with lib.maintainers; [ ];
105 };
106}