Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend][ONNX] New Operators and Opsets to Support BERT #4197

Merged
merged 23 commits into from
Oct 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0f9e667
Added slice v10
Oct 16, 2019
d7d0de3
Added constantofshape operation and small refactor.
jwfromm Oct 16, 2019
019190c
Finished one_hot implementation.
jwfromm Oct 17, 2019
3826216
Reshape working across all bert layers.
jwfromm Oct 18, 2019
9442f6a
Fixed constantofshape and removed code duplication.
jwfromm Oct 18, 2019
c558ef7
onnx model fully ingested.
jwfromm Oct 18, 2019
01d2145
Working on improving onnx tests.
jwfromm Oct 22, 2019
05d8905
Changed onnx testing to use onnxruntime instead of caffe2, also forma…
jwfromm Oct 22, 2019
bffad89
Add arbitrary output nodes to onnx frontend.
jwfromm Oct 22, 2019
94654af
Added v6 tiling for bert squad 8 support.
jwfromm Oct 22, 2019
e9a2591
Small syntax fixes
jwfromm Oct 23, 2019
cb25dad
Reduced code duplication in split opset versions.
jwfromm Oct 23, 2019
123cae8
Added batch matmul test
jwfromm Oct 24, 2019
7f07ffa
Added unstack split testing.
jwfromm Oct 24, 2019
7703198
Adde onehot test, needs a little cleanup probably.
jwfromm Oct 24, 2019
4988d33
Replaced deprecated constant fill with constantofshape and updated te…
jwfromm Oct 24, 2019
b7e2644
Added tests for new opset version of slice and tile.
jwfromm Oct 24, 2019
14737d2
lint clean up
jwfromm Oct 24, 2019
eea6fc4
Lint fixes
jwfromm Oct 24, 2019
0c108a7
Changed onnx dependency
jwfromm Oct 24, 2019
89876cb
Went back to caffe2 runtime for CI integration.
jwfromm Oct 25, 2019
b18de8b
Rebase and small typo/syntax changes.
jwfromm Oct 28, 2019
bbf203c
Added hard casting of onehot attributes to int.
jwfromm Oct 28, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
import logging

import tvm
import numpy as np
from topi.util import get_const_tuple
from .. import expr as _expr
from .. import module as _module
from .. import transform as _transform
from .. import op as _op
from .. import analysis


class RequiredAttr(object):
Expand Down Expand Up @@ -474,6 +476,50 @@ def infer_channels(inputs, transpose=False):
return channels


def infer_value(input_val, params):
"""A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that
whose output shape depends on the value of a tensor.
"""
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _expr.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
return m.get_output(0)


def infer_value_simulated(input_val, params):
"""Extention to infer_value that can be used when some input
values are missing. This function creates dummy inputs with the same
shape and random values then calls infer_value. This is helpful when
implementing certain onnx operators where we need to evaluate the graph
to determine a static shape.
"""
fake_params = []
# Add a fake copy of all missing params.
for free_param in analysis.free_vars(input_val):
if free_param.name_hint not in params:
fp_dtype = free_param.type_annotation.dtype
fp_shape = [s.value for s in free_param.type_annotation.shape]
fake_params.append(free_param)
params[free_param.name_hint] = tvm.nd.array(
np.random.rand(*fp_shape).astype(fp_dtype)
)
# Now infer the value.
output_value = infer_value(input_val, params)
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
# Clean fake params out of param dictionary.
for fake_p in fake_params:
params.pop(fake_p.name_hint, None)
return output_value


def new_var(name_hint,
type_annotation=None,
shape=None,
Expand Down
Loading