Skip to content

Commit

Permalink
Update on "[Feature] ConditionalPolicySwitch transform"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
mikaylagawarecki committed Feb 5, 2025
2 parents 4c6f563 + 116c0e1 commit 11f7d8a
Show file tree
Hide file tree
Showing 23 changed files with 755 additions and 168 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
strategy:
matrix:
python_version: ["3.10"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
repository: pytorch/rl
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-linux-habitat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down
54 changes: 27 additions & 27 deletions .github/workflows/test-linux-libs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down Expand Up @@ -59,7 +59,7 @@ jobs:
strategy:
matrix:
python_version: ["3.11"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down Expand Up @@ -96,7 +96,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down Expand Up @@ -131,7 +131,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down Expand Up @@ -166,7 +166,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down Expand Up @@ -200,7 +200,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
repository: pytorch/rl
Expand Down Expand Up @@ -235,7 +235,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand All @@ -256,7 +256,7 @@ jobs:
set -euo pipefail
export PYTHON_VERSION="3.9"
export CU_VERSION="12.1"
export CU_VERSION="12.4"
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
Expand All @@ -277,7 +277,7 @@ jobs:
repository: pytorch/rl
runner: "linux.g5.4xlarge.nvidia.gpu"
gpu-arch-type: cuda
gpu-arch-version: "12.1"
gpu-arch-version: "12.4"
docker-image: "nvidia/cuda:12.4.1-runtime-ubuntu22.04"
timeout: 120
script: |
Expand All @@ -291,7 +291,7 @@ jobs:
set -euo pipefail
export PYTHON_VERSION="3.11"
export CU_VERSION="12.1"
export CU_VERSION="12.4"
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
Expand All @@ -309,7 +309,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand All @@ -330,7 +330,7 @@ jobs:
set -euo pipefail
export PYTHON_VERSION="3.9"
export CU_VERSION="12.1"
export CU_VERSION="12.4"
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
Expand All @@ -347,7 +347,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand All @@ -368,7 +368,7 @@ jobs:
set -euo pipefail
export PYTHON_VERSION="3.9"
export CU_VERSION="12.1"
export CU_VERSION="12.4"
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
Expand All @@ -385,7 +385,7 @@ jobs:
strategy:
matrix:
python_version: ["3.10.12"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand All @@ -406,7 +406,7 @@ jobs:
set -euo pipefail
export PYTHON_VERSION="3.10.12"
export CU_VERSION="12.1"
export CU_VERSION="12.4"
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
Expand All @@ -423,7 +423,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down Expand Up @@ -458,7 +458,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down Expand Up @@ -510,7 +510,7 @@ jobs:
set -euo pipefail
export PYTHON_VERSION="3.9"
export CU_VERSION="12.1"
export CU_VERSION="12.4"
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
Expand All @@ -528,7 +528,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down Expand Up @@ -562,7 +562,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down Expand Up @@ -597,7 +597,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
repository: pytorch/rl
Expand Down Expand Up @@ -633,7 +633,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand All @@ -654,7 +654,7 @@ jobs:
set -euo pipefail
export PYTHON_VERSION="3.9"
export CU_VERSION="12.1"
export CU_VERSION="12.4"
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
Expand All @@ -672,7 +672,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down Expand Up @@ -707,7 +707,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand All @@ -728,7 +728,7 @@ jobs:
set -euo pipefail
export PYTHON_VERSION="3.9"
export CU_VERSION="12.1"
export CU_VERSION="12.4"
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-linux-rlhf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
repository: pytorch/rl
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-linux-sota.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ jobs:
strategy:
matrix:
python_version: ["3.11"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down Expand Up @@ -158,7 +158,7 @@ jobs:
strategy:
matrix:
python_version: ["3.11"]
cuda_arch_version: ["12.1"]
cuda_arch_version: ["12.4"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ using the following components:
LazyMemmapStorage
LazyTensorStorage
ListStorage
LazyStackStorage
ListStorageCheckpointer
NestedStorageCheckpointer
PrioritizedSampler
Expand Down
3 changes: 2 additions & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,7 @@ in the relevant functions:
>>> print(env2._env.env.env)
<gym.envs.classic_control.pendulum.PendulumEnv at 0x1629916a0>

We can see that the two libraries modify the value returned by :func:`~.gym.gym_backend()`
We can see that the two libraries modify the value returned by :func:`~torchrl.envs.gym.gym_backend()`
which can be further used to indicate which library needs to be used for
the current computation. :class:`~.gym.set_gym_backend` is also a decorator:
we can use it to tell to a specific function what gym backend needs to be used
Expand Down Expand Up @@ -1189,3 +1189,4 @@ the following function will return ``1`` when queried:
VmasWrapper
gym_backend
set_gym_backend
register_gym_spec_conversion
44 changes: 42 additions & 2 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,11 @@ def _step(
return tensordict


def get_random_string(min_size, max_size):
size = random.randint(min_size, max_size)
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))


class CountingEnvWithString(CountingEnv):
def __init__(self, *args, **kwargs):
self.max_size = kwargs.pop("max_size", 30)
Expand All @@ -1083,8 +1088,7 @@ def __init__(self, *args, **kwargs):
)

def get_random_string(self):
size = random.randint(self.min_size, self.max_size)
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))
return get_random_string(self.min_size, self.max_size)

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
res = super()._reset(tensordict, **kwargs)
Expand Down Expand Up @@ -2202,3 +2206,39 @@ def _step(

def _set_seed(self, seed):
...


class Str2StrEnv(EnvBase):
def __init__(self, min_size=4, max_size=10, **kwargs):
self.observation_spec = Composite(
observation=NonTensor(example_data="an observation!", shape=())
)
self.full_action_spec = Composite(
action=NonTensor(example_data="an action!", shape=())
)
self.reward_spec = Unbounded(shape=(1,), dtype=torch.float)
self.min_size = min_size
self.max_size = max_size
super().__init__(**kwargs)

def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
assert isinstance(tensordict["action"], str)
out = tensordict.empty()
out.set("observation", self.get_random_string())
out.set("done", torch.zeros(1, dtype=torch.bool).bernoulli_(0.01))
out.set("reward", torch.zeros(1, dtype=torch.float).bernoulli_(0.01))
return out

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
out = tensordict.empty() if tensordict is not None else TensorDict()
out.set("observation", self.get_random_string())
out.set("done", torch.zeros(1, dtype=torch.bool).bernoulli_(0.01))
return out

def get_random_string(self):
return get_random_string(self.min_size, self.max_size)

def _set_seed(self, seed: Optional[int]):
random.seed(seed)
torch.manual_seed(0)
return seed
Loading

0 comments on commit 11f7d8a

Please sign in to comment.