-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add perturbation functionality to geneformer's .predict()
method
#98
Conversation
Closes #98 |
So far there are no tests covering the new functionality |
…with_perturbations
Added a test to cover the new functionality |
@ordabayevy I think I might need some help with the |
Okay nevermind, I just did some extra copy and paste and now mypy seems happy. I wonder how you're supposed to give **kwargs as an input in a way mypy will be happy with. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Left some comments
x_ng: torch.Tensor, | ||
feature_deletion: list[str] | None = None, | ||
feature_activation: list[str] | None = None, | ||
feature_map: dict[str, int] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: sort inputs in the order feature_activation, feature_deletion, feature_map
and elsewhere so they are in the same order just to avoid confusion
cellarium/ml/models/geneformer.py
Outdated
Returns: | ||
A dictionary with the inference results. | ||
|
||
NOTE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: change this to rst format .. note::
@@ -166,8 +202,14 @@ def predict( | |||
output_attentions: bool = True, | |||
output_input_ids: bool = True, | |||
output_attention_mask: bool = True, | |||
feature_map: dict[str, int] | None = None, | |||
feature_activation: list[str] | None = None, | |||
feature_deletion: list[str] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: sort inputs in the order feature_activation, feature_deletion, feature_map
feature_deletion: | ||
Specify features whose expression should be set to zero before tokenization (remove from inputs). | ||
feature_activation: | ||
Specify features whose expression should be set to > max(x_ng) before tokenization (top rank). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: sort inputs in the order feature_activation, feature_deletion, feature_map
cellarium/ml/models/geneformer.py
Outdated
if feature_deletion: | ||
assert all( | ||
[g in self.var_names_g for g in feature_deletion] | ||
), "Some feature_deletion elements are not in self.var_names_g" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: raise ValueError
cellarium/ml/models/geneformer.py
Outdated
assert all( | ||
[g in self.var_names_g for g in feature_deletion] | ||
), "Some feature_deletion elements are not in self.var_names_g" | ||
deletion_logic_g = np.logical_or.reduce([(self.var_names_g == g) for g in feature_deletion]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand this correctly there is already a numpy function for this: np.isin(self.var_names_g, feature_deletion)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great call, I did not realize that!
cellarium/ml/models/geneformer.py
Outdated
max_val = x_ng.max() | ||
for i, g in enumerate(feature_activation[::-1]): | ||
feature_logic_g = self.var_names_g == g | ||
assert feature_logic_g.sum() == 1, f"feature_activation element {g} is not in self.var_names_g" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: raise ValueError
cellarium/ml/models/geneformer.py
Outdated
if feature_map: | ||
for g, target_token in feature_map.items(): | ||
feature_logic_g = self.var_names_g == g | ||
assert feature_logic_g.sum() == 1, f"feature_map key {g} not in self.var_names_g" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: raise ValueError
for g, target_token in feature_map.items(): | ||
feature_logic_g = self.var_names_g == g | ||
assert feature_logic_g.sum() == 1, f"feature_map key {g} not in self.var_names_g" | ||
initial_token = self.feature_ids[feature_logic_g] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there would be any benefits of having a gene_name -> gene_id
dictionary so we don't have to do this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it might not hurt
|
||
print(f"Expected input_ids:\n{expected_input_ids}") | ||
print(f"Actual input_ids:\n{input_ids}") | ||
torch.testing.assert_close(input_ids, expected_input_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know that torch.testing
existed!
Notebook comment: you can directly index var_names_g = pipeline.get_submodule("2").var_names_g
# instead
var_names_g = pipeline[2].var_names_g
# or better
# var_names_g = pipeline.model.var_names_g # sorry this actually won't work, it is for CellariumModule.model |
Alright I think I've addressed everything, worth double-checking |
Adds the optional kwargs
to
Geneformer.predict()
The idea of "deletion" and "activation" are described in the Geneformer paper. Deletion is achieved by zeroing expression before tokenization. Activation is achieved by setting feature expression to a high value so that it ends up at the top of the rank-ordered list after tokenization.
"Map" is a general kind of idea to implement other kinds of in silico experiments. It takes the data after tokenization, and maps tokens to other tokens. This could be used to implement the replacement of a specific token or set of tokens with a pad (0) or mask (1) token, or to switch the role of two tokens, etc.