1{
2 lib,
3 stdenv,
4 fetchFromGitHub,
5 rocmUpdateScript,
6 cmake,
7 rocm-cmake,
8 rocprim,
9 clr,
10 gtest,
11 gbenchmark,
12 buildTests ? false,
13 buildBenchmarks ? false,
14 gpuTargets ? [ ],
15}:
16
17# CUB can also be used as a backend instead of rocPRIM.
18stdenv.mkDerivation (finalAttrs: {
19 pname = "hipcub";
20 version = "6.3.3";
21
22 outputs = [
23 "out"
24 ]
25 ++ lib.optionals buildTests [
26 "test"
27 ]
28 ++ lib.optionals buildBenchmarks [
29 "benchmark"
30 ];
31
32 src = fetchFromGitHub {
33 owner = "ROCm";
34 repo = "hipCUB";
35 rev = "rocm-${finalAttrs.version}";
36 hash = "sha256-uECOQWG9C64tg5YZdm9/3+fZXaZVGslu8vElK3m23GY=";
37 };
38
39 nativeBuildInputs = [
40 cmake
41 rocm-cmake
42 clr
43 ];
44
45 buildInputs = [
46 rocprim
47 ]
48 ++ lib.optionals buildTests [
49 gtest
50 ]
51 ++ lib.optionals buildBenchmarks [
52 gbenchmark
53 ];
54
55 cmakeFlags = [
56 "-DHIP_ROOT_DIR=${clr}"
57 # Manually define CMAKE_INSTALL_<DIR>
58 # See: https://github.com/NixOS/nixpkgs/pull/197838
59 "-DCMAKE_INSTALL_BINDIR=bin"
60 "-DCMAKE_INSTALL_LIBDIR=lib"
61 "-DCMAKE_INSTALL_INCLUDEDIR=include"
62 ]
63 ++ lib.optionals (gpuTargets != [ ]) [
64 "-DAMDGPU_TARGETS=${lib.concatStringsSep ";" gpuTargets}"
65 ]
66 ++ lib.optionals buildTests [
67 "-DBUILD_TEST=ON"
68 ]
69 ++ lib.optionals buildBenchmarks [
70 "-DBUILD_BENCHMARK=ON"
71 ];
72
73 postInstall =
74 lib.optionalString buildTests ''
75 mkdir -p $test/bin
76 mv $out/bin/test_* $test/bin
77 ''
78 + lib.optionalString buildBenchmarks ''
79 mkdir -p $benchmark/bin
80 mv $out/bin/benchmark_* $benchmark/bin
81 ''
82 + lib.optionalString (buildTests || buildBenchmarks) ''
83 rmdir $out/bin
84 '';
85
86 passthru.updateScript = rocmUpdateScript {
87 name = finalAttrs.pname;
88 inherit (finalAttrs.src) owner;
89 inherit (finalAttrs.src) repo;
90 };
91
92 meta = with lib; {
93 description = "Thin wrapper library on top of rocPRIM or CUB";
94 homepage = "https://github.com/ROCm/hipCUB";
95 license = with licenses; [ bsd3 ];
96 teams = [ teams.rocm ];
97 platforms = platforms.linux;
98 };
99})