Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Dynamic shape fx trace #294

Merged
merged 24 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
803d302
lint again for some reason
Aalanli Jun 9, 2023
a59ba09
lint again for some reason
Aalanli Jun 9, 2023
0be7d38
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 9, 2023
06a9cd0
nevermind
Aalanli Jun 9, 2023
f550d2f
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 12, 2023
fd19d60
Merge branch 'hidet-org:main' into main
Aalanli Jun 17, 2023
323e1a8
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 26, 2023
d26f6aa
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 26, 2023
7c4fdc1
add support for dynamic shape compilation
Jun 27, 2023
cb70f69
format/lint
Jun 27, 2023
f9ae927
add extra check
Aalanli Jun 29, 2023
87abfe9
add another helpful debug option
Aalanli Jun 29, 2023
20a61c6
minor fixes
Aalanli Jun 29, 2023
ad7f50f
fix norm bug
Aalanli Jun 29, 2023
8dcbc39
format/lint
Aalanli Jun 29, 2023
f58fed5
fix bug in predict correctness conditional
Jun 29, 2023
48b8190
Merge branch 'main' into dynamic-fx-trace
Jun 30, 2023
b92442c
lint
Jun 30, 2023
cf21a7d
Merge branch 'main' into dynamic-fx-trace
Jul 4, 2023
b83f399
slightly change
yaoyaoding Jul 11, 2023
f5ea3ab
fix a small bug in reshape
yaoyaoding Jul 11, 2023
3e4bc32
reshape fix
Aalanli Jul 12, 2023
1546bf4
Merge branch 'dynamic-fx-trace' of https://github.com/Aalanli/hidet i…
Aalanli Jul 12, 2023
4b6be25
fix tensor_size registered method
Aalanli Jul 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=no-name-in-module
from typing import List, Callable, Sequence
from typing import List, Callable, Sequence, Union
import logging
import torch
import hidet.option
from hidet.ir.type import data_type
from hidet.ir.expr import is_constant
from hidet.graph.flow_graph import FlowGraph
from hidet.graph.transforms import PassContext, optimize
from hidet.runtime import CompiledGraph
Expand Down Expand Up @@ -103,40 +104,61 @@ def hidet_backend(graph_module, example_inputs):
interpreter: Interpreter = hidet.frontend.from_torch(graph_module)

# prepare dummy and symbolic inputs for correctness and flow graph construction
symbolic_inputs: List[Tensor] = [] # for flow graph construction
inputs: List[Union[Tensor, int, bool, float]] = [] # for flow graph construction
for example_input in example_inputs:
if isinstance(example_input, torch.Tensor):
symbolic_input = symbol_like_torch(example_input)
symbolic_inputs.append(symbolic_input)
inputs.append(symbolic_input)
elif isinstance(example_input, (int, bool, float)):
inputs.append(symbolic_input)
elif isinstance(example_input, torch.SymInt):
try:
inputs.append(int(example_input))
except Exception as e:
raise ValueError(f"hidet_backend: free symbolic example input {example_input}") from e
else:
raise ValueError('hidet_backend: only support torch.Tensor as example input')
raise ValueError(f'hidet_backend: unexpected example input {example_input}, type {type(example_input)}')

if dynamo_config['correctness_report']:
# check correctness using random inputs
logger.info('start to check correctness')
dummy_inputs: List[Tensor] = [] # for correctness check
for symbolic_input in symbolic_inputs:
if data_type(symbolic_input.dtype).is_integer():
dummy_input = hidet.zeros_like(symbolic_input)
# there exist some symbolic shapes, currently we don't support this option
# as there is no way to principly get concrete shapes at this stage from symbolic shapes
# since some models like resnet requires the image to be above a certain size.
if any(not all(is_constant(s) for s in t.shape) for t in inputs if isinstance(t, hidet.Tensor)):
raise ValueError("hidet_backend: cannot print correctness report with dynamic=True")
dummy_inputs = [] # for correctness check
for arg in inputs:
if isinstance(arg, hidet.Tensor):
if data_type(arg.dtype).is_integer():
dummy_input = hidet.zeros_like(arg)
else:
dummy_input = hidet.randn_like(arg)
else:
dummy_input = hidet.randn_like(symbolic_input)
dummy_input = arg
dummy_inputs.append(dummy_input)
report: str = interpreter.forward_with_check(*dummy_inputs)
logger.info('finish checking correctness')
print(report)

logger.info('hidet: symbolic inputs: ')
for symbolic_input in symbolic_inputs:
logger.info('hidet: %s', symbolic_input.signature())
logger.info('hidet: inputs: ')
for arg in inputs:
if isinstance(arg, hidet.Tensor):
logger.info('hidet: %s', arg.signature())
else:
logger.info('hidet: %s', arg)

# symbolic run to get flow graph
output = interpreter(*symbolic_inputs)
output = interpreter(*inputs)
output_format, output_tensors = serialize_output(output)
flow_graph: FlowGraph = hidet.trace_from(output_tensors, inputs=symbolic_inputs)
input_tensors = [x for x in inputs if isinstance(x, hidet.Tensor)]
input_tensor_indices = [i for (i, x) in enumerate(inputs) if isinstance(x, hidet.Tensor)]
flow_graph: FlowGraph = hidet.trace_from(output_tensors, inputs=input_tensors)

executor = generate_executor(flow_graph)

def wrapper(*args: Tensor):
args = [args[i] for i in input_tensor_indices]
outputs: Sequence[torch.Tensor] = executor(*args)
ret = deserialize_output(output_format, outputs)
return ret
Expand Down
51 changes: 46 additions & 5 deletions python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ def tensor_from_torch(tensor: torch.Tensor) -> Tensor:
return hidet.graph.tensor.from_torch(tensor)


def is_torch_path(name: str) -> bool:
name = name.split(".")
if len(name) > 0:
return name[0] == "torch"
return False


def belong_to_torch(code_obj) -> bool:
belongs = False
if hasattr(code_obj, "__module__") and code_obj.__module__ is not None:
belongs |= is_torch_path(code_obj.__module__)
if not belongs and hasattr(code_obj, "__package__") and code_obj.__package__ is not None:
belongs |= is_torch_path(code_obj.__package__)
return belongs


class HidetModule:
def __init__(self, torch_module: torch.nn.Module):
self.mod: torch.nn.Module = torch_module
Expand Down Expand Up @@ -182,6 +198,15 @@ def __init__(self, graph_module: torch.fx.GraphModule):
self.torch_modules: Dict[str, torch.nn.Module] = dict(graph_module.named_modules())
self.hidet_modules: Dict[str, HidetModule] = {}

# basically dynamo further wraps some builtin functions with annoying locals functions
# which gets dispatched incorrectly
self.ignore_funcs: Dict[str, Callable] = {
# see torch._dynamo.variables.lists.SizeVariable.get_item_dyn
# this signifies that the target of getitem is a torch.Size, we overload torch.Tensor.size by
# returning a list, so this method needs to be overloaded in the interpreter as well
'_dynamo_get_item_lambda': lambda target, index: target[index]
}

self._check_support()

def __call__(self, *args):
Expand Down Expand Up @@ -210,8 +235,10 @@ def _check_support(self):
if torch_cls not in Registry.registered_modules:
not_supported.add(torch_cls)
elif node.op == "call_function":
if node.target not in Registry.registered_functions:
target_fn = self._lookup_function(node.target)
if target_fn is None:
not_supported.add(node.target)

if len(not_supported) > 0:
lines = []
lines.append("The following modules/functions are not supported by hidet yet:")
Expand All @@ -233,6 +260,20 @@ def _lookup_hidet_method(self, torch_method):
raise NotImplementedError(f"hidet: method {method_name} is not supported yet.")
return Registry.registered_methods[torch_method]

def _lookup_function(self, code_obj):
if code_obj.__name__ in self.ignore_funcs:
return self.ignore_funcs[code_obj.__name__]
if belong_to_torch(code_obj):
if code_obj in Registry.registered_functions:
return Registry.registered_functions[code_obj]
else:
return None
else:
# this branch handles all the other cases, such as getitem, operator.add, etc.
# since the inputs are all hidet tensors, applying this function should resolve to
# the actual traced implementation
return code_obj

@staticmethod
def _callable_info(f: Callable) -> Tuple[str, str, int]:
if inspect.ismethod(f):
Expand Down Expand Up @@ -315,13 +356,13 @@ def load_arg(a, env):
attr = getattr(attr, atom)
hidet_env[node.name] = tensor_from_torch(attr) if isinstance(attr, torch.Tensor) else attr
elif node.op == "call_function":
hidet_func = Registry.registered_functions[node.target]
exec_func = self._lookup_function(node.target)
hidet_args = load_arg(node.args, hidet_env)
hidet_kwargs = load_arg(node.kwargs, hidet_env)
try:
hidet_env[node.name] = hidet_func(*hidet_args, **hidet_kwargs)
hidet_env[node.name] = exec_func(*hidet_args, **hidet_kwargs)
except Exception as e:
self._raise_exception(e, node.target, hidet_func, hidet_args, hidet_kwargs)
self._raise_exception(e, node.target, exec_func, hidet_args, hidet_kwargs)
elif node.op == "call_method":
args = load_arg(node.args, hidet_env)
kwargs = load_arg(node.kwargs, hidet_env)
Expand Down Expand Up @@ -403,7 +444,7 @@ def load_arg(a, env):
torch_kwargs = load_arg(node.kwargs, torch_env)
torch_env[node.name] = torch_func(*torch_args, **torch_kwargs)

hidet_func = Registry.registered_functions[torch_func]
hidet_func = self._lookup_function(node.target)
hidet_args = load_arg(node.args, hidet_env)
hidet_kwargs = load_arg(node.kwargs, hidet_env)

Expand Down
8 changes: 4 additions & 4 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ def bilinear(x_1: Tensor, x_2: Tensor, weight: Tensor, bias: Optional[Tensor]):
@register_function(operator.add)
@register_function(torch.ops.aten.add.Tensor)
def add(x: Tensor, y: Tensor):
return ops.add(x, y)
return x + y


@register_function(operator.iadd)
def iadd(x: Tensor, y: Tensor):
return ops.add(x, y)
return x + y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the x and y could be DynInt?



@register_function(torch.sin)
Expand Down Expand Up @@ -362,7 +362,7 @@ def zeros(*size, out=None, dtype=None, layout=None, device=None, pin_memory=Fals

@register_function(torch.ones)
def ones(
*size: Union[int, Sequence[int]],
*size: Union[Int, Sequence[Int]],
out: Optional[Tensor] = None,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
Expand All @@ -381,7 +381,7 @@ def ones(
if isinstance(size[0], (list, tuple)):
size = size[0]

shape = [int(v) for v in size]
shape = [v if isinstance(v, hidet.ir.Expr) else int(v) for v in size]
if dtype is None:
dtype = torch.get_default_dtype()

Expand Down
12 changes: 10 additions & 2 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import List, Union
import torch

from hidet.ir.type import DataType
from hidet.ir.type import DataType, Int
from hidet.graph.tensor import Tensor
from hidet.graph import ops
from hidet.runtime.device import instantiate_device
Expand Down Expand Up @@ -130,7 +130,7 @@ def tensor_view(self: Tensor, *args) -> Tensor:
else:
if len(args) == 1 and isinstance(args[0], (list, tuple)):
args = args[0]
dst_shape = [int(arg) for arg in args]
dst_shape = list(args)
return ops.reshape(self, dst_shape)


Expand Down Expand Up @@ -161,6 +161,14 @@ def tensor_split(self: Tensor, split_size, dim=0) -> List[Tensor]:
return ops.split(self, axis=dim, parts_or_sections=parts)


@register_method(torch.Tensor.size)
def tensor_size(self: Tensor, dim=None) -> List[Int]:
if dim is None:
return self.shape
else:
return self.shape[dim]


@register_method(torch.Tensor.chunk)
def tensor_chunk(self: Tensor, chunks, dim=0) -> List[Tensor]:
dim_size = self.shape[dim]
Expand Down
15 changes: 13 additions & 2 deletions python/hidet/graph/frontend/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,19 @@ def device_from_torch(torch_device) -> Device:
def symbol_like_torch(tensor) -> Tensor:
import hidet
import torch

if isinstance(tensor, torch.Tensor):
from torch._subclasses.fake_tensor import FakeTensor

if isinstance(tensor, FakeTensor):
# this should be fine for now; torch wraps around the sympy library
symbolic_shape = []
for s in tensor.shape:
try:
i = int(s)
except Exception: # pylint: disable=broad-except
i = str(s)
symbolic_shape.append(i)
return hidet.symbol(shape=symbolic_shape, dtype=dtype_from_torch(tensor.dtype).name, device=tensor.device.type)
elif isinstance(tensor, torch.Tensor):
return hidet.symbol(
shape=list(tensor.shape), dtype=dtype_from_torch(tensor.dtype).name, device=tensor.device.type
)
Expand Down
10 changes: 9 additions & 1 deletion python/hidet/graph/graph_utils/instruments/debug_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
class GraphForwardDebugInstrument(GraphForwardInstrument):
_template = '{:>5} {:>30} {:>3} {:<25} {:>8} {:>8} {:>8} {:>10} {:>10} {:>10} {:>10}'

def __init__(self, output_dir='./outs/debug', print_summary=False, dump_outputs=False):
def __init__(self, output_dir='./outs/debug', print_summary=False, dump_outputs=False, dump_op=False):
self.output_dir: str = output_dir
self.print_summary: bool = print_summary
self.dump_outputs: bool = dump_outputs
self.dump_op: bool = dump_op

self.debugging: bool = False
self.summary_file: Optional[str] = None
Expand Down Expand Up @@ -141,6 +142,13 @@ def after_operator(self, op: Operator, inputs: List[Tensor], outputs: List[Tenso
with open(array_path, 'w') as f:
with np.printoptions(precision=8, edgeitems=30, linewidth=512):
f.write(str(array))
if self.dump_op:
op_path = os.path.join(
self.output_dir, '{}_{}{}.txt'.format(self.operator_idx, op.name, f'_def{idx}' if idx > 0 else '')
)
with open(op_path, 'w') as f:
f.write('Operator:\n{}\n'.format(op))
f.write('Task:\n{}\n'.format(op.task))

with open(self.summary_file, 'a') as f:
f.write('\n'.join(lines) + '\n')
Expand Down
6 changes: 3 additions & 3 deletions python/hidet/graph/ops/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

class FullTask(Task):
def __init__(
self, shape: Sequence[int], value: Union[int, float, bool, Constant, Expr], dtype: Union[DataType, str]
self, shape: Sequence[Int], value: Union[int, float, bool, Constant, Expr], dtype: Union[DataType, str]
):
dtype: DataType = data_type(dtype)
value: Constant = dtype(value) if isinstance(value, (int, float, bool)) else value
Expand Down Expand Up @@ -123,12 +123,12 @@ def infer_dtype(self, start, stop, step):
class FullOp(Operator):
def __init__(
self,
shape: Sequence[int],
shape: Sequence[Int],
value: Union[float, int, bool, Constant, Tensor],
dtype: Optional[DataType] = None,
device: Union[Device, str] = 'cpu',
):
shape = [int(v) for v in shape]
shape = list(shape)
device: Device = instantiate_device(device)

if isinstance(value, Tensor):
Expand Down
21 changes: 3 additions & 18 deletions python/hidet/graph/ops/normalize/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hidet

from ..utils import Tensor, normalize_dim
from ..arithmetic import rsqrt
from .norm import normalize
from .norm_f16 import normalize_f16


def resolve_norm_func(dtype):
if dtype == hidet.float32:
return normalize
elif dtype == hidet.float16:
return normalize_f16
else:
raise NotImplementedError("normalize function for dtype {} is not implemented".format(dtype))


def batch_norm_infer(x: Tensor, running_mean: Tensor, running_var: Tensor, epsilon=1e-5, axis=1) -> Tensor:
Expand Down Expand Up @@ -58,8 +46,7 @@ def instance_norm(x: Tensor, epsilon: float = 1e-5, accumulate_dtype: str = 'flo
The normalized tensor.
"""
dims = [dim for dim in range(2, len(x.shape))]
norm_func = resolve_norm_func(x.dtype)
return norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
return normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)


def layer_norm(x: Tensor, num_last_dims: int = 1, epsilon: float = 1e-5, accumulate_dtype: str = 'float32') -> Tensor:
Expand All @@ -82,9 +69,8 @@ def layer_norm(x: Tensor, num_last_dims: int = 1, epsilon: float = 1e-5, accumul
ret: Tensor
The normalized tensor.
"""
norm_func = resolve_norm_func(x.dtype)
dims = list(range(len(x.shape) - num_last_dims, len(x.shape)))
return norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
return normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)


def group_norm(x: Tensor, num_groups, epsilon: float = 1e-5, accumulate_dtype: str = 'float32'):
Expand Down Expand Up @@ -119,7 +105,6 @@ def group_norm(x: Tensor, num_groups, epsilon: float = 1e-5, accumulate_dtype: s

x = x.reshape(new_shape)
dims = list(range(2, len(x.shape)))
norm_func = resolve_norm_func(x.dtype)
normed = norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
normed = normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)

return normed.reshape(x_shape)
Loading