Skip to content

Commit

Permalink
Fix factor map op for shortlist
Browse files Browse the repository at this point in the history
  • Loading branch information
rhenry-nv committed Jul 9, 2021
1 parent 22d13b3 commit 7be492a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/graph/node_operators_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -1291,9 +1291,9 @@ size_t numLemmas_;
bool hasShortlist_;
public:
AddFactorMaxesOp(const std::vector<Expr>& nodes, bool hasShortlist, size_t groupStart, size_t numLemmas)
: NaryNodeOp(nodes, getShape(nodes, hasShortlist), commonType(std::vector<Expr>(nodes.begin() + 1, nodes.end())) ) {
: NaryNodeOp(nodes, getShape(nodes, hasShortlist), commonType(std::vector<Expr>(nodes.begin() + 1 + (int)hasShortlist, nodes.end())) ) {
groupStart_ = groupStart;
numLemmas_ = hasShortlist? nodes[1]->shape().size(): numLemmas;
numLemmas_ = numLemmas;
hasShortlist_ = hasShortlist;
}

Expand Down
2 changes: 1 addition & 1 deletion src/layers/generic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ namespace marian {
} else {
auto numGroups = getNumFactorGroups();
if(numGroups > 1 && graph()->isInference() && graph()->getBackend()->getDeviceId().type == DeviceType::gpu) {
Expr shortlistIndices = shortlist? constant(shortlist->indices()) : nullptr;
Expr shortlistIndices = shortlist? indices(shortlist->indices()) : nullptr;
Expr lemmaHasFactorGroupTensor = getLemmaHasFactorGroupTensor();
std::vector<Expr> groupLosses(logits_.size());
std::transform(logits_.begin(), logits_.end(), groupLosses.begin(), [](const Ptr<RationalLoss>& loss) -> Expr {return loss->loss();});
Expand Down

0 comments on commit 7be492a

Please sign in to comment.