Skip to content

Commit

Permalink
Follow-ups from initial mypy PR (facebookresearch#25)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#25

A couple follow-ups from facebookresearch#22:
- Add TODO for type checking FLAVA code
- Add a couple other type hints missed by mypy in clip_text_encoder (due to https://mypy.readthedocs.io/en/stable/common_issues.html#no-errors-reported-for-obviously-wrong-code)
- Some changes to weighted_embedding_encoder to silence mypy errors

Test Plan:
```
$ python -m pytest -v
...
====================================================================================== short test summary info =======================================================================================
FAILED test/transforms/test_clip_transform.py::TestCLIPTransform::test_clip_multi_transform - requests.exceptions.MissingSchema: Invalid URL '/data/home/ebs/torchmultimodal/torchmultimodal/test/a...
FAILED test/transforms/test_clip_transform.py::TestCLIPTransform::test_clip_single_transform - requests.exceptions.MissingSchema: Invalid URL '/data/home/ebs/torchmultimodal/torchmultimodal/test/...
================================================================== 2 failed, 48 passed, 3 skipped, 26 warnings in 76.48s (0:01:16) ===================================================================

```
(Failures are pre-existing and will be fixed with a separate commit)

```
$ mypy
Success: no issues found in 29 source files
```

Reviewed By: langong347

Differential Revision: D35617460

Pulled By: ebsmothers

fbshipit-source-id: ab2bf0ec7f85fa700287787728aa40443b9c7a0c
  • Loading branch information
ebsmothers authored and facebook-github-bot committed Apr 13, 2022
1 parent 82b6641 commit 003ebca
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ allow_redefinition = True
namespace_packages = True
install_types = True

# TODO (T116951827): Remove after fixing FLAVA type check errors
exclude = models/flava.py|modules/losses/flava.py

[mypy-PIL.*]
Expand Down
4 changes: 2 additions & 2 deletions torchmultimodal/modules/encoders/clip_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
if use_clip_init:
self.initialize_parameters()

def initialize_parameters(self):
def initialize_parameters(self) -> None:
# Initialize token and positional embeddings
nn.init.normal_(
self.encoder.token_embedding.weight, std=self.TOKEN_EMBEDDING_INIT_STD
Expand All @@ -85,7 +85,7 @@ def initialize_parameters(self):
# Initialize projection
nn.init.normal_(self.projection.weight, std=self.width ** -0.5)

def build_attention_mask(self):
def build_attention_mask(self) -> torch.Tensor:
mask = torch.full((self.context_length, self.context_length), True).triu(1)
return mask.to(device=None, dtype=torch.bool)

Expand Down
18 changes: 9 additions & 9 deletions torchmultimodal/modules/encoders/weighted_embedding_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple, Union, Callable
from typing import Callable, Tuple, Union

import torch
from torch import nn, Tensor
Expand All @@ -16,7 +16,7 @@ class WeightedEmbeddingEncoder(nn.Module):
Args:
embedding (nn.Embedding): embedding module
pooling_function (Callable[[Tensor], Union[Tensor, Tuple]]): pooling function to combine the weighted embeddings,\
pooling_function (Callable[[Tensor, int], Union[Tensor, Tuple]]): pooling function to combine the weighted embeddings,\
example: torch.sum function should return a tensor or namedtuple containing the tensor in the values field like torch.max
pooling_dim (int) : dimension along which the pooling function is applied
use_hash (bool): if hashing based on embedding vocab size if applied to input
Expand All @@ -31,7 +31,7 @@ class WeightedEmbeddingEncoder(nn.Module):
def __init__(
self,
embedding: nn.Embedding,
pooling_function: Callable[[Tensor], Union[Tensor, Tuple]],
pooling_function: Callable[[Tensor, int], Union[Tensor, Tuple]],
pooling_dim: int = 1,
use_hash: bool = False,
) -> None:
Expand Down Expand Up @@ -67,12 +67,12 @@ def forward(self, x: Tensor) -> Tensor:

weighted_embeddings = self.embedding(index) * weights.unsqueeze(-1)

pooled_embeddings = self.pooling_function(
weighted_embeddings, dim=self.pooling_dim
)
if not isinstance(pooled_embeddings, Tensor):
pooled_embeddings = self.pooling_function(weighted_embeddings, self.pooling_dim)
if isinstance(pooled_embeddings, Tensor):
output: Tensor = pooled_embeddings
else:
assert hasattr(
pooled_embeddings, "values"
), "pooled embeddings should be Tensor or tuple with values field as Tensor"
pooled_embeddings = pooled_embeddings.values
return pooled_embeddings
output = pooled_embeddings.values # type: ignore
return output

0 comments on commit 003ebca

Please sign in to comment.