Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Offline datasets: D4RL #928

Merged
merged 86 commits into from
Mar 16, 2023
Merged

[Feature] Offline datasets: D4RL #928

merged 86 commits into from
Mar 16, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Feb 20, 2023

Description

Integrates offline RL dataset in torchrl

>>> from torchrl.data.datasets import D4RLExperienceReplay
>>> data = D4RLExperienceReplay('kitchen-complete-v0', split_trajs=True)
>>> print(data._storage._storage)
TensorDict(
    fields={
        _batch_size: MemmapTensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        action: MemmapTensor(shape=torch.Size([20, 207, 9]), device=cpu, dtype=torch.float32, is_shared=False),
        done: MemmapTensor(shape=torch.Size([20, 207]), device=cpu, dtype=torch.bool, is_shared=False),
        index: MemmapTensor(shape=torch.Size([20]), device=cpu, dtype=torch.int32, is_shared=False),
        infos: MemmapTensor(shape=torch.Size([20, 207]), device=cpu, dtype=torch.int64, is_shared=False),
        mask: MemmapTensor(shape=torch.Size([20, 207]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: MemmapTensor(shape=torch.Size([20, 207]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: MemmapTensor(shape=torch.Size([20, 207, 60]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: MemmapTensor(shape=torch.Size([20, 207]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([20, 207]),
            device=cpu,
            is_shared=False),
        observation: MemmapTensor(shape=torch.Size([20, 207, 60]), device=cpu, dtype=torch.float32, is_shared=False),
        timeouts: MemmapTensor(shape=torch.Size([20, 207]), device=cpu, dtype=torch.bool, is_shared=False),
        traj_ids: MemmapTensor(shape=torch.Size([20, 207]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([20]),
    device=cpu,
    is_shared=False)
>>> print(data.sample(10))  # will sample 10 trajectories since split_trajs is set to True
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 207, 9]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False),
        index: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int32, is_shared=False),
        infos: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int64, is_shared=False),
        mask: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10, 207, 60]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10, 207]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([10, 207, 60]), device=cpu, dtype=torch.float32, is_shared=False),
        timeouts: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.bool, is_shared=False),
        traj_ids: Tensor(shape=torch.Size([10, 207]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([10, 207]),
    device=cpu,
    is_shared=False)

These datasets can be used with transforms:

>>> from torchrl.data.datasets.d4rl import D4RLExperienceReplay
>>> from torchrl.envs import ObservationNorm
>>> data = D4RLExperienceReplay("maze2d-umaze-v1")
>>> # we can append transforms to the dataset
>>> data.append_transform(ObservationNorm(loc=-1, scale=1.0))
>>> data.sample(128)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([128, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.bool, is_shared=False),
        index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int32, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([128, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
                reward: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([128]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([128, 4]), device=cpu, dtype=torch.float32, is_shared=False),
    batch_size=torch.Size([128]),
    device=cpu,
    is_shared=False)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 20, 2023
@vmoens vmoens added the enhancement New feature or request label Feb 20, 2023
@BY571 BY571 mentioned this pull request Feb 22, 2023
9 tasks
# Conflicts:
#	torchrl/envs/common.py
#	torchrl/envs/libs/vmas.py
#	torchrl/envs/vec_env.py
@vmoens vmoens changed the title [Feature] Offline datasets [Feature] Offline datasets: D4RL Mar 10, 2023
@vmoens vmoens merged commit ce81995 into main Mar 16, 2023
@vmoens vmoens deleted the offline_datasets branch March 16, 2023 20:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants