Skip to content

Commit

Permalink
[Dynamic Shape] Adding more dynamic shape support (#228)
Browse files Browse the repository at this point in the history
* .

.

.

.

.

.

.

.

.

* .

* fix
  • Loading branch information
yaoyaoding authored May 16, 2023
1 parent 8a57169 commit ff6e6dc
Show file tree
Hide file tree
Showing 58 changed files with 1,140 additions and 759 deletions.
11 changes: 0 additions & 11 deletions gallery/how-to-guides/visualize-flow-graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,6 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor):

print(graph_opt)

# %%
# The dumped netron graphs that can be visualized:
#
# :download:`Download 1_FoldConstantPass.json <../../../../gallery/how-to-guides/outs/1_FoldConstantPass.json>`
#
# :download:`Download 2_PatternTransformPass.json <../../../../gallery/how-to-guides/outs/2_SubgraphRewritePass.json>`
#
# :download:`Download 4_ResolveVariantPass.json <../../../../gallery/how-to-guides/outs/4_ResolveVariantPass.json>`
#
# :download:`Download 5_FuseOperatorPass.json <../../../../gallery/how-to-guides/outs/5_FuseOperatorPass.json>`

# %%
# Summary
# -------
Expand Down
1 change: 1 addition & 0 deletions python/hidet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .graph import ops
from .graph import empty, randn, zeros, ones, full, randint, symbol, asarray, from_torch
from .graph import empty_like, randn_like, zeros_like, ones_like, symbol_like, full_like
from .graph import symbolic_size
from .graph import trace_from, load_graph, save_graph
from .graph import jit
from .graph import from_dlpack
Expand Down
31 changes: 25 additions & 6 deletions python/hidet/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from hidet.ir.func import IRModule, Function
from hidet.ir.type import FuncType
from hidet.runtime.module import compiled_task_cache, CompiledFunction
from hidet.runtime.device import Device

logger = logging.Logger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -42,7 +43,7 @@ def build_task(task: Task, target_device='cuda', load=True) -> Optional[Compiled
----------
task: Task
The task to be built.
target_device: str
target_device: str or Device
The target device. Candidates are 'cuda' and 'cpu'.
load: bool
Whether to load the compiled function. If False, the compiled function will not be loaded, and None is returned.
Expand All @@ -55,6 +56,9 @@ def build_task(task: Task, target_device='cuda', load=True) -> Optional[Compiled
task_string: str = str(task)
compiled_func: Optional[CompiledFunction] = None

if isinstance(target_device, Device):
target_device = target_device.type

space_level = option.get_option('search_space')
op_cache_dir = os.path.join(option.get_option('cache_dir'), './ops')
use_cache = option.get_option('cache_operator')
Expand Down Expand Up @@ -152,20 +156,35 @@ def _lazy_initialize_cuda():
hidet.cuda.compute_capability(i)


def build_task_batch(tasks: List[Task], target_device: str = 'cuda', raise_on_error: bool = True):
def get_objects(obj, predicate, visited: set, path: list):
import gc

for referent in gc.get_referents(obj):
if id(referent) not in visited:
visited.add(id(referent))
if predicate(referent):
for p in path:
print(p)
assert False
path.append(referent)
get_objects(referent, predicate, visited, path)
path.pop()


def build_task_batch(task_device_pairs: List[Tuple[Task, Device]], raise_on_error: bool = True):
dumped_options = option.dump_options()
jobs = [(task, target_device, dumped_options) for task in tasks]
jobs = [(task, device, dumped_options) for task, device in task_device_pairs]
if option.get_option('parallel_build') and len(jobs) > 1:
_lazy_initialize_cuda()
with multiprocessing.Pool() as pool:
status_list = list(pool.map(_build_task_job, jobs))
else:
status_list = list(map(_build_task_job, jobs))
if not all(status_list) and raise_on_error:
if not all(status_list) and option.get_option('parallel_build') and raise_on_error:
msg = ['Failed to build {} tasks:'.format(sum(1 for s in status_list if not s))]
for task, status in zip(tasks, status_list):
for (task, device), status in zip(task_device_pairs, status_list):
if not status:
msg.append(f' {task.signature()}')
msg.append(f' [{device}] {task.signature()}')
msg.append('Please turn off parallel build to see the error message:')
msg.append(' hidet.option.parallel_build(False)')
raise RuntimeError('\n'.join(msg))
Expand Down
11 changes: 6 additions & 5 deletions python/hidet/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,23 @@
# limitations under the License.
from . import tensor
from . import operator
from . import module
from . import modules
from . import nn
from . import ops
from . import ir
from . import frontend

from .tensor import Tensor
from .operator import Operator
from .module import Module
from .operator import Operator, SizeVar
from .ir import FlowGraph
from .transforms import GraphPass, PassContext, GraphPassInstrument
from .ir.flow_graph import GraphForwardContext, GraphForwardInstrument
from .ir.instruments import GraphForwardBenchmarkInstrument, GraphForwardDebugInstrument
from .nn import Module

from .tensor import asarray, randn, empty, zeros, ones, symbol, randint, randn_like, empty_like, zeros_like, ones_like
from .tensor import symbol_like, full, full_like
from .tensor import from_numpy, from_dlpack, from_torch
from .tensor import symbolic_size
from .ir import trace_from, load_graph, save_graph, forward_context
from .transforms import optimize
from .modules import nn
from .jit import jit
2 changes: 1 addition & 1 deletion python/hidet/graph/frontend/onnx/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import onnx.numpy_helper
import onnx.external_data_helper
import hidet
from hidet.graph.modules import nn
from hidet.graph.nn import nn
from hidet.graph import ops
from hidet.graph.tensor import Tensor, from_numpy, randn
from . import utils
Expand Down
127 changes: 40 additions & 87 deletions python/hidet/graph/ir/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
import hidet.graph.operator
import hidet.cuda
from hidet import option
from hidet.ir.expr import Var, Constant, convert
from hidet.ir.expr import Var
from hidet.ir.task import Task
from hidet.graph.tensor import Tensor, zeros_like, randn_like
from hidet.graph.operator import Operator
from hidet.utils.doc import Doc, NewLine, Text, doc_join
from hidet.utils.namer import Namer
from hidet.graph.operator import Operator, SizeVar

logger = logging.getLogger(__name__)

Expand All @@ -36,10 +35,12 @@ def before_graph(self, graph: FlowGraph, inputs: List[Tensor]) -> None:
def after_graph(self, graph: FlowGraph, inputs: List[Tensor], outputs: List[Tensor]) -> None:
pass

def before_operator(self, op: Operator, inputs: List[Tensor]) -> None:
def before_operator(self, op: Operator, inputs: List[Tensor], shape_map: Dict[SizeVar, int]) -> None:
pass

def after_operator(self, op: Operator, inputs: List[Tensor], outputs: List[Tensor]) -> None:
def after_operator(
self, op: Operator, inputs: List[Tensor], shape_map: Dict[SizeVar, int], outputs: List[Tensor]
) -> None:
pass


Expand Down Expand Up @@ -75,27 +76,29 @@ def after_graph(graph: FlowGraph, inputs: List[Tensor], outputs: List[Tensor]) -
instrument.after_graph(graph, inputs, outputs)

@staticmethod
def before_operator(op: Operator, inputs: List[Tensor]) -> None:
def before_operator(op: Operator, inputs: List[Tensor], shape_map: Dict[SizeVar, int]) -> None:
ctx = GraphForwardContext.current()
for instrument in ctx.instruments:
instrument.before_operator(op, inputs)
instrument.before_operator(op, inputs, shape_map)

@staticmethod
def after_operator(op: Operator, inputs: List[Tensor], outputs: List[Tensor]) -> None:
def after_operator(
op: Operator, inputs: List[Tensor], shape_map: Dict[SizeVar, int], outputs: List[Tensor]
) -> None:
ctx = GraphForwardContext.current()
for instrument in ctx.instruments:
instrument.after_operator(op, inputs, outputs)
instrument.after_operator(op, inputs, shape_map, outputs)

def append_instrument(self, instrument: GraphForwardInstrument):
self.instruments.append(instrument)

def debug(self, output_dir='./outs/debug', print_summary: bool = False):
from .flow_graph_impl import GraphForwardDebugInstrument
def debug(self, output_dir='./outs/debug', print_summary: bool = False, dump_outputs: bool = False):
from .instruments import GraphForwardDebugInstrument

self.instruments.append(GraphForwardDebugInstrument(output_dir, print_summary))
self.instruments.append(GraphForwardDebugInstrument(output_dir, print_summary, dump_outputs))

def benchmark(self, output_dir='./outs/benchmark', print_summary: bool = False, warmup=3, number=10, repeat=3):
from .flow_graph_impl import GraphForwardBenchmarkInstrument
from .instruments import GraphForwardBenchmarkInstrument

self.instruments.append(GraphForwardBenchmarkInstrument(output_dir, print_summary, warmup, number, repeat))

Expand Down Expand Up @@ -132,62 +135,9 @@ def __call__(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:
return outputs[0] if len(outputs) == 1 else outputs

def __str__(self):
namer = Namer()

def get_tensor_sig(x: Tensor) -> Doc:
return Text(x.dtype.name) + '[' + doc_join([str(v) for v in x.shape], ', ') + ']'

def get_attr_repr(value: Union[float, int, bool, str, list, tuple, FlowGraph]) -> Doc:
if isinstance(value, (float, int, bool)):
return Text(str(value))
elif isinstance(value, str):
return Text('"{}"'.format(value))
elif isinstance(value, list):
return '[' + doc_join([get_attr_repr(v) for v in value], ', ') + ']'
elif isinstance(value, tuple):
return '(' + doc_join([get_attr_repr(v) for v in value], ', ') + ')'
elif isinstance(value, FlowGraph):
return Text('FlowGraph({})'.format(', '.join(u.name for u in value.nodes)))
else:
return Text(str(value))

param_docs = []
for x in self.inputs:
name = namer(x)
param_docs.append(Text(name) + ': ' + get_tensor_sig(x))

# head
head_doc = 'Graph(' + doc_join(param_docs, ', ') + ')'

# body
body_doc = Doc()
const_doc = Doc()
for op in self.nodes:
# const inputs
for x in op.inputs:
if x not in namer.obj_name:
assert x.storage is not None
const_doc += NewLine() + namer.get_name(x, hint='c') + ' = Constant(' + get_tensor_sig(x) + ')'
outputs = op.outputs
if len(outputs) > 1:
raise NotImplementedError()
output: Tensor = outputs[0]
line_doc = Doc()
line_doc += namer(output) + ': ' + get_tensor_sig(output) + ' = '
line_doc += op.name + '('
line_doc += doc_join([namer(x) for x in op.inputs], sep=', ')
if op.attrs:
line_doc += ', ' + doc_join(
[Text(name) + '=' + get_attr_repr(value) for name, value in op.attrs.items()], ', '
)
line_doc += ') '
body_doc += NewLine() + line_doc
from .flow_graph_impl import flow_graph_as_text

# return statement
body_doc += NewLine() + Text('return ') + doc_join([namer(x) for x in self.outputs], ', ')

graph_doc = head_doc + '{' + const_doc.indent() + body_doc.indent() + NewLine() + '}'
return str(graph_doc)
return flow_graph_as_text(self)

@property
def nodes(self) -> List[Operator]:
Expand All @@ -208,8 +158,10 @@ def invalid_cache(self):
self._usage_count = None

def _build(self):
tasks = []
tunable_tasks = []
from hidet.runtime.device import Device

tasks: List[Tuple[Task, Device]] = []
tunable_tasks: List[Tuple[Task, Device]] = []
task_keys = set()
search_space = hidet.option.get_option('search_space')
for node in self.nodes:
Expand All @@ -222,9 +174,9 @@ def _build(self):
method not in node.task.__class__.__dict__
for method in ['implement_cuda', 'implement_cpu', 'implement']
):
tasks.append(node.task)
tasks.append((node.task, node.device))
else:
tunable_tasks.append(node.task)
tunable_tasks.append((node.task, node.device))

hidet.driver.build_task_batch(tasks)

Expand Down Expand Up @@ -263,12 +215,10 @@ def forward(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:
# a tensor should be freed after running an operator.
usage_count = self.usage_count.copy()
tensor_map: Dict[Tensor, Tensor] = {} # symbolic tensor -> actual tensor during the forward process
shape_remap: Dict[Var, Constant] = {} # symbolic dimension -> actual shape dimension for the symbolic tensors
shape_map: Dict[SizeVar, int] = {} # symbolic dimension -> actual shape dimension for the symbolic tensors
for st, at in zip(self.inputs, inputs):
tensor_map[st] = at
shape_remap.update(
{dim: convert(at.shape[idx]) for idx, dim in enumerate(st.shape) if isinstance(dim, Var)}
)
shape_map.update({dim: at.shape[idx] for idx, dim in enumerate(st.shape) if isinstance(dim, Var)})

# run each operator in the graph in a topological order
for idx, node in enumerate(self.nodes):
Expand All @@ -285,21 +235,22 @@ def forward(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:
else:
# constant input
node_inputs.append(node_input)
node_inputs = node_inputs[: len(node.inputs)]

# run node
GraphForwardContext.before_operator(node, node_inputs)
logger.debug('[%4d/%d] run operator %s', idx, len(self.nodes), node.name)
GraphForwardContext.before_operator(node, node_inputs, shape_map)
logger.debug('[%4d/%d] run operator %s, %s', idx, len(self.nodes), node.name, node.task)
logger.debug(' inputs: %s', [x.signature() for x in node_inputs])
node_outputs = node.imperative_run(node_inputs, remap=shape_remap)
node_outputs = node.imperative_run(node_inputs, shape_map=shape_map)
logger.debug(' outputs: %s', [x.signature() for x in node_outputs])
GraphForwardContext.after_operator(node, node_inputs, node_outputs)
GraphForwardContext.after_operator(node, node_inputs, shape_map, node_outputs)

# update map
for node_output, symbolic_output in zip(node_outputs, node.outputs):
tensor_map[symbolic_output] = node_output
shape_remap.update(
shape_map.update(
{
dim: convert(node_output.shape[idx])
dim: node_output.shape[idx]
for idx, dim in enumerate(symbolic_output.shape)
if isinstance(dim, Var)
}
Expand Down Expand Up @@ -478,10 +429,10 @@ def _analyze(outputs: List[Tensor]) -> Tuple[List[Tensor], List[Operator], Dict[

def find_all_nodes(u: Operator):
all_nodes.add(u)
for it in u.inputs:
if it.op is None:
for x in u.inputs:
if x.op is None:
continue
v: Operator = it.op
v: Operator = x.op
if v not in all_nodes:
find_all_nodes(v)

Expand All @@ -493,7 +444,7 @@ def find_all_nodes(u: Operator):
out_degree: Dict[Operator, int] = {u: 0 for u in all_nodes}
for u in all_nodes:
for it in u.inputs:
if it.op is None:
if it.op is None or it.op not in all_nodes:
continue
out_degree[it.op] += 1
for u in outputs:
Expand All @@ -514,6 +465,8 @@ def find_all_nodes(u: Operator):
if it.storage is None and all(it is not v for v in free_vars):
# input
free_vars.append(it)
elif it.op not in all_nodes:
pass
else:
if it is not it.op.outputs[it.trace[1]]:
raise ValueError('The trace is broken')
Expand Down
Loading

0 comments on commit ff6e6dc

Please sign in to comment.