Skip to content
This repository was archived by the owner on Feb 12, 2020. It is now read-only.

Solved some bug #3

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# MaskGAN.pytorch

A PyTorch attempt at reimplementing
A PyTorch attempt at reimplementing

* MaskGAN: Better Text Generation via Filling in the _______ , William Fedus, Ian Goodfellow, Andrew M. Dai
[[paper]](https://openreview.net/pdf?id=ByOExmWAb)

**This is a work in progress.**

==Solved some bugs in original repository==
https://github.com/jerinphilip/MaskGAN.pytorch

# Setting up

Expand All @@ -28,12 +28,12 @@ python3 -m pip install git+https://github.com/pytorch/fairseq

#### IMDB Reviews Dataset
```
mkdir datasets
mkdir datasets
cd datasets
IMDB_DATASET='http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
wget $IMDB_DATASET -O aclImdb_v1.tar.gz
tar xvzf aclImdb_v1.tar.gz
```
```

#### Training

Expand All @@ -48,6 +48,6 @@ Run the training script.

```
python3 -m mgan.main \
--path datasets/aclImdb/train/ \
--spm_path datasets/aclImdb/train/imdb.model
--path datasets/aclImdb/ \
--spm_prefix datasets/aclImdb/train/imdb
```
4 changes: 2 additions & 2 deletions mgan/data/imdb_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, path, tokenizer, mask_builder, truncate_length, vocab=None):
def _construct_vocabulary(self):
if self.vocab is None:
raw_dataset = IMDbDataset(self.path)
builder = VocabBuilder(raw_dataset, self.tokenizer, self.path)
builder = VocabBuilder(raw_dataset, self.tokenizer, self.path, self.mask_builder)
self.vocab = builder.vocab()

def __len__(self):
Expand Down Expand Up @@ -76,7 +76,7 @@ def collate(samples):

lengths = torch.LongTensor(lengths)
lengths, sort_order = lengths.sort(descending=True)

def _rearrange(tensor):
return tensor.index_select(0, sort_order)

Expand Down
12 changes: 8 additions & 4 deletions mgan/data/vocab_builder.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
import os
from tqdm import tqdm
from fairseq.data.dictionary import Dictionary

class VocabBuilder:
def __init__(self, dataset, tokenizer, save_path):
def __init__(self, dataset, tokenizer, save_path, mask_builder):
self.save_path = save_path
self.dataset = dataset
self.tokenizer = tokenizer
self.vocab_path = os.path.join(save_path, 'vocab.pt')
self._vocab = None
self.mask_builder = mask_builder

def vocab(self):
if self._vocab is None:
self.build_vocab()
return self._vocab

def build_vocab(self):
print('vocab path:',self.vocab_path)
if os.path.exists(self.vocab_path):
self._vocab = Dictionary.load(self.vocab_path)
else:
self.rebuild_vocab()

def rebuild_vocab(self):
self._vocab = Dictionary()
self._vocab.add_symbol(self.mask_builder.mask_token)
desc = 'build-vocab: {}'.format(self.save_path)
pbar = tqdm(
range(len(self.dataset)),
desc=desc,
range(len(self.dataset)),
desc=desc,
leave=True
)

Expand Down
15 changes: 8 additions & 7 deletions mgan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,36 +37,37 @@ def main(args):
truncate_length = 20
batch_size = int(max_tokens/truncate_length)

checkpoint_path = "/home/jerin/mgan-attempts/"
checkpoint_path = "/data/neil_noadmin/jerin/mgan-attempts/"
saver = Saver(checkpoint_path)

train_path = os.path.join(args.path, 'train')
print('train path:',train_path)
dev_path = os.path.join(args.path, 'test')

train_dataset = TensorIMDbDataset(
train_path, spm_tokenize,
train_path, spm_tokenize,
rmask, truncate_length
)

# Constructed vocabulary from train
vocab = train_dataset.vocab
Task = namedtuple('Task', 'source_dictionary target_dictionary')
task = Task(source_dictionary=vocab,
task = Task(source_dictionary=vocab,
target_dictionary=vocab)

trainer = MGANTrainer(args, task, saver, visdom, vocab)
def loader(dataset):
_loader = DataLoader(dataset, batch_size=batch_size,
collate_fn=TensorIMDbDataset.collate,
_loader = DataLoader(dataset, batch_size=batch_size,
collate_fn=TensorIMDbDataset.collate,
shuffle=True, num_workers=8)
return _loader

#trainer.validate_dataset(loader(train_dataset))

dev_dataset = TensorIMDbDataset(
dev_path, spm_tokenize,
dev_path, spm_tokenize,
rmask, truncate_length,
vocab
vocab
)

Datasets = namedtuple('Dataset', 'train dev')
Expand Down
7 changes: 3 additions & 4 deletions mgan/models/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def __init__(self, *args, **kwargs):
out_embed_dim = self.additional_fc.out_features if hasattr(self, "additional_fc") else self.hidden_size
self.fc_out = nn.Linear(out_embed_dim, 1)

def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
x, attn_scores = super().forward(prev_output_tokens, encoder_out_dict, incremental_state)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
x, attn_scores = super().forward(prev_output_tokens, encoder_out, incremental_state)
return x, attn_scores


Expand Down Expand Up @@ -58,7 +58,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
)
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to '
'--share-all-embeddings requires --encoder_embed_dim to '
'match --decoder-embed-dim'
)
pretrained_decoder_embed = pretrained_encoder_embed
Expand Down Expand Up @@ -99,7 +99,6 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out,
attention=options.eval_bool(args.decoder_attention),
encoder_embed_dim=args.encoder_embed_dim,
encoder_output_units=encoder.output_units,
pretrained_embed=pretrained_decoder_embed,
share_input_output_embed=args.share_decoder_input_output_embed,
Expand Down
7 changes: 3 additions & 4 deletions mgan/models/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def __init__(self, *args, **kwargs):
out_embed_dim = self.additional_fc.out_features if hasattr(self, "additional_fc") else self.hidden_size
self.fc_out = nn.Linear(out_embed_dim, 1)

def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
x, attn_scores = super().forward(prev_output_tokens, encoder_out_dict, incremental_state)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
x, attn_scores = super().forward(prev_output_tokens, encoder_out, incremental_state)
# Do not apply sigmoid, numerically unstable while training.
# Get logits and use BCEWithLogitsLoss() instead.
# x = torch.sigmoid(x)
Expand Down Expand Up @@ -64,7 +64,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
)
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to '
'--share-all-embeddings requires --encoder_embed_dim to '
'match --decoder-embed-dim'
)
pretrained_decoder_embed = pretrained_encoder_embed
Expand Down Expand Up @@ -105,7 +105,6 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out,
attention=options.eval_bool(args.decoder_attention),
encoder_embed_dim=args.encoder_embed_dim,
encoder_output_units=encoder.output_units,
pretrained_embed=pretrained_decoder_embed,
share_input_output_embed=args.share_decoder_input_output_embed,
Expand Down
5 changes: 3 additions & 2 deletions mgan/preproc/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, model_prefix):
for key in ['model', 'vocab']:
self.path[key] = '{}.{}'.format(self.prefix, key)

self.sp = spm.SentencePieceProcessor()
self.sp = spm.SentencePieceProcessor()
self.sp.Load(self.path['model'])

# Build vocabulary.
Expand All @@ -33,7 +33,8 @@ def build_vocabulary(self):
def __call__(self, text):
tokens = self.sp.EncodeAsPieces(text)

to_utf = lambda x: x.decode("utf-8")
# to_utf = lambda x: x.decode("utf-8")
to_utf = lambda x: x
stokens = list(map(to_utf, tokens))

wanted = lambda s: s in self.vocab
Expand Down