-
Notifications
You must be signed in to change notification settings - Fork 374
feat: support prod, max, min, and mean via reduce layer #2355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| ): | ||
| input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) | ||
|
|
||
| if dim is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on torch behavior, it seems that torch.mean can also accept an empty list as an indicator to reduce over all dimensions (but min, max, and prod don't seem to accept this). Consider switching this to: if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! And then I also checked docs for amax and sum. They do accept empty sequence as well, so I made changes accordingly.
d4f1451 to
6f136de
Compare
| @dynamo_tensorrt_converter(torch.ops.aten.min.default) # type: ignore[misc] | ||
| def aten_ops_min( | ||
| ctx: ConversionContext, | ||
| target: Target, | ||
| args: Tuple[Argument, ...], | ||
| kwargs: Dict[str, Argument], | ||
| name: str, | ||
| ) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
| return impl.reduce.min( | ||
| ctx, | ||
| target, | ||
| SourceIR.ATEN, | ||
| name, | ||
| args[0], | ||
| dim=None, | ||
| keepdim=False, | ||
| return_indices=False, | ||
| ) | ||
|
|
||
|
|
||
| @dynamo_tensorrt_converter(torch.ops.aten.min.dim, capability_validator=one_user_validator) # type: ignore[misc] | ||
| def aten_ops_min_dim( | ||
| ctx: ConversionContext, | ||
| target: Target, | ||
| args: Tuple[Argument, ...], | ||
| kwargs: Dict[str, Argument], | ||
| name: str, | ||
| ) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
| return impl.reduce.min( | ||
| ctx, | ||
| target, | ||
| SourceIR.ATEN, | ||
| name, | ||
| args[0], | ||
| args[1], | ||
| args_bounds_check(args, 2, replacement=False), | ||
| return_indices=True, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider coalescing these by using something like:
return impl.reduce.min(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args_bounds_check(args, 1, replacement=None),
args_bounds_check(args, 2, replacement=False),
return_indices=(target==torch.ops.aten.min.dim),
)The same could apply for max
| if return_indices: | ||
| return layer.get_output(0), None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@narendasan, @apbose - there are certain converter cases where indices are returned from an operator but never used nor accessed in the graph (confirmed via validator). In these cases, we wouldn't want to use extra computation time to add layers for an unused tensor. Which of these seems best?
- Return None for the unused tensor, as here -
(data, None) - Return only the data, since the unused tensor should never be accessed, as in:
(data,)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
6f136de to
1717018
Compare
gs-olive
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, pending CI pass.
|
Tests passing - verified locally |
Description
Support prod, max, min, and mean via reduce layer
Fixes #2205
Type of change
Checklist: