Skip to content

Commit

Permalink
Force GPU device objects that refer to the same physical card using the
Browse files Browse the repository at this point in the history
same stream id to use the same cuda stream objects. This avoids confusing
the per-device memory allocator in ways that cause memory corruption.

Fixes tensorflow/serving#335.

PiperOrigin-RevId: 157258318
  • Loading branch information
tensorflower-gardener committed May 26, 2017
1 parent 2ff1d7b commit fccaac3
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 40 deletions.
98 changes: 66 additions & 32 deletions tensorflow/core/common_runtime/gpu/gpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ limitations under the License.
#include <stdlib.h>
#include <string.h>
#include <algorithm>
#include <map>
#include <tuple>
#include <vector>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
Expand Down Expand Up @@ -174,6 +176,63 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
TF_DISALLOW_COPY_AND_ASSIGN(EigenCudaStreamDevice);
};

// This factory helps to ensure that different GPU device objects that refer to
// the same physical device and stream group id use the same stream group
// object (and therefore the same CUDA streams). This is necessary since there
// is a single memory allocator per device (see ProcessState::GetGPUAllocator)
// and allocators must not be shared across streams.
class BaseGPUDevice::StreamGroupFactory {
public:
// Returns the unique stream group for use with the stream defined by
// {gpu_id, stream_group_within_gpu}, creating it if it does not yet exist.
// This function is thread safe.
BaseGPUDevice::StreamGroup* GetOrCreate(int gpu_id,
int stream_group_within_gpu,
gpu::StreamExecutor* executor) {
mutex_lock guard(lock_);
StreamGroup* group = &streams_[key_type(gpu_id, stream_group_within_gpu)];
if (!group->compute) {
group->compute = new gpu::Stream(executor);
group->compute->Init();
VLOG(2) << "Created stream[" << stream_group_within_gpu
<< "] = " << group->compute;

group->host_to_device = new gpu::Stream(executor);
group->host_to_device->Init();
VLOG(2) << "Created host_to_device_stream[" << stream_group_within_gpu
<< "] = " << group->host_to_device;

group->device_to_host = new gpu::Stream(executor);
group->device_to_host->Init();
VLOG(2) << "Created device_to_host_stream[" << stream_group_within_gpu
<< "] = " << group->device_to_host;

group->device_to_device = new gpu::Stream(executor);
group->device_to_device->Init();
VLOG(2) << "Created device_to_device_stream[" << stream_group_within_gpu
<< "] = " << group->device_to_host;
}
return group;
}

// Returns a reference to the StreamGroupFactory singleton. Note that this is
// never destroyed, so the objects it owns are never deleted.
static StreamGroupFactory& Global() {
static StreamGroupFactory* instance = new StreamGroupFactory();
return *instance;
}

private:
mutex lock_;
using key_type = std::tuple<int, int>;
std::map<key_type, StreamGroup> streams_;

// StreamGroupFactory cannot be created directly; Call
// StreamGroupFactory::Global() to get the global instance.
StreamGroupFactory() = default;
TF_DISALLOW_COPY_AND_ASSIGN(StreamGroupFactory);
};

BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
Bytes memory_limit, const DeviceLocality& locality,
int gpu_id, const string& physical_device_desc,
Expand All @@ -193,12 +252,6 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
BaseGPUDevice::~BaseGPUDevice() {
delete gpu_device_info_;
for (auto ctx : device_contexts_) ctx->Unref();
for (auto& stream_group : streams_) {
delete stream_group.compute;
delete stream_group.host_to_device;
delete stream_group.device_to_host;
delete stream_group.device_to_device;
}
}

Status BaseGPUDevice::Init(const SessionOptions& options) {
Expand All @@ -217,27 +270,8 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {

// Create the specified number of GPU streams
for (int i = 0; i < max_streams_; i++) {
auto stream = new gpu::Stream(executor_);
stream->Init();
VLOG(2) << "Created stream[" << i << "] = " << stream;

auto host_to_device_stream = new gpu::Stream(executor_);
host_to_device_stream->Init();
VLOG(2) << "Created host_to_device_stream[" << i
<< "] = " << host_to_device_stream;

auto device_to_host_stream = new gpu::Stream(executor_);
device_to_host_stream->Init();
VLOG(2) << "Created device_to_host_stream[" << i
<< "] = " << device_to_host_stream;

auto device_to_device_stream = new gpu::Stream(executor_);
device_to_device_stream->Init();
VLOG(2) << "Created device_to_device_stream[" << i
<< "] = " << device_to_device_stream;

streams_.push_back({stream, host_to_device_stream, device_to_host_stream,
device_to_device_stream});
streams_.push_back(
StreamGroupFactory::Global().GetOrCreate(gpu_id_, i, executor_));

size_t scratch_buffer_size = Eigen::kCudaScratchSize + sizeof(unsigned int);
void* scratch_buffer = gpu_allocator_->AllocateRaw(
Expand All @@ -259,12 +293,12 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
"Failed to memcopy into scratch buffer for device ", gpu_id_);
}

device_contexts_.push_back(
new GPUDeviceContext(i, stream, host_to_device_stream,
device_to_host_stream, device_to_device_stream));
device_contexts_.push_back(new GPUDeviceContext(
i, streams_.back()->compute, streams_.back()->host_to_device,
streams_.back()->device_to_host, streams_.back()->device_to_device));
}
gpu_device_info_ = new GpuDeviceInfo;
gpu_device_info_->stream = streams_[0].compute;
gpu_device_info_->stream = streams_[0]->compute;
gpu_device_info_->default_context = device_contexts_[0];
gpu_device_info_->event_mgr = em_.get();
gpu_device_info_->gpu_id = gpu_id_;
Expand Down Expand Up @@ -511,7 +545,7 @@ void BaseGPUDevice::ReinitializeDevice(OpKernelContext* context,
static_cast<ConcretePerOpGpuDevice*>(device);
DCHECK(concrete_device);
const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>(
streams_[stream_id].compute->implementation()->CudaStreamMemberHack());
streams_[stream_id]->compute->implementation()->CudaStreamMemberHack());
concrete_device->Reinitialize(context, cuda_stream, gpu_id_, allocator,
scratch_[stream_id]);
}
Expand Down
16 changes: 11 additions & 5 deletions tensorflow/core/common_runtime/gpu/gpu_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
Expand Down Expand Up @@ -96,12 +100,14 @@ class BaseGPUDevice : public LocalDevice {

private:
struct StreamGroup {
gpu::Stream* compute;
gpu::Stream* host_to_device;
gpu::Stream* device_to_host;
gpu::Stream* device_to_device;
gpu::Stream* compute = nullptr;
gpu::Stream* host_to_device = nullptr;
gpu::Stream* device_to_host = nullptr;
gpu::Stream* device_to_device = nullptr;
};
gtl::InlinedVector<StreamGroup, 4> streams_;
class StreamGroupFactory;

gtl::InlinedVector<StreamGroup*, 4> streams_;
gtl::InlinedVector<char*, 4> scratch_;
std::vector<GPUDeviceContext*> device_contexts_;
GpuDeviceInfo* gpu_device_info_ = nullptr;
Expand Down
57 changes: 54 additions & 3 deletions tensorflow/python/kernel_tests/basic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@
from __future__ import division
from __future__ import print_function

import math
import itertools
import threading

import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.gen_array_ops import _broadcast_gradient_args
from tensorflow.python.platform import test

Expand Down Expand Up @@ -219,5 +225,50 @@ def testGradient(self):
self._compareGpu(x, y + 0.1, np.floor_divide, math_ops.floordiv)


if __name__ == "__main__":
class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase):
"""Tests concurrent sessions executing on the same GPU."""

def _run_session(self, results):
n_iterations = 500
with self.test_session(use_gpu=True) as s:
data = variables.Variable(1.0)
with ops.device('/gpu:0'):
random_seed.set_random_seed(1)
matrix1 = variables.Variable(
random_ops.truncated_normal([1024, 1]), name='matrix1')
matrix2 = variables.Variable(
random_ops.truncated_normal([1, 1024]), name='matrix2')
x1 = math_ops.multiply(data, matrix1, name='x1')
x3 = math_ops.matmul(x1, math_ops.matmul(matrix2, matrix1))
x4 = math_ops.matmul(array_ops.transpose(x3), x3, name='x4')
s.run(variables.global_variables_initializer())

for _ in xrange(n_iterations):
value = s.run(x4)
results.append(value)
if value != results[0]:
break

def testConcurrentSessions(self):
if not test.is_gpu_available():
return

n_threads = 4
results = [[]] * n_threads
threads = [
threading.Thread(target=self._run_session, args=(results[i],))
for i in xrange(n_threads)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()

flat_results = [x for x in itertools.chain(*results)]
self.assertNotEqual(0, len(flat_results))
for result in flat_results:
self.assertEqual(result, flat_results[0])


if __name__ == '__main__':
test.main()

0 comments on commit fccaac3

Please sign in to comment.