Skip to content

Commit ad77e79

Browse files
ezyangfacebook-github-bot
authored andcommitted
Prepare for "Fix type-safety of torch.nn.Module instances": wave 2
Summary: See D52890934 Reviewed By: malfet, r-barnes Differential Revision: D66245100 fbshipit-source-id: 019058106ac7eaacf29c1c55912922ea55894d23
1 parent 478942d commit ad77e79

File tree

5 files changed

+14
-0
lines changed

5 files changed

+14
-0
lines changed

captum/_utils/transformers_typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,5 @@ def supports_caching(model: nn.Module) -> bool:
116116
# Cache is mandatory
117117
return True
118118
# Fallback on _supports_cache_class attribute
119+
# pyre-fixme[7]: Expected `bool` but got `Union[Module, Tensor]`.
119120
return getattr(model, "_supports_cache_class", False)

captum/attr/_core/deep_lift.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def _forward_pre_hook(
407407
set necessary hooks on inputs there.
408408
"""
409409
inputs = _format_tensor_into_tuples(inputs)
410+
# pyre-fixme[16]: `Module` has no attribute `input`.
410411
module.input = inputs[0].clone().detach()
411412

412413
def _forward_hook(
@@ -420,6 +421,7 @@ def _forward_hook(
420421
outputs of a neuron
421422
"""
422423
outputs = _format_tensor_into_tuples(outputs)
424+
# pyre-fixme[16]: `Module` has no attribute `output`.
423425
module.output = outputs[0].clone().detach()
424426

425427
def _backward_hook(
@@ -536,6 +538,8 @@ def forward_hook(
536538
):
537539
return [
538540
self.model.module.register_forward_pre_hook(pre_hook), # type: ignore
541+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no
542+
# attribute `register_forward_hook`.
539543
self.model.module.register_forward_hook(forward_hook),
540544
] # type: ignore
541545
else:

captum/attr/_core/llm_attr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def _get_target_tokens(
404404
gen_args = DEFAULT_GEN_ARGS
405405

406406
model_inp = self._format_model_input(inp.to_model_input())
407+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
407408
output_tokens = self.model.generate(model_inp, **gen_args)
408409
target_tokens = output_tokens[0][model_inp.size(1) :]
409410
else:
@@ -558,9 +559,11 @@ def _forward_func(
558559
outputs.past_key_values = DynamicCache.from_legacy_cache(
559560
outputs.past_key_values
560561
)
562+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
561563
model_kwargs = self.model._update_model_kwargs_for_generation(
562564
outputs, model_kwargs
563565
)
566+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
564567
model_inputs = self.model.prepare_inputs_for_generation(
565568
model_inp, **model_kwargs
566569
)

captum/insights/attr_vis/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,8 @@ def _calculate_vis_output(
407407
else self.models
408408
)
409409
results = []
410+
# pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got
411+
# `Union[List[Any], Module]`.
410412
for model_index, model in enumerate(models_used):
411413
# Get list of model visualizations for each input
412414
actual_label_output = None

tests/attr/layer/test_layer_gradient_x_activation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ def _layer_activation_test_assert(
174174
self, attributions, expected_activation, delta=0.01
175175
)
176176
else:
177+
# pyre-fixme[6]: For 1st argument expected
178+
# `pyre_extensions.PyreReadOnly[Sized]` but got `ModuleOrModuleList`.
177179
for i in range(len(target_layer)):
178180
assertTensorTuplesAlmostEqual(
179181
self, attributions[i], expected_activation[i], delta=0.01
@@ -196,6 +198,8 @@ def _layer_activation_test_assert(
196198
delta=0.01,
197199
)
198200
else:
201+
# pyre-fixme[6]: For 1st argument expected
202+
# `pyre_extensions.PyreReadOnly[Sized]` but got `ModuleOrModuleList`.
199203
for i in range(len(target_layer)):
200204
assertTensorTuplesAlmostEqual(
201205
self,

0 commit comments

Comments
 (0)