lol
1{ lib
2, stdenv
3, cmake
4, libtorch-bin
5, linkFarm
6, symlinkJoin
7
8, cudaSupport
9, cudatoolkit
10, cudnn
11}:
12let
13 cudatoolkit_joined = symlinkJoin {
14 name = "${cudatoolkit.name}-unsplit";
15 paths = [ cudatoolkit.out cudatoolkit.lib ];
16 };
17
18 # We do not have access to /run/opengl-driver/lib in the sandbox,
19 # so use a stub instead.
20 cudaStub = linkFarm "cuda-stub" [{
21 name = "libcuda.so.1";
22 path = "${cudatoolkit}/lib/stubs/libcuda.so";
23 }];
24
25in stdenv.mkDerivation {
26 pname = "libtorch-test";
27 version = libtorch-bin.version;
28
29 src = ./.;
30
31 nativeBuildInputs = [ cmake ];
32
33 buildInputs = [ libtorch-bin ] ++
34 lib.optionals cudaSupport [ cudnn ];
35
36 cmakeFlags = lib.optionals cudaSupport
37 [ "-DCUDA_TOOLKIT_ROOT_DIR=${cudatoolkit_joined}" ];
38
39 doCheck = true;
40
41 installPhase = ''
42 touch $out
43 '';
44
45 checkPhase = lib.optionalString cudaSupport ''
46 LD_LIBRARY_PATH=${cudaStub}''${LD_LIBRARY_PATH:+:}$LD_LIBRARY_PATH \
47 '' + ''
48 ./test
49 '';
50}