Skip to content

Commit

Permalink
[feat] Avoid reallocating and copying device memory when the inputs o…
Browse files Browse the repository at this point in the history
…f the Remove and ImportValues functions are already device variables.
  • Loading branch information
MoFHeka authored and rhdong committed Sep 27, 2022
1 parent 634b96d commit 0d53e01
Showing 1 changed file with 41 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,17 +282,26 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {

CUDA_CHECK(cudaStreamCreate(&_stream));
if (len > 0) {
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * len));
CUDA_CHECK(cudaMemcpy((void*)d_keys, (void*)keys.tensor_data().data(),
sizeof(K) * len, cudaMemcpyDefault));
cudaPointerAttributes keys_attr;
CUDA_CHECK(cudaPointerGetAttributes(&keys_attr,
(void*)keys.tensor_data().data()));
if (keys_attr.type != cudaMemoryTypeDevice) {
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * len));
CUDA_CHECK(cudaMemcpy((void*)d_keys, (void*)keys.tensor_data().data(),
sizeof(K) * len, cudaMemcpyDefault));
} else {
d_keys = (K*)keys.tensor_data().data();
}
{
mutex_lock l(mu_);
table_->remove((const K*)d_keys, len, _stream);
RehashIfNeeded(_stream);
CUDA_CHECK(cudaStreamSynchronize(_stream));
}
CUDA_CHECK(cudaStreamDestroy(_stream));
CUDA_CHECK(cudaFree(d_keys));
if (keys_attr.type != cudaMemoryTypeDevice) {
CUDA_CHECK(cudaFree(d_keys));
}
}
return Status::OK();
}
Expand All @@ -318,13 +327,28 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
if (len > 0) {
cudaStream_t _stream;
CUDA_CHECK(cudaStreamCreate(&_stream));
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * len));
CUDA_CHECK(
cudaMallocManaged((void**)&d_values, sizeof(V) * runtime_dim_ * len));
CUDA_CHECK(cudaMemcpy((void*)d_keys, (void*)keys.tensor_data().data(),
sizeof(K) * len, cudaMemcpyDefault));
CUDA_CHECK(cudaMemcpy((void*)d_values, (void*)values.tensor_data().data(),
sizeof(V) * runtime_dim_ * len, cudaMemcpyDefault));
cudaPointerAttributes keys_attr;
CUDA_CHECK(cudaPointerGetAttributes(&keys_attr,
(void*)keys.tensor_data().data()));
if (keys_attr.type != cudaMemoryTypeDevice) {
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * len));
CUDA_CHECK(cudaMemcpy((void*)d_keys, (void*)keys.tensor_data().data(),
sizeof(K) * len, cudaMemcpyDefault));
} else {
d_keys = (K*)keys.tensor_data().data();
}
cudaPointerAttributes values_attr;
CUDA_CHECK(cudaPointerGetAttributes(&values_attr,
(void*)values.tensor_data().data()));
if (values_attr.type != cudaMemoryTypeDevice) {
CUDA_CHECK(cudaMallocManaged((void**)&d_values,
sizeof(V) * runtime_dim_ * len));
CUDA_CHECK(
cudaMemcpy((void*)d_values, (void*)values.tensor_data().data(),
sizeof(V) * runtime_dim_ * len, cudaMemcpyDefault));
} else {
d_values = (gpu::ValueArrayBase<V>*)values.tensor_data().data();
}
{
mutex_lock l(mu_);
table_->clear(_stream);
Expand All @@ -334,8 +358,12 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
CUDA_CHECK(cudaStreamSynchronize(_stream));
}
CUDA_CHECK(cudaStreamDestroy(_stream));
CUDA_CHECK(cudaFree(d_keys));
CUDA_CHECK(cudaFree(d_values));
if (keys_attr.type != cudaMemoryTypeDevice) {
CUDA_CHECK(cudaFree(d_keys));
}
if (values_attr.type != cudaMemoryTypeDevice) {
CUDA_CHECK(cudaFree(d_values));
}
}
return Status::OK();
}
Expand Down

0 comments on commit 0d53e01

Please sign in to comment.