Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF frontend][bugfix]Avoid making a new node when already has span info #7789

Merged
merged 9 commits into from
Apr 9, 2021
Merged

[TF frontend][bugfix]Avoid making a new node when already has span info #7789

merged 9 commits into from
Apr 9, 2021

Conversation

xqdan
Copy link
Contributor

@xqdan xqdan commented Apr 2, 2021

Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.

@lixiaoquan @zhiics @srkreddy1238

This bug will lead redundent operations in graph ir when parsering bert:

%4 = add(%0, %3) /* bert/embeddings/add */;
  %5 = strided_slice(meta[relay.Constant][1], begin=[0, 0], end=[128, -1], strides=[1, 1], slice_mode="size") /* bert/embeddings/Slice */;
  %6 = reshape(%5, newshape=[1, 128, 768]) /* bert/embeddings/Reshape_4 */;
  %7 = add(%4, %6) /* bert/embeddings/add_1 */;
  %8 = mean(%7, axis=[2], keepdims=True) /* bert/embeddings/LayerNorm/moments/StopGradient */;
  %9 = subtract(%7, %8);
  %10 = multiply(%9, %9) /* bert/embeddings/LayerNorm/moments/SquaredDifference */;
  %11 = mean(%10, axis=[2], keepdims=True) /* bert/embeddings/LayerNorm/moments/variance */;
  %12 = add(%11, 1e-12f) /* bert/embeddings/LayerNorm/batchnorm/add */;
  %13 = power(%12, -0.5f) /* bert/embeddings/LayerNorm/batchnorm/Rsqrt */;
  %14 = multiply(%13, meta[relay.Constant][2]) /* bert/embeddings/LayerNorm/batchnorm/mul */;
  %15 = multiply(%7, %14) /* bert/embeddings/LayerNorm/batchnorm/mul_1 */;
  %16 = mean(%7, axis=[2], keepdims=True) /* bert/embeddings/LayerNorm/moments/mean */; --> %16 is the same as %8
  %17 = multiply(%16, %14) /* bert/embeddings/LayerNorm/batchnorm/mul_2 */;
  %18 = subtract(meta[relay.Constant][3], %17) /* bert/embeddings/LayerNorm/batchnorm/sub */;

The last mean is redundent op created by _set_span.

Since this is pb parsering flow, does anyone know how to make a unitest?

@xqdan xqdan changed the title Avoid making a new node when already has span info [TF front end]Avoid making a new node when already has span info Apr 2, 2021
@xqdan xqdan changed the title [TF front end]Avoid making a new node when already has span info [TF front end][bugfix]Avoid making a new node when already has span info Apr 2, 2021
@xqdan xqdan changed the title [TF front end][bugfix]Avoid making a new node when already has span info [TF frontend][bugfix]Avoid making a new node when already has span info Apr 2, 2021
@tqchen
Copy link
Member

tqchen commented Apr 3, 2021

cc @zhiics can you please help manage this PR Thank you

@@ -3851,11 +3851,11 @@ def _convert_operator(
@staticmethod
def _set_span(sym, node_name):
span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
if isinstance(sym, _expr.Call):
if isinstance(sym, _expr.Call) and sym.span is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a test case for these two lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd love to, but I don't know how to create unitest for pb parsering flow, any suggestion? thanks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this require pb parsing or directly creating a tf program would be sufficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this require pb parsing or directly creating a tf program would be sufficient?

You are right,where can I find an exmple? I know the basic idea is creatig a tf graph, using from_tensorflow to handle it, then check the output graph ir.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can refer to the existing unit/integration tests for tf frontend. Most of them just create a tf graph.

@lixiaoquan
Copy link
Contributor

lixiaoquan commented Apr 6, 2021

It is ok to me.

I guess this is caused by different nodes in PB are being parsed to the same Relay node, so they have different names.

I think EliminateCommonSubexpr() should be able to eliminate the redundancy, but it requires opt_level 3, so it is disabled by default.

int opt_level{2};

xiaoqiang.dan added 2 commits April 6, 2021 17:48
@xqdan
Copy link
Contributor Author

xqdan commented Apr 6, 2021

It is ok to me.

I guess this is caused by different nodes in PB are being parsed to the same Relay node, so they have different names.

I think EliminateCommonSubexpr() should be able to eliminate the redundancy, but it requires opt_level 3, so it is disabled by default.

int opt_level{2};

Yes,EliminateCommonSubexpr() can remove redundant operations, however better to fix at the very beginning.

@xqdan xqdan closed this Apr 6, 2021
@xqdan xqdan reopened this Apr 6, 2021
xiaoqiang.dan added 3 commits April 6, 2021 17:58
@xqdan
Copy link
Contributor Author

xqdan commented Apr 7, 2021

@zhiics @kevinthesun Added a unites, please review. BTW it's easy to add nn.moments as unitest because we met it, however I don't know which tf graph/layer can tigger tuple branch, appreciate if you have suggestions.

@zhiics
Copy link
Member

zhiics commented Apr 8, 2021

@xqdan thanks for adding the test. But I think we probably don't need to create a new test file though. Instead, we can put it under test_forward as the operator tests are sitting there.

xiaoqiang.dan added 2 commits April 8, 2021 14:54
@zhiics zhiics merged commit 461d06e into apache:main Apr 9, 2021
@zhiics
Copy link
Member

zhiics commented Apr 9, 2021

Thanks @xqdan @kevinthesun

tmoreau89 pushed a commit to tmoreau89/tvm that referenced this pull request Apr 11, 2021
…fo (apache#7789)

* Avoid making a new node when already has span info

* add test

* add test

* add test

* fix

* fix

* move test to test_forward.py

* fix

* fix

Co-authored-by: xiaoqiang.dan <xiaoqiang.dan@streamcoputing.com>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
…fo (apache#7789)

* Avoid making a new node when already has span info

* add test

* add test

* add test

* fix

* fix

* move test to test_forward.py

* fix

* fix

Co-authored-by: xiaoqiang.dan <xiaoqiang.dan@streamcoputing.com>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
…fo (apache#7789)

* Avoid making a new node when already has span info

* add test

* add test

* add test

* fix

* fix

* move test to test_forward.py

* fix

* fix

Co-authored-by: xiaoqiang.dan <xiaoqiang.dan@streamcoputing.com>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
…fo (apache#7789)

* Avoid making a new node when already has span info

* add test

* add test

* add test

* fix

* fix

* move test to test_forward.py

* fix

* fix

Co-authored-by: xiaoqiang.dan <xiaoqiang.dan@streamcoputing.com>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request May 11, 2021
…fo (apache#7789)

* Avoid making a new node when already has span info

* add test

* add test

* add test

* fix

* fix

* move test to test_forward.py

* fix

* fix

Co-authored-by: xiaoqiang.dan <xiaoqiang.dan@streamcoputing.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants