Ginkgo Generated from branch based on main. Ginkgo version 1.10.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
precision_dispatch.hpp
1// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
6#define GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
7
8
9#include <ginkgo/config.hpp>
10#include <ginkgo/core/base/math.hpp>
11#include <ginkgo/core/base/temporary_conversion.hpp>
12#include <ginkgo/core/distributed/vector.hpp>
13#include <ginkgo/core/matrix/dense.hpp>
14
15
16namespace gko {
17
18
43template <typename ValueType, typename Ptr>
44detail::temporary_conversion<std::conditional_t<
45 std::is_const<detail::pointee<Ptr>>::value, const matrix::Dense<ValueType>,
48{
49 using Pointee = detail::pointee<Ptr>;
50 using Dense = matrix::Dense<ValueType>;
54 using MaybeConstDense =
55 std::conditional_t<std::is_const<Pointee>::value, const Dense, Dense>;
56 auto result =
57 detail::temporary_conversion<MaybeConstDense>::template create<
58 NextDense, Next2Dense, Next3Dense>(matrix);
59 if (!result) {
60 GKO_NOT_SUPPORTED(matrix);
61 }
62 return result;
63}
64
65
80template <typename ValueType, typename Function, typename... Args>
81void precision_dispatch(Function fn, Args*... linops)
82{
83 fn(make_temporary_conversion<ValueType>(linops).get()...);
84}
85
86
96template <typename ValueType, typename Function>
97void precision_dispatch_real_complex(Function fn, const LinOp* in, LinOp* out)
98{
99 // do we need to convert complex Dense to real Dense?
100 // all real dense vectors are intra-convertible, thus by casting to
101 // ConvertibleTo<matrix::Dense<>>, we can check whether a LinOp is a real
102 // dense matrix:
103 auto complex_to_real =
105 dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in));
106 if (complex_to_real) {
109 using Dense = matrix::Dense<ValueType>;
110 // These dynamic_casts are only needed to make the code compile
111 // If ValueType is complex, this branch will never be taken
112 // If ValueType is real, the cast is a no-op
113 fn(dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
114 dynamic_cast<Dense*>(dense_out->create_real_view().get()));
115 } else {
117 }
118}
119
120
130template <typename ValueType, typename Function>
131void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
132 const LinOp* in, LinOp* out)
133{
134 // do we need to convert complex Dense to real Dense?
135 // all real dense vectors are intra-convertible, thus by casting to
136 // ConvertibleTo<matrix::Dense<>>, we can check whether a LinOp is a real
137 // dense matrix:
138 auto complex_to_real =
140 dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in));
141 if (complex_to_real) {
144 auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
145 using Dense = matrix::Dense<ValueType>;
146 // These dynamic_casts are only needed to make the code compile
147 // If ValueType is complex, this branch will never be taken
148 // If ValueType is real, the cast is a no-op
149 fn(dense_alpha.get(),
150 dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
151 dynamic_cast<Dense*>(dense_out->create_real_view().get()));
152 } else {
153 precision_dispatch<ValueType>(fn, alpha, in, out);
154 }
155}
156
157
167template <typename ValueType, typename Function>
168void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
169 const LinOp* in, const LinOp* beta,
170 LinOp* out)
171{
172 // do we need to convert complex Dense to real Dense?
173 // all real dense vectors are intra-convertible, thus by casting to
174 // ConvertibleTo<matrix::Dense<>>, we can check whether a LinOp is a real
175 // dense matrix:
176 auto complex_to_real =
178 dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in));
179 if (complex_to_real) {
182 auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
183 auto dense_beta = make_temporary_conversion<ValueType>(beta);
184 using Dense = matrix::Dense<ValueType>;
185 // These dynamic_casts are only needed to make the code compile
186 // If ValueType is complex, this branch will never be taken
187 // If ValueType is real, the cast is a no-op
188 fn(dense_alpha.get(),
189 dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
190 dense_beta.get(),
191 dynamic_cast<Dense*>(dense_out->create_real_view().get()));
192 } else {
193 precision_dispatch<ValueType>(fn, alpha, in, beta, out);
194 }
195}
196
197
227template <typename ValueType, typename Function>
228void mixed_precision_dispatch(Function fn, const LinOp* in, LinOp* out)
229{
230#ifdef GINKGO_MIXED_PRECISION
231 using fst_type = matrix::Dense<ValueType>;
235 auto dispatch_out_vector = [&](auto dense_in) {
236 if (auto dense_out = dynamic_cast<fst_type*>(out)) {
237 fn(dense_in, dense_out);
238 } else if (auto dense_out = dynamic_cast<snd_type*>(out)) {
239 fn(dense_in, dense_out);
240 } else if (auto dense_out = dynamic_cast<trd_type*>(out)) {
241 fn(dense_in, dense_out);
242 } else if (auto dense_out = dynamic_cast<fth_type*>(out)) {
243 fn(dense_in, dense_out);
244 } else {
245 GKO_NOT_SUPPORTED(out);
246 }
247 };
248 if (auto dense_in = dynamic_cast<const fst_type*>(in)) {
249 dispatch_out_vector(dense_in);
250 } else if (auto dense_in = dynamic_cast<const snd_type*>(in)) {
251 dispatch_out_vector(dense_in);
252 } else if (auto dense_in = dynamic_cast<const trd_type*>(in)) {
253 dispatch_out_vector(dense_in);
254 } else if (auto dense_in = dynamic_cast<const fth_type*>(in)) {
255 dispatch_out_vector(dense_in);
256 } else {
257 GKO_NOT_SUPPORTED(in);
258 }
259#else
261#endif
262}
263
264
274template <typename ValueType, typename Function,
275 std::enable_if_t<is_complex<ValueType>()>* = nullptr>
277 LinOp* out)
278{
279#ifdef GINKGO_MIXED_PRECISION
281#else
283#endif
284}
285
286
287template <typename ValueType, typename Function,
288 std::enable_if_t<!is_complex<ValueType>()>* = nullptr>
289void mixed_precision_dispatch_real_complex(Function fn, const LinOp* in,
290 LinOp* out)
291{
292#ifdef GINKGO_MIXED_PRECISION
293 if (!dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in)) {
295 [&fn](auto dense_in, auto dense_out) {
296 fn(dense_in->create_real_view().get(),
297 dense_out->create_real_view().get());
298 },
299 in, out);
300 } else {
302 }
303#else
305#endif
306}
307
308
309namespace experimental {
310
311
312#if GINKGO_BUILD_MPI
313
314
315namespace distributed {
316
317
343template <typename ValueType>
344gko::detail::temporary_conversion<Vector<ValueType>> make_temporary_conversion(
345 LinOp* matrix)
346{
347 auto result =
348 gko::detail::temporary_conversion<Vector<ValueType>>::template create<
352 if (!result) {
353 GKO_NOT_SUPPORTED(matrix);
354 }
355 return result;
356}
357
358
362template <typename ValueType>
363gko::detail::temporary_conversion<const Vector<ValueType>>
365{
366 auto result = gko::detail::temporary_conversion<const Vector<ValueType>>::
367 template create<Vector<next_precision<ValueType>>,
370 if (!result) {
371 GKO_NOT_SUPPORTED(matrix);
372 }
373 return result;
374}
375
376
391template <typename ValueType, typename Function, typename... Args>
392void precision_dispatch(Function fn, Args*... linops)
393{
395}
396
397
398template <typename ValueType, typename Function>
399void mixed_precision_dispatch(Function fn, const LinOp* in, LinOp* out)
400{
401#ifdef GINKGO_MIXED_PRECISION
402 using fst_type = Vector<ValueType>;
403 using snd_type = Vector<next_precision<ValueType, 2>>;
404 using trd_type = Vector<next_precision<ValueType, 3>>;
405 auto dispatch_out_vector = [&](auto vector_in) {
406 if (auto vector_out = dynamic_cast<fst_type*>(out)) {
407 fn(vector_in, vector_out);
408 } else if (auto vector_out = dynamic_cast<snd_type*>(out)) {
409 fn(vector_in, vector_out);
410 } else if (auto vector_out = dynamic_cast<trd_type*>(out)) {
411 fn(vector_in, vector_out);
412 } else {
413 GKO_NOT_SUPPORTED(out);
414 }
415 };
416 if (auto vector_in = dynamic_cast<const fst_type*>(in)) {
417 dispatch_out_vector(vector_in);
418 } else if (auto vector_in = dynamic_cast<const snd_type*>(in)) {
419 dispatch_out_vector(vector_in);
420 } else if (auto vector_in = dynamic_cast<const trd_type*>(in)) {
421 dispatch_out_vector(vector_in);
422 } else {
423 GKO_NOT_SUPPORTED(in);
424 }
425#else
426 // avoid ambiguous
428#endif
429}
430
431
441template <typename ValueType, typename Function>
442void precision_dispatch_real_complex(Function fn, const LinOp* in, LinOp* out)
443{
444 auto complex_to_real = !(
447 in));
448 if (complex_to_real) {
449 auto dense_in =
451 auto dense_out =
454 // These dynamic_casts are only needed to make the code compile
455 // If ValueType is complex, this branch will never be taken
456 // If ValueType is real, the cast is a no-op
457 fn(dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
458 dynamic_cast<Vector*>(dense_out->create_real_view().get()));
459 } else {
461 }
462}
463
464
465template <typename ValueType, typename Function>
466void mixed_precision_dispatch_real_complex(Function fn, const LinOp* in,
467 LinOp* out)
468{
469 auto complex_to_real = !(
472 in));
473 if (complex_to_real) {
474 distributed::mixed_precision_dispatch<to_complex<ValueType>>(
475 [&fn](auto vector_in, auto vector_out) {
476 fn(vector_in->create_real_view().get(),
477 vector_out->create_real_view().get());
478 },
479 in, out);
480 } else {
481 distributed::mixed_precision_dispatch<ValueType>(fn, in, out);
482 }
483}
484
485
489template <typename ValueType, typename Function>
490void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
491 const LinOp* in, LinOp* out)
492{
493 auto complex_to_real = !(
496 in));
497 if (complex_to_real) {
498 auto dense_in =
500 auto dense_out =
502 auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
504 // These dynamic_casts are only needed to make the code compile
505 // If ValueType is complex, this branch will never be taken
506 // If ValueType is real, the cast is a no-op
507 fn(dense_alpha.get(),
508 dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
509 dynamic_cast<Vector*>(dense_out->create_real_view().get()));
510 } else {
514 }
515}
516
517
521template <typename ValueType, typename Function>
522void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
523 const LinOp* in, const LinOp* beta,
524 LinOp* out)
525{
526 auto complex_to_real = !(
529 in));
530 if (complex_to_real) {
531 auto dense_in =
533 auto dense_out =
535 auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
536 auto dense_beta = gko::make_temporary_conversion<ValueType>(beta);
538 // These dynamic_casts are only needed to make the code compile
539 // If ValueType is complex, this branch will never be taken
540 // If ValueType is real, the cast is a no-op
541 fn(dense_alpha.get(),
542 dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
543 dense_beta.get(),
544 dynamic_cast<Vector*>(dense_out->create_real_view().get()));
545 } else {
550 }
551}
552
553
554} // namespace distributed
555
556
570template <typename ValueType, typename Function>
571void precision_dispatch_real_complex_distributed(Function fn, const LinOp* in,
572 LinOp* out)
573{
574 if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
576 fn, in, out);
577 } else {
579 }
580}
581
582
587template <typename ValueType, typename Function>
588void precision_dispatch_real_complex_distributed(Function fn,
589 const LinOp* alpha,
590 const LinOp* in, LinOp* out)
591{
592 if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
594 fn, alpha, in, out);
595 } else {
597 }
598}
599
600
605template <typename ValueType, typename Function>
606void precision_dispatch_real_complex_distributed(Function fn,
607 const LinOp* alpha,
608 const LinOp* in,
609 const LinOp* beta, LinOp* out)
610{
611 if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
613 fn, alpha, in, beta, out);
614
615 } else {
617 out);
618 }
619}
620
621
622#else
623
624
635template <typename ValueType, typename Function, typename... Args>
636void precision_dispatch_real_complex_distributed(Function fn, Args*... args)
637{
639}
640
641
642#endif
643
644
645} // namespace experimental
646} // namespace gko
647
648
649#endif // GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
ConvertibleTo interface is used to mark that the implementer can be converted to the object of Result...
Definition polymorphic_object.hpp:479
Definition lin_op.hpp:117
A base class for distributed objects.
Definition base.hpp:32
Vector is a format which explicitly stores (multiple) distributed column vectors in a dense storage f...
Definition vector.hpp:77
Dense is a matrix format which explicitly stores all values of the matrix.
Definition dense.hpp:120
The distributed namespace.
Definition polymorphic_object.hpp:19
gko::detail::temporary_conversion< Vector< ValueType > > make_temporary_conversion(LinOp *matrix)
Convert the given LinOp from experimental::distributed::Vector<...> to experimental::distributed::Vec...
Definition precision_dispatch.hpp:344
void precision_dispatch_real_complex(Function fn, const LinOp *in, LinOp *out)
Calls the given function with the given LinOps temporarily converted to experimental::distributed::Ve...
Definition precision_dispatch.hpp:442
void precision_dispatch(Function fn, Args *... linops)
Calls the given function with each given argument LinOp temporarily converted into experimental::dist...
Definition precision_dispatch.hpp:392
The matrix namespace.
Definition dense_cache.hpp:24
The Ginkgo namespace.
Definition abstract_factory.hpp:20
void mixed_precision_dispatch(Function fn, const LinOp *in, LinOp *out)
Calls the given function with each given argument LinOp converted into matrix::Dense<ValueType> as pa...
Definition precision_dispatch.hpp:228
void mixed_precision_dispatch_real_complex(Function fn, const LinOp *in, LinOp *out)
Calls the given function with the given LinOps cast to their dynamic type matrix::Dense<ValueType>* a...
Definition precision_dispatch.hpp:276
void precision_dispatch(Function fn, Args *... linops)
Calls the given function with each given argument LinOp temporarily converted into matrix::Dense<Valu...
Definition precision_dispatch.hpp:81
detail::temporary_conversion< std::conditional_t< std::is_const< detail::pointee< Ptr > >::value, const matrix::Dense< ValueType >, matrix::Dense< ValueType > > > make_temporary_conversion(Ptr &&matrix)
Convert the given LinOp from matrix::Dense<...> to matrix::Dense<ValueType>.
Definition precision_dispatch.hpp:47
void precision_dispatch_real_complex(Function fn, const LinOp *in, LinOp *out)
Calls the given function with the given LinOps temporarily converted to matrix::Dense<ValueType>* as ...
Definition precision_dispatch.hpp:97
constexpr bool is_complex()
Checks if T is a complex type.
Definition math.hpp:225