Ginkgo Generated from branch based on main. Ginkgo version 1.10.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
workspace.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_SOLVER_WORKSPACE_HPP_
6#define GKO_PUBLIC_CORE_SOLVER_WORKSPACE_HPP_
7
8
9#include <typeinfo>
10
11#include <ginkgo/core/matrix/dense.hpp>
12
13
14namespace gko {
15namespace solver {
16namespace detail {
17
18
22class any_array {
23public:
24 template <typename ValueType>
25 array<ValueType>& init(std::shared_ptr<const Executor> exec, size_type size)
26 {
27 auto container = std::make_unique<concrete_container<ValueType>>(
28 std::move(exec), size);
29 auto& arr = container->arr;
30 data_ = std::move(container);
31 return arr;
32 }
33
34 bool empty() const { return data_.get() == nullptr; }
35
36 template <typename ValueType>
37 bool contains() const
38 {
39 return dynamic_cast<const concrete_container<ValueType>*>(data_.get());
40 }
41
42 template <typename ValueType>
43 array<ValueType>& get()
44 {
45 GKO_ASSERT(this->template contains<ValueType>());
46 return dynamic_cast<concrete_container<ValueType>*>(data_.get())->arr;
47 }
48
49 template <typename ValueType>
50 const array<ValueType>& get() const
51 {
52 GKO_ASSERT(this->template contains<ValueType>());
53 return dynamic_cast<const concrete_container<ValueType>*>(data_.get())
54 ->arr;
55 }
56
57 void clear() { data_.reset(); }
58
59private:
60 struct generic_container {
61 virtual ~generic_container() = default;
62 };
63
64 template <typename ValueType>
65 struct concrete_container : generic_container {
66 template <typename... Args>
67 concrete_container(Args&&... args) : arr{std::forward<Args>(args)...}
68 {}
69
71 };
72
73 std::unique_ptr<generic_container> data_;
74};
75
76
77class workspace {
78public:
79 workspace(std::shared_ptr<const Executor> exec) : exec_{std::move(exec)} {}
80
81 workspace(const workspace& other) : workspace{other.get_executor()} {}
82
83 workspace(workspace&& other) : workspace{other.get_executor()}
84 {
85 other.clear();
86 }
87
88 workspace& operator=(const workspace& other) { return *this; }
89
90 workspace& operator=(workspace&& other)
91 {
92 other.clear();
93 return *this;
94 }
95
96 template <typename LinOpType, typename CreateOperation>
97 LinOpType* create_or_get_op(int op_id, CreateOperation create,
98 const std::type_info& expected_type,
99 dim<2> size, size_type stride)
100 {
101 GKO_ASSERT(op_id >= 0 && op_id < operators_.size());
102 // does the existing object have the wrong type?
103 // vector types may vary e.g. if users derive from Dense
104 auto stored_op = operators_[op_id].get();
105 LinOpType* op{};
106 if (!stored_op || typeid(*stored_op) != expected_type) {
107 auto new_op = create();
108 op = new_op.get();
109 operators_[op_id] = std::move(new_op);
110 return op;
111 }
112 // does the existing object have the wrong dimensions?
113 op = dynamic_cast<LinOpType*>(operators_[op_id].get());
114 GKO_ASSERT(op);
115 if (op->get_size() != size || op->get_stride() != stride) {
116 auto new_op = create();
117 op = new_op.get();
118 operators_[op_id] = std::move(new_op);
119 }
120 return op;
121 }
122
123 const LinOp* get_op(int op_id) const
124 {
125 GKO_ASSERT(op_id >= 0 && op_id < operators_.size());
126 return operators_[op_id].get();
127 }
128
129 template <typename ValueType>
130 array<ValueType>& init_or_get_array(int array_id)
131 {
132 GKO_ASSERT(array_id >= 0 && array_id < arrays_.size());
133 auto& array = arrays_[array_id];
134 if (array.empty()) {
135 auto& result =
136 array.template init<ValueType>(this->get_executor(), 0);
137 return result;
138 }
139 // array types should not change!
140 GKO_ASSERT(array.template contains<ValueType>());
141 return array.template get<ValueType>();
142 }
143
144 template <typename ValueType>
145 array<ValueType>& create_or_get_array(int array_id, size_type size)
146 {
147 auto& result = init_or_get_array<ValueType>(array_id);
148 if (result.get_size() != size) {
149 result.resize_and_reset(size);
150 }
151 return result;
152 }
153
154 std::shared_ptr<const Executor> get_executor() const { return exec_; }
155
156 void set_size(int num_operators, int num_arrays)
157 {
158 operators_.resize(num_operators);
159 arrays_.resize(num_arrays);
160 }
161
162 void clear()
163 {
164 for (auto& op : operators_) {
165 op.reset();
166 }
167 for (auto& array : arrays_) {
168 array.clear();
169 }
170 }
171
172private:
173 std::shared_ptr<const Executor> exec_;
174 std::vector<std::unique_ptr<LinOp>> operators_;
175 std::vector<any_array> arrays_;
176};
177
178
179} // namespace detail
180} // namespace solver
181} // namespace gko
182
183#endif // GKO_PUBLIC_CORE_SOLVER_WORKSPACE_HPP_
The ginkgo Solve namespace.
Definition bicg.hpp:28
The Ginkgo namespace.
Definition abstract_factory.hpp:20
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:90
@ array
The matrix should be written as dense matrix in column-major order.
Definition mtx_io.hpp:96