A game about forced loneliness, made by TACStudios
1#ifndef THREADING_SM6_IMPL
2#define THREADING_SM6_IMPL
3
4namespace Threading
5{
6 // Currently we only cover scalar types as at the time of writing this utility library we only needed emulation for those.
7 // Support for vector types is currently not there but can be added as needed (and this comment removed).
8 groupshared uint g_Scratch[THREADING_BLOCK_SIZE];
9
10 uint Wave::GetIndex() { return indexW; }
11
12 void Wave::Init(uint groupIndex)
13 {
14 indexG = groupIndex;
15 indexW = indexG / GetLaneCount();
16 }
17
18 // Note: The HLSL intrinsics should be correctly replaced by console-specific intrinsics by our API library.
19 #define DEFINE_API_FOR_TYPE(TYPE) \
20 bool Wave::AllEqual(TYPE v) { return WaveActiveAllEqual(v); } \
21 TYPE Wave::Product(TYPE v) { return WaveActiveProduct(v); } \
22 TYPE Wave::Sum(TYPE v) { return WaveActiveSum(v); } \
23 TYPE Wave::Max(TYPE v) { return WaveActiveMax(v); } \
24 TYPE Wave::Min(TYPE v) { return WaveActiveMin(v); } \
25 TYPE Wave::InclusivePrefixSum (TYPE v) { return WavePrefixSum(v) + v; } \
26 TYPE Wave::InclusivePrefixProduct (TYPE v) { return WavePrefixProduct(v) * v; } \
27 TYPE Wave::PrefixSum(TYPE v) { return WavePrefixSum(v); } \
28 TYPE Wave::PrefixProduct(TYPE v) { return WavePrefixProduct(v); } \
29 TYPE Wave::ReadLaneAt(TYPE v, uint i) { return WaveReadLaneAt(v, i); } \
30 TYPE Wave::ReadLaneFirst(TYPE v) { return WaveReadLaneFirst(v); } \
31
32 // Currently just support scalars.
33 DEFINE_API_FOR_TYPE(uint)
34 DEFINE_API_FOR_TYPE(int)
35 DEFINE_API_FOR_TYPE(float)
36
37 // The following intrinsics need only be declared once.
38 uint Wave::GetLaneCount() { return WaveGetLaneCount(); }
39 uint Wave::GetLaneIndex() { return WaveGetLaneIndex(); }
40 bool Wave::IsFirstLane() { return WaveIsFirstLane(); }
41 bool Wave::AllTrue(bool v) { return WaveActiveAllTrue(v); }
42 bool Wave::AnyTrue(bool v) { return WaveActiveAnyTrue(v); }
43 uint4 Wave::Ballot(bool v) { return WaveActiveBallot(v); }
44 uint Wave::CountBits(bool v) { return WaveActiveCountBits(v); }
45 uint Wave::PrefixCountBits(bool v) { return WavePrefixCountBits(v); }
46 uint Wave::And(uint v) { return WaveActiveBitAnd(v); }
47 uint Wave::Or (uint v) { return WaveActiveBitOr(v); }
48 uint Wave::Xor(uint v) { return WaveActiveBitXor(v); }
49
50#define EMULATED_GROUP_REDUCE(TYPE, OP) \
51 GroupMemoryBarrierWithGroupSync(); \
52 g_Scratch[groupIndex] = asuint(v); \
53 GroupMemoryBarrierWithGroupSync(); \
54 [unroll] \
55 for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \
56 { \
57 if (groupIndex < s) \
58 g_Scratch[groupIndex] = asuint(as##TYPE(g_Scratch[groupIndex]) OP as##TYPE(g_Scratch[groupIndex + s])); \
59 GroupMemoryBarrierWithGroupSync(); \
60 } \
61 return as##TYPE(g_Scratch[0]); \
62
63#define EMULATED_GROUP_REDUCE_CMP(TYPE, OP) \
64 GroupMemoryBarrierWithGroupSync(); \
65 g_Scratch[groupIndex] = asuint(v); \
66 GroupMemoryBarrierWithGroupSync(); \
67 [unroll] \
68 for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \
69 { \
70 if (groupIndex < s) \
71 g_Scratch[groupIndex] = asuint(OP(as##TYPE(g_Scratch[groupIndex]), as##TYPE(g_Scratch[groupIndex + s]))); \
72 GroupMemoryBarrierWithGroupSync(); \
73 } \
74 return as##TYPE(g_Scratch[0]); \
75
76#define EMULATED_GROUP_PREFIX(TYPE, OP, FILL_VALUE) \
77 GroupMemoryBarrierWithGroupSync(); \
78 g_Scratch[groupIndex] = asuint(v); \
79 GroupMemoryBarrierWithGroupSync(); \
80 [unroll] \
81 for (uint s = 1u; s < THREADING_BLOCK_SIZE; s <<= 1u) \
82 { \
83 TYPE nv = FILL_VALUE; \
84 if (groupIndex >= s) \
85 { \
86 nv = as##TYPE(g_Scratch[groupIndex - s]); \
87 } \
88 nv = as##TYPE(g_Scratch[groupIndex]) OP nv; \
89 GroupMemoryBarrierWithGroupSync(); \
90 g_Scratch[groupIndex] = asuint(nv); \
91 GroupMemoryBarrierWithGroupSync(); \
92 } \
93 TYPE result = FILL_VALUE; \
94 if (groupIndex > 0u) \
95 result = as##TYPE(g_Scratch[groupIndex - 1]); \
96 return result; \
97
98 uint Group::GetWaveCount()
99 {
100 return THREADING_BLOCK_SIZE / WaveGetLaneCount();
101 }
102
103 #define DEFINE_API_FOR_TYPE_GROUP(TYPE) \
104 bool Group::AllEqual(TYPE v) { return AllTrue(ReadThreadFirst(v) == v); } \
105 TYPE Group::Product(TYPE v) { EMULATED_GROUP_REDUCE(TYPE, *) } \
106 TYPE Group::Sum(TYPE v) { EMULATED_GROUP_REDUCE(TYPE, +) } \
107 TYPE Group::Max(TYPE v) { EMULATED_GROUP_REDUCE_CMP(TYPE, max) } \
108 TYPE Group::Min(TYPE v) { EMULATED_GROUP_REDUCE_CMP(TYPE, min) } \
109 TYPE Group::InclusivePrefixSum (TYPE v) { return PrefixSum(v) + v; } \
110 TYPE Group::InclusivePrefixProduct (TYPE v) { return PrefixProduct(v) * v; } \
111 TYPE Group::PrefixSum (TYPE v) { EMULATED_GROUP_PREFIX(TYPE, +, (TYPE)0) } \
112 TYPE Group::PrefixProduct (TYPE v) { EMULATED_GROUP_PREFIX(TYPE, *, (TYPE)1) } \
113 TYPE Group::ReadThreadAt(TYPE v, uint i) { GroupMemoryBarrierWithGroupSync(); g_Scratch[groupIndex] = asuint(v); GroupMemoryBarrierWithGroupSync(); return as##TYPE(g_Scratch[i]); } \
114 TYPE Group::ReadThreadFirst(TYPE v) { return ReadThreadAt(v, 0u); } \
115 TYPE Group::ReadThreadShuffle(TYPE v, uint i) { return ReadThreadAt(v, i); } \
116
117 // Currently just support scalars.
118 DEFINE_API_FOR_TYPE_GROUP(uint)
119 DEFINE_API_FOR_TYPE_GROUP(int)
120 DEFINE_API_FOR_TYPE_GROUP(float)
121
122 // The following emulated functions need only be declared once.
123 uint Group::GetThreadCount() { return THREADING_BLOCK_SIZE; }
124 uint Group::GetThreadIndex() { return groupIndex; }
125 bool Group::IsFirstThread() { return groupIndex == 0u; }
126 bool Group::AllTrue(bool v) { return And(v) != 0u; }
127 bool Group::AnyTrue(bool v) { return Or (v) != 0u; }
128 uint Group::PrefixCountBits(bool v) { return PrefixSum((uint)v); }
129 uint Group::And(uint v) { EMULATED_GROUP_REDUCE(uint, &) }
130 uint Group::Or (uint v) { EMULATED_GROUP_REDUCE(uint, |) }
131 uint Group::Xor(uint v) { EMULATED_GROUP_REDUCE(uint, ^) }
132
133 GroupBallot Group::Ballot(bool v)
134 {
135 uint indexDw = groupIndex % 32u;
136 uint offsetDw = (groupIndex / 32u) * 32u;
137 uint indexScratch = offsetDw + indexDw;
138
139 GroupMemoryBarrierWithGroupSync();
140
141 g_Scratch[groupIndex] = v << indexDw;
142
143 GroupMemoryBarrierWithGroupSync();
144
145 [unroll]
146 for (uint s = min(THREADING_BLOCK_SIZE / 2u, 16u); s > 0u; s >>= 1u)
147 {
148 if (indexDw < s)
149 g_Scratch[indexScratch] = g_Scratch[indexScratch] | g_Scratch[indexScratch + s];
150
151 GroupMemoryBarrierWithGroupSync();
152 }
153
154 GroupBallot ballot = (GroupBallot)0;
155
156 // Explicitly mark this loop as "unroll" to avoid warnings about assigning to an array reference
157 [unroll]
158 for (uint dwordIndex = 0; dwordIndex < _THREADING_GROUP_BALLOT_DWORDS; ++dwordIndex)
159 {
160 ballot.dwords[dwordIndex] = g_Scratch[dwordIndex * 32];
161 }
162
163 return ballot;
164 }
165
166 uint Group::CountBits(bool v)
167 {
168 return Ballot(v).CountBits();
169 }
170}
171#endif