Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Optional shared encoder for the two tower model. (#1705)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1705

Provide shared encoder option for the two tower model.

Reviewed By: sinonwang

Differential Revision: D29557881

fbshipit-source-id: cd2702c157248fbbfa365358fd23aba075509336
  • Loading branch information
HannaMao authored and facebook-github-bot committed Jul 22, 2021
1 parent 5577474 commit fdd8aa2
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 14 deletions.
14 changes: 14 additions & 0 deletions pytext/config/config_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,20 @@ def v34_to_v35(json_config):
return json_config


@register_down_grade_adapter(from_version=36)
def v36_to_v35(json_config):
for v in get_json_config_iterator(json_config, "TwoTowerClassificationModel"):
if "use_shared_encoder" in v:
del v["use_shared_encoder"]
return json_config


@register_adapter(from_version=35)
def v35_to_v36(json_config):
# New config field was added with backwards-compatible default value
return json_config


def get_name_from_options(export_config):
"""
Reverse engineer which model is which based on recognized
Expand Down
2 changes: 1 addition & 1 deletion pytext/config/pytext_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,4 +409,4 @@ class LogitsConfig(TestConfig):
fp16: bool = False


LATEST_VERSION = 35
LATEST_VERSION = 36
1 change: 1 addition & 0 deletions pytext/config/test/json_config/v36_test_downgrade.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
1 change: 1 addition & 0 deletions pytext/config/test/json_config/v36_test_upgrade.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
52 changes: 39 additions & 13 deletions pytext/models/two_tower_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class InputConfig(ConfigBase):
output_layer: ClassificationOutputLayer.Config = (
ClassificationOutputLayer.Config()
)
use_shared_encoder: Optional[bool] = False

def trace(self, inputs):
return torch.jit.trace(self, inputs)
Expand Down Expand Up @@ -113,16 +114,25 @@ def forward(
left_encoder_inputs: Tuple[torch.Tensor, ...],
*args
) -> List[torch.Tensor]:
if self.right_encoder.output_encoded_layers:
# if encoded layers are returned, discard them
right_representation = self.right_encoder(right_encoder_inputs)[1]
if self.use_shared_encoder:
if self.right_encoder.output_encoded_layers:
# if encoded layers are returned, discard them
right_representation = self.right_encoder(right_encoder_inputs)[1]
left_representation = self.right_encoder(left_encoder_inputs)[1]
else:
right_representation = self.right_encoder(right_encoder_inputs)[0]
left_representation = self.right_encoder(left_encoder_inputs)[0]
else:
right_representation = self.right_encoder(right_encoder_inputs)[0]
if self.left_encoder.output_encoded_layers:
# if encoded layers are returned, discard them
left_representation = self.left_encoder(left_encoder_inputs)[1]
else:
left_representation = self.left_encoder(left_encoder_inputs)[0]
if self.right_encoder.output_encoded_layers:
# if encoded layers are returned, discard them
right_representation = self.right_encoder(right_encoder_inputs)[1]
else:
right_representation = self.right_encoder(right_encoder_inputs)[0]
if self.left_encoder.output_encoded_layers:
# if encoded layers are returned, discard them
left_representation = self.left_encoder(left_encoder_inputs)[1]
else:
left_representation = self.left_encoder(left_encoder_inputs)[0]
return self.decoder(right_representation, left_representation, *args)

def caffe2_export(self, tensorizers, tensor_dict, path, export_onnx_path=None):
Expand Down Expand Up @@ -178,16 +188,32 @@ def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
output_layer_cls = MulticlassOutputLayer

output_layer = output_layer_cls(list(labels), loss)
return cls(right_encoder, left_encoder, decoder, output_layer)
return cls(
right_encoder,
left_encoder,
decoder,
output_layer,
config.use_shared_encoder,
)

def __init__(
self, right_encoder, left_encoder, decoder, output_layer, stage=Stage.TRAIN
self,
right_encoder,
left_encoder,
decoder,
output_layer,
use_shared_encoder,
stage=Stage.TRAIN,
) -> None:
super().__init__(stage=stage)
self.right_encoder = right_encoder
self.left_encoder = left_encoder
self.use_shared_encoder = use_shared_encoder
self.decoder = decoder
self.module_list = [right_encoder, left_encoder, decoder]
if self.use_shared_encoder:
self.module_list = [right_encoder, decoder]
else:
self.left_encoder = left_encoder
self.module_list = [right_encoder, left_encoder, decoder]
self.output_layer = output_layer
self.stage = stage
log_class_usage(__class__)

0 comments on commit fdd8aa2

Please sign in to comment.