Skip to content

Commit

Permalink
Fixes for running on the correct stream when changing devices.
Browse files Browse the repository at this point in the history
Thank you, Vladimir.

Signed-off-by: Daniel Galvez <dgalvez@computelab-frontend-3.nvidia.com>
  • Loading branch information
Daniel Galvez committed Feb 22, 2024
1 parent 2abe43e commit 7bbbe3d
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def create_inner_while_loop_kernel():


@contextlib.contextmanager
def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle):
def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle, device):
"""
Even though we add a conditional node only once, we need to
capture the kernel that calls cudaGraphSetConditional() both
Expand All @@ -115,15 +115,15 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi
to decide both whether to enter the loop, and also whether to
execute the next iteration of the loop).
"""
capture_status, _, graph, _, _ = cu_call(cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream().cuda_stream))
capture_status, _, graph, _, _ = cu_call(cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream))
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive

cuda.cuLaunchKernel(
while_loop_kernel, 1, 1, 1, 1, 1, 1, 0, torch.cuda.current_stream().cuda_stream, while_loop_args.ctypes.data, 0
while_loop_kernel, 1, 1, 1, 1, 1, 1, 0, torch.cuda.current_stream(device=device).cuda_stream, while_loop_args.ctypes.data, 0
)

capture_status, _, graph, dependencies, _ = cu_call(
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream().cuda_stream)
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream)
)
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive

Expand All @@ -145,14 +145,14 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi

cu_call(
cudart.cudaStreamUpdateCaptureDependencies(
torch.cuda.current_stream().cuda_stream,
torch.cuda.current_stream(device=device).cuda_stream,
[node],
1,
cudart.cudaStreamUpdateCaptureDependenciesFlags.cudaStreamSetCaptureDependencies,
)
)
body_stream = torch.cuda.Stream()
previous_stream = torch.cuda.current_stream()
body_stream = torch.cuda.Stream(device)
previous_stream = torch.cuda.current_stream(device=device)
cu_call(
cudart.cudaStreamBeginCaptureToGraph(
body_stream.cuda_stream,
Expand Down Expand Up @@ -198,6 +198,8 @@ def __init__(self, max_symbols: int, caller):
self.encoder_output = None
self.encoder_output_length = None
self.f = None
# We also lazily initialize a variable holding the current device
self.device = None

# Reasonable default maximum time. 375 frames * (80ms / frame) = 30 seconds
# 80ms is the frame size of recent fastconformer models
Expand All @@ -214,6 +216,7 @@ def __init__(self, max_symbols: int, caller):
self.caller = caller

def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_length):
torch.cuda.nvtx.range_push("Init")
if self.first_call:
# We need to call the original _greedy_decode_blank_as_pad
# implementation at least once beforehand in order to make
Expand All @@ -223,9 +226,12 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
# initializes things like a cudnnHandle_t via
# cudnnCreate(), which can involve synchronizing with the
# host. Such actions are not stream capturable to a graph.
self.caller._greedy_decode_blank_as_pad_loop_frames(
encoder_output, encoder_output_length, encoder_output.device
)
with torch.cuda.stream(torch.cuda.Stream(self.device)):
self.caller._greedy_decode_blank_as_pad_loop_frames(
encoder_output, encoder_output_length, encoder_output.device
)

self.device = encoder_output.device

self.symbols_added_t = torch.tensor(0, dtype=torch.int64, device=encoder_output.device)
self.max_symbols_t = torch.tensor(self.max_symbols, dtype=torch.int64, device=encoder_output.device)
Expand Down Expand Up @@ -261,11 +267,18 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
(self.batch_size, self.max_time, self.max_symbols), dtype=torch.int64, device="cpu", pin_memory=True
)

self.graph = None

self.graph = torch.cuda.CUDAGraph()

# Always create a new stream, because the per-thread default stream disallows stream capture to a graph.
with torch.cuda.stream(torch.cuda.Stream()), torch.inference_mode(), torch.cuda.graph(self.graph):
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("build graph")

# Always create a new stream, because the per-thread default stream disallows stream capture to a graph.
stream_for_graph = torch.cuda.Stream(self.device)
with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph(self.graph, stream=stream_for_graph):
# This is failing...
self.f = torch.zeros(
(self.batch_size, 1, self.encoder_output.shape[-1]),
dtype=encoder_output.dtype,
Expand Down Expand Up @@ -295,7 +308,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
self.max_out_len_t = self.encoder_output_length.max()

capture_status, _, graph, _, _ = cu_call(
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream().cuda_stream)
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.device).cuda_stream)
)
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive

Expand All @@ -308,7 +321,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
dtype=np.uint64,
)

with with_conditional_node(for_loop_kernel, for_loop_args, for_loop_conditional_handle):
with with_conditional_node(for_loop_kernel, for_loop_args, for_loop_conditional_handle, self.device):
torch.index_select(self.encoder_output, 1, self.time_idx_t.unsqueeze(0), out=self.f)

self.not_all_blank_t.fill_(True)
Expand All @@ -330,7 +343,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
],
dtype=np.uint64,
)
with with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle):
with with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle, self.device):
g, hidden_prime = self.caller._pred_step(
self.last_label.unsqueeze(1), hidden, batch_size=self.batch_size
)
Expand Down Expand Up @@ -373,6 +386,8 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
self.last_label.fill_(self.caller._SOS)
self.time_idx_t.fill_(0)

torch.cuda.nvtx.range_pop()

def __call__(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -400,13 +415,28 @@ def __call__(
if torch.is_autocast_enabled():
x = x.to(torch.get_autocast_gpu_dtype())

if max_time > self.max_time or batch_size > self.batch_size:
if (max_time > self.max_time or
batch_size > self.batch_size or
self.device != x.device):
# In the first two cases, we need to recreate the cuda
# graph to handle larger tensor sizes. In the third case,
# we need to recreate the graph, as well as all tensors,
# because the computation is now happening on a different
# GPU. Therefore, in the third case, we unconditionally
# set self.first_call to True to make sure that all
# possibly blocking initializers are initialized properly
# again on the new device.
print("GALVEZ: reinit!")
if self.device != x.device:
self.first_call = True
self._reinitialize(max_time, batch_size, x, out_len)

torch.cuda.nvtx.range_push("Graph")
self.encoder_output[: x.shape[0], : x.shape[1], ...].copy_(x)
self.encoder_output_length[: out_len.shape[0]].copy_(out_len)
self.graph.replay()
torch.cuda.current_stream().synchronize()
torch.cuda.current_stream(device=self.device).synchronize()
torch.cuda.nvtx.range_pop()

self.scores_cpu[self.labels_cpu == self.caller._blank_index] = 0.0
total_scores = self.scores_cpu.sum(dtype=torch.float32, axis=(1, 2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,43 @@ def test_cuda_graph_rnnt_greedy_decoder(model_name, batch_size, enable_bfloat16)
print("erroneous samples:")
print("Original transcript:", actual)
print("New transcript:", fast)

def test_change_devices():
if torch.cuda.device_count() < 2:
pytest.skip("Test requires more than 2 GPUs")

first_device = torch.device("cuda:0")
second_device = torch.device("cuda:1")

model_name = "stt_en_fastconformer_transducer_xlarge"
batch_size = 8

conf = ASRModel.from_pretrained(model_name, return_config=True)
with open_dict(conf):
conf["decoding"]["greedy"]["max_symbols"] = 5
conf["decoding"]["greedy"]["loop_labels"] = False
conf["decoding"]["greedy"]["use_cuda_graph_decoder"] = True

nemo_model = ASRModel.from_pretrained(model_name, map_location=second_device)
nemo_model.change_decoding_strategy(conf["decoding"])

# Test that the model can run successfully when it is first
# initialized on second_device and then transferred to
# true_device
nemo_model.to(first_device)
audio_filepaths = glob.glob("tests/.data/asr/test/an4/wav/*.wav")
with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
second_device_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None)

# Test that the model can run successfully back on second_device
# after having been first run on first_device. Because the
# decoder's data structures are lazily initialized, this activates
# slightly different code than the first case (where the decoder
# has not run at all), so we want to exercise both cases.
nemo_model.to(second_device)

with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
first_device_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None)
# Sanity check: The device we run on should not change execution
# output.
assert first_device_transcripts == second_device_transcripts

0 comments on commit 7bbbe3d

Please sign in to comment.