This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Avoid some LR scheduler warnings. * Round out schedulers. And end to end tests. * unigram agent works with beam search
- Loading branch information
1 parent
9c894ba
commit 5de0fbc
Showing
7 changed files
with
185 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
UnigramAgent always predicts the unigram distribution. | ||
It is a full TorchGeneratorAgent model, so it can be used heavily in testing, while | ||
being very quick to optimize. | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
from parlai.core.torch_generator_agent import TorchGeneratorAgent, TorchGeneratorModel | ||
|
||
|
||
class UnigramEncoder(nn.Module): | ||
def forward(self, x): | ||
return None | ||
|
||
|
||
class UnigramDecoder(nn.Module): | ||
def forward(self, x, encoder_state, incr_state=None): | ||
return x.unsqueeze(-1), None | ||
|
||
|
||
class UnigramModel(TorchGeneratorModel): | ||
def __init__(self, dictionary): | ||
super().__init__() | ||
self.encoder = UnigramEncoder() | ||
self.decoder = UnigramDecoder() | ||
self.v = len(dictionary) | ||
self.p = nn.Parameter(torch.zeros(self.v)) | ||
|
||
def output(self, do): | ||
desired = list(do.shape)[:2] + [self.v] | ||
x = self.p.unsqueeze(0).unsqueeze(0) | ||
return x.expand(desired) | ||
|
||
def reorder_encoder_states(self, *args): | ||
return None | ||
|
||
def reorder_decoder_incremental_state(self, *args): | ||
return None | ||
|
||
|
||
class UnigramAgent(TorchGeneratorAgent): | ||
def build_model(self): | ||
return UnigramModel(self.dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters