Skip to content

Commit

Permalink
Added more fusion and vectorized kernel for transducer (#1125)
Browse files Browse the repository at this point in the history
* Added support for fused ReLU and dropout into transducer joint

* Reorganized code selection path in transducer joint fwd
* Added support for fused ReLU+dropout into transducer joint

* Vectorize transducer loss backward with fused softmax (#3)

* Nanz/transducer loss (#4)

* Vectorize transducer loss backward with fused softmax

* Added a predicate to avoid potential IMA

* Nanz/transducer loss (#5)

* Vectorize transducer loss backward with fused softmax

* Added a predicate to avoid potentional IMA

* Added more predicates to avoid IMAs

* Updated documentations for newly added features.

* Fixed a error in transducer.py
  • Loading branch information
nanz-nv authored Jul 17, 2021
1 parent ed71996 commit 0c2c6ee
Show file tree
Hide file tree
Showing 8 changed files with 662 additions and 185 deletions.
32 changes: 23 additions & 9 deletions apex/contrib/csrc/transducer/transducer_joint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

torch::Tensor transducer_joint_cuda_forward(
std::vector<torch::Tensor> transducer_joint_cuda_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
Expand All @@ -14,19 +14,23 @@ torch::Tensor transducer_joint_cuda_forward(
int64_t packedBatch,
int opt,
bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize);


std::vector<torch::Tensor> transducer_joint_cuda_backward(
torch::Tensor grad,
std::vector<torch::Tensor> in,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput);
bool packOutput,
float scale);

torch::Tensor transducer_joint_forward(
std::vector<torch::Tensor> transducer_joint_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
Expand All @@ -35,6 +39,9 @@ torch::Tensor transducer_joint_forward(
int64_t packedBatch,
int opt,
bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize) {
CHECK_INPUT(f);
CHECK_INPUT(g);
Expand All @@ -51,30 +58,37 @@ torch::Tensor transducer_joint_forward(
packedBatch,
opt,
packOutput,
relu,
dropout,
dropoutProb,
tileSize);
}

std::vector<torch::Tensor> transducer_joint_backward(
torch::Tensor grad,
std::vector<torch::Tensor> in,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput) {
CHECK_INPUT(grad);
bool packOutput,
float scale) {
for (auto t : in){
CHECK_INPUT(t);
}
CHECK_INPUT(fLen);
CHECK_INPUT(gLen);
if (packOutput)
CHECK_INPUT(batchOffset);
return transducer_joint_cuda_backward(
grad,
in,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
packOutput);
packOutput,
scale);
}


Expand Down
Loading

0 comments on commit 0c2c6ee

Please sign in to comment.