Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Dynamic shape fx trace #294

Merged
merged 24 commits into from
Jul 12, 2023
Merged
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
803d302
lint again for some reason
Aalanli Jun 9, 2023
a59ba09
lint again for some reason
Aalanli Jun 9, 2023
0be7d38
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 9, 2023
06a9cd0
nevermind
Aalanli Jun 9, 2023
f550d2f
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 12, 2023
fd19d60
Merge branch 'hidet-org:main' into main
Aalanli Jun 17, 2023
323e1a8
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 26, 2023
d26f6aa
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 26, 2023
7c4fdc1
add support for dynamic shape compilation
Jun 27, 2023
cb70f69
format/lint
Jun 27, 2023
f9ae927
add extra check
Aalanli Jun 29, 2023
87abfe9
add another helpful debug option
Aalanli Jun 29, 2023
20a61c6
minor fixes
Aalanli Jun 29, 2023
ad7f50f
fix norm bug
Aalanli Jun 29, 2023
8dcbc39
format/lint
Aalanli Jun 29, 2023
f58fed5
fix bug in predict correctness conditional
Jun 29, 2023
48b8190
Merge branch 'main' into dynamic-fx-trace
Jun 30, 2023
b92442c
lint
Jun 30, 2023
cf21a7d
Merge branch 'main' into dynamic-fx-trace
Jul 4, 2023
b83f399
slightly change
yaoyaoding Jul 11, 2023
f5ea3ab
fix a small bug in reshape
yaoyaoding Jul 11, 2023
3e4bc32
reshape fix
Aalanli Jul 12, 2023
1546bf4
Merge branch 'dynamic-fx-trace' of https://github.com/Aalanli/hidet i…
Aalanli Jul 12, 2023
4b6be25
fix tensor_size registered method
Aalanli Jul 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
import torch
import hidet.option
from hidet.ir.type import data_type
from hidet.ir.expr import is_constant
from hidet.graph.flow_graph import FlowGraph
from hidet.graph.transforms import PassContext, optimize
from hidet.runtime import CompiledGraph
@@ -103,40 +104,66 @@ def hidet_backend(graph_module, example_inputs):
interpreter: Interpreter = hidet.frontend.from_torch(graph_module)

# prepare dummy and symbolic inputs for correctness and flow graph construction
symbolic_inputs: List[Tensor] = [] # for flow graph construction
# unfortunately, when dynamic=True in torch.compile, there may exist other non-tensor parameters
# in example inputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For those dynamic shape, I am wondering if these scalar parameters are act as the shape of the input tensors. If that's the case, we can ignore those scalar parameters.

Say a torch model gives us

sample_inputs = [tensor(['m', 'n'], 'm', 'n']

We can declare the symbol variable for 'm' and 'n' (when we define the symbol tensor) and ignore the 'm' and 'n' scalar parameters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any clue on this?

inputs = [] # for flow graph construction
for example_input in example_inputs:
if isinstance(example_input, torch.Tensor):
symbolic_input = symbol_like_torch(example_input)
symbolic_inputs.append(symbolic_input)
inputs.append(symbolic_input)
elif isinstance(example_input, int):
inputs.append(symbolic_input)
elif isinstance(example_input, torch.SymInt):
try:
inputs.append(int(example_input))
except Exception as e:
raise ValueError(f"hidet_backend: free symbolic example input {example_input}") from e
else:
raise ValueError('hidet_backend: only support torch.Tensor as example input')
raise ValueError(f'hidet_backend: unexpected example input {example_input}, type {type(example_input)}')

if dynamo_config['correctness_report']:
# check correctness using random inputs
logger.info('start to check correctness')
dummy_inputs: List[Tensor] = [] # for correctness check
for symbolic_input in symbolic_inputs:
if data_type(symbolic_input.dtype).is_integer():
dummy_input = hidet.zeros_like(symbolic_input)
# there exist some symbolic shapes, currently we don't support this option
# as there is no way to principly get concrete shapes at this stage from symbolic shapes
# since some models like resnet requires the image to be above a certain size.
if any(not all(is_constant(s) for s in t.shape) for t in inputs if isinstance(t, hidet.Tensor)):
raise ValueError("hidet_backend: cannot print correctness report with dynamic=True")
dummy_inputs = [] # for correctness check
for arg in inputs:
if isinstance(arg, hidet.Tensor):
if data_type(arg.dtype).is_integer():
dummy_input = hidet.zeros_like(arg)
else:
dummy_input = hidet.randn_like(arg)
else:
dummy_input = hidet.randn_like(symbolic_input)
dummy_input = arg
dummy_inputs.append(dummy_input)
report: str = interpreter.forward_with_check(*dummy_inputs)
logger.info('finish checking correctness')
print(report)

logger.info('hidet: symbolic inputs: ')
for symbolic_input in symbolic_inputs:
logger.info('hidet: %s', symbolic_input.signature())
logger.info('hidet: inputs: ')
for arg in inputs:
if isinstance(arg, hidet.Tensor):
logger.info('hidet: %s', arg.signature())
else:
logger.info('hidet: %s', arg)

# symbolic run to get flow graph
output = interpreter(*symbolic_inputs)
output = interpreter(*inputs)
output_format, output_tensors = serialize_output(output)
flow_graph: FlowGraph = hidet.trace_from(output_tensors, inputs=symbolic_inputs)
input_tensors = list(filter(lambda x: isinstance(x, hidet.Tensor), inputs))
# essentially, I think this is a bug in torch._inductor
# the example inputs have instances of torch.SymInt (when dynamic=True), while the inputs to the compiled model
# are torch.Tensors.
input_map = [isinstance(x, hidet.Tensor) for x in inputs]
flow_graph: FlowGraph = hidet.trace_from(output_tensors, inputs=input_tensors)

executor = generate_executor(flow_graph)

def wrapper(*args: Tensor):
args = [t for (t, is_hidet_tensor) in zip(args, input_map) if is_hidet_tensor]
outputs: Sequence[torch.Tensor] = executor(*args)
ret = deserialize_output(output_format, outputs)
return ret
51 changes: 46 additions & 5 deletions python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
@@ -149,6 +149,22 @@ def tensor_from_torch(tensor: torch.Tensor) -> Tensor:
return hidet.graph.tensor.from_torch(tensor)


def is_torch_path(name: str) -> bool:
name = name.split(".")
if len(name) > 0:
return name[0] == "torch"
return False


def belong_to_torch(code_obj) -> bool:
belongs = False
if hasattr(code_obj, "__module__") and code_obj.__module__ is not None:
belongs |= is_torch_path(code_obj.__module__)
if not belongs and hasattr(code_obj, "__package__") and code_obj.__package__ is not None:
belongs |= is_torch_path(code_obj.__package__)
return belongs


class HidetModule:
def __init__(self, torch_module: torch.nn.Module):
self.mod: torch.nn.Module = torch_module
@@ -182,6 +198,15 @@ def __init__(self, graph_module: torch.fx.GraphModule):
self.torch_modules: Dict[str, torch.nn.Module] = dict(graph_module.named_modules())
self.hidet_modules: Dict[str, HidetModule] = {}

# basically dynamo further wraps some builtin functions with annoying locals functions
# which gets dispatched incorrectly
self.ignore_funcs: Dict[str, Callable] = {
# see torch._dynamo.variables.lists.SizeVariable.get_item_dyn
# this signifies that the target of getitem is a torch.Size, we overload torch.Tensor.size by
# returning a list, so this method needs to be overloaded in the interpreter as well
'_dynamo_get_item_lambda': lambda target, index: target[index]
}

self._check_support()

def __call__(self, *args):
@@ -210,8 +235,10 @@ def _check_support(self):
if torch_cls not in Registry.registered_modules:
not_supported.add(torch_cls)
elif node.op == "call_function":
if node.target not in Registry.registered_functions:
target_fn = self._lookup_function(node.target)
if target_fn is None:
not_supported.add(node.target)

if len(not_supported) > 0:
lines = []
lines.append("The following modules/functions are not supported by hidet yet:")
@@ -233,6 +260,20 @@ def _lookup_hidet_method(self, torch_method):
raise NotImplementedError(f"hidet: method {method_name} is not supported yet.")
return Registry.registered_methods[torch_method]

def _lookup_function(self, code_obj):
if code_obj.__name__ in self.ignore_funcs:
return self.ignore_funcs[code_obj.__name__]
if belong_to_torch(code_obj):
if code_obj in Registry.registered_functions:
return Registry.registered_functions[code_obj]
else:
return None
else:
# this branch handles all the other cases, such as getitem, operator.add, etc.
# since the inputs are all hidet tensors, applying this function should resolve to
# the actual traced implementation
return code_obj

@staticmethod
def _callable_info(f: Callable) -> Tuple[str, str, int]:
if inspect.ismethod(f):
@@ -315,13 +356,13 @@ def load_arg(a, env):
attr = getattr(attr, atom)
hidet_env[node.name] = tensor_from_torch(attr) if isinstance(attr, torch.Tensor) else attr
elif node.op == "call_function":
hidet_func = Registry.registered_functions[node.target]
exec_func = self._lookup_function(node.target)
hidet_args = load_arg(node.args, hidet_env)
hidet_kwargs = load_arg(node.kwargs, hidet_env)
try:
hidet_env[node.name] = hidet_func(*hidet_args, **hidet_kwargs)
hidet_env[node.name] = exec_func(*hidet_args, **hidet_kwargs)
except Exception as e:
self._raise_exception(e, node.target, hidet_func, hidet_args, hidet_kwargs)
self._raise_exception(e, node.target, exec_func, hidet_args, hidet_kwargs)
elif node.op == "call_method":
args = load_arg(node.args, hidet_env)
kwargs = load_arg(node.kwargs, hidet_env)
@@ -403,7 +444,7 @@ def load_arg(a, env):
torch_kwargs = load_arg(node.kwargs, torch_env)
torch_env[node.name] = torch_func(*torch_args, **torch_kwargs)

hidet_func = Registry.registered_functions[torch_func]
hidet_func = self._lookup_function(node.target)
hidet_args = load_arg(node.args, hidet_env)
hidet_kwargs = load_arg(node.kwargs, hidet_env)

8 changes: 4 additions & 4 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
@@ -145,12 +145,12 @@ def bilinear(x_1: Tensor, x_2: Tensor, weight: Tensor, bias: Optional[Tensor]):
@register_function(operator.add)
@register_function(torch.ops.aten.add.Tensor)
def add(x: Tensor, y: Tensor):
return ops.add(x, y)
return x + y


@register_function(operator.iadd)
def iadd(x: Tensor, y: Tensor):
return ops.add(x, y)
return x + y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the x and y could be DynInt?



@register_function(torch.sin)
@@ -362,7 +362,7 @@ def zeros(*size, out=None, dtype=None, layout=None, device=None, pin_memory=Fals

@register_function(torch.ones)
def ones(
*size: Union[int, Sequence[int]],
*size: Union[Int, Sequence[Int]],
out: Optional[Tensor] = None,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
@@ -381,7 +381,7 @@ def ones(
if isinstance(size[0], (list, tuple)):
size = size[0]

shape = [int(v) for v in size]
shape = [v if isinstance(v, hidet.ir.Expr) else int(v) for v in size]
if dtype is None:
dtype = torch.get_default_dtype()

9 changes: 7 additions & 2 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
from typing import List, Union
import torch

from hidet.ir.type import DataType
from hidet.ir.type import DataType, Int
from hidet.graph.tensor import Tensor
from hidet.graph import ops
from hidet.runtime.device import instantiate_device
@@ -130,7 +130,7 @@ def tensor_view(self: Tensor, *args) -> Tensor:
else:
if len(args) == 1 and isinstance(args[0], (list, tuple)):
args = args[0]
dst_shape = [int(arg) for arg in args]
dst_shape = list(args)
return ops.reshape(self, dst_shape)


@@ -161,6 +161,11 @@ def tensor_split(self: Tensor, split_size, dim=0) -> List[Tensor]:
return ops.split(self, axis=dim, parts_or_sections=parts)


@register_method(torch.Tensor.size)
def tensor_size(self: Tensor) -> List[Int]:
return self.shape


@register_method(torch.Tensor.chunk)
def tensor_chunk(self: Tensor, chunks, dim=0) -> List[Tensor]:
dim_size = self.shape[dim]
15 changes: 13 additions & 2 deletions python/hidet/graph/frontend/torch/utils.py
Original file line number Diff line number Diff line change
@@ -112,8 +112,19 @@ def device_from_torch(torch_device) -> Device:
def symbol_like_torch(tensor) -> Tensor:
import hidet
import torch

if isinstance(tensor, torch.Tensor):
from torch._subclasses.fake_tensor import FakeTensor

if isinstance(tensor, FakeTensor):
# this should be fine for now; torch wraps around the sympy library
symbolic_shape = []
for s in tensor.shape:
try:
i = int(s)
except Exception: # pylint: disable=broad-except
i = str(s)
symbolic_shape.append(i)
return hidet.symbol(shape=symbolic_shape, dtype=dtype_from_torch(tensor.dtype).name, device=tensor.device.type)
elif isinstance(tensor, torch.Tensor):
return hidet.symbol(
shape=list(tensor.shape), dtype=dtype_from_torch(tensor.dtype).name, device=tensor.device.type
)
10 changes: 9 additions & 1 deletion python/hidet/graph/graph_utils/instruments/debug_instrument.py
Original file line number Diff line number Diff line change
@@ -19,10 +19,11 @@
class GraphForwardDebugInstrument(GraphForwardInstrument):
_template = '{:>5} {:>30} {:>3} {:<25} {:>8} {:>8} {:>8} {:>10} {:>10} {:>10} {:>10}'

def __init__(self, output_dir='./outs/debug', print_summary=False, dump_outputs=False):
def __init__(self, output_dir='./outs/debug', print_summary=False, dump_outputs=False, dump_op=False):
self.output_dir: str = output_dir
self.print_summary: bool = print_summary
self.dump_outputs: bool = dump_outputs
self.dump_op: bool = dump_op

self.debugging: bool = False
self.summary_file: Optional[str] = None
@@ -141,6 +142,13 @@ def after_operator(self, op: Operator, inputs: List[Tensor], outputs: List[Tenso
with open(array_path, 'w') as f:
with np.printoptions(precision=8, edgeitems=30, linewidth=512):
f.write(str(array))
if self.dump_op:
op_path = os.path.join(
self.output_dir, '{}_{}{}.txt'.format(self.operator_idx, op.name, f'_def{idx}' if idx > 0 else '')
)
with open(op_path, 'w') as f:
f.write('Operator:\n{}\n'.format(op))
f.write('Task:\n{}\n'.format(op.task))

with open(self.summary_file, 'a') as f:
f.write('\n'.join(lines) + '\n')
6 changes: 3 additions & 3 deletions python/hidet/graph/ops/create.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@

class FullTask(Task):
def __init__(
self, shape: Sequence[int], value: Union[int, float, bool, Constant, Expr], dtype: Union[DataType, str]
self, shape: Sequence[Int], value: Union[int, float, bool, Constant, Expr], dtype: Union[DataType, str]
):
dtype: DataType = data_type(dtype)
value: Constant = dtype(value) if isinstance(value, (int, float, bool)) else value
@@ -123,12 +123,12 @@ def infer_dtype(self, start, stop, step):
class FullOp(Operator):
def __init__(
self,
shape: Sequence[int],
shape: Sequence[Int],
value: Union[float, int, bool, Constant, Tensor],
dtype: Optional[DataType] = None,
device: Union[Device, str] = 'cpu',
):
shape = [int(v) for v in shape]
shape = list(shape)
device: Device = instantiate_device(device)

if isinstance(value, Tensor):
21 changes: 3 additions & 18 deletions python/hidet/graph/ops/normalize/layers.py
Original file line number Diff line number Diff line change
@@ -9,21 +9,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hidet

from ..utils import Tensor, normalize_dim
from ..arithmetic import rsqrt
from .norm import normalize
from .norm_f16 import normalize_f16


def resolve_norm_func(dtype):
if dtype == hidet.float32:
return normalize
elif dtype == hidet.float16:
return normalize_f16
else:
raise NotImplementedError("normalize function for dtype {} is not implemented".format(dtype))


def batch_norm_infer(x: Tensor, running_mean: Tensor, running_var: Tensor, epsilon=1e-5, axis=1) -> Tensor:
@@ -58,8 +46,7 @@ def instance_norm(x: Tensor, epsilon: float = 1e-5, accumulate_dtype: str = 'flo
The normalized tensor.
"""
dims = [dim for dim in range(2, len(x.shape))]
norm_func = resolve_norm_func(x.dtype)
return norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
return normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)


def layer_norm(x: Tensor, num_last_dims: int = 1, epsilon: float = 1e-5, accumulate_dtype: str = 'float32') -> Tensor:
@@ -82,9 +69,8 @@ def layer_norm(x: Tensor, num_last_dims: int = 1, epsilon: float = 1e-5, accumul
ret: Tensor
The normalized tensor.
"""
norm_func = resolve_norm_func(x.dtype)
dims = list(range(len(x.shape) - num_last_dims, len(x.shape)))
return norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
return normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)


def group_norm(x: Tensor, num_groups, epsilon: float = 1e-5, accumulate_dtype: str = 'float32'):
@@ -119,7 +105,6 @@ def group_norm(x: Tensor, num_groups, epsilon: float = 1e-5, accumulate_dtype: s

x = x.reshape(new_shape)
dims = list(range(2, len(x.shape)))
norm_func = resolve_norm_func(x.dtype)
normed = norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
normed = normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)

return normed.reshape(x_shape)
44 changes: 27 additions & 17 deletions python/hidet/graph/ops/normalize/norm.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
from hidet.ir.compute import reduce
from hidet.ir.expr import Expr
from hidet.lang import spatial, repeat, view, cast, register_tensor, shared_tensor
from hidet.lang import data_type, TensorType, i32, f32, attrs
from hidet.lang import data_type, TensorType, i32, attrs
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
from hidet.graph.ops.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.utils import compute, input_like, normalize_dim
@@ -111,6 +111,8 @@ def implement_cuda(self, working_dir: str) -> IRModule:
import math

x, y = self.inputs[0], self.outputs[0]
dtype = x.type.dtype

input_shape: List[Expr] = list(x.shape)
dims = self.dims

@@ -152,14 +154,20 @@ def welford_combine(
return
delta = mean_b[0] - mean_a[0]

mean_a[0] = mean_a[0] + delta * cast(count_b[0], f32) / cast(count, f32)
mean_a[0] = mean_a[0] + delta * cast(count_b[0], accumulate_dtype) / cast(count, accumulate_dtype)
m2_a[0] = (
m2_a[0] + m2_b[0] + delta * delta * cast(count_a[0], f32) * cast(count_b[0], f32) / cast(count, f32)
m2_a[0]
+ m2_b[0]
+ delta
* delta
* cast(count_a[0], accumulate_dtype)
* cast(count_b[0], accumulate_dtype)
/ cast(count, accumulate_dtype)
)
count_a[0] = count

@hidet.script
def norm_kernel(x: f32[x.shape], y: f32[y.shape]):
def norm_kernel(x: dtype[x.shape], y: dtype[y.shape]):
attrs.cuda.grid_dim = grid_size
attrs.cuda.block_dim = block_size
attrs.cuda.min_blocks = 1
@@ -170,9 +178,9 @@ def norm_kernel(x: f32[x.shape], y: f32[y.shape]):
smem_count = shared_tensor(i32, shape=[used_smem_bytes_per_block])

# cache repeated loads
regs_repeat = register_tensor(f32, shape=[repeat_reduction])
regs_repeat = register_tensor(dtype, shape=[repeat_reduction])

reg32 = register_tensor(f32, [1])
reg_dtype = register_tensor(dtype, [1])
mean_final = register_tensor(accumulate_dtype, [1])
m2_final = register_tensor(accumulate_dtype, [1])
count_final = register_tensor('int32', [1])
@@ -185,7 +193,7 @@ def norm_kernel(x: f32[x.shape], y: f32[y.shape]):
# note, this is evaluated at compile time
ele_idx = spatial_idxs + dim_zeros
norm_tensor = ~x[ele_idx]
flat_tensor = view(norm_tensor, f32[reduce_extent])
flat_tensor = view(norm_tensor, dtype[reduce_extent])

reduce_mapping = repeat(repeat_reduction) * spatial(block_size)
for reduction_idx in reduce_mapping.on(threadIdx.x):
@@ -197,15 +205,15 @@ def norm_kernel(x: f32[x.shape], y: f32[y.shape]):
other_count = register_tensor('int32', [1])

if reduction_idx < reduce_extent:
reg32[0] = flat_tensor[reduction_idx]
reg_dtype[0] = flat_tensor[reduction_idx]
count[0] = 1
else:
reg32[0] = f32.zero
reg_dtype[0] = dtype.zero
count[0] = 0
regs_repeat[reduction_idx // block_size] = reg32[0]
regs_repeat[reduction_idx // block_size] = reg_dtype[0]

mean[0] = reg32[0]
m2[0] = f32.zero
mean[0] = reg_dtype[0]
m2[0] = accumulate_dtype.zero

# Warp reduce by shuffle down
mask = active_mask()
@@ -243,8 +251,8 @@ def norm_kernel(x: f32[x.shape], y: f32[y.shape]):

# reduce shared memory with just a single warp
if stages > 1 and threadIdx.x < warp_size:
mean[0] = smem_mean[threadIdx.x] if threadIdx.x < shm_count else f32.zero
m2[0] = smem_m2[threadIdx.x] if threadIdx.x < shm_count else f32.zero
mean[0] = smem_mean[threadIdx.x] if threadIdx.x < shm_count else accumulate_dtype.zero
m2[0] = smem_m2[threadIdx.x] if threadIdx.x < shm_count else accumulate_dtype.zero
count[0] = smem_count[threadIdx.x] if threadIdx.x < shm_count else 0

syncthreads()
@@ -292,18 +300,20 @@ def norm_kernel(x: f32[x.shape], y: f32[y.shape]):
welford_combine(mean_final, m2_final, count_final, mean, m2, count)

# end of mean and var calculation, perform write back
m2_final[0] = m2_final[0] / cast(count_final[0], f32)
m2_final[0] = m2_final[0] / cast(count_final[0], dtype)

for spatial_idxs in task_layout.on(blockIdx.x, bind_tuple=True):
ele_idx = spatial_idxs + dim_zeros
norm_tensor = ~y[ele_idx]
flat_tensor = view(norm_tensor, f32[reduce_extent])
flat_tensor = view(norm_tensor, dtype[reduce_extent])

reduce_mapping = repeat(repeat_reduction) * spatial(block_size)
for reduction_idx in reduce_mapping.on(threadIdx.x):
if reduction_idx < reduce_extent:
val = regs_repeat[reduction_idx // block_size]
normed = (val - mean_final[0]) * prim.rsqrt(m2_final[0] + self.attrs['epsilon'])
normed = (val - mean_final[0]) * prim.rsqrt(
m2_final[0] + cast(self.attrs['epsilon'], accumulate_dtype)
)
flat_tensor[reduction_idx] = normed

ir_module = module.ir_module()
276 changes: 0 additions & 276 deletions python/hidet/graph/ops/normalize/norm_f16.py

This file was deleted.

14 changes: 1 addition & 13 deletions python/hidet/graph/ops/normalize/resolve.py
Original file line number Diff line number Diff line change
@@ -11,16 +11,13 @@
# limitations under the License.
from typing import List, Optional, Callable, Any

from hidet.ir import dtypes
from hidet.ir.expr import is_constant
from hidet.graph.operator import Operator, Tensor
from hidet.graph.transforms import ResolveRule, register_resolve_rule
from hidet.graph.ops.utils import is_contiguous_norm
from hidet.utils import prod


from .norm import NormalizeOp
from .norm_f16 import normalize_f16


@register_resolve_rule(NormalizeOp)
@@ -32,15 +29,6 @@ class NormalizeResolveRule(ResolveRule):
2) resolve_generic: Default case, return the output of the regular f32 reduce schedule.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the resolve_fp16 comment above

"""

def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]:
dims = op.attrs['dims']
x: Tensor = op.inputs[0]
if not is_contiguous_norm(dims, len(x.shape)):
return None
if x.dtype != dtypes.float16 or prod([x.shape[dd] for dd in dims]) % 2 != 0:
Copy link
Collaborator

@xinli-git xinli-git Jul 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing this is safe for now, but we might need to think about how to handle it when we decide to use 2xfp16 types and the norm size is odd.

return None
return [normalize_f16(x, dims)]

def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]:
dims = op.attrs['dims']
x: Tensor = op.inputs[0]
@@ -57,7 +45,7 @@ def resolve(self, op: Operator) -> Optional[List[Tensor]]:
assert isinstance(op, NormalizeOp)
if not is_constant(*op.inputs[0].shape):
return None
resolve_funcs: List[Callable[[Operator], Any]] = [self.resolve_f16, self.resolve_generic]
resolve_funcs: List[Callable[[Operator], Any]] = [self.resolve_generic]
for resolve_func in resolve_funcs:
outs = resolve_func(op)
if outs is not None:
4 changes: 2 additions & 2 deletions python/hidet/testing/torch_utils.py
Original file line number Diff line number Diff line change
@@ -24,12 +24,12 @@ def forward(self, *args, **kwargs):
return self.op(*args, **kwargs)


def check_module(model: torch.nn.Module, args: Sequence[torch.Tensor], atol=1e-4, rtol=1e-4):
def check_module(model: torch.nn.Module, args: Sequence[torch.Tensor], atol=1e-4, rtol=1e-4, dynamic=False):
model = model.cuda()
model.eval()
args = [x.cuda() if isinstance(x, torch.Tensor) else x for x in args]
# we use a lambda to make sure the model is compiled by pytorch
model_opt = torch.compile(lambda *args, **kwargs: model(*args, **kwargs), backend='hidet')
model_opt = torch.compile(lambda *args, **kwargs: model(*args, **kwargs), backend='hidet', dynamic=dynamic)
torch_outputs = model(*args)
hidet_outputs = model_opt(*args)
if isinstance(torch_outputs, torch.Tensor):
8 changes: 5 additions & 3 deletions tests/frontends/torch/test_torch_bert.py
Original file line number Diff line number Diff line change
@@ -18,21 +18,23 @@
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('seq_length', [128])
@pytest.mark.parametrize('use_fp16,use_tensor_core', [(False, False), (False, True), (True, True)])
def test_bert(batch_size: int, seq_length: int, use_fp16, use_tensor_core):
@pytest.mark.parametrize('dynamic', [False, True])
def test_bert(batch_size: int, seq_length: int, use_fp16, use_tensor_core, dynamic):
tokens_tensor = torch.zeros((batch_size, seq_length), dtype=torch.long, device='cuda')
segments_tensors = torch.zeros((batch_size, seq_length), dtype=torch.long, device='cuda')
args = (tokens_tensor.cuda(),)
kwargs = {'token_type_ids': segments_tensors.cuda()}
model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased').cuda().eval()
model_opt = torch.compile(model, backend='hidet')
model_opt = torch.compile(model, backend='hidet', dynamic=dynamic)
y1 = model(*args, **kwargs).last_hidden_state

try:
hidet.torch.dynamo_config.use_fp16(use_fp16)
hidet.torch.dynamo_config.use_tensor_core(use_tensor_core)

y2 = model_opt(*args, **kwargs).last_hidden_state
torch.testing.assert_close(y1, y2, atol=1e-2, rtol=1e-2)
tol = 1e-1 if use_fp16 else 1e-2
torch.testing.assert_close(y1, y2, atol=tol, rtol=tol)
finally:
# in case of failure, reset the config
hidet.torch.dynamo_config.reset()
9 changes: 5 additions & 4 deletions tests/frontends/torch/test_torch_resnet50.py
Original file line number Diff line number Diff line change
@@ -16,11 +16,12 @@


@pytest.mark.parametrize('shape', [[1, 3, 224, 224]])
def test_resnet50(shape):
@pytest.mark.parametrize('dynamic', [False, True])
def test_resnet50(shape, dynamic):
torch.backends.cudnn.allow_tf32 = False # disable tf32 for accuracy
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True)
x = torch.randn(*shape)
check_module(model, [x], atol=1e-2, rtol=1e-2)
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True).cuda().eval()
x = torch.randn(*shape).cuda()
check_module(model, [x], atol=1e-2, rtol=1e-2, dynamic=dynamic)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we been using the CPU path before this change?

torch.backends.cudnn.allow_tf32 = True