Skip to content

Commit

Permalink
Prettier spec
Browse files Browse the repository at this point in the history
  • Loading branch information
dibyaghosh committed Dec 13, 2023
1 parent 8226978 commit 3fdd859
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ Octo models are transformer-based diffusion policies, trained on a diverse mix o
Follow the installation instructions, then load a pre-trained OCTO model! See [examples](examples/) for guides to zero-shot evaluation and finetuning.

```
from octo.model.octo_model import OCTOModel
model = OCTOModel.load_pretrained("hf://rail-berkeley/octo-base")
from octo.model.octo_model import OctoModel
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base")
print(model.get_pretty_spec())
```

![Octo model](docs/assets/teaser.png)
Expand Down
48 changes: 48 additions & 0 deletions octo/model/octo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,42 @@ def _init(rng):
dataset_statistics=dataset_statistics,
)

def get_pretty_spec(self):
# TODO: generalize this to work with other models
window_size = self.example_batch["observation"]["pad_mask"].shape[1]

observation_space = {
k: ("batch", "history_window", *v.shape[2:])
for k, v in self.example_batch["observation"].items()
if k.startswith("image")
}
task_space = {
k: ("batch", *v.shape[1:])
for k, v in self.example_batch["task"].items()
if k.startswith("image")
}
if self.text_processor is not None:
task_space["language_instruction"] = jax.tree_map(
lambda arr: ("batch", *arr.shape[1:]),
self.example_batch["task"]["language_instruction"],
)

try:
action_head = self.module.heads["action"]
action_head_repr = str(action_head.__class__)
action_dim, pred_horizon = action_head.action_dim, action_head.pred_horizon
except:
action_head_repr, action_dim, pred_horizon = "", None, None

return SPEC_TEMPLATE.format(
window_size=window_size,
observation_space=flax.core.pretty_repr(observation_space),
task_space=flax.core.pretty_repr(task_space),
action_head_repr=action_head_repr,
action_dim=action_dim,
pred_horizon=pred_horizon,
)


def _verify_shapes(
pytree,
Expand Down Expand Up @@ -458,3 +494,15 @@ def _download_from_huggingface(huggingface_repo_id: str):

folder = huggingface_hub.snapshot_download(huggingface_repo_id)
return folder
return folder


SPEC_TEMPLATE = """
This model is trained with a window size of {window_size}, predicting {action_dim} dimensional actions {pred_horizon} steps into the future.
Observations and tasks conform to the following spec:
Observations: {observation_space}
Tasks: {task_space}
At inference, you may pass in any subset of these observation and task keys, with a history window up to {window_size} timesteps.
"""

0 comments on commit 3fdd859

Please sign in to comment.