Skip to content

Commit

Permalink
Add aten::mv support (#9894)
Browse files Browse the repository at this point in the history
* Rebase

* Fix bad merge

Co-authored-by: driazati <driazati@users.noreply.github.com>
  • Loading branch information
driazati and driazati authored Jan 11, 2022
1 parent f6b0c38 commit d1ee201
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 6 deletions.
14 changes: 14 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2856,6 +2856,19 @@ def dot(self, inputs, _):
lhs, rhs = inputs
return _op.sum(_op.multiply(lhs, rhs))

def mv(self, inputs, _):
lhs, rhs = inputs

# Convert the 1D matrix (vector) into a 2D matrix with the extra
# dimension=1
rhs_matrix = _op.transform.expand_dims(rhs, 0)

# Run multiplication
dense_result = _op.nn.dense(lhs, rhs_matrix, units=None)

# Chop off the extra result dimension
return _op.transform.squeeze(dense_result)

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -3081,6 +3094,7 @@ def create_convert_map(self):
"aten::roll": self.roll,
"aten::einsum": self.einsum,
"aten::dot": self.dot,
"aten::mv": self.mv,
}

def update_convert_map(self, custom_map):
Expand Down
2 changes: 1 addition & 1 deletion src/ir/diagnostic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ void ReportAt(const DiagnosticContext& context, std::ostream& out, const Span& s
if (it == context->module->source_map->source_map.end()) {
LOG(FATAL) << "The source maps are not populated for this module. "
<< "Please use `tvm.relay.transform.AnnotateSpans` to attach source maps for error "
"reporting. "
"reporting.\n"
<< "Error: " << diagnostic->message;
}

Expand Down
8 changes: 5 additions & 3 deletions src/relay/analysis/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,12 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {

if (mismatches.size() != 0) {
auto err = Diagnostic::Error(this->span);
err << "The Relay type checker is unable to show the following types match.\n";
err << "In particular ";
err << "The Relay type checker is unable to show the following types match:\n"
<< " " << PrettyPrint(tt1) << "\n"
<< " " << PrettyPrint(tt2) << "\n";
err << "In particular:\n";
for (auto mismatch : mismatches) {
err << "dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch)
err << " dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch)
<< " does not match " << std::get<2>(mismatch) << ".";
}
this->solver_->Emit(err);
Expand Down
10 changes: 10 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4096,5 +4096,15 @@ def test_fn(x):
verify_model(test_fn, [x])


@tvm.testing.uses_gpu
def test_mv():
def test_fn(m, v):
return m.mv(v)

verify_model(test_fn, [torch.randn(4, 4), torch.randn(4)])
verify_model(test_fn, [torch.randn(2, 2), torch.randn(2)])
verify_model(test_fn, [torch.randn(3, 8), torch.randn(8)])


if __name__ == "__main__":
pytest.main([__file__])
8 changes: 6 additions & 2 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,10 +1418,14 @@ def _body(i, st):
start = relay.var("start", shape=(), dtype="int32")
body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
func = relay.Function([start], relay.TupleGetItem(body, 1))

with DiagnosticTesting() as diagnostics:
diagnostics.assert_message(
"The Relay type checker is unable to show the following types "
"match.\nIn particular dimension 0 conflicts: 2 does not match 1."
"The Relay type checker is unable to show the following types match:\n"
" Tensor[(2, 1), int32]\n"
" Tensor[(1, 1), int32]\n"
"In particular:\n"
" dimension 0 conflicts: 2 does not match 1."
)
func = infer_type(func)

Expand Down

0 comments on commit d1ee201

Please sign in to comment.