Ginkgo Generated from branch based on main. Ginkgo version 1.10.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
mpi.hpp
1// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_MPI_HPP_
6#define GKO_PUBLIC_CORE_BASE_MPI_HPP_
7
8
9#include <memory>
10#include <type_traits>
11#include <utility>
12
13#include <ginkgo/config.hpp>
14#include <ginkgo/core/base/exception.hpp>
15#include <ginkgo/core/base/exception_helpers.hpp>
16#include <ginkgo/core/base/executor.hpp>
17#include <ginkgo/core/base/half.hpp>
18#include <ginkgo/core/base/types.hpp>
19#include <ginkgo/core/base/utils_helper.hpp>
20
21
22#if GINKGO_BUILD_MPI
23
24
25#include <mpi.h>
26
27
28namespace gko {
29namespace experimental {
36namespace mpi {
37
38
42inline constexpr bool is_gpu_aware()
43{
44#if GINKGO_HAVE_GPU_AWARE_MPI
45 return true;
46#else
47 return false;
48#endif
49}
50
51
59int map_rank_to_device_id(MPI_Comm comm, int num_devices);
60
61
62#define GKO_REGISTER_MPI_TYPE(input_type, mpi_type) \
63 template <> \
64 struct type_impl<input_type> { \
65 static MPI_Datatype get_type() { return mpi_type; } \
66 }
67
76template <typename T>
77struct type_impl {};
78
79
80GKO_REGISTER_MPI_TYPE(char, MPI_CHAR);
81GKO_REGISTER_MPI_TYPE(unsigned char, MPI_UNSIGNED_CHAR);
82GKO_REGISTER_MPI_TYPE(unsigned, MPI_UNSIGNED);
83GKO_REGISTER_MPI_TYPE(int, MPI_INT);
84GKO_REGISTER_MPI_TYPE(unsigned short, MPI_UNSIGNED_SHORT);
85GKO_REGISTER_MPI_TYPE(unsigned long, MPI_UNSIGNED_LONG);
86GKO_REGISTER_MPI_TYPE(long, MPI_LONG);
87GKO_REGISTER_MPI_TYPE(long long, MPI_LONG_LONG_INT);
88GKO_REGISTER_MPI_TYPE(unsigned long long, MPI_UNSIGNED_LONG_LONG);
89GKO_REGISTER_MPI_TYPE(float, MPI_FLOAT);
90GKO_REGISTER_MPI_TYPE(double, MPI_DOUBLE);
91GKO_REGISTER_MPI_TYPE(long double, MPI_LONG_DOUBLE);
92#if GINKGO_ENABLE_HALF
93// OpenMPI 5.0 have support from MPIX_C_FLOAT16 and MPICHv3.4a1 MPIX_C_FLOAT16
94// Only OpenMPI support complex float16
95// TODO: use native type when mpi is configured with half feature
96GKO_REGISTER_MPI_TYPE(half, MPI_UNSIGNED_SHORT);
97GKO_REGISTER_MPI_TYPE(std::complex<half>, MPI_FLOAT);
98#endif // GKO_ENABLE_HALF
99#if GINKGO_ENABLE_BFLOAT16
100GKO_REGISTER_MPI_TYPE(bfloat16, MPI_UNSIGNED_SHORT);
101GKO_REGISTER_MPI_TYPE(std::complex<bfloat16>, MPI_FLOAT);
102#endif // GKO_ENABLE_BFLOAT16
103GKO_REGISTER_MPI_TYPE(std::complex<float>, MPI_C_FLOAT_COMPLEX);
104GKO_REGISTER_MPI_TYPE(std::complex<double>, MPI_C_DOUBLE_COMPLEX);
105
106
114public:
121 contiguous_type(int count, MPI_Datatype old_type) : type_(MPI_DATATYPE_NULL)
122 {
123 GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_contiguous(count, old_type, &type_));
124 GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_commit(&type_));
125 }
126
130 contiguous_type() : type_(MPI_DATATYPE_NULL) {}
131
136
141
147 contiguous_type(contiguous_type&& other) noexcept : type_(MPI_DATATYPE_NULL)
148 {
149 *this = std::move(other);
150 }
151
160 {
161 if (this != &other) {
162 this->type_ = std::exchange(other.type_, MPI_DATATYPE_NULL);
163 }
164 return *this;
165 }
166
171 {
172 if (type_ != MPI_DATATYPE_NULL) {
173 MPI_Type_free(&type_);
174 }
175 }
176
182 MPI_Datatype get() const { return type_; }
183
184private:
185 MPI_Datatype type_;
186};
187
188
193enum class thread_type {
194 serialized = MPI_THREAD_SERIALIZED,
195 funneled = MPI_THREAD_FUNNELED,
196 single = MPI_THREAD_SINGLE,
197 multiple = MPI_THREAD_MULTIPLE
198};
199
200
211public:
212 static bool is_finalized()
213 {
214 int flag = 0;
215 GKO_ASSERT_NO_MPI_ERRORS(MPI_Finalized(&flag));
216 return flag;
217 }
218
219 static bool is_initialized()
220 {
221 int flag = 0;
222 GKO_ASSERT_NO_MPI_ERRORS(MPI_Initialized(&flag));
223 return flag;
224 }
225
231 int get_provided_thread_support() const { return provided_thread_support_; }
232
241 environment(int& argc, char**& argv,
242 const thread_type thread_t = thread_type::serialized)
243 {
244 this->required_thread_support_ = static_cast<int>(thread_t);
245 GKO_ASSERT_NO_MPI_ERRORS(
246 MPI_Init_thread(&argc, &argv, this->required_thread_support_,
247 &(this->provided_thread_support_)));
248 }
249
253 ~environment() { MPI_Finalize(); }
254
255 environment(const environment&) = delete;
256 environment(environment&&) = delete;
257 environment& operator=(const environment&) = delete;
258 environment& operator=(environment&&) = delete;
259
260private:
261 int required_thread_support_;
262 int provided_thread_support_;
263};
264
265
266namespace {
267
268
273class comm_deleter {
274public:
275 using pointer = MPI_Comm*;
276 void operator()(pointer comm) const
277 {
278 GKO_ASSERT(*comm != MPI_COMM_NULL);
279 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_free(comm));
280 delete comm;
281 }
282};
283
284
285} // namespace
286
287
291struct status {
295 status() : status_(MPI_Status{}) {}
296
302 MPI_Status* get() { return &this->status_; }
303
314 template <typename T>
315 int get_count(const T* data) const
316 {
317 int count;
318 MPI_Get_count(&status_, type_impl<T>::get_type(), &count);
319 return count;
320 }
321
322private:
323 MPI_Status status_;
324};
325
326
331class request {
332public:
337 request() : req_(MPI_REQUEST_NULL) {}
338
339 request(const request&) = delete;
340
341 request& operator=(const request&) = delete;
342
343 request(request&& o) noexcept { *this = std::move(o); }
344
345 request& operator=(request&& o) noexcept
346 {
347 if (this != &o) {
348 this->req_ = std::exchange(o.req_, MPI_REQUEST_NULL);
349 }
350 return *this;
351 }
352
353 ~request()
354 {
355 if (req_ != MPI_REQUEST_NULL) {
356 if (MPI_Request_free(&req_) != MPI_SUCCESS) {
357 std::terminate(); // since we can't throw in destructors, we
358 // have to terminate the program
359 }
360 }
361 }
362
368 MPI_Request* get() { return &this->req_; }
369
377 {
379 GKO_ASSERT_NO_MPI_ERRORS(MPI_Wait(&req_, status.get()));
380 return status;
381 }
382
383private:
384 MPI_Request req_;
385};
386
387
395inline std::vector<status> wait_all(std::vector<request>& req)
396{
397 std::vector<status> stat;
398 for (std::size_t i = 0; i < req.size(); ++i) {
399 stat.emplace_back(req[i].wait());
400 }
401 return stat;
402}
403
404
420public:
431 communicator(const MPI_Comm& comm, bool force_host_buffer = false)
432 : comm_(), force_host_buffer_(force_host_buffer)
433 {
434 this->comm_.reset(new MPI_Comm(comm));
435 }
436
445 communicator(const MPI_Comm& comm, int color, int key)
446 {
447 MPI_Comm comm_out;
448 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split(comm, color, key, &comm_out));
449 this->comm_.reset(new MPI_Comm(comm_out), comm_deleter{});
450 }
451
460 communicator(const communicator& comm, int color, int key)
461 {
462 MPI_Comm comm_out;
463 GKO_ASSERT_NO_MPI_ERRORS(
464 MPI_Comm_split(comm.get(), color, key, &comm_out));
465 this->comm_.reset(new MPI_Comm(comm_out), comm_deleter{});
466 }
467
477 static communicator create_owning(const MPI_Comm& comm,
478 bool force_host_buffer = false)
479 {
480 communicator comm_out(MPI_COMM_NULL, force_host_buffer);
481 comm_out.comm_.reset(new MPI_Comm(comm), comm_deleter{});
482 return comm_out;
483 }
484
490 communicator(const communicator& other) = default;
491
498 communicator(communicator&& other) { *this = std::move(other); }
499
503 communicator& operator=(const communicator& other) = default;
504
509 {
510 if (this != &other) {
511 comm_ = std::exchange(other.comm_,
512 std::make_shared<MPI_Comm>(MPI_COMM_NULL));
513 force_host_buffer_ = other.force_host_buffer_;
514 }
515 return *this;
516 }
517
523 const MPI_Comm& get() const { return *(this->comm_.get()); }
524
525 bool force_host_buffer() const { return force_host_buffer_; }
526
532 int size() const { return get_num_ranks(); }
533
539 int rank() const { return get_my_rank(); };
540
546 int node_local_rank() const { return get_node_local_rank(); };
547
553 bool operator==(const communicator& rhs) const { return is_identical(rhs); }
554
560 bool operator!=(const communicator& rhs) const { return !(*this == rhs); }
561
571 bool is_identical(const communicator& rhs) const
572 {
573 if (get() == MPI_COMM_NULL || rhs.get() == MPI_COMM_NULL) {
574 return get() == rhs.get();
575 }
576 int flag;
577 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_compare(get(), rhs.get(), &flag));
578 return flag == MPI_IDENT;
579 }
580
593 bool is_congruent(const communicator& rhs) const
594 {
595 if (get() == MPI_COMM_NULL || rhs.get() == MPI_COMM_NULL) {
596 return get() == rhs.get();
597 }
598 int flag;
599 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_compare(get(), rhs.get(), &flag));
600 return flag == MPI_CONGRUENT;
601 }
602
607 void synchronize() const
608 {
609 GKO_ASSERT_NO_MPI_ERRORS(MPI_Barrier(this->get()));
610 }
611
625 template <typename SendType>
626 void send(std::shared_ptr<const Executor> exec, const SendType* send_buffer,
627 const int send_count, const int destination_rank,
628 const int send_tag) const
629 {
630 auto guard = exec->get_scoped_device_id_guard();
631 GKO_ASSERT_NO_MPI_ERRORS(
632 MPI_Send(send_buffer, send_count, type_impl<SendType>::get_type(),
633 destination_rank, send_tag, this->get()));
634 }
635
652 template <typename SendType>
653 request i_send(std::shared_ptr<const Executor> exec,
654 const SendType* send_buffer, const int send_count,
655 const int destination_rank, const int send_tag) const
656 {
657 auto guard = exec->get_scoped_device_id_guard();
658 request req;
659 GKO_ASSERT_NO_MPI_ERRORS(
660 MPI_Isend(send_buffer, send_count, type_impl<SendType>::get_type(),
661 destination_rank, send_tag, this->get(), req.get()));
662 return req;
663 }
664
680 template <typename RecvType>
681 status recv(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
682 const int recv_count, const int source_rank,
683 const int recv_tag) const
684 {
685 auto guard = exec->get_scoped_device_id_guard();
686 status st;
687 GKO_ASSERT_NO_MPI_ERRORS(
688 MPI_Recv(recv_buffer, recv_count, type_impl<RecvType>::get_type(),
689 source_rank, recv_tag, this->get(), st.get()));
690 return st;
691 }
692
708 template <typename RecvType>
709 request i_recv(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
710 const int recv_count, const int source_rank,
711 const int recv_tag) const
712 {
713 auto guard = exec->get_scoped_device_id_guard();
714 request req;
715 GKO_ASSERT_NO_MPI_ERRORS(
716 MPI_Irecv(recv_buffer, recv_count, type_impl<RecvType>::get_type(),
717 source_rank, recv_tag, this->get(), req.get()));
718 return req;
719 }
720
733 template <typename BroadcastType>
734 void broadcast(std::shared_ptr<const Executor> exec, BroadcastType* buffer,
735 int count, int root_rank) const
736 {
737 auto guard = exec->get_scoped_device_id_guard();
738 GKO_ASSERT_NO_MPI_ERRORS(MPI_Bcast(buffer, count,
740 root_rank, this->get()));
741 }
742
758 template <typename BroadcastType>
759 request i_broadcast(std::shared_ptr<const Executor> exec,
760 BroadcastType* buffer, int count, int root_rank) const
761 {
762 auto guard = exec->get_scoped_device_id_guard();
763 request req;
764 GKO_ASSERT_NO_MPI_ERRORS(
765 MPI_Ibcast(buffer, count, type_impl<BroadcastType>::get_type(),
766 root_rank, this->get(), req.get()));
767 return req;
768 }
769
784 template <typename ReduceType>
785 void reduce(std::shared_ptr<const Executor> exec,
786 const ReduceType* send_buffer, ReduceType* recv_buffer,
787 int count, MPI_Op operation, int root_rank) const
788 {
789 auto guard = exec->get_scoped_device_id_guard();
790 GKO_ASSERT_NO_MPI_ERRORS(MPI_Reduce(send_buffer, recv_buffer, count,
792 operation, root_rank, this->get()));
793 }
794
811 template <typename ReduceType>
812 request i_reduce(std::shared_ptr<const Executor> exec,
813 const ReduceType* send_buffer, ReduceType* recv_buffer,
814 int count, MPI_Op operation, int root_rank) const
815 {
816 auto guard = exec->get_scoped_device_id_guard();
817 request req;
818 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ireduce(
819 send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
820 operation, root_rank, this->get(), req.get()));
821 return req;
822 }
823
837 template <typename ReduceType>
838 void all_reduce(std::shared_ptr<const Executor> exec,
839 ReduceType* recv_buffer, int count, MPI_Op operation) const
840 {
841 auto guard = exec->get_scoped_device_id_guard();
842 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
843 MPI_IN_PLACE, recv_buffer, count, type_impl<ReduceType>::get_type(),
844 operation, this->get()));
845 }
846
862 template <typename ReduceType>
863 request i_all_reduce(std::shared_ptr<const Executor> exec,
864 ReduceType* recv_buffer, int count,
865 MPI_Op operation) const
866 {
867 auto guard = exec->get_scoped_device_id_guard();
868 request req;
869 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
870 MPI_IN_PLACE, recv_buffer, count, type_impl<ReduceType>::get_type(),
871 operation, this->get(), req.get()));
872 return req;
873 }
874
889 template <typename ReduceType>
890 void all_reduce(std::shared_ptr<const Executor> exec,
891 const ReduceType* send_buffer, ReduceType* recv_buffer,
892 int count, MPI_Op operation) const
893 {
894 auto guard = exec->get_scoped_device_id_guard();
895 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
896 send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
897 operation, this->get()));
898 }
899
916 template <typename ReduceType>
917 request i_all_reduce(std::shared_ptr<const Executor> exec,
918 const ReduceType* send_buffer, ReduceType* recv_buffer,
919 int count, MPI_Op operation) const
920 {
921 auto guard = exec->get_scoped_device_id_guard();
922 request req;
923 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
924 send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
925 operation, this->get(), req.get()));
926 return req;
927 }
928
945 template <typename SendType, typename RecvType>
946 void gather(std::shared_ptr<const Executor> exec,
947 const SendType* send_buffer, const int send_count,
948 RecvType* recv_buffer, const int recv_count,
949 int root_rank) const
950 {
951 auto guard = exec->get_scoped_device_id_guard();
952 GKO_ASSERT_NO_MPI_ERRORS(
953 MPI_Gather(send_buffer, send_count, type_impl<SendType>::get_type(),
954 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
955 root_rank, this->get()));
956 }
957
977 template <typename SendType, typename RecvType>
978 request i_gather(std::shared_ptr<const Executor> exec,
979 const SendType* send_buffer, const int send_count,
980 RecvType* recv_buffer, const int recv_count,
981 int root_rank) const
982 {
983 auto guard = exec->get_scoped_device_id_guard();
984 request req;
985 GKO_ASSERT_NO_MPI_ERRORS(MPI_Igather(
986 send_buffer, send_count, type_impl<SendType>::get_type(),
987 recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
988 this->get(), req.get()));
989 return req;
990 }
991
1010 template <typename SendType, typename RecvType>
1011 void gather_v(std::shared_ptr<const Executor> exec,
1012 const SendType* send_buffer, const int send_count,
1013 RecvType* recv_buffer, const int* recv_counts,
1014 const int* displacements, int root_rank) const
1015 {
1016 auto guard = exec->get_scoped_device_id_guard();
1017 GKO_ASSERT_NO_MPI_ERRORS(MPI_Gatherv(
1018 send_buffer, send_count, type_impl<SendType>::get_type(),
1019 recv_buffer, recv_counts, displacements,
1020 type_impl<RecvType>::get_type(), root_rank, this->get()));
1021 }
1022
1043 template <typename SendType, typename RecvType>
1044 request i_gather_v(std::shared_ptr<const Executor> exec,
1045 const SendType* send_buffer, const int send_count,
1046 RecvType* recv_buffer, const int* recv_counts,
1047 const int* displacements, int root_rank) const
1048 {
1049 auto guard = exec->get_scoped_device_id_guard();
1050 request req;
1051 GKO_ASSERT_NO_MPI_ERRORS(MPI_Igatherv(
1052 send_buffer, send_count, type_impl<SendType>::get_type(),
1053 recv_buffer, recv_counts, displacements,
1054 type_impl<RecvType>::get_type(), root_rank, this->get(),
1055 req.get()));
1056 return req;
1057 }
1058
1074 template <typename SendType, typename RecvType>
1075 void all_gather(std::shared_ptr<const Executor> exec,
1076 const SendType* send_buffer, const int send_count,
1077 RecvType* recv_buffer, const int recv_count) const
1078 {
1079 auto guard = exec->get_scoped_device_id_guard();
1080 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allgather(
1081 send_buffer, send_count, type_impl<SendType>::get_type(),
1082 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1083 this->get()));
1084 }
1085
1104 template <typename SendType, typename RecvType>
1105 request i_all_gather(std::shared_ptr<const Executor> exec,
1106 const SendType* send_buffer, const int send_count,
1107 RecvType* recv_buffer, const int recv_count) const
1108 {
1109 auto guard = exec->get_scoped_device_id_guard();
1110 request req;
1111 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallgather(
1112 send_buffer, send_count, type_impl<SendType>::get_type(),
1113 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1114 this->get(), req.get()));
1115 return req;
1116 }
1117
1133 template <typename SendType, typename RecvType>
1134 void scatter(std::shared_ptr<const Executor> exec,
1135 const SendType* send_buffer, const int send_count,
1136 RecvType* recv_buffer, const int recv_count,
1137 int root_rank) const
1138 {
1139 auto guard = exec->get_scoped_device_id_guard();
1140 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatter(
1141 send_buffer, send_count, type_impl<SendType>::get_type(),
1142 recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
1143 this->get()));
1144 }
1145
1164 template <typename SendType, typename RecvType>
1165 request i_scatter(std::shared_ptr<const Executor> exec,
1166 const SendType* send_buffer, const int send_count,
1167 RecvType* recv_buffer, const int recv_count,
1168 int root_rank) const
1169 {
1170 auto guard = exec->get_scoped_device_id_guard();
1171 request req;
1172 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscatter(
1173 send_buffer, send_count, type_impl<SendType>::get_type(),
1174 recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
1175 this->get(), req.get()));
1176 return req;
1177 }
1178
1197 template <typename SendType, typename RecvType>
1198 void scatter_v(std::shared_ptr<const Executor> exec,
1199 const SendType* send_buffer, const int* send_counts,
1200 const int* displacements, RecvType* recv_buffer,
1201 const int recv_count, int root_rank) const
1202 {
1203 auto guard = exec->get_scoped_device_id_guard();
1204 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatterv(
1205 send_buffer, send_counts, displacements,
1206 type_impl<SendType>::get_type(), recv_buffer, recv_count,
1207 type_impl<RecvType>::get_type(), root_rank, this->get()));
1208 }
1209
1230 template <typename SendType, typename RecvType>
1231 request i_scatter_v(std::shared_ptr<const Executor> exec,
1232 const SendType* send_buffer, const int* send_counts,
1233 const int* displacements, RecvType* recv_buffer,
1234 const int recv_count, int root_rank) const
1235 {
1236 auto guard = exec->get_scoped_device_id_guard();
1237 request req;
1238 GKO_ASSERT_NO_MPI_ERRORS(
1239 MPI_Iscatterv(send_buffer, send_counts, displacements,
1240 type_impl<SendType>::get_type(), recv_buffer,
1241 recv_count, type_impl<RecvType>::get_type(),
1242 root_rank, this->get(), req.get()));
1243 return req;
1244 }
1245
1262 template <typename RecvType>
1263 void all_to_all(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
1264 const int recv_count) const
1265 {
1266 auto guard = exec->get_scoped_device_id_guard();
1267 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1268 MPI_IN_PLACE, recv_count, type_impl<RecvType>::get_type(),
1269 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1270 this->get()));
1271 }
1272
1291 template <typename RecvType>
1292 request i_all_to_all(std::shared_ptr<const Executor> exec,
1293 RecvType* recv_buffer, const int recv_count) const
1294 {
1295 auto guard = exec->get_scoped_device_id_guard();
1296 request req;
1297 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1298 MPI_IN_PLACE, recv_count, type_impl<RecvType>::get_type(),
1299 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1300 this->get(), req.get()));
1301 return req;
1302 }
1303
1320 template <typename SendType, typename RecvType>
1321 void all_to_all(std::shared_ptr<const Executor> exec,
1322 const SendType* send_buffer, const int send_count,
1323 RecvType* recv_buffer, const int recv_count) const
1324 {
1325 auto guard = exec->get_scoped_device_id_guard();
1326 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1327 send_buffer, send_count, type_impl<SendType>::get_type(),
1328 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1329 this->get()));
1330 }
1331
1350 template <typename SendType, typename RecvType>
1351 request i_all_to_all(std::shared_ptr<const Executor> exec,
1352 const SendType* send_buffer, const int send_count,
1353 RecvType* recv_buffer, const int recv_count) const
1354 {
1355 auto guard = exec->get_scoped_device_id_guard();
1356 request req;
1357 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1358 send_buffer, send_count, type_impl<SendType>::get_type(),
1359 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1360 this->get(), req.get()));
1361 return req;
1362 }
1363
1383 template <typename SendType, typename RecvType>
1384 void all_to_all_v(std::shared_ptr<const Executor> exec,
1385 const SendType* send_buffer, const int* send_counts,
1386 const int* send_offsets, RecvType* recv_buffer,
1387 const int* recv_counts, const int* recv_offsets) const
1388 {
1389 this->all_to_all_v(std::move(exec), send_buffer, send_counts,
1390 send_offsets, type_impl<SendType>::get_type(),
1391 recv_buffer, recv_counts, recv_offsets,
1393 }
1394
1410 void all_to_all_v(std::shared_ptr<const Executor> exec,
1411 const void* send_buffer, const int* send_counts,
1412 const int* send_offsets, MPI_Datatype send_type,
1413 void* recv_buffer, const int* recv_counts,
1414 const int* recv_offsets, MPI_Datatype recv_type) const
1415 {
1416 auto guard = exec->get_scoped_device_id_guard();
1417 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoallv(
1418 send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1419 recv_counts, recv_offsets, recv_type, this->get()));
1420 }
1421
1441 request i_all_to_all_v(std::shared_ptr<const Executor> exec,
1442 const void* send_buffer, const int* send_counts,
1443 const int* send_offsets, MPI_Datatype send_type,
1444 void* recv_buffer, const int* recv_counts,
1445 const int* recv_offsets,
1446 MPI_Datatype recv_type) const
1447 {
1448 auto guard = exec->get_scoped_device_id_guard();
1449 request req;
1450 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoallv(
1451 send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1452 recv_counts, recv_offsets, recv_type, this->get(), req.get()));
1453 return req;
1454 }
1455
1476 template <typename SendType, typename RecvType>
1477 request i_all_to_all_v(std::shared_ptr<const Executor> exec,
1478 const SendType* send_buffer, const int* send_counts,
1479 const int* send_offsets, RecvType* recv_buffer,
1480 const int* recv_counts,
1481 const int* recv_offsets) const
1482 {
1483 return this->i_all_to_all_v(
1484 std::move(exec), send_buffer, send_counts, send_offsets,
1485 type_impl<SendType>::get_type(), recv_buffer, recv_counts,
1486 recv_offsets, type_impl<RecvType>::get_type());
1487 }
1488
1503 template <typename ScanType>
1504 void scan(std::shared_ptr<const Executor> exec, const ScanType* send_buffer,
1505 ScanType* recv_buffer, int count, MPI_Op operation) const
1506 {
1507 auto guard = exec->get_scoped_device_id_guard();
1508 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scan(send_buffer, recv_buffer, count,
1510 operation, this->get()));
1511 }
1512
1529 template <typename ScanType>
1530 request i_scan(std::shared_ptr<const Executor> exec,
1531 const ScanType* send_buffer, ScanType* recv_buffer,
1532 int count, MPI_Op operation) const
1533 {
1534 auto guard = exec->get_scoped_device_id_guard();
1535 request req;
1536 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscan(send_buffer, recv_buffer, count,
1538 operation, this->get(), req.get()));
1539 return req;
1540 }
1541
1542private:
1543 std::shared_ptr<MPI_Comm> comm_;
1544 bool force_host_buffer_;
1545
1546 int get_my_rank() const
1547 {
1548 int my_rank = 0;
1549 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(get(), &my_rank));
1550 return my_rank;
1551 }
1552
1553 int get_node_local_rank() const
1554 {
1555 MPI_Comm local_comm;
1556 int rank;
1557 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split_type(
1558 this->get(), MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &local_comm));
1559 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(local_comm, &rank));
1560 MPI_Comm_free(&local_comm);
1561 return rank;
1562 }
1563
1564 int get_num_ranks() const
1565 {
1566 int size = 1;
1567 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_size(this->get(), &size));
1568 return size;
1569 }
1570};
1571
1572
1577bool requires_host_buffer(const std::shared_ptr<const Executor>& exec,
1578 const communicator& comm);
1579
1580
1586inline double get_walltime() { return MPI_Wtime(); }
1587
1588
1597template <typename ValueType>
1598class window {
1599public:
1603 enum class create_type { allocate = 1, create = 2, dynamic_create = 3 };
1604
1608 enum class lock_type { shared = 1, exclusive = 2 };
1609
1613 window() : window_(MPI_WIN_NULL) {}
1614
1615 window(const window& other) = delete;
1616
1617 window& operator=(const window& other) = delete;
1618
1625 window(window&& other) : window_{std::exchange(other.window_, MPI_WIN_NULL)}
1626 {}
1627
1635 {
1636 window_ = std::exchange(other.window_, MPI_WIN_NULL);
1637 }
1638
1651 window(std::shared_ptr<const Executor> exec, ValueType* base, int num_elems,
1652 const communicator& comm, const int disp_unit = sizeof(ValueType),
1653 MPI_Info input_info = MPI_INFO_NULL,
1654 create_type c_type = create_type::create)
1655 {
1656 auto guard = exec->get_scoped_device_id_guard();
1657 unsigned size = num_elems * sizeof(ValueType);
1658 if (c_type == create_type::create) {
1659 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_create(
1660 base, size, disp_unit, input_info, comm.get(), &this->window_));
1661 } else if (c_type == create_type::dynamic_create) {
1662 GKO_ASSERT_NO_MPI_ERRORS(
1663 MPI_Win_create_dynamic(input_info, comm.get(), &this->window_));
1664 } else if (c_type == create_type::allocate) {
1665 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_allocate(
1666 size, disp_unit, input_info, comm.get(), base, &this->window_));
1667 } else {
1668 GKO_NOT_IMPLEMENTED;
1669 }
1670 }
1671
1677 MPI_Win get_window() const { return this->window_; }
1678
1685 void fence(int assert = 0) const
1686 {
1687 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_fence(assert, this->window_));
1688 }
1689
1698 void lock(int rank, lock_type lock_t = lock_type::shared,
1699 int assert = 0) const
1700 {
1701 if (lock_t == lock_type::shared) {
1702 GKO_ASSERT_NO_MPI_ERRORS(
1703 MPI_Win_lock(MPI_LOCK_SHARED, rank, assert, this->window_));
1704 } else if (lock_t == lock_type::exclusive) {
1705 GKO_ASSERT_NO_MPI_ERRORS(
1706 MPI_Win_lock(MPI_LOCK_EXCLUSIVE, rank, assert, this->window_));
1707 } else {
1708 GKO_NOT_IMPLEMENTED;
1709 }
1710 }
1711
1718 void unlock(int rank) const
1719 {
1720 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock(rank, this->window_));
1721 }
1722
1729 void lock_all(int assert = 0) const
1730 {
1731 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_lock_all(assert, this->window_));
1732 }
1733
1738 void unlock_all() const
1739 {
1740 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock_all(this->window_));
1741 }
1742
1749 void flush(int rank) const
1750 {
1751 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush(rank, this->window_));
1752 }
1753
1760 void flush_local(int rank) const
1761 {
1762 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local(rank, this->window_));
1763 }
1764
1769 void flush_all() const
1770 {
1771 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_all(this->window_));
1772 }
1773
1778 void flush_all_local() const
1779 {
1780 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local_all(this->window_));
1781 }
1782
1786 void sync() const { GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_sync(this->window_)); }
1787
1792 {
1793 if (this->window_ && this->window_ != MPI_WIN_NULL) {
1794 MPI_Win_free(&this->window_);
1795 }
1796 }
1797
1808 template <typename PutType>
1809 void put(std::shared_ptr<const Executor> exec, const PutType* origin_buffer,
1810 const int origin_count, const int target_rank,
1811 const unsigned int target_disp, const int target_count) const
1812 {
1813 auto guard = exec->get_scoped_device_id_guard();
1814 GKO_ASSERT_NO_MPI_ERRORS(
1815 MPI_Put(origin_buffer, origin_count, type_impl<PutType>::get_type(),
1816 target_rank, target_disp, target_count,
1818 }
1819
1832 template <typename PutType>
1833 request r_put(std::shared_ptr<const Executor> exec,
1834 const PutType* origin_buffer, const int origin_count,
1835 const int target_rank, const unsigned int target_disp,
1836 const int target_count) const
1837 {
1838 auto guard = exec->get_scoped_device_id_guard();
1839 request req;
1840 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rput(
1841 origin_buffer, origin_count, type_impl<PutType>::get_type(),
1842 target_rank, target_disp, target_count,
1843 type_impl<PutType>::get_type(), this->get_window(), req.get()));
1844 return req;
1845 }
1846
1858 template <typename PutType>
1859 void accumulate(std::shared_ptr<const Executor> exec,
1860 const PutType* origin_buffer, const int origin_count,
1861 const int target_rank, const unsigned int target_disp,
1862 const int target_count, MPI_Op operation) const
1863 {
1864 auto guard = exec->get_scoped_device_id_guard();
1865 GKO_ASSERT_NO_MPI_ERRORS(MPI_Accumulate(
1866 origin_buffer, origin_count, type_impl<PutType>::get_type(),
1867 target_rank, target_disp, target_count,
1868 type_impl<PutType>::get_type(), operation, this->get_window()));
1869 }
1870
1884 template <typename PutType>
1885 request r_accumulate(std::shared_ptr<const Executor> exec,
1886 const PutType* origin_buffer, const int origin_count,
1887 const int target_rank, const unsigned int target_disp,
1888 const int target_count, MPI_Op operation) const
1889 {
1890 auto guard = exec->get_scoped_device_id_guard();
1891 request req;
1892 GKO_ASSERT_NO_MPI_ERRORS(MPI_Raccumulate(
1893 origin_buffer, origin_count, type_impl<PutType>::get_type(),
1894 target_rank, target_disp, target_count,
1895 type_impl<PutType>::get_type(), operation, this->get_window(),
1896 req.get()));
1897 return req;
1898 }
1899
1910 template <typename GetType>
1911 void get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1912 const int origin_count, const int target_rank,
1913 const unsigned int target_disp, const int target_count) const
1914 {
1915 auto guard = exec->get_scoped_device_id_guard();
1916 GKO_ASSERT_NO_MPI_ERRORS(
1917 MPI_Get(origin_buffer, origin_count, type_impl<GetType>::get_type(),
1918 target_rank, target_disp, target_count,
1920 }
1921
1934 template <typename GetType>
1935 request r_get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1936 const int origin_count, const int target_rank,
1937 const unsigned int target_disp, const int target_count) const
1938 {
1939 auto guard = exec->get_scoped_device_id_guard();
1940 request req;
1941 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget(
1942 origin_buffer, origin_count, type_impl<GetType>::get_type(),
1943 target_rank, target_disp, target_count,
1944 type_impl<GetType>::get_type(), this->get_window(), req.get()));
1945 return req;
1946 }
1947
1961 template <typename GetType>
1962 void get_accumulate(std::shared_ptr<const Executor> exec,
1963 GetType* origin_buffer, const int origin_count,
1964 GetType* result_buffer, const int result_count,
1965 const int target_rank, const unsigned int target_disp,
1966 const int target_count, MPI_Op operation) const
1967 {
1968 auto guard = exec->get_scoped_device_id_guard();
1969 GKO_ASSERT_NO_MPI_ERRORS(MPI_Get_accumulate(
1970 origin_buffer, origin_count, type_impl<GetType>::get_type(),
1971 result_buffer, result_count, type_impl<GetType>::get_type(),
1972 target_rank, target_disp, target_count,
1973 type_impl<GetType>::get_type(), operation, this->get_window()));
1974 }
1975
1991 template <typename GetType>
1992 request r_get_accumulate(std::shared_ptr<const Executor> exec,
1993 GetType* origin_buffer, const int origin_count,
1994 GetType* result_buffer, const int result_count,
1995 const int target_rank,
1996 const unsigned int target_disp,
1997 const int target_count, MPI_Op operation) const
1998 {
1999 auto guard = exec->get_scoped_device_id_guard();
2000 request req;
2001 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget_accumulate(
2002 origin_buffer, origin_count, type_impl<GetType>::get_type(),
2003 result_buffer, result_count, type_impl<GetType>::get_type(),
2004 target_rank, target_disp, target_count,
2005 type_impl<GetType>::get_type(), operation, this->get_window(),
2006 req.get()));
2007 return req;
2008 }
2009
2020 template <typename GetType>
2021 void fetch_and_op(std::shared_ptr<const Executor> exec,
2022 GetType* origin_buffer, GetType* result_buffer,
2023 const int target_rank, const unsigned int target_disp,
2024 MPI_Op operation) const
2025 {
2026 auto guard = exec->get_scoped_device_id_guard();
2027 GKO_ASSERT_NO_MPI_ERRORS(MPI_Fetch_and_op(
2028 origin_buffer, result_buffer, type_impl<GetType>::get_type(),
2029 target_rank, target_disp, operation, this->get_window()));
2030 }
2031
2032private:
2033 MPI_Win window_;
2034};
2035
2036
2037} // namespace mpi
2038} // namespace experimental
2039} // namespace gko
2040
2041
2042#endif // GKO_HAVE_MPI
2043
2044
2045#endif // GKO_PUBLIC_CORE_BASE_MPI_HPP_
A class providing basic support for bfloat16 precision floating point types.
Definition bfloat16.hpp:76
A thin wrapper of MPI_Comm that supports most MPI calls.
Definition mpi.hpp:419
status recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Receive data from source rank.
Definition mpi.hpp:681
void scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
Scatter data from root rank to all ranks in the communicator with offsets.
Definition mpi.hpp:1198
communicator(const communicator &other)=default
Create a copy of a communicator.
request i_broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
(Non-blocking) Broadcast data from calling process to all ranks in the communicator
Definition mpi.hpp:759
void gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Gather data onto the root rank from all ranks in the communicator.
Definition mpi.hpp:946
request i_recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Receive (Non-blocking, Immediate return) data from source rank.
Definition mpi.hpp:709
request i_scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Scatter data from root rank to all ranks in the communicator with offsets.
Definition mpi.hpp:1231
void all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Communicate data from all ranks to all other ranks (MPI_Alltoall).
Definition mpi.hpp:1321
bool is_identical(const communicator &rhs) const
Checks if the rhs communicator is identical to this communicator.
Definition mpi.hpp:571
request i_all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
(Non-blocking) Communicate data from all ranks to all other ranks (MPI_Ialltoall).
Definition mpi.hpp:1351
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Communicate data from all ranks to all other ranks with offsets (MPI_Ialltoallv).
Definition mpi.hpp:1441
bool operator!=(const communicator &rhs) const
Compare two communicator objects for non-equality.
Definition mpi.hpp:560
void scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Scatter data from root rank to all ranks in the communicator.
Definition mpi.hpp:1134
void synchronize() const
This function is used to synchronize the ranks in the communicator.
Definition mpi.hpp:607
int rank() const
Return the rank of the calling process in the communicator.
Definition mpi.hpp:539
request i_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
(Non-blocking) Reduce data into root from all calling processes on the same communicator.
Definition mpi.hpp:812
int size() const
Return the size of the communicator (number of ranks).
Definition mpi.hpp:532
void send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Send (Blocking) data from calling process to destination rank.
Definition mpi.hpp:626
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Communicate data from all ranks to all other ranks with offsets (MPI_Ialltoallv).
Definition mpi.hpp:1477
request i_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Gather data onto the root rank from all ranks in the communicator.
Definition mpi.hpp:978
void all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
(In-place) Communicate data from all ranks to all other ranks in place (MPI_Alltoall).
Definition mpi.hpp:1263
void all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Communicate data from all ranks to all other ranks with offsets (MPI_Alltoallv).
Definition mpi.hpp:1384
request i_all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
(In-place, non-blocking) Reduce data from all calling processes from all calling processes on same co...
Definition mpi.hpp:863
request i_all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
(In-place, Non-blocking) Communicate data from all ranks to all other ranks in place (MPI_Ialltoall).
Definition mpi.hpp:1292
void all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Communicate data from all ranks to all other ranks with offsets (MPI_Alltoallv).
Definition mpi.hpp:1410
int node_local_rank() const
Return the node local rank of the calling process in the communicator.
Definition mpi.hpp:546
void broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
Broadcast data from calling process to all ranks in the communicator.
Definition mpi.hpp:734
static communicator create_owning(const MPI_Comm &comm, bool force_host_buffer=false)
Creates a new communicator and takes ownership of the MPI_Comm.
Definition mpi.hpp:477
const MPI_Comm & get() const
Return the underlying MPI_Comm object.
Definition mpi.hpp:523
communicator(const MPI_Comm &comm, int color, int key)
Create a communicator object from an existing MPI_Comm object using color and key.
Definition mpi.hpp:445
void all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
(In-place) Reduce data from all calling processes from all calling processes on same communicator.
Definition mpi.hpp:838
communicator & operator=(const communicator &other)=default
void all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Gather data onto all ranks from all ranks in the communicator.
Definition mpi.hpp:1075
communicator & operator=(communicator &&other)
Definition mpi.hpp:508
request i_all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
(Non-blocking) Gather data onto all ranks from all ranks in the communicator.
Definition mpi.hpp:1105
bool operator==(const communicator &rhs) const
Compare two communicator objects for equality.
Definition mpi.hpp:553
bool is_congruent(const communicator &rhs) const
Checks if the rhs communicator is congruent to this communicator.
Definition mpi.hpp:593
void all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Reduce data from all calling processes from all calling processes on same communicator.
Definition mpi.hpp:890
request i_gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
(Non-blocking) Gather data onto the root rank from all ranks in the communicator with offsets.
Definition mpi.hpp:1044
request i_all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Reduce data from all calling processes from all calling processes on same communicator.
Definition mpi.hpp:917
communicator(const MPI_Comm &comm, bool force_host_buffer=false)
Non-owning constructor for an existing communicator of type MPI_Comm.
Definition mpi.hpp:431
request i_scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Does a scan operation with the given operator.
Definition mpi.hpp:1530
void reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
Reduce data into root from all calling processes on the same communicator.
Definition mpi.hpp:785
request i_scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Scatter data from root rank to all ranks in the communicator.
Definition mpi.hpp:1165
void scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Does a scan operation with the given operator.
Definition mpi.hpp:1504
void gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
Gather data onto the root rank from all ranks in the communicator with offsets.
Definition mpi.hpp:1011
request i_send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Send (Non-blocking, Immediate return) data from calling process to destination rank.
Definition mpi.hpp:653
communicator(communicator &&other)
Move constructor.
Definition mpi.hpp:498
communicator(const communicator &comm, int color, int key)
Create a communicator object from an existing MPI_Comm object using color and key.
Definition mpi.hpp:460
MPI_Datatype get() const
Access the underlying MPI_Datatype.
Definition mpi.hpp:182
contiguous_type(int count, MPI_Datatype old_type)
Constructs a wrapper for a contiguous MPI_Datatype.
Definition mpi.hpp:121
contiguous_type()
Constructs empty wrapper with MPI_DATATYPE_NULL.
Definition mpi.hpp:130
contiguous_type(const contiguous_type &)=delete
Disallow copying of wrapper type.
contiguous_type(contiguous_type &&other) noexcept
Move constructor, leaves other with MPI_DATATYPE_NULL.
Definition mpi.hpp:147
contiguous_type & operator=(contiguous_type &&other) noexcept
Move assignment, leaves other with MPI_DATATYPE_NULL.
Definition mpi.hpp:159
contiguous_type & operator=(const contiguous_type &)=delete
Disallow copying of wrapper type.
~contiguous_type()
Destructs object by freeing wrapped MPI_Datatype.
Definition mpi.hpp:170
Class that sets up and finalizes the MPI environment.
Definition mpi.hpp:210
~environment()
Call MPI_Finalize at the end of the scope of this class.
Definition mpi.hpp:253
int get_provided_thread_support() const
Return the provided thread support.
Definition mpi.hpp:231
environment(int &argc, char **&argv, const thread_type thread_t=thread_type::serialized)
Call MPI_Init_thread and initialize the MPI environment.
Definition mpi.hpp:241
The request class is a light, move-only wrapper around the MPI_Request handle.
Definition mpi.hpp:331
request()
The default constructor.
Definition mpi.hpp:337
MPI_Request * get()
Get a pointer to the underlying MPI_Request handle.
Definition mpi.hpp:368
status wait()
Allows a rank to wait on a particular request handle.
Definition mpi.hpp:376
This class wraps the MPI_Window class with RAII functionality.
Definition mpi.hpp:1598
void get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Get data from the target window.
Definition mpi.hpp:1911
request r_put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Put data into the target window.
Definition mpi.hpp:1833
window()
The default constructor.
Definition mpi.hpp:1613
void get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Get Accumulate data from the target window.
Definition mpi.hpp:1962
void put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Put data into the target window.
Definition mpi.hpp:1809
~window()
The deleter which calls MPI_Win_free when the window leaves its scope.
Definition mpi.hpp:1791
lock_type
The lock type for passive target synchronization of the windows.
Definition mpi.hpp:1608
window & operator=(window &&other)
The move assignment operator.
Definition mpi.hpp:1634
request r_accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
(Non-blocking) Accumulate data into the target window.
Definition mpi.hpp:1885
request r_get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
(Non-blocking) Get Accumulate data (with handle) from the target window.
Definition mpi.hpp:1992
void fetch_and_op(std::shared_ptr< const Executor > exec, GetType *origin_buffer, GetType *result_buffer, const int target_rank, const unsigned int target_disp, MPI_Op operation) const
Fetch and operate on data from the target window (An optimized version of Get_accumulate).
Definition mpi.hpp:2021
void sync() const
Synchronize the public and private buffers for the window object.
Definition mpi.hpp:1786
void unlock(int rank) const
Close the epoch using MPI_Win_unlock for the window object.
Definition mpi.hpp:1718
void fence(int assert=0) const
The active target synchronization using MPI_Win_fence for the window object.
Definition mpi.hpp:1685
void flush(int rank) const
Flush the existing RDMA operations on the target rank for the calling process for the window object.
Definition mpi.hpp:1749
void unlock_all() const
Close the epoch on all ranks using MPI_Win_unlock_all for the window object.
Definition mpi.hpp:1738
create_type
The create type for the window object.
Definition mpi.hpp:1603
window(std::shared_ptr< const Executor > exec, ValueType *base, int num_elems, const communicator &comm, const int disp_unit=sizeof(ValueType), MPI_Info input_info=MPI_INFO_NULL, create_type c_type=create_type::create)
Create a window object with a given data pointer and type.
Definition mpi.hpp:1651
void accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Accumulate data into the target window.
Definition mpi.hpp:1859
void lock_all(int assert=0) const
Create the epoch on all ranks using MPI_Win_lock_all for the window object.
Definition mpi.hpp:1729
void lock(int rank, lock_type lock_t=lock_type::shared, int assert=0) const
Create an epoch using MPI_Win_lock for the window object.
Definition mpi.hpp:1698
void flush_all_local() const
Flush all the local existing RDMA operations on the calling rank for the window object.
Definition mpi.hpp:1778
window(window &&other)
The move constructor.
Definition mpi.hpp:1625
void flush_local(int rank) const
Flush the existing RDMA operations on the calling rank from the target rank for the window object.
Definition mpi.hpp:1760
MPI_Win get_window() const
Get the underlying window object of MPI_Win type.
Definition mpi.hpp:1677
request r_get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Get data (with handle) from the target window.
Definition mpi.hpp:1935
void flush_all() const
Flush all the existing RDMA operations for the calling process for the window object.
Definition mpi.hpp:1769
A class providing basic support for half precision floating point types.
Definition half.hpp:288
The mpi namespace, contains wrapper for many MPI functions.
Definition mpi.hpp:36
int map_rank_to_device_id(MPI_Comm comm, int num_devices)
Maps each MPI rank to a single device id in a round robin manner.
bool requires_host_buffer(const std::shared_ptr< const Executor > &exec, const communicator &comm)
Checks if the combination of Executor and communicator requires passing MPI buffers from the host mem...
double get_walltime()
Get the rank in the communicator of the calling process.
Definition mpi.hpp:1586
constexpr bool is_gpu_aware()
Return if GPU aware functionality is available.
Definition mpi.hpp:42
thread_type
This enum specifies the threading type to be used when creating an MPI environment.
Definition mpi.hpp:193
std::vector< status > wait_all(std::vector< request > &req)
Allows a rank to wait on multiple request handles.
Definition mpi.hpp:395
The Ginkgo namespace.
Definition abstract_factory.hpp:20
STL namespace.
The status struct is a light wrapper around the MPI_Status struct.
Definition mpi.hpp:291
int get_count(const T *data) const
Get the count of the number of elements received by the communication call.
Definition mpi.hpp:315
status()
The default constructor.
Definition mpi.hpp:295
MPI_Status * get()
Get a pointer to the underlying MPI_Status object.
Definition mpi.hpp:302
A struct that is used to determine the MPI_Datatype of a specified type.
Definition mpi.hpp:77