diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 67d561b..8850c3d 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -1148,7 +1148,7 @@ def _calc_match_score( (for an entire batch) """ # Remove trailing tokens from predictions based on decoder reversal - if decoder_reverse: + if not decoder_reverse: batch_all_aa_scores = batch_all_aa_scores[:, 1:] else: batch_all_aa_scores = batch_all_aa_scores[:, :-1] @@ -1163,6 +1163,8 @@ def _calc_match_score( per_aa_scores = batch_all_aa_scores[rows, cols, truth_aa_indices] + logging.debug("$$$$$$$$$$$$$||%s||$$$$$$$$$$$$$$", per_aa_scores) + per_aa_scores[per_aa_scores == 0] += 1e-10 score_mask = truth_aa_indices != 0 per_aa_scores[~score_mask] = 0