nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix

rocmPackages.clr: backport bf16 compilation fix

+1015
+979
pkgs/development/rocm-modules/5/clr/add-missing-operators.patch
··· 1 + From 86bd518981b364c138f9901b28a529899d8654f3 Mon Sep 17 00:00:00 2001 2 + From: Jatin Chaudhary <JatinJaikishan.Chaudhary@amd.com> 3 + Date: Wed, 11 Oct 2023 23:19:29 +0100 4 + Subject: [PATCH] SWDEV-367537 - Add missing operators to __hip_bfloat16 5 + implementation 6 + 7 + Add __host__ and __device__ to bunch of operator/function matching CUDA 8 + Fix some bugs seen in __hisinf 9 + 10 + Change-Id: I9e67e3e3eb2083b463158f3e250e5221c89b2896 11 + --- 12 + hipamd/include/hip/amd_detail/amd_hip_bf16.h | 533 ++++++++++++++++--- 13 + 1 file changed, 446 insertions(+), 87 deletions(-) 14 + 15 + diff --git a/hipamd/include/hip/amd_detail/amd_hip_bf16.h b/hipamd/include/hip/amd_detail/amd_hip_bf16.h 16 + index 757cb7ada..b15ea3b65 100644 17 + --- a/hipamd/include/hip/amd_detail/amd_hip_bf16.h 18 + +++ b/hipamd/include/hip/amd_detail/amd_hip_bf16.h 19 + @@ -96,10 +96,20 @@ 20 + #if defined(__HIPCC_RTC__) 21 + #define __HOST_DEVICE__ __device__ 22 + #else 23 + +#include <algorithm> 24 + #include <climits> 25 + -#define __HOST_DEVICE__ __host__ __device__ 26 + +#include <cmath> 27 + +#define __HOST_DEVICE__ __host__ __device__ inline 28 + #endif 29 + 30 + +#define HIPRT_ONE_BF16 __float2bfloat16(1.0f) 31 + +#define HIPRT_ZERO_BF16 __float2bfloat16(0.0f) 32 + +#define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) 33 + +#define HIPRT_MAX_NORMAL_BF16 __ushort_as_bfloat16((unsigned short)0x7F7FU) 34 + +#define HIPRT_MIN_DENORM_BF16 __ushort_as_bfloat16((unsigned short)0x0001U) 35 + +#define HIPRT_NAN_BF16 __ushort_as_bfloat16((unsigned short)0x7FFFU) 36 + +#define HIPRT_NEG_ZERO_BF16 __ushort_as_bfloat16((unsigned short)0x8000U) 37 + + 38 + // Since we are using unsigned short to represent data in bfloat16, it can be of different sizes on 39 + // different machines. These naive checks should prevent some undefined behavior on systems which 40 + // have different sizes for basic types. 41 + @@ -189,7 +199,7 @@ __HOST_DEVICE__ float2 __bfloat1622float2(const __hip_bfloat162 a) { 42 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 43 + * \brief Moves bfloat16 value to bfloat162 44 + */ 45 + -__device__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) { 46 + +__HOST_DEVICE__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) { 47 + return __hip_bfloat162{a, a}; 48 + } 49 + 50 + @@ -197,13 +207,13 @@ __device__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) { 51 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 52 + * \brief Reinterprets bits in a __hip_bfloat16 as a signed short integer 53 + */ 54 + -__device__ short int __bfloat16_as_short(const __hip_bfloat16 h) { return (short)h.data; } 55 + +__HOST_DEVICE__ short int __bfloat16_as_short(const __hip_bfloat16 h) { return (short)h.data; } 56 + 57 + /** 58 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 59 + * \brief Reinterprets bits in a __hip_bfloat16 as an unsigned signed short integer 60 + */ 61 + -__device__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) { return h.data; } 62 + +__HOST_DEVICE__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) { return h.data; } 63 + 64 + /** 65 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 66 + @@ -225,7 +235,7 @@ __HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) { 67 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 68 + * \brief Combine two __hip_bfloat16 to __hip_bfloat162 69 + */ 70 + -__device__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a, const __hip_bfloat16 b) { 71 + +__HOST_DEVICE__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a, const __hip_bfloat16 b) { 72 + return __hip_bfloat162{a, b}; 73 + } 74 + 75 + @@ -233,13 +243,13 @@ __device__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a, const __hi 76 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 77 + * \brief Returns high 16 bits of __hip_bfloat162 78 + */ 79 + -__device__ __hip_bfloat16 __high2bfloat16(const __hip_bfloat162 a) { return a.y; } 80 + +__HOST_DEVICE__ __hip_bfloat16 __high2bfloat16(const __hip_bfloat162 a) { return a.y; } 81 + 82 + /** 83 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 84 + * \brief Returns high 16 bits of __hip_bfloat162 85 + */ 86 + -__device__ __hip_bfloat162 __high2bfloat162(const __hip_bfloat162 a) { 87 + +__HOST_DEVICE__ __hip_bfloat162 __high2bfloat162(const __hip_bfloat162 a) { 88 + return __hip_bfloat162{a.y, a.y}; 89 + } 90 + 91 + @@ -253,7 +263,8 @@ __HOST_DEVICE__ float __high2float(const __hip_bfloat162 a) { return __bfloat162 92 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 93 + * \brief Extracts high 16 bits from each and combines them 94 + */ 95 + -__device__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) { 96 + +__HOST_DEVICE__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a, 97 + + const __hip_bfloat162 b) { 98 + return __hip_bfloat162{a.y, b.y}; 99 + } 100 + 101 + @@ -261,13 +272,13 @@ __device__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a, const __hi 102 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 103 + * \brief Returns low 16 bits of __hip_bfloat162 104 + */ 105 + -__device__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { return a.x; } 106 + +__HOST_DEVICE__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { return a.x; } 107 + 108 + /** 109 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 110 + * \brief Returns low 16 bits of __hip_bfloat162 111 + */ 112 + -__device__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) { 113 + +__HOST_DEVICE__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) { 114 + return __hip_bfloat162{a.x, a.x}; 115 + } 116 + 117 + @@ -281,7 +292,7 @@ __HOST_DEVICE__ float __low2float(const __hip_bfloat162 a) { return __bfloat162f 118 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 119 + * \brief Swaps both halves 120 + */ 121 + -__device__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) { 122 + +__HOST_DEVICE__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) { 123 + return __hip_bfloat162{a.y, a.x}; 124 + } 125 + 126 + @@ -289,7 +300,7 @@ __device__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) { 127 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 128 + * \brief Extracts low 16 bits from each and combines them 129 + */ 130 + -__device__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) { 131 + +__HOST_DEVICE__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) { 132 + return __hip_bfloat162{a.x, b.x}; 133 + } 134 + 135 + @@ -297,7 +308,7 @@ __device__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a, const __hip 136 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 137 + * \brief Reinterprets short int into a bfloat16 138 + */ 139 + -__device__ __hip_bfloat16 __short_as_bfloat16(const short int a) { 140 + +__HOST_DEVICE__ __hip_bfloat16 __short_as_bfloat16(const short int a) { 141 + return __hip_bfloat16{(unsigned short)a}; 142 + } 143 + 144 + @@ -305,7 +316,7 @@ __device__ __hip_bfloat16 __short_as_bfloat16(const short int a) { 145 + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 146 + * \brief Reinterprets unsigned short int into a bfloat16 147 + */ 148 + -__device__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) { 149 + +__HOST_DEVICE__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) { 150 + return __hip_bfloat16{a}; 151 + } 152 + 153 + @@ -314,7 +325,7 @@ __device__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) { 154 + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 155 + * \brief Adds two bfloat16 values 156 + */ 157 + -__device__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) { 158 + +__HOST_DEVICE__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) { 159 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b)); 160 + } 161 + 162 + @@ -322,7 +333,7 @@ __device__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) 163 + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 164 + * \brief Subtracts two bfloat16 values 165 + */ 166 + -__device__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) { 167 + +__HOST_DEVICE__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) { 168 + return __float2bfloat16(__bfloat162float(a) - __bfloat162float(b)); 169 + } 170 + 171 + @@ -330,7 +341,7 @@ __device__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) 172 + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 173 + * \brief Divides two bfloat16 values 174 + */ 175 + -__device__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) { 176 + +__HOST_DEVICE__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) { 177 + return __float2bfloat16(__bfloat162float(a) / __bfloat162float(b)); 178 + } 179 + 180 + @@ -348,7 +359,7 @@ __device__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip_bfloat16 b, 181 + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 182 + * \brief Multiplies two bfloat16 values 183 + */ 184 + -__device__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) { 185 + +__HOST_DEVICE__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) { 186 + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b)); 187 + } 188 + 189 + @@ -356,7 +367,7 @@ __device__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) 190 + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 191 + * \brief Negate a bfloat16 value 192 + */ 193 + -__device__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) { 194 + +__HOST_DEVICE__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) { 195 + auto ret = a; 196 + ret.data ^= 0x8000; 197 + return ret; 198 + @@ -366,7 +377,7 @@ __device__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) { 199 + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 200 + * \brief Returns absolute of a bfloat16 201 + */ 202 + -__device__ __hip_bfloat16 __habs(const __hip_bfloat16 a) { 203 + +__HOST_DEVICE__ __hip_bfloat16 __habs(const __hip_bfloat16 a) { 204 + auto ret = a; 205 + ret.data &= 0x7FFF; 206 + return ret; 207 + @@ -376,7 +387,7 @@ __device__ __hip_bfloat16 __habs(const __hip_bfloat16 a) { 208 + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 209 + * \brief Divides bfloat162 values 210 + */ 211 + -__device__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, const __hip_bfloat162 b) { 212 + +__HOST_DEVICE__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, const __hip_bfloat162 b) { 213 + return __hip_bfloat162{__float2bfloat16(__bfloat162float(a.x) / __bfloat162float(b.x)), 214 + __float2bfloat16(__bfloat162float(a.y) / __bfloat162float(b.y))}; 215 + } 216 + @@ -385,7 +396,7 @@ __device__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, const __hip_bfloat16 217 + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 218 + * \brief Returns absolute of a bfloat162 219 + */ 220 + -__device__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) { 221 + +__HOST_DEVICE__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) { 222 + return __hip_bfloat162{__habs(a.x), __habs(a.y)}; 223 + } 224 + 225 + @@ -393,7 +404,7 @@ __device__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) { 226 + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 227 + * \brief Adds two bfloat162 values 228 + */ 229 + -__device__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 230 + +__HOST_DEVICE__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 231 + return __hip_bfloat162{__hadd(a.x, b.x), __hadd(a.y, b.y)}; 232 + } 233 + 234 + @@ -410,7 +421,7 @@ __device__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, const __hip_bfloat16 235 + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 236 + * \brief Multiplies two bfloat162 values 237 + */ 238 + -__device__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 239 + +__HOST_DEVICE__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 240 + return __hip_bfloat162{__hmul(a.x, b.x), __hmul(a.y, b.y)}; 241 + } 242 + 243 + @@ -418,7 +429,7 @@ __device__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat16 244 + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 245 + * \brief Converts a bfloat162 into negative 246 + */ 247 + -__device__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) { 248 + +__HOST_DEVICE__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) { 249 + return __hip_bfloat162{__hneg(a.x), __hneg(a.y)}; 250 + } 251 + 252 + @@ -426,15 +437,251 @@ __device__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) { 253 + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 254 + * \brief Subtracts two bfloat162 values 255 + */ 256 + -__device__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 257 + +__HOST_DEVICE__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 258 + return __hip_bfloat162{__hsub(a.x, b.x), __hsub(a.y, b.y)}; 259 + } 260 + 261 + +/** 262 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 263 + + * \brief Operator to multiply two __hip_bfloat16 numbers 264 + + */ 265 + +__HOST_DEVICE__ __hip_bfloat16 operator*(const __hip_bfloat16& l, const __hip_bfloat16& r) { 266 + + return __hmul(l, r); 267 + +} 268 + + 269 + +/** 270 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 271 + + * \brief Operator to multiply-assign two __hip_bfloat16 numbers 272 + + */ 273 + +__HOST_DEVICE__ __hip_bfloat16 operator*=(__hip_bfloat16& l, const __hip_bfloat16& r) { 274 + + l = __hmul(l, r); 275 + + return l; 276 + +} 277 + + 278 + +/** 279 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 280 + + * \brief Operator to unary+ on a __hip_bfloat16 number 281 + + */ 282 + +__HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l) { return l; } 283 + + 284 + +/** 285 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 286 + + * \brief Operator to add two __hip_bfloat16 numbers 287 + + */ 288 + +__HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l, const __hip_bfloat16& r) { 289 + + return __hadd(l, r); 290 + +} 291 + + 292 + +/** 293 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 294 + + * \brief Operator to negate a __hip_bfloat16 number 295 + + */ 296 + +__HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l) { return __hneg(l); } 297 + + 298 + +/** 299 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 300 + + * \brief Operator to subtract two __hip_bfloat16 numbers 301 + + */ 302 + +__HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l, const __hip_bfloat16& r) { 303 + + return __hsub(l, r); 304 + +} 305 + + 306 + +/** 307 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 308 + + * \brief Operator to post increment a __hip_bfloat16 number 309 + + */ 310 + +__HOST_DEVICE__ __hip_bfloat16 operator++(__hip_bfloat16& l, const int) { 311 + + auto ret = l; 312 + + l = __hadd(l, HIPRT_ONE_BF16); 313 + + return ret; 314 + +} 315 + + 316 + +/** 317 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 318 + + * \brief Operator to pre increment a __hip_bfloat16 number 319 + + */ 320 + +__HOST_DEVICE__ __hip_bfloat16& operator++(__hip_bfloat16& l) { 321 + + l = __hadd(l, HIPRT_ONE_BF16); 322 + + return l; 323 + +} 324 + + 325 + +/** 326 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 327 + + * \brief Operator to post decrement a __hip_bfloat16 number 328 + + */ 329 + +__HOST_DEVICE__ __hip_bfloat16 operator--(__hip_bfloat16& l, const int) { 330 + + auto ret = l; 331 + + l = __hsub(l, HIPRT_ONE_BF16); 332 + + return ret; 333 + +} 334 + + 335 + +/** 336 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 337 + + * \brief Operator to pre decrement a __hip_bfloat16 number 338 + + */ 339 + +__HOST_DEVICE__ __hip_bfloat16& operator--(__hip_bfloat16& l) { 340 + + l = __hsub(l, HIPRT_ONE_BF16); 341 + + return l; 342 + +} 343 + + 344 + +/** 345 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 346 + + * \brief Operator to add-assign two __hip_bfloat16 numbers 347 + + */ 348 + +__HOST_DEVICE__ __hip_bfloat16& operator+=(__hip_bfloat16& l, const __hip_bfloat16& r) { 349 + + l = l + r; 350 + + return l; 351 + +} 352 + + 353 + +/** 354 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 355 + + * \brief Operator to subtract-assign two __hip_bfloat16 numbers 356 + + */ 357 + +__HOST_DEVICE__ __hip_bfloat16& operator-=(__hip_bfloat16& l, const __hip_bfloat16& r) { 358 + + l = l - r; 359 + + return l; 360 + +} 361 + + 362 + +/** 363 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 364 + + * \brief Operator to divide two __hip_bfloat16 numbers 365 + + */ 366 + +__HOST_DEVICE__ __hip_bfloat16 operator/(const __hip_bfloat16& l, const __hip_bfloat16& r) { 367 + + return __hdiv(l, r); 368 + +} 369 + + 370 + +/** 371 + + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH 372 + + * \brief Operator to divide-assign two __hip_bfloat16 numbers 373 + + */ 374 + +__HOST_DEVICE__ __hip_bfloat16& operator/=(__hip_bfloat16& l, const __hip_bfloat16& r) { 375 + + l = l / r; 376 + + return l; 377 + +} 378 + + 379 + +/** 380 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 381 + + * \brief Operator to multiply two __hip_bfloat162 numbers 382 + + */ 383 + +__HOST_DEVICE__ __hip_bfloat162 operator*(const __hip_bfloat162& l, const __hip_bfloat162& r) { 384 + + return __hmul2(l, r); 385 + +} 386 + + 387 + +/** 388 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 389 + + * \brief Operator to multiply-assign two __hip_bfloat162 numbers 390 + + */ 391 + +__HOST_DEVICE__ __hip_bfloat162 operator*=(__hip_bfloat162& l, const __hip_bfloat162& r) { 392 + + l = __hmul2(l, r); 393 + + return l; 394 + +} 395 + + 396 + +/** 397 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 398 + + * \brief Operator to unary+ on a __hip_bfloat162 number 399 + + */ 400 + +__HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l) { return l; } 401 + + 402 + +/** 403 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 404 + + * \brief Operator to add two __hip_bfloat162 numbers 405 + + */ 406 + +__HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l, const __hip_bfloat162& r) { 407 + + return __hadd2(l, r); 408 + +} 409 + + 410 + +/** 411 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 412 + + * \brief Operator to negate a __hip_bfloat162 number 413 + + */ 414 + +__HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l) { return __hneg2(l); } 415 + + 416 + +/** 417 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 418 + + * \brief Operator to subtract two __hip_bfloat162 numbers 419 + + */ 420 + +__HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l, const __hip_bfloat162& r) { 421 + + return __hsub2(l, r); 422 + +} 423 + + 424 + +/** 425 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 426 + + * \brief Operator to post increment a __hip_bfloat162 number 427 + + */ 428 + +__HOST_DEVICE__ __hip_bfloat162 operator++(__hip_bfloat162& l, const int) { 429 + + auto ret = l; 430 + + l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); 431 + + return ret; 432 + +} 433 + + 434 + +/** 435 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 436 + + * \brief Operator to pre increment a __hip_bfloat162 number 437 + + */ 438 + +__HOST_DEVICE__ __hip_bfloat162& operator++(__hip_bfloat162& l) { 439 + + l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); 440 + + return l; 441 + +} 442 + + 443 + +/** 444 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 445 + + * \brief Operator to post decrement a __hip_bfloat162 number 446 + + */ 447 + +__HOST_DEVICE__ __hip_bfloat162 operator--(__hip_bfloat162& l, const int) { 448 + + auto ret = l; 449 + + l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); 450 + + return ret; 451 + +} 452 + + 453 + +/** 454 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 455 + + * \brief Operator to pre decrement a __hip_bfloat162 number 456 + + */ 457 + +__HOST_DEVICE__ __hip_bfloat162& operator--(__hip_bfloat162& l) { 458 + + l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); 459 + + return l; 460 + +} 461 + + 462 + +/** 463 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 464 + + * \brief Operator to add-assign two __hip_bfloat162 numbers 465 + + */ 466 + +__HOST_DEVICE__ __hip_bfloat162& operator+=(__hip_bfloat162& l, const __hip_bfloat162& r) { 467 + + l = l + r; 468 + + return l; 469 + +} 470 + + 471 + +/** 472 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 473 + + * \brief Operator to subtract-assign two __hip_bfloat162 numbers 474 + + */ 475 + +__HOST_DEVICE__ __hip_bfloat162& operator-=(__hip_bfloat162& l, const __hip_bfloat162& r) { 476 + + l = l - r; 477 + + return l; 478 + +} 479 + + 480 + +/** 481 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 482 + + * \brief Operator to divide two __hip_bfloat162 numbers 483 + + */ 484 + +__HOST_DEVICE__ __hip_bfloat162 operator/(const __hip_bfloat162& l, const __hip_bfloat162& r) { 485 + + return __h2div(l, r); 486 + +} 487 + + 488 + +/** 489 + + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH 490 + + * \brief Operator to divide-assign two __hip_bfloat162 numbers 491 + + */ 492 + +__HOST_DEVICE__ __hip_bfloat162& operator/=(__hip_bfloat162& l, const __hip_bfloat162& r) { 493 + + l = l / r; 494 + + return l; 495 + +} 496 + + 497 + /** 498 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 499 + * \brief Compare two bfloat162 values 500 + */ 501 + -__device__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) { 502 + +__HOST_DEVICE__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) { 503 + return __bfloat162float(a) == __bfloat162float(b); 504 + } 505 + 506 + @@ -442,7 +689,7 @@ __device__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) { 507 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 508 + * \brief Compare two bfloat162 values - unordered equal 509 + */ 510 + -__device__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) { 511 + +__HOST_DEVICE__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) { 512 + return !(__bfloat162float(a) < __bfloat162float(b)) && 513 + !(__bfloat162float(a) > __bfloat162float(b)); 514 + } 515 + @@ -451,7 +698,7 @@ __device__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) { 516 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 517 + * \brief Compare two bfloat162 values - greater than 518 + */ 519 + -__device__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) { 520 + +__HOST_DEVICE__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) { 521 + return __bfloat162float(a) > __bfloat162float(b); 522 + } 523 + 524 + @@ -459,7 +706,7 @@ __device__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) { 525 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 526 + * \brief Compare two bfloat162 values - unordered greater than 527 + */ 528 + -__device__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 529 + +__HOST_DEVICE__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 530 + return !(__bfloat162float(a) <= __bfloat162float(b)); 531 + } 532 + 533 + @@ -467,7 +714,7 @@ __device__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 534 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 535 + * \brief Compare two bfloat162 values - greater than equal 536 + */ 537 + -__device__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) { 538 + +__HOST_DEVICE__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) { 539 + return __bfloat162float(a) >= __bfloat162float(b); 540 + } 541 + 542 + @@ -475,7 +722,7 @@ __device__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) { 543 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 544 + * \brief Compare two bfloat162 values - unordered greater than equal 545 + */ 546 + -__device__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 547 + +__HOST_DEVICE__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 548 + return !(__bfloat162float(a) < __bfloat162float(b)); 549 + } 550 + 551 + @@ -483,7 +730,7 @@ __device__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 552 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 553 + * \brief Compare two bfloat162 values - not equal 554 + */ 555 + -__device__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) { 556 + +__HOST_DEVICE__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) { 557 + return __bfloat162float(a) != __bfloat162float(b); 558 + } 559 + 560 + @@ -491,7 +738,7 @@ __device__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) { 561 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 562 + * \brief Compare two bfloat162 values - unordered not equal 563 + */ 564 + -__device__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 565 + +__HOST_DEVICE__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 566 + return !(__bfloat162float(a) == __bfloat162float(b)); 567 + } 568 + 569 + @@ -499,23 +746,31 @@ __device__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 570 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 571 + * \brief Compare two bfloat162 values - return max 572 + */ 573 + -__device__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) { 574 + +__HOST_DEVICE__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) { 575 + +#if __HIP_DEVICE_COMPILE__ 576 + return __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a), __bfloat162float(b))); 577 + +#else 578 + + return __float2bfloat16(std::max(__bfloat162float(a), __bfloat162float(b))); 579 + +#endif 580 + } 581 + 582 + /** 583 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 584 + * \brief Compare two bfloat162 values - return min 585 + */ 586 + -__device__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) { 587 + +__HOST_DEVICE__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) { 588 + +#if __HIP_DEVICE_COMPILE__ 589 + return __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a), __bfloat162float(b))); 590 + +#else 591 + + return __float2bfloat16(std::min(__bfloat162float(a), __bfloat162float(b))); 592 + +#endif 593 + } 594 + 595 + /** 596 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 597 + * \brief Compare two bfloat162 values - less than operator 598 + */ 599 + -__device__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) { 600 + +__HOST_DEVICE__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) { 601 + return __bfloat162float(a) < __bfloat162float(b); 602 + } 603 + 604 + @@ -523,15 +778,15 @@ __device__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) { 605 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 606 + * \brief Compare two bfloat162 values - unordered less than 607 + */ 608 + -__device__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 609 + +__HOST_DEVICE__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 610 + return !(__bfloat162float(a) >= __bfloat162float(b)); 611 + } 612 + 613 + /** 614 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 615 + - * \brief Compare two bfloat162 values - less than 616 + + * \brief Compare two bfloat162 values - less than equal 617 + */ 618 + -__device__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) { 619 + +__HOST_DEVICE__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) { 620 + return __bfloat162float(a) <= __bfloat162float(b); 621 + } 622 + 623 + @@ -539,7 +794,7 @@ __device__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) { 624 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 625 + * \brief Compare two bfloat162 values - unordered less than equal 626 + */ 627 + -__device__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 628 + +__HOST_DEVICE__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 629 + return !(__bfloat162float(a) > __bfloat162float(b)); 630 + } 631 + 632 + @@ -547,19 +802,33 @@ __device__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) { 633 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 634 + * \brief Checks if number is inf 635 + */ 636 + -__device__ int __hisinf(const __hip_bfloat16 a) { return __ocml_isinf_f32(__bfloat162float(a)); } 637 + +__HOST_DEVICE__ int __hisinf(const __hip_bfloat16 a) { 638 + + unsigned short sign = a.data & 0x8000U; 639 + +#if __HIP_DEVICE_COMPILE__ 640 + + int res = __ocml_isinf_f32(__bfloat162float(a)); 641 + +#else 642 + + int res = std::isinf(__bfloat162float(a)) ? 1 : 0; 643 + +#endif 644 + + return (res == 0) ? res : ((sign != 0U) ? -res : res); 645 + +} 646 + 647 + /** 648 + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 649 + * \brief Checks if number is nan 650 + */ 651 + -__device__ bool __hisnan(const __hip_bfloat16 a) { return __ocml_isnan_f32(__bfloat162float(a)); } 652 + +__HOST_DEVICE__ bool __hisnan(const __hip_bfloat16 a) { 653 + +#if __HIP_DEVICE_COMPILE__ 654 + + return __ocml_isnan_f32(__bfloat162float(a)); 655 + +#else 656 + + return std::isnan(__bfloat162float(a)); 657 + +#endif 658 + +} 659 + 660 + /** 661 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 662 + * \brief Checks if two numbers are equal 663 + */ 664 + -__device__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 665 + +__HOST_DEVICE__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 666 + return __heq(a.x, b.x) && __heq(a.y, b.y); 667 + } 668 + 669 + @@ -567,7 +836,7 @@ __device__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 670 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 671 + * \brief Checks if two numbers are equal - unordered 672 + */ 673 + -__device__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 674 + +__HOST_DEVICE__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 675 + return __hequ(a.x, b.x) && __hequ(a.y, b.y); 676 + } 677 + 678 + @@ -575,7 +844,7 @@ __device__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 679 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 680 + * \brief Check for a >= b 681 + */ 682 + -__device__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 683 + +__HOST_DEVICE__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 684 + return __hge(a.x, b.x) && __hge(a.y, b.y); 685 + } 686 + 687 + @@ -583,7 +852,7 @@ __device__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 688 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 689 + * \brief Check for a >= b - unordered 690 + */ 691 + -__device__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 692 + +__HOST_DEVICE__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 693 + return __hgeu(a.x, b.x) && __hgeu(a.y, b.y); 694 + } 695 + 696 + @@ -591,7 +860,7 @@ __device__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 697 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 698 + * \brief Check for a > b 699 + */ 700 + -__device__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 701 + +__HOST_DEVICE__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 702 + return __hgt(a.x, b.x) && __hgt(a.y, b.y); 703 + } 704 + 705 + @@ -599,7 +868,7 @@ __device__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 706 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 707 + * \brief Check for a > b - unordered 708 + */ 709 + -__device__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 710 + +__HOST_DEVICE__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 711 + return __hgtu(a.x, b.x) && __hgtu(a.y, b.y); 712 + } 713 + 714 + @@ -607,7 +876,7 @@ __device__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 715 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 716 + * \brief Check for a <= b 717 + */ 718 + -__device__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 719 + +__HOST_DEVICE__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 720 + return __hle(a.x, b.x) && __hle(a.y, b.y); 721 + } 722 + 723 + @@ -615,7 +884,7 @@ __device__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 724 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 725 + * \brief Check for a <= b - unordered 726 + */ 727 + -__device__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 728 + +__HOST_DEVICE__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 729 + return __hleu(a.x, b.x) && __hleu(a.y, b.y); 730 + } 731 + 732 + @@ -623,7 +892,7 @@ __device__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 733 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 734 + * \brief Check for a < b 735 + */ 736 + -__device__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 737 + +__HOST_DEVICE__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 738 + return __hlt(a.x, b.x) && __hlt(a.y, b.y); 739 + } 740 + 741 + @@ -631,7 +900,7 @@ __device__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 742 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 743 + * \brief Check for a < b - unordered 744 + */ 745 + -__device__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 746 + +__HOST_DEVICE__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 747 + return __hltu(a.x, b.x) && __hltu(a.y, b.y); 748 + } 749 + 750 + @@ -639,7 +908,7 @@ __device__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 751 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 752 + * \brief Check for a != b 753 + */ 754 + -__device__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 755 + +__HOST_DEVICE__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 756 + return __hne(a.x, b.x) && __hne(a.y, b.y); 757 + } 758 + 759 + @@ -647,7 +916,7 @@ __device__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 760 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 761 + * \brief Check for a != b 762 + */ 763 + -__device__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 764 + +__HOST_DEVICE__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 765 + return __hneu(a.x, b.x) && __hneu(a.y, b.y); 766 + } 767 + 768 + @@ -655,84 +924,175 @@ __device__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 769 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 770 + * \brief Check for a != b, returns 1.0 if equal, otherwise 0.0 771 + */ 772 + -__device__ __hip_bfloat162 __heq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 773 + - return __hip_bfloat162{{__heq(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, 774 + - {__heq(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; 775 + +__HOST_DEVICE__ __hip_bfloat162 __heq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 776 + + return __hip_bfloat162{{__heq(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, 777 + + {__heq(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; 778 + } 779 + 780 + /** 781 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 782 + * \brief Check for a >= b, returns 1.0 if greater than equal, otherwise 0.0 783 + */ 784 + -__device__ __hip_bfloat162 __hge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 785 + - return __hip_bfloat162{{__hge(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, 786 + - {__hge(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; 787 + +__HOST_DEVICE__ __hip_bfloat162 __hge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 788 + + return __hip_bfloat162{{__hge(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, 789 + + {__hge(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; 790 + } 791 + 792 + /** 793 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 794 + * \brief Check for a > b, returns 1.0 if greater than equal, otherwise 0.0 795 + */ 796 + -__device__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 797 + - return __hip_bfloat162{{__hgt(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, 798 + - {__hgt(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; 799 + +__HOST_DEVICE__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 800 + + return __hip_bfloat162{{__hgt(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, 801 + + {__hgt(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}}; 802 + } 803 + 804 + /** 805 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 806 + * \brief Check for a is NaN, returns 1.0 if NaN, otherwise 0.0 807 + */ 808 + -__device__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) { 809 + - return __hip_bfloat162{ 810 + - {__ocml_isnan_f32(__bfloat162float(a.x)) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, 811 + - {__ocml_isnan_f32(__bfloat162float(a.y)) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; 812 + +__HOST_DEVICE__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) { 813 + + return __hip_bfloat162{{__hisnan(a.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, 814 + + {__hisnan(a.y) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}}; 815 + } 816 + 817 + /** 818 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 819 + * \brief Check for a <= b, returns 1.0 if greater than equal, otherwise 0.0 820 + */ 821 + -__device__ __hip_bfloat162 __hle2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 822 + - return __hip_bfloat162{{__hle(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, 823 + - {__hle(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; 824 + +__HOST_DEVICE__ __hip_bfloat162 __hle2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 825 + + return __hip_bfloat162{{__hle(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, 826 + + {__hle(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; 827 + } 828 + 829 + /** 830 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 831 + * \brief Check for a < b, returns 1.0 if greater than equal, otherwise 0.0 832 + */ 833 + -__device__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 834 + - return __hip_bfloat162{{__hlt(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, 835 + - {__hlt(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; 836 + +__HOST_DEVICE__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 837 + + return __hip_bfloat162{{__hlt(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, 838 + + {__hlt(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; 839 + } 840 + 841 + /** 842 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 843 + * \brief Returns max of two elements 844 + */ 845 + -__device__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 846 + - return __hip_bfloat162{ 847 + - __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a.x), __bfloat162float(b.x))), 848 + - __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a.y), __bfloat162float(b.y)))}; 849 + +__HOST_DEVICE__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 850 + + return __hip_bfloat162{__hmax(a.x, b.x), __hmax(a.y, b.y)}; 851 + } 852 + 853 + /** 854 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 855 + * \brief Returns min of two elements 856 + */ 857 + -__device__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 858 + - return __hip_bfloat162{ 859 + - __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a.x), __bfloat162float(b.x))), 860 + - __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a.y), __bfloat162float(b.y)))}; 861 + +__HOST_DEVICE__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 862 + + return __hip_bfloat162{__hmin(a.x, b.x), __hmin(a.y, b.y)}; 863 + } 864 + 865 + /** 866 + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 867 + * \brief Checks for not equal to 868 + */ 869 + -__device__ __hip_bfloat162 __hne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 870 + - return __hip_bfloat162{{__hne(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, 871 + - {__hne(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; 872 + +__HOST_DEVICE__ __hip_bfloat162 __hne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { 873 + + return __hip_bfloat162{{__hne(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, 874 + + {__hne(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; 875 + +} 876 + + 877 + +/** 878 + + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 879 + + * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers 880 + + */ 881 + +__HOST_DEVICE__ bool operator==(const __hip_bfloat16& l, const __hip_bfloat16& r) { 882 + + return __heq(l, r); 883 + +} 884 + + 885 + +/** 886 + + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 887 + + * \brief Operator to perform a not equal on two __hip_bfloat16 numbers 888 + + */ 889 + +__HOST_DEVICE__ bool operator!=(const __hip_bfloat16& l, const __hip_bfloat16& r) { 890 + + return __hne(l, r); 891 + +} 892 + + 893 + +/** 894 + + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 895 + + * \brief Operator to perform a less than on two __hip_bfloat16 numbers 896 + + */ 897 + +__HOST_DEVICE__ bool operator<(const __hip_bfloat16& l, const __hip_bfloat16& r) { 898 + + return __hlt(l, r); 899 + +} 900 + + 901 + +/** 902 + + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 903 + + * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers 904 + + */ 905 + +__HOST_DEVICE__ bool operator<=(const __hip_bfloat16& l, const __hip_bfloat16& r) { 906 + + return __hle(l, r); 907 + +} 908 + + 909 + +/** 910 + + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 911 + + * \brief Operator to perform a greater than on two __hip_bfloat16 numbers 912 + + */ 913 + +__HOST_DEVICE__ bool operator>(const __hip_bfloat16& l, const __hip_bfloat16& r) { 914 + + return __hgt(l, r); 915 + +} 916 + + 917 + +/** 918 + + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 919 + + * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers 920 + + */ 921 + +__HOST_DEVICE__ bool operator>=(const __hip_bfloat16& l, const __hip_bfloat16& r) { 922 + + return __hge(l, r); 923 + +} 924 + + 925 + +/** 926 + + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 927 + + * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers 928 + + */ 929 + +__HOST_DEVICE__ bool operator==(const __hip_bfloat162& l, const __hip_bfloat162& r) { 930 + + return __heq(l.x, r.x) && __heq(l.y, r.y); 931 + +} 932 + + 933 + +/** 934 + + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 935 + + * \brief Operator to perform a not equal on two __hip_bfloat16 numbers 936 + + */ 937 + +__HOST_DEVICE__ bool operator!=(const __hip_bfloat162& l, const __hip_bfloat162& r) { 938 + + return __hne(l.x, r.x) || __hne(l.y, r.y); 939 + +} 940 + + 941 + +/** 942 + + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 943 + + * \brief Operator to perform a less than on two __hip_bfloat16 numbers 944 + + */ 945 + +__HOST_DEVICE__ bool operator<(const __hip_bfloat162& l, const __hip_bfloat162& r) { 946 + + return __hlt(l.x, r.x) && __hlt(l.y, r.y); 947 + +} 948 + + 949 + +/** 950 + + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 951 + + * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers 952 + + */ 953 + +__HOST_DEVICE__ bool operator<=(const __hip_bfloat162& l, const __hip_bfloat162& r) { 954 + + return __hle(l.x, r.x) && __hle(l.y, r.y); 955 + +} 956 + + 957 + +/** 958 + + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP 959 + + * \brief Operator to perform a greater than on two __hip_bfloat16 numbers 960 + + */ 961 + +__HOST_DEVICE__ bool operator>(const __hip_bfloat162& l, const __hip_bfloat162& r) { 962 + + return __hgt(l.x, r.x) && __hgt(l.y, r.y); 963 + +} 964 + + 965 + +/** 966 + + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP 967 + + * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers 968 + + */ 969 + +__HOST_DEVICE__ bool operator>=(const __hip_bfloat162& l, const __hip_bfloat162& r) { 970 + + return __hge(l.x, r.x) && __hge(l.y, r.y); 971 + } 972 + 973 + /** 974 + @@ -974,5 +1334,4 @@ __device__ __hip_bfloat162 h2sqrt(const __hip_bfloat162 h) { 975 + __device__ __hip_bfloat162 h2trunc(const __hip_bfloat162 h) { 976 + return __hip_bfloat162{htrunc(h.x), htrunc(h.y)}; 977 + } 978 + - 979 + #endif
+5
pkgs/development/rocm-modules/5/clr/default.nix
··· 88 88 "-DCMAKE_INSTALL_LIBDIR=lib" 89 89 ]; 90 90 91 + patches = [ 92 + ./add-missing-operators.patch 93 + ./static-functions.patch 94 + ]; 95 + 91 96 postPatch = '' 92 97 patchShebangs hipamd/src 93 98
+31
pkgs/development/rocm-modules/5/clr/static-functions.patch
··· 1 + From 77c581a3ebd47b5e2908973b70adea66891159ee Mon Sep 17 00:00:00 2001 2 + From: Jatin Chaudhary <JatinJaikishan.Chaudhary@amd.com> 3 + Date: Mon, 4 Dec 2023 17:21:39 +0000 4 + Subject: [PATCH] SWDEV-435702 - the functions in bf16 header need to be static 5 + 6 + If the compiler decides not to inline these functions, we might break ODR (one definition rule) due to this file being included in multiple files and being linked together 7 + 8 + Change-Id: Iacbfdabb53f5b4e5db8c690b23f3730ec9af16c0 9 + --- 10 + hipamd/include/hip/amd_detail/amd_hip_bf16.h | 4 ++-- 11 + 1 file changed, 2 insertions(+), 2 deletions(-) 12 + 13 + diff --git a/hipamd/include/hip/amd_detail/amd_hip_bf16.h b/hipamd/include/hip/amd_detail/amd_hip_bf16.h 14 + index 836e090eb..204269a84 100644 15 + --- a/hipamd/include/hip/amd_detail/amd_hip_bf16.h 16 + +++ b/hipamd/include/hip/amd_detail/amd_hip_bf16.h 17 + @@ -94,12 +94,12 @@ 18 + #include "math_fwd.h" // ocml device functions 19 + 20 + #if defined(__HIPCC_RTC__) 21 + -#define __HOST_DEVICE__ __device__ 22 + +#define __HOST_DEVICE__ __device__ static 23 + #else 24 + #include <algorithm> 25 + #include <climits> 26 + #include <cmath> 27 + -#define __HOST_DEVICE__ __host__ __device__ inline 28 + +#define __HOST_DEVICE__ __host__ __device__ static inline 29 + #endif 30 + 31 + #define HIPRT_ONE_BF16 __float2bfloat16(1.0f)