Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Dynamic shape fx trace #294

Merged
merged 24 commits into from
Jul 12, 2023
Merged
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
803d302
lint again for some reason
Aalanli Jun 9, 2023
a59ba09
lint again for some reason
Aalanli Jun 9, 2023
0be7d38
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 9, 2023
06a9cd0
nevermind
Aalanli Jun 9, 2023
f550d2f
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 12, 2023
fd19d60
Merge branch 'hidet-org:main' into main
Aalanli Jun 17, 2023
323e1a8
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 26, 2023
d26f6aa
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 26, 2023
7c4fdc1
add support for dynamic shape compilation
Jun 27, 2023
cb70f69
format/lint
Jun 27, 2023
f9ae927
add extra check
Aalanli Jun 29, 2023
87abfe9
add another helpful debug option
Aalanli Jun 29, 2023
20a61c6
minor fixes
Aalanli Jun 29, 2023
ad7f50f
fix norm bug
Aalanli Jun 29, 2023
8dcbc39
format/lint
Aalanli Jun 29, 2023
f58fed5
fix bug in predict correctness conditional
Jun 29, 2023
48b8190
Merge branch 'main' into dynamic-fx-trace
Jun 30, 2023
b92442c
lint
Jun 30, 2023
cf21a7d
Merge branch 'main' into dynamic-fx-trace
Jul 4, 2023
b83f399
slightly change
yaoyaoding Jul 11, 2023
f5ea3ab
fix a small bug in reshape
yaoyaoding Jul 11, 2023
3e4bc32
reshape fix
Aalanli Jul 12, 2023
1546bf4
Merge branch 'dynamic-fx-trace' of https://github.com/Aalanli/hidet i…
Aalanli Jul 12, 2023
4b6be25
fix tensor_size registered method
Aalanli Jul 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
import torch
import hidet.option
from hidet.ir.type import data_type
from hidet.ir.expr import is_constant
from hidet.graph.flow_graph import FlowGraph
from hidet.graph.transforms import PassContext, optimize
from hidet.runtime import CompiledGraph
@@ -103,40 +104,66 @@ def hidet_backend(graph_module, example_inputs):
interpreter: Interpreter = hidet.frontend.from_torch(graph_module)

# prepare dummy and symbolic inputs for correctness and flow graph construction
symbolic_inputs: List[Tensor] = [] # for flow graph construction
# unfortunately, when dynamic=True in torch.compile, there may exist other non-tensor parameters
# in example inputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For those dynamic shape, I am wondering if these scalar parameters are act as the shape of the input tensors. If that's the case, we can ignore those scalar parameters.

Say a torch model gives us

sample_inputs = [tensor(['m', 'n'], 'm', 'n']

We can declare the symbol variable for 'm' and 'n' (when we define the symbol tensor) and ignore the 'm' and 'n' scalar parameters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any clue on this?

inputs = [] # for flow graph construction
for example_input in example_inputs:
if isinstance(example_input, torch.Tensor):
symbolic_input = symbol_like_torch(example_input)
symbolic_inputs.append(symbolic_input)
inputs.append(symbolic_input)
elif isinstance(example_input, int):
inputs.append(symbolic_input)
elif isinstance(example_input, torch.SymInt):
try:
inputs.append(int(example_input))
except Exception as e:
raise ValueError(f"hidet_backend: free symbolic example input {example_input}") from e
else:
raise ValueError('hidet_backend: only support torch.Tensor as example input')
raise ValueError(f'hidet_backend: unexpected example input {example_input}, type {type(example_input)}')

if dynamo_config['correctness_report']:
# check correctness using random inputs
logger.info('start to check correctness')
dummy_inputs: List[Tensor] = [] # for correctness check
for symbolic_input in symbolic_inputs:
if data_type(symbolic_input.dtype).is_integer():
dummy_input = hidet.zeros_like(symbolic_input)
# there exist some symbolic shapes, currently we don't support this option
# as there is no way to principly get concrete shapes at this stage from symbolic shapes
# since some models like resnet requires the image to be above a certain size.
if any(not all(is_constant(s) for s in t.shape) for t in inputs if isinstance(t, hidet.Tensor)):
raise ValueError("hidet_backend: cannot print correctness report with dynamic=True")
dummy_inputs = [] # for correctness check
for arg in inputs:
if isinstance(arg, hidet.Tensor):
if data_type(arg.dtype).is_integer():
dummy_input = hidet.zeros_like(arg)
else:
dummy_input = hidet.randn_like(arg)
else:
dummy_input = hidet.randn_like(symbolic_input)
dummy_input = arg
dummy_inputs.append(dummy_input)
report: str = interpreter.forward_with_check(*dummy_inputs)
logger.info('finish checking correctness')
print(report)

logger.info('hidet: symbolic inputs: ')
for symbolic_input in symbolic_inputs:
logger.info('hidet: %s', symbolic_input.signature())
logger.info('hidet: inputs: ')
for arg in inputs:
if isinstance(arg, hidet.Tensor):
logger.info('hidet: %s', arg.signature())
else:
logger.info('hidet: %s', arg)

# symbolic run to get flow graph
output = interpreter(*symbolic_inputs)
output = interpreter(*inputs)
output_format, output_tensors = serialize_output(output)
flow_graph: FlowGraph = hidet.trace_from(output_tensors, inputs=symbolic_inputs)
input_tensors = list(filter(lambda x: isinstance(x, hidet.Tensor), inputs))
# essentially, I think this is a bug in torch._inductor
# the example inputs have instances of torch.SymInt (when dynamic=True), while the inputs to the compiled model
# are torch.Tensors.
input_map = [isinstance(x, hidet.Tensor) for x in inputs]
flow_graph: FlowGraph = hidet.trace_from(output_tensors, inputs=input_tensors)

executor = generate_executor(flow_graph)

def wrapper(*args: Tensor):
args = [t for (t, is_hidet_tensor) in zip(args, input_map) if is_hidet_tensor]
outputs: Sequence[torch.Tensor] = executor(*args)
ret = deserialize_output(output_format, outputs)
return ret
51 changes: 46 additions & 5 deletions python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
@@ -149,6 +149,22 @@ def tensor_from_torch(tensor: torch.Tensor) -> Tensor:
return hidet.graph.tensor.from_torch(tensor)


def is_torch_path(name: str) -> bool:
name = name.split(".")
if len(name) > 0:
return name[0] == "torch"
return False


def belong_to_torch(code_obj) -> bool:
belongs = False
if hasattr(code_obj, "__module__") and code_obj.__module__ is not None:
belongs |= is_torch_path(code_obj.__module__)
if not belongs and hasattr(code_obj, "__package__") and code_obj.__package__ is not None:
belongs |= is_torch_path(code_obj.__package__)
return belongs


class HidetModule:
def __init__(self, torch_module: torch.nn.Module):
self.mod: torch.nn.Module = torch_module
@@ -182,6 +198,15 @@ def __init__(self, graph_module: torch.fx.GraphModule):
self.torch_modules: Dict[str, torch.nn.Module] = dict(graph_module.named_modules())
self.hidet_modules: Dict[str, HidetModule] = {}

# basically dynamo further wraps some builtin functions with annoying locals functions
# which gets dispatched incorrectly
self.ignore_funcs: Dict[str, Callable] = {
# see torch._dynamo.variables.lists.SizeVariable.get_item_dyn
# this signifies that the target of getitem is a torch.Size, we overload torch.Tensor.size by
# returning a list, so this method needs to be overloaded in the interpreter as well
'_dynamo_get_item_lambda': lambda target, index: target[index]
}

self._check_support()

def __call__(self, *args):
@@ -210,8 +235,10 @@ def _check_support(self):
if torch_cls not in Registry.registered_modules:
not_supported.add(torch_cls)
elif node.op == "call_function":
if node.target not in Registry.registered_functions:
target_fn = self._lookup_function(node.target)
if target_fn is None:
not_supported.add(node.target)

if len(not_supported) > 0:
lines = []
lines.append("The following modules/functions are not supported by hidet yet:")
@@ -233,6 +260,20 @@ def _lookup_hidet_method(self, torch_method):
raise NotImplementedError(f"hidet: method {method_name} is not supported yet.")
return Registry.registered_methods[torch_method]

def _lookup_function(self, code_obj):
if code_obj.__name__ in self.ignore_funcs:
return self.ignore_funcs[code_obj.__name__]
if belong_to_torch(code_obj):
if code_obj in Registry.registered_functions:
return Registry.registered_functions[code_obj]
else:
return None
else:
# this branch handles all the other cases, such as getitem, operator.add, etc.
# since the inputs are all hidet tensors, applying this function should resolve to
# the actual traced implementation
return code_obj

@staticmethod
def _callable_info(f: Callable) -> Tuple[str, str, int]:
if inspect.ismethod(f):
@@ -315,13 +356,13 @@ def load_arg(a, env):
attr = getattr(attr, atom)
hidet_env[node.name] = tensor_from_torch(attr) if isinstance(attr, torch.Tensor) else attr
elif node.op == "call_function":
hidet_func = Registry.registered_functions[node.target]
exec_func = self._lookup_function(node.target)
hidet_args = load_arg(node.args, hidet_env)
hidet_kwargs = load_arg(node.kwargs, hidet_env)
try:
hidet_env[node.name] = hidet_func(*hidet_args, **hidet_kwargs)
hidet_env[node.name] = exec_func(*hidet_args, **hidet_kwargs)
except Exception as e:
self._raise_exception(e, node.target, hidet_func, hidet_args, hidet_kwargs)
self._raise_exception(e, node.target, exec_func, hidet_args, hidet_kwargs)
elif node.op == "call_method":
args = load_arg(node.args, hidet_env)
kwargs = load_arg(node.kwargs, hidet_env)
@@ -403,7 +444,7 @@ def load_arg(a, env):
torch_kwargs = load_arg(node.kwargs, torch_env)
torch_env[node.name] = torch_func(*torch_args, **torch_kwargs)

hidet_func = Registry.registered_functions[torch_func]
hidet_func = self._lookup_function(node.target)
hidet_args = load_arg(node.args, hidet_env)
hidet_kwargs = load_arg(node.kwargs, hidet_env)

8 changes: 4 additions & 4 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
@@ -145,12 +145,12 @@ def bilinear(x_1: Tensor, x_2: Tensor, weight: Tensor, bias: Optional[Tensor]):
@register_function(operator.add)
@register_function(torch.ops.aten.add.Tensor)
def add(x: Tensor, y: Tensor):
return ops.add(x, y)
return x + y


@register_function(operator.iadd)
def iadd(x: Tensor, y: Tensor):
return ops.add(x, y)
return x + y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the x and y could be DynInt?



@register_function(torch.sin)
@@ -362,7 +362,7 @@ def zeros(*size, out=None, dtype=None, layout=None, device=None, pin_memory=Fals

@register_function(torch.ones)
def ones(
*size: Union[int, Sequence[int]],
*size: Union[Int, Sequence[Int]],
out: Optional[Tensor] = None,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
@@ -381,7 +381,7 @@ def ones(
if isinstance(size[0], (list, tuple)):
size = size[0]

shape = [int(v) for v in size]
shape = [v if isinstance(v, hidet.ir.Expr) else int(v) for v in size]
if dtype is None:
dtype = torch.get_default_dtype()

9 changes: 7 additions & 2 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
from typing import List, Union
import torch

from hidet.ir.type import DataType
from hidet.ir.type import DataType, Int
from hidet.graph.tensor import Tensor
from hidet.graph import ops
from hidet.runtime.device import instantiate_device
@@ -130,7 +130,7 @@ def tensor_view(self: Tensor, *args) -> Tensor:
else:
if len(args) == 1 and isinstance(args[0], (list, tuple)):
args = args[0]
dst_shape = [int(arg) for arg in args]
dst_shape = list(args)
return ops.reshape(self, dst_shape)


@@ -161,6 +161,11 @@ def tensor_split(self: Tensor, split_size, dim=0) -> List[Tensor]:
return ops.split(self, axis=dim, parts_or_sections=parts)


@register_method(torch.Tensor.size)
def tensor_size(self: Tensor) -> List[Int]:
return self.shape


@register_method(torch.Tensor.chunk)
def tensor_chunk(self: Tensor, chunks, dim=0) -> List[Tensor]:
dim_size = self.shape[dim]
15 changes: 13 additions & 2 deletions python/hidet/graph/frontend/torch/utils.py
Original file line number Diff line number Diff line change
@@ -112,8 +112,19 @@ def device_from_torch(torch_device) -> Device:
def symbol_like_torch(tensor) -> Tensor:
import hidet
import torch

if isinstance(tensor, torch.Tensor):
from torch._subclasses.fake_tensor import FakeTensor

if isinstance(tensor, FakeTensor):
# this should be fine for now; torch wraps around the sympy library
symbolic_shape = []
for s in tensor.shape:
try:
i = int(s)
except Exception: # pylint: disable=broad-except
i = str(s)
symbolic_shape.append(i)
return hidet.symbol(shape=symbolic_shape, dtype=dtype_from_torch(tensor.dtype).name, device=tensor.device.type)
elif isinstance(tensor, torch.Tensor):
return hidet.symbol(
shape=list(tensor.shape), dtype=dtype_from_torch(tensor.dtype).name, device=tensor.device.type
)
10 changes: 9 additions & 1 deletion python/hidet/graph/graph_utils/instruments/debug_instrument.py
Original file line number Diff line number Diff line change
@@ -19,10 +19,11 @@
class GraphForwardDebugInstrument(GraphForwardInstrument):
_template = '{:>5} {:>30} {:>3} {:<25} {:>8} {:>8} {:>8} {:>10} {:>10} {:>10} {:>10}'

def __init__(self, output_dir='./outs/debug', print_summary=False, dump_outputs=False):
def __init__(self, output_dir='./outs/debug', print_summary=False, dump_outputs=False, dump_op=False):
self.output_dir: str = output_dir
self.print_summary: bool = print_summary
self.dump_outputs: bool = dump_outputs
self.dump_op: bool = dump_op

self.debugging: bool = False
self.summary_file: Optional[str] = None
@@ -141,6 +142,13 @@ def after_operator(self, op: Operator, inputs: List[Tensor], outputs: List[Tenso
with open(array_path, 'w') as f:
with np.printoptions(precision=8, edgeitems=30, linewidth=512):
f.write(str(array))
if self.dump_op:
op_path = os.path.join(
self.output_dir, '{}_{}{}.txt'.format(self.operator_idx, op.name, f'_def{idx}' if idx > 0 else '')
)
with open(op_path, 'w') as f:
f.write('Operator:\n{}\n'.format(op))
f.write('Task:\n{}\n'.format(op.task))

with open(self.summary_file, 'a') as f:
f.write('\n'.join(lines) + '\n')
6 changes: 3 additions & 3 deletions python/hidet/graph/ops/create.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@

class FullTask(Task):
def __init__(
self, shape: Sequence[int], value: Union[int, float, bool, Constant, Expr], dtype: Union[DataType, str]
self, shape: Sequence[Int], value: Union[int, float, bool, Constant, Expr], dtype: Union[DataType, str]
):
dtype: DataType = data_type(dtype)
value: Constant = dtype(value) if isinstance(value, (int, float, bool)) else value
@@ -123,12 +123,12 @@ def infer_dtype(self, start, stop, step):
class FullOp(Operator):
def __init__(
self,
shape: Sequence[int],
shape: Sequence[Int],
value: Union[float, int, bool, Constant, Tensor],
dtype: Optional[DataType] = None,
device: Union[Device, str] = 'cpu',
):
shape = [int(v) for v in shape]
shape = list(shape)
device: Device = instantiate_device(device)

if isinstance(value, Tensor):
21 changes: 3 additions & 18 deletions python/hidet/graph/ops/normalize/layers.py
Original file line number Diff line number Diff line change
@@ -9,21 +9,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hidet

from ..utils import Tensor, normalize_dim
from ..arithmetic import rsqrt
from .norm import normalize
from .norm_f16 import normalize_f16


def resolve_norm_func(dtype):
if dtype == hidet.float32:
return normalize
elif dtype == hidet.float16:
return normalize_f16
else:
raise NotImplementedError("normalize function for dtype {} is not implemented".format(dtype))


def batch_norm_infer(x: Tensor, running_mean: Tensor, running_var: Tensor, epsilon=1e-5, axis=1) -> Tensor:
@@ -58,8 +46,7 @@ def instance_norm(x: Tensor, epsilon: float = 1e-5, accumulate_dtype: str = 'flo
The normalized tensor.
"""
dims = [dim for dim in range(2, len(x.shape))]
norm_func = resolve_norm_func(x.dtype)
return norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
return normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)


def layer_norm(x: Tensor, num_last_dims: int = 1, epsilon: float = 1e-5, accumulate_dtype: str = 'float32') -> Tensor:
@@ -82,9 +69,8 @@ def layer_norm(x: Tensor, num_last_dims: int = 1, epsilon: float = 1e-5, accumul
ret: Tensor
The normalized tensor.
"""
norm_func = resolve_norm_func(x.dtype)
dims = list(range(len(x.shape) - num_last_dims, len(x.shape)))
return norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
return normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)


def group_norm(x: Tensor, num_groups, epsilon: float = 1e-5, accumulate_dtype: str = 'float32'):
@@ -119,7 +105,6 @@ def group_norm(x: Tensor, num_groups, epsilon: float = 1e-5, accumulate_dtype: s

x = x.reshape(new_shape)
dims = list(range(2, len(x.shape)))
norm_func = resolve_norm_func(x.dtype)
normed = norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
normed = normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)

return normed.reshape(x_shape)
Loading