Skip to content

Commit 003ebca

Browse files
ebsmothersfacebook-github-bot
authored andcommitted
Follow-ups from initial mypy PR (facebookresearch#25)
Summary: Pull Request resolved: facebookresearch#25 A couple follow-ups from facebookresearch#22: - Add TODO for type checking FLAVA code - Add a couple other type hints missed by mypy in clip_text_encoder (due to https://mypy.readthedocs.io/en/stable/common_issues.html#no-errors-reported-for-obviously-wrong-code) - Some changes to weighted_embedding_encoder to silence mypy errors Test Plan: ``` $ python -m pytest -v ... ====================================================================================== short test summary info ======================================================================================= FAILED test/transforms/test_clip_transform.py::TestCLIPTransform::test_clip_multi_transform - requests.exceptions.MissingSchema: Invalid URL '/data/home/ebs/torchmultimodal/torchmultimodal/test/a... FAILED test/transforms/test_clip_transform.py::TestCLIPTransform::test_clip_single_transform - requests.exceptions.MissingSchema: Invalid URL '/data/home/ebs/torchmultimodal/torchmultimodal/test/... ================================================================== 2 failed, 48 passed, 3 skipped, 26 warnings in 76.48s (0:01:16) =================================================================== ``` (Failures are pre-existing and will be fixed with a separate commit) ``` $ mypy Success: no issues found in 29 source files ``` Reviewed By: langong347 Differential Revision: D35617460 Pulled By: ebsmothers fbshipit-source-id: ab2bf0ec7f85fa700287787728aa40443b9c7a0c
1 parent 82b6641 commit 003ebca

File tree

3 files changed

+12
-11
lines changed

3 files changed

+12
-11
lines changed

mypy.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ allow_redefinition = True
1313
namespace_packages = True
1414
install_types = True
1515

16+
# TODO (T116951827): Remove after fixing FLAVA type check errors
1617
exclude = models/flava.py|modules/losses/flava.py
1718

1819
[mypy-PIL.*]

torchmultimodal/modules/encoders/clip_text_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(
6161
if use_clip_init:
6262
self.initialize_parameters()
6363

64-
def initialize_parameters(self):
64+
def initialize_parameters(self) -> None:
6565
# Initialize token and positional embeddings
6666
nn.init.normal_(
6767
self.encoder.token_embedding.weight, std=self.TOKEN_EMBEDDING_INIT_STD
@@ -85,7 +85,7 @@ def initialize_parameters(self):
8585
# Initialize projection
8686
nn.init.normal_(self.projection.weight, std=self.width ** -0.5)
8787

88-
def build_attention_mask(self):
88+
def build_attention_mask(self) -> torch.Tensor:
8989
mask = torch.full((self.context_length, self.context_length), True).triu(1)
9090
return mask.to(device=None, dtype=torch.bool)
9191

torchmultimodal/modules/encoders/weighted_embedding_encoder.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Tuple, Union, Callable
7+
from typing import Callable, Tuple, Union
88

99
import torch
1010
from torch import nn, Tensor
@@ -16,7 +16,7 @@ class WeightedEmbeddingEncoder(nn.Module):
1616
1717
Args:
1818
embedding (nn.Embedding): embedding module
19-
pooling_function (Callable[[Tensor], Union[Tensor, Tuple]]): pooling function to combine the weighted embeddings,\
19+
pooling_function (Callable[[Tensor, int], Union[Tensor, Tuple]]): pooling function to combine the weighted embeddings,\
2020
example: torch.sum function should return a tensor or namedtuple containing the tensor in the values field like torch.max
2121
pooling_dim (int) : dimension along which the pooling function is applied
2222
use_hash (bool): if hashing based on embedding vocab size if applied to input
@@ -31,7 +31,7 @@ class WeightedEmbeddingEncoder(nn.Module):
3131
def __init__(
3232
self,
3333
embedding: nn.Embedding,
34-
pooling_function: Callable[[Tensor], Union[Tensor, Tuple]],
34+
pooling_function: Callable[[Tensor, int], Union[Tensor, Tuple]],
3535
pooling_dim: int = 1,
3636
use_hash: bool = False,
3737
) -> None:
@@ -67,12 +67,12 @@ def forward(self, x: Tensor) -> Tensor:
6767

6868
weighted_embeddings = self.embedding(index) * weights.unsqueeze(-1)
6969

70-
pooled_embeddings = self.pooling_function(
71-
weighted_embeddings, dim=self.pooling_dim
72-
)
73-
if not isinstance(pooled_embeddings, Tensor):
70+
pooled_embeddings = self.pooling_function(weighted_embeddings, self.pooling_dim)
71+
if isinstance(pooled_embeddings, Tensor):
72+
output: Tensor = pooled_embeddings
73+
else:
7474
assert hasattr(
7575
pooled_embeddings, "values"
7676
), "pooled embeddings should be Tensor or tuple with values field as Tensor"
77-
pooled_embeddings = pooled_embeddings.values
78-
return pooled_embeddings
77+
output = pooled_embeddings.values # type: ignore
78+
return output

0 commit comments

Comments
 (0)