-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix device handling in tests #52
Conversation
@@ -185,7 +185,8 @@ def _create_additive_hook( | |||
|
|||
def hook_fn(_m: Any, _inputs: Any, outputs: Any) -> Any: | |||
original_tensor = untuple_tensor(outputs) | |||
original_tensor[None] = operator(original_tensor, target_activation) | |||
act = target_activation.to(original_tensor.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting the device here allows us to skip setting the activation devices.
E.g.:
activations = {
1: torch.randn(1, 512), # can be on cpu
}
...
model_patcher.patch_activations(
activations, layer_type=layer_type, operator="piecewise_addition"
)
As of 16426e4, all the tests pass for me locally (with GPU enabled): $ python -m pytest tests/util/test_model_patcher.py
...
12 passed, 9 warnings in 15.54s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, LGTM!
@chanind after merging in your latest changes from main, one of the tests isn't working any more. Looking at the error logs, it seems to be in this code chunk: def _create_additive_hook(
target_activation: torch.Tensor, operator: PatchOperator
) -> Any:
"""Create a hook function that adds the given target_activation to the model output"""
def hook_fn(_m: Any, _inputs: Any, outputs: Any) -> Any:
original_tensor = untuple_tensor(outputs)
act = target_activation.to(original_tensor.device) # This line raises an error
original_tensor[None] = operator(original_tensor, act)
return outputs
return hook_fn Looking at the CI logs here, this line fails:
with this error:
I.e. the |
Good catch! Fixed in #55 |
Modifies
ModelPatcher
so that target activations are moved onto the appropriate device before performing operator.closes #51