Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variant/kraken #99

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
11 changes: 9 additions & 2 deletions octo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def apply_trajectory_transforms(
num_parallel_calls,
)

# chunks observations and actions
# chunks observations and actions, giving them a new axis at index 1 of size `window_size` and
# `window_size + future_action_window_size`, respectively

dataset = dataset.traj_map(
partial(
traj_transforms.chunk_act_obs,
Expand Down Expand Up @@ -391,7 +393,9 @@ def is_nonzero_length(traj):
full_dataset = full_dataset.filter(ModuleSpec.instantiate(filter_fcn_spec))
if ignore_errors:
full_dataset = full_dataset.ignore_errors()

full_dataset = full_dataset.traj_map(restructure).filter(is_nonzero_length)

# tries to load from cache, otherwise computes on the fly
dataset_statistics = get_dataset_statistics(
full_dataset,
Expand Down Expand Up @@ -454,13 +458,13 @@ def is_nonzero_length(traj):

return dataset, dataset_statistics


def make_single_dataset(
dataset_kwargs: dict,
*,
train: bool,
traj_transform_kwargs: dict = {},
frame_transform_kwargs: dict = {},
user_modify_traj
) -> dl.DLataset:
"""Creates a single dataset from kwargs. Returns a dataset of trajectories.

Expand All @@ -474,6 +478,9 @@ def make_single_dataset(
**dataset_kwargs,
train=train,
)

dataset = dataset.traj_map(user_modify_traj)

dataset = apply_trajectory_transforms(dataset, **traj_transform_kwargs, train=train)
dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train)

Expand Down
1 change: 1 addition & 0 deletions octo/data/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def normalize_action_and_proprio(
mask = metadata[key].get(
"mask", tf.ones_like(metadata[key]["mean"], dtype=tf.bool)
)

traj = dl.transforms.selective_tree_map(
traj,
match=lambda k, _: k == traj_key,
Expand Down
4 changes: 4 additions & 0 deletions octo/model/components/action_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ class DiscreteActionHead(nn.Module, ActionHead):
action_dim: int = 7
vocab_size: int = 256
normalization_type: str = "uniform"
low: Optional[float] = None
high: Optional[float] = None

def setup(self):
total_output = self.action_horizon * self.action_dim * self.vocab_size
Expand All @@ -267,6 +269,8 @@ def setup(self):
self.action_tokenizer = BinTokenizer(
n_bins=self.vocab_size,
bin_type=self.normalization_type,
low=self.low,
high=self.high
)

def __call__(
Expand Down
7 changes: 5 additions & 2 deletions octo/model/components/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,14 @@ class BinTokenizer(nn.Module):

n_bins: int = 256
bin_type: str = "uniform"
low: float = 0
high: float = 1
low: Optional[float] = None
high: Optional[float] = None

def setup(self):
if self.bin_type == "uniform":
if self.low is None or self.high is None:
raise ValueError("Low and high must be provided for uniform normalization")

self.thresholds = jnp.linspace(self.low, self.high, self.n_bins + 1)
elif self.bin_type == "normal":
self.thresholds = norm.ppf(jnp.linspace(EPS, 1 - EPS, self.n_bins + 1))
Expand Down
158 changes: 158 additions & 0 deletions octo/model/octo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,162 @@ def sample_actions(
else:
raise ValueError(f"Unknown normalization type: {normalization_type}")
return action

@partial(jax.jit, static_argnames=("train", "sample_shape", "argmax", "beam"))
def sample_future_actions(
self,
observations: Data,
tasks: Data,
unnormalization_statistics: Optional[Data] = None,
normalization_type: NormalizationType = NormalizationType.NORMAL,
beam:int = 1,
timestep_pad_mask: Optional[ArrayLike] = None,
train: bool = False,
argmax: bool = False,
sample_shape: Tuple[int, ...] = (),
rng: Optional[PRNGKey] = None,
temperature: float = 1.0,
):
"""Samples actions from the model. See `action_heads.py` for more info.

Args:
observations: dictionary of arrays of shape (batch_size, window_size, *)
tasks: dict of tasks of shape (batch_size, *)
unnormalization_statistics: dict of statistics for unnormalizing actions (must contain "mean",
"std", and optionally "mask")
normalization_type: type of normalization applied to the actions
timestep_pad_mask: (batch_size, window_size) Boolean mask that is False when the timestep corresponds to padding
train: whether to run in train mode
...see `action_heads.py` for the rest of the kwargs.
Returns:
actions: (*sample_shape, batch_size, action_horizon, action_dim)
"""
if timestep_pad_mask is None:
timestep_pad_mask = observations["pad_mask"]

transformer_outputs = self.run_transformer(
observations, tasks, timestep_pad_mask, train=train
)
action_head = self.module.bind({"params": self.params}).heads[
"action"
]

action_logits = action_head(transformer_outputs, train=train)[:, -1]

action_distribution = jax.nn.softmax(action_logits, axis=-1)

action_tokens = jnp.argsort(action_distribution, axis=-1)[..., -beam:].astype(jnp.int32)
confidence = jnp.take_along_axis(action_distribution, action_tokens, axis=-1)

action_tokens = jnp.broadcast_to(
action_tokens, sample_shape + action_tokens.shape
)

action = action_head.action_tokenizer.decode(action_tokens)

if unnormalization_statistics is not None:
if normalization_type == NormalizationType.NORMAL:
mask = unnormalization_statistics.get(
"mask",
jnp.ones_like(unnormalization_statistics["mean"], dtype=bool),
)
action = action[..., : len(mask)]
action = jnp.where(
mask,
(action * unnormalization_statistics["std"])
+ unnormalization_statistics["mean"],
action,
)
elif normalization_type == NormalizationType.BOUNDS:
mask = unnormalization_statistics.get(
"mask", jnp.ones_like(unnormalization_statistics["p01"], dtype=bool)
)
action = action[..., : len(mask)]
action = jnp.where(
mask,
(action + 1)
* (
unnormalization_statistics["p99"]
- unnormalization_statistics["p01"]
)
/ 2
+ unnormalization_statistics["p01"],
action,
)
else:
raise ValueError(f"Unknown normalization type: {normalization_type}")

return action, confidence

@partial(jax.jit, static_argnames=("train", "sample_shape", "beam"))
def sample_trajectory(
self,
observations: Data,
next_action,
tasks: Data,
unnormalization_statistics: Optional[Data] = None,
normalization_type: NormalizationType = NormalizationType.NORMAL,
beam: int = 1,
timestep_pad_mask: Optional[ArrayLike] = None,
train: bool = False,
argmax: bool = False,
sample_shape: Tuple[int, ...] = (),
rng: Optional[PRNGKey] = None,
temperature: float = 1.0,
):
if timestep_pad_mask is None:
pad_mask = observations["pad_mask"]

transformer_outputs = self.run_transformer(
observations, tasks, pad_mask, train=train
)

trajectory_head = self.module.bind({"params": self.params}).heads[
"trajectory"
]

action = trajectory_head.predict_action(
transformer_outputs,
train=train,
argmax=argmax,
sample_shape=sample_shape,
rng=rng,
temperature=temperature,
)

if unnormalization_statistics is not None:
if normalization_type == NormalizationType.NORMAL:
mask = unnormalization_statistics.get(
"mask",
jnp.ones_like(unnormalization_statistics["mean"], dtype=bool),
)
action = action[..., : len(mask)]
action = jnp.where(
mask,
(action * unnormalization_statistics["std"])
+ unnormalization_statistics["mean"],
action,
)
elif normalization_type == NormalizationType.BOUNDS:
mask = unnormalization_statistics.get(
"mask", jnp.ones_like(unnormalization_statistics["p01"], dtype=bool)
)
action = action[..., : len(mask)]
action = jnp.where(
mask,
(action + 1)
* (
unnormalization_statistics["p99"]
- unnormalization_statistics["p01"]
)
/ 2
+ unnormalization_statistics["p01"],
action,
)
else:
raise ValueError(f"Unknown normalization type: {normalization_type}")

return action

@classmethod
def load_pretrained(
Expand Down Expand Up @@ -277,6 +433,8 @@ def load_pretrained(
tf.io.gfile.join(checkpoint_path, "config.json"), "r"
) as f:
config = json.load(f)
if 'readouts' in config['model']:
config['model']['readout_tokenizers'] = config['model'].pop('readouts')

# shim to support old configs
if "pred_horizon" in config["model"]["heads"]["action"]["kwargs"]:
Expand Down
Loading
Loading