tangled
alpha
login
or
join now
tjh.dev
/
nixpkgs
0
fork
atom
nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
0
fork
atom
overview
issues
pulls
pipelines
rocmPackages.clr: backport bf16 compilation fix
Yaroslav Bolyukin
2 years ago
bbf286d9
24e3450a
+1015
3 changed files
expand all
collapse all
unified
split
pkgs
development
rocm-modules
5
clr
add-missing-operators.patch
default.nix
static-functions.patch
+979
pkgs/development/rocm-modules/5/clr/add-missing-operators.patch
reviewed
···
1
1
+
From 86bd518981b364c138f9901b28a529899d8654f3 Mon Sep 17 00:00:00 2001
2
2
+
From: Jatin Chaudhary <JatinJaikishan.Chaudhary@amd.com>
3
3
+
Date: Wed, 11 Oct 2023 23:19:29 +0100
4
4
+
Subject: [PATCH] SWDEV-367537 - Add missing operators to __hip_bfloat16
5
5
+
implementation
6
6
+
7
7
+
Add __host__ and __device__ to bunch of operator/function matching CUDA
8
8
+
Fix some bugs seen in __hisinf
9
9
+
10
10
+
Change-Id: I9e67e3e3eb2083b463158f3e250e5221c89b2896
11
11
+
---
12
12
+
hipamd/include/hip/amd_detail/amd_hip_bf16.h | 533 ++++++++++++++++---
13
13
+
1 file changed, 446 insertions(+), 87 deletions(-)
14
14
+
15
15
+
diff --git a/hipamd/include/hip/amd_detail/amd_hip_bf16.h b/hipamd/include/hip/amd_detail/amd_hip_bf16.h
16
16
+
index 757cb7ada..b15ea3b65 100644
17
17
+
--- a/hipamd/include/hip/amd_detail/amd_hip_bf16.h
18
18
+
+++ b/hipamd/include/hip/amd_detail/amd_hip_bf16.h
19
19
+
@@ -96,10 +96,20 @@
20
20
+
#if defined(__HIPCC_RTC__)
21
21
+
#define __HOST_DEVICE__ __device__
22
22
+
#else
23
23
+
+#include <algorithm>
24
24
+
#include <climits>
25
25
+
-#define __HOST_DEVICE__ __host__ __device__
26
26
+
+#include <cmath>
27
27
+
+#define __HOST_DEVICE__ __host__ __device__ inline
28
28
+
#endif
29
29
+
30
30
+
+#define HIPRT_ONE_BF16 __float2bfloat16(1.0f)
31
31
+
+#define HIPRT_ZERO_BF16 __float2bfloat16(0.0f)
32
32
+
+#define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)
33
33
+
+#define HIPRT_MAX_NORMAL_BF16 __ushort_as_bfloat16((unsigned short)0x7F7FU)
34
34
+
+#define HIPRT_MIN_DENORM_BF16 __ushort_as_bfloat16((unsigned short)0x0001U)
35
35
+
+#define HIPRT_NAN_BF16 __ushort_as_bfloat16((unsigned short)0x7FFFU)
36
36
+
+#define HIPRT_NEG_ZERO_BF16 __ushort_as_bfloat16((unsigned short)0x8000U)
37
37
+
+
38
38
+
// Since we are using unsigned short to represent data in bfloat16, it can be of different sizes on
39
39
+
// different machines. These naive checks should prevent some undefined behavior on systems which
40
40
+
// have different sizes for basic types.
41
41
+
@@ -189,7 +199,7 @@ __HOST_DEVICE__ float2 __bfloat1622float2(const __hip_bfloat162 a) {
42
42
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
43
43
+
* \brief Moves bfloat16 value to bfloat162
44
44
+
*/
45
45
+
-__device__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) {
46
46
+
+__HOST_DEVICE__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) {
47
47
+
return __hip_bfloat162{a, a};
48
48
+
}
49
49
+
50
50
+
@@ -197,13 +207,13 @@ __device__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) {
51
51
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
52
52
+
* \brief Reinterprets bits in a __hip_bfloat16 as a signed short integer
53
53
+
*/
54
54
+
-__device__ short int __bfloat16_as_short(const __hip_bfloat16 h) { return (short)h.data; }
55
55
+
+__HOST_DEVICE__ short int __bfloat16_as_short(const __hip_bfloat16 h) { return (short)h.data; }
56
56
+
57
57
+
/**
58
58
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
59
59
+
* \brief Reinterprets bits in a __hip_bfloat16 as an unsigned signed short integer
60
60
+
*/
61
61
+
-__device__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) { return h.data; }
62
62
+
+__HOST_DEVICE__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) { return h.data; }
63
63
+
64
64
+
/**
65
65
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
66
66
+
@@ -225,7 +235,7 @@ __HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
67
67
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
68
68
+
* \brief Combine two __hip_bfloat16 to __hip_bfloat162
69
69
+
*/
70
70
+
-__device__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a, const __hip_bfloat16 b) {
71
71
+
+__HOST_DEVICE__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a, const __hip_bfloat16 b) {
72
72
+
return __hip_bfloat162{a, b};
73
73
+
}
74
74
+
75
75
+
@@ -233,13 +243,13 @@ __device__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a, const __hi
76
76
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
77
77
+
* \brief Returns high 16 bits of __hip_bfloat162
78
78
+
*/
79
79
+
-__device__ __hip_bfloat16 __high2bfloat16(const __hip_bfloat162 a) { return a.y; }
80
80
+
+__HOST_DEVICE__ __hip_bfloat16 __high2bfloat16(const __hip_bfloat162 a) { return a.y; }
81
81
+
82
82
+
/**
83
83
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
84
84
+
* \brief Returns high 16 bits of __hip_bfloat162
85
85
+
*/
86
86
+
-__device__ __hip_bfloat162 __high2bfloat162(const __hip_bfloat162 a) {
87
87
+
+__HOST_DEVICE__ __hip_bfloat162 __high2bfloat162(const __hip_bfloat162 a) {
88
88
+
return __hip_bfloat162{a.y, a.y};
89
89
+
}
90
90
+
91
91
+
@@ -253,7 +263,8 @@ __HOST_DEVICE__ float __high2float(const __hip_bfloat162 a) { return __bfloat162
92
92
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
93
93
+
* \brief Extracts high 16 bits from each and combines them
94
94
+
*/
95
95
+
-__device__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) {
96
96
+
+__HOST_DEVICE__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a,
97
97
+
+ const __hip_bfloat162 b) {
98
98
+
return __hip_bfloat162{a.y, b.y};
99
99
+
}
100
100
+
101
101
+
@@ -261,13 +272,13 @@ __device__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a, const __hi
102
102
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
103
103
+
* \brief Returns low 16 bits of __hip_bfloat162
104
104
+
*/
105
105
+
-__device__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { return a.x; }
106
106
+
+__HOST_DEVICE__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { return a.x; }
107
107
+
108
108
+
/**
109
109
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
110
110
+
* \brief Returns low 16 bits of __hip_bfloat162
111
111
+
*/
112
112
+
-__device__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) {
113
113
+
+__HOST_DEVICE__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) {
114
114
+
return __hip_bfloat162{a.x, a.x};
115
115
+
}
116
116
+
117
117
+
@@ -281,7 +292,7 @@ __HOST_DEVICE__ float __low2float(const __hip_bfloat162 a) { return __bfloat162f
118
118
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
119
119
+
* \brief Swaps both halves
120
120
+
*/
121
121
+
-__device__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) {
122
122
+
+__HOST_DEVICE__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) {
123
123
+
return __hip_bfloat162{a.y, a.x};
124
124
+
}
125
125
+
126
126
+
@@ -289,7 +300,7 @@ __device__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) {
127
127
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
128
128
+
* \brief Extracts low 16 bits from each and combines them
129
129
+
*/
130
130
+
-__device__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) {
131
131
+
+__HOST_DEVICE__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) {
132
132
+
return __hip_bfloat162{a.x, b.x};
133
133
+
}
134
134
+
135
135
+
@@ -297,7 +308,7 @@ __device__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a, const __hip
136
136
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
137
137
+
* \brief Reinterprets short int into a bfloat16
138
138
+
*/
139
139
+
-__device__ __hip_bfloat16 __short_as_bfloat16(const short int a) {
140
140
+
+__HOST_DEVICE__ __hip_bfloat16 __short_as_bfloat16(const short int a) {
141
141
+
return __hip_bfloat16{(unsigned short)a};
142
142
+
}
143
143
+
144
144
+
@@ -305,7 +316,7 @@ __device__ __hip_bfloat16 __short_as_bfloat16(const short int a) {
145
145
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
146
146
+
* \brief Reinterprets unsigned short int into a bfloat16
147
147
+
*/
148
148
+
-__device__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) {
149
149
+
+__HOST_DEVICE__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) {
150
150
+
return __hip_bfloat16{a};
151
151
+
}
152
152
+
153
153
+
@@ -314,7 +325,7 @@ __device__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) {
154
154
+
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
155
155
+
* \brief Adds two bfloat16 values
156
156
+
*/
157
157
+
-__device__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) {
158
158
+
+__HOST_DEVICE__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) {
159
159
+
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b));
160
160
+
}
161
161
+
162
162
+
@@ -322,7 +333,7 @@ __device__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b)
163
163
+
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
164
164
+
* \brief Subtracts two bfloat16 values
165
165
+
*/
166
166
+
-__device__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) {
167
167
+
+__HOST_DEVICE__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) {
168
168
+
return __float2bfloat16(__bfloat162float(a) - __bfloat162float(b));
169
169
+
}
170
170
+
171
171
+
@@ -330,7 +341,7 @@ __device__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b)
172
172
+
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
173
173
+
* \brief Divides two bfloat16 values
174
174
+
*/
175
175
+
-__device__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) {
176
176
+
+__HOST_DEVICE__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) {
177
177
+
return __float2bfloat16(__bfloat162float(a) / __bfloat162float(b));
178
178
+
}
179
179
+
180
180
+
@@ -348,7 +359,7 @@ __device__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip_bfloat16 b,
181
181
+
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
182
182
+
* \brief Multiplies two bfloat16 values
183
183
+
*/
184
184
+
-__device__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) {
185
185
+
+__HOST_DEVICE__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) {
186
186
+
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b));
187
187
+
}
188
188
+
189
189
+
@@ -356,7 +367,7 @@ __device__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b)
190
190
+
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
191
191
+
* \brief Negate a bfloat16 value
192
192
+
*/
193
193
+
-__device__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) {
194
194
+
+__HOST_DEVICE__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) {
195
195
+
auto ret = a;
196
196
+
ret.data ^= 0x8000;
197
197
+
return ret;
198
198
+
@@ -366,7 +377,7 @@ __device__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) {
199
199
+
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
200
200
+
* \brief Returns absolute of a bfloat16
201
201
+
*/
202
202
+
-__device__ __hip_bfloat16 __habs(const __hip_bfloat16 a) {
203
203
+
+__HOST_DEVICE__ __hip_bfloat16 __habs(const __hip_bfloat16 a) {
204
204
+
auto ret = a;
205
205
+
ret.data &= 0x7FFF;
206
206
+
return ret;
207
207
+
@@ -376,7 +387,7 @@ __device__ __hip_bfloat16 __habs(const __hip_bfloat16 a) {
208
208
+
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
209
209
+
* \brief Divides bfloat162 values
210
210
+
*/
211
211
+
-__device__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, const __hip_bfloat162 b) {
212
212
+
+__HOST_DEVICE__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, const __hip_bfloat162 b) {
213
213
+
return __hip_bfloat162{__float2bfloat16(__bfloat162float(a.x) / __bfloat162float(b.x)),
214
214
+
__float2bfloat16(__bfloat162float(a.y) / __bfloat162float(b.y))};
215
215
+
}
216
216
+
@@ -385,7 +396,7 @@ __device__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, const __hip_bfloat16
217
217
+
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
218
218
+
* \brief Returns absolute of a bfloat162
219
219
+
*/
220
220
+
-__device__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) {
221
221
+
+__HOST_DEVICE__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) {
222
222
+
return __hip_bfloat162{__habs(a.x), __habs(a.y)};
223
223
+
}
224
224
+
225
225
+
@@ -393,7 +404,7 @@ __device__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) {
226
226
+
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
227
227
+
* \brief Adds two bfloat162 values
228
228
+
*/
229
229
+
-__device__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
230
230
+
+__HOST_DEVICE__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
231
231
+
return __hip_bfloat162{__hadd(a.x, b.x), __hadd(a.y, b.y)};
232
232
+
}
233
233
+
234
234
+
@@ -410,7 +421,7 @@ __device__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, const __hip_bfloat16
235
235
+
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
236
236
+
* \brief Multiplies two bfloat162 values
237
237
+
*/
238
238
+
-__device__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
239
239
+
+__HOST_DEVICE__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
240
240
+
return __hip_bfloat162{__hmul(a.x, b.x), __hmul(a.y, b.y)};
241
241
+
}
242
242
+
243
243
+
@@ -418,7 +429,7 @@ __device__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat16
244
244
+
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
245
245
+
* \brief Converts a bfloat162 into negative
246
246
+
*/
247
247
+
-__device__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) {
248
248
+
+__HOST_DEVICE__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) {
249
249
+
return __hip_bfloat162{__hneg(a.x), __hneg(a.y)};
250
250
+
}
251
251
+
252
252
+
@@ -426,15 +437,251 @@ __device__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) {
253
253
+
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
254
254
+
* \brief Subtracts two bfloat162 values
255
255
+
*/
256
256
+
-__device__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
257
257
+
+__HOST_DEVICE__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
258
258
+
return __hip_bfloat162{__hsub(a.x, b.x), __hsub(a.y, b.y)};
259
259
+
}
260
260
+
261
261
+
+/**
262
262
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
263
263
+
+ * \brief Operator to multiply two __hip_bfloat16 numbers
264
264
+
+ */
265
265
+
+__HOST_DEVICE__ __hip_bfloat16 operator*(const __hip_bfloat16& l, const __hip_bfloat16& r) {
266
266
+
+ return __hmul(l, r);
267
267
+
+}
268
268
+
+
269
269
+
+/**
270
270
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
271
271
+
+ * \brief Operator to multiply-assign two __hip_bfloat16 numbers
272
272
+
+ */
273
273
+
+__HOST_DEVICE__ __hip_bfloat16 operator*=(__hip_bfloat16& l, const __hip_bfloat16& r) {
274
274
+
+ l = __hmul(l, r);
275
275
+
+ return l;
276
276
+
+}
277
277
+
+
278
278
+
+/**
279
279
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
280
280
+
+ * \brief Operator to unary+ on a __hip_bfloat16 number
281
281
+
+ */
282
282
+
+__HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l) { return l; }
283
283
+
+
284
284
+
+/**
285
285
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
286
286
+
+ * \brief Operator to add two __hip_bfloat16 numbers
287
287
+
+ */
288
288
+
+__HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l, const __hip_bfloat16& r) {
289
289
+
+ return __hadd(l, r);
290
290
+
+}
291
291
+
+
292
292
+
+/**
293
293
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
294
294
+
+ * \brief Operator to negate a __hip_bfloat16 number
295
295
+
+ */
296
296
+
+__HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l) { return __hneg(l); }
297
297
+
+
298
298
+
+/**
299
299
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
300
300
+
+ * \brief Operator to subtract two __hip_bfloat16 numbers
301
301
+
+ */
302
302
+
+__HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l, const __hip_bfloat16& r) {
303
303
+
+ return __hsub(l, r);
304
304
+
+}
305
305
+
+
306
306
+
+/**
307
307
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
308
308
+
+ * \brief Operator to post increment a __hip_bfloat16 number
309
309
+
+ */
310
310
+
+__HOST_DEVICE__ __hip_bfloat16 operator++(__hip_bfloat16& l, const int) {
311
311
+
+ auto ret = l;
312
312
+
+ l = __hadd(l, HIPRT_ONE_BF16);
313
313
+
+ return ret;
314
314
+
+}
315
315
+
+
316
316
+
+/**
317
317
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
318
318
+
+ * \brief Operator to pre increment a __hip_bfloat16 number
319
319
+
+ */
320
320
+
+__HOST_DEVICE__ __hip_bfloat16& operator++(__hip_bfloat16& l) {
321
321
+
+ l = __hadd(l, HIPRT_ONE_BF16);
322
322
+
+ return l;
323
323
+
+}
324
324
+
+
325
325
+
+/**
326
326
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
327
327
+
+ * \brief Operator to post decrement a __hip_bfloat16 number
328
328
+
+ */
329
329
+
+__HOST_DEVICE__ __hip_bfloat16 operator--(__hip_bfloat16& l, const int) {
330
330
+
+ auto ret = l;
331
331
+
+ l = __hsub(l, HIPRT_ONE_BF16);
332
332
+
+ return ret;
333
333
+
+}
334
334
+
+
335
335
+
+/**
336
336
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
337
337
+
+ * \brief Operator to pre decrement a __hip_bfloat16 number
338
338
+
+ */
339
339
+
+__HOST_DEVICE__ __hip_bfloat16& operator--(__hip_bfloat16& l) {
340
340
+
+ l = __hsub(l, HIPRT_ONE_BF16);
341
341
+
+ return l;
342
342
+
+}
343
343
+
+
344
344
+
+/**
345
345
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
346
346
+
+ * \brief Operator to add-assign two __hip_bfloat16 numbers
347
347
+
+ */
348
348
+
+__HOST_DEVICE__ __hip_bfloat16& operator+=(__hip_bfloat16& l, const __hip_bfloat16& r) {
349
349
+
+ l = l + r;
350
350
+
+ return l;
351
351
+
+}
352
352
+
+
353
353
+
+/**
354
354
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
355
355
+
+ * \brief Operator to subtract-assign two __hip_bfloat16 numbers
356
356
+
+ */
357
357
+
+__HOST_DEVICE__ __hip_bfloat16& operator-=(__hip_bfloat16& l, const __hip_bfloat16& r) {
358
358
+
+ l = l - r;
359
359
+
+ return l;
360
360
+
+}
361
361
+
+
362
362
+
+/**
363
363
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
364
364
+
+ * \brief Operator to divide two __hip_bfloat16 numbers
365
365
+
+ */
366
366
+
+__HOST_DEVICE__ __hip_bfloat16 operator/(const __hip_bfloat16& l, const __hip_bfloat16& r) {
367
367
+
+ return __hdiv(l, r);
368
368
+
+}
369
369
+
+
370
370
+
+/**
371
371
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
372
372
+
+ * \brief Operator to divide-assign two __hip_bfloat16 numbers
373
373
+
+ */
374
374
+
+__HOST_DEVICE__ __hip_bfloat16& operator/=(__hip_bfloat16& l, const __hip_bfloat16& r) {
375
375
+
+ l = l / r;
376
376
+
+ return l;
377
377
+
+}
378
378
+
+
379
379
+
+/**
380
380
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
381
381
+
+ * \brief Operator to multiply two __hip_bfloat162 numbers
382
382
+
+ */
383
383
+
+__HOST_DEVICE__ __hip_bfloat162 operator*(const __hip_bfloat162& l, const __hip_bfloat162& r) {
384
384
+
+ return __hmul2(l, r);
385
385
+
+}
386
386
+
+
387
387
+
+/**
388
388
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
389
389
+
+ * \brief Operator to multiply-assign two __hip_bfloat162 numbers
390
390
+
+ */
391
391
+
+__HOST_DEVICE__ __hip_bfloat162 operator*=(__hip_bfloat162& l, const __hip_bfloat162& r) {
392
392
+
+ l = __hmul2(l, r);
393
393
+
+ return l;
394
394
+
+}
395
395
+
+
396
396
+
+/**
397
397
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
398
398
+
+ * \brief Operator to unary+ on a __hip_bfloat162 number
399
399
+
+ */
400
400
+
+__HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l) { return l; }
401
401
+
+
402
402
+
+/**
403
403
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
404
404
+
+ * \brief Operator to add two __hip_bfloat162 numbers
405
405
+
+ */
406
406
+
+__HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l, const __hip_bfloat162& r) {
407
407
+
+ return __hadd2(l, r);
408
408
+
+}
409
409
+
+
410
410
+
+/**
411
411
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
412
412
+
+ * \brief Operator to negate a __hip_bfloat162 number
413
413
+
+ */
414
414
+
+__HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l) { return __hneg2(l); }
415
415
+
+
416
416
+
+/**
417
417
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
418
418
+
+ * \brief Operator to subtract two __hip_bfloat162 numbers
419
419
+
+ */
420
420
+
+__HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l, const __hip_bfloat162& r) {
421
421
+
+ return __hsub2(l, r);
422
422
+
+}
423
423
+
+
424
424
+
+/**
425
425
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
426
426
+
+ * \brief Operator to post increment a __hip_bfloat162 number
427
427
+
+ */
428
428
+
+__HOST_DEVICE__ __hip_bfloat162 operator++(__hip_bfloat162& l, const int) {
429
429
+
+ auto ret = l;
430
430
+
+ l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16});
431
431
+
+ return ret;
432
432
+
+}
433
433
+
+
434
434
+
+/**
435
435
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
436
436
+
+ * \brief Operator to pre increment a __hip_bfloat162 number
437
437
+
+ */
438
438
+
+__HOST_DEVICE__ __hip_bfloat162& operator++(__hip_bfloat162& l) {
439
439
+
+ l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16});
440
440
+
+ return l;
441
441
+
+}
442
442
+
+
443
443
+
+/**
444
444
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
445
445
+
+ * \brief Operator to post decrement a __hip_bfloat162 number
446
446
+
+ */
447
447
+
+__HOST_DEVICE__ __hip_bfloat162 operator--(__hip_bfloat162& l, const int) {
448
448
+
+ auto ret = l;
449
449
+
+ l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16});
450
450
+
+ return ret;
451
451
+
+}
452
452
+
+
453
453
+
+/**
454
454
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
455
455
+
+ * \brief Operator to pre decrement a __hip_bfloat162 number
456
456
+
+ */
457
457
+
+__HOST_DEVICE__ __hip_bfloat162& operator--(__hip_bfloat162& l) {
458
458
+
+ l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16});
459
459
+
+ return l;
460
460
+
+}
461
461
+
+
462
462
+
+/**
463
463
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
464
464
+
+ * \brief Operator to add-assign two __hip_bfloat162 numbers
465
465
+
+ */
466
466
+
+__HOST_DEVICE__ __hip_bfloat162& operator+=(__hip_bfloat162& l, const __hip_bfloat162& r) {
467
467
+
+ l = l + r;
468
468
+
+ return l;
469
469
+
+}
470
470
+
+
471
471
+
+/**
472
472
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
473
473
+
+ * \brief Operator to subtract-assign two __hip_bfloat162 numbers
474
474
+
+ */
475
475
+
+__HOST_DEVICE__ __hip_bfloat162& operator-=(__hip_bfloat162& l, const __hip_bfloat162& r) {
476
476
+
+ l = l - r;
477
477
+
+ return l;
478
478
+
+}
479
479
+
+
480
480
+
+/**
481
481
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
482
482
+
+ * \brief Operator to divide two __hip_bfloat162 numbers
483
483
+
+ */
484
484
+
+__HOST_DEVICE__ __hip_bfloat162 operator/(const __hip_bfloat162& l, const __hip_bfloat162& r) {
485
485
+
+ return __h2div(l, r);
486
486
+
+}
487
487
+
+
488
488
+
+/**
489
489
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
490
490
+
+ * \brief Operator to divide-assign two __hip_bfloat162 numbers
491
491
+
+ */
492
492
+
+__HOST_DEVICE__ __hip_bfloat162& operator/=(__hip_bfloat162& l, const __hip_bfloat162& r) {
493
493
+
+ l = l / r;
494
494
+
+ return l;
495
495
+
+}
496
496
+
+
497
497
+
/**
498
498
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
499
499
+
* \brief Compare two bfloat162 values
500
500
+
*/
501
501
+
-__device__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) {
502
502
+
+__HOST_DEVICE__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) {
503
503
+
return __bfloat162float(a) == __bfloat162float(b);
504
504
+
}
505
505
+
506
506
+
@@ -442,7 +689,7 @@ __device__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) {
507
507
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
508
508
+
* \brief Compare two bfloat162 values - unordered equal
509
509
+
*/
510
510
+
-__device__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) {
511
511
+
+__HOST_DEVICE__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) {
512
512
+
return !(__bfloat162float(a) < __bfloat162float(b)) &&
513
513
+
!(__bfloat162float(a) > __bfloat162float(b));
514
514
+
}
515
515
+
@@ -451,7 +698,7 @@ __device__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) {
516
516
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
517
517
+
* \brief Compare two bfloat162 values - greater than
518
518
+
*/
519
519
+
-__device__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
520
520
+
+__HOST_DEVICE__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
521
521
+
return __bfloat162float(a) > __bfloat162float(b);
522
522
+
}
523
523
+
524
524
+
@@ -459,7 +706,7 @@ __device__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
525
525
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
526
526
+
* \brief Compare two bfloat162 values - unordered greater than
527
527
+
*/
528
528
+
-__device__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
529
529
+
+__HOST_DEVICE__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
530
530
+
return !(__bfloat162float(a) <= __bfloat162float(b));
531
531
+
}
532
532
+
533
533
+
@@ -467,7 +714,7 @@ __device__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
534
534
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
535
535
+
* \brief Compare two bfloat162 values - greater than equal
536
536
+
*/
537
537
+
-__device__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) {
538
538
+
+__HOST_DEVICE__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) {
539
539
+
return __bfloat162float(a) >= __bfloat162float(b);
540
540
+
}
541
541
+
542
542
+
@@ -475,7 +722,7 @@ __device__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) {
543
543
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
544
544
+
* \brief Compare two bfloat162 values - unordered greater than equal
545
545
+
*/
546
546
+
-__device__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
547
547
+
+__HOST_DEVICE__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
548
548
+
return !(__bfloat162float(a) < __bfloat162float(b));
549
549
+
}
550
550
+
551
551
+
@@ -483,7 +730,7 @@ __device__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
552
552
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
553
553
+
* \brief Compare two bfloat162 values - not equal
554
554
+
*/
555
555
+
-__device__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) {
556
556
+
+__HOST_DEVICE__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) {
557
557
+
return __bfloat162float(a) != __bfloat162float(b);
558
558
+
}
559
559
+
560
560
+
@@ -491,7 +738,7 @@ __device__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) {
561
561
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
562
562
+
* \brief Compare two bfloat162 values - unordered not equal
563
563
+
*/
564
564
+
-__device__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
565
565
+
+__HOST_DEVICE__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
566
566
+
return !(__bfloat162float(a) == __bfloat162float(b));
567
567
+
}
568
568
+
569
569
+
@@ -499,23 +746,31 @@ __device__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
570
570
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
571
571
+
* \brief Compare two bfloat162 values - return max
572
572
+
*/
573
573
+
-__device__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) {
574
574
+
+__HOST_DEVICE__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) {
575
575
+
+#if __HIP_DEVICE_COMPILE__
576
576
+
return __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a), __bfloat162float(b)));
577
577
+
+#else
578
578
+
+ return __float2bfloat16(std::max(__bfloat162float(a), __bfloat162float(b)));
579
579
+
+#endif
580
580
+
}
581
581
+
582
582
+
/**
583
583
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
584
584
+
* \brief Compare two bfloat162 values - return min
585
585
+
*/
586
586
+
-__device__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) {
587
587
+
+__HOST_DEVICE__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) {
588
588
+
+#if __HIP_DEVICE_COMPILE__
589
589
+
return __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a), __bfloat162float(b)));
590
590
+
+#else
591
591
+
+ return __float2bfloat16(std::min(__bfloat162float(a), __bfloat162float(b)));
592
592
+
+#endif
593
593
+
}
594
594
+
595
595
+
/**
596
596
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
597
597
+
* \brief Compare two bfloat162 values - less than operator
598
598
+
*/
599
599
+
-__device__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
600
600
+
+__HOST_DEVICE__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
601
601
+
return __bfloat162float(a) < __bfloat162float(b);
602
602
+
}
603
603
+
604
604
+
@@ -523,15 +778,15 @@ __device__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
605
605
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
606
606
+
* \brief Compare two bfloat162 values - unordered less than
607
607
+
*/
608
608
+
-__device__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
609
609
+
+__HOST_DEVICE__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
610
610
+
return !(__bfloat162float(a) >= __bfloat162float(b));
611
611
+
}
612
612
+
613
613
+
/**
614
614
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
615
615
+
- * \brief Compare two bfloat162 values - less than
616
616
+
+ * \brief Compare two bfloat162 values - less than equal
617
617
+
*/
618
618
+
-__device__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) {
619
619
+
+__HOST_DEVICE__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) {
620
620
+
return __bfloat162float(a) <= __bfloat162float(b);
621
621
+
}
622
622
+
623
623
+
@@ -539,7 +794,7 @@ __device__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) {
624
624
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
625
625
+
* \brief Compare two bfloat162 values - unordered less than equal
626
626
+
*/
627
627
+
-__device__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
628
628
+
+__HOST_DEVICE__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
629
629
+
return !(__bfloat162float(a) > __bfloat162float(b));
630
630
+
}
631
631
+
632
632
+
@@ -547,19 +802,33 @@ __device__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
633
633
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
634
634
+
* \brief Checks if number is inf
635
635
+
*/
636
636
+
-__device__ int __hisinf(const __hip_bfloat16 a) { return __ocml_isinf_f32(__bfloat162float(a)); }
637
637
+
+__HOST_DEVICE__ int __hisinf(const __hip_bfloat16 a) {
638
638
+
+ unsigned short sign = a.data & 0x8000U;
639
639
+
+#if __HIP_DEVICE_COMPILE__
640
640
+
+ int res = __ocml_isinf_f32(__bfloat162float(a));
641
641
+
+#else
642
642
+
+ int res = std::isinf(__bfloat162float(a)) ? 1 : 0;
643
643
+
+#endif
644
644
+
+ return (res == 0) ? res : ((sign != 0U) ? -res : res);
645
645
+
+}
646
646
+
647
647
+
/**
648
648
+
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
649
649
+
* \brief Checks if number is nan
650
650
+
*/
651
651
+
-__device__ bool __hisnan(const __hip_bfloat16 a) { return __ocml_isnan_f32(__bfloat162float(a)); }
652
652
+
+__HOST_DEVICE__ bool __hisnan(const __hip_bfloat16 a) {
653
653
+
+#if __HIP_DEVICE_COMPILE__
654
654
+
+ return __ocml_isnan_f32(__bfloat162float(a));
655
655
+
+#else
656
656
+
+ return std::isnan(__bfloat162float(a));
657
657
+
+#endif
658
658
+
+}
659
659
+
660
660
+
/**
661
661
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
662
662
+
* \brief Checks if two numbers are equal
663
663
+
*/
664
664
+
-__device__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
665
665
+
+__HOST_DEVICE__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
666
666
+
return __heq(a.x, b.x) && __heq(a.y, b.y);
667
667
+
}
668
668
+
669
669
+
@@ -567,7 +836,7 @@ __device__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
670
670
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
671
671
+
* \brief Checks if two numbers are equal - unordered
672
672
+
*/
673
673
+
-__device__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
674
674
+
+__HOST_DEVICE__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
675
675
+
return __hequ(a.x, b.x) && __hequ(a.y, b.y);
676
676
+
}
677
677
+
678
678
+
@@ -575,7 +844,7 @@ __device__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
679
679
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
680
680
+
* \brief Check for a >= b
681
681
+
*/
682
682
+
-__device__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
683
683
+
+__HOST_DEVICE__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
684
684
+
return __hge(a.x, b.x) && __hge(a.y, b.y);
685
685
+
}
686
686
+
687
687
+
@@ -583,7 +852,7 @@ __device__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
688
688
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
689
689
+
* \brief Check for a >= b - unordered
690
690
+
*/
691
691
+
-__device__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
692
692
+
+__HOST_DEVICE__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
693
693
+
return __hgeu(a.x, b.x) && __hgeu(a.y, b.y);
694
694
+
}
695
695
+
696
696
+
@@ -591,7 +860,7 @@ __device__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
697
697
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
698
698
+
* \brief Check for a > b
699
699
+
*/
700
700
+
-__device__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
701
701
+
+__HOST_DEVICE__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
702
702
+
return __hgt(a.x, b.x) && __hgt(a.y, b.y);
703
703
+
}
704
704
+
705
705
+
@@ -599,7 +868,7 @@ __device__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
706
706
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
707
707
+
* \brief Check for a > b - unordered
708
708
+
*/
709
709
+
-__device__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
710
710
+
+__HOST_DEVICE__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
711
711
+
return __hgtu(a.x, b.x) && __hgtu(a.y, b.y);
712
712
+
}
713
713
+
714
714
+
@@ -607,7 +876,7 @@ __device__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
715
715
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
716
716
+
* \brief Check for a <= b
717
717
+
*/
718
718
+
-__device__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
719
719
+
+__HOST_DEVICE__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
720
720
+
return __hle(a.x, b.x) && __hle(a.y, b.y);
721
721
+
}
722
722
+
723
723
+
@@ -615,7 +884,7 @@ __device__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
724
724
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
725
725
+
* \brief Check for a <= b - unordered
726
726
+
*/
727
727
+
-__device__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
728
728
+
+__HOST_DEVICE__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
729
729
+
return __hleu(a.x, b.x) && __hleu(a.y, b.y);
730
730
+
}
731
731
+
732
732
+
@@ -623,7 +892,7 @@ __device__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
733
733
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
734
734
+
* \brief Check for a < b
735
735
+
*/
736
736
+
-__device__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
737
737
+
+__HOST_DEVICE__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
738
738
+
return __hlt(a.x, b.x) && __hlt(a.y, b.y);
739
739
+
}
740
740
+
741
741
+
@@ -631,7 +900,7 @@ __device__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
742
742
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
743
743
+
* \brief Check for a < b - unordered
744
744
+
*/
745
745
+
-__device__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
746
746
+
+__HOST_DEVICE__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
747
747
+
return __hltu(a.x, b.x) && __hltu(a.y, b.y);
748
748
+
}
749
749
+
750
750
+
@@ -639,7 +908,7 @@ __device__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
751
751
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
752
752
+
* \brief Check for a != b
753
753
+
*/
754
754
+
-__device__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
755
755
+
+__HOST_DEVICE__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
756
756
+
return __hne(a.x, b.x) && __hne(a.y, b.y);
757
757
+
}
758
758
+
759
759
+
@@ -647,7 +916,7 @@ __device__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
760
760
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
761
761
+
* \brief Check for a != b
762
762
+
*/
763
763
+
-__device__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
764
764
+
+__HOST_DEVICE__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
765
765
+
return __hneu(a.x, b.x) && __hneu(a.y, b.y);
766
766
+
}
767
767
+
768
768
+
@@ -655,84 +924,175 @@ __device__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
769
769
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
770
770
+
* \brief Check for a != b, returns 1.0 if equal, otherwise 0.0
771
771
+
*/
772
772
+
-__device__ __hip_bfloat162 __heq2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
773
773
+
- return __hip_bfloat162{{__heq(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
774
774
+
- {__heq(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
775
775
+
+__HOST_DEVICE__ __hip_bfloat162 __heq2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
776
776
+
+ return __hip_bfloat162{{__heq(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
777
777
+
+ {__heq(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}};
778
778
+
}
779
779
+
780
780
+
/**
781
781
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
782
782
+
* \brief Check for a >= b, returns 1.0 if greater than equal, otherwise 0.0
783
783
+
*/
784
784
+
-__device__ __hip_bfloat162 __hge2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
785
785
+
- return __hip_bfloat162{{__hge(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
786
786
+
- {__hge(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
787
787
+
+__HOST_DEVICE__ __hip_bfloat162 __hge2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
788
788
+
+ return __hip_bfloat162{{__hge(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
789
789
+
+ {__hge(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}};
790
790
+
}
791
791
+
792
792
+
/**
793
793
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
794
794
+
* \brief Check for a > b, returns 1.0 if greater than equal, otherwise 0.0
795
795
+
*/
796
796
+
-__device__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
797
797
+
- return __hip_bfloat162{{__hgt(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
798
798
+
- {__hgt(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
799
799
+
+__HOST_DEVICE__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
800
800
+
+ return __hip_bfloat162{{__hgt(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
801
801
+
+ {__hgt(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}};
802
802
+
}
803
803
+
804
804
+
/**
805
805
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
806
806
+
* \brief Check for a is NaN, returns 1.0 if NaN, otherwise 0.0
807
807
+
*/
808
808
+
-__device__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) {
809
809
+
- return __hip_bfloat162{
810
810
+
- {__ocml_isnan_f32(__bfloat162float(a.x)) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
811
811
+
- {__ocml_isnan_f32(__bfloat162float(a.y)) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
812
812
+
+__HOST_DEVICE__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) {
813
813
+
+ return __hip_bfloat162{{__hisnan(a.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
814
814
+
+ {__hisnan(a.y) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}};
815
815
+
}
816
816
+
817
817
+
/**
818
818
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
819
819
+
* \brief Check for a <= b, returns 1.0 if greater than equal, otherwise 0.0
820
820
+
*/
821
821
+
-__device__ __hip_bfloat162 __hle2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
822
822
+
- return __hip_bfloat162{{__hle(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
823
823
+
- {__hle(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
824
824
+
+__HOST_DEVICE__ __hip_bfloat162 __hle2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
825
825
+
+ return __hip_bfloat162{{__hle(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
826
826
+
+ {__hle(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}};
827
827
+
}
828
828
+
829
829
+
/**
830
830
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
831
831
+
* \brief Check for a < b, returns 1.0 if greater than equal, otherwise 0.0
832
832
+
*/
833
833
+
-__device__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
834
834
+
- return __hip_bfloat162{{__hlt(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
835
835
+
- {__hlt(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
836
836
+
+__HOST_DEVICE__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
837
837
+
+ return __hip_bfloat162{{__hlt(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
838
838
+
+ {__hlt(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}};
839
839
+
}
840
840
+
841
841
+
/**
842
842
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
843
843
+
* \brief Returns max of two elements
844
844
+
*/
845
845
+
-__device__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
846
846
+
- return __hip_bfloat162{
847
847
+
- __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a.x), __bfloat162float(b.x))),
848
848
+
- __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a.y), __bfloat162float(b.y)))};
849
849
+
+__HOST_DEVICE__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
850
850
+
+ return __hip_bfloat162{__hmax(a.x, b.x), __hmax(a.y, b.y)};
851
851
+
}
852
852
+
853
853
+
/**
854
854
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
855
855
+
* \brief Returns min of two elements
856
856
+
*/
857
857
+
-__device__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
858
858
+
- return __hip_bfloat162{
859
859
+
- __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a.x), __bfloat162float(b.x))),
860
860
+
- __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a.y), __bfloat162float(b.y)))};
861
861
+
+__HOST_DEVICE__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
862
862
+
+ return __hip_bfloat162{__hmin(a.x, b.x), __hmin(a.y, b.y)};
863
863
+
}
864
864
+
865
865
+
/**
866
866
+
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
867
867
+
* \brief Checks for not equal to
868
868
+
*/
869
869
+
-__device__ __hip_bfloat162 __hne2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
870
870
+
- return __hip_bfloat162{{__hne(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
871
871
+
- {__hne(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
872
872
+
+__HOST_DEVICE__ __hip_bfloat162 __hne2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
873
873
+
+ return __hip_bfloat162{{__hne(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
874
874
+
+ {__hne(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}};
875
875
+
+}
876
876
+
+
877
877
+
+/**
878
878
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
879
879
+
+ * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers
880
880
+
+ */
881
881
+
+__HOST_DEVICE__ bool operator==(const __hip_bfloat16& l, const __hip_bfloat16& r) {
882
882
+
+ return __heq(l, r);
883
883
+
+}
884
884
+
+
885
885
+
+/**
886
886
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
887
887
+
+ * \brief Operator to perform a not equal on two __hip_bfloat16 numbers
888
888
+
+ */
889
889
+
+__HOST_DEVICE__ bool operator!=(const __hip_bfloat16& l, const __hip_bfloat16& r) {
890
890
+
+ return __hne(l, r);
891
891
+
+}
892
892
+
+
893
893
+
+/**
894
894
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
895
895
+
+ * \brief Operator to perform a less than on two __hip_bfloat16 numbers
896
896
+
+ */
897
897
+
+__HOST_DEVICE__ bool operator<(const __hip_bfloat16& l, const __hip_bfloat16& r) {
898
898
+
+ return __hlt(l, r);
899
899
+
+}
900
900
+
+
901
901
+
+/**
902
902
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
903
903
+
+ * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers
904
904
+
+ */
905
905
+
+__HOST_DEVICE__ bool operator<=(const __hip_bfloat16& l, const __hip_bfloat16& r) {
906
906
+
+ return __hle(l, r);
907
907
+
+}
908
908
+
+
909
909
+
+/**
910
910
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
911
911
+
+ * \brief Operator to perform a greater than on two __hip_bfloat16 numbers
912
912
+
+ */
913
913
+
+__HOST_DEVICE__ bool operator>(const __hip_bfloat16& l, const __hip_bfloat16& r) {
914
914
+
+ return __hgt(l, r);
915
915
+
+}
916
916
+
+
917
917
+
+/**
918
918
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
919
919
+
+ * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers
920
920
+
+ */
921
921
+
+__HOST_DEVICE__ bool operator>=(const __hip_bfloat16& l, const __hip_bfloat16& r) {
922
922
+
+ return __hge(l, r);
923
923
+
+}
924
924
+
+
925
925
+
+/**
926
926
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
927
927
+
+ * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers
928
928
+
+ */
929
929
+
+__HOST_DEVICE__ bool operator==(const __hip_bfloat162& l, const __hip_bfloat162& r) {
930
930
+
+ return __heq(l.x, r.x) && __heq(l.y, r.y);
931
931
+
+}
932
932
+
+
933
933
+
+/**
934
934
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
935
935
+
+ * \brief Operator to perform a not equal on two __hip_bfloat16 numbers
936
936
+
+ */
937
937
+
+__HOST_DEVICE__ bool operator!=(const __hip_bfloat162& l, const __hip_bfloat162& r) {
938
938
+
+ return __hne(l.x, r.x) || __hne(l.y, r.y);
939
939
+
+}
940
940
+
+
941
941
+
+/**
942
942
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
943
943
+
+ * \brief Operator to perform a less than on two __hip_bfloat16 numbers
944
944
+
+ */
945
945
+
+__HOST_DEVICE__ bool operator<(const __hip_bfloat162& l, const __hip_bfloat162& r) {
946
946
+
+ return __hlt(l.x, r.x) && __hlt(l.y, r.y);
947
947
+
+}
948
948
+
+
949
949
+
+/**
950
950
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
951
951
+
+ * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers
952
952
+
+ */
953
953
+
+__HOST_DEVICE__ bool operator<=(const __hip_bfloat162& l, const __hip_bfloat162& r) {
954
954
+
+ return __hle(l.x, r.x) && __hle(l.y, r.y);
955
955
+
+}
956
956
+
+
957
957
+
+/**
958
958
+
+ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
959
959
+
+ * \brief Operator to perform a greater than on two __hip_bfloat16 numbers
960
960
+
+ */
961
961
+
+__HOST_DEVICE__ bool operator>(const __hip_bfloat162& l, const __hip_bfloat162& r) {
962
962
+
+ return __hgt(l.x, r.x) && __hgt(l.y, r.y);
963
963
+
+}
964
964
+
+
965
965
+
+/**
966
966
+
+ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
967
967
+
+ * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers
968
968
+
+ */
969
969
+
+__HOST_DEVICE__ bool operator>=(const __hip_bfloat162& l, const __hip_bfloat162& r) {
970
970
+
+ return __hge(l.x, r.x) && __hge(l.y, r.y);
971
971
+
}
972
972
+
973
973
+
/**
974
974
+
@@ -974,5 +1334,4 @@ __device__ __hip_bfloat162 h2sqrt(const __hip_bfloat162 h) {
975
975
+
__device__ __hip_bfloat162 h2trunc(const __hip_bfloat162 h) {
976
976
+
return __hip_bfloat162{htrunc(h.x), htrunc(h.y)};
977
977
+
}
978
978
+
-
979
979
+
#endif
+5
pkgs/development/rocm-modules/5/clr/default.nix
reviewed
···
88
88
"-DCMAKE_INSTALL_LIBDIR=lib"
89
89
];
90
90
91
91
+
patches = [
92
92
+
./add-missing-operators.patch
93
93
+
./static-functions.patch
94
94
+
];
95
95
+
91
96
postPatch = ''
92
97
patchShebangs hipamd/src
93
98
+31
pkgs/development/rocm-modules/5/clr/static-functions.patch
reviewed
···
1
1
+
From 77c581a3ebd47b5e2908973b70adea66891159ee Mon Sep 17 00:00:00 2001
2
2
+
From: Jatin Chaudhary <JatinJaikishan.Chaudhary@amd.com>
3
3
+
Date: Mon, 4 Dec 2023 17:21:39 +0000
4
4
+
Subject: [PATCH] SWDEV-435702 - the functions in bf16 header need to be static
5
5
+
6
6
+
If the compiler decides not to inline these functions, we might break ODR (one definition rule) due to this file being included in multiple files and being linked together
7
7
+
8
8
+
Change-Id: Iacbfdabb53f5b4e5db8c690b23f3730ec9af16c0
9
9
+
---
10
10
+
hipamd/include/hip/amd_detail/amd_hip_bf16.h | 4 ++--
11
11
+
1 file changed, 2 insertions(+), 2 deletions(-)
12
12
+
13
13
+
diff --git a/hipamd/include/hip/amd_detail/amd_hip_bf16.h b/hipamd/include/hip/amd_detail/amd_hip_bf16.h
14
14
+
index 836e090eb..204269a84 100644
15
15
+
--- a/hipamd/include/hip/amd_detail/amd_hip_bf16.h
16
16
+
+++ b/hipamd/include/hip/amd_detail/amd_hip_bf16.h
17
17
+
@@ -94,12 +94,12 @@
18
18
+
#include "math_fwd.h" // ocml device functions
19
19
+
20
20
+
#if defined(__HIPCC_RTC__)
21
21
+
-#define __HOST_DEVICE__ __device__
22
22
+
+#define __HOST_DEVICE__ __device__ static
23
23
+
#else
24
24
+
#include <algorithm>
25
25
+
#include <climits>
26
26
+
#include <cmath>
27
27
+
-#define __HOST_DEVICE__ __host__ __device__ inline
28
28
+
+#define __HOST_DEVICE__ __host__ __device__ static inline
29
29
+
#endif
30
30
+
31
31
+
#define HIPRT_ONE_BF16 __float2bfloat16(1.0f)