Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add manage_dataset script #543

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,13 @@ def encode_episode_videos(self, episode_index: int) -> dict:

return video_paths

def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
def consolidate(
self,
run_compute_stats: bool = True,
keep_image_files: bool = False,
batch_size: int = 8,
num_workers: int = 8,
) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
Expand All @@ -896,7 +902,7 @@ def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = F
if run_compute_stats:
self.stop_image_writer()
# TODO(aliberts): refactor stats in save_episodes
self.meta.stats = compute_stats(self)
self.meta.stats = compute_stats(self, batch_size=batch_size, num_workers=num_workers)
serialized_stats = serialize_dict(self.meta.stats)
write_json(serialized_stats, self.root / STATS_PATH)
self.consolidated = True
Expand Down
110 changes: 110 additions & 0 deletions lerobot/scripts/manage_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
Utilities to manage a dataset.

Examples of usage:

- Consolidate a dataset, by encoding images into videos and computing statistics:
```bash
python lerobot/scripts/manage_dataset.py consolidate \
--repo-id $USER/koch_test
```

- Consolidate a dataset which is not uploaded on the hub yet:
```bash
python lerobot/scripts/manage_dataset.py consolidate \
--repo-id $USER/koch_test \
--local-files-only 1
```

- Upload a dataset on the hub:
```bash
python lerobot/scripts/manage_dataset.py push_to_hub \
--repo-id $USER/koch_test
```
"""

import argparse
from pathlib import Path

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

if __name__ == "__main__":
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="mode", required=True)

# Set common options for all the subparsers
base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
)
base_parser.add_argument(
"--repo-id",
type=str,
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
base_parser.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)

parser_conso = subparsers.add_parser("consolidate", parents=[base_parser])
parser_conso.add_argument(
"--batch-size",
type=int,
default=8,
help="Batch size loaded by DataLoader for computing the dataset statistics.",
)
parser_conso.add_argument(
"--num-workers",
type=int,
default=8,
help="Number of processes of Dataloader for computing the dataset statistics.",
)

parser_push = subparsers.add_parser("push_to_hub", parents=[base_parser])
parser_push.add_argument(
"--tags",
type=str,
nargs="*",
default=None,
help="Optional additional tags to categorize the dataset on the Hugging Face Hub. Use space-separated values (e.g. 'so100 indoor'). The tag 'LeRobot' will always be added.",
)
parser_push.add_argument(
"--license",
type=str,
default="apache-2.0",
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
)
parser_push.add_argument(
"--private",
type=int,
default=0,
help="Create a private dataset repository on the Hugging Face Hub. Push publicly by default.",
)

args = parser.parse_args()
kwargs = vars(args)

mode = kwargs.pop("mode")
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")

dataset = LeRobotDataset(
repo_id=repo_id,
root=root,
local_files_only=local_files_only,
)

if mode == "consolidate":
dataset.consolidate(**kwargs)

elif mode == "push_to_hub":
private = kwargs.pop("private") == 1
dataset.push_to_hub(private=private, **kwargs)
Loading