Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Added shared embedding option to director model. #4763

Merged
merged 2 commits into from
Sep 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions projects/director/director_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,37 @@
import parlai.utils.logging as logging


class ScalarLayer(nn.Module):
def __init__(self):
super().__init__()
self.params = nn.Parameter(torch.Tensor([1.0, 0.0]))

def forward(self, input: torch.Tensor):
return input * self.params[0].expand_as(input) + self.params[1].expand_as(input)


class DirectorModel(TransformerGeneratorModel):
"""
Director model that extends TransformerGeneratorModel and adds |V| binary classifier
heads.
"""

def __init__(self, opt: Opt, dictionary: DictionaryAgent, **kwargs):
def __init__(
self,
opt: Opt,
dictionary: DictionaryAgent,
**kwargs,
):
super().__init__(opt, dictionary, **kwargs)

vocabulary_size = len(dictionary)

decoder_output_dim = self.decoder.out_dim
self.classifier_heads = nn.Linear(decoder_output_dim, vocabulary_size)
self.use_shared_embedding = opt.get('director_use_shared_embedding', False)
if self.use_shared_embedding:
self.classifier_heads = ScalarLayer()
else:
self.classifier_heads = nn.Linear(decoder_output_dim, vocabulary_size)

self.infer_gamma = opt['train_gamma']
if opt.get('infer_gamma') is not None:
Expand All @@ -56,6 +74,9 @@ def classifier_output(self, input: torch.Tensor):
if self.freeze_decoder:
input = input.detach()

if self.use_shared_embedding:
input = self.generator_output(input)

return self.classifier_heads(input)

def output(self, latent: torch.Tensor):
Expand Down Expand Up @@ -182,6 +203,12 @@ def add_cmdline_args(
default=False,
help='Train the generation head with the positive examples from the feedback data.',
)
group.add_argument(
'--director-use-shared-embedding',
type=bool,
default=False,
help='Use a shared final embedding for the generator and classifier head of the director.',
)
return parser

def __init__(self, opt: Opt, shared=None):
Expand Down