-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array #5243
Conversation
Hi @kevinthesun, I started experimenting with how to integrate static tensor array in Torch frontend. My use case is to support Python tensor list append and stack. I got two problems below:
|
Update: With the new static tensor array, I got the following PyTorch LSTM model, originally from the fastrnn benchmark in PyTorch repo here https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py#L187, converted correctly to Relay and got the identical result as torch! It was not possible with generic tensor array. @kevinthesun @wweic class LSTMCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
self.bias_ih = Parameter(torch.randn(4 * hidden_size))
self.bias_hh = Parameter(torch.randn(4 * hidden_size))
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
hx, cx = state
gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, (hy, cy)
class LSTMLayer(jit.ScriptModule):
def __init__(self, cell, *cell_args):
super().__init__()
self.cell = cell(*cell_args)
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
outputs = []
for i in range(input.size(0)):
out, state = self.cell(input[i], state)
outputs += [out]
return torch.stack(outputs), state Here is the converted Relay IR:
|
@masahi You can use tensor_get_data to achieve this. |
Yes you can use |
Ah thanks. I tried to use it on the output of stack, but since the first axis is 'Any', I don't know how to pass A better question might be, why do we need to pass |
@masahi The shape passed to |
hmm I tried this: def _tensor_array_stack(prelude):
def _impl(inputs, input_types):
# print(prelude.mod)
# TODO: how to get the fixed shape of static_tensor_array inputs[0]?
shape = (2, 4)
stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
stacked = stack(inputs[0])
stacked_shape = (Any(), 2, 4)
static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", stacked_shape)
static_tensor_array_ops.register()
get_tensor = prelude.get_var_static('tensor_get_data', "float32", stacked_shape)
return get_tensor(stacked)
return _impl But I'm still getting |
https://github.com/apache/incubator-tvm/pull/5243/files#diff-eae8ecf976e0031823eeae454466f964R903 Take tensor_array_gather as an example, you create a new static tensor array ops object with your input tensor array shape, and register all ops except tensor_get_data. After this, https://github.com/apache/incubator-tvm/pull/5243/files#diff-eae8ecf976e0031823eeae454466f964R924 you need to manually register tensor_get_data. It won't be automatically registered since input shape and output shape might not match. |
Great I got the following working. Also confirmed def _tensor_array_stack(prelude):
def _impl(inputs, input_types):
shape = get_tensor_array_shape(inputs[0], "float32", prelude)
stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
stacked = stack(inputs[0])
stacked_shape = (Any(),) + shape
static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
static_tensor_array_ops.define_tensor_get_data(stacked_shape)
# passing stacked_shape below gives "'Prelude' object has no attribute" error
get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape)
return get_tensor(stacked)
return _impl |
@kevinthesun @wweic Is it reasonable to add "axis" parameter to tensor array concat? I encountered a need to concat along the -1 axis. |
@masahi To support different axis we need to change both |
@kevinthesun I'm not entirely familiar with TF let alone its tensor array support. If that is fine I can review. |
@masahi Sure. Please go ahead and review. I think a lot of logics can be reused in pytorch. |
Ok for now I went an easy route of just defining concat_last op. It seems to work, but I'm getting the following typing error:
The first axis is already Any by tensor array stack. Now I'm trying to concat (?, 2, 4) tensors along -1 axis to get (?, 2, ?) tensor. Is this possible? They typing error suggests no. UPDATE: Solved by mapping tensor_constructor with concat-ed shape (?, 2, ?):
|
Thanks @kevinthesun |
…pache#5243) * Support TF Frontend Static TensorArray * Fix pylint * Fix lint * Move get_tensor_array_shape into prelude * Fix lint * Fix common
…pache#5243) * Support TF Frontend Static TensorArray * Fix pylint * Fix lint * Move get_tensor_array_shape into prelude * Fix lint * Fix common
…pache#5243) * Support TF Frontend Static TensorArray * Fix pylint * Fix lint * Move get_tensor_array_shape into prelude * Fix lint * Fix common
…m_data:master to master * commit 'cd0d52daa6942bdafa9363ff6cfa3d25fcd5b8d6': (824 commits) [Intrinsic] Add log1p, ldexp, atan2, hypot, nextafter, copysign (apache#5312) [Rust][CI] Restore Rust CI (apache#5137) Remove PrimExpr from String (apache#5311) [Requantize] Cleanup and Optimize Lowering (apache#5286) [IR][TRANSFORM] Enable CopyOnWrite for passes. (apache#5309) [PYTORCH]Abs, Arange, Softplus ops (apache#5295) [LLVM] Fix generation of LLVM intrinsics (apache#5282) [BYOC] Add example of Composite + Annotate for DNNL fused op (apache#5272) [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array (apache#5243) [RUNTIME] Introduce RValue reference(move) support to TypedPackedFunc (apache#5271) [RELAY][FRONTEND][CAFFE2] add Mul and ConvTranspose operator (apache#5302) [BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes (apache#5277) [CI] Fix the hexagon string (apache#5304) [Arith] linear system and equation solver (apache#5171) [PYTORCH]Repeat, Reciprocal & Reshape Op support (apache#5280) [FRONTEND][TENSORFLOW] Fix gather_nd indices (apache#5279) Update device_annotation.cc (apache#5291) [REFACTOR][IR] Move to runtime::String (apache#5276) [NDArray] Set shape_ in NDArray::FromDLPack (apache#5301) [RUNTIME] Initial implementation of Hexagon runtime support (apache#5252) ...
Improve TensorFlow frontend to deal with static shape tensor array. After this PR, most tensor array operators will have static input/output shapes.
@wweic @zhiics @yongwww @masahi