diff --git a/CHANGELOG.md b/CHANGELOG.md index 5218f91..4fe542d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] + +## [v0.4.8] +### Changed +- `Tokenizer.detokenize()` now truncates the output to the first stop token it finds, if `trim_stop_token=True`. + ## [v0.4.7] ### Fixed - Add stop and start tokens for `AnnotatedSpectrumDataset`, when available. diff --git a/depthcharge/tokenizers/peptides.py b/depthcharge/tokenizers/peptides.py index 9207f81..427eeb7 100644 --- a/depthcharge/tokenizers/peptides.py +++ b/depthcharge/tokenizers/peptides.py @@ -194,7 +194,7 @@ def detokenize( tokens=tokens, join=join, trim_start_token=trim_start_token, - trim_stop_token=trim_start_token, + trim_stop_token=trim_stop_token, ) if self.reverse: diff --git a/depthcharge/tokenizers/tokenizer.py b/depthcharge/tokenizers/tokenizer.py index c07a0fa..d7f8d2c 100644 --- a/depthcharge/tokenizers/tokenizer.py +++ b/depthcharge/tokenizers/tokenizer.py @@ -133,7 +133,7 @@ def detokenize( trim_start_token : bool, optional Remove the start token from the beginning of a sequence. trim_stop_token : bool, optional - Remove the stop token from the end of a sequence. + Remove the stop token and anything following it from the sequence. Returns ------- @@ -143,16 +143,18 @@ def detokenize( """ decoded = [] for row in tokens: - seq = [ - self.reverse_index[i] - for i in row - if self.reverse_index[i] is not None - ] + seq = [] + for idx in row: + if self.reverse_index[idx] is None: + continue + + if trim_stop_token and idx == self.stop_int: + break + + seq.append(self.reverse_index[idx]) if trim_start_token and seq[0] == self.start_token: seq.pop(0) - if trim_stop_token and seq[-1] == self.stop_token: - seq.pop(-1) if join: seq = "".join(seq) diff --git a/tests/unit_tests/test_tokenizers/test_peptides.py b/tests/unit_tests/test_tokenizers/test_peptides.py index 6e4a66f..b4cef95 100644 --- a/tests/unit_tests/test_tokenizers/test_peptides.py +++ b/tests/unit_tests/test_tokenizers/test_peptides.py @@ -115,3 +115,25 @@ def test_almost_compliant_proform(): """Test initializing with a peptide without an expicit mass sign.""" tokenizer = PeptideTokenizer.from_proforma("[10]-EDITHR") assert "[+10.000000]-" in tokenizer.residues + + +@pytest.mark.parametrize( + ("start", "stop", "expected"), + [ + (True, True, "ACD"), + (True, False, "ACD$E"), + (False, True, "?ACD"), + (False, False, "?ACD$E"), + ], +) +def test_trim(start, stop, expected): + """Test that the start and stop tokens can be trimmed.""" + tokenizer = PeptideTokenizer(start_token="?") + tokens = torch.tensor([[0, 2, 3, 4, 5, 1, 6]]) + out = tokenizer.detokenize( + tokens, + trim_start_token=start, + trim_stop_token=stop, + ) + + assert out[0] == expected