Skip to content

Commit

Permalink
Merge pull request #86 from octo-models/lang_paraphrases
Browse files Browse the repository at this point in the history
Language paraphrase augmentations
  • Loading branch information
mees authored May 8, 2024
2 parents 89045cc + 1ceea1f commit 923089e
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
89 changes: 89 additions & 0 deletions octo/data/utils/task_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,100 @@
Contains basic logic for randomly zero-ing out keys in the task specification.
"""

import pickle

from huggingface_hub import hf_hub_download
import tensorflow as tf

from octo.data.utils.data_utils import to_padding


def delete_and_rephrase(
traj,
paraphrases_repo: str,
paraphrases_filename: str,
rephrase_prob: float,
keep_image_prob: float,
):
traj = rephrase_instruction(
traj, paraphrases_repo, paraphrases_filename, rephrase_prob
)
traj = delete_task_conditioning(traj, keep_image_prob)
return traj


class Rephraser:
def create_static_hash_table(self, dictionary):
"""Takes a python dictionary with string keys and values and creates a tf static hash table"""
keys = list(dictionary.keys())
values = list(dictionary.values())
initializer = tf.lookup.KeyValueTensorInitializer(
keys, values, key_dtype=tf.string, value_dtype=tf.string
)
hash_table = tf.lookup.StaticHashTable(initializer, default_value="")
return hash_table

def __init__(self, paraphrases_repo: str, paraphrases_filename: str):
if isinstance(paraphrases_repo, str) and isinstance(paraphrases_filename, str):
with open(
hf_hub_download(
repo_id=paraphrases_repo,
filename=paraphrases_filename,
repo_type="dataset",
),
"rb",
) as file:
lang_paraphrases = pickle.load(file)
# Create StaticHashTable
self.rephrase_lookup = self.create_static_hash_table(lang_paraphrases)


def rephrase_instruction(
traj: dict, paraphrases_repo: str, paraphrases_filename: str, rephrase_prob: float
) -> dict:
"""Randomly rephrases language instructions with precomputed paraphrases
Args:
traj: A dictionary containing trajectory data. Should have a "task" key.
paraphrases_repo: The name of the HF repo containing the paraphrases file.
paraphrases_filename: The name of the file containing the paraphrases.
rephrase_prob: The probability of augmenting the language instruction. The probability of keeping the language
instruction is 1 - rephrase_prob.
"""
rephraser = Rephraser(paraphrases_repo, paraphrases_filename)

if "language_instruction" not in traj["task"]:
return traj
original_language = traj["task"]["language_instruction"]
# check the language key is not empty
string_is_not_empty = tf.reduce_all(tf.strings.length(original_language) > 0)
# check dict is not empty
dict_is_not_empty = bool(rephraser.rephrase_lookup)
if dict_is_not_empty and string_is_not_empty:
rephrased_instruction = rephraser.rephrase_lookup.lookup(original_language[0])
rephrased_instruction = tf.where(
tf.strings.length(rephrased_instruction) > 0,
original_language[0] + "." + rephrased_instruction,
original_language[0],
)
split_tensor = tf.strings.split(rephrased_instruction, sep=".")
num_strings = tf.cast(tf.shape(split_tensor)[0], tf.int32)
random_index = tf.random.uniform(
(tf.shape(original_language)[0],),
minval=0,
maxval=num_strings,
dtype=tf.int32,
)
sampled_language = tf.gather(split_tensor, random_index)
rand = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32)
sampled_language = tf.where(
rand < rephrase_prob,
sampled_language,
original_language,
)
traj["task"]["language_instruction"] = sampled_language
return traj


def delete_task_conditioning(
traj: dict,
keep_image_prob: float,
Expand Down
6 changes: 6 additions & 0 deletions scripts/configs/octo_pretrain_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ def get_config(config_string=None):
),
traj_transform_kwargs=dict(
future_action_window_size=3,
task_augment_strategy="delete_and_rephrase",
task_augment_kwargs=dict(
paraphrases_repo="rail-berkeley/OXE_paraphrases",
paraphrases_filename="paraphrases_oxe.pkl",
rephrase_prob=0.5,
),
),
frame_transform_kwargs=dict(
image_dropout_prob=0.5,
Expand Down

0 comments on commit 923089e

Please sign in to comment.