Skip to content

Commit

Permalink
[Relay][Frontend][ONNX] New Operators and Opsets to Support BERT (#4197)
Browse files Browse the repository at this point in the history
* Added slice v10

* Added constantofshape operation and small refactor.

* Finished one_hot implementation.

* Reshape working across all bert layers.

* Fixed constantofshape and removed code duplication.

* onnx model fully ingested.

* Working on improving onnx tests.

* Changed onnx testing to use onnxruntime instead of caffe2, also formatted.

* Add arbitrary output nodes to onnx frontend.

* Added v6 tiling for bert squad 8 support.

* Small syntax fixes

* Reduced code duplication in split opset versions.

* Added batch matmul test

* Added unstack split testing.

* Adde onehot test, needs a little cleanup probably.

* Replaced deprecated constant fill with constantofshape and updated tests accordingly.

* Added tests for new opset version of slice and tile.

* lint clean up

* Lint fixes

* Changed onnx dependency

* Went back to caffe2 runtime for CI integration.

* Rebase and small typo/syntax changes.

* Added hard casting of onehot attributes to int.
  • Loading branch information
jwfromm authored and jroesch committed Oct 30, 2019
1 parent 71f39be commit 156aa59
Show file tree
Hide file tree
Showing 4 changed files with 744 additions and 381 deletions.
46 changes: 46 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
import logging

import tvm
import numpy as np
from topi.util import get_const_tuple
from .. import expr as _expr
from .. import module as _module
from .. import transform as _transform
from .. import op as _op
from .. import analysis


class RequiredAttr(object):
Expand Down Expand Up @@ -474,6 +476,50 @@ def infer_channels(inputs, transpose=False):
return channels


def infer_value(input_val, params):
"""A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that
whose output shape depends on the value of a tensor.
"""
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _expr.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
return m.get_output(0)


def infer_value_simulated(input_val, params):
"""Extention to infer_value that can be used when some input
values are missing. This function creates dummy inputs with the same
shape and random values then calls infer_value. This is helpful when
implementing certain onnx operators where we need to evaluate the graph
to determine a static shape.
"""
fake_params = []
# Add a fake copy of all missing params.
for free_param in analysis.free_vars(input_val):
if free_param.name_hint not in params:
fp_dtype = free_param.type_annotation.dtype
fp_shape = [s.value for s in free_param.type_annotation.shape]
fake_params.append(free_param)
params[free_param.name_hint] = tvm.nd.array(
np.random.rand(*fp_shape).astype(fp_dtype)
)
# Now infer the value.
output_value = infer_value(input_val, params)
# Clean fake params out of param dictionary.
for fake_p in fake_params:
params.pop(fake_p.name_hint, None)
return output_value


def new_var(name_hint,
type_annotation=None,
shape=None,
Expand Down
Loading

0 comments on commit 156aa59

Please sign in to comment.