diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index ab4d71bc6..980f405d0 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -10,7 +10,7 @@ on: jobs: tests: - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: runner: linux.12xlarge docker-image: cimg/python:3.11 diff --git a/.github/workflows/test-pip-cpu-with-type-checks.yml b/.github/workflows/test-pip-cpu-with-type-checks.yml index 3336a76f6..47569fc22 100644 --- a/.github/workflows/test-pip-cpu-with-type-checks.yml +++ b/.github/workflows/test-pip-cpu-with-type-checks.yml @@ -14,7 +14,7 @@ jobs: matrix: pytorch_args: ["", "-n"] fail-fast: false - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: runner: linux.12xlarge docker-image: cimg/python:3.11 diff --git a/.github/workflows/test-pip-cpu.yml b/.github/workflows/test-pip-cpu.yml index 83a513ac2..81e006ff7 100644 --- a/.github/workflows/test-pip-cpu.yml +++ b/.github/workflows/test-pip-cpu.yml @@ -37,7 +37,7 @@ jobs: - pytorch_args: "-v 2.1.0" docker_img: "cimg/python:3.12" fail-fast: false - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: runner: linux.12xlarge docker-image: ${{ matrix.docker_img }} diff --git a/.github/workflows/test-pip-gpu.yml b/.github/workflows/test-pip-gpu.yml index 117f515f4..d1ada427a 100644 --- a/.github/workflows/test-pip-gpu.yml +++ b/.github/workflows/test-pip-gpu.yml @@ -14,7 +14,7 @@ jobs: matrix: cuda_arch_version: ["12.1"] fail-fast: false - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: runner: linux.4xlarge.nvidia.gpu repository: pytorch/captum diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 95ab41c5a..e0631b507 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -91,7 +91,7 @@ def _is_tuple(inputs: Tensor) -> Literal[False]: ... @typing.overload def _is_tuple( - inputs: TensorOrTupleOfTensorsGeneric, # type: ignore + inputs: Union[Tensor, Tuple[Tensor, ...]], ) -> bool: ... @@ -373,8 +373,6 @@ def _expand_target( def _expand_feature_mask( feature_mask: Union[Tensor, Tuple[Tensor, ...]], n_samples: int ) -> Tuple[Tensor, ...]: - # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, - # typing.Tuple[Tensor, ...]]`. is_feature_mask_tuple = _is_tuple(feature_mask) feature_mask = _format_tensor_into_tuples(feature_mask) feature_mask_new = tuple( diff --git a/captum/_utils/models/model.py b/captum/_utils/models/model.py index f6cb6600f..1ebaba171 100644 --- a/captum/_utils/models/model.py +++ b/captum/_utils/models/model.py @@ -22,8 +22,7 @@ class Model(ABC): def fit( self, train_data: DataLoader, - # pyre-fixme[2]: Parameter must be annotated. - **kwargs, + **kwargs: object, ) -> Optional[Dict[str, Union[int, float, Tensor]]]: r""" Override this method to actually train your model. diff --git a/captum/attr/_core/deep_lift.py b/captum/attr/_core/deep_lift.py index 339894713..dfc323acd 100644 --- a/captum/attr/_core/deep_lift.py +++ b/captum/attr/_core/deep_lift.py @@ -348,20 +348,21 @@ def attribute( # type: ignore self._remove_hooks(main_model_hooks) undo_gradient_requirements(inputs_tuple, gradient_mask) - # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGeneric... - return _compute_conv_delta_and_format_attrs( - self, - return_convergence_delta, - attributions, - baselines, - inputs_tuple, - additional_forward_args, - target, - is_inputs_tuple, + return cast( + TensorOrTupleOfTensorsGeneric, + _compute_conv_delta_and_format_attrs( + self, + return_convergence_delta, + attributions, + baselines, + inputs_tuple, + additional_forward_args, + target, + is_inputs_tuple, + ), ) - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future(self) -> None: r""" This method is not implemented for DeepLift. """ @@ -831,11 +832,18 @@ def attribute( # type: ignore ) if return_convergence_delta: - # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... - return _format_output(is_inputs_tuple, attributions), delta + return ( + cast( + TensorOrTupleOfTensorsGeneric, + _format_output(is_inputs_tuple, attributions), + ), + delta, + ) else: - # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... - return _format_output(is_inputs_tuple, attributions) + return cast( + TensorOrTupleOfTensorsGeneric, + _format_output(is_inputs_tuple, attributions), + ) def _expand_inputs_baselines_targets( self, @@ -995,10 +1003,8 @@ def maxpool3d( def maxpool( module: Module, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - pool_func: Callable, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - unpool_func: Callable, + pool_func: Callable[..., Tuple[Tensor, Tensor]], + unpool_func: Callable[..., Tensor], inputs: Tensor, outputs: Tensor, grad_input: Tensor, diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index ab9b9f9c6..379be8673 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -401,10 +401,10 @@ def attribute( if attr_progress is not None: attr_progress.close() - # pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <: - # [Tensor, typing.Tuple[Tensor, ...]]]` - # but got `Union[Tensor, typing.Tuple[Tensor, ...]]`. - return self._generate_result(total_attrib, weights, is_inputs_tuple) # type: ignore # noqa: E501 line too long + return cast( + TensorOrTupleOfTensorsGeneric, + self._generate_result(total_attrib, weights, is_inputs_tuple), + ) def _attribute_with_independent_feature_masks( self, @@ -629,8 +629,7 @@ def _should_skip_inputs_and_warn( all_empty = False if self._min_examples_per_batch_grouped is not None and ( formatted_inputs[tensor_idx].shape[0] - # pyre-ignore[58]: Type has been narrowed to int - < self._min_examples_per_batch_grouped + < cast(int, self._min_examples_per_batch_grouped) ): should_skip = True break @@ -789,35 +788,35 @@ def attribute_future( ) if enable_cross_tensor_attribution: - # pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric - # <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got - # `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]` - return self._attribute_with_cross_tensor_feature_masks_future( # type: ignore # noqa: E501 line too long - formatted_inputs=formatted_inputs, - formatted_additional_forward_args=formatted_additional_forward_args, - target=target, - baselines=baselines, - formatted_feature_mask=formatted_feature_mask, - attr_progress=attr_progress, - processed_initial_eval_fut=processed_initial_eval_fut, - is_inputs_tuple=is_inputs_tuple, - perturbations_per_eval=perturbations_per_eval, + return cast( + Future[TensorOrTupleOfTensorsGeneric], + self._attribute_with_cross_tensor_feature_masks_future( # type: ignore + formatted_inputs=formatted_inputs, + formatted_additional_forward_args=formatted_additional_forward_args, # noqa: E501 line too long + target=target, + baselines=baselines, + formatted_feature_mask=formatted_feature_mask, + attr_progress=attr_progress, + processed_initial_eval_fut=processed_initial_eval_fut, + is_inputs_tuple=is_inputs_tuple, + perturbations_per_eval=perturbations_per_eval, + ), ) else: - # pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric - # <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got - # `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]` - return self._attribute_with_independent_feature_masks_future( # type: ignore # noqa: E501 line too long - formatted_inputs, - formatted_additional_forward_args, - target, - baselines, - formatted_feature_mask, - perturbations_per_eval, - attr_progress, - processed_initial_eval_fut, - is_inputs_tuple, - **kwargs, + return cast( + Future[TensorOrTupleOfTensorsGeneric], + self._attribute_with_independent_feature_masks_future( # type: ignore # noqa: E501 line too long + formatted_inputs, + formatted_additional_forward_args, + target, + baselines, + formatted_feature_mask, + perturbations_per_eval, + attr_progress, + processed_initial_eval_fut, + is_inputs_tuple, + **kwargs, + ), ) def _attribute_with_independent_feature_masks_future( diff --git a/captum/attr/_core/guided_backprop_deconvnet.py b/captum/attr/_core/guided_backprop_deconvnet.py index e38542559..e1e221686 100644 --- a/captum/attr/_core/guided_backprop_deconvnet.py +++ b/captum/attr/_core/guided_backprop_deconvnet.py @@ -2,7 +2,7 @@ # pyre-strict import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import cast, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -78,12 +78,11 @@ def attribute( self._remove_hooks() undo_gradient_requirements(inputs_tuple, gradient_mask) - # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got - # `Tuple[Tensor, ...]`. - return _format_output(is_inputs_tuple, gradients) + return cast( + TensorOrTupleOfTensorsGeneric, _format_output(is_inputs_tuple, gradients) + ) - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future(self) -> None: r""" This method is not implemented for ModifiedReluGradientAttribution. """ diff --git a/captum/attr/_core/guided_grad_cam.py b/captum/attr/_core/guided_grad_cam.py index 2278c42e6..9c4f14b15 100644 --- a/captum/attr/_core/guided_grad_cam.py +++ b/captum/attr/_core/guided_grad_cam.py @@ -2,7 +2,7 @@ # pyre-strict import warnings -from typing import List, Optional, Union +from typing import cast, List, Optional, Union import torch from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple @@ -223,6 +223,7 @@ def attribute( ) output_attr.append(torch.empty(0)) - # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got - # `Tuple[Tensor, ...]`. - return _format_output(is_inputs_tuple, tuple(output_attr)) + return cast( + TensorOrTupleOfTensorsGeneric, + _format_output(is_inputs_tuple, tuple(output_attr)), + ) diff --git a/captum/attr/_core/input_x_gradient.py b/captum/attr/_core/input_x_gradient.py index 020a8f516..d2f7b30e3 100644 --- a/captum/attr/_core/input_x_gradient.py +++ b/captum/attr/_core/input_x_gradient.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Callable, Optional +from typing import Callable, cast, Optional from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple from captum._utils.gradient import ( @@ -126,12 +126,11 @@ def attribute( ) undo_gradient_requirements(inputs_tuple, gradient_mask) - # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got - # `Tuple[Tensor, ...]`. - return _format_output(is_inputs_tuple, attributions) + return cast( + TensorOrTupleOfTensorsGeneric, _format_output(is_inputs_tuple, attributions) + ) - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future(self) -> None: r""" This method is not implemented for InputXGradient. """ diff --git a/captum/attr/_core/integrated_gradients.py b/captum/attr/_core/integrated_gradients.py index cbdb87053..47e71d276 100644 --- a/captum/attr/_core/integrated_gradients.py +++ b/captum/attr/_core/integrated_gradients.py @@ -2,7 +2,7 @@ # pyre-strict import typing -from typing import Callable, List, Literal, Optional, Tuple, Union +from typing import Callable, cast, List, Literal, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -301,16 +301,18 @@ def attribute( # type: ignore additional_forward_args=additional_forward_args, target=target, ) - # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... - return _format_output(is_inputs_tuple, attributions), delta - # pyre-fixme[7]: Expected - # `Union[Tuple[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, - # typing.Tuple[Tensor, ...]]], Tensor], Variable[TensorOrTupleOfTensorsGeneric - # <: [Tensor, typing.Tuple[Tensor, ...]]]]` but got `Tuple[Tensor, ...]`. - return _format_output(is_inputs_tuple, attributions) - - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + return ( + cast( + TensorOrTupleOfTensorsGeneric, + _format_output(is_inputs_tuple, attributions), + ), + delta, + ) + return cast( + TensorOrTupleOfTensorsGeneric, _format_output(is_inputs_tuple, attributions) + ) + + def attribute_future(self) -> None: r""" This method is not implemented for IntegratedGradients. """ diff --git a/captum/attr/_core/kernel_shap.py b/captum/attr/_core/kernel_shap.py index 1b5c9d5ed..b804b58af 100644 --- a/captum/attr/_core/kernel_shap.py +++ b/captum/attr/_core/kernel_shap.py @@ -292,8 +292,7 @@ def attribute( # type: ignore show_progress=show_progress, ) - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future(self) -> None: r""" This method is not implemented for KernelShap. """ diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index 1a4ea06d5..39e840bf6 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -548,8 +548,7 @@ def generate_perturbation() -> ( return generate_perturbation - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future(self) -> None: r""" This method is not implemented for LimeBase. """ @@ -1116,8 +1115,7 @@ def attribute( # type: ignore show_progress=show_progress, ) - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future(self) -> None: return super().attribute_future() def _attribute_kwargs( # type: ignore diff --git a/captum/attr/_core/lrp.py b/captum/attr/_core/lrp.py index c42125fa3..ac2d4576e 100644 --- a/captum/attr/_core/lrp.py +++ b/captum/attr/_core/lrp.py @@ -4,7 +4,7 @@ import typing from collections import defaultdict -from typing import Any, Callable, cast, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, cast, Dict, List, Literal, Optional, Tuple, Union import torch.nn as nn from captum._utils.common import ( @@ -230,16 +230,17 @@ def attribute( undo_gradient_requirements(input_tuple, gradient_mask) if return_convergence_delta: - # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... return ( - _format_output(is_inputs_tuple, relevances), + cast( + TensorOrTupleOfTensorsGeneric, + _format_output(is_inputs_tuple, relevances), + ), self.compute_convergence_delta(relevances, output), ) else: return _format_output(is_inputs_tuple, relevances) # type: ignore - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future(self) -> None: r""" This method is not implemented for LRP. """ diff --git a/captum/attr/_core/neuron/neuron_conductance.py b/captum/attr/_core/neuron/neuron_conductance.py index b04d9d79e..504adcf74 100644 --- a/captum/attr/_core/neuron/neuron_conductance.py +++ b/captum/attr/_core/neuron/neuron_conductance.py @@ -2,7 +2,7 @@ # pyre-strict import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -328,9 +328,9 @@ def attribute( attribute_to_neuron_input=attribute_to_neuron_input, grad_kwargs=grad_kwargs, ) - # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got - # `Tuple[Tensor, ...]`. - return _format_output(is_inputs_tuple, attrs) + return cast( + TensorOrTupleOfTensorsGeneric, _format_output(is_inputs_tuple, attrs) + ) def _attribute( self, diff --git a/captum/attr/_core/neuron/neuron_gradient.py b/captum/attr/_core/neuron/neuron_gradient.py index 4ece2aed2..3d1c293a3 100644 --- a/captum/attr/_core/neuron/neuron_gradient.py +++ b/captum/attr/_core/neuron/neuron_gradient.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, cast, List, Optional, Tuple, Union from captum._utils.common import ( _format_additional_forward_args, @@ -184,7 +184,6 @@ def attribute( undo_gradient_requirements(inputs_tuple, gradient_mask) - # pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <: - # [Tensor, typing.Tuple[Tensor, ...]]]` but got `Union[Tensor, - # typing.Tuple[Tensor, ...]]`. - return _format_output(is_inputs_tuple, input_grads) + return cast( + TensorOrTupleOfTensorsGeneric, _format_output(is_inputs_tuple, input_grads) + ) diff --git a/captum/attr/_core/noise_tunnel.py b/captum/attr/_core/noise_tunnel.py index bef35a300..0d029a181 100644 --- a/captum/attr/_core/noise_tunnel.py +++ b/captum/attr/_core/noise_tunnel.py @@ -2,7 +2,7 @@ # pyre-strict from enum import Enum -from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union import torch from captum._utils.common import ( @@ -14,7 +14,6 @@ _format_tensor_into_tuples, _is_tuple, ) -from captum._utils.typing import TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import Attribution, GradientAttribution from captum.attr._utils.common import _validate_noise_tunnel_type from captum.log import log_usage @@ -91,12 +90,10 @@ def attribute( draw_baseline_from_distrib: bool = False, **kwargs: Any, ) -> Union[ - Union[ - Tensor, - Tuple[Tensor, Tensor], - Tuple[Tensor, ...], - Tuple[Tuple[Tensor, ...], Tensor], - ] + Tensor, + Tuple[Tensor, Tensor], + Tuple[Tensor, ...], + Tuple[Tuple[Tensor, ...], Tensor], ]: r""" Args: @@ -298,8 +295,7 @@ def attribute( delta, ) - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future(self) -> None: r""" This method is not implemented for NoiseTunnel. """ @@ -490,11 +486,11 @@ def _apply_checks_and_return_attributions( is_attrib_tuple: bool, return_convergence_delta: bool, delta: Union[None, Tensor], - # pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <: - # [torch._tensor.Tensor, typing.Tuple[torch._tensor.Tensor, ...]]]` - # isn't present in the function's parameters. ) -> Union[ - TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] + Tensor, + Tuple[Tensor, Tensor], + Tuple[Tensor, ...], + Tuple[Tuple[Tensor, ...], Tensor], ]: attributions_tuple = _format_output(is_attrib_tuple, attributions) @@ -503,17 +499,15 @@ def _apply_checks_and_return_attributions( if self.is_delta_supported and return_convergence_delta else attributions_tuple ) - ret = cast( - # pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <: - # [torch._tensor.Tensor, typing.Tuple[torch._tensor.Tensor, ...]]]` - # isn't present in the function's parameters. + return cast( Union[ - TensorOrTupleOfTensorsGeneric, - Tuple[TensorOrTupleOfTensorsGeneric, Tensor], + Tensor, + Tuple[Tensor, Tensor], + Tuple[Tensor, ...], + Tuple[Tuple[Tensor, ...], Tensor], ], ret, ) - return ret def has_convergence_delta(self) -> bool: return self.is_delta_supported diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py index 710e53166..1b6d6e2bf 100644 --- a/captum/attr/_core/occlusion.py +++ b/captum/attr/_core/occlusion.py @@ -269,8 +269,7 @@ def attribute( # type: ignore show_progress=show_progress, ) - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future(self) -> None: r""" This method is not implemented for Occlusion. """ diff --git a/captum/attr/_core/saliency.py b/captum/attr/_core/saliency.py index 408da915d..0eb50e501 100644 --- a/captum/attr/_core/saliency.py +++ b/captum/attr/_core/saliency.py @@ -2,7 +2,7 @@ # pyre-strict -from typing import Callable, Optional +from typing import Callable, cast, Optional import torch from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple @@ -138,12 +138,12 @@ def attribute( else: attributions = gradients undo_gradient_requirements(inputs_tuple, gradient_mask) - # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got - # `Tuple[Tensor, ...]`. - return _format_output(is_inputs_tuple, attributions) - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + return cast( + TensorOrTupleOfTensorsGeneric, _format_output(is_inputs_tuple, attributions) + ) + + def attribute_future(self) -> None: r""" This method is not implemented for Saliency. """ diff --git a/captum/attr/_core/shapley_value.py b/captum/attr/_core/shapley_value.py index 46fb410b7..f294222f1 100644 --- a/captum/attr/_core/shapley_value.py +++ b/captum/attr/_core/shapley_value.py @@ -5,7 +5,7 @@ import itertools import math import warnings -from typing import Callable, cast, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, cast, Iterable, List, Optional, Sequence, Tuple, Union import torch from captum._utils.common import ( @@ -452,9 +452,7 @@ def attribute( tensor_attrib_total / iter_count for tensor_attrib_total in total_attrib ) formatted_attr = _format_output(is_inputs_tuple, attrib) - # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got - # `Tuple[Tensor, ...]`. - return formatted_attr + return cast(TensorOrTupleOfTensorsGeneric, formatted_attr) def attribute_future( self, @@ -595,9 +593,7 @@ def attribute_future( ) ) ) - # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got - # `Tuple[Tensor, ...]`. - return formatted_attr # type: ignore + return cast(Future[TensorOrTupleOfTensorsGeneric], formatted_attr) def _initialEvalToPrevResultsTuple( self, @@ -860,8 +856,7 @@ def _get_n_evaluations( """return the total number of forward evaluations needed""" return math.ceil(total_features / perturbations_per_eval) * n_samples - # pyre-fixme[2]: Parameter must be annotated. - def _strict_run_forward(self, *args, **kwargs) -> Tensor: + def _strict_run_forward(self, *args: Any, **kwargs: Any) -> Tensor: """ A temp wrapper for global _run_forward util to force forward output type assertion & conversion. @@ -886,8 +881,7 @@ def _strict_run_forward(self, *args, **kwargs) -> Tensor: # ref: https://github.com/pytorch/pytorch/pull/21215 return torch.tensor([forward_output], dtype=cast(dtype, output_type)) - # pyre-fixme[2]: Parameter must be annotated. - def _strict_run_forward_future(self, *args, **kwargs) -> Future[Tensor]: + def _strict_run_forward_future(self, *args: Any, **kwargs: Any) -> Future[Tensor]: """ A temp wrapper for global _run_forward util to force forward outputtype assertion & conversion, but takes diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index 5ee947976..ec97f8e8c 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -25,9 +25,8 @@ def _scatter_itp_attr_by_mask( # input_shape in shape(batch_size, *inp_feature_dims) # attribute in shape(*output_dims, *inp_feature_dims) - # pyre-fixme[60]: Concatenation not yet support for multiple variadic tuples: - # `*output_dims, *input_shape[slice(1, None, None)]`. - attr_shape = (*output_dims, *input_shape[1:]) + # Current limitation in pyre with multiple variadic tuples + attr_shape = (*output_dims, *input_shape[1:]) # pyre-ignore[60] expanded_feature_indices = mask.expand(attr_shape) @@ -39,12 +38,17 @@ def _scatter_itp_attr_by_mask( # (*output_dims, 1..., 1, n_itp_features) # then broadcast to (*output_dims, *inp.shape[1:-1], n_itp_features) n_extra_dims = len(extra_inp_dims) - # pyre-fixme[60]: Concatenation not yet support for multiple variadic - # tuples: `*output_dims, *(1).__mul__(n_extra_dims)`. - unsqueezed_shape = (*output_dims, *(1,) * n_extra_dims, n_itp_features) - # pyre-fixme[60]: Concatenation not yet support for multiple variadic - # tuples: `*output_dims, *extra_inp_dims`. - expanded_shape = (*output_dims, *extra_inp_dims, n_itp_features) + # Current limitation in pyre with multiple variadic tuples + unsqueezed_shape = ( # pyre-ignore[60] + *output_dims, + *(1,) * n_extra_dims, + n_itp_features, + ) + expanded_shape = ( # pyre-ignore[60] + *output_dims, + *extra_inp_dims, + n_itp_features, + ) expanded_itp_attr = itp_attr.reshape(unsqueezed_shape).expand(expanded_shape) else: expanded_itp_attr = itp_attr diff --git a/captum/module/binary_concrete_stochastic_gates.py b/captum/module/binary_concrete_stochastic_gates.py index e517e1379..8d47b7dc6 100644 --- a/captum/module/binary_concrete_stochastic_gates.py +++ b/captum/module/binary_concrete_stochastic_gates.py @@ -134,8 +134,9 @@ def __init__( self.eps = eps # pre-calculate the fixed term used in active prob - # pyre-fixme[4]: Attribute must be annotated. - self.active_prob_offset = temperature * math.log(-lower_bound / upper_bound) + self.active_prob_offset: float = temperature * math.log( + -lower_bound / upper_bound + ) def _sample_gate_values(self, batch_size: int) -> Tensor: """ diff --git a/captum/module/gaussian_stochastic_gates.py b/captum/module/gaussian_stochastic_gates.py index 55b804e3e..7f9ec86c0 100644 --- a/captum/module/gaussian_stochastic_gates.py +++ b/captum/module/gaussian_stochastic_gates.py @@ -45,8 +45,8 @@ def __init__( self, n_gates: int, mask: Optional[Tensor] = None, - reg_weight: Optional[float] = 1.0, - std: Optional[float] = 0.5, + reg_weight: float = 1.0, + std: float = 0.5, reg_reduction: str = "sum", ) -> None: """ @@ -79,9 +79,7 @@ def __init__( super().__init__( n_gates, mask=mask, - # pyre-fixme[6]: For 3rd argument expected `float` but got - # `Optional[float]`. - reg_weight=reg_weight, # type: ignore + reg_weight=reg_weight, reg_reduction=reg_reduction, ) @@ -89,8 +87,6 @@ def __init__( nn.init.normal_(mu, mean=0.5, std=0.01) self.mu = nn.Parameter(mu) - # pyre-fixme[58]: `<` is not supported for operand types `int` and - # `Optional[float]`. assert 0 < std, f"the standard deviation should be positive, received {std}" # type: ignore # noqa: E501 line too long self.std = std @@ -107,9 +103,7 @@ def _sample_gate_values(self, batch_size: int) -> Tensor: if self.training: n = torch.empty(batch_size, self.n_gates, device=self.mu.device) - # pyre-fixme[6]: For 2nd argument expected `float` but got - # `Optional[float]`. - n.normal_(mean=0, std=self.std) # type: ignore + n.normal_(mean=0, std=self.std) return self.mu + n return self.mu.expand(batch_size, self.n_gates) diff --git a/captum/testing/attr/helpers/get_config_util.py b/captum/testing/attr/helpers/get_config_util.py index aa66d08a8..f950c4a05 100644 --- a/captum/testing/attr/helpers/get_config_util.py +++ b/captum/testing/attr/helpers/get_config_util.py @@ -1,7 +1,7 @@ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. # pyre-strict -from typing import Any, Tuple +from typing import List, Tuple import torch from captum._utils.gradient import compute_gradients @@ -10,17 +10,15 @@ from torch.nn import Module -# pyre-fixme[3]: Return annotation cannot contain `Any`. -def get_basic_config() -> Tuple[Module, Tensor, Tensor, Any]: +def get_basic_config() -> Tuple[Module, Tensor, Tensor, None]: input = torch.tensor([1.0, 2.0, 3.0, 0.0, -1.0, 7.0], requires_grad=True).T # manually percomputed gradients grads = torch.tensor([-0.0, -0.0, -0.0, 1.0, 1.0, -0.0]) return BasicModel(), input, grads, None -# pyre-fixme[3]: Return annotation cannot contain `Any`. def get_multiargs_basic_config() -> ( - Tuple[Module, Tuple[Tensor, ...], Tuple[Tensor, ...], Any] + Tuple[Module, Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[List[int], int]] ): model = BasicModel5_MultiArgs() additional_forward_args = ([2, 3], 1) @@ -34,9 +32,8 @@ def get_multiargs_basic_config() -> ( return model, inputs, grads, additional_forward_args -# pyre-fixme[3]: Return annotation cannot contain `Any`. def get_multiargs_basic_config_large() -> ( - Tuple[Module, Tuple[Tensor, ...], Tuple[Tensor, ...], Any] + Tuple[Module, Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[List[int], int]] ): model = BasicModel5_MultiArgs() additional_forward_args = ([2, 3], 1) diff --git a/captum/testing/helpers/basic.py b/captum/testing/helpers/basic.py index 04d6c8f66..9e7a1c953 100644 --- a/captum/testing/helpers/basic.py +++ b/captum/testing/helpers/basic.py @@ -5,20 +5,18 @@ import random import unittest -from typing import Callable, Generator +from typing import Any, Callable, Generator, List, Tuple, TypeVar, Union import numpy as np import torch from captum.log import patch_methods from torch import Tensor +ReturnType = TypeVar("ReturnType") -# pyre-fixme[3]: Return type must be annotated. -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. -def deep_copy_args(func: Callable): - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def copy_args(*args, **kwargs): + +def deep_copy_args(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]: + def copy_args(*args: Any, **kwargs: Any) -> ReturnType: return func( *(copy.deepcopy(x) for x in args), **{k: copy.deepcopy(v) for k, v in kwargs.items()}, @@ -28,8 +26,7 @@ def copy_args(*args, **kwargs): def assertTensorAlmostEqual( - # pyre-fixme[2]: Parameter must be annotated. - test, + test: unittest.TestCase, # pyre-fixme[2]: Parameter must be annotated. actual, # pyre-fixme[2]: Parameter must be annotated. @@ -75,8 +72,7 @@ def assertTensorAlmostEqual( def assertTensorTuplesAlmostEqual( - # pyre-fixme[2]: Parameter must be annotated. - test, + test: unittest.TestCase, # pyre-fixme[2]: Parameter must be annotated. actual, # pyre-fixme[2]: Parameter must be annotated. @@ -95,15 +91,17 @@ def assertTensorTuplesAlmostEqual( assertTensorAlmostEqual(test, actual, expected, delta, mode) -# pyre-fixme[2]: Parameter must be annotated. -def assertAttributionComparision(test, attributions1, attributions2) -> None: +def assertAttributionComparision( + test: unittest.TestCase, + attributions1: Union[Tensor, Tuple[Tensor, ...]], + attributions2: Union[Tensor, Tuple[Tensor, ...]], +) -> None: for attribution1, attribution2 in zip(attributions1, attributions2): for attr_row1, attr_row2 in zip(attribution1, attribution2): assertTensorAlmostEqual(test, attr_row1, attr_row2, 0.05, "max") -# pyre-fixme[2]: Parameter must be annotated. -def assert_delta(test, delta) -> None: +def assert_delta(test: unittest.TestCase, delta: Tensor) -> None: delta_condition = (delta.abs() < 0.00001).all() test.assertTrue( delta_condition, diff --git a/tests/attr/test_deeplift_basic.py b/tests/attr/test_deeplift_basic.py index 1f54eb006..ee5963cde 100644 --- a/tests/attr/test_deeplift_basic.py +++ b/tests/attr/test_deeplift_basic.py @@ -310,7 +310,7 @@ def test_futures_not_implemented(self) -> None: dl = DeepLift(model, multiply_by_inputs=False) attributions = None with self.assertRaises(NotImplementedError): - attributions = dl.attribute_future() + attributions = dl.attribute_future() # type: ignore self.assertEqual(attributions, None) def _deeplift_assert( diff --git a/tests/attr/test_input_x_gradient.py b/tests/attr/test_input_x_gradient.py index 8718fae6d..84178726a 100644 --- a/tests/attr/test_input_x_gradient.py +++ b/tests/attr/test_input_x_gradient.py @@ -55,7 +55,7 @@ def test_futures_not_implemented(self) -> None: input_x_grad = InputXGradient(model.forward) attributions = None with self.assertRaises(NotImplementedError): - attributions = input_x_grad.attribute_future() + attributions = input_x_grad.attribute_future() # type: ignore self.assertEqual(attributions, None) def _input_x_gradient_base_assert( diff --git a/tests/attr/test_integrated_gradients_basic.py b/tests/attr/test_integrated_gradients_basic.py index 074b0bd63..ad1376543 100644 --- a/tests/attr/test_integrated_gradients_basic.py +++ b/tests/attr/test_integrated_gradients_basic.py @@ -161,7 +161,7 @@ def test_futures_not_implemented(self) -> None: ig = IntegratedGradients(model, multiply_by_inputs=True) attributions = None with self.assertRaises(NotImplementedError): - attributions = ig.attribute_future() + attributions = ig.attribute_future() # type: ignore self.assertEqual(attributions, None) def _assert_multi_variable( diff --git a/tests/attr/test_kernel_shap.py b/tests/attr/test_kernel_shap.py index 61bd66397..ad31702ed 100644 --- a/tests/attr/test_kernel_shap.py +++ b/tests/attr/test_kernel_shap.py @@ -337,7 +337,7 @@ def test_futures_not_implemented(self) -> None: kernel_shap = KernelShap(net) attributions = None with self.assertRaises(NotImplementedError): - attributions = kernel_shap.attribute_future() + attributions = kernel_shap.attribute_future() # type: ignore self.assertEqual(attributions, None) def _multi_input_scalar_kernel_shap_assert(self, func: Callable) -> None: diff --git a/tests/attr/test_lime.py b/tests/attr/test_lime.py index 095ef9cf0..68a18bc7b 100644 --- a/tests/attr/test_lime.py +++ b/tests/attr/test_lime.py @@ -482,7 +482,7 @@ def test_futures_not_implemented(self) -> None: ) attributions = None with self.assertRaises(NotImplementedError): - attributions = lime.attribute_future() + attributions = lime.attribute_future() # type: ignore self.assertEqual(attributions, None) # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. diff --git a/tests/attr/test_lrp.py b/tests/attr/test_lrp.py index d2098e41e..6d49edeba 100644 --- a/tests/attr/test_lrp.py +++ b/tests/attr/test_lrp.py @@ -335,5 +335,5 @@ def test_futures_not_implemented(self) -> None: lrp = LRP(model) attributions = None with self.assertRaises(NotImplementedError): - attributions = lrp.attribute_future() + attributions = lrp.attribute_future() # type: ignore self.assertEqual(attributions, None) diff --git a/tests/attr/test_occlusion.py b/tests/attr/test_occlusion.py index 32705cbdb..f4a884fdd 100644 --- a/tests/attr/test_occlusion.py +++ b/tests/attr/test_occlusion.py @@ -287,7 +287,7 @@ def test_futures_not_implemented(self) -> None: occ = Occlusion(net) attributions = None with self.assertRaises(NotImplementedError): - attributions = occ.attribute_future() + attributions = occ.attribute_future() # type: ignore self.assertEqual(attributions, None) @unittest.mock.patch("sys.stderr", new_callable=io.StringIO) diff --git a/tests/attr/test_saliency.py b/tests/attr/test_saliency.py index a1518f47f..e322a6467 100644 --- a/tests/attr/test_saliency.py +++ b/tests/attr/test_saliency.py @@ -112,7 +112,7 @@ def test_futures_not_implemented(self) -> None: saliency = Saliency(model) attributions = None with self.assertRaises(NotImplementedError): - attributions = saliency.attribute_future() + attributions = saliency.attribute_future() # type: ignore self.assertEqual(attributions, None) def _saliency_base_assert(