lol
1{ lib
2, stdenv
3, cmake
4, libtorch-bin
5, linkFarm
6, symlinkJoin
7
8, cudaSupport
9, cudaPackages ? {}
10}:
11let
12 inherit (cudaPackages) cudatoolkit cudnn;
13
14 cudatoolkit_joined = symlinkJoin {
15 name = "${cudatoolkit.name}-unsplit";
16 paths = [ cudatoolkit.out cudatoolkit.lib ];
17 };
18
19 # We do not have access to /run/opengl-driver/lib in the sandbox,
20 # so use a stub instead.
21 cudaStub = linkFarm "cuda-stub" [{
22 name = "libcuda.so.1";
23 path = "${cudatoolkit}/lib/stubs/libcuda.so";
24 }];
25
26in stdenv.mkDerivation {
27 pname = "libtorch-test";
28 version = libtorch-bin.version;
29
30 src = ./.;
31
32 nativeBuildInputs = [ cmake ];
33
34 buildInputs = [ libtorch-bin ] ++
35 lib.optionals cudaSupport [ cudnn ];
36
37 cmakeFlags = lib.optionals cudaSupport
38 [ "-DCUDA_TOOLKIT_ROOT_DIR=${cudatoolkit_joined}" ];
39
40 doCheck = true;
41
42 installPhase = ''
43 touch $out
44 '';
45
46 checkPhase = lib.optionalString cudaSupport ''
47 LD_LIBRARY_PATH=${cudaStub}''${LD_LIBRARY_PATH:+:}$LD_LIBRARY_PATH \
48 '' + ''
49 ./test
50 '';
51}