diff --git a/python/hidet/graph/ops/pool.py b/python/hidet/graph/ops/pool.py index ddbf8e9e0..feddf7dbf 100644 --- a/python/hidet/graph/ops/pool.py +++ b/python/hidet/graph/ops/pool.py @@ -349,7 +349,7 @@ def __init__( class AdaptivePoolNdOp(Operator): spatial_ndim: Optional[int] = None reduce_type: Optional[str] = None - last_channel_layout: Optional[bool] = None + last_channel: Optional[bool] = None def __init__(self, x: Tensor, output_size): if len(x.shape) != self.spatial_ndim + 2: @@ -362,7 +362,7 @@ def __init__(self, x: Tensor, output_size): self.reduce_type = self.reduce_type # todo: merge AdaptivePoolTask and AdaptivePoolChannelLastTask into one class - if self.last_channel_layout: + if self.last_channel: task = AdaptivePoolChannelLastTask(input_like(x, 'x'), output_size, reduce_type=self.reduce_type) else: task = AdaptivePoolTask(input_like(x, 'x'), output_size, reduce_type=self.reduce_type) @@ -433,43 +433,43 @@ class AvgPool3dChannelLastOp(AvgPoolNdOp): class AdaptiveAvgPool1dOp(AdaptivePoolNdOp): reduce_type = 'avg' spatial_ndim = 1 - last_channel_layout = False + last_channel = False class AdaptiveAvgPool2dOp(AdaptivePoolNdOp): reduce_type = 'avg' spatial_ndim = 2 - last_channel_layout = False + last_channel = False class AdaptiveAvgPool3dOp(AdaptivePoolNdOp): reduce_type = 'avg' spatial_ndim = 3 - last_channel_layout = False + last_channel = False class AdaptiveAvgPool2dChannelLastOp(AdaptivePoolNdOp): reduce_type = 'avg' spatial_ndim = 2 - last_channel_layout = True + last_channel = True class AdaptiveMaxPool1dOp(AdaptivePoolNdOp): reduce_type = 'max' spatial_ndim = 1 - last_channel_layout = False + last_channel = False class AdaptiveMaxPool2dOp(AdaptivePoolNdOp): reduce_type = 'max' spatial_ndim = 2 - last_channel_layout = False + last_channel = False class AdaptiveMaxPool3dOp(AdaptivePoolNdOp): reduce_type = 'max' spatial_ndim = 3 - last_channel_layout = False + last_channel = False def max_pool1d(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: @@ -575,8 +575,10 @@ def adaptive_avg_pool2d_channel_last(x: Tensor, output_size: Union[int, Sequence @register_resolve_rule(AdaptivePoolNdOp) class AdaptivePoolResolveRule(ResolveRule): - def resolve(self, op: Operator) -> Optional[List[Tensor]]: + def resolve(self, op: AdaptivePoolNdOp) -> Optional[List[Tensor]]: assert isinstance(op, AdaptivePoolNdOp) + if not op.last_channel: + return None x: Tensor = op.inputs[0] output_size = op.attrs['output_size'] reduce_type = op.reduce_type @@ -590,24 +592,3 @@ def resolve(self, op: Operator) -> Optional[List[Tensor]]: elif reduce_type == 'avg': return [mean(x, dims=dims[2:], keep_dim=True)] return None - - -@register_resolve_rule(AdaptivePoolChannelLastOp) -class AdaptivePoolChannelLastResolveRule(ResolveRule): - def resolve(self, op: Operator) -> Optional[List[Tensor]]: - assert isinstance(op, AdaptivePoolChannelLastOp) - x: Tensor = op.inputs[0] - # TODO: Deal with generic N-dimensional convolution - if len(x.shape) != 4: - return None - output_size = op.attrs['output_size'] - reduce_type = op.reduce_type - resolve_to_reduce = output_size == 1 if isinstance(output_size, int) else all(d == 1 for d in output_size) - if resolve_to_reduce: - from hidet.graph.ops import mean, max - - if reduce_type == 'max': - return [max(x, dims=[1, 2], keep_dim=True)] - elif reduce_type == 'avg': - return [mean(x, dims=[1, 2], keep_dim=True)] - return None