Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,19 @@ def communication_backend_name(self):
def is_triton_supported(self):
...

# Graph operations
@abc.abstractmethod
def create_graph(self):
...

@abc.abstractmethod
def capture_to_graph(self, graph):
...

@abc.abstractmethod
def replay_graph(self, graph):
...

# Tensor operations
@property
@abc.abstractmethod
Expand Down
12 changes: 11 additions & 1 deletion accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,18 @@ def is_fp16_supported(self):
def supported_dtypes(self):
return [torch.float, torch.bfloat16]

# Tensor operations
# Graph operations
def create_graph(self):
return None

def capture_to_graph(self, graph):
from deepspeed.runtime.utils import noop_context
return noop_context()

def replay_graph(self, graph):
return

# Tensor operations
@property
def BFloat16Tensor(self):
return torch.BFloat16Tensor
Expand Down
11 changes: 11 additions & 0 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,17 @@ def is_triton_supported(self):
else:
return False

# Graph operations
def create_graph(self):
return torch.cuda.CUDAGraph()

def capture_to_graph(self, graph):
return torch.cuda.graph(graph)

def replay_graph(self, graph):
graph.replay()
return

# Tensor operations

@property
Expand Down
11 changes: 11 additions & 0 deletions accelerator/mps_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,17 @@ def communication_backend_name(self):
def is_triton_supported(self):
return False

# Graph operations
def create_graph(self):
return None

def capture_to_graph(self, graph):
from deepspeed.runtime.utils import noop_context
return noop_context()

def replay_graph(self, graph):
return

# Tensor operations
@property
def BFloat16Tensor(self):
Expand Down
11 changes: 11 additions & 0 deletions accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,17 @@ def communication_backend_name(self):
def is_triton_supported(self):
return False

# Graph operations
def create_graph(self):
return None

def capture_to_graph(self, graph):
from deepspeed.runtime.utils import noop_context
return noop_context()

def replay_graph(self, graph):
return

# Tensor operations

@property
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,11 +524,11 @@ def _create_cuda_graph(self, *inputs, **kwargs):
get_accelerator().current_stream().wait_stream(cuda_stream)

# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs = torch.cuda.CUDAGraph()
self._cuda_graphs = get_accelerator().create_graph()
self.static_inputs = inputs
self.static_kwargs = kwargs

with torch.cuda.graph(self._cuda_graphs):
with get_accelerator().capture_to_graph(self._cuda_graphs):
self.static_output = self.module(*self.static_inputs, **self.static_kwargs)

self.cuda_graph_created = True
Expand All @@ -540,7 +540,7 @@ def _graph_replay(self, *inputs, **kwargs):
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._cuda_graphs.replay()
get_accelerator().replay_graph(self._cuda_graphs)
return self.static_output

def model_times(self):
Expand Down
7 changes: 4 additions & 3 deletions deepspeed/model_implementations/diffusers/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import torch
from deepspeed.accelerator import get_accelerator
from ..features.cuda_graph import CUDAGraph


Expand All @@ -29,7 +30,7 @@ def _graph_replay(self, *inputs, **kwargs):
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._cuda_graphs.replay()
get_accelerator().replay_graph(self._cuda_graphs)
return self.static_output

def forward(self, *inputs, **kwargs):
Expand All @@ -53,11 +54,11 @@ def _create_cuda_graph(self, *inputs, **kwargs):
torch.cuda.current_stream().wait_stream(cuda_stream)

# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs = torch.cuda.CUDAGraph()
self._cuda_graphs = get_accelerator().create_graph()
self.static_inputs = inputs
self.static_kwargs = kwargs

with torch.cuda.graph(self._cuda_graphs):
with get_accelerator().capture_to_graph(self._cuda_graphs):
self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)

self.cuda_graph_created = True
Expand Down
19 changes: 10 additions & 9 deletions deepspeed/model_implementations/diffusers/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import torch
from deepspeed.accelerator import get_accelerator
from ..features.cuda_graph import CUDAGraph


Expand All @@ -27,7 +28,7 @@ def _graph_replay_decoder(self, *inputs, **kwargs):
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_decoder_kwargs[k].copy_(kwargs[k])
self._decoder_cuda_graph.replay()
get_accelerator().replay_graph(self._decoder_cuda_graph)
return self.static_decoder_output

def _decode(self, x, return_dict=True):
Expand All @@ -43,11 +44,11 @@ def _create_cuda_graph_decoder(self, *inputs, **kwargs):
torch.cuda.current_stream().wait_stream(cuda_stream)

# create cuda_graph and assign static_inputs and static_outputs
self._decoder_cuda_graph = torch.cuda.CUDAGraph()
self._decoder_cuda_graph = get_accelerator().create_graph()
self.static_decoder_inputs = inputs
self.static_decoder_kwargs = kwargs

with torch.cuda.graph(self._decoder_cuda_graph):
with get_accelerator().capture_to_graph(self._decoder_cuda_graph):
self.static_decoder_output = self._decode(*self.static_decoder_inputs, **self.static_decoder_kwargs)

self.decoder_cuda_graph_created = True
Expand All @@ -70,7 +71,7 @@ def _graph_replay_encoder(self, *inputs, **kwargs):
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_encoder_kwargs[k].copy_(kwargs[k])
self._encoder_cuda_graph.replay()
get_accelerator().replay_graph(self._encoder_cuda_graph)
return self.static_encoder_output

def _encode(self, x, return_dict=True):
Expand All @@ -86,11 +87,11 @@ def _create_cuda_graph_encoder(self, *inputs, **kwargs):
torch.cuda.current_stream().wait_stream(cuda_stream)

# create cuda_graph and assign static_inputs and static_outputs
self._encoder_cuda_graph = torch.cuda.CUDAGraph()
self._encoder_cuda_graph = get_accelerator().create_graph()
self.static_encoder_inputs = inputs
self.static_encoder_kwargs = kwargs

with torch.cuda.graph(self._encoder_cuda_graph):
with get_accelerator().capture_to_graph(self._encoder_cuda_graph):
self.static_encoder_output = self._encode(*self.static_encoder_inputs, **self.static_encoder_kwargs)

self.encoder_cuda_graph_created = True
Expand All @@ -113,7 +114,7 @@ def _graph_replay(self, *inputs, **kwargs):
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._all_cuda_graph.replay()
get_accelerator().replay_graph(self._all_cuda_graph)
return self.static_output

def forward(self, *inputs, **kwargs):
Expand All @@ -137,11 +138,11 @@ def _create_cuda_graph(self, *inputs, **kwargs):
torch.cuda.current_stream().wait_stream(cuda_stream)

# create cuda_graph and assign static_inputs and static_outputs
self._all_cuda_graph = torch.cuda.CUDAGraph()
self._all_cuda_graph = get_accelerator().create_graph()
self.static_inputs = inputs
self.static_kwargs = kwargs

with torch.cuda.graph(self._all_cuda_graph):
with get_accelerator().capture_to_graph(self._all_cuda_graph):
self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)

self.all_cuda_graph_created = True
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/model_implementations/transformers/clip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _graph_replay(self, *inputs, **kwargs):
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[self.iter][k].copy_(kwargs[k])
self._cuda_graphs[self.iter].replay()
get_accelerator().replay_graph(self._cuda_graphs[self.iter])
return self.static_output[self.iter]

def forward(self, *inputs, **kwargs):
Expand All @@ -63,11 +63,11 @@ def _create_cuda_graph(self, *inputs, **kwargs):
torch.cuda.current_stream().wait_stream(cuda_stream)

# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs[self.iter] = torch.cuda.CUDAGraph()
self._cuda_graphs[self.iter] = get_accelerator().create_graph()
self.static_inputs[self.iter] = inputs
self.static_kwargs[self.iter] = kwargs

with torch.cuda.graph(self._cuda_graphs[self.iter]):
with get_accelerator().capture_to_graph(self._cuda_graphs[self.iter]):
self.static_output[self.iter] = self._forward(*self.static_inputs[self.iter],
**self.static_kwargs[self.iter])

Expand Down
9 changes: 4 additions & 5 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,13 @@ def graph_process(replay_first_step, func, *args, **kwargs):
with get_accelerator().stream(cuda_stream):
func(*args, **kwargs)
get_accelerator().current_stream().wait_stream(cuda_stream)
# TODO: Apply get_accelerator interface for torch.cuda.CUDAGraph and torch.cuda.graph #ignore-cuda
graph_cache[func.__name__] = torch.cuda.CUDAGraph() #ignore-cuda
with torch.cuda.graph(graph_cache[func.__name__]): #ignore-cuda
graph_cache[func.__name__] = get_accelerator().create_graph()
with get_accelerator().capture_to_graph(graph_cache[func.__name__]):
func(*args, **kwargs)
if replay_first_step:
graph_cache[func.__name__].replay()
get_accelerator().replay_graph(graph_cache[func.__name__])
else:
graph_cache[func.__name__].replay()
get_accelerator().replay_graph(graph_cache[func.__name__])


def noop_decorator(func):
Expand Down