Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ccl send and recv #8855

Merged
merged 51 commits into from
Aug 21, 2022
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
0edc3b8
rename REGISTER_COLLECTIVE_COMMUNICATION_FACTORY to REGISTER_COLLECTI…
clackhan Aug 1, 2022
4ea2510
refactor_ccl_allgather_and_reduce_scatter
clackhan Aug 1, 2022
52c8e2f
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Aug 1, 2022
8ea051a
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Aug 2, 2022
043828a
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Aug 2, 2022
280a3f7
refactor ccl::Reduce
clackhan Aug 2, 2022
5500e9f
remove useless code
clackhan Aug 2, 2022
3b3d2d9
refactor ccl::Broadcast
clackhan Aug 2, 2022
b66da0f
fix static check error
clackhan Aug 2, 2022
f789a14
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Aug 2, 2022
ab0f0a8
reslove comment
clackhan Aug 2, 2022
6064515
Merge branch 'refactor_ccl_all_gather_and_reduce_scatter' of https://…
clackhan Aug 2, 2022
31da07f
monir fix
clackhan Aug 2, 2022
ff24acc
Merge branch 'master' into refactor_ccl_all_gather_and_reduce_scatter
clackhan Aug 4, 2022
31c157d
reslove comments
clackhan Aug 4, 2022
069fcd7
fix macro lock error
clackhan Aug 4, 2022
0656a82
Merge branch 'refactor_ccl_all_gather_and_reduce_scatter' of https://…
clackhan Aug 4, 2022
a42f39c
refine
clackhan Aug 4, 2022
62bbd34
Merge branch 'master' into refactor_ccl_all_gather_and_reduce_scatter
clackhan Aug 4, 2022
0b978f3
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Aug 4, 2022
fe9d9ba
Merge branch 'master' into refactor_ccl_all_gather_and_reduce_scatter
clackhan Aug 4, 2022
13b637f
Merge branch 'master' into refactor_ccl_all_gather_and_reduce_scatter
mergify[bot] Aug 4, 2022
8a14774
Merge branch 'refactor_ccl_all_gather_and_reduce_scatter' of https://…
clackhan Aug 4, 2022
944860a
fix an idiot error
clackhan Aug 4, 2022
01ac174
Merge branch 'refactor_ccl_all_gather_and_reduce_scatter' of https://…
clackhan Aug 4, 2022
b220706
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Aug 4, 2022
b6cefb8
fix reduce functor bug
clackhan Aug 5, 2022
9936f1a
Merge branch 'master' into refactor_ccl_reduce_and_broadcast
mergify[bot] Aug 5, 2022
5d88ec3
refactor_ccl_send_and_recv
clackhan Aug 5, 2022
392a56b
Merge branch 'refactor_ccl_reduce_and_broadcast' of https://github.co…
clackhan Aug 5, 2022
b429728
refine
clackhan Aug 5, 2022
d0eb19d
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Aug 7, 2022
32f9f5a
Merge branch 'master' into refactor_ccl_send_and_recv
clackhan Aug 18, 2022
64f5507
refine
clackhan Aug 18, 2022
993c5e3
Merge branch 'refactor_ccl_send_and_recv' of https://github.com/Onefl…
clackhan Aug 18, 2022
010d832
Merge branch 'master' into refactor_ccl_send_and_recv
clackhan Aug 18, 2022
cc858ae
Merge branch 'master' into refactor_ccl_send_and_recv
clackhan Aug 18, 2022
acdde00
Merge branch 'master' into refactor_ccl_send_and_recv
mergify[bot] Aug 19, 2022
c199a3e
Merge branch 'master' into refactor_ccl_send_and_recv
mergify[bot] Aug 19, 2022
ca9cfe6
Merge branch 'master' into refactor_ccl_send_and_recv
mergify[bot] Aug 19, 2022
c2d35fb
Merge branch 'master' into refactor_ccl_send_and_recv
mergify[bot] Aug 19, 2022
7d88b52
Merge branch 'master' into refactor_ccl_send_and_recv
mergify[bot] Aug 19, 2022
0c7c5e5
Merge branch 'master' into refactor_ccl_send_and_recv
mergify[bot] Aug 19, 2022
2669fa0
Merge branch 'master' into refactor_ccl_send_and_recv
clackhan Aug 20, 2022
47503d7
Merge branch 'master' into refactor_ccl_send_and_recv
mergify[bot] Aug 20, 2022
2f89584
Merge branch 'master' into refactor_ccl_send_and_recv
mergify[bot] Aug 20, 2022
804a010
Merge branch 'master' into refactor_ccl_send_and_recv
mergify[bot] Aug 20, 2022
13a0fb3
Merge branch 'master' into refactor_ccl_send_and_recv
mergify[bot] Aug 20, 2022
90b6a90
Merge branch 'master' into refactor_ccl_send_and_recv
mergify[bot] Aug 20, 2022
5647535
Merge branch 'master' into refactor_ccl_send_and_recv
mergify[bot] Aug 20, 2022
182c18c
Merge branch 'master' into refactor_ccl_send_and_recv
clackhan Aug 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions oneflow/core/boxing/one_to_one_boxing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/user/kernels/communicate_util.h"

namespace oneflow {

Expand All @@ -31,8 +32,7 @@ Maybe<void> RawCheckNaiveOneToOne(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> ou
CHECK_EQ_OR_RETURN(out->placement()->parallel_num(), 1);
CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag());
CHECK_OR_RETURN(in->placement() != out->placement());
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
CHECK_OR_RETURN(IsSendAndRecvRegistered(in->placement()->device_type())); // NOLINT
return Maybe<void>::Ok();
}
// NOLINTEND(maybe-need-error-msg)
Expand Down
6 changes: 2 additions & 4 deletions oneflow/core/boxing/slice_boxing_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h"
#include "oneflow/core/boxing/eager_boxing_logger.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/user/kernels/communicate_util.h"

namespace oneflow {

Expand All @@ -26,10 +27,7 @@ namespace private_details {
Maybe<one::Tensor> PreprocessInputTensor4SliceBoxing(const std::shared_ptr<one::Tensor>& tensor,
const std::string& log_prefix) {
const auto& tensor_placement = JUST(tensor->parallel_desc());
if (tensor_placement->device_type() == DeviceType::kCPU
|| tensor_placement->device_type() == DeviceType::kCUDA) {
return tensor;
}
if (IsSendAndRecvRegistered(tensor_placement->device_type())) { return tensor; }

const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
Symbol<ParallelDesc> new_placement = JUST(ReplaceDeviceType(tensor_placement, DeviceType::kCPU));
Expand Down
59 changes: 2 additions & 57 deletions oneflow/core/ccl/ccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,7 @@ Maybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t
return Maybe<void>::Ok();
}

#ifdef WITH_CUDA
std::pair<ncclComm_t, int64_t> RawGetNcclCommAndPeerNcclRank(int64_t peer_process_id) {
std::set<std::pair<int64_t, int64_t>> device_set;
const int64_t& rank = GlobalProcessCtx::Rank();
const int64_t peer_nccl_rank = (peer_process_id > rank) ? 1 : 0;
device_set.emplace(rank, GlobalProcessCtx::LocalRank());
device_set.emplace(peer_process_id, GlobalProcessCtx::LocalRank(peer_process_id));
return {CHECK_NOTNULL(Singleton<EagerNcclCommMgr>::Get())->GetCommForDevice(device_set),
peer_nccl_rank};
}
auto* GetNcclCommAndPeerNcclRank = DECORATE(&RawGetNcclCommAndPeerNcclRank, ThreadLocal);
#endif

template<>
Maybe<void> Send<DeviceType::kCPU>(const void* in, size_t elem_cnt, DataType dtype, int64_t dst,
ep::Stream* stream) {
CHECK_OR_RETURN(IsPODDataType(dtype));
size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype);
Maybe<void> CpuSend(const void* in, size_t buffer_size, int64_t dst) {
TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
NaiveAsyncTransportCtx transport_ctx(
transport_token,
Expand All @@ -115,28 +98,7 @@ Maybe<void> Send<DeviceType::kCPU>(const void* in, size_t elem_cnt, DataType dty
return Maybe<void>::Ok();
}

#ifdef WITH_CUDA
template<>
Maybe<void> Send<DeviceType::kCUDA>(const void* in, size_t elem_cnt, DataType dtype, int64_t dst,
ep::Stream* stream) {
#if NCCL_VERSION_CODE >= 2700
CHECK_OR_RETURN(IsPODDataType(dtype));
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst);
OF_NCCL_CHECK_OR_RETURN(ncclSend(in, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second,
comm_and_peer_rank.first,
stream->As<ep::CudaStream>()->cuda_stream()));
return Maybe<void>::Ok();
#else
UNIMPLEMENTED_THEN_RETURN() << "GPU send is only supported when nccl version >= 2.7"
#endif
}
#endif

template<>
Maybe<void> Recv<DeviceType::kCPU>(void* out, size_t elem_cnt, DataType dtype, int64_t src,
ep::Stream* stream) {
CHECK_OR_RETURN(IsPODDataType(dtype));
size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype);
Maybe<void> CpuRecv(void* out, size_t buffer_size, int64_t src) {
TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
NaiveAsyncTransportCtx transport_ctx(
transport_token,
Expand All @@ -154,22 +116,5 @@ Maybe<void> Recv<DeviceType::kCPU>(void* out, size_t elem_cnt, DataType dtype, i
return Maybe<void>::Ok();
}

#ifdef WITH_CUDA
template<>
Maybe<void> Recv<DeviceType::kCUDA>(void* out, size_t elem_cnt, DataType dtype, int64_t src,
ep::Stream* stream) {
#if NCCL_VERSION_CODE >= 2700
CHECK_OR_RETURN(IsPODDataType(dtype));
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(src);
OF_NCCL_CHECK_OR_RETURN(ncclRecv(out, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second,
comm_and_peer_rank.first,
stream->As<ep::CudaStream>()->cuda_stream()));
return Maybe<void>::Ok();
#else
UNIMPLEMENTED_THEN_RETURN() << "GPU recv is only supported when nccl version >= 2.7"
#endif
}
#endif

} // namespace ccl
} // namespace oneflow
6 changes: 2 additions & 4 deletions oneflow/core/ccl/ccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ class TransportToken;
// collective communication library
namespace ccl {

template<DeviceType device_type>
Maybe<void> Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst, ep::Stream* stream);
Maybe<void> CpuSend(const void* in, size_t buffer_size, int64_t dst);

template<DeviceType device_type>
Maybe<void> Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, ep::Stream* stream);
Maybe<void> CpuRecv(void* out, size_t buffer_size, int64_t src);

Maybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t root,
Symbol<ParallelDesc> parallel_desc, const TransportToken& transport_token);
Expand Down
17 changes: 6 additions & 11 deletions oneflow/core/functional/impl/comm_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,13 @@ class SendFunctor {
JUST(attrs.SetAttr<int64_t>("dst_process_id", dst));
if (send_meta) {
std::shared_ptr<FlatShape> flat_shape = JUST(FlatShape::New(*x->shape()));
JUST(ccl::Send<DeviceType::kCPU>(flat_shape.get(), sizeof(*flat_shape), DataType::kChar, dst,
nullptr));
JUST(ccl::CpuSend(flat_shape.get(), sizeof(*flat_shape), dst));

DataType dtype = x->dtype()->data_type();
JUST(ccl::Send<DeviceType::kCPU>(&dtype, sizeof(dtype), DataType::kChar, dst, nullptr));
JUST(ccl::CpuSend(&dtype, sizeof(dtype), dst));

DeviceType device_type = JUST(Device::GetPlacement(*JUST(x->device())))->device_type();
JUST(ccl::Send<DeviceType::kCPU>(&device_type, sizeof(device_type), DataType::kChar, dst,
nullptr));
JUST(ccl::CpuSend(&device_type, sizeof(device_type), dst));
}
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_expr_, {x}, attrs));
return Maybe<void>::Ok();
Expand Down Expand Up @@ -373,16 +371,13 @@ class RecvFunctor {
} else if (!optional_shape.has_value() && !optional_dtype.has_value()
&& !optional_device.has_value()) {
FlatShape flat_shape{};
JUST(ccl::Recv<DeviceType::kCPU>(&flat_shape, sizeof(flat_shape), DataType::kChar, src,
nullptr));
JUST(ccl::CpuRecv(&flat_shape, sizeof(flat_shape), src));
shape = *JUST(flat_shape.ToShape());

JUST(ccl::Recv<DeviceType::kCPU>(&data_type, sizeof(data_type), DataType::kChar, src,
nullptr));
JUST(ccl::CpuRecv(&data_type, sizeof(data_type), src));

DeviceType device_type = DeviceType::kInvalidDevice;
JUST(ccl::Recv<DeviceType::kCPU>(&device_type, sizeof(device_type), DataType::kChar, src,
nullptr));
JUST(ccl::CpuRecv(&device_type, sizeof(device_type), src));
device = JUST(Device::New(*JUST(DeviceTag4DeviceType(device_type))));
} else {
UNIMPLEMENTED_THEN_RETURN() << "All or none of shape, dtype and device should have value.";
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/functional/impl/slice_boxing_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ limitations under the License.
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/function_library.h"
#include "oneflow/core/functional/impl/common.h"
#include "oneflow/core/ccl/ccl.h"

namespace oneflow {
namespace one {
Expand Down
51 changes: 51 additions & 0 deletions oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/ccl/ccl.h"
#include "oneflow/core/job/rank_group.h"
#include "oneflow/core/framework/transport_util.h"
#include "oneflow/user/kernels/collective_communication/include/recv.h"

namespace oneflow {

namespace ccl {

// Use CpuRecvImpl to avoid name confilict
class CpuRecvImpl final : public Recv {
public:
OF_DISALLOW_COPY_AND_MOVE(CpuRecvImpl);
CpuRecvImpl() : size_of_datatype_(0) {}
~CpuRecvImpl() = default;

void Init(DataType datatype) override {
CHECK(IsPODDataType(datatype));
this->size_of_datatype_ = GetSizeOfDataType(datatype);
}

void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const override {
size_t buffer_size = elem_cnt * size_of_datatype_;
CHECK_JUST(CpuRecv(out, buffer_size, src));
}

private:
size_t size_of_datatype_;
};

REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Recv, CpuRecvImpl);

} // namespace ccl

} // namespace oneflow
51 changes: 51 additions & 0 deletions oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/ccl/ccl.h"
#include "oneflow/core/job/rank_group.h"
#include "oneflow/core/framework/transport_util.h"
#include "oneflow/user/kernels/collective_communication/include/send.h"

namespace oneflow {

namespace ccl {

// Use CpuSendImpl to avoid name confilict
class CpuSendImpl final : public Send {
public:
OF_DISALLOW_COPY_AND_MOVE(CpuSendImpl);
CpuSendImpl() : size_of_datatype_(0) {}
~CpuSendImpl() = default;

void Init(DataType datatype) override {
CHECK(IsPODDataType(datatype));
this->size_of_datatype_ = GetSizeOfDataType(datatype);
}

void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const override {
size_t buffer_size = elem_cnt * size_of_datatype_;
CHECK_JUST(CpuSend(in, buffer_size, dst));
}

private:
size_t size_of_datatype_;
};

REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Send, CpuSendImpl);

} // namespace ccl

} // namespace oneflow
53 changes: 53 additions & 0 deletions oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_CUDA
#include "oneflow/user/kernels/collective_communication/include/recv.h"
#include "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h"
#include "oneflow/core/device/nccl_util.h"

namespace oneflow {

namespace ccl {

class CudaRecv final : public Recv {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaRecv);
CudaRecv() : nccl_datatype_() {}
~CudaRecv() = default;

void Init(DataType datatype) override { this->nccl_datatype_ = GetNcclDataType(datatype); }

void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const override {
#if HAS_NCCL_SEND_RECV
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(src);
OF_NCCL_CHECK(ncclRecv(out, elem_cnt, nccl_datatype_, comm_and_peer_rank.second,
comm_and_peer_rank.first, stream->As<ep::CudaStream>()->cuda_stream()));
#else
UNIMPLEMENTED() << "GPU recv is only supported when nccl version >= 2.7"
#endif // HAS_NCCL_SEND_RECV
}

private:
ncclDataType_t nccl_datatype_;
};

REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, Recv, CudaRecv);

} // namespace ccl

} // namespace oneflow

#endif // WITH_CUDA
53 changes: 53 additions & 0 deletions oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_CUDA
#include "oneflow/user/kernels/collective_communication/include/send.h"
#include "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h"
#include "oneflow/core/device/nccl_util.h"

namespace oneflow {

namespace ccl {

class CudaSend final : public Send {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaSend);
CudaSend() : nccl_datatype_() {}
~CudaSend() = default;

void Init(DataType datatype) override { this->nccl_datatype_ = GetNcclDataType(datatype); }

void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const override {
#if HAS_NCCL_SEND_RECV
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst);
OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, comm_and_peer_rank.second,
comm_and_peer_rank.first, stream->As<ep::CudaStream>()->cuda_stream()));
#else
UNIMPLEMENTED() << "GPU send is only supported when nccl version >= 2.7"
#endif // HAS_NCCL_SEND_RECV
}

private:
ncclDataType_t nccl_datatype_;
};

REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, Send, CudaSend);

} // namespace ccl

} // namespace oneflow

#endif // WITH_CUDA
Loading