Skip to content

Commit

Permalink
Option for noise in teacher forcing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 448489522
  • Loading branch information
CLRSDev authored and copybara-github committed May 13, 2022
1 parent 1a10ef6 commit 807952a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 14 deletions.
22 changes: 13 additions & 9 deletions clrs/_src/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
checkpoint_path: str = '/tmp/clrs3',
freeze_processor: bool = False,
dropout_prob: float = 0.0,
hint_teacher_forcing_noise: float = 0.0,
name: str = 'base_model',
):
"""Constructor for BaselineModel.
Expand Down Expand Up @@ -97,6 +98,9 @@ def __init__(
freeze_processor: If True, the processor weights will be frozen and
only encoders and decoders (and, if used, the lstm) will be trained.
dropout_prob: Dropout rate in the message-passing stage.
hint_teacher_forcing_noise: Probability of using predicted hints instead
of ground-truth hints as inputs during training (only relevant if
`encode_hints`=True
name: Model name.
Raises:
Expand Down Expand Up @@ -128,18 +132,18 @@ def __init__(
nb_dims[outp.name] = outp.data.shape[-1]
self.nb_dims.append(nb_dims)

self._create_net_fns(hidden_dim, encode_hints, kind,
use_lstm, dropout_prob, nb_heads)
self._create_net_fns(hidden_dim, encode_hints, kind, use_lstm,
dropout_prob, hint_teacher_forcing_noise, nb_heads)
self.params = None
self.opt_state = None
self.opt_state_skeleton = None

def _create_net_fns(self, hidden_dim, encode_hints, kind,
use_lstm, dropout_prob, nb_heads):
def _create_net_fns(self, hidden_dim, encode_hints, kind, use_lstm,
dropout_prob, hint_teacher_forcing_noise, nb_heads):
def _use_net(*args, **kwargs):
return nets.Net(self._spec, hidden_dim, encode_hints,
self.decode_hints, self.decode_diffs,
kind, use_lstm, dropout_prob,
kind, use_lstm, dropout_prob, hint_teacher_forcing_noise,
nb_heads, self.nb_dims)(*args, **kwargs)

self.net_fn = hk.transform(_use_net)
Expand Down Expand Up @@ -279,7 +283,7 @@ def verbose_loss(self, feedback: _Feedback, extra_info) -> Dict[str, _Array]:
losses_.update(
losses.hint_loss(
truth=truth,
preds=hint_preds,
preds=[x[truth.name] for x in hint_preds],
gt_diffs=gt_diffs,
lengths=lengths,
nb_nodes=nb_nodes,
Expand Down Expand Up @@ -323,13 +327,13 @@ class BaselineModelChunked(BaselineModel):
`BaselineModel`.
"""

def _create_net_fns(self, hidden_dim, encode_hints, kind,
use_lstm, dropout_prob, nb_heads):
def _create_net_fns(self, hidden_dim, encode_hints, kind, use_lstm,
dropout_prob, hint_teacher_forcing_noise, nb_heads):
def _use_net(*args, **kwargs):
return nets.NetChunked(
self._spec, hidden_dim, encode_hints,
self.decode_hints, self.decode_diffs,
kind, use_lstm, dropout_prob,
kind, use_lstm, dropout_prob, hint_teacher_forcing_noise,
nb_heads, self.nb_dims)(*args, **kwargs)

self.net_fn = hk.transform(_use_net)
Expand Down
34 changes: 29 additions & 5 deletions clrs/_src/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
kind: str,
use_lstm: bool,
dropout_prob: float,
hint_teacher_forcing_noise: float,
nb_heads: int,
nb_dims=None,
name: str = 'net',
Expand All @@ -93,6 +94,7 @@ def __init__(
super().__init__(name=name)

self._dropout_prob = dropout_prob
self._hint_teacher_forcing_noise = hint_teacher_forcing_noise
self.spec = spec
self.hidden_dim = hidden_dim
self.encode_hints = encode_hints
Expand All @@ -118,20 +120,34 @@ def _msg_passing_step(self,
decs: Dict[str, Tuple[hk.Module]],
diff_decs: Dict[str, Any],
):
if (not first_step) and repred and self.decode_hints:
if self.decode_hints and not first_step:
decoded_hint = decoders.postprocess(spec,
mp_state.hint_preds)
if repred and self.decode_hints and not first_step:
cur_hint = []
for hint in decoded_hint:
cur_hint.append(decoded_hint[hint])
else:
cur_hint = []
needs_noise = (self.decode_hints and not first_step and
self._hint_teacher_forcing_noise > 0)
if needs_noise:
# For noisy teacher forcing, choose which examples in the batch to force
force_mask = jax.random.bernoulli(
hk.next_rng_key(), 1.0 - self._hint_teacher_forcing_noise,
(batch_size,))
else:
force_mask = None
for hint in hints:
hint.data = jnp.asarray(hint.data)
hint_data = jnp.asarray(hint.data)[i]
if needs_noise:
hint_data = jnp.where(_expand_to(force_mask, hint_data),
hint_data,
decoded_hint[hint.name].data)
_, loc, typ = spec[hint.name]
cur_hint.append(
probing.DataPoint(
name=hint.name, location=loc, type_=typ, data=hint.data[i]))
name=hint.name, location=loc, type_=typ, data=hint_data))

gt_diffs = None
if hints[0].data.shape[0] > 1 and self.decode_diffs:
Expand Down Expand Up @@ -501,13 +517,21 @@ def _as_prediction_data(hint):
hints_for_pred = hints
else:
prev_hint_preds = mp_state.hint_preds
if repred and self.decode_hints:
if self.decode_hints:
if repred:
force_mask = jnp.zeros(batch_size, dtype=bool)
elif self._hint_teacher_forcing_noise == 0.:
force_mask = jnp.ones(batch_size, dtype=bool)
else:
force_mask = jax.random.bernoulli(
hk.next_rng_key(), 1.0 - self._hint_teacher_forcing_noise,
(batch_size,))
decoded_hints = decoders.postprocess(spec, prev_hint_preds)
hints_for_pred = []
for h in hints:
hints_for_pred.append(probing.DataPoint(
name=h.name, location=h.location, type_=h.type_,
data=jnp.where(_expand_to(is_first, h.data),
data=jnp.where(_expand_to(is_first | force_mask, h.data),
h.data, decoded_hints[h.name].data)))
else:
hints_for_pred = hints
Expand Down

0 comments on commit 807952a

Please sign in to comment.