Skip to content

Commit

Permalink
Fix potential data race with OrtValue usage in Python (#9841)
Browse files Browse the repository at this point in the history
  • Loading branch information
hariharans29 authored Nov 23, 2021
1 parent 0ae0f29 commit 18fd2cf
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 3 deletions.
21 changes: 19 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,26 @@ struct ProviderInfo_CUDA_Impl : ProviderInfo_CUDA {
}

// Used by slice_concatenate_test.cc and onnxruntime_pybind_state.cc
void cudaMemcpy_HostToDevice(void* dst, const void* src, size_t count) override { CUDA_CALL_THROW(cudaMemcpy(dst, src, count, cudaMemcpyHostToDevice)); }

void cudaMemcpy_HostToDevice(void* dst, const void* src, size_t count) override {
// cudaMemcpy() operates on the default stream
CUDA_CALL_THROW(cudaMemcpy(dst, src, count, cudaMemcpyHostToDevice));

// To ensure that the copy has completed, invoke a stream sync for the default stream.
// https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior__memcpy-sync
// For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated.
// The function will return once the pageable buffer has been copied to the staging memory for DMA transfer
// to device memory, but the DMA to final destination may not have completed.

CUDA_CALL_THROW(cudaStreamSynchronize(0));
}

// Used by onnxruntime_pybind_state.cc
void cudaMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override { CUDA_CALL_THROW(cudaMemcpy(dst, src, count, cudaMemcpyDeviceToHost)); }
void cudaMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override {
// https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior__memcpy-sync
// For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed.
CUDA_CALL_THROW(cudaMemcpy(dst, src, count, cudaMemcpyDeviceToHost));
}

int cudaGetDeviceCount() override {
int num_devices = 0;
Expand Down
14 changes: 13 additions & 1 deletion onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,19 @@ def test_session_with_ortvalue_input(ortvalue):

# The constructed OrtValue should still be valid after being used in a session
self.assertTrue(np.array_equal(ortvalue2.numpy(), numpy_arr_input))


def testOrtValue_ghIssue9799(self):
if 'CUDAExecutionProvider' in onnxrt.get_available_providers():
session = onnxrt.InferenceSession(get_name("identity_9799.onnx"),
providers=onnxrt.get_available_providers())

for seq_length in range(40, 200):
inps = np.ones((seq_length, 16, 7, 5, 3, 3)).astype(np.float32)
ort_val = onnxrt.OrtValue.ortvalue_from_numpy(inps, 'cuda', 0)
upstreams_onnxrt = {'input': ort_val}
outs = session.run(output_names=['output'], input_feed=upstreams_onnxrt)[0]
self.assertTrue(np.allclose(inps, outs))

def testSparseTensorCooFormat(self):
cpu_device = onnxrt.OrtDevice.make('cpu', 0)
shape = [9,9]
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/test/testdata/identity_9799.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
pytorch1.10:�
%
inputoutput
Identity_0"Identitytorch-jit-exportZ;
input2
0,
input_dynamic_axes_1




b<
output2
0,
input_dynamic_axes_1




B

0 comments on commit 18fd2cf

Please sign in to comment.