-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP #6699
Conversation
Update type_relations.cc Update transform.cc Update transform.cc Update transform.cc Update transform.cc Update transform.cc Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py update Update mxnet.py debug Update generic.py Update topi_integration.py fix bug update Update test_forward.py Update test_forward.py fix test case Update mxnet.py update Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py Update mxnet.py Update mxnet.py debug Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py
As we add more tests can we measure what kind of time increase this will induce in CI? integration tests are becoming increasingly slow and expensive to run. cc @areusch and @tkonolige |
The integration tests take a very long time because there are two many combinations. For example: https://github.com/apache/incubator-tvm/blob/461e75bd5ffaf45a0f270998514d444463d11261/tests/python/frontend/mxnet/test_forward.py#L2119-L2125 We may try to simplify the tests by not using a full cartesian product |
debug Update common.py update Update mxnet.py update Update test_forward.py Update test_forward.py
I've verified the TVM integration with 5 NLP backbones in GluonNLP: BERT, ALBERT, ELECTRA, RoBERTA, and BART import mxnet as mx
import numpy as np
import gluonnlp
from gluonnlp.models import get_backbone
import numpy.testing as npt
import tvm
from tvm import relay
import tvm.contrib.graph_runtime as runtime
mx.npx.set_np()
instance_info = {
'g4': {'target': "cuda -model=t4", 'use_gpu': True},
'c4': {'target': 'llvm -mcpu=core-avx2 -libs=cblas', 'use_gpu': False},
'c5': {'target': 'llvm -mcpu=skylake-avx512 -libs=cblas', 'use_gpu': False},
'p3': {'target': 'cuda -model=v100', 'use_gpu': True}
}
def test_backbone(model_name, batch_size=2, seq_length=128, instance='g4',
required_pass=None, opt_level=3):
if required_pass is None:
required_pass = ["FastMath"]
model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(model_name)
model = model_cls.from_cfg(cfg)
model.load_parameters(backbone_param_path)
model.hybridize()
token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length), dtype=np.int32)
token_types = mx.np.random.randint(0, 2, (batch_size, seq_length), dtype=np.int32)
valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,), dtype=np.int32)
if 'bart' in model_name:
mx_out = model(token_ids, valid_length, token_ids, valid_length)
shape_dict = {
'data0': token_ids.shape,
'data1': valid_length.shape,
'data2': token_ids.shape,
'data3': valid_length.shape,
}
dtype_dict = {
'data0': token_ids.dtype.name,
'data1': valid_length.dtype.name,
'data2': token_ids.dtype.name,
'data3': valid_length.dtype.name,
}
elif 'roberta' in model_name or 'xlmr' in model_name:
mx_out = model(token_ids, valid_length)
shape_dict = {
'data0': token_ids.shape,
'data1': valid_length.shape,
}
dtype_dict = {
'data0': token_ids.dtype.name,
'data1': valid_length.dtype.name,
}
else:
mx_out = model(token_ids, token_types, valid_length)
shape_dict = {
'data0': token_ids.shape,
'data1': token_types.shape,
'data2': valid_length.shape
}
dtype_dict = {
'data0': token_ids.dtype.name,
'data1': token_types.dtype.name,
'data2': valid_length.dtype.name
}
sym = model._cached_graph[1]
params = {}
for k, v in model.collect_params().items():
params[v._var_name] = tvm.nd.array(v.data().asnumpy())
mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params)
target = instance_info[instance]['target']
use_gpu = instance_info[instance]['use_gpu']
with relay.build_config(opt_level=opt_level, required_pass=required_pass):
graph, lib, cparams = relay.build(mod, target, params=params)
if use_gpu:
ctx = tvm.gpu()
else:
ctx = tvm.cpu()
rt = runtime.create(graph, lib, ctx)
rt.set_input(**cparams)
if 'bart' in model_name:
rt.set_input(data0=token_ids, data1=valid_length, data2=token_ids, data3=valid_length)
elif 'roberta' in model_name:
rt.set_input(data0=token_ids, data1=valid_length)
else:
rt.set_input(data0=token_ids, data1=token_types, data2=valid_length)
rt.run()
for i in range(rt.get_num_outputs()):
out = rt.get_output(i)
if rt.get_num_outputs() == 1:
mx_out_gt = mx_out.asnumpy()
else:
mx_out_gt = mx_out[i].asnumpy()
if 'mobilebert' in model_name and len(out.shape) == 3:
npt.assert_allclose(out.asnumpy()[:, 1:, :], mx_out[i].asnumpy()[:, 1:, :],
rtol=6e-2, atol=6e-2)
else:
npt.assert_allclose(out.asnumpy(), mx_out_gt, rtol=6e-2, atol=6e-2)
# test_backbone('google_en_cased_bert_base', instance='g4')
test_model_names = ['google_albert_base_v2',
'google_en_cased_bert_base',
'google_electra_small',
'fairseq_roberta_base',
'fairseq_bart_base']
for model_name in test_model_names:
test_backbone(model_name, instance='g4') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to me. Thanks @sxjscience
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
Thanks @sxjscience @yzhliu. The test simplification could be in the follow up PRs. |
…nNLP (apache#6699) * update Update type_relations.cc Update transform.cc Update transform.cc Update transform.cc Update transform.cc Update transform.cc Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py update Update mxnet.py debug Update generic.py Update topi_integration.py fix bug update Update test_forward.py Update test_forward.py fix test case Update mxnet.py update Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py Update mxnet.py Update mxnet.py debug Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py * address comments * Update mxnet.py * Update mxnet.py * fix * improve where test * Update test_forward.py * Update test_forward.py * Update test_forward.py * update * Update mxnet.py * Update mxnet.py * Update mxnet.py debug Update common.py update Update mxnet.py update Update test_forward.py Update test_forward.py * update * fix lint * Update mxnet.py * Update test_op_level1.py * fix lint
…nNLP (apache#6699) * update Update type_relations.cc Update transform.cc Update transform.cc Update transform.cc Update transform.cc Update transform.cc Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py update Update mxnet.py debug Update generic.py Update topi_integration.py fix bug update Update test_forward.py Update test_forward.py fix test case Update mxnet.py update Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py Update mxnet.py Update mxnet.py debug Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py * address comments * Update mxnet.py * Update mxnet.py * fix * improve where test * Update test_forward.py * Update test_forward.py * Update test_forward.py * update * Update mxnet.py * Update mxnet.py * Update mxnet.py debug Update common.py update Update mxnet.py update Update test_forward.py Update test_forward.py * update * fix lint * Update mxnet.py * Update test_op_level1.py * fix lint
…nNLP (apache#6699) * update Update type_relations.cc Update transform.cc Update transform.cc Update transform.cc Update transform.cc Update transform.cc Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py update Update mxnet.py debug Update generic.py Update topi_integration.py fix bug update Update test_forward.py Update test_forward.py fix test case Update mxnet.py update Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py Update mxnet.py Update mxnet.py debug Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py * address comments * Update mxnet.py * Update mxnet.py * fix * improve where test * Update test_forward.py * Update test_forward.py * Update test_forward.py * update * Update mxnet.py * Update mxnet.py * Update mxnet.py debug Update common.py update Update mxnet.py update Update test_forward.py Update test_forward.py * update * fix lint * Update mxnet.py * Update test_op_level1.py * fix lint
…nNLP (apache#6699) * update Update type_relations.cc Update transform.cc Update transform.cc Update transform.cc Update transform.cc Update transform.cc Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py update Update mxnet.py debug Update generic.py Update topi_integration.py fix bug update Update test_forward.py Update test_forward.py fix test case Update mxnet.py update Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py Update mxnet.py Update mxnet.py debug Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py * address comments * Update mxnet.py * Update mxnet.py * fix * improve where test * Update test_forward.py * Update test_forward.py * Update test_forward.py * update * Update mxnet.py * Update mxnet.py * Update mxnet.py debug Update common.py update Update mxnet.py update Update test_forward.py Update test_forward.py * update * fix lint * Update mxnet.py * Update test_op_level1.py * fix lint
Fix the MXNet 2.0 integration in relay. Tested the BERT and ALBERT model in the new GluonNLP 1.0 and has passed the test. I will later add unittests in GluonNLP side to ensure that the backbones can be run with the graph runtime.