diff --git a/docs/conf.py b/docs/conf.py index 390dbed..8bfe0c2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,7 +16,12 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = ["sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.napoleon", "sphinx.ext.doctest"] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx.ext.doctest", +] templates_path = ["_templates"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] @@ -27,9 +32,9 @@ html_theme = "sphinx_book_theme" html_logo = "../sceptr.svg" -html_static_path = ['_static'] +html_static_path = ["_static"] html_css_files = [ - 'css/colours.css', + "css/colours.css", ] html_theme_options = { "repository_url": "https://github.com/yutanagano/sceptr", diff --git a/src/sceptr/model.py b/src/sceptr/model.py index d321f07..4c61eb8 100644 --- a/src/sceptr/model.py +++ b/src/sceptr/model.py @@ -97,10 +97,13 @@ class ResidueRepresentations: Note that the zeroth element of the shape tuple above is 14 because the CDR3B sequence of the first TCR in ``tcrs`` is 14 residues long, and the first element of the shape tuple is 64 because the model dimensionality of the default SCEPTR variant is 64. """ + representation_array: ndarray compartment_mask: ndarray - def __init__(self, representation_array: ndarray, compartment_mask: ndarray) -> None: + def __init__( + self, representation_array: ndarray, compartment_mask: ndarray + ) -> None: self.representation_array = representation_array self.compartment_mask = compartment_mask @@ -150,7 +153,9 @@ def calc_vector_representations(self, instances: DataFrame) -> ndarray: return torch_representations.cpu().numpy() @torch.no_grad() - def calc_residue_representations(self, instances: DataFrame) -> ResidueRepresentations: + def calc_residue_representations( + self, instances: DataFrame + ) -> ResidueRepresentations: """ Map each TCR to a set of amino acid residue-level representations. The residue-level representations are the output of the penultimate self-attention layer, as also used by the :py:func:`~sceptr.variant.average_pooling` variant when generating TCR receptor-level representations. @@ -171,7 +176,9 @@ def calc_residue_representations(self, instances: DataFrame) -> ResidueRepresent For details on how to interpret/use this output, please refer to the documentation for :py:class:`~sceptr.model.ResidueRepresentations`. """ if not isinstance(self._tokeniser, CdrTokeniser): - raise NotImplementedError("The calc_residue_representations method is currently only supported on SCEPTR model variants that 1) use both the alpha and beta chains, and 2) take into account all three CDR loops from each chain.") + raise NotImplementedError( + "The calc_residue_representations method is currently only supported on SCEPTR model variants that 1) use both the alpha and beta chains, and 2) take into account all three CDR loops from each chain." + ) instances = instances.copy() @@ -196,7 +203,9 @@ def calc_residue_representations(self, instances: DataFrame) -> ResidueRepresent raw_token_embeddings = self._bert._embed(padded_batch) padding_mask = self._bert._get_padding_mask(padded_batch) - residue_reps = self._bert._self_attention_stack.get_token_embeddings_at_penultimate_layer(raw_token_embeddings, padding_mask) + residue_reps = self._bert._self_attention_stack.get_token_embeddings_at_penultimate_layer( + raw_token_embeddings, padding_mask + ) residue_reps = residue_reps[:, 1:, :] compartment_masks = padded_batch[:, 1:, 3] @@ -204,8 +213,12 @@ def calc_residue_representations(self, instances: DataFrame) -> ResidueRepresent residue_reps_collection.append(residue_reps) compartment_masks_collection.append(compartment_masks) - residue_reps_combined = torch.concatenate(residue_reps_collection, dim=0).cpu().numpy() - compartment_masks_combined = torch.concatenate(compartment_masks_collection, dim=0).cpu().numpy() + residue_reps_combined = ( + torch.concatenate(residue_reps_collection, dim=0).cpu().numpy() + ) + compartment_masks_combined = ( + torch.concatenate(compartment_masks_collection, dim=0).cpu().numpy() + ) return ResidueRepresentations(residue_reps_combined, compartment_masks_combined) diff --git a/src/sceptr/variant.py b/src/sceptr/variant.py index 99cea3e..104ee42 100644 --- a/src/sceptr/variant.py +++ b/src/sceptr/variant.py @@ -198,6 +198,11 @@ def a_sceptr(): """ Load the alpha chain-only variant of SCEPTR. This variant has the same architecture as the default, but is specifically trained only with the alpha chain in distribution. + Thus, this model cannot interpret paired-chain or beta chain-only data. + + .. important :: + **This variant is unrelated to the single-chain analysis in our preprint**, which involved applying the :py:func:`~sceptr.variant.default` model to single-chain data. + In contrast, this variant is a *distinct model* that was pre-trained specifically *only on alpha chains*. .. note :: Because this model is trained only with the alpha chain in distribution, we expect it to perform slightly better than the default in settings where strictly only the alpha chains are available. @@ -215,6 +220,11 @@ def b_sceptr(): """ Load the beta chain-only variant of SCEPTR. This variant has the same architecture as the default, but is specifically trained only with the beta chain in distribution. + Thus, this model cannot interpret paired-chain or alpha chain-only data. + + .. important :: + **This variant is unrelated to the single-chain analysis in our preprint**, which involved applying the :py:func:`~sceptr.variant.default` model to single-chain data. + In contrast, this variant is a *distinct model* that was pre-trained specifically *only on beta chains*. .. note :: Because this model is trained only with the beta chain in distribution, we expect it to perform slightly better than the default in settings where strictly only the beta chains are available. diff --git a/tests/test_variants.py b/tests/test_variants.py index 96ec164..04c24d3 100644 --- a/tests/test_variants.py +++ b/tests/test_variants.py @@ -63,7 +63,7 @@ def test_residue_representations(self, model, dummy_data): "SCEPTR (small)", "SCEPTR (BLOSUM)", "SCEPTR (average-pooling)", - "SCEPTR (finetuned)" + "SCEPTR (finetuned)", ): result = model.calc_residue_representations(dummy_data)