Ginkgo Generated from branch based on main. Ginkgo version 1.10.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
bfloat16.hpp
1// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_BFLOAT16_HPP_
6#define GKO_PUBLIC_CORE_BASE_BFLOAT16_HPP_
7
8
9#include <climits>
10#include <complex>
11#include <cstdint>
12#include <cstring>
13#include <type_traits>
14
15#include <ginkgo/core/base/half.hpp>
16
17
18class __nv_bfloat16;
19class hip_bfloat16;
20class __hip_bfloat16;
21
22
23namespace gko {
24
25
26class bfloat16;
27
28
29namespace detail {
30template <>
31struct basic_float_traits<bfloat16> {
32 using type = bfloat16;
33 static constexpr int sign_bits = 1;
34 static constexpr int significand_bits = 7;
35 static constexpr int exponent_bits = 8;
36 static constexpr bool rounds_to_nearest = true;
37};
38
39template <>
40struct basic_float_traits<__nv_bfloat16> {
41 using type = __nv_bfloat16;
42 static constexpr int sign_bits = 1;
43 static constexpr int significand_bits = 7;
44 static constexpr int exponent_bits = 8;
45 static constexpr bool rounds_to_nearest = true;
46};
47
48template <>
49struct basic_float_traits<hip_bfloat16> {
50 using type = hip_bfloat16;
51 static constexpr int sign_bits = 1;
52 static constexpr int significand_bits = 7;
53 static constexpr int exponent_bits = 8;
54 static constexpr bool rounds_to_nearest = true;
55};
56
57template <>
58struct basic_float_traits<__hip_bfloat16> {
59 using type = __hip_bfloat16;
60 static constexpr int sign_bits = 1;
61 static constexpr int significand_bits = 7;
62 static constexpr int exponent_bits = 8;
63 static constexpr bool rounds_to_nearest = true;
64};
65
66
67} // namespace detail
68
69
76class alignas(std::uint16_t) bfloat16 {
77public:
78 // create bfloat16 value from the bits directly.
79 static constexpr bfloat16 create_from_bits(
80 const std::uint16_t& bits) noexcept
81 {
82 bfloat16 result;
83 result.data_ = bits;
84 return result;
85 }
86
87 // TODO: NVHPC (host side) may not use zero initialization for the data
88 // member by default constructor in some cases. Not sure whether it is
89 // caused by something else in jacobi or isai.
90 constexpr bfloat16() noexcept : data_(0){};
91
92 template <typename T,
93 typename = std::enable_if_t<std::is_scalar<T>::value ||
94 std::is_same_v<T, half>>>
95 bfloat16(const T& val) : data_(0)
96 {
97 this->float2bfloat16(static_cast<float>(val));
98 }
99
100 template <typename V>
101 bfloat16& operator=(const V& val)
102 {
103 this->float2bfloat16(static_cast<float>(val));
104 return *this;
105 }
106
107 operator float() const noexcept
108 {
109 const auto bits = bfloat162float(data_);
110 float ans(0);
111 std::memcpy(&ans, &bits, sizeof(float));
112 return ans;
113 }
114
115 // can not use bfloat16 operator _op(const bfloat16) for bfloat16 + bfloat16
116 // operation will cast it to float and then do float operation such that it
117 // becomes float in the end.
118#define BFLOAT16_OPERATOR(_op, _opeq) \
119 friend bfloat16 operator _op(const bfloat16& lhf, const bfloat16& rhf) \
120 { \
121 return static_cast<bfloat16>(static_cast<float>(lhf) \
122 _op static_cast<float>(rhf)); \
123 } \
124 bfloat16& operator _opeq(const bfloat16& hf) \
125 { \
126 auto result = *this _op hf; \
127 data_ = result.data_; \
128 return *this; \
129 }
130
131 BFLOAT16_OPERATOR(+, +=)
132 BFLOAT16_OPERATOR(-, -=)
133 BFLOAT16_OPERATOR(*, *=)
134 BFLOAT16_OPERATOR(/, /=)
135
136#undef BFLOAT16_OPERATOR
137
138 // Do operation with different type
139 // If it is floating point, using floating point as type.
140 // If it is bfloat16, using float as type.
141 // If it is integer, using bfloat16 as type.
142#define BFLOAT16_FRIEND_OPERATOR(_op, _opeq) \
143 template <typename T> \
144 friend std::enable_if_t< \
145 !std::is_same<T, bfloat16>::value && \
146 (std::is_scalar<T>::value || std::is_same_v<T, half>), \
147 std::conditional_t< \
148 std::is_floating_point<T>::value, T, \
149 std::conditional_t<std::is_same_v<T, half>, float, bfloat16>>> \
150 operator _op(const bfloat16& hf, const T& val) \
151 { \
152 using type = \
153 std::conditional_t<std::is_floating_point<T>::value, T, bfloat16>; \
154 auto result = static_cast<type>(hf); \
155 result _opeq static_cast<type>(val); \
156 return result; \
157 } \
158 template <typename T> \
159 friend std::enable_if_t< \
160 !std::is_same<T, bfloat16>::value && \
161 (std::is_scalar<T>::value || std::is_same_v<T, half>), \
162 std::conditional_t< \
163 std::is_floating_point<T>::value, T, \
164 std::conditional_t<std::is_same_v<T, half>, float, bfloat16>>> \
165 operator _op(const T& val, const bfloat16& hf) \
166 { \
167 using type = \
168 std::conditional_t<std::is_floating_point<T>::value, T, bfloat16>; \
169 auto result = static_cast<type>(val); \
170 result _opeq static_cast<type>(hf); \
171 return result; \
172 }
173
174 BFLOAT16_FRIEND_OPERATOR(+, +=)
175 BFLOAT16_FRIEND_OPERATOR(-, -=)
176 BFLOAT16_FRIEND_OPERATOR(*, *=)
177 BFLOAT16_FRIEND_OPERATOR(/, /=)
178
179#undef BFLOAT16_FRIEND_OPERATOR
180
181 // the negative
182 bfloat16 operator-() const
183 {
184 auto val = 0.0f - *this;
185 return static_cast<bfloat16>(val);
186 }
187
188private:
189 using f16_traits = detail::float_traits<bfloat16>;
190 using f32_traits = detail::float_traits<float>;
191
192 void float2bfloat16(const float& val) noexcept
193 {
194 std::uint32_t bit_val(0);
195 std::memcpy(&bit_val, &val, sizeof(float));
196 data_ = float2bfloat16(bit_val);
197 }
198
199 static constexpr std::uint16_t float2bfloat16(std::uint32_t data_) noexcept
200 {
201 using conv = detail::precision_converter<float, bfloat16>;
202 if (f32_traits::is_inf(data_)) {
203 return conv::shift_sign(data_) | f16_traits::exponent_mask;
204 } else if (f32_traits::is_nan(data_)) {
205 return conv::shift_sign(data_) | f16_traits::exponent_mask |
206 f16_traits::significand_mask;
207 } else {
208 const auto exp = conv::shift_exponent(data_);
209 if (f16_traits::is_inf(exp)) {
210 return conv::shift_sign(data_) | exp;
211 } else if (f16_traits::is_denom(exp)) {
212 // TODO: handle denormals
213 return conv::shift_sign(data_);
214 } else {
215 // Rounding to even
216 const auto result = conv::shift_sign(data_) | exp |
217 conv::shift_significand(data_);
218 const auto tail =
219 data_ & static_cast<f32_traits::bits_type>(
220 (1 << conv::significand_offset) - 1);
221
222 constexpr auto bfloat16 = static_cast<f32_traits::bits_type>(
223 1 << (conv::significand_offset - 1));
224 return result + (tail > bfloat16 ||
225 ((tail == bfloat16) && (result & 1)));
226 }
227 }
228 }
229
230 static constexpr std::uint32_t bfloat162float(std::uint16_t data_) noexcept
231 {
232 using conv = detail::precision_converter<bfloat16, float>;
233 if (f16_traits::is_inf(data_)) {
234 return conv::shift_sign(data_) | f32_traits::exponent_mask;
235 } else if (f16_traits::is_nan(data_)) {
236 return conv::shift_sign(data_) | f32_traits::exponent_mask |
237 f32_traits::significand_mask;
238 } else if (f16_traits::is_denom(data_)) {
239 // TODO: handle denormals
240 return conv::shift_sign(data_);
241 } else {
242 return conv::shift_sign(data_) | conv::shift_exponent(data_) |
243 conv::shift_significand(data_);
244 }
245 }
246
247 std::uint16_t data_;
248};
249
250
251} // namespace gko
252
253
254namespace std {
255
256
257template <>
258class complex<gko::bfloat16> {
259public:
260 using value_type = gko::bfloat16;
261
262 complex(const value_type& real = value_type(0.f),
263 const value_type& imag = value_type(0.f))
264 : real_(real), imag_(imag)
265 {}
266
267 template <
268 typename T, typename U,
269 typename = std::enable_if_t<
270 (std::is_scalar<T>::value || std::is_same_v<T, gko::half>)&&(
271 std::is_scalar<U>::value || std::is_same_v<U, gko::half>)>>
272 explicit complex(const T& real, const U& imag)
273 : real_(static_cast<value_type>(real)),
274 imag_(static_cast<value_type>(imag))
275 {}
276
277 template <typename T,
278 typename = std::enable_if_t<std::is_scalar<T>::value ||
279 std::is_same_v<T, gko::half>>>
280 complex(const T& real)
281 : real_(static_cast<value_type>(real)),
282 imag_(static_cast<value_type>(0.f))
283 {}
284
285 // When using complex(real, imag), MSVC with CUDA try to recognize the
286 // complex is a member not constructor.
287 template <typename T,
288 typename = std::enable_if_t<std::is_scalar<T>::value ||
289 std::is_same_v<T, gko::half>>>
290 explicit complex(const complex<T>& other)
291 : real_(static_cast<value_type>(other.real())),
292 imag_(static_cast<value_type>(other.imag()))
293 {}
294
295 value_type real() const noexcept { return real_; }
296
297 value_type imag() const noexcept { return imag_; }
298
299 operator std::complex<float>() const noexcept
300 {
301 return std::complex<float>(static_cast<float>(real_),
302 static_cast<float>(imag_));
303 }
304
305 template <typename V>
306 complex& operator=(const V& val)
307 {
308 real_ = val;
309 imag_ = value_type();
310 return *this;
311 }
312
313 template <typename V>
314 complex& operator=(const std::complex<V>& val)
315 {
316 real_ = val.real();
317 imag_ = val.imag();
318 return *this;
319 }
320
321 complex& operator+=(const value_type& real)
322 {
323 real_ += real;
324 return *this;
325 }
326
327 complex& operator-=(const value_type& real)
328 {
329 real_ -= real;
330 return *this;
331 }
332
333 complex& operator*=(const value_type& real)
334 {
335 real_ *= real;
336 imag_ *= real;
337 return *this;
338 }
339
340 complex& operator/=(const value_type& real)
341 {
342 real_ /= real;
343 imag_ /= real;
344 return *this;
345 }
346
347 template <typename T>
348 complex& operator+=(const complex<T>& val)
349 {
350 real_ += val.real();
351 imag_ += val.imag();
352 return *this;
353 }
354
355 template <typename T>
356 complex& operator-=(const complex<T>& val)
357 {
358 real_ -= val.real();
359 imag_ -= val.imag();
360 return *this;
361 }
362
363 template <typename T>
364 complex& operator*=(const complex<T>& val)
365 {
366 auto val_f = static_cast<std::complex<float>>(val);
367 auto result_f = static_cast<std::complex<float>>(*this);
368 result_f *= val_f;
369 real_ = result_f.real();
370 imag_ = result_f.imag();
371 return *this;
372 }
373
374 template <typename T>
375 complex& operator/=(const complex<T>& val)
376 {
377 auto val_f = static_cast<std::complex<float>>(val);
378 auto result_f = static_cast<std::complex<float>>(*this);
379 result_f /= val_f;
380 real_ = result_f.real();
381 imag_ = result_f.imag();
382 return *this;
383 }
384
385#define COMPLEX_BFLOAT16_OPERATOR(_op, _opeq) \
386 friend complex operator _op(const complex& lhf, const complex& rhf) \
387 { \
388 auto a = lhf; \
389 a _opeq rhf; \
390 return a; \
391 }
392
393 COMPLEX_BFLOAT16_OPERATOR(+, +=)
394 COMPLEX_BFLOAT16_OPERATOR(-, -=)
395 COMPLEX_BFLOAT16_OPERATOR(*, *=)
396 COMPLEX_BFLOAT16_OPERATOR(/, /=)
397
398#undef COMPLEX_BFLOAT16_OPERATOR
399
400private:
401 value_type real_;
402 value_type imag_;
403};
404
405
406template <>
407struct numeric_limits<gko::bfloat16> {
408 static constexpr bool is_specialized{true};
409 static constexpr bool is_signed{true};
410 static constexpr bool is_integer{false};
411 static constexpr bool is_exact{false};
412 static constexpr bool is_bounded{true};
413 static constexpr bool is_modulo{false};
414 static constexpr int digits{
415 gko::detail::float_traits<gko::bfloat16>::significand_bits + 1};
416 // 3/10 is approx. log_10(2)
417 static constexpr int digits10{digits * 3 / 10};
418
419 static constexpr gko::bfloat16 epsilon()
420 {
421 constexpr auto bits = static_cast<std::uint16_t>(0b0'01111000'0000000u);
422 return gko::bfloat16::create_from_bits(bits);
423 }
424
425 static constexpr gko::bfloat16 infinity()
426 {
427 constexpr auto bits = static_cast<std::uint16_t>(0b0'11111111'0000000u);
428 return gko::bfloat16::create_from_bits(bits);
429 }
430
431 static constexpr gko::bfloat16 min()
432 {
433 constexpr auto bits = static_cast<std::uint16_t>(0b0'00000001'0000000u);
434 return gko::bfloat16::create_from_bits(bits);
435 }
436
437 static constexpr gko::bfloat16 max()
438 {
439 constexpr auto bits = static_cast<std::uint16_t>(0b0'11111110'1111111u);
440 return gko::bfloat16::create_from_bits(bits);
441 }
442
443 static constexpr gko::bfloat16 lowest()
444 {
445 constexpr auto bits = static_cast<std::uint16_t>(0b1'11111110'1111111u);
446 return gko::bfloat16::create_from_bits(bits);
447 };
448
449 static constexpr gko::bfloat16 quiet_NaN()
450 {
451 constexpr auto bits = static_cast<std::uint16_t>(0b0'11111111'1111111u);
452 return gko::bfloat16::create_from_bits(bits);
453 }
454};
455
456
457// complex using a template on operator= for any kind of complex<T>, so we can
458// do full specialization for bfloat16
459template <>
460inline complex<double>& complex<double>::operator=(
461 const std::complex<gko::bfloat16>& a)
462{
463 complex<double> t(a.real(), a.imag());
464 operator=(t);
465 return *this;
466}
467
468
469// For MSVC
470template <>
471inline complex<float>& complex<float>::operator=(
472 const std::complex<gko::bfloat16>& a)
473{
474 complex<float> t(a.real(), a.imag());
475 operator=(t);
476 return *this;
477}
478
479
480} // namespace std
481
482
483#endif // GKO_PUBLIC_CORE_BASE_bfloat16_HPP_
A class providing basic support for bfloat16 precision floating point types.
Definition bfloat16.hpp:76
The Ginkgo namespace.
Definition abstract_factory.hpp:20
constexpr auto imag(const T &x)
Returns the imaginary part of the object.
Definition math.hpp:916
constexpr T min(const T &x, const T &y)
Returns the smaller of the arguments.
Definition math.hpp:750
constexpr T max(const T &x, const T &y)
Returns the larger of the arguments.
Definition math.hpp:732
constexpr auto real(const T &x)
Returns the real part of the object.
Definition math.hpp:900
STL namespace.