-
-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Treat bool as equivalent to Literal[True, False] to allow better type inference #6113
Comments
I think this can be done using literal types: essentially we can treat cc @Michael0x2a |
This appeared again so raising priority to normal. |
Would like to take a stab at this if no one's picked it up already! |
Another use case: from typing import Literal, overload
@overload
def foo(x: Literal[False]) -> int: ...
@overload
def foo(x: Literal[True]) -> str: ...
def foo(x): pass
b: bool = True
foo(b) Gives us:
It would be great if this could infer |
I was puzzled by this for a while, with a case similar to srittau's example above. The third overload they mention is here in the docs: @overload
def fetch_data(raw: Literal[True]) -> bytes: ...
@overload
def fetch_data(raw: Literal[False]) -> str: ...
# The last overload is a fallback in case the caller
# provides a regular bool:
@overload
def fetch_data(raw: bool) -> Union[bytes, str]: ...
def fetch_data(raw: bool) -> Union[bytes, str]:
# Implementation is omitted
... Presumably when this issue is fixed, the third overload can be ditched |
This issue is marked as fixed but the last two comments have not been resolved - the fallback |
Fixes #91654. Currently, the `hook` parameters of `nn.Module.register_forward_pre_hook` and `nn.Module.register_forward_hook` are typed as `Callable[..., None]`, which 1) does not enable the validation of the signature of `hook` and 2) incorrectly restricts the return type of `hook`, which the docstrings of these methods themselves state can be non-`None`. The typing of the first parameter of `hook` as `TypeVar("T", bound="Module")` allows the binding of `Callable` whose first parameter is a subclass of `Module`. --- Here are some examples of: 1. forward hooks and pre-hook hooks being accepted by mypy according to the new type hints 2. mypy throwing errors d.t. incorrect `hook` signatures 3. false negatives of pre-hooks being accepted as forward hooks 4. false negatives of hooks with kwargs being accepted irrespective of the value provided for `with_kwargs` ```python from typing import Any, Dict, Tuple import torch from torch import nn def forward_pre_hook( module: nn.Linear, args: Tuple[torch.Tensor, ...], ) -> None: ... def forward_pre_hook_return_input( module: nn.Linear, args: Tuple[torch.Tensor, ...], ) -> Tuple[torch.Tensor, ...]: ... def forward_pre_hook_with_kwargs( module: nn.Linear, args: Tuple[torch.Tensor, ...], kwargs: Dict[str, Any], ) -> None: ... def forward_pre_hook_with_kwargs_return_input( module: nn.Linear, args: Tuple[torch.Tensor, ...], kwargs: Dict[str, Any], ) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]: ... def forward_hook( module: nn.Linear, args: Tuple[torch.Tensor, ...], output: torch.Tensor, ) -> None: ... def forward_hook_return_output( module: nn.Linear, args: Tuple[torch.Tensor, ...], output: torch.Tensor, ) -> torch.Tensor: ... def forward_hook_with_kwargs( module: nn.Linear, args: Tuple[torch.Tensor, ...], kwargs: Dict[str, Any], output: torch.Tensor, ) -> None: ... def forward_hook_with_kwargs_return_output( module: nn.Linear, args: Tuple[torch.Tensor, ...], kwargs: Dict[str, Any], output: torch.Tensor, ) -> torch.Tensor: ... model = nn.Module() # OK model.register_forward_pre_hook(forward_pre_hook) model.register_forward_pre_hook(forward_pre_hook_return_input) model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=True) model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=True) model.register_forward_hook(forward_hook) model.register_forward_hook(forward_hook_return_output) model.register_forward_hook(forward_hook_with_kwargs, with_kwargs=True) model.register_forward_hook(forward_hook_with_kwargs_return_output, with_kwargs=True) # mypy(error): [arg-type] model.register_forward_pre_hook(forward_hook) model.register_forward_pre_hook(forward_hook_return_output) model.register_forward_pre_hook(forward_hook_with_kwargs) model.register_forward_pre_hook(forward_hook_with_kwargs_return_output) model.register_forward_hook(forward_pre_hook) model.register_forward_hook(forward_pre_hook_return_input) # false negatives model.register_forward_hook(forward_pre_hook_with_kwargs) model.register_forward_hook(forward_pre_hook_with_kwargs_return_input) model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=False) model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=False) ... ``` --- Though it is not functional as of mypy 0.991, the ideal typing of these methods would use [`typing.Literal`](https://mypy.readthedocs.io/en/stable/literal_types.html#literal-types): ```python T = TypeVar("T", bound="Module") class Module: @overload def register_forward_hook( self, hook: Callable[[T, Tuple[Any, ...], Any], Optional[Any]], *, prepend: bool = ..., with_kwargs: Literal[False] = ..., ) -> RemovableHandle: ... @overload def register_forward_hook( self, hook: Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]], *, prepend: bool = ..., with_kwargs: Literal[True] = ..., ) -> RemovableHandle: ... def register_forward_hook(...): ... ``` which would: 1. validate the signature of `hook` according to the corresponding literal value provided for `with_kwargs` (and fix the false negative examples above) 2. implicitly define the [fallback `bool` signature](python/mypy#6113 (comment)) e.g. to handle if a non-literal is provided for `with_kwargs` Pull Request resolved: #92061 Approved by: https://github.com/albanD
Similar to this #6027, but more basic use case.
Mypy is giving the error
error: Item "bool" of "Union[bool, Collection[int]]" has no attribute "__iter__" (not iterable)
But it should be able to infer that it can't be a
bool
?using if / elif /else also gives the error.
The text was updated successfully, but these errors were encountered: