5#ifndef GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
6#define GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
13#include <ginkgo/core/base/lin_op.hpp>
14#include <ginkgo/core/base/math.hpp>
15#include <ginkgo/core/log/logger.hpp>
16#include <ginkgo/core/matrix/dense.hpp>
17#include <ginkgo/core/matrix/identity.hpp>
18#include <ginkgo/core/solver/workspace.hpp>
19#include <ginkgo/core/stop/combined.hpp>
20#include <ginkgo/core/stop/criterion.hpp>
23GKO_BEGIN_DISABLE_DEPRECATION_WARNINGS
65class ApplyWithInitialGuess {
67 friend class multigrid::detail::MultigridState;
82 virtual void apply_with_initial_guess(
const LinOp* b,
LinOp* x,
88 apply_with_initial_guess(b.get(), x.get(), guess);
103 virtual void apply_with_initial_guess(
const LinOp* alpha,
const LinOp* b,
114 apply_with_initial_guess(alpha.get(), b.get(), beta.get(), x.get(),
131 explicit ApplyWithInitialGuess(
160template <
typename DerivedType>
161class EnableApplyWithInitialGuess :
public ApplyWithInitialGuess {
163 friend class multigrid::detail::MultigridState;
165 explicit EnableApplyWithInitialGuess(
167 : ApplyWithInitialGuess(guess)
174 void apply_with_initial_guess(
const LinOp* b,
LinOp* x,
178 auto exec = self()->get_executor();
179 GKO_ASSERT_CONFORMANT(self(), b);
180 GKO_ASSERT_EQUAL_ROWS(self(), x);
181 GKO_ASSERT_EQUAL_COLS(b, x);
192 void apply_with_initial_guess(
const LinOp* alpha,
const LinOp* b,
197 self(), alpha, b, beta, x);
198 auto exec = self()->get_executor();
199 GKO_ASSERT_CONFORMANT(self(), b);
200 GKO_ASSERT_EQUAL_ROWS(self(), x);
201 GKO_ASSERT_EQUAL_COLS(b, x);
202 GKO_ASSERT_EQUAL_DIMENSIONS(alpha,
dim<2>(1, 1));
203 GKO_ASSERT_EQUAL_DIMENSIONS(beta,
dim<2>(1, 1));
204 this->apply_with_initial_guess_impl(
210 self(), alpha, b, beta, x);
218 virtual void apply_with_initial_guess_impl(
225 virtual void apply_with_initial_guess_impl(
229 GKO_ENABLE_SELF(DerivedType);
237template <
typename Solver>
240 static int num_vectors(
const Solver&) {
return 0; }
242 static int num_arrays(
const Solver&) {
return 0; }
244 static std::vector<std::string> op_names(
const Solver&) {
return {}; }
246 static std::vector<std::string> array_names(
const Solver&) {
return {}; }
248 static std::vector<int> scalars(
const Solver&) {
return {}; }
250 static std::vector<int> vectors(
const Solver&) {
return {}; }
269template <
typename DerivedType>
280 auto exec = self()->get_executor();
282 GKO_ASSERT_EQUAL_DIMENSIONS(self(), new_precond);
283 GKO_ASSERT_IS_SQUARE_MATRIX(new_precond);
284 if (new_precond->get_executor() != exec) {
295 EnablePreconditionable&
operator=(
const EnablePreconditionable& other)
297 if (&other !=
this) {
308 EnablePreconditionable&
operator=(EnablePreconditionable&& other)
310 if (&other !=
this) {
312 other.set_preconditioner(
nullptr);
338 *
this = std::move(other);
342 DerivedType* self() {
return static_cast<DerivedType*
>(
this); }
344 const DerivedType* self()
const
346 return static_cast<const DerivedType*
>(
this);
362class SolverBaseLinOp {
364 SolverBaseLinOp(std::shared_ptr<const Executor> exec)
365 : workspace_{std::move(exec)}
368 virtual ~SolverBaseLinOp() =
default;
375 std::shared_ptr<const LinOp> get_system_matrix()
const
377 return system_matrix_;
380 const LinOp* get_workspace_op(
int vector_id)
const
382 return workspace_.get_op(vector_id);
385 virtual int get_num_workspace_ops()
const {
return 0; }
387 virtual std::vector<std::string> get_workspace_op_names()
const
396 virtual std::vector<int> get_workspace_scalars()
const {
return {}; }
402 virtual std::vector<int> get_workspace_vectors()
const {
return {}; }
405 void set_system_matrix_base(std::shared_ptr<const LinOp> system_matrix)
407 system_matrix_ = std::move(system_matrix);
410 void set_workspace_size(
int num_operators,
int num_arrays)
const
412 workspace_.set_size(num_operators, num_arrays);
415 template <
typename LinOpType>
416 LinOpType* create_workspace_op(
int vector_id, gko::dim<2> size)
const
418 return workspace_.template create_or_get_op<LinOpType>(
421 return LinOpType::create(this->workspace_.get_executor(), size);
423 typeid(LinOpType), size, size[1]);
426 template <
typename LinOpType>
427 LinOpType* create_workspace_op_with_config_of(
int vector_id,
428 const LinOpType* vec)
const
430 return workspace_.template create_or_get_op<LinOpType>(
431 vector_id, [&] {
return LinOpType::create_with_config_of(vec); },
432 typeid(*vec), vec->get_size(), vec->get_stride());
435 template <
typename LinOpType>
436 LinOpType* create_workspace_op_with_type_of(
int vector_id,
437 const LinOpType* vec,
440 return workspace_.template create_or_get_op<LinOpType>(
443 return LinOpType::create_with_type_of(
444 vec, workspace_.get_executor(), size, size[1]);
446 typeid(*vec), size, size[1]);
449 template <
typename LinOpType>
450 LinOpType* create_workspace_op_with_type_of(
int vector_id,
451 const LinOpType* vec,
453 dim<2> local_size)
const
455 return workspace_.template create_or_get_op<LinOpType>(
458 return LinOpType::create_with_type_of(
459 vec, workspace_.get_executor(), global_size, local_size,
462 typeid(*vec), global_size, local_size[1]);
465 template <
typename ValueType>
466 matrix::Dense<ValueType>* create_workspace_scalar(
int vector_id,
469 return workspace_.template create_or_get_op<matrix::Dense<ValueType>>(
473 workspace_.get_executor(), dim<2>{1, size});
475 typeid(matrix::Dense<ValueType>), gko::dim<2>{1, size}, size);
478 template <
typename ValueType>
479 const matrix::Dense<ValueType>* create_workspace_fixed_scalar(
480 int vector_id,
size_type size, ValueType val)
const
482 return workspace_.template create_or_get_op<matrix::Dense<ValueType>>(
486 workspace_.get_executor(), dim<2>{1, size});
490 typeid(matrix::Dense<ValueType>), gko::dim<2>{1, size}, size);
493 template <
typename ValueType>
496 return workspace_.template create_or_get_array<ValueType>(array_id,
500 template <
typename ValueType>
503 return workspace_.template init_or_get_array<ValueType>(array_id);
507 mutable detail::workspace workspace_;
509 std::shared_ptr<const LinOp> system_matrix_;
516template <
typename MatrixType>
519 GKO_DEPRECATED(
"This class will be replaced by the template-less detail::SolverBaseLinOp in a future release")
SolverBase
521 : public detail::SolverBaseLinOp {
523 using detail::SolverBaseLinOp::SolverBaseLinOp;
534 return std::dynamic_pointer_cast<const MatrixType>(
535 SolverBaseLinOp::get_system_matrix());
539 void set_system_matrix_base(std::shared_ptr<const MatrixType> system_matrix)
541 SolverBaseLinOp::set_system_matrix_base(std::move(system_matrix));
555template <
typename DerivedType,
typename MatrixType = LinOp>
562 EnableSolverBase&
operator=(
const EnableSolverBase& other)
564 if (&other !=
this) {
576 if (&other !=
this) {
577 set_system_matrix(other.get_system_matrix());
578 other.set_system_matrix(
nullptr);
585 EnableSolverBase(std::shared_ptr<const MatrixType> system_matrix)
586 : SolverBase<MatrixType>{self()->get_executor()}
588 set_system_matrix(std::move(system_matrix));
595 :
SolverBase<MatrixType>{other.self()->get_executor()}
605 :
SolverBase<MatrixType>{other.self()->get_executor()}
607 *
this = std::move(other);
610 int get_num_workspace_ops()
const override
612 using traits = workspace_traits<DerivedType>;
613 return traits::num_vectors(*self());
616 std::vector<std::string> get_workspace_op_names()
const override
618 using traits = workspace_traits<DerivedType>;
619 return traits::op_names(*self());
629 return traits::scalars(*self());
639 return traits::vectors(*self());
643 void set_system_matrix(std::shared_ptr<const MatrixType> new_system_matrix)
645 auto exec = self()->get_executor();
646 if (new_system_matrix) {
647 GKO_ASSERT_EQUAL_DIMENSIONS(self(), new_system_matrix);
648 GKO_ASSERT_IS_SQUARE_MATRIX(new_system_matrix);
649 if (new_system_matrix->get_executor() != exec) {
650 new_system_matrix =
gko::clone(exec, new_system_matrix);
653 this->set_system_matrix_base(new_system_matrix);
656 void setup_workspace()
const
658 using traits = workspace_traits<DerivedType>;
659 this->set_workspace_size(traits::num_vectors(*self()),
660 traits::num_arrays(*self()));
664 DerivedType* self() {
return static_cast<DerivedType*
>(
this); }
666 const DerivedType* self()
const
668 return static_cast<const DerivedType*
>(
this);
689 return stop_factory_;
698 std::shared_ptr<const stop::CriterionFactory> new_stop_factory)
700 stop_factory_ = new_stop_factory;
704 std::shared_ptr<const stop::CriterionFactory> stop_factory_;
717template <
typename DerivedType>
724 EnableIterativeBase&
operator=(
const EnableIterativeBase& other)
726 if (&other !=
this) {
737 EnableIterativeBase&
operator=(EnableIterativeBase&& other)
739 if (&other !=
this) {
741 other.set_stop_criterion_factory(
nullptr);
749 std::shared_ptr<const stop::CriterionFactory> stop_factory)
765 *
this = std::move(other);
769 std::shared_ptr<const stop::CriterionFactory> new_stop_factory)
override
771 auto exec = self()->get_executor();
772 if (new_stop_factory && new_stop_factory->get_executor() != exec) {
773 new_stop_factory =
gko::clone(exec, new_stop_factory);
779 DerivedType* self() {
return static_cast<DerivedType*
>(
this); }
781 const DerivedType* self()
const
783 return static_cast<const DerivedType*
>(
this);
798template <
typename ValueType,
typename DerivedType>
799class EnablePreconditionedIterativeSolver
800 :
public EnableSolverBase<DerivedType>,
801 public EnableIterativeBase<DerivedType>,
802 public EnablePreconditionable<DerivedType> {
804 EnablePreconditionedIterativeSolver() =
default;
806 EnablePreconditionedIterativeSolver(
807 std::shared_ptr<const LinOp> system_matrix,
808 std::shared_ptr<const stop::CriterionFactory> stop_factory,
810 : EnableSolverBase<DerivedType>(std::move(system_matrix)),
811 EnableIterativeBase<DerivedType>{std::move(stop_factory)},
815 template <
typename FactoryParameters>
816 EnablePreconditionedIterativeSolver(
817 std::shared_ptr<const LinOp> system_matrix,
818 const FactoryParameters& params)
819 : EnablePreconditionedIterativeSolver{
821 generate_preconditioner(system_matrix, params)}
825 template <
typename FactoryParameters>
826 static std::shared_ptr<const LinOp> generate_preconditioner(
827 std::shared_ptr<const LinOp> system_matrix,
828 const FactoryParameters& params)
830 if (params.generated_preconditioner) {
831 return params.generated_preconditioner;
832 }
else if (params.preconditioner) {
833 return params.preconditioner->generate(system_matrix);
836 system_matrix->get_executor(), system_matrix->get_size());
842template <
typename Parameters,
typename Factory>
848 std::vector<std::shared_ptr<const stop::CriterionFactory>>
853template <
typename Parameters,
typename Factory>
860 std::shared_ptr<const LinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
876GKO_END_DISABLE_DEPRECATION_WARNINGS
Definition lin_op.hpp:117
A LinOp implementing this interface can be preconditioned.
Definition lin_op.hpp:682
virtual void set_preconditioner(std::shared_ptr< const LinOp > new_precond)
Sets the preconditioner operator used by the Preconditionable.
Definition lin_op.hpp:702
virtual std::shared_ptr< const LinOp > get_preconditioner() const
Returns the preconditioner operator used by the Preconditionable.
Definition lin_op.hpp:691
The enable_parameters_type mixin is used to create a base implementation of the factory parameters st...
Definition abstract_factory.hpp:211
static std::unique_ptr< Dense > create(std::shared_ptr< const Executor > exec, const dim< 2 > &size={}, size_type stride=0)
Creates an uninitialized Dense matrix of the specified size.
static std::unique_ptr< Identity > create(std::shared_ptr< const Executor > exec, dim< 2 > size)
Creates an Identity matrix of the specified size.
This class is used for function parameters in the place of raw pointers.
Definition utils_helper.hpp:41
A LinOp deriving from this CRTP class stores a stopping criterion factory and allows applying with a ...
Definition solver_base.hpp:718
EnableIterativeBase & operator=(EnableIterativeBase &&other)
Moves the provided stopping criterion, clones it onto this executor if executors don't match.
Definition solver_base.hpp:737
void set_stop_criterion_factory(std::shared_ptr< const stop::CriterionFactory > new_stop_factory) override
Sets the stopping criterion of the solver.
Definition solver_base.hpp:768
EnableIterativeBase(EnableIterativeBase &&other)
Moves the provided stopping criterion.
Definition solver_base.hpp:763
EnableIterativeBase(const EnableIterativeBase &other)
Creates a shallow copy of the provided stopping criterion.
Definition solver_base.hpp:757
EnableIterativeBase & operator=(const EnableIterativeBase &other)
Creates a shallow copy of the provided stopping criterion, clones it onto this executor if executors ...
Definition solver_base.hpp:724
Mixin providing default operation for Preconditionable with correct value semantics.
Definition solver_base.hpp:270
EnablePreconditionable(const EnablePreconditionable &other)
Creates a shallow copy of the provided preconditioner.
Definition solver_base.hpp:327
EnablePreconditionable & operator=(EnablePreconditionable &&other)
Moves the provided preconditioner, clones it onto this executor if executors don't match.
Definition solver_base.hpp:308
EnablePreconditionable(EnablePreconditionable &&other)
Moves the provided preconditioner.
Definition solver_base.hpp:336
EnablePreconditionable & operator=(const EnablePreconditionable &other)
Creates a shallow copy of the provided preconditioner, clones it onto this executor if executors don'...
Definition solver_base.hpp:295
void set_preconditioner(std::shared_ptr< const LinOp > new_precond) override
Sets the preconditioner operator used by the Preconditionable.
Definition solver_base.hpp:278
A LinOp deriving from this CRTP class stores a system matrix.
Definition solver_base.hpp:556
EnableSolverBase(EnableSolverBase &&other)
Moves the provided system matrix.
Definition solver_base.hpp:604
std::vector< int > get_workspace_vectors() const override
Returns the IDs of all vectors (workspace vectors with system dimension-dependent size,...
Definition solver_base.hpp:636
std::vector< int > get_workspace_scalars() const override
Returns the IDs of all scalars (workspace vectors with system dimension-independent size,...
Definition solver_base.hpp:626
EnableSolverBase(const EnableSolverBase &other)
Creates a shallow copy of the provided system matrix.
Definition solver_base.hpp:594
EnableSolverBase & operator=(EnableSolverBase &&other)
Moves the provided system matrix, clones it onto this executor if executors don't match.
Definition solver_base.hpp:574
EnableSolverBase & operator=(const EnableSolverBase &other)
Creates a shallow copy of the provided system matrix, clones it onto this executor if executors don't...
Definition solver_base.hpp:562
A LinOp implementing this interface stores a stopping criterion factory.
Definition solver_base.hpp:679
std::shared_ptr< const stop::CriterionFactory > get_stop_criterion_factory() const
Gets the stopping criterion factory of the solver.
Definition solver_base.hpp:686
virtual void set_stop_criterion_factory(std::shared_ptr< const stop::CriterionFactory > new_stop_factory)
Sets the stopping criterion of the solver.
Definition solver_base.hpp:697
Definition solver_base.hpp:521
std::shared_ptr< const MatrixType > get_system_matrix() const
Returns the system matrix, with its concrete type, used by the solver.
Definition solver_base.hpp:532
#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default)
Creates a scalar factory parameter in the factory parameters structure.
Definition abstract_factory.hpp:445
std::shared_ptr< const CriterionFactory > combine(FactoryContainer &&factories)
Combines multiple criterion factories into a single combined criterion factory.
Definition combined.hpp:109
The logger namespace .
Definition convergence.hpp:22
The multigrid components namespace.
Definition matrix.hpp:36
The Preconditioner namespace.
Definition gauss_seidel.hpp:19
The ginkgo Solve namespace.
Definition bicg.hpp:28
initial_guess_mode
Give a initial guess mode about the input of the apply method.
Definition solver_base.hpp:33
@ provided
the input is provided
Definition solver_base.hpp:45
@ rhs
the input is right hand side
Definition solver_base.hpp:41
@ zero
the input is zero
Definition solver_base.hpp:37
The Ginkgo namespace.
Definition abstract_factory.hpp:20
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:90
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition utils_helper.hpp:173
detail::temporary_clone< detail::pointee< Ptr > > make_temporary_clone(std::shared_ptr< const Executor > exec, Ptr &&ptr)
Creates a temporary_clone.
Definition temporary_clone.hpp:208
@ array
The matrix should be written as dense matrix in column-major order.
Definition mtx_io.hpp:96
A type representing the dimensions of a multidimensional object.
Definition dim.hpp:26
Definition solver_base.hpp:844
std::vector< std::shared_ptr< const stop::CriterionFactory > > criteria
Stopping criteria to be used by the solver.
Definition solver_base.hpp:849
Definition solver_base.hpp:855
std::shared_ptr< const LinOp > generated_preconditioner
Already generated preconditioner.
Definition solver_base.hpp:868
std::shared_ptr< const LinOpFactory > preconditioner
The preconditioner to be used by the iterative solver.
Definition solver_base.hpp:861
Traits class providing information on the type and location of workspace vectors inside a solver.
Definition solver_base.hpp:238