Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal #1194

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions uncertainty_baselines/datasets/augmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def mixup(batch_size, aug_params, images, labels):
aug_params: Dict of data augmentation hyper parameters.
images: A batch of images of shape [batch_size, ...]
labels: A batch of labels of shape [batch_size, num_classes]

Returns:
A tuple of (images, labels) with the same dimensions as the input with
Mixup regularization applied.
Expand Down Expand Up @@ -192,15 +191,16 @@ def mixup(batch_size, aug_params, images, labels):
labels = tf.reshape(
tf.tile(labels, [1, aug_count + 1]), [batch_size, aug_count + 1, -1])
labels_mix = (
labels * mix_weight +
tf.gather(labels, mixup_index) * (1. - mix_weight))
labels * mix_weight + tf.gather(labels, mixup_index) *
(1. - mix_weight))
labels_mix = tf.reshape(
tf.transpose(labels_mix, [1, 0, 2]), [batch_size * (aug_count + 1), -1])
else:
labels_mix = (
labels * mix_weight +
tf.gather(labels, mixup_index) * (1. - mix_weight))
return images_mix, labels_mix
labels * mix_weight + tf.gather(labels, mixup_index) *
(1. - mix_weight))

return images_mix, labels_mix


def adaptive_mixup(batch_size, aug_params, images, labels):
Expand All @@ -215,7 +215,6 @@ def adaptive_mixup(batch_size, aug_params, images, labels):
aug_params: Dict of data augmentation hyper parameters.
images: A batch of images of shape [batch_size, ...]
labels: A batch of labels of shape [batch_size, num_classes]

Returns:
A tuple of (images, labels) with the same dimensions as the input with
Mixup regularization applied.
Expand All @@ -229,8 +228,8 @@ def adaptive_mixup(batch_size, aug_params, images, labels):
# Need to filter out elements in alpha which equal to 0.
greater_zero_indicator = tf.cast(alpha > 0, alpha.dtype)
less_one_indicator = tf.cast(alpha < 1, alpha.dtype)
valid_alpha_indicator = tf.cast(
greater_zero_indicator * less_one_indicator, tf.bool)
valid_alpha_indicator = tf.cast(greater_zero_indicator * less_one_indicator,
tf.bool)
sampled_alpha = tf.where(valid_alpha_indicator, alpha, 0.1)
mix_weight = tfd.Beta(sampled_alpha, sampled_alpha).sample()
mix_weight = tf.where(valid_alpha_indicator, mix_weight, alpha)
Expand All @@ -253,4 +252,5 @@ def adaptive_mixup(batch_size, aug_params, images, labels):
images_mix = (
images * images_mix_weight + images[::-1] * (1. - images_mix_weight))
labels_mix = labels * mix_weight + labels[::-1] * (1. - mix_weight)
return images_mix, labels_mix

return images_mix, labels_mix
18 changes: 15 additions & 3 deletions uncertainty_baselines/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def _add_example_id(enumerate_id, example):
def _load(self,
*,
preprocess_fn: Optional[PreProcessFn] = None,
process_batch_fn: Optional[PreProcessFn] = None,
batch_size: int = -1) -> tf.data.Dataset:
"""Transforms the dataset from builder.as_dataset() to batch, repeat, etc.

Expand All @@ -278,6 +279,9 @@ def _load(self,
preprocess_fn: an optional preprocessing function, if not provided then a
subclass must define _create_process_example_fn() which will be used to
preprocess the data.
process_batch_fn: an optional processing batch function, if not
provided then _create_process_batch_fn() will be used to generate the
function that will process a batch of data.
batch_size: the batch size to use.

Returns:
Expand Down Expand Up @@ -372,7 +376,8 @@ def _load(self,
else:
dataset = dataset.batch(batch_size, drop_remainder=self._drop_remainder)

process_batch_fn = self._create_process_batch_fn(batch_size) # pylint: disable=assignment-from-none
if process_batch_fn is None:
process_batch_fn = self._create_process_batch_fn(batch_size)
if process_batch_fn:
dataset = dataset.map(
process_batch_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
Expand Down Expand Up @@ -406,6 +411,7 @@ def load(
self,
*,
preprocess_fn: Optional[PreProcessFn] = None,
process_batch_fn: Optional[PreProcessFn] = None,
batch_size: int = -1,
strategy: Optional[tf.distribute.Strategy] = None) -> tf.data.Dataset:
"""Function definition to support multi-host dataset sharding.
Expand All @@ -431,11 +437,13 @@ def load(

Args:
preprocess_fn: see `load()`.
process_batch_fn: see `load()`.
batch_size: the *global* batch size to use. This should equal
`per_replica_batch_size * num_replica_in_sync`.
strategy: the DistributionStrategy used to shard the dataset. Note that
this is only required if TensorFlow for training, otherwise it can be
ignored.

Returns:
A sharded dataset, with its seed combined with the per-host id.
"""
Expand All @@ -445,11 +453,15 @@ def _load_distributed(ctx: tf.distribute.InputContext):
self._seed, ctx.input_pipeline_id)
per_replica_batch_size = ctx.get_per_replica_batch_size(batch_size)
return self._load(
preprocess_fn=preprocess_fn, batch_size=per_replica_batch_size)
preprocess_fn=preprocess_fn,
process_batch_fn=process_batch_fn,
batch_size=per_replica_batch_size)

return strategy.distribute_datasets_from_function(_load_distributed)
else:
return self._load(preprocess_fn=preprocess_fn, batch_size=batch_size)
return self._load(preprocess_fn=preprocess_fn,
process_batch_fn=process_batch_fn,
batch_size=batch_size)


_BaseDatasetClass = Type[TypeVar('B', bound=BaseDataset)]
Expand Down