Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add perturbation functionality to geneformer's .predict() method #98

Merged
merged 21 commits into from
Feb 13, 2024

Conversation

sjfleming
Copy link
Contributor

@sjfleming sjfleming commented Oct 23, 2023

Adds the optional kwargs

feature_deletion
feature_activation
feature_map

to Geneformer.predict()

The idea of "deletion" and "activation" are described in the Geneformer paper. Deletion is achieved by zeroing expression before tokenization. Activation is achieved by setting feature expression to a high value so that it ends up at the top of the rank-ordered list after tokenization.

"Map" is a general kind of idea to implement other kinds of in silico experiments. It takes the data after tokenization, and maps tokens to other tokens. This could be used to implement the replacement of a specific token or set of tokens with a pad (0) or mask (1) token, or to switch the role of two tokens, etc.

@sjfleming
Copy link
Contributor Author

Closes #98

@sjfleming
Copy link
Contributor Author

So far there are no tests covering the new functionality

@sjfleming
Copy link
Contributor Author

Added a test to cover the new functionality

@sjfleming sjfleming requested a review from ordabayevy February 2, 2024 06:43
@sjfleming sjfleming marked this pull request as ready for review February 2, 2024 06:43
@sjfleming
Copy link
Contributor Author

@ordabayevy I think I might need some help with the mypy error here. It's coming from my added geneformer test. It doesn't seem to like the way I'm trying to pass **kwargs as an input? Do you know what's wrong? The pytest test itself passes.

@sjfleming sjfleming marked this pull request as draft February 2, 2024 06:47
@sjfleming
Copy link
Contributor Author

Okay nevermind, I just did some extra copy and paste and now mypy seems happy. I wonder how you're supposed to give **kwargs as an input in a way mypy will be happy with.

@sjfleming sjfleming marked this pull request as ready for review February 2, 2024 07:45
@sjfleming sjfleming requested a review from mbabadi February 2, 2024 07:45
Copy link
Contributor

@ordabayevy ordabayevy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Left some comments

x_ng: torch.Tensor,
feature_deletion: list[str] | None = None,
feature_activation: list[str] | None = None,
feature_map: dict[str, int] | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: sort inputs in the order feature_activation, feature_deletion, feature_map and elsewhere so they are in the same order just to avoid confusion

Returns:
A dictionary with the inference results.

NOTE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: change this to rst format .. note::

@@ -166,8 +202,14 @@ def predict(
output_attentions: bool = True,
output_input_ids: bool = True,
output_attention_mask: bool = True,
feature_map: dict[str, int] | None = None,
feature_activation: list[str] | None = None,
feature_deletion: list[str] | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: sort inputs in the order feature_activation, feature_deletion, feature_map

feature_deletion:
Specify features whose expression should be set to zero before tokenization (remove from inputs).
feature_activation:
Specify features whose expression should be set to > max(x_ng) before tokenization (top rank).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: sort inputs in the order feature_activation, feature_deletion, feature_map

if feature_deletion:
assert all(
[g in self.var_names_g for g in feature_deletion]
), "Some feature_deletion elements are not in self.var_names_g"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: raise ValueError

assert all(
[g in self.var_names_g for g in feature_deletion]
), "Some feature_deletion elements are not in self.var_names_g"
deletion_logic_g = np.logical_or.reduce([(self.var_names_g == g) for g in feature_deletion])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this correctly there is already a numpy function for this: np.isin(self.var_names_g, feature_deletion)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great call, I did not realize that!

max_val = x_ng.max()
for i, g in enumerate(feature_activation[::-1]):
feature_logic_g = self.var_names_g == g
assert feature_logic_g.sum() == 1, f"feature_activation element {g} is not in self.var_names_g"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: raise ValueError

if feature_map:
for g, target_token in feature_map.items():
feature_logic_g = self.var_names_g == g
assert feature_logic_g.sum() == 1, f"feature_map key {g} not in self.var_names_g"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: raise ValueError

for g, target_token in feature_map.items():
feature_logic_g = self.var_names_g == g
assert feature_logic_g.sum() == 1, f"feature_map key {g} not in self.var_names_g"
initial_token = self.feature_ids[feature_logic_g]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there would be any benefits of having a gene_name -> gene_id dictionary so we don't have to do this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it might not hurt


print(f"Expected input_ids:\n{expected_input_ids}")
print(f"Actual input_ids:\n{input_ids}")
torch.testing.assert_close(input_ids, expected_input_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know that torch.testing existed!

@ordabayevy
Copy link
Contributor

ordabayevy commented Feb 6, 2024

Notebook comment: you can directly index CellariumPipeline:

var_names_g = pipeline.get_submodule("2").var_names_g
# instead
var_names_g = pipeline[2].var_names_g
# or better
# var_names_g = pipeline.model.var_names_g  # sorry this actually won't work, it is for CellariumModule.model

@sjfleming
Copy link
Contributor Author

Alright I think I've addressed everything, worth double-checking

@ordabayevy ordabayevy merged commit 55a902c into main Feb 13, 2024
5 checks passed
@ordabayevy ordabayevy deleted the sf_predict_with_perturbations branch February 13, 2024 18:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants