5#ifndef GKO_PUBLIC_CORE_BASE_BFLOAT16_HPP_
6#define GKO_PUBLIC_CORE_BASE_BFLOAT16_HPP_
15#include <ginkgo/core/base/half.hpp>
31struct basic_float_traits<bfloat16> {
32 using type = bfloat16;
33 static constexpr int sign_bits = 1;
34 static constexpr int significand_bits = 7;
35 static constexpr int exponent_bits = 8;
36 static constexpr bool rounds_to_nearest =
true;
40struct basic_float_traits<__nv_bfloat16> {
41 using type = __nv_bfloat16;
42 static constexpr int sign_bits = 1;
43 static constexpr int significand_bits = 7;
44 static constexpr int exponent_bits = 8;
45 static constexpr bool rounds_to_nearest =
true;
49struct basic_float_traits<hip_bfloat16> {
50 using type = hip_bfloat16;
51 static constexpr int sign_bits = 1;
52 static constexpr int significand_bits = 7;
53 static constexpr int exponent_bits = 8;
54 static constexpr bool rounds_to_nearest =
true;
58struct basic_float_traits<__hip_bfloat16> {
59 using type = __hip_bfloat16;
60 static constexpr int sign_bits = 1;
61 static constexpr int significand_bits = 7;
62 static constexpr int exponent_bits = 8;
63 static constexpr bool rounds_to_nearest =
true;
76class alignas(std::uint16_t) bfloat16 {
79 static constexpr bfloat16 create_from_bits(
80 const std::uint16_t& bits)
noexcept
90 constexpr bfloat16() noexcept : data_(0){};
93 typename = std::enable_if_t<std::is_scalar<T>::value ||
94 std::is_same_v<T, half>>>
95 bfloat16(
const T& val) : data_(0)
97 this->float2bfloat16(
static_cast<float>(val));
100 template <
typename V>
101 bfloat16& operator=(
const V& val)
103 this->float2bfloat16(
static_cast<float>(val));
107 operator float()
const noexcept
109 const auto bits = bfloat162float(data_);
111 std::memcpy(&ans, &bits,
sizeof(
float));
118#define BFLOAT16_OPERATOR(_op, _opeq) \
119 friend bfloat16 operator _op(const bfloat16& lhf, const bfloat16& rhf) \
121 return static_cast<bfloat16>(static_cast<float>(lhf) \
122 _op static_cast<float>(rhf)); \
124 bfloat16& operator _opeq(const bfloat16& hf) \
126 auto result = *this _op hf; \
127 data_ = result.data_; \
131 BFLOAT16_OPERATOR(+, +=)
132 BFLOAT16_OPERATOR(-, -=)
133 BFLOAT16_OPERATOR(*, *=)
134 BFLOAT16_OPERATOR(/, /=)
136#undef BFLOAT16_OPERATOR
142#define BFLOAT16_FRIEND_OPERATOR(_op, _opeq) \
143 template <typename T> \
144 friend std::enable_if_t< \
145 !std::is_same<T, bfloat16>::value && \
146 (std::is_scalar<T>::value || std::is_same_v<T, half>), \
147 std::conditional_t< \
148 std::is_floating_point<T>::value, T, \
149 std::conditional_t<std::is_same_v<T, half>, float, bfloat16>>> \
150 operator _op(const bfloat16& hf, const T& val) \
153 std::conditional_t<std::is_floating_point<T>::value, T, bfloat16>; \
154 auto result = static_cast<type>(hf); \
155 result _opeq static_cast<type>(val); \
158 template <typename T> \
159 friend std::enable_if_t< \
160 !std::is_same<T, bfloat16>::value && \
161 (std::is_scalar<T>::value || std::is_same_v<T, half>), \
162 std::conditional_t< \
163 std::is_floating_point<T>::value, T, \
164 std::conditional_t<std::is_same_v<T, half>, float, bfloat16>>> \
165 operator _op(const T& val, const bfloat16& hf) \
168 std::conditional_t<std::is_floating_point<T>::value, T, bfloat16>; \
169 auto result = static_cast<type>(val); \
170 result _opeq static_cast<type>(hf); \
174 BFLOAT16_FRIEND_OPERATOR(+, +=)
175 BFLOAT16_FRIEND_OPERATOR(-, -=)
176 BFLOAT16_FRIEND_OPERATOR(*, *=)
177 BFLOAT16_FRIEND_OPERATOR(/, /=)
179#undef BFLOAT16_FRIEND_OPERATOR
182 bfloat16 operator-()
const
184 auto val = 0.0f - *
this;
185 return static_cast<bfloat16
>(val);
189 using f16_traits = detail::float_traits<bfloat16>;
190 using f32_traits = detail::float_traits<float>;
192 void float2bfloat16(
const float& val)
noexcept
194 std::uint32_t bit_val(0);
195 std::memcpy(&bit_val, &val,
sizeof(
float));
196 data_ = float2bfloat16(bit_val);
199 static constexpr std::uint16_t float2bfloat16(std::uint32_t data_)
noexcept
201 using conv = detail::precision_converter<float, bfloat16>;
202 if (f32_traits::is_inf(data_)) {
203 return conv::shift_sign(data_) | f16_traits::exponent_mask;
204 }
else if (f32_traits::is_nan(data_)) {
205 return conv::shift_sign(data_) | f16_traits::exponent_mask |
206 f16_traits::significand_mask;
208 const auto exp = conv::shift_exponent(data_);
209 if (f16_traits::is_inf(exp)) {
210 return conv::shift_sign(data_) | exp;
211 }
else if (f16_traits::is_denom(exp)) {
213 return conv::shift_sign(data_);
216 const auto result = conv::shift_sign(data_) | exp |
217 conv::shift_significand(data_);
219 data_ &
static_cast<f32_traits::bits_type
>(
220 (1 << conv::significand_offset) - 1);
222 constexpr auto bfloat16 =
static_cast<f32_traits::bits_type
>(
223 1 << (conv::significand_offset - 1));
224 return result + (tail > bfloat16 ||
225 ((tail == bfloat16) && (result & 1)));
230 static constexpr std::uint32_t bfloat162float(std::uint16_t data_)
noexcept
232 using conv = detail::precision_converter<bfloat16, float>;
233 if (f16_traits::is_inf(data_)) {
234 return conv::shift_sign(data_) | f32_traits::exponent_mask;
235 }
else if (f16_traits::is_nan(data_)) {
236 return conv::shift_sign(data_) | f32_traits::exponent_mask |
237 f32_traits::significand_mask;
238 }
else if (f16_traits::is_denom(data_)) {
240 return conv::shift_sign(data_);
242 return conv::shift_sign(data_) | conv::shift_exponent(data_) |
243 conv::shift_significand(data_);
258class complex<
gko::bfloat16> {
262 complex(
const value_type& real = value_type(0.f),
263 const value_type& imag = value_type(0.f))
264 : real_(real), imag_(imag)
268 typename T,
typename U,
269 typename = std::enable_if_t<
270 (std::is_scalar<T>::value || std::is_same_v<T, gko::half>)&&(
271 std::is_scalar<U>::value || std::is_same_v<U, gko::half>)>>
272 explicit complex(
const T& real,
const U& imag)
273 : real_(static_cast<value_type>(
real)),
274 imag_(static_cast<value_type>(
imag))
277 template <
typename T,
278 typename = std::enable_if_t<std::is_scalar<T>::value ||
279 std::is_same_v<T, gko::half>>>
280 complex(
const T& real)
281 : real_(static_cast<value_type>(
real)),
282 imag_(static_cast<value_type>(0.f))
287 template <
typename T,
288 typename = std::enable_if_t<std::is_scalar<T>::value ||
289 std::is_same_v<T, gko::half>>>
290 explicit complex(
const complex<T>& other)
291 : real_(static_cast<value_type>(other.
real())),
292 imag_(static_cast<value_type>(other.
imag()))
295 value_type
real() const noexcept {
return real_; }
297 value_type
imag() const noexcept {
return imag_; }
299 operator std::complex<float>() const noexcept
301 return std::complex<float>(
static_cast<float>(real_),
302 static_cast<float>(imag_));
305 template <
typename V>
306 complex& operator=(
const V& val)
309 imag_ = value_type();
313 template <
typename V>
314 complex& operator=(
const std::complex<V>& val)
321 complex& operator+=(
const value_type& real)
327 complex& operator-=(
const value_type& real)
333 complex& operator*=(
const value_type& real)
340 complex& operator/=(
const value_type& real)
347 template <
typename T>
348 complex& operator+=(
const complex<T>& val)
355 template <
typename T>
356 complex& operator-=(
const complex<T>& val)
363 template <
typename T>
364 complex& operator*=(
const complex<T>& val)
366 auto val_f =
static_cast<std::complex<float>
>(val);
367 auto result_f =
static_cast<std::complex<float>
>(*this);
369 real_ = result_f.real();
370 imag_ = result_f.imag();
374 template <
typename T>
375 complex& operator/=(
const complex<T>& val)
377 auto val_f =
static_cast<std::complex<float>
>(val);
378 auto result_f =
static_cast<std::complex<float>
>(*this);
380 real_ = result_f.real();
381 imag_ = result_f.imag();
385#define COMPLEX_BFLOAT16_OPERATOR(_op, _opeq) \
386 friend complex operator _op(const complex& lhf, const complex& rhf) \
393 COMPLEX_BFLOAT16_OPERATOR(+, +=)
394 COMPLEX_BFLOAT16_OPERATOR(-, -=)
395 COMPLEX_BFLOAT16_OPERATOR(*, *=)
396 COMPLEX_BFLOAT16_OPERATOR(/, /=)
398#undef COMPLEX_BFLOAT16_OPERATOR
407struct numeric_limits<gko::bfloat16> {
408 static constexpr bool is_specialized{
true};
409 static constexpr bool is_signed{
true};
410 static constexpr bool is_integer{
false};
411 static constexpr bool is_exact{
false};
412 static constexpr bool is_bounded{
true};
413 static constexpr bool is_modulo{
false};
414 static constexpr int digits{
415 gko::detail::float_traits<gko::bfloat16>::significand_bits + 1};
417 static constexpr int digits10{digits * 3 / 10};
419 static constexpr gko::bfloat16 epsilon()
421 constexpr auto bits =
static_cast<std::uint16_t
>(0b0'01111000'0000000u);
422 return gko::bfloat16::create_from_bits(bits);
425 static constexpr gko::bfloat16 infinity()
427 constexpr auto bits =
static_cast<std::uint16_t
>(0b0'11111111'0000000u);
428 return gko::bfloat16::create_from_bits(bits);
431 static constexpr gko::bfloat16
min()
433 constexpr auto bits =
static_cast<std::uint16_t
>(0b0'00000001'0000000u);
434 return gko::bfloat16::create_from_bits(bits);
437 static constexpr gko::bfloat16
max()
439 constexpr auto bits =
static_cast<std::uint16_t
>(0b0'11111110'1111111u);
440 return gko::bfloat16::create_from_bits(bits);
443 static constexpr gko::bfloat16 lowest()
445 constexpr auto bits =
static_cast<std::uint16_t
>(0b1'11111110'1111111u);
446 return gko::bfloat16::create_from_bits(bits);
449 static constexpr gko::bfloat16 quiet_NaN()
451 constexpr auto bits =
static_cast<std::uint16_t
>(0b0'11111111'1111111u);
452 return gko::bfloat16::create_from_bits(bits);
460inline complex<double>& complex<double>::operator=(
461 const std::complex<gko::bfloat16>& a)
463 complex<double> t(a.real(), a.imag());
471inline complex<float>& complex<float>::operator=(
472 const std::complex<gko::bfloat16>& a)
474 complex<float> t(a.real(), a.imag());
A class providing basic support for bfloat16 precision floating point types.
Definition bfloat16.hpp:76
The Ginkgo namespace.
Definition abstract_factory.hpp:20
constexpr auto imag(const T &x)
Returns the imaginary part of the object.
Definition math.hpp:916
constexpr T min(const T &x, const T &y)
Returns the smaller of the arguments.
Definition math.hpp:750
constexpr T max(const T &x, const T &y)
Returns the larger of the arguments.
Definition math.hpp:732
constexpr auto real(const T &x)
Returns the real part of the object.
Definition math.hpp:900