Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array #5243

Merged
merged 6 commits into from
Apr 11, 2020

Conversation

kevinthesun
Copy link
Contributor

Improve TensorFlow frontend to deal with static shape tensor array. After this PR, most tensor array operators will have static input/output shapes.

@wweic @zhiics @yongwww @masahi

@masahi
Copy link
Member

masahi commented Apr 6, 2020

Hi @kevinthesun, I started experimenting with how to integrate static tensor array in Torch frontend. My use case is to support Python tensor list append and stack. I got two problems below:

  1. When I append the tensor to tensor array (by concat), I can do infer shape on the input tensor to get the fixed shape static tensor array expects. But after I've done some appends and try to stack the static tensor array, I don't have a way to tell what fixed shape the input tensor array to stack expects. See
    https://github.com/masahi/tvm/blob/support-more-rnn/python/tvm/relay/frontend/pytorch.py#L989-L990
    Since the shape is fixed, I think there should be an easy way to query the shape associated with a static array. I see you have such function check_tensor_array_shape in this PR (by parsing op name). Is this the recommended way?

  2. The output type of stack is currently static_tensor_float32_?_2_4_t[] in my test. Is there a way to easily unwrap static tensor type wrapper and get relay Tensor? @wweic had such unwrapper in [Prelude][Relay] Add get tensor data utilities #4325 for generic arrays. We should have something equivalent for static arrays.

@masahi
Copy link
Member

masahi commented Apr 6, 2020

Update: With the new static tensor array, I got the following PyTorch LSTM model, originally from the fastrnn benchmark in PyTorch repo here https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py#L187, converted correctly to Relay and got the identical result as torch! It was not possible with generic tensor array. @kevinthesun @wweic

class LSTMCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias_ih = Parameter(torch.randn(4 * hidden_size))
        self.bias_hh = Parameter(torch.randn(4 * hidden_size))

    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        hx, cx = state
        gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                 torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, (hy, cy)


class LSTMLayer(jit.ScriptModule):
    def __init__(self, cell, *cell_args):
        super().__init__()
        self.cell = cell(*cell_args)

    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        outputs = []
        for i in range(input.size(0)):
            out, state = self.cell(input[i], state)
            outputs += [out]
        return torch.stack(outputs), state

Here is the converted Relay IR:

fn (%input: Tensor[(5, 2, 3), float32], %v25: Tensor[(16, 3), float32], %v28: Tensor[(16), float32], %v30: Tensor[(16, 4), float32], %v34: Tensor[(16), float32], %states: (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) -> (static_tensor_float32_?_2_4_t[], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) {
  %0 = Nil /* ty=List[static_tensor_float32_2_4_t[]] */;
  %36 = (
    let %while_loop: fn (int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) -> (int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) = fn (%i.1: int32, %outputs.6: List[static_tensor_float32_2_4_t[]], %state.6: (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) -> (int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) {
      %1 = less(%i.1, 5 /* ty=int32 */) /* ty=bool */;
      if (%1) {
        %2 = add(%i.1, 1 /* ty=int32 */) /* ty=int32 */;
        %3 = take(%input, %i.1, axis=0) /* ty=Tensor[(2, 3), float32] */;
        %4 = transpose(%v25, axes=[1, 0]) /* ty=Tensor[(3, 16), float32] */;
        %5 = transpose(%4, axes=[1, 0]) /* ty=Tensor[(16, 3), float32] */;
        %6 = nn.dense(%3, %5, units=None) /* ty=Tensor[(2, 16), float32] */;
        %7 = add(%6, %v28) /* ty=Tensor[(2, 16), float32] */;
        %8 = %state.6.0;
        %9 = transpose(%v30, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
        %10 = transpose(%9, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
        %11 = nn.dense(%8, %10, units=None) /* ty=Tensor[(2, 16), float32] */;
        %12 = add(%7, %11) /* ty=Tensor[(2, 16), float32] */;
        %13 = add(%12, %v34) /* ty=Tensor[(2, 16), float32] */;
        %14 = strided_slice(%13, begin=[0, 12], end=[2, 16], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
        %15 = sigmoid(%14) /* ty=Tensor[(2, 4), float32] */;
        %16 = strided_slice(%13, begin=[0, 4], end=[2, 8], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
        %17 = sigmoid(%16) /* ty=Tensor[(2, 4), float32] */;
        %18 = %state.6.1;
        %19 = multiply(%17, %18) /* ty=Tensor[(2, 4), float32] */;
        %20 = strided_slice(%13, begin=[0, 0], end=[2, 4], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
        %21 = sigmoid(%20) /* ty=Tensor[(2, 4), float32] */;
        %22 = strided_slice(%13, begin=[0, 8], end=[2, 12], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
        %23 = tanh(%22) /* ty=Tensor[(2, 4), float32] */;
        %24 = multiply(%21, %23) /* ty=Tensor[(2, 4), float32] */;
        %25 = add(%19, %24) /* ty=Tensor[(2, 4), float32] */;
        %26 = tanh(%25) /* ty=Tensor[(2, 4), float32] */;
        %27 = multiply(%15, %26) /* ty=Tensor[(2, 4), float32] */;
        %28 = (%27, %25);
        %29 = (%27, %28);
        %30 = %29.0;
        %31 = tensor_constructor_float32_2_4(%30) /* ty=static_tensor_float32_2_4_t[] */;
        %32 = Nil /* ty=List[static_tensor_float32_2_4_t[]] */;
        %33 = Cons(%31, %32) /* ty=List[static_tensor_float32_2_4_t[]] */;
        %34 = @concat(%outputs.6, %33) /* ty=List[static_tensor_float32_2_4_t[]] */;
        %35 = %29.1;
        %while_loop(%2, %34, %35) /* ty=(int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) */
      } else {
        (%i.1, %outputs.6, %state.6)
      }
    };
    %while_loop
  );
  %37 = %36(0 /* ty=int32 */, %0, %states) /* ty=(int32, List[static_tensor_float32_2_4_t[]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32])) */;
  %38 = %37.1;
  %39 = @tensor_array_stack_float32_2_4(%38) /* ty=static_tensor_float32_?_2_4_t[] */;
  %40 = %37.2;
  (%39, %40)
}

@kevinthesun
Copy link
Contributor Author

2. The output type of stack is currently static_tensor_float32_?_2_4_t[] in my test. Is there a way to easily unwrap static tensor type wrapper and get relay Tensor? @wweic had such unwrapper in #4325 for generic arrays. We should have something equivalent for static arrays.

@masahi You can use tensor_get_data to achieve this.

@kevinthesun
Copy link
Contributor Author

  1. When I append the tensor to tensor array (by concat), I can do infer shape on the input tensor to get the fixed shape static tensor array expects. But after I've done some appends and try to stack the static tensor array, I don't have a way to tell what fixed shape the input tensor array to stack expects. See
    https://github.com/masahi/tvm/blob/support-more-rnn/python/tvm/relay/frontend/pytorch.py#L989-L990
    Since the shape is fixed, I think there should be an easy way to query the shape associated with a static array. I see you have such function check_tensor_array_shape in this PR (by parsing op name). Is this the recommended way?

Yes you can use check_tensor_array_shape. I'll change the name to get_tensor_array_shape.

@masahi
Copy link
Member

masahi commented Apr 6, 2020

@masahi You can use tensor_get_data to achieve this

Ah thanks. I tried to use it on the output of stack, but since the first axis is 'Any', I don't know how to pass shape param in get_var_static('tensor_get_data', "float32", shape). How do I do it?

A better question might be, why do we need to pass shape all over the place? I'd imagine, intuitively stack and other ops that operate on already existing tensor array should be able to figure out the shape, no?

@kevinthesun
Copy link
Contributor Author

@masahi The shape passed to get_var_static is for identification. For tensor_get_data, it is just for picking up corresponding global var from prelude mod. You just need to pass the shape with which you created StaticTensorArrayOps. For example, if you create a tensor array with shape (1, 2, 3), you just need to pass (1, 2, 3) to get_var_static. However, to define_tensor_get_data, you want to pass (Any(), 1, 2, 3), since this is the actual output shape.

@masahi
Copy link
Member

masahi commented Apr 6, 2020

hmm I tried this:

def _tensor_array_stack(prelude):
    def _impl(inputs, input_types):
        # print(prelude.mod)
        # TODO: how to get the fixed shape of static_tensor_array inputs[0]?
        shape = (2, 4)
        stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
        stacked = stack(inputs[0])

        stacked_shape = (Any(), 2, 4)
        static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", stacked_shape)
        static_tensor_array_ops.register()
        get_tensor = prelude.get_var_static('tensor_get_data', "float32", stacked_shape)
        return get_tensor(stacked)
    return _impl

But I'm still getting AttributeError: 'Prelude' object has no attribute 'tensor_get_data_float32_?_2_4'. Do I need a new prelude to register a new shape?

@kevinthesun
Copy link
Contributor Author

https://github.com/apache/incubator-tvm/pull/5243/files#diff-eae8ecf976e0031823eeae454466f964R903 Take tensor_array_gather as an example, you create a new static tensor array ops object with your input tensor array shape, and register all ops except tensor_get_data. After this, https://github.com/apache/incubator-tvm/pull/5243/files#diff-eae8ecf976e0031823eeae454466f964R924 you need to manually register tensor_get_data. It won't be automatically registered since input shape and output shape might not match.

@masahi
Copy link
Member

masahi commented Apr 6, 2020

Great I got the following working. Also confirmed get_tensor_array_shape worked. Happy now :) Thank you very much! Unwrapping should enable supporting the "stacked" LSTM in https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py#L252, where the output from LSTMLayer is pipelined num_layers times to get a bigger network.

def _tensor_array_stack(prelude):
    def _impl(inputs, input_types):
        shape = get_tensor_array_shape(inputs[0], "float32", prelude)
        stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
        stacked = stack(inputs[0])

        stacked_shape = (Any(),) + shape
        static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
        static_tensor_array_ops.define_tensor_get_data(stacked_shape)
        # passing stacked_shape below gives "'Prelude' object has no attribute" error
        get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape)
        return get_tensor(stacked)
    return _impl

@masahi
Copy link
Member

masahi commented Apr 10, 2020

@kevinthesun @wweic Is it reasonable to add "axis" parameter to tensor array concat? I encountered a need to concat along the -1 axis.

@kevinthesun
Copy link
Contributor Author

@kevinthesun @wweic Is it reasonable to add "axis" parameter to tensor array concat? I encountered a need to concat along the -1 axis.

@masahi To support different axis we need to change both define_tensor_concatenate and define_tensor_array_concat to support axis arguments. The main thing we need to take care is the output shape.

@kevinthesun
Copy link
Contributor Author

@masahi @wweic PTAL

@masahi
Copy link
Member

masahi commented Apr 10, 2020

@kevinthesun I'm not entirely familiar with TF let alone its tensor array support. If that is fine I can review.

@kevinthesun
Copy link
Contributor Author

@masahi Sure. Please go ahead and review. I think a lot of logics can be reused in pytorch.

@masahi
Copy link
Member

masahi commented Apr 10, 2020

@masahi To support different axis we need to change both define_tensor_concatenate and define_tensor_array_concat to support axis arguments. The main thing we need to take care is the output shape.

Ok for now I went an easy route of just defining concat_last op. It seems to work, but I'm getting the following typing error:

...
  %101 = @map(tensor_constructor_float32_?_2_4(Tensor[(?, 2, 4), float32]), %100);
  %102 = @tensor_array_concat_last_float32_?_2_?(%101) unable to unify: `static_tensor_float32_?_2_4_t` and `static_tensor_float32_?_2_?_t`; ;
  %103 = @tensor_get_data_float32_?_2_?(%102);
...

The first axis is already Any by tensor array stack. Now I'm trying to concat (?, 2, 4) tensors along -1 axis to get (?, 2, ?) tensor. Is this possible? They typing error suggests no.

UPDATE: Solved by mapping tensor_constructor with concat-ed shape (?, 2, ?):

  %101 = @map(tensor_constructor_float32_?_2_?(Tensor[(?, 2, ?), float32]), %100) /* ty=List[static_tensor_float32_?_2_?_t[]] */;
  %102 = @tensor_array_concat_last_float32_?_2_?(%101) /* ty=static_tensor_float32_?_2_?_t[] */;
  %103 = @tensor_get_data_float32_?_2_?(%102) /* ty=Tensor[(?, 2, ?), float32] */;

@masahi masahi merged commit 4b27cd1 into apache:master Apr 11, 2020
@masahi
Copy link
Member

masahi commented Apr 11, 2020

Thanks @kevinthesun

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Apr 16, 2020
…pache#5243)

* Support TF Frontend Static TensorArray

* Fix pylint

* Fix lint

* Move get_tensor_array_shape into prelude

* Fix lint

* Fix common
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Apr 17, 2020
…pache#5243)

* Support TF Frontend Static TensorArray

* Fix pylint

* Fix lint

* Move get_tensor_array_shape into prelude

* Fix lint

* Fix common
dpankratz pushed a commit to dpankratz/incubator-tvm that referenced this pull request Apr 24, 2020
…pache#5243)

* Support TF Frontend Static TensorArray

* Fix pylint

* Fix lint

* Move get_tensor_array_shape into prelude

* Fix lint

* Fix common
@kevinthesun kevinthesun deleted the ImproveTFTensorArray branch May 26, 2020 17:31
monklof pushed a commit to monklof/incubator-tvm that referenced this pull request Jan 22, 2021
…m_data:master to master

* commit 'cd0d52daa6942bdafa9363ff6cfa3d25fcd5b8d6': (824 commits)
  [Intrinsic] Add log1p, ldexp, atan2, hypot, nextafter, copysign (apache#5312)
  [Rust][CI] Restore Rust CI (apache#5137)
  Remove PrimExpr from String (apache#5311)
  [Requantize] Cleanup and Optimize Lowering (apache#5286)
  [IR][TRANSFORM] Enable CopyOnWrite for passes. (apache#5309)
  [PYTORCH]Abs, Arange, Softplus ops (apache#5295)
  [LLVM] Fix generation of LLVM intrinsics (apache#5282)
  [BYOC] Add example of Composite + Annotate for DNNL fused op (apache#5272)
  [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array (apache#5243)
  [RUNTIME] Introduce RValue reference(move) support to TypedPackedFunc (apache#5271)
  [RELAY][FRONTEND][CAFFE2] add Mul and ConvTranspose operator (apache#5302)
  [BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes (apache#5277)
  [CI] Fix the hexagon string (apache#5304)
  [Arith] linear system and equation solver (apache#5171)
  [PYTORCH]Repeat, Reciprocal & Reshape Op support (apache#5280)
  [FRONTEND][TENSORFLOW] Fix gather_nd indices (apache#5279)
  Update device_annotation.cc (apache#5291)
  [REFACTOR][IR] Move to runtime::String (apache#5276)
  [NDArray] Set shape_ in NDArray::FromDLPack (apache#5301)
  [RUNTIME] Initial implementation of Hexagon runtime support (apache#5252)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants