1{
2 lib,
3 cudaLib,
4 cudaMajorMinorVersion,
5 redistSystem,
6 stdenv,
7 # Builder-specific arguments
8 # Short package name (e.g., "cuda_cccl")
9 # pname : String
10 pname,
11 # Common name (e.g., "cutensor" or "cudnn") -- used in the URL.
12 # Also known as the Redistributable Name.
13 # redistName : String,
14 redistName,
15 # releasesModule :: Path
16 # A path to a module which provides a `releases` attribute
17 releasesModule,
18 # shims :: Path
19 # A path to a module which provides a `shims` attribute
20 # The redistribRelease is only used in ./manifest.nix for the package version
21 # and the package description (which NVIDIA's manifest calls the "name").
22 # It's also used for fetching the source, but we override that since we can't
23 # re-use that portion of the functionality (different URLs, etc.).
24 # The featureRelease is used to populate meta.platforms (by way of looking at the attribute names), determine the
25 # outputs of the package, and provide additional package-specific constraints (e.g., min/max supported CUDA versions,
26 # required versions of other packages, etc.).
27 # shimFn :: {package, redistSystem} -> AttrSet
28 shimsFn ? (throw "shimsFn must be provided"),
29}:
30let
31 evaluatedModules = lib.modules.evalModules {
32 modules = [
33 ../modules
34 releasesModule
35 ];
36 };
37
38 # NOTE: Important types:
39 # - Releases: ../modules/${pname}/releases/releases.nix
40 # - Package: ../modules/${pname}/releases/package.nix
41
42 # Check whether a package supports our CUDA version.
43 # satisfiesCudaVersion :: Package -> Bool
44 satisfiesCudaVersion =
45 package:
46 lib.versionAtLeast cudaMajorMinorVersion package.minCudaVersion
47 && lib.versionAtLeast package.maxCudaVersion cudaMajorMinorVersion;
48
49 # FIXME: do this at the module system level
50 propagatePlatforms = lib.mapAttrs (redistSystem: lib.map (p: { inherit redistSystem; } // p));
51
52 # Releases for all platforms and all CUDA versions.
53 allReleases = propagatePlatforms evaluatedModules.config.${pname}.releases;
54
55 # Releases for all platforms and our CUDA version.
56 allReleases' = lib.mapAttrs (_: lib.filter satisfiesCudaVersion) allReleases;
57
58 # Packages for all platforms and our CUDA versions.
59 allPackages = lib.concatLists (lib.attrValues allReleases');
60
61 packageOlder = p1: p2: lib.versionOlder p1.version p2.version;
62 packageSupportedPlatform = p: p.redistSystem == redistSystem;
63
64 # Compute versioned attribute name to be used in this package set
65 # Patch version changes should not break the build, so we only use major and minor
66 # computeName :: Package -> String
67 computeName = { version, ... }: cudaLib.mkVersionedName pname (lib.versions.majorMinor version);
68
69 # The newest package for each major-minor version, with newest first.
70 # newestPackages :: List Package
71 newestPackages =
72 let
73 newestForEachMajorMinorVersion = lib.foldl' (
74 newestPackages: package:
75 let
76 majorMinorVersion = lib.versions.majorMinor package.version;
77 existingPackage = newestPackages.${majorMinorVersion} or null;
78 in
79 newestPackages
80 // {
81 ${majorMinorVersion} =
82 # Only keep the existing package if it is newer than the one we are considering or it is supported on the
83 # current platform and the one we are considering is not.
84 if
85 existingPackage != null
86 && (
87 packageOlder package existingPackage
88 || (!packageSupportedPlatform package && packageSupportedPlatform existingPackage)
89 )
90 then
91 existingPackage
92 else
93 package;
94 }
95 ) { } allPackages;
96 in
97 # Sort the packages by version so the newest is first.
98 # NOTE: builtins.sort requires a strict weak ordering, so we must use versionOlder rather than versionAtLeast.
99 # See https://github.com/NixOS/nixpkgs/commit/9fd753ea84e5035b357a275324e7fd7ccfb1fc77.
100 lib.sort (lib.flip packageOlder) (lib.attrValues newestForEachMajorMinorVersion);
101
102 extension =
103 final: _:
104 let
105 # Builds our package into derivation and wraps it in a nameValuePair, where the name is the versioned name
106 # of the package.
107 buildPackage =
108 package:
109 let
110 shims = final.callPackage shimsFn { inherit package redistSystem; };
111 name = computeName package;
112 drv = final.callPackage ./manifest.nix {
113 inherit pname redistName;
114 inherit (shims) redistribRelease featureRelease;
115 };
116 in
117 lib.nameValuePair name drv;
118
119 # versionedDerivations :: AttrSet Derivation
120 versionedDerivations = builtins.listToAttrs (lib.map buildPackage newestPackages);
121
122 defaultDerivation = {
123 ${pname} = (buildPackage (lib.head newestPackages)).value;
124 };
125 in
126 # NOTE: Must condition on the length of newestPackages to avoid non-total function lib.head aborting if
127 # newestPackages is empty.
128 lib.optionalAttrs (lib.length newestPackages > 0) (versionedDerivations // defaultDerivation);
129in
130extension