-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][ADT]Static Tensor Array #5103
[Relay][ADT]Static Tensor Array #5103
Conversation
There are lots of AST written by hand. Can you try using the relay parser? |
Is the constraint only fixed "rank"? I have a use case for fixed "shape" tensor array (a tighter constraint) |
@MarisaKirisame This PR follows the patten of current Tensor Array. What do you mean by using relay parser? |
@masahi fixed shape is a subset of fixed rank and supported in this PR. For stack operator, even if we have all tensors in the array with the same static shape, the first axis output shape will still be Any(). |
@kevinthesun Great! I've been waiting for this (thanks @wweic). I think the first output axis of stack being Any makes sense and it is no problem for me, assuming it is the length of the array.
I think Marisa is talking about writing the new ADT and functions in the "Relay language" itself and use the parser to make it available to python, similar to the way List ADT and its helper functions are implemented in Prelude. |
I see. My understanding for this is List ADT primitives are small and mature enough to be implemented in prelude. We can also move tensor array related primitives into prelude when they becomes more mature. In addition, we might want to move the generic tensor array as well, not just static tensor array. Does this make sense? |
Yes. Since the generic tensor array is already implemented in this way, I also think it is better to let this in first and later port all tensor array stuff to Relay lang. There could be common code between generic/static that should be refactored in the process. |
fa9e4a7
to
92d7eaf
Compare
8a4b4fc
to
769cdbd
Compare
@MarisaKirisame The reason we can not use relay parser is that we want to dynamically generate the operators for specific shape while we convert the TF model to relay IR. Current relay text parser can not easily do that. |
python/tvm/relay/prelude.py
Outdated
shape_str = str(self.shape).replace('[', '').replace(']', '')\ | ||
.replace('(', '').replace(')', '').replace(', ', '_')\ | ||
.replace(',', '') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should improve this a bit. can we use '_'.join(self.shape)
?
python/tvm/relay/prelude.py
Outdated
"""Defines the dynamic tensor ADT, which is the container for tensors | ||
with variable shapes.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Defines the dynamic tensor ADT, which is the container for tensors | |
with variable shapes.""" | |
"""Defines the static tensor ADT, which is the container for tensors | |
with fixed shape.""" |
python/tvm/relay/prelude.py
Outdated
lower = Var('lower', scalar_type('int32')) | ||
upper = Var('upper', scalar_type('int32')) | ||
tvar = Var('t') | ||
case = Clause(PatternConstructor(self.get_var('tensor_constructor'), [PatternVar(tvar)]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
case = Clause(PatternConstructor(self.get_var('tensor_constructor'), [PatternVar(tvar)]), | |
case = Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), |
python/tvm/relay/prelude.py
Outdated
Function([x], Match(x, [case], False), tensor_type_var(), []) | ||
|
||
def define_tensor_array_read(self): | ||
"""Defines a function to get the head of a list. Assume the list has at least one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Defines a function to get the head of a list. Assume the list has at least one | |
"""Defines a function to get the nth element of a list. Assume the list has at least one |
python/tvm/relay/prelude.py
Outdated
list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t] | ||
|
||
Set static indices shape by specifying indices_shape. | ||
Set for_update to get static indices shape operator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set for_update to get static indices shape operator. | |
Set force_update to get static indices shape operator. |
@wweic Comments addressed. PTAL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. some minor comments.
python/tvm/relay/prelude.py
Outdated
@@ -217,7 +218,7 @@ def define_tensor_array_unstack(self): | |||
tensor_array_unstack_tensor(t) : tensor_t -> list[tensor_t] | |||
""" | |||
ndim = len(self.shape) | |||
# Skip scalar case | |||
# We don't register unstask for scalar tensor array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# We don't register unstask for scalar tensor array | |
# We don't register unstack for scalar tensor array |
4a9dd1d
to
183ac9e
Compare
@wweic Fixed. |
Thanks @wweic @masahi @MarisaKirisame |
* Add other static tensor array ops * Add tensor array get data * Minor refactor * Fix pylint * Update docstring * Make get data more generic * Improve test * Improve split test * Improve get data * Minor fix * Further improvement for static shape * Improve shape parsing * Unify get_static_name
* Add other static tensor array ops * Add tensor array get data * Minor refactor * Fix pylint * Update docstring * Make get data more generic * Improve test * Improve split test * Improve get data * Minor fix * Further improvement for static shape * Improve shape parsing * Unify get_static_name
* Add other static tensor array ops * Add tensor array get data * Minor refactor * Fix pylint * Update docstring * Make get data more generic * Improve test * Improve split test * Improve get data * Minor fix * Further improvement for static shape * Improve shape parsing * Unify get_static_name
Add static tensor array for fixed rank tensor array. With this change, we can more easily do type inference and optimization for most tensor array use cases.
Thanks for @wweic working on the base infra of StaticTensorArrayOps class.
@wweic @zhiics @yongwww @icemelon9