From 1bd85b6dca9bf01f4f420e435eaeec7e9931c633 Mon Sep 17 00:00:00 2001 From: liuzh91 Date: Tue, 1 Sep 2020 14:58:51 +0800 Subject: [PATCH] [BUGFIX] fix valid candidates issue (#1323) * fix valid candidates issue * replace numpy with mxnet numpy * update gumbel trick Co-authored-by: Ubuntu --- scripts/pretraining/pretraining_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/pretraining/pretraining_utils.py b/scripts/pretraining/pretraining_utils.py index d8d75df45f..cc84641589 100644 --- a/scripts/pretraining/pretraining_utils.py +++ b/scripts/pretraining/pretraining_utils.py @@ -499,7 +499,7 @@ def dynamic_masking(self, F, input_ids, valid_lengths): for ignore_token in ignore_tokens: # TODO(zheyuye), Update when operation += supported - valid_candidates = valid_candidates + \ + valid_candidates = valid_candidates * \ F.np.not_equal(input_ids, ignore_token) valid_lengths = valid_lengths.astype(np.float32) valid_candidates = valid_candidates.astype(np.float32) @@ -507,14 +507,15 @@ def dynamic_masking(self, F, input_ids, valid_lengths): 1, F.np.minimum(N, round(valid_lengths * self._mask_prob))) # Get the masking probability of each position - sample_probs = F.npx.softmax( - self._proposal_distribution * valid_candidates, axis=-1) # (B, L) + sample_probs = self._proposal_distribution * valid_candidates + sample_probs /= F.np.sum(sample_probs, axis=-1, keepdims=True) sample_probs = F.npx.stop_gradient(sample_probs) gumbels = F.np.random.gumbel(F.np.zeros_like(sample_probs)) # Following the instruction of official repo to avoid deduplicate postions # with Top_k Sampling as https://github.com/google-research/electra/issues/41 masked_positions = F.npx.topk( - sample_probs + gumbels, k=N, axis=-1, ret_typ='indices', dtype=np.int32) + F.np.log(sample_probs) + gumbels, k=N, + axis=-1, ret_typ='indices', dtype=np.int32) masked_weights = F.npx.sequence_mask( F.np.ones_like(masked_positions),