Skip to content

Commit

Permalink
Merge pull request openvla#34 from siddk/download-hub
Browse files Browse the repository at this point in the history
Automated HF Hub Download
  • Loading branch information
siddk authored Apr 19, 2024
2 parents c7d6eae + a449ea5 commit 4b8b31e
Showing 1 changed file with 55 additions and 24 deletions.
79 changes: 55 additions & 24 deletions prismatic/models/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pathlib import Path
from typing import List, Optional, Union

from huggingface_hub import hf_hub_download
from huggingface_hub import HfFileSystem, hf_hub_download

from prismatic.conf import ModelConfig
from prismatic.models.materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform
Expand All @@ -25,6 +25,7 @@

# === HF Hub Repository ===
HF_HUB_REPO = "TRI-ML/prismatic-vlms"
VLA_HF_HUB_REPO = "openvla/openvla-dev"


# === Available Models ===
Expand Down Expand Up @@ -118,37 +119,67 @@ def load(

# === Load Pretrained VLA Model ===
def load_vla(
model_path: Union[str, Path],
model_id_or_path: Union[str, Path],
hf_token: Optional[str] = None,
cache_dir: Optional[Union[str, Path]] = None,
load_for_training: bool = False,
step_to_load: Optional[int] = None,
model_type: str = "pretrained",
) -> OpenVLA:
"""Loads a pretrained VLA model directly from checkpoint path."""
overwatch.info(f"Loading from local checkpoint path `{model_path}`")

# Assert that the checkpoint path looks like: `..../<RUN_ID>/checkpoints/<CHECKPOINT_DIR>`
model_path = str(model_path)
assert os.path.isfile(model_path)
assert model_path[-3:] == ".pt" and model_path.split("/")[-2] == "checkpoints" and len(model_path.split("/")) >= 3
run_dir = Path("/".join(model_path.split("/")[:-2])) # `..../<RUN_ID>`

# Get paths for `config.json`, 'dataset_statistics.json' and pretrained checkpoint
config_json = run_dir / "config.json"
dataset_stats_json = run_dir / "dataset_statistics.json"
checkpoint_pt = model_path
assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
assert dataset_stats_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"

# Load VLA Config from `config.json` and extract Model Config
"""Loads a pretrained OpenVLA from either local disk or the HuggingFace Hub."""

# TODO (siddk, moojink) :: Unify semantics with `load()` above; right now, `load_vla()` assumes path points to
# checkpoint `.pt` file, rather than the top-level run directory!
if os.path.isfile(model_id_or_path):
overwatch.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(model_id_or_path))}`")

# [Validate] Checkpoint Path should look like `.../<RUN_ID>/checkpoints/<CHECKPOINT_PATH>.pt`
assert (checkpoint_pt.suffix == ".pt") and (checkpoint_pt.parent.name == "checkpoints"), "Invalid checkpoint!"
run_dir = checkpoint_pt.parents[1]

# Get paths for `config.json`, `dataset_statistics.json` and pretrained checkpoint
config_json, dataset_statistics_json = run_dir / "config.json", run_dir / "dataset_statistics.json"
assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"

# Otherwise =>> try looking for a match on `model_id_or_path` on the HF Hub (`VLA_HF_HUB_REPO`)
else:
# Search HF Hub Repo via fsspec API
overwatch.info(f"Checking HF for `{(hf_path := str(Path(VLA_HF_HUB_REPO) / model_type / model_id_or_path))}`")
if not (tmpfs := HfFileSystem()).exists(hf_path):
raise ValueError(f"Couldn't find valid HF Hub Path `{hf_path = }`")

# Identify Checkpoint to Load (via `step_to_load`)
valid_ckpts = tmpfs.glob(f"{hf_path}/checkpoints/step-{step_to_load if step_to_load is not None else ''}*.pt")
if (len(valid_ckpts) == 0) or (step_to_load is not None and len(valid_ckpts) != 1):
raise ValueError(f"Couldn't find a valid checkpoint to load from HF Hub Path `{hf_path}/checkpoints/")

# Call to `glob` will sort steps in ascending order (if `step_to_load` is None); just grab last element
target_ckpt = Path(valid_ckpts[-1]).name

overwatch.info(f"Downloading Model `{model_id_or_path}` Config & Checkpoint `{target_ckpt}`")
with overwatch.local_zero_first():
relpath = Path(model_type) / model_id_or_path
config_json = hf_hub_download(
repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'config.json')!s}", cache_dir=cache_dir
)
dataset_statistics_json = hf_hub_download(
repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'dataset_statistics.json')!s}", cache_dir=cache_dir
)
checkpoint_pt = hf_hub_download(
repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / target_ckpt)!s}", cache_dir=cache_dir
)

# Load VLA Config (and corresponding base VLM `ModelConfig`) from `config.json`
with open(config_json, "r") as f:
vla_cfg = json.load(f)["vla"]
model_cfg = ModelConfig.get_choice_class(vla_cfg["base_vlm"])()

# Load dataset statistics for action de-normalization
with open(dataset_stats_json, "r") as f:
# Load Dataset Statistics for Action Denormalization
with open(dataset_statistics_json, "r") as f:
norm_stats = json.load(f)

# = Load Individual Components necessary for Instantiating a VLM =
# = Load Individual Components necessary for Instantiating a VLA (via base VLM components) =
# =>> Print Minimal Config
overwatch.info(
f"Found Config =>> Loading & Freezing [bold blue]{model_cfg.model_id}[/] with:\n"
Expand All @@ -174,11 +205,11 @@ def load_vla(
inference_mode=not load_for_training,
)

# Create action tokenizer
# Create Action Tokenizer
action_tokenizer = ActionTokenizer(llm_backbone.get_tokenizer())

# Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile)
overwatch.info(f"Loading VLM [bold blue]{model_cfg.model_id}[/] from Checkpoint")
overwatch.info(f"Loading VLA [bold blue]{model_cfg.model_id}[/] from Checkpoint")
vla = OpenVLA.from_pretrained(
checkpoint_pt,
model_cfg.model_id,
Expand Down

0 comments on commit 4b8b31e

Please sign in to comment.