Ginkgo Generated from branch based on main. Ginkgo version 1.10.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
half.hpp
1// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_HALF_HPP_
6#define GKO_PUBLIC_CORE_BASE_HALF_HPP_
7
8
9#include <climits>
10#include <complex>
11#include <cstdint>
12#include <cstring>
13#include <type_traits>
14
15
16class __half;
17
18
19namespace gko {
20
21
22template <typename, std::size_t, std::size_t>
24
25
26class half;
27
28class bfloat16;
29
30
31namespace detail {
32
33
34constexpr std::size_t byte_size = CHAR_BIT;
35
36template <std::size_t, typename = void>
37struct uint_of_impl {};
38
39template <std::size_t Bits>
40struct uint_of_impl<Bits, std::enable_if_t<(Bits <= 16)>> {
41 using type = std::uint16_t;
42};
43
44template <std::size_t Bits>
45struct uint_of_impl<Bits, std::enable_if_t<(16 < Bits && Bits <= 32)>> {
46 using type = std::uint32_t;
47};
48
49template <std::size_t Bits>
50struct uint_of_impl<Bits, std::enable_if_t<(32 < Bits) && (Bits <= 64)>> {
51 using type = std::uint64_t;
52};
53
54template <std::size_t Bits>
55using uint_of = typename uint_of_impl<Bits>::type;
56
57
58template <typename T>
59struct basic_float_traits {};
60
61template <>
62struct basic_float_traits<half> {
63 using type = half;
64 static constexpr int sign_bits = 1;
65 static constexpr int significand_bits = 10;
66 static constexpr int exponent_bits = 5;
67 static constexpr bool rounds_to_nearest = true;
68};
69
70template <>
71struct basic_float_traits<__half> {
72 using type = __half;
73 static constexpr int sign_bits = 1;
74 static constexpr int significand_bits = 10;
75 static constexpr int exponent_bits = 5;
76 static constexpr bool rounds_to_nearest = true;
77};
78
79template <>
80struct basic_float_traits<float> {
81 using type = float;
82 static constexpr int sign_bits = 1;
83 static constexpr int significand_bits = 23;
84 static constexpr int exponent_bits = 8;
85 static constexpr bool rounds_to_nearest = true;
86};
87
88template <>
89struct basic_float_traits<double> {
90 using type = double;
91 static constexpr int sign_bits = 1;
92 static constexpr int significand_bits = 52;
93 static constexpr int exponent_bits = 11;
94 static constexpr bool rounds_to_nearest = true;
95};
96
97template <typename FloatType, std::size_t NumComponents,
98 std::size_t ComponentId>
99struct basic_float_traits<truncated<FloatType, NumComponents, ComponentId>> {
100 using type = truncated<FloatType, NumComponents, ComponentId>;
101 static constexpr int sign_bits = ComponentId == 0 ? 1 : 0;
102 static constexpr int exponent_bits =
103 ComponentId == 0 ? basic_float_traits<FloatType>::exponent_bits : 0;
104 static constexpr int significand_bits =
105 ComponentId == 0 ? sizeof(type) * byte_size - exponent_bits - 1
106 : sizeof(type) * byte_size;
107 static constexpr bool rounds_to_nearest = false;
108};
109
110
111template <typename UintType>
112constexpr UintType create_ones(int n)
113{
114 return (n == sizeof(UintType) * byte_size ? static_cast<UintType>(0)
115 : static_cast<UintType>(1) << n) -
116 static_cast<UintType>(1);
117}
118
119
120template <typename T>
121struct float_traits {
122 using type = typename basic_float_traits<T>::type;
123 using bits_type = uint_of<sizeof(type) * byte_size>;
124 static constexpr int sign_bits = basic_float_traits<T>::sign_bits;
125 static constexpr int significand_bits =
126 basic_float_traits<T>::significand_bits;
127 static constexpr int exponent_bits = basic_float_traits<T>::exponent_bits;
128 static constexpr bits_type significand_mask =
129 create_ones<bits_type>(significand_bits);
130 static constexpr bits_type exponent_mask =
131 create_ones<bits_type>(significand_bits + exponent_bits) -
132 significand_mask;
133 static constexpr bits_type bias_mask =
134 create_ones<bits_type>(significand_bits + exponent_bits - 1) -
135 significand_mask;
136 static constexpr bits_type sign_mask =
137 create_ones<bits_type>(sign_bits + significand_bits + exponent_bits) -
138 exponent_mask - significand_mask;
139 static constexpr bool rounds_to_nearest =
140 basic_float_traits<T>::rounds_to_nearest;
141
142 static constexpr auto eps =
143 1.0 / (1ll << (significand_bits + rounds_to_nearest));
144
145 static constexpr bool is_inf(bits_type data)
146 {
147 return (data & exponent_mask) == exponent_mask &&
148 (data & significand_mask) == bits_type{};
149 }
150
151 static constexpr bool is_nan(bits_type data)
152 {
153 return (data & exponent_mask) == exponent_mask &&
154 (data & significand_mask) != bits_type{};
155 }
156
157 static constexpr bool is_denom(bits_type data)
158 {
159 return (data & exponent_mask) == bits_type{};
160 }
161};
162
163
164template <typename SourceType, typename ResultType,
165 bool = (sizeof(SourceType) <= sizeof(ResultType))>
166struct precision_converter;
167
168// upcasting implementation details
169template <typename SourceType, typename ResultType>
170struct precision_converter<SourceType, ResultType, true> {
171 using source_traits = float_traits<SourceType>;
172 using result_traits = float_traits<ResultType>;
173 using source_bits = typename source_traits::bits_type;
174 using result_bits = typename result_traits::bits_type;
175
176 static_assert(source_traits::exponent_bits <=
177 result_traits::exponent_bits &&
178 source_traits::significand_bits <=
179 result_traits::significand_bits,
180 "SourceType has to have both lower range and precision or "
181 "higher range and precision than ResultType");
182
183 static constexpr int significand_offset =
184 result_traits::significand_bits - source_traits::significand_bits;
185 static constexpr int exponent_offset = significand_offset;
186 static constexpr int sign_offset = result_traits::exponent_bits -
187 source_traits::exponent_bits +
188 exponent_offset;
189 static constexpr result_bits bias_change =
190 result_traits::bias_mask -
191 (static_cast<result_bits>(source_traits::bias_mask) << exponent_offset);
192
193 static constexpr result_bits shift_significand(source_bits data) noexcept
194 {
195 return static_cast<result_bits>(data & source_traits::significand_mask)
196 << significand_offset;
197 }
198
199 static constexpr result_bits shift_exponent(source_bits data) noexcept
200 {
201 return update_bias(
202 static_cast<result_bits>(data & source_traits::exponent_mask)
203 << exponent_offset);
204 }
205
206 static constexpr result_bits shift_sign(source_bits data) noexcept
207 {
208 return static_cast<result_bits>(data & source_traits::sign_mask)
209 << sign_offset;
210 }
211
212private:
213 static constexpr result_bits update_bias(result_bits data) noexcept
214 {
215 return data == typename result_traits::bits_type{} ? data
216 : data + bias_change;
217 }
218};
219
220// downcasting implementation details
221template <typename SourceType, typename ResultType>
222struct precision_converter<SourceType, ResultType, false> {
223 using source_traits = float_traits<SourceType>;
224 using result_traits = float_traits<ResultType>;
225 using source_bits = typename source_traits::bits_type;
226 using result_bits = typename result_traits::bits_type;
227
228 static_assert(source_traits::exponent_bits >=
229 result_traits::exponent_bits &&
230 source_traits::significand_bits >=
231 result_traits::significand_bits,
232 "SourceType has to have both lower range and precision or "
233 "higher range and precision than ResultType");
234
235 static constexpr int significand_offset =
236 source_traits::significand_bits - result_traits::significand_bits;
237 static constexpr int exponent_offset = significand_offset;
238 static constexpr int sign_offset = source_traits::exponent_bits -
239 result_traits::exponent_bits +
240 exponent_offset;
241 static constexpr source_bits bias_change =
242 (source_traits::bias_mask >> exponent_offset) -
243 static_cast<source_bits>(result_traits::bias_mask);
244
245 static constexpr result_bits shift_significand(source_bits data) noexcept
246 {
247 return static_cast<result_bits>(
248 (data & source_traits::significand_mask) >> significand_offset);
249 }
250
251 static constexpr result_bits shift_exponent(source_bits data) noexcept
252 {
253 return static_cast<result_bits>(update_bias(
254 (data & source_traits::exponent_mask) >> exponent_offset));
255 }
256
257 static constexpr result_bits shift_sign(source_bits data) noexcept
258 {
259 return static_cast<result_bits>((data & source_traits::sign_mask) >>
260 sign_offset);
261 }
262
263private:
264 static constexpr source_bits update_bias(source_bits data) noexcept
265 {
266 return data <= bias_change ? typename source_traits::bits_type{}
267 : limit_exponent(data - bias_change);
268 }
269
270 static constexpr source_bits limit_exponent(source_bits data) noexcept
271 {
272 return data >= static_cast<source_bits>(result_traits::exponent_mask)
273 ? static_cast<source_bits>(result_traits::exponent_mask)
274 : data;
275 }
276};
277
278
279} // namespace detail
280
281
288class alignas(std::uint16_t) half {
289public:
290 // create half value from the bits directly.
291 static constexpr half create_from_bits(const std::uint16_t& bits) noexcept
292 {
293 half result;
294 result.data_ = bits;
295 return result;
296 }
297
298 // TODO: NVHPC (host side) may not use zero initialization for the data
299 // member by default constructor in some cases. Not sure whether it is
300 // caused by something else in jacobi or isai.
301 constexpr half() noexcept : data_(0){};
302
303 template <typename T,
304 typename = std::enable_if_t<std::is_scalar<T>::value ||
305 std::is_same_v<T, bfloat16>>>
306 half(const T& val) : data_(0)
307 {
308 this->float2half(static_cast<float>(val));
309 }
310
311 template <typename V>
312 half& operator=(const V& val)
313 {
314 this->float2half(static_cast<float>(val));
315 return *this;
316 }
317
318 operator float() const noexcept
319 {
320 const auto bits = half2float(data_);
321 float ans(0);
322 std::memcpy(&ans, &bits, sizeof(float));
323 return ans;
324 }
325
326 // can not use half operator _op(const half) for half + half
327 // operation will cast it to float and then do float operation such that it
328 // becomes float in the end.
329#define HALF_OPERATOR(_op, _opeq) \
330 friend half operator _op(const half& lhf, const half& rhf) \
331 { \
332 return static_cast<half>(static_cast<float>(lhf) \
333 _op static_cast<float>(rhf)); \
334 } \
335 half& operator _opeq(const half& hf) \
336 { \
337 auto result = *this _op hf; \
338 data_ = result.data_; \
339 return *this; \
340 }
341
342 HALF_OPERATOR(+, +=)
343 HALF_OPERATOR(-, -=)
344 HALF_OPERATOR(*, *=)
345 HALF_OPERATOR(/, /=)
346
347#undef HALF_OPERATOR
348
349 // Do operation with different type
350 // If it is floating point, using floating point as type.
351 // If it is integer, using half as type
352 // Note: we do not define the operation with bfloat16, which is already
353 // defined in bfloat16.hpp
354#define HALF_FRIEND_OPERATOR(_op, _opeq) \
355 template <typename T> \
356 friend std::enable_if_t< \
357 !std::is_same<T, half>::value && std::is_scalar<T>::value, \
358 std::conditional_t<std::is_floating_point<T>::value, T, half>> \
359 operator _op(const half& hf, const T& val) \
360 { \
361 using type = \
362 std::conditional_t<std::is_floating_point<T>::value, T, half>; \
363 auto result = static_cast<type>(hf); \
364 result _opeq static_cast<type>(val); \
365 return result; \
366 } \
367 template <typename T> \
368 friend std::enable_if_t< \
369 !std::is_same<T, half>::value && std::is_scalar<T>::value, \
370 std::conditional_t<std::is_floating_point<T>::value, T, half>> \
371 operator _op(const T& val, const half& hf) \
372 { \
373 using type = \
374 std::conditional_t<std::is_floating_point<T>::value, T, half>; \
375 auto result = static_cast<type>(val); \
376 result _opeq static_cast<type>(hf); \
377 return result; \
378 }
379
380 HALF_FRIEND_OPERATOR(+, +=)
381 HALF_FRIEND_OPERATOR(-, -=)
382 HALF_FRIEND_OPERATOR(*, *=)
383 HALF_FRIEND_OPERATOR(/, /=)
384
385#undef HALF_FRIEND_OPERATOR
386
387 // the negative
388 half operator-() const
389 {
390 auto val = 0.0f - *this;
391 return static_cast<half>(val);
392 }
393
394private:
395 using f16_traits = detail::float_traits<half>;
396 using f32_traits = detail::float_traits<float>;
397
398 void float2half(const float& val) noexcept
399 {
400 std::uint32_t bit_val(0);
401 std::memcpy(&bit_val, &val, sizeof(float));
402 data_ = float2half(bit_val);
403 }
404
405 static constexpr std::uint16_t float2half(std::uint32_t data_) noexcept
406 {
407 using conv = detail::precision_converter<float, half>;
408 if (f32_traits::is_inf(data_)) {
409 return conv::shift_sign(data_) | f16_traits::exponent_mask;
410 } else if (f32_traits::is_nan(data_)) {
411 return conv::shift_sign(data_) | f16_traits::exponent_mask |
412 f16_traits::significand_mask;
413 } else {
414 const auto exp = conv::shift_exponent(data_);
415 if (f16_traits::is_inf(exp)) {
416 return conv::shift_sign(data_) | exp;
417 } else if (f16_traits::is_denom(exp)) {
418 // TODO: handle denormals
419 return conv::shift_sign(data_);
420 } else {
421 // Rounding to even
422 const auto result = conv::shift_sign(data_) | exp |
423 conv::shift_significand(data_);
424 const auto tail =
425 data_ & static_cast<f32_traits::bits_type>(
426 (1 << conv::significand_offset) - 1);
427
428 constexpr auto half = static_cast<f32_traits::bits_type>(
429 1 << (conv::significand_offset - 1));
430 return result +
431 (tail > half || ((tail == half) && (result & 1)));
432 }
433 }
434 }
435
436 static constexpr std::uint32_t half2float(std::uint16_t data_) noexcept
437 {
438 using conv = detail::precision_converter<half, float>;
439 if (f16_traits::is_inf(data_)) {
440 return conv::shift_sign(data_) | f32_traits::exponent_mask;
441 } else if (f16_traits::is_nan(data_)) {
442 return conv::shift_sign(data_) | f32_traits::exponent_mask |
443 f32_traits::significand_mask;
444 } else if (f16_traits::is_denom(data_)) {
445 // TODO: handle denormals
446 return conv::shift_sign(data_);
447 } else {
448 return conv::shift_sign(data_) | conv::shift_exponent(data_) |
449 conv::shift_significand(data_);
450 }
451 }
452
453 std::uint16_t data_;
454};
455
456
457} // namespace gko
458
459
460namespace std {
461
462
463template <>
464class complex<gko::half> {
465public:
466 using value_type = gko::half;
467
468 complex(const value_type& real = value_type(0.f),
469 const value_type& imag = value_type(0.f))
470 : real_(real), imag_(imag)
471 {}
472
473 template <
474 typename T, typename U,
475 typename = std::enable_if_t<
476 (std::is_scalar<T>::value || std::is_same_v<T, gko::bfloat16>)&&(
477 std::is_scalar<U>::value || std::is_same_v<U, gko::bfloat16>)>>
478 explicit complex(const T& real, const U& imag)
479 : real_(static_cast<value_type>(real)),
480 imag_(static_cast<value_type>(imag))
481 {}
482
483 template <typename T,
484 typename = std::enable_if_t<std::is_scalar<T>::value ||
485 std::is_same_v<T, gko::bfloat16>>>
486 complex(const T& real)
487 : real_(static_cast<value_type>(real)),
488 imag_(static_cast<value_type>(0.f))
489 {}
490
491 // When using complex(real, imag), MSVC with CUDA try to recognize the
492 // complex is a member not constructor.
493 template <typename T,
494 typename = std::enable_if_t<std::is_scalar<T>::value ||
495 std::is_same_v<T, gko::bfloat16>>>
496 explicit complex(const complex<T>& other)
497 : real_(static_cast<value_type>(other.real())),
498 imag_(static_cast<value_type>(other.imag()))
499 {}
500
501 value_type real() const noexcept { return real_; }
502
503 value_type imag() const noexcept { return imag_; }
504
505 operator std::complex<float>() const noexcept
506 {
507 return std::complex<float>(static_cast<float>(real_),
508 static_cast<float>(imag_));
509 }
510
511 template <typename V>
512 complex& operator=(const V& val)
513 {
514 real_ = val;
515 imag_ = value_type();
516 return *this;
517 }
518
519 template <typename V>
520 complex& operator=(const std::complex<V>& val)
521 {
522 real_ = val.real();
523 imag_ = val.imag();
524 return *this;
525 }
526
527 complex& operator+=(const value_type& real)
528 {
529 real_ += real;
530 return *this;
531 }
532
533 complex& operator-=(const value_type& real)
534 {
535 real_ -= real;
536 return *this;
537 }
538
539 complex& operator*=(const value_type& real)
540 {
541 real_ *= real;
542 imag_ *= real;
543 return *this;
544 }
545
546 complex& operator/=(const value_type& real)
547 {
548 real_ /= real;
549 imag_ /= real;
550 return *this;
551 }
552
553 template <typename T>
554 complex& operator+=(const complex<T>& val)
555 {
556 real_ += val.real();
557 imag_ += val.imag();
558 return *this;
559 }
560
561 template <typename T>
562 complex& operator-=(const complex<T>& val)
563 {
564 real_ -= val.real();
565 imag_ -= val.imag();
566 return *this;
567 }
568
569 template <typename T>
570 complex& operator*=(const complex<T>& val)
571 {
572 auto val_f = static_cast<std::complex<float>>(val);
573 auto result_f = static_cast<std::complex<float>>(*this);
574 result_f *= val_f;
575 real_ = result_f.real();
576 imag_ = result_f.imag();
577 return *this;
578 }
579
580 template <typename T>
581 complex& operator/=(const complex<T>& val)
582 {
583 auto val_f = static_cast<std::complex<float>>(val);
584 auto result_f = static_cast<std::complex<float>>(*this);
585 result_f /= val_f;
586 real_ = result_f.real();
587 imag_ = result_f.imag();
588 return *this;
589 }
590
591#define COMPLEX_HALF_OPERATOR(_op, _opeq) \
592 friend complex operator _op(const complex& lhf, const complex& rhf) \
593 { \
594 auto a = lhf; \
595 a _opeq rhf; \
596 return a; \
597 }
598
599 COMPLEX_HALF_OPERATOR(+, +=)
600 COMPLEX_HALF_OPERATOR(-, -=)
601 COMPLEX_HALF_OPERATOR(*, *=)
602 COMPLEX_HALF_OPERATOR(/, /=)
603
604#undef COMPLEX_HALF_OPERATOR
605
606private:
607 value_type real_;
608 value_type imag_;
609};
610
611
612template <>
613struct numeric_limits<gko::half> {
614 static constexpr bool is_specialized{true};
615 static constexpr bool is_signed{true};
616 static constexpr bool is_integer{false};
617 static constexpr bool is_exact{false};
618 static constexpr bool is_bounded{true};
619 static constexpr bool is_modulo{false};
620 static constexpr int digits{
621 gko::detail::float_traits<gko::half>::significand_bits + 1};
622 // 3/10 is approx. log_10(2)
623 static constexpr int digits10{digits * 3 / 10};
624
625 static constexpr gko::half epsilon()
626 {
627 constexpr auto bits = static_cast<std::uint16_t>(0b0'00101'0000000000u);
628 return gko::half::create_from_bits(bits);
629 }
630
631 static constexpr gko::half infinity()
632 {
633 constexpr auto bits = static_cast<std::uint16_t>(0b0'11111'0000000000u);
634 return gko::half::create_from_bits(bits);
635 }
636
637 static constexpr gko::half min()
638 {
639 constexpr auto bits = static_cast<std::uint16_t>(0b0'00001'0000000000u);
640 return gko::half::create_from_bits(bits);
641 }
642
643 static constexpr gko::half max()
644 {
645 constexpr auto bits = static_cast<std::uint16_t>(0b0'11110'1111111111u);
646 return gko::half::create_from_bits(bits);
647 }
648
649 static constexpr gko::half lowest()
650 {
651 constexpr auto bits = static_cast<std::uint16_t>(0b1'11110'1111111111u);
652 return gko::half::create_from_bits(bits);
653 };
654
655 static constexpr gko::half quiet_NaN()
656 {
657 constexpr auto bits = static_cast<std::uint16_t>(0b0'11111'1111111111u);
658 return gko::half::create_from_bits(bits);
659 }
660};
661
662
663// complex using a template on operator= for any kind of complex<T>, so we can
664// do full specialization for half
665template <>
666inline complex<double>& complex<double>::operator=(
667 const std::complex<gko::half>& a)
668{
669 complex<double> t(a.real(), a.imag());
670 operator=(t);
671 return *this;
672}
673
674
675// For MSVC
676template <>
677inline complex<float>& complex<float>::operator=(
678 const std::complex<gko::half>& a)
679{
680 complex<float> t(a.real(), a.imag());
681 operator=(t);
682 return *this;
683}
684
685
686} // namespace std
687
688
689#endif // GKO_PUBLIC_CORE_BASE_HALF_HPP_
A class providing basic support for bfloat16 precision floating point types.
Definition bfloat16.hpp:76
A class providing basic support for half precision floating point types.
Definition half.hpp:288
Definition half.hpp:23
The Ginkgo namespace.
Definition abstract_factory.hpp:20
constexpr size_type byte_size
Number of bits in a byte.
Definition types.hpp:178
STL namespace.