Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[BUGFIX] fix valid candidates issue (#1323)
Browse files Browse the repository at this point in the history
* fix valid candidates issue

* replace numpy with mxnet numpy

* update gumbel trick

Co-authored-by: Ubuntu <ubuntu@ip-10-20-2-34.ec2.internal>
  • Loading branch information
liuzh47 and Ubuntu authored Sep 1, 2020
1 parent ff95fb4 commit 1bd85b6
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions scripts/pretraining/pretraining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,22 +499,23 @@ def dynamic_masking(self, F, input_ids, valid_lengths):

for ignore_token in ignore_tokens:
# TODO(zheyuye), Update when operation += supported
valid_candidates = valid_candidates + \
valid_candidates = valid_candidates * \
F.np.not_equal(input_ids, ignore_token)
valid_lengths = valid_lengths.astype(np.float32)
valid_candidates = valid_candidates.astype(np.float32)
num_masked_position = F.np.maximum(
1, F.np.minimum(N, round(valid_lengths * self._mask_prob)))

# Get the masking probability of each position
sample_probs = F.npx.softmax(
self._proposal_distribution * valid_candidates, axis=-1) # (B, L)
sample_probs = self._proposal_distribution * valid_candidates
sample_probs /= F.np.sum(sample_probs, axis=-1, keepdims=True)
sample_probs = F.npx.stop_gradient(sample_probs)
gumbels = F.np.random.gumbel(F.np.zeros_like(sample_probs))
# Following the instruction of official repo to avoid deduplicate postions
# with Top_k Sampling as https://github.com/google-research/electra/issues/41
masked_positions = F.npx.topk(
sample_probs + gumbels, k=N, axis=-1, ret_typ='indices', dtype=np.int32)
F.np.log(sample_probs) + gumbels, k=N,
axis=-1, ret_typ='indices', dtype=np.int32)

masked_weights = F.npx.sequence_mask(
F.np.ones_like(masked_positions),
Expand Down

0 comments on commit 1bd85b6

Please sign in to comment.