forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relay/TOPI][Op] Add batch_matmul in relay and TOPI (apache#2561)
* Add batch_dot and cpu schedule * Add relay support for batch_dot * Rename batch_dot to batch_matmul * nits * Add missing file * Put batch_matmul and dense x86 schedule in separate files * Fix pylint * Remove unused import * Add cuda schedule for batch_matmul * Add test case with larger batch size * Add batch_matmul in api doc * Fix quantize pass rounding error * Fix pylint and minor change * bug fix
- Loading branch information
Showing
23 changed files
with
715 additions
and
212 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* \brief Batch matmul op constructions | ||
* \file nn/batch_matmul.h | ||
*/ | ||
#ifndef TOPI_NN_BATCH_MATMUL_H_ | ||
#define TOPI_NN_BATCH_MATMUL_H_ | ||
|
||
#include <string> | ||
|
||
#include "topi/tags.h" | ||
#include "tvm/tvm.h" | ||
|
||
namespace topi { | ||
namespace nn { | ||
using namespace tvm; | ||
|
||
/*! | ||
* \brief Creates an operation that calculates matrix multiplication in batch. | ||
* | ||
* \param x Tensor with shape [batch, M, K] | ||
* \param y Tensor with shape [batch, N, K] | ||
* | ||
* \return Tensor with shape [batch, M, N] | ||
*/ | ||
inline tvm::Tensor batch_matmul(const tvm::Tensor& x, | ||
const tvm::Tensor& y) { | ||
CHECK_EQ(x->shape.size(), 3) << "batch_matmul requires 3-D data"; | ||
CHECK_EQ(y->shape.size(), 3) << "batch_matmul requires 3-D data"; | ||
|
||
auto batch = x->shape[0]; | ||
auto M = x->shape[1]; | ||
auto K = x->shape[2]; | ||
auto N = y->shape[1]; | ||
|
||
auto k = tvm::reduce_axis(Range(0, K), "k"); | ||
auto result = tvm::compute( | ||
{ batch, M, N }, | ||
[&](Var b, Var i, Var j) { | ||
return tvm::sum(x(b, i, k) * y(b, j, k), { k }); | ||
}, "tensor", "batch_matmul"); | ||
|
||
return result; | ||
} | ||
|
||
} // namespace nn | ||
} // namespace topi | ||
|
||
#endif // TOPI_NN_BATCH_MATMUL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.