5#ifndef GKO_PUBLIC_CORE_BASE_MATH_HPP_
6#define GKO_PUBLIC_CORE_BASE_MATH_HPP_
17#include <ginkgo/config.hpp>
18#include <ginkgo/core/base/half.hpp>
19#include <ginkgo/core/base/types.hpp>
20#include <ginkgo/core/base/utils.hpp>
41struct remove_complex_impl {
49struct remove_complex_impl<std::complex<T>> {
60struct to_complex_impl {
61 using type = std::complex<T>;
70struct to_complex_impl<std::complex<T>> {
71 using type = std::complex<T>;
76struct is_complex_impl :
public std::integral_constant<bool, false> {};
79struct is_complex_impl<std::complex<T>>
80 :
public std::integral_constant<bool, true> {};
84struct is_complex_or_scalar_impl : std::is_scalar<T> {};
87struct is_complex_or_scalar_impl<half> : std::true_type {};
90struct is_complex_or_scalar_impl<bfloat16> : std::true_type {};
93struct is_complex_or_scalar_impl<std::complex<T>>
94 : is_complex_or_scalar_impl<T> {};
104template <
template <
typename>
class converter,
typename T>
105struct template_converter {};
116template <
template <
typename>
class converter,
template <
typename...>
class T,
118struct template_converter<converter, T<Rest...>> {
119 using type = T<typename converter<Rest>::type...>;
123template <
typename T,
typename =
void>
124struct remove_complex_s {};
133struct remove_complex_s<T,
134 std::enable_if_t<is_complex_or_scalar_impl<T>::value>> {
135 using type =
typename detail::remove_complex_impl<T>::type;
145struct remove_complex_s<
146 T, std::enable_if_t<!is_complex_or_scalar_impl<T>::value>> {
148 typename detail::template_converter<detail::remove_complex_impl,
153template <
typename T,
typename =
void>
154struct to_complex_s {};
163struct to_complex_s<T, std::enable_if_t<is_complex_or_scalar_impl<T>::value>> {
164 using type =
typename detail::to_complex_impl<T>::type;
174struct to_complex_s<T, std::enable_if_t<!is_complex_or_scalar_impl<T>::value>> {
176 typename detail::template_converter<detail::to_complex_impl, T>::type;
202 using type =
typename std::complex<T>::value_type;
227 return detail::is_complex_impl<T>::value;
251 return detail::is_complex_or_scalar_impl<T>::value;
300struct next_precision_base_impl {};
303struct next_precision_base_impl<float> {
308struct next_precision_base_impl<double> {
313struct next_precision_base_impl<
std::complex<T>> {
314 using type = std::complex<typename next_precision_base_impl<T>::type>;
323template <
typename T,
int step,
typename Visited,
typename... Rest>
324struct find_precision_list_impl;
326template <
typename T,
int step,
typename... Visited,
typename U,
328struct find_precision_list_impl<T, step,
std::tuple<Visited...>, U, Rest...> {
330 typename find_precision_list_impl<T, step, std::tuple<Visited..., U>,
334template <
typename T,
int step,
typename... Visited,
typename... Rest>
335struct find_precision_list_impl<T, step,
std::tuple<Visited...>, T, Rest...> {
336 using tuple = std::tuple<T, Rest..., Visited...>;
337 constexpr static auto tuple_size =
338 static_cast<int>(std::tuple_size_v<tuple>);
340 constexpr static int index = (tuple_size + step % tuple_size) % tuple_size;
341 using type = std::tuple_element_t<index, tuple>;
345template <
typename T,
int step = 1>
346struct find_precision_impl {
347 using type =
typename find_precision_list_impl<T, step, std::tuple<>,
348#if GINKGO_ENABLE_HALF
351#if GINKGO_ENABLE_BFLOAT16
354 float,
double>::type;
358template <
typename T,
int step>
359struct find_precision_impl<
std::complex<T>, step> {
360 using type = std::complex<typename find_precision_impl<T, step>::type>;
365struct reduce_precision_impl {
370struct reduce_precision_impl<
std::complex<T>> {
371 using type = std::complex<typename reduce_precision_impl<T>::type>;
375struct reduce_precision_impl<double> {
381struct reduce_precision_impl<float> {
387struct increase_precision_impl {
392struct increase_precision_impl<
std::complex<T>> {
393 using type = std::complex<typename increase_precision_impl<T>::type>;
397struct increase_precision_impl<float> {
403struct increase_precision_impl<
half> {
409struct infinity_impl {
412 static constexpr auto value = std::numeric_limits<T>::infinity();
419template <
typename T1,
typename T2>
420struct highest_precision_impl {
421 using type =
decltype(T1{} + T2{});
424template <
typename T1,
typename T2>
425struct highest_precision_impl<
std::complex<T1>, std::complex<T2>> {
426 using type = std::complex<typename highest_precision_impl<T1, T2>::type>;
429template <
typename Head,
typename... Tail>
430struct highest_precision_variadic {
431 using type =
typename highest_precision_impl<
432 Head,
typename highest_precision_variadic<Tail...>::type>::type;
435template <
typename Head>
436struct highest_precision_variadic<Head> {
465template <
typename T,
int step = 1>
472template <
typename T,
int step = 1>
501template <
typename... Ts>
503 typename detail::highest_precision_variadic<Ts...>::type;
538template <
typename FloatType,
size_type NumComponents,
size_type ComponentId>
546struct truncate_type_impl {
547 using type = truncated<T, 2, 0>;
550template <
typename T,
size_type Components>
551struct truncate_type_impl<truncated<T, Components, 0>> {
552 using type = truncated<T, 2 * Components, 0>;
556struct truncate_type_impl<std::complex<T>> {
557 using type = std::complex<typename truncate_type_impl<T>::type>;
562struct type_size_impl {
563 static constexpr auto value =
sizeof(T) *
byte_size;
567struct type_size_impl<std::complex<T>> {
568 static constexpr auto value =
sizeof(T) *
byte_size;
579template <
typename T,
size_type Limit = sizeof(u
int16) *
byte_size>
581 std::conditional_t<detail::type_size_impl<T>::value >= 2 * Limit,
582 typename detail::truncate_type_impl<T>::type, T>;
591template <
typename S,
typename R>
599 GKO_ATTRIBUTES R
operator()(S val) {
return static_cast<R
>(val); }
616 return (num + den - 1) / den;
642GKO_INLINE
constexpr T
zero(
const T&)
654GKO_INLINE
constexpr T
one()
660GKO_INLINE
constexpr half one<half>()
662 constexpr auto bits =
static_cast<uint16>(0b0'01111'0000000000u);
663 return half::create_from_bits(bits);
669 constexpr auto bits =
static_cast<uint16>(0b0'01111111'0000000u);
670 return bfloat16::create_from_bits(bits);
684GKO_INLINE
constexpr T
one(
const T&)
732GKO_INLINE
constexpr T
max(
const T& x,
const T& y)
734 return x >= y ? x : y;
750GKO_INLINE
constexpr T
min(
const T& x,
const T& y)
752 return x <= y ? x : y;
768template <
typename Ref,
typename Dummy = std::
void_t<>>
769struct has_to_arithmetic_type : std::false_type {
770 static_assert(std::is_same<Dummy, void>::value,
771 "Do not modify the Dummy value!");
775template <
typename Ref>
776struct has_to_arithmetic_type<
777 Ref, std::
void_t<decltype(std::declval<Ref>().to_arithmetic_type())>>
779 using type =
decltype(std::declval<Ref>().to_arithmetic_type());
787template <
typename Ref,
typename Dummy = std::
void_t<>>
788struct has_arithmetic_type : std::false_type {
789 static_assert(std::is_same<Dummy, void>::value,
790 "Do not modify the Dummy value!");
793template <
typename Ref>
794struct has_arithmetic_type<Ref, std::
void_t<typename Ref::arithmetic_type>>
809template <
typename Ref>
810constexpr GKO_ATTRIBUTES
811 std::enable_if_t<has_to_arithmetic_type<Ref>::value,
812 typename has_to_arithmetic_type<Ref>::type>
813 to_arithmetic_type(
const Ref& ref)
815 return ref.to_arithmetic_type();
818template <
typename Ref>
819constexpr GKO_ATTRIBUTES std::enable_if_t<!has_to_arithmetic_type<Ref>::value &&
820 has_arithmetic_type<Ref>::value,
821 typename Ref::arithmetic_type>
822to_arithmetic_type(
const Ref& ref)
827template <
typename Ref>
828constexpr GKO_ATTRIBUTES std::enable_if_t<!has_to_arithmetic_type<Ref>::value &&
829 !has_arithmetic_type<Ref>::value,
831to_arithmetic_type(
const Ref& ref)
841GKO_ATTRIBUTES GKO_INLINE
constexpr std::enable_if_t<!is_complex_s<T>::value, T>
848GKO_ATTRIBUTES GKO_INLINE
constexpr std::enable_if_t<is_complex_s<T>::value,
857GKO_ATTRIBUTES GKO_INLINE
constexpr std::enable_if_t<!is_complex_s<T>::value, T>
864GKO_ATTRIBUTES GKO_INLINE
constexpr std::enable_if_t<is_complex_s<T>::value,
873GKO_ATTRIBUTES GKO_INLINE
constexpr std::enable_if_t<!is_complex_s<T>::value, T>
880GKO_ATTRIBUTES GKO_INLINE
constexpr std::enable_if_t<is_complex_s<T>::value, T>
883 return T{real_impl(x), -imag_impl(x)};
900GKO_ATTRIBUTES GKO_INLINE
constexpr auto real(
const T& x)
902 return detail::real_impl(detail::to_arithmetic_type(x));
916GKO_ATTRIBUTES GKO_INLINE
constexpr auto imag(
const T& x)
918 return detail::imag_impl(detail::to_arithmetic_type(x));
930GKO_ATTRIBUTES GKO_INLINE
constexpr auto conj(
const T& x)
932 return detail::conj_impl(detail::to_arithmetic_type(x));
962GKO_INLINE
constexpr std::enable_if_t<!is_complex_s<T>::value, T>
abs(
965 return x >=
zero<T>() ? x : -x;
977GKO_INLINE gko::half
abs(
const std::complex<gko::half>& x)
980 return static_cast<gko::half
>(
abs(std::complex<float>(x)));
983GKO_INLINE gko::bfloat16
abs(
const std::complex<gko::bfloat16>& x)
986 return static_cast<gko::bfloat16
>(
abs(std::complex<float>(x)));
992GKO_INLINE gko::half sqrt(gko::half a)
994 return gko::half(std::sqrt(
float(a)));
997GKO_INLINE std::complex<gko::half> sqrt(std::complex<gko::half> a)
999 return std::complex<gko::half>(sqrt(std::complex<float>(
1000 static_cast<float>(a.real()),
static_cast<float>(a.imag()))));
1003GKO_INLINE gko::bfloat16 sqrt(gko::bfloat16 a)
1005 return gko::bfloat16(std::sqrt(
float(a)));
1008GKO_INLINE std::complex<gko::bfloat16> sqrt(std::complex<gko::bfloat16> a)
1010 return std::complex<gko::bfloat16>(sqrt(std::complex<float>(
1011 static_cast<float>(a.real()),
static_cast<float>(a.imag()))));
1020template <
typename T>
1021GKO_INLINE
constexpr T
pi()
1023 return static_cast<T
>(3.1415926535897932384626433);
1035template <
typename T>
1056template <
typename T>
1074template <
typename T>
1076 const T& hint = T{1})
noexcept
1093template <
typename T>
1094GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<!is_complex_s<T>::value,
bool>
1097 constexpr T infinity{detail::infinity_impl<T>::value};
1098 return abs(value) < infinity;
1113template <
typename T>
1114GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<is_complex_s<T>::value,
bool>
1132template <
typename T>
1148template <
typename T>
1150 "is_nan can't be used safely on the device (MSVC+CUDA), and will thus be "
1151 "removed in a future release, without replacement")
1152GKO_INLINE GKO_ATTRIBUTES
1156 return isnan(value);
1169template <
typename T>
1171 "is_nan can't be used safely on the device (MSVC+CUDA), and will thus be "
1172 "removed in a future release, without replacement")
1187template <
typename T>
1188GKO_INLINE
constexpr std::enable_if_t<!is_complex_s<T>::value, T>
nan()
1190 return std::numeric_limits<T>::quiet_NaN();
1201template <
typename T>
1202GKO_INLINE
constexpr std::enable_if_t<is_complex_s<T>::value, T>
nan()
A class providing basic support for bfloat16 precision floating point types.
Definition bfloat16.hpp:76
A class providing basic support for half precision floating point types.
Definition half.hpp:288
typename detail::make_void< Ts... >::type void_t
Use the custom implementation, since the std::void_t used in is_matrix_type_builder seems to trigger ...
Definition std_extensions.hpp:47
The Ginkgo namespace.
Definition abstract_factory.hpp:20
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:654
std::enable_if_t<!is_complex_s< T >::value, bool > is_finite(const T &value)
Checks if a floating point number is finite, meaning it is neither +/- infinity nor NaN.
Definition math.hpp:1095
constexpr T pi()
Returns the value of pi.
Definition math.hpp:1021
constexpr std::enable_if_t<!is_complex_s< T >::value, T > abs(const T &x)
Returns the absolute value of the object.
Definition math.hpp:962
typename detail::remove_complex_s< T >::type remove_complex
Obtain the type which removed the complex of complex/scalar type or the template parameter of class b...
Definition math.hpp:264
constexpr increase_precision< T > round_up(T val)
Increases the precision of the input parameter.
Definition math.hpp:532
typename detail::next_precision_base_impl< T >::type next_precision_base
Obtains the next type in the singly-linked precision list.
Definition math.hpp:448
std::conditional_t< detail::type_size_impl< T >::value >=2 *Limit, typename detail::truncate_type_impl< T >::type, T > truncate_type
Truncates the type by half (by dropping bits), but ensures that it is at least Limit bits wide.
Definition math.hpp:580
typename detail::highest_precision_variadic< Ts... >::type highest_precision
Obtains the smallest arithmetic type that is able to store elements of all template parameter types e...
Definition math.hpp:502
typename detail::to_complex_s< T >::type to_complex
Obtain the type which adds the complex of complex/scalar type or the template parameter of class by a...
Definition math.hpp:283
constexpr uint32 get_significant_bit(const T &n, uint32 hint=0u) noexcept
Returns the position of the most significant bit of the number.
Definition math.hpp:1057
constexpr bool is_complex_or_scalar()
Checks if T is a complex/scalar type.
Definition math.hpp:249
std::enable_if_t<!is_complex_s< T >::value, bool > is_nan(const T &value)
Checks if a floating point number is NaN.
Definition math.hpp:1153
detail::is_complex_impl< T > is_complex_s
Allows to check if T is a complex value during compile time by accessing the value attribute of this ...
Definition math.hpp:215
constexpr std::enable_if_t<!is_complex_s< T >::value, T > nan()
Returns a quiet NaN of the given type.
Definition math.hpp:1188
constexpr T zero()
Returns the additive identity for T.
Definition math.hpp:626
constexpr bool is_zero(T value)
Returns true if and only if the given value is zero.
Definition math.hpp:699
constexpr auto imag(const T &x)
Returns the imaginary part of the object.
Definition math.hpp:916
std::uint32_t uint32
32-bit unsigned integral type.
Definition types.hpp:130
constexpr std::complex< remove_complex< T > > unit_root(int64 n, int64 k=1)
Returns the value of exp(2 * pi * i * k / n), i.e.
Definition math.hpp:1036
next_precision_base< T > previous_precision_base
Obtains the previous type in the singly-linked precision list.
Definition math.hpp:458
typename detail::reduce_precision_impl< T >::type reduce_precision
Obtains the next type in the hierarchy with lower precision than T.
Definition math.hpp:480
constexpr reduce_precision< T > round_down(T val)
Reduces the precision of the input parameter.
Definition math.hpp:516
std::int64_t int64
64-bit signed integral type.
Definition types.hpp:113
constexpr int64 ceildiv(int64 num, int64 den)
Performs integer division with rounding up.
Definition math.hpp:614
constexpr bool is_complex()
Checks if T is a complex type.
Definition math.hpp:225
T safe_divide(T a, T b)
Computes the quotient of the given parameters, guarding against division by zero.
Definition math.hpp:1133
constexpr T min(const T &x, const T &y)
Returns the smaller of the arguments.
Definition math.hpp:750
constexpr auto squared_norm(const T &x) -> decltype(real(conj(x) *x))
Returns the squared norm of the object.
Definition math.hpp:944
detail::is_complex_or_scalar_impl< T > is_complex_or_scalar_s
Allows to check if T is a complex or scalar value during compile time by accessing the value attribut...
Definition math.hpp:239
constexpr size_type byte_size
Number of bits in a byte.
Definition types.hpp:178
constexpr T get_superior_power(const T &base, const T &limit, const T &hint=T{1}) noexcept
Returns the smallest power of base not smaller than limit.
Definition math.hpp:1075
typename detail::increase_precision_impl< T >::type increase_precision
Obtains the next type in the hierarchy with higher precision than T.
Definition math.hpp:487
remove_complex< T > to_real
to_real is alias of remove_complex
Definition math.hpp:292
constexpr auto conj(const T &x)
Returns the conjugate of an object.
Definition math.hpp:930
typename detail::find_precision_impl< T, -step >::type previous_precision
Obtains the previous move type of T in the singly-linked precision corresponding bfloat16/half.
Definition math.hpp:473
typename detail::find_precision_impl< T, step >::type next_precision
Obtains the next move type of T in the singly-linked precision corresponding bfloat16/half.
Definition math.hpp:466
constexpr bool is_nonzero(T value)
Returns true if and only if the given value is not zero.
Definition math.hpp:714
std::uint16_t uint16
16-bit unsigned integral type.
Definition types.hpp:124
constexpr T max(const T &x, const T &y)
Returns the larger of the arguments.
Definition math.hpp:732
constexpr auto real(const T &x)
Returns the real part of the object.
Definition math.hpp:900
Access the underlying real type of a complex number.
Definition math.hpp:189
T type
The type.
Definition math.hpp:191
Used to convert objects of type S to objects of type R using static_cast.
Definition math.hpp:592
R operator()(S val)
Converts the object to result type.
Definition math.hpp:599