1# Support matrix can be found at
2# https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn-880/support-matrix/index.html
3#
4# TODO(@connorbaker):
5# This is a very similar strategy to CUDA/CUDNN:
6#
7# - Get all versions supported by the current release of CUDA
8# - Build all of them
9# - Make the newest the default
10#
11# Unique twists:
12#
13# - Instead of providing different releases for each version of CUDA, CuTensor has multiple subdirectories in `lib`
14# -- one for each version of CUDA.
15{
16 cudaLib,
17 cudaMajorMinorVersion,
18 lib,
19 redistSystem,
20}:
21let
22 inherit (lib)
23 attrsets
24 lists
25 modules
26 versions
27 trivial
28 ;
29
30 redistName = "cutensor";
31 pname = "libcutensor";
32
33 cutensorVersions = [
34 "1.3.3"
35 "1.4.0"
36 "1.5.0"
37 "1.6.2"
38 "1.7.0"
39 "2.0.2"
40 "2.1.0"
41 ];
42
43 # Manifests :: { redistrib, feature }
44
45 # Each release of cutensor gets mapped to an evaluated module for that release.
46 # From there, we can get the min/max CUDA versions supported by that release.
47 # listOfManifests :: List Manifests
48 listOfManifests =
49 let
50 configEvaluator =
51 fullCutensorVersion:
52 modules.evalModules {
53 modules = [
54 ../modules
55 # We need to nest the manifests in a config.cutensor.manifests attribute so the
56 # module system can evaluate them.
57 {
58 cutensor.manifests = {
59 redistrib = trivial.importJSON (./manifests + "/redistrib_${fullCutensorVersion}.json");
60 feature = trivial.importJSON (./manifests + "/feature_${fullCutensorVersion}.json");
61 };
62 }
63 ];
64 };
65 # Un-nest the manifests attribute set.
66 releaseGrabber = evaluatedModules: evaluatedModules.config.cutensor.manifests;
67 in
68 lists.map (trivial.flip trivial.pipe [
69 configEvaluator
70 releaseGrabber
71 ]) cutensorVersions;
72
73 # Our cudaMajorMinorVersion tells us which version of CUDA we're building against.
74 # The subdirectories in lib/ tell us which versions of CUDA are supported.
75 # Typically the names will look like this:
76 #
77 # - 10.2
78 # - 11
79 # - 11.0
80 # - 12
81
82 # libPath :: String
83 libPath =
84 let
85 cudaMajorVersion = versions.major cudaMajorMinorVersion;
86 in
87 if cudaMajorMinorVersion == "10.2" then cudaMajorMinorVersion else cudaMajorVersion;
88
89 # A release is supported if it has a libPath that matches our CUDA version for our platform.
90 # LibPath are not constant across the same release -- one platform may support fewer
91 # CUDA versions than another.
92 # platformIsSupported :: Manifests -> Boolean
93 platformIsSupported =
94 { feature, redistrib, ... }:
95 (attrsets.attrByPath [
96 pname
97 redistSystem
98 ] null feature) != null;
99
100 # TODO(@connorbaker): With an auxiliary file keeping track of the CUDA versions each release supports,
101 # we could filter out releases that don't support our CUDA version.
102 # However, we don't have that currently, so we make a best-effort to try to build TensorRT with whatever
103 # libPath corresponds to our CUDA version.
104 # supportedManifests :: List Manifests
105 supportedManifests = builtins.filter platformIsSupported listOfManifests;
106
107 # Compute versioned attribute name to be used in this package set
108 # Patch version changes should not break the build, so we only use major and minor
109 # computeName :: RedistribRelease -> String
110 computeName =
111 { version, ... }: cudaLib.mkVersionedName redistName (lib.versions.majorMinor version);
112in
113final: _:
114let
115 # buildCutensorPackage :: Manifests -> AttrSet Derivation
116 buildCutensorPackage =
117 { redistrib, feature }:
118 let
119 drv = final.callPackage ../generic-builders/manifest.nix {
120 inherit pname redistName libPath;
121 redistribRelease = redistrib.${pname};
122 featureRelease = feature.${pname};
123 };
124 in
125 attrsets.nameValuePair (computeName redistrib.${pname}) drv;
126
127 extension =
128 let
129 nameOfNewest = computeName (lists.last supportedManifests).redistrib.${pname};
130 drvs = builtins.listToAttrs (lists.map buildCutensorPackage supportedManifests);
131 containsDefault = attrsets.optionalAttrs (drvs != { }) { cutensor = drvs.${nameOfNewest}; };
132 in
133 drvs // containsDefault;
134in
135extension