5#ifndef GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
6#define GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
9#include <ginkgo/config.hpp>
10#include <ginkgo/core/base/math.hpp>
11#include <ginkgo/core/base/temporary_conversion.hpp>
12#include <ginkgo/core/distributed/vector.hpp>
13#include <ginkgo/core/matrix/dense.hpp>
43template <
typename ValueType,
typename Ptr>
44detail::temporary_conversion<std::conditional_t<
49 using Pointee = detail::pointee<Ptr>;
54 using MaybeConstDense =
55 std::conditional_t<std::is_const<Pointee>::value,
const Dense, Dense>;
57 detail::temporary_conversion<MaybeConstDense>::template create<
58 NextDense, Next2Dense, Next3Dense>(
matrix);
80template <
typename ValueType,
typename Function,
typename... Args>
96template <
typename ValueType,
typename Function>
103 auto complex_to_real =
106 if (complex_to_real) {
113 fn(
dynamic_cast<const Dense*
>(dense_in->create_real_view().get()),
114 dynamic_cast<Dense*
>(dense_out->create_real_view().get()));
130template <
typename ValueType,
typename Function>
138 auto complex_to_real =
141 if (complex_to_real) {
149 fn(dense_alpha.get(),
150 dynamic_cast<const Dense*
>(dense_in->create_real_view().get()),
151 dynamic_cast<Dense*
>(dense_out->create_real_view().get()));
167template <
typename ValueType,
typename Function>
176 auto complex_to_real =
179 if (complex_to_real) {
188 fn(dense_alpha.get(),
189 dynamic_cast<const Dense*
>(dense_in->create_real_view().get()),
191 dynamic_cast<Dense*
>(dense_out->create_real_view().get()));
227template <
typename ValueType,
typename Function>
230#ifdef GINKGO_MIXED_PRECISION
235 auto dispatch_out_vector = [&](
auto dense_in) {
236 if (
auto dense_out =
dynamic_cast<fst_type*
>(out)) {
237 fn(dense_in, dense_out);
238 }
else if (
auto dense_out =
dynamic_cast<snd_type*
>(out)) {
239 fn(dense_in, dense_out);
240 }
else if (
auto dense_out =
dynamic_cast<trd_type*
>(out)) {
241 fn(dense_in, dense_out);
242 }
else if (
auto dense_out =
dynamic_cast<fth_type*
>(out)) {
243 fn(dense_in, dense_out);
245 GKO_NOT_SUPPORTED(out);
248 if (
auto dense_in =
dynamic_cast<const fst_type*
>(in)) {
249 dispatch_out_vector(dense_in);
250 }
else if (
auto dense_in =
dynamic_cast<const snd_type*
>(in)) {
251 dispatch_out_vector(dense_in);
252 }
else if (
auto dense_in =
dynamic_cast<const trd_type*
>(in)) {
253 dispatch_out_vector(dense_in);
254 }
else if (
auto dense_in =
dynamic_cast<const fth_type*
>(in)) {
255 dispatch_out_vector(dense_in);
257 GKO_NOT_SUPPORTED(in);
274template <
typename ValueType,
typename Function,
275 std::enable_if_t<is_complex<ValueType>()>* =
nullptr>
279#ifdef GINKGO_MIXED_PRECISION
287template <
typename ValueType,
typename Function,
288 std::enable_if_t<!is_complex<ValueType>()>* =
nullptr>
292#ifdef GINKGO_MIXED_PRECISION
293 if (!
dynamic_cast<const ConvertibleTo<matrix::Dense<>
>*>(in)) {
295 [&fn](
auto dense_in,
auto dense_out) {
296 fn(dense_in->create_real_view().get(),
297 dense_out->create_real_view().get());
309namespace experimental {
343template <
typename ValueType>
348 gko::detail::temporary_conversion<Vector<ValueType>>::template create<
353 GKO_NOT_SUPPORTED(
matrix);
362template <
typename ValueType>
363gko::detail::temporary_conversion<const Vector<ValueType>>
366 auto result = gko::detail::temporary_conversion<const Vector<ValueType>>::
367 template create<Vector<next_precision<ValueType>>,
371 GKO_NOT_SUPPORTED(
matrix);
391template <
typename ValueType,
typename Function,
typename... Args>
398template <
typename ValueType,
typename Function>
401#ifdef GINKGO_MIXED_PRECISION
402 using fst_type = Vector<ValueType>;
403 using snd_type = Vector<next_precision<ValueType, 2>>;
404 using trd_type = Vector<next_precision<ValueType, 3>>;
405 auto dispatch_out_vector = [&](
auto vector_in) {
406 if (
auto vector_out =
dynamic_cast<fst_type*
>(out)) {
407 fn(vector_in, vector_out);
408 }
else if (
auto vector_out =
dynamic_cast<snd_type*
>(out)) {
409 fn(vector_in, vector_out);
410 }
else if (
auto vector_out =
dynamic_cast<trd_type*
>(out)) {
411 fn(vector_in, vector_out);
413 GKO_NOT_SUPPORTED(out);
416 if (
auto vector_in =
dynamic_cast<const fst_type*
>(in)) {
417 dispatch_out_vector(vector_in);
418 }
else if (
auto vector_in =
dynamic_cast<const snd_type*
>(in)) {
419 dispatch_out_vector(vector_in);
420 }
else if (
auto vector_in =
dynamic_cast<const trd_type*
>(in)) {
421 dispatch_out_vector(vector_in);
423 GKO_NOT_SUPPORTED(in);
441template <
typename ValueType,
typename Function>
444 auto complex_to_real = !(
448 if (complex_to_real) {
457 fn(
dynamic_cast<const Vector*
>(dense_in->create_real_view().get()),
458 dynamic_cast<Vector*
>(dense_out->create_real_view().get()));
465template <
typename ValueType,
typename Function>
469 auto complex_to_real = !(
473 if (complex_to_real) {
474 distributed::mixed_precision_dispatch<to_complex<ValueType>>(
475 [&fn](
auto vector_in,
auto vector_out) {
476 fn(vector_in->create_real_view().get(),
477 vector_out->create_real_view().get());
481 distributed::mixed_precision_dispatch<ValueType>(fn, in, out);
489template <
typename ValueType,
typename Function>
493 auto complex_to_real = !(
497 if (complex_to_real) {
507 fn(dense_alpha.get(),
508 dynamic_cast<const Vector*
>(dense_in->create_real_view().get()),
509 dynamic_cast<Vector*
>(dense_out->create_real_view().get()));
521template <
typename ValueType,
typename Function>
526 auto complex_to_real = !(
530 if (complex_to_real) {
541 fn(dense_alpha.get(),
542 dynamic_cast<const Vector*
>(dense_in->create_real_view().get()),
544 dynamic_cast<Vector*
>(dense_out->create_real_view().get()));
570template <
typename ValueType,
typename Function>
571void precision_dispatch_real_complex_distributed(Function fn,
const LinOp* in,
587template <
typename ValueType,
typename Function>
588void precision_dispatch_real_complex_distributed(Function fn,
590 const LinOp* in, LinOp* out)
592 if (
dynamic_cast<const experimental::distributed::DistributedBase*
>(in)) {
605template <
typename ValueType,
typename Function>
606void precision_dispatch_real_complex_distributed(Function fn,
609 const LinOp* beta, LinOp* out)
611 if (
dynamic_cast<const experimental::distributed::DistributedBase*
>(in)) {
613 fn, alpha, in, beta, out);
635template <
typename ValueType,
typename Function,
typename... Args>
636void precision_dispatch_real_complex_distributed(Function fn, Args*... args)
ConvertibleTo interface is used to mark that the implementer can be converted to the object of Result...
Definition polymorphic_object.hpp:479
Definition lin_op.hpp:117
A base class for distributed objects.
Definition base.hpp:32
Vector is a format which explicitly stores (multiple) distributed column vectors in a dense storage f...
Definition vector.hpp:77
Dense is a matrix format which explicitly stores all values of the matrix.
Definition dense.hpp:120
The distributed namespace.
Definition polymorphic_object.hpp:19
gko::detail::temporary_conversion< Vector< ValueType > > make_temporary_conversion(LinOp *matrix)
Convert the given LinOp from experimental::distributed::Vector<...> to experimental::distributed::Vec...
Definition precision_dispatch.hpp:344
void precision_dispatch_real_complex(Function fn, const LinOp *in, LinOp *out)
Calls the given function with the given LinOps temporarily converted to experimental::distributed::Ve...
Definition precision_dispatch.hpp:442
void precision_dispatch(Function fn, Args *... linops)
Calls the given function with each given argument LinOp temporarily converted into experimental::dist...
Definition precision_dispatch.hpp:392
The matrix namespace.
Definition dense_cache.hpp:24
The Ginkgo namespace.
Definition abstract_factory.hpp:20
void mixed_precision_dispatch(Function fn, const LinOp *in, LinOp *out)
Calls the given function with each given argument LinOp converted into matrix::Dense<ValueType> as pa...
Definition precision_dispatch.hpp:228
void mixed_precision_dispatch_real_complex(Function fn, const LinOp *in, LinOp *out)
Calls the given function with the given LinOps cast to their dynamic type matrix::Dense<ValueType>* a...
Definition precision_dispatch.hpp:276
void precision_dispatch(Function fn, Args *... linops)
Calls the given function with each given argument LinOp temporarily converted into matrix::Dense<Valu...
Definition precision_dispatch.hpp:81
detail::temporary_conversion< std::conditional_t< std::is_const< detail::pointee< Ptr > >::value, const matrix::Dense< ValueType >, matrix::Dense< ValueType > > > make_temporary_conversion(Ptr &&matrix)
Convert the given LinOp from matrix::Dense<...> to matrix::Dense<ValueType>.
Definition precision_dispatch.hpp:47
void precision_dispatch_real_complex(Function fn, const LinOp *in, LinOp *out)
Calls the given function with the given LinOps temporarily converted to matrix::Dense<ValueType>* as ...
Definition precision_dispatch.hpp:97
constexpr bool is_complex()
Checks if T is a complex type.
Definition math.hpp:225