Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option for noise in teacher forcing #74

Merged
merged 1 commit into from
May 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions clrs/_src/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
checkpoint_path: str = '/tmp/clrs3',
freeze_processor: bool = False,
dropout_prob: float = 0.0,
hint_teacher_forcing_noise: float = 0.0,
name: str = 'base_model',
):
"""Constructor for BaselineModel.
Expand Down Expand Up @@ -97,6 +98,9 @@ def __init__(
freeze_processor: If True, the processor weights will be frozen and
only encoders and decoders (and, if used, the lstm) will be trained.
dropout_prob: Dropout rate in the message-passing stage.
hint_teacher_forcing_noise: Probability of using predicted hints instead
of ground-truth hints as inputs during training (only relevant if
`encode_hints`=True
name: Model name.

Raises:
Expand Down Expand Up @@ -128,18 +132,18 @@ def __init__(
nb_dims[outp.name] = outp.data.shape[-1]
self.nb_dims.append(nb_dims)

self._create_net_fns(hidden_dim, encode_hints, kind,
use_lstm, dropout_prob, nb_heads)
self._create_net_fns(hidden_dim, encode_hints, kind, use_lstm,
dropout_prob, hint_teacher_forcing_noise, nb_heads)
self.params = None
self.opt_state = None
self.opt_state_skeleton = None

def _create_net_fns(self, hidden_dim, encode_hints, kind,
use_lstm, dropout_prob, nb_heads):
def _create_net_fns(self, hidden_dim, encode_hints, kind, use_lstm,
dropout_prob, hint_teacher_forcing_noise, nb_heads):
def _use_net(*args, **kwargs):
return nets.Net(self._spec, hidden_dim, encode_hints,
self.decode_hints, self.decode_diffs,
kind, use_lstm, dropout_prob,
kind, use_lstm, dropout_prob, hint_teacher_forcing_noise,
nb_heads, self.nb_dims)(*args, **kwargs)

self.net_fn = hk.transform(_use_net)
Expand Down Expand Up @@ -279,7 +283,7 @@ def verbose_loss(self, feedback: _Feedback, extra_info) -> Dict[str, _Array]:
losses_.update(
losses.hint_loss(
truth=truth,
preds=hint_preds,
preds=[x[truth.name] for x in hint_preds],
gt_diffs=gt_diffs,
lengths=lengths,
nb_nodes=nb_nodes,
Expand Down Expand Up @@ -323,13 +327,13 @@ class BaselineModelChunked(BaselineModel):
`BaselineModel`.
"""

def _create_net_fns(self, hidden_dim, encode_hints, kind,
use_lstm, dropout_prob, nb_heads):
def _create_net_fns(self, hidden_dim, encode_hints, kind, use_lstm,
dropout_prob, hint_teacher_forcing_noise, nb_heads):
def _use_net(*args, **kwargs):
return nets.NetChunked(
self._spec, hidden_dim, encode_hints,
self.decode_hints, self.decode_diffs,
kind, use_lstm, dropout_prob,
kind, use_lstm, dropout_prob, hint_teacher_forcing_noise,
nb_heads, self.nb_dims)(*args, **kwargs)

self.net_fn = hk.transform(_use_net)
Expand Down
34 changes: 29 additions & 5 deletions clrs/_src/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
kind: str,
use_lstm: bool,
dropout_prob: float,
hint_teacher_forcing_noise: float,
nb_heads: int,
nb_dims=None,
name: str = 'net',
Expand All @@ -93,6 +94,7 @@ def __init__(
super().__init__(name=name)

self._dropout_prob = dropout_prob
self._hint_teacher_forcing_noise = hint_teacher_forcing_noise
self.spec = spec
self.hidden_dim = hidden_dim
self.encode_hints = encode_hints
Expand All @@ -118,20 +120,34 @@ def _msg_passing_step(self,
decs: Dict[str, Tuple[hk.Module]],
diff_decs: Dict[str, Any],
):
if (not first_step) and repred and self.decode_hints:
if self.decode_hints and not first_step:
decoded_hint = decoders.postprocess(spec,
mp_state.hint_preds)
if repred and self.decode_hints and not first_step:
cur_hint = []
for hint in decoded_hint:
cur_hint.append(decoded_hint[hint])
else:
cur_hint = []
needs_noise = (self.decode_hints and not first_step and
self._hint_teacher_forcing_noise > 0)
if needs_noise:
# For noisy teacher forcing, choose which examples in the batch to force
force_mask = jax.random.bernoulli(
hk.next_rng_key(), 1.0 - self._hint_teacher_forcing_noise,
(batch_size,))
else:
force_mask = None
for hint in hints:
hint.data = jnp.asarray(hint.data)
hint_data = jnp.asarray(hint.data)[i]
if needs_noise:
hint_data = jnp.where(_expand_to(force_mask, hint_data),
hint_data,
decoded_hint[hint.name].data)
_, loc, typ = spec[hint.name]
cur_hint.append(
probing.DataPoint(
name=hint.name, location=loc, type_=typ, data=hint.data[i]))
name=hint.name, location=loc, type_=typ, data=hint_data))

gt_diffs = None
if hints[0].data.shape[0] > 1 and self.decode_diffs:
Expand Down Expand Up @@ -501,13 +517,21 @@ def _as_prediction_data(hint):
hints_for_pred = hints
else:
prev_hint_preds = mp_state.hint_preds
if repred and self.decode_hints:
if self.decode_hints:
if repred:
force_mask = jnp.zeros(batch_size, dtype=bool)
elif self._hint_teacher_forcing_noise == 0.:
force_mask = jnp.ones(batch_size, dtype=bool)
else:
force_mask = jax.random.bernoulli(
hk.next_rng_key(), 1.0 - self._hint_teacher_forcing_noise,
(batch_size,))
decoded_hints = decoders.postprocess(spec, prev_hint_preds)
hints_for_pred = []
for h in hints:
hints_for_pred.append(probing.DataPoint(
name=h.name, location=h.location, type_=h.type_,
data=jnp.where(_expand_to(is_first, h.data),
data=jnp.where(_expand_to(is_first | force_mask, h.data),
h.data, decoded_hints[h.name].data)))
else:
hints_for_pred = hints
Expand Down