-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Zero-Dim] support input 0D Tensor for distribution transform api #47677
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
04e0d6e
to
043ceae
Compare
043ceae
to
098e654
Compare
raise ValueError( | ||
f"The numel of 'in_event_shape' should be 'out_event_shape', " | ||
f"but got {functools.reduce(operator.mul, in_event_shape)}!={functools.reduce(operator.mul, out_event_shape)}" | ||
f"but got {in_size}!={out_size}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个修改的目的是?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0D Tensor的shape为[],numel为1,用functools.reduce无法实现,会报空列表的错误,所以就自行实现了
if shape[-len(self._in_event_shape) :] != self._in_event_shape: | ||
if list(shape[-len(self._in_event_shape) :]) != list( | ||
self._in_event_shape | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议只读类型用tuple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if shape[-len(self._out_event_shape) :] != self._out_event_shape: | ||
if list(shape[-len(self._out_event_shape) :]) != list( | ||
self._out_event_shape | ||
): | ||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
# original patching object as well as re-constrfucted patches. | ||
delete_patches_if_need(f) | ||
|
||
f.__test__ = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
貌似有 typo
PR types
New features
PR changes
APIs
Describe
为Distribution Transform类API支持输入0D Tensor,如下: