-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Dynamic shape fx trace #294
Changes from 18 commits
803d302
a59ba09
0be7d38
06a9cd0
f550d2f
fd19d60
323e1a8
d26f6aa
7c4fdc1
cb70f69
f9ae927
87abfe9
20a61c6
ad7f50f
8dcbc39
f58fed5
48b8190
b92442c
cf21a7d
b83f399
f5ea3ab
3e4bc32
1546bf4
4b6be25
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -145,12 +145,12 @@ def bilinear(x_1: Tensor, x_2: Tensor, weight: Tensor, bias: Optional[Tensor]): | |
@register_function(operator.add) | ||
@register_function(torch.ops.aten.add.Tensor) | ||
def add(x: Tensor, y: Tensor): | ||
return ops.add(x, y) | ||
return x + y | ||
|
||
|
||
@register_function(operator.iadd) | ||
def iadd(x: Tensor, y: Tensor): | ||
return ops.add(x, y) | ||
return x + y | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the |
||
|
||
|
||
@register_function(torch.sin) | ||
|
@@ -362,7 +362,7 @@ def zeros(*size, out=None, dtype=None, layout=None, device=None, pin_memory=Fals | |
|
||
@register_function(torch.ones) | ||
def ones( | ||
*size: Union[int, Sequence[int]], | ||
*size: Union[Int, Sequence[Int]], | ||
out: Optional[Tensor] = None, | ||
dtype: Optional[torch.dtype] = None, | ||
layout: Optional[torch.layout] = None, | ||
|
@@ -381,7 +381,7 @@ def ones( | |
if isinstance(size[0], (list, tuple)): | ||
size = size[0] | ||
|
||
shape = [int(v) for v in size] | ||
shape = [v if isinstance(v, hidet.ir.Expr) else int(v) for v in size] | ||
if dtype is None: | ||
dtype = torch.get_default_dtype() | ||
|
||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,16 +11,13 @@ | |
# limitations under the License. | ||
from typing import List, Optional, Callable, Any | ||
|
||
from hidet.ir import dtypes | ||
from hidet.ir.expr import is_constant | ||
from hidet.graph.operator import Operator, Tensor | ||
from hidet.graph.transforms import ResolveRule, register_resolve_rule | ||
from hidet.graph.ops.utils import is_contiguous_norm | ||
from hidet.utils import prod | ||
|
||
|
||
from .norm import NormalizeOp | ||
from .norm_f16 import normalize_f16 | ||
|
||
|
||
@register_resolve_rule(NormalizeOp) | ||
|
@@ -32,15 +29,6 @@ class NormalizeResolveRule(ResolveRule): | |
2) resolve_generic: Default case, return the output of the regular f32 reduce schedule. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the resolve_fp16 comment above |
||
""" | ||
|
||
def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]: | ||
dims = op.attrs['dims'] | ||
x: Tensor = op.inputs[0] | ||
if not is_contiguous_norm(dims, len(x.shape)): | ||
return None | ||
if x.dtype != dtypes.float16 or prod([x.shape[dd] for dd in dims]) % 2 != 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removing this is safe for now, but we might need to think about how to handle it when we decide to use 2xfp16 types and the norm size is odd. |
||
return None | ||
return [normalize_f16(x, dims)] | ||
|
||
def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: | ||
dims = op.attrs['dims'] | ||
x: Tensor = op.inputs[0] | ||
|
@@ -57,7 +45,7 @@ def resolve(self, op: Operator) -> Optional[List[Tensor]]: | |
assert isinstance(op, NormalizeOp) | ||
if not is_constant(*op.inputs[0].shape): | ||
return None | ||
resolve_funcs: List[Callable[[Operator], Any]] = [self.resolve_f16, self.resolve_generic] | ||
resolve_funcs: List[Callable[[Operator], Any]] = [self.resolve_generic] | ||
for resolve_func in resolve_funcs: | ||
outs = resolve_func(op) | ||
if outs is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,11 +16,12 @@ | |
|
||
|
||
@pytest.mark.parametrize('shape', [[1, 3, 224, 224]]) | ||
def test_resnet50(shape): | ||
@pytest.mark.parametrize('dynamic', [False, True]) | ||
def test_resnet50(shape, dynamic): | ||
torch.backends.cudnn.allow_tf32 = False # disable tf32 for accuracy | ||
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True) | ||
x = torch.randn(*shape) | ||
check_module(model, [x], atol=1e-2, rtol=1e-2) | ||
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True).cuda().eval() | ||
x = torch.randn(*shape).cuda() | ||
check_module(model, [x], atol=1e-2, rtol=1e-2, dynamic=dynamic) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have we been using the CPU path before this change? |
||
torch.backends.cudnn.allow_tf32 = True | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For those dynamic shape, I am wondering if these scalar parameters are act as the shape of the input tensors. If that's the case, we can ignore those scalar parameters.
Say a torch model gives us
We can declare the symbol variable for 'm' and 'n' (when we define the symbol tensor) and ignore the 'm' and 'n' scalar parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any clue on this?