Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Treat bool as equivalent to Literal[True, False] to allow better type inference #6113

Closed
ahmedsoe opened this issue Dec 28, 2018 · 6 comments · Fixed by #10389
Closed

Treat bool as equivalent to Literal[True, False] to allow better type inference #6113

ahmedsoe opened this issue Dec 28, 2018 · 6 comments · Fixed by #10389

Comments

@ahmedsoe
Copy link

ahmedsoe commented Dec 28, 2018

Similar to this #6027, but more basic use case.

from typing import Collection, Union, List


def foo(bar: Union[bool, Collection[int]]) -> Union[int, List[int]]:
    if bar is False:  # or   if not bar:
        return 10

    if bar is True:
        return 20

    # otherwise, bar must be a collection (list, tuple, ...)
    return [i*2 for i in bar]


print(foo(True))
print(foo(False))
print(foo([1, 2, 3]))
print(foo((1, 2, 3,)))

Mypy is giving the error

error: Item "bool" of "Union[bool, Collection[int]]" has no attribute "__iter__" (not iterable)

But it should be able to infer that it can't be a bool?
using if / elif /else also gives the error.

@ilevkivskyi
Copy link
Member

I think this can be done using literal types: essentially we can treat bool as Union[Literal[True], Literal[False]].

cc @Michael0x2a

@ilevkivskyi
Copy link
Member

This appeared again so raising priority to normal.

@ilevkivskyi ilevkivskyi changed the title Infer type after membership check Treat bool as equivalent to Literal[True, False] to allow better type inference Aug 7, 2019
@ethan-leba
Copy link
Contributor

Would like to take a stab at this if no one's picked it up already!

@srittau
Copy link
Contributor

srittau commented May 17, 2021

Another use case:

from typing import Literal, overload

@overload
def foo(x: Literal[False]) -> int: ...
@overload
def foo(x: Literal[True]) -> str: ...
def foo(x): pass

b: bool = True
foo(b)

Gives us:

foo.py:10: error: No overload variant of "foo" matches argument type "bool"
foo.py:10: note: Possible overload variants:
foo.py:10: note:     def foo(x: Literal[False]) -> int
foo.py:10: note:     def foo(x: Literal[True]) -> str
foo.py:10: note: Revealed type is 'Any'
Found 1 error in 1 file (checked 1 source file)

It would be great if this could infer foo(x: bool) -> str | int: ... like it does in other kinds of overloads. python/typeshed#5471 contains a case where we can't work around this by using a third overload.

@MetRonnie
Copy link

I was puzzled by this for a while, with a case similar to srittau's example above. The third overload they mention is here in the docs:

@overload
def fetch_data(raw: Literal[True]) -> bytes: ...
@overload
def fetch_data(raw: Literal[False]) -> str: ...

# The last overload is a fallback in case the caller
# provides a regular bool:

@overload
def fetch_data(raw: bool) -> Union[bytes, str]: ...

def fetch_data(raw: bool) -> Union[bytes, str]:
    # Implementation is omitted
    ...

Presumably when this issue is fixed, the third overload can be ditched

@matangover
Copy link
Contributor

matangover commented Oct 3, 2022

This issue is marked as fixed but the last two comments have not been resolved - the fallback bool overload is still needed. Is this by design? @hauntsaninja

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Jan 13, 2023
Fixes #91654.

Currently, the `hook` parameters of `nn.Module.register_forward_pre_hook` and `nn.Module.register_forward_hook` are typed as `Callable[..., None]`, which 1) does not enable the validation of the signature of `hook` and 2) incorrectly restricts the return type of `hook`, which the docstrings of these methods themselves state can be non-`None`.

The typing of the first parameter of `hook` as `TypeVar("T", bound="Module")` allows the binding of `Callable` whose first parameter is a subclass of `Module`.

---

Here are some examples of:
1. forward hooks and pre-hook hooks being accepted by mypy according to the new type hints
2. mypy throwing errors d.t. incorrect `hook` signatures
3. false negatives of pre-hooks being accepted as forward hooks
4. false negatives of hooks with kwargs being accepted irrespective of the value provided for `with_kwargs`

```python
from typing import Any, Dict, Tuple

import torch
from torch import nn

def forward_pre_hook(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
) -> None:
    ...

def forward_pre_hook_return_input(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, ...]:
    ...

def forward_pre_hook_with_kwargs(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
) -> None:
    ...

def forward_pre_hook_with_kwargs_return_input(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]:
    ...

def forward_hook(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    output: torch.Tensor,
) -> None:
    ...

def forward_hook_return_output(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    output: torch.Tensor,
) -> torch.Tensor:
    ...

def forward_hook_with_kwargs(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
    output: torch.Tensor,
) -> None:
    ...

def forward_hook_with_kwargs_return_output(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
    output: torch.Tensor,
) -> torch.Tensor:
    ...

model = nn.Module()

# OK
model.register_forward_pre_hook(forward_pre_hook)
model.register_forward_pre_hook(forward_pre_hook_return_input)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=True)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=True)

model.register_forward_hook(forward_hook)
model.register_forward_hook(forward_hook_return_output)
model.register_forward_hook(forward_hook_with_kwargs, with_kwargs=True)
model.register_forward_hook(forward_hook_with_kwargs_return_output, with_kwargs=True)

# mypy(error): [arg-type]
model.register_forward_pre_hook(forward_hook)
model.register_forward_pre_hook(forward_hook_return_output)
model.register_forward_pre_hook(forward_hook_with_kwargs)
model.register_forward_pre_hook(forward_hook_with_kwargs_return_output)

model.register_forward_hook(forward_pre_hook)
model.register_forward_hook(forward_pre_hook_return_input)

# false negatives
model.register_forward_hook(forward_pre_hook_with_kwargs)
model.register_forward_hook(forward_pre_hook_with_kwargs_return_input)

model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=False)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=False)
...
```

---

Though it is not functional as of mypy 0.991, the ideal typing of these methods would use [`typing.Literal`](https://mypy.readthedocs.io/en/stable/literal_types.html#literal-types):

```python
T = TypeVar("T", bound="Module")

class Module:

    @overload
    def register_forward_hook(
        self,
        hook: Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
        *,
        prepend: bool = ...,
        with_kwargs: Literal[False] = ...,
    ) -> RemovableHandle:
        ...

    @overload
    def register_forward_hook(
        self,
        hook: Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
        *,
        prepend: bool = ...,
        with_kwargs: Literal[True] = ...,
    ) -> RemovableHandle:
        ...

    def register_forward_hook(...):
        ...

```

which would:

1. validate the signature of `hook` according to the corresponding literal value provided for `with_kwargs` (and fix the false negative examples above)
2. implicitly define the [fallback `bool` signature](python/mypy#6113 (comment)) e.g. to handle if a non-literal is provided for `with_kwargs`
Pull Request resolved: #92061
Approved by: https://github.com/albanD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants