1{
2 _cuda,
3 cudaOlder,
4 cudaPackages,
5 cudaMajorMinorVersion,
6 lib,
7 patchelf,
8 requireFile,
9 stdenv,
10}:
11let
12 inherit (lib)
13 attrsets
14 maintainers
15 meta
16 strings
17 versions
18 ;
19 inherit (stdenv) hostPlatform;
20 # targetArch :: String
21 targetArch = attrsets.attrByPath [ hostPlatform.system ] "unsupported" {
22 x86_64-linux = "x86_64-linux-gnu";
23 aarch64-linux = "aarch64-linux-gnu";
24 };
25in
26finalAttrs: prevAttrs: {
27 # Useful for inspecting why something went wrong.
28 brokenConditions =
29 let
30 cudaTooOld = cudaOlder finalAttrs.passthru.featureRelease.minCudaVersion;
31 cudaTooNew =
32 (finalAttrs.passthru.featureRelease.maxCudaVersion != null)
33 && strings.versionOlder finalAttrs.passthru.featureRelease.maxCudaVersion cudaMajorMinorVersion;
34 cudnnVersionIsSpecified = finalAttrs.passthru.featureRelease.cudnnVersion != null;
35 cudnnVersionSpecified = versions.majorMinor finalAttrs.passthru.featureRelease.cudnnVersion;
36 cudnnVersionProvided = versions.majorMinor finalAttrs.passthru.cudnn.version;
37 cudnnTooOld =
38 cudnnVersionIsSpecified && (strings.versionOlder cudnnVersionProvided cudnnVersionSpecified);
39 cudnnTooNew =
40 cudnnVersionIsSpecified && (strings.versionOlder cudnnVersionSpecified cudnnVersionProvided);
41 in
42 prevAttrs.brokenConditions or { }
43 // {
44 "CUDA version is too old" = cudaTooOld;
45 "CUDA version is too new" = cudaTooNew;
46 "CUDNN version is too old" = cudnnTooOld;
47 "CUDNN version is too new" = cudnnTooNew;
48 };
49
50 src = requireFile {
51 name = finalAttrs.passthru.redistribRelease.filename;
52 inherit (finalAttrs.passthru.redistribRelease) hash;
53 message = ''
54 To use the TensorRT derivation, you must join the NVIDIA Developer Program and
55 download the ${finalAttrs.version} TAR package for CUDA ${cudaMajorMinorVersion} from
56 ${finalAttrs.meta.homepage}.
57
58 Once you have downloaded the file, add it to the store with the following
59 command, and try building this derivation again.
60
61 $ nix-store --add-fixed sha256 ${finalAttrs.passthru.redistribRelease.filename}
62 '';
63 };
64
65 # We need to look inside the extracted output to get the files we need.
66 sourceRoot = "TensorRT-${finalAttrs.version}";
67
68 buildInputs = prevAttrs.buildInputs or [ ] ++ [ (finalAttrs.passthru.cudnn.lib or null) ];
69
70 preInstall =
71 prevAttrs.preInstall or ""
72 + strings.optionalString (targetArch != "unsupported") ''
73 # Replace symlinks to bin and lib with the actual directories from targets.
74 for dir in bin lib; do
75 rm "$dir"
76 mv "targets/${targetArch}/$dir" "$dir"
77 done
78
79 # Remove broken symlinks
80 for dir in include samples; do
81 rm "targets/${targetArch}/$dir" || :
82 done
83 '';
84
85 # Tell autoPatchelf about runtime dependencies.
86 postFixup =
87 let
88 versionTriple = "${versions.majorMinor finalAttrs.version}.${versions.patch finalAttrs.version}";
89 in
90 prevAttrs.postFixup or ""
91 + ''
92 ${meta.getExe' patchelf "patchelf"} --add-needed libnvinfer.so \
93 "$lib/lib/libnvinfer.so.${versionTriple}" \
94 "$lib/lib/libnvinfer_plugin.so.${versionTriple}" \
95 "$lib/lib/libnvinfer_builder_resource.so.${versionTriple}"
96 '';
97
98 passthru = prevAttrs.passthru or { } // {
99 useCudatoolkitRunfile = strings.versionOlder cudaMajorMinorVersion "11.3.999";
100 # The CUDNN used with TensorRT.
101 # If null, the default cudnn derivation will be used.
102 # If a version is specified, the cudnn derivation with that version will be used,
103 # unless it is not available, in which case the default cudnn derivation will be used.
104 cudnn =
105 let
106 desiredName = _cuda.lib.mkVersionedName "cudnn" (
107 lib.versions.majorMinor finalAttrs.passthru.featureRelease.cudnnVersion
108 );
109 in
110 if finalAttrs.passthru.featureRelease.cudnnVersion == null || (cudaPackages ? desiredName) then
111 cudaPackages.cudnn
112 else
113 cudaPackages.${desiredName};
114 };
115
116 meta = prevAttrs.meta or { } // {
117 badPlatforms =
118 prevAttrs.meta.badPlatforms or [ ]
119 ++ lib.optionals (targetArch == "unsupported") [ hostPlatform.system ];
120 homepage = "https://developer.nvidia.com/tensorrt";
121 maintainers = prevAttrs.meta.maintainers or [ ] ++ [ maintainers.aidalgol ];
122 teams = prevAttrs.meta.teams or [ ];
123
124 # Building TensorRT on Hydra is impossible because of the non-redistributable
125 # license and because the source needs to be manually downloaded from the
126 # NVIDIA Developer Program (see requireFile above).
127 hydraPlatforms = lib.platforms.none;
128 };
129}