Skip to content

Attempt fix ci: only cast reward from float64 to float32 #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 17, 2024

Conversation

deathcoder
Copy link
Owner

@deathcoder deathcoder commented Sep 17, 2024

Trying to fix https://github.com/deathcoder/stable-baselines3/actions/runs/10903046945/job/30256261083#step:9:947

______________________________ test_get_original _______________________________

    def test_get_original():
        venv = _make_warmstart_cartpole()
        for _ in range(3):
            actions = [venv.action_space.sample()]
            obs, rewards, _, _ = venv.step(actions)
            obs = obs[0]
            orig_obs = venv.get_original_obs()[0]
            rewards = rewards[0]
            orig_rewards = venv.get_original_reward()[0]
    
            assert np.all(orig_rewards == 1)
            assert orig_obs.shape == obs.shape
            assert orig_rewards.dtype == rewards.dtype
            assert not np.array_equal(orig_obs, obs)
            assert not np.array_equal(orig_rewards, rewards)
            np.testing.assert_allclose(venv.normalize_obs(orig_obs), obs)
>           np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards)

tests/test_vec_normalize.py:318: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_allclose.<locals>.compare at 0x7f2dc0786940>, array(0.0709089, dtype=float32), array(0.0709089, dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=0', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-07, atol=0
E           
E           Mismatched elements: 1 / 1 (100%)
E           Max absolute difference: 7.450581e-09
E           Max relative difference: 1.0507258e-07
E            x: array(0.070909, dtype=float32)
E            y: array(0.070909, dtype=float32)

In a recent change i always cast rewards to float32 in normalize_reward, with this change i only cast from float64 to float32

Summary by CodeRabbit

  • New Features

    • Introduced a manual trigger for CI workflows.
    • Added support for MacOS Metal "mps" device to enhance performance.
    • Implemented version tracking for the cloudpickle library.
  • Bug Fixes

    • Improved type handling in reward normalization for better compatibility.
  • Documentation

    • Enhanced clarity in function parameter documentation and system information reporting.
  • Tests

    • Updated tests to skip unsupported conditions for MPS backend and revised assertions for system information.

Copy link

coderabbitai bot commented Sep 17, 2024

Walkthrough

The pull request introduces several enhancements across multiple files, including the addition of a workflow_dispatch trigger to the CI configuration, improvements for MacOS Metal support, and updates to type handling in various functions. Notable changes include the integration of a new function to determine available computational accelerators, modifications to reward normalization, and adjustments to testing conditions based on backend compatibility. These changes collectively aim to enhance functionality, compatibility, and clarity within the codebase.

Changes

Files Change Summary
.github/workflows/ci.yml Added a workflow_dispatch trigger and reformatted branch specifications.
docs/misc/changelog.rst Introduced use_mps_device for MacOS Metal support and added save_cloudpickle_version for serialization.
stable_baselines3/common/buffers.py Added a check for the mps backend in the to_torch function to support Apple Silicon devices.
stable_baselines3/common/envs/bit_flipping_env.py Ensured state elements are converted to integers and updated image representation casting to np.uint8.
stable_baselines3/common/utils.py Improved documentation in set_random_seed, added get_available_accelerator function, and modified get_system_info to report available accelerators.
stable_baselines3/common/vec_env/vec_normalize.py Added _maybe_cast_reward static method to convert rewards to np.float32.
tests/test_spaces.py Introduced a conditional check to skip tests when using the MPS backend due to lack of float64 support.
tests/test_utils.py Modified assertions in test_get_system_info to check for "Accelerator" instead of "GPU Enabled" and added a check for "Cloudpickle".

Poem

In the garden where code does play,
A rabbit hops with joy today.
With MPS support and CI's new flair,
Our codebase blooms, beyond compare!
So let’s celebrate, with a joyful cheer,
For every change brings us near! 🐇✨


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    -- I pushed a fix in commit <commit_id>, please review it.
    -- Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    -- @coderabbitai generate unit testing code for this file.
    -- @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    -- @coderabbitai generate interesting stats about this repository and render them as a table.
    -- @coderabbitai read src/utils.ts and generate unit testing code.
    -- @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    -- @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.
Early access features: enabled

We are currently testing the following features in early access:

  • OpenAI o1 for code reviews: OpenAI's new o1 model is being tested for generating code suggestions in code reviews.

Note:

  • You can enable or disable early access features from the CodeRabbit UI or by updating the CodeRabbit configuration file.

@deathcoder deathcoder changed the base branch from feat/mps-support to master September 17, 2024 15:59
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 16

Outside diff range and nitpick comments (3)
.github/workflows/ci.yml (1)

8-11: Confirm the branch names and consider updating to 'main' if applicable.

The workflow is triggered on pushes and pull requests to the master branch. Ensure that master is the correct default branch for your repository. If you have moved to main as the default branch, please update the branches specification:

 branches:
-  [master]
+  [main]
stable_baselines3/common/vec_env/vec_normalize.py (1)

131-132: Fix grammatical error in the docstring

There's a minor grammatical issue in the docstring. The phrase "keep the others dtype unchanged" should be corrected to "keep the other dtypes unchanged."

Apply this diff:

    Cast `np.float64` reward datatype to `np.float32`,
-   keep the others dtype unchanged.
+   keep the other dtypes unchanged.
stable_baselines3/common/buffers.py (1)

138-139: Ensure consistent data types when converting arrays to tensors

In the MPS conditional block, the array is converted to a tensor with dtype=th.float32. This forces the data type to float32, which might not be intended if the input array has a different data type. To preserve the original data type of the input array, consider using array.dtype instead of explicitly setting it to float32.

Apply this diff to retain the input array's data type:

-return th.tensor(array, dtype=th.float32, device=self.device)
+return th.tensor(array, dtype=array.dtype, device=self.device)

Alternatively, if casting to float32 is necessary for MPS compatibility, please ensure that this is intentional and documented.

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 512eea9 and 1ac4a60.

Files selected for processing (8)
  • .github/workflows/ci.yml (2 hunks)
  • docs/misc/changelog.rst (2 hunks)
  • stable_baselines3/common/buffers.py (1 hunks)
  • stable_baselines3/common/envs/bit_flipping_env.py (1 hunks)
  • stable_baselines3/common/utils.py (5 hunks)
  • stable_baselines3/common/vec_env/vec_normalize.py (2 hunks)
  • tests/test_spaces.py (2 hunks)
  • tests/test_utils.py (1 hunks)
Additional context used
actionlint
.github/workflows/ci.yml

32-32: shellcheck reported issue in this script: SC2102:info:11:14: Ranges can only match single chars (mentioned due to duplicates)

(shellcheck)

Additional comments not posted (7)
.github/workflows/ci.yml (1)

43-45: Ensure compatibility of opencv-python-headless with your dependencies.

Installing opencv-python-headless separately may conflict with opencv-python if it's included in your project's dependencies. Verify that only one version is installed to prevent unexpected behavior.

You can run the following script to check for conflicting installations:

stable_baselines3/common/envs/bit_flipping_env.py (1)

85-85: Confirmed: Correct Data Type Handling for Image Representation

Casting state to np.uint8 before multiplying by 255 ensures that the binary values are properly converted to image pixel intensities (0 or 255). The concatenation with zeros to match the size allows for accurate reshaping into self.image_shape. This change enhances type safety and prevents unintended type coercion during image construction.

stable_baselines3/common/vec_env/vec_normalize.py (1)

271-272: Ensure consistent reward data types after normalization

The addition of _maybe_cast_reward(reward) in normalize_reward ensures that rewards are properly cast to np.float32 when needed, addressing the precision issues causing CI failures. Good job on implementing this solution.

stable_baselines3/common/utils.py (1)

551-551: Include accelerator details in system information

The addition of "Accelerator": get_available_accelerator(), enhances the system information by specifying the available accelerator. This provides more detailed and helpful information for debugging and user awareness.

tests/test_utils.py (2)

445-445: Ensure 'Accelerator' is consistently present in system info

The assertion assert "Accelerator" in info_str replaces the previous check for "GPU Enabled". Please verify that the 'Accelerator' keyword is included in info_str across all environments, including those without GPU support, to prevent potential test failures on systems without accelerators.


448-448: Confirm 'Cloudpickle' is always included in system info

The addition of assert "Cloudpickle" in info_str assumes that 'Cloudpickle' is present in all environments. Ensure that 'Cloudpickle' is a mandatory dependency and that it's consistently included in info_str to avoid assertion errors in environments where 'Cloudpickle' might not be installed.

stable_baselines3/common/buffers.py (1)

138-139: Verify that tensors are placed on the correct device

The function uses self.device when creating tensors. Ensure that self.device is correctly set to 'mps' when the MPS device is available and intended for use. If self.device is not 'mps', tensors will not be placed on the MPS device despite the conditional check.

Run the following script to check the device configuration:

This script searches for instances where self.device is assigned and ensures that it's set appropriately based on device availability.

Verification successful

Verification confirms correct device handling, including MPS support

The codebase correctly handles device selection, including MPS support. The get_device function in utils.py properly manages device selection, with MPS being prioritized when available. The self.device attribute is consistently set using get_device(device) in both base_class.py and buffers.py, ensuring that the correct device is used throughout the codebase.

The additional check in buffers.py (lines 138-139) is a safeguard specific to tensor creation and doesn't conflict with the overall device selection logic. This check ensures that tensors are created correctly on the MPS device when it's available, regardless of the self.device setting.

In conclusion, the concern about self.device potentially not being set to 'mps' when the MPS device is available is unfounded. The codebase handles device selection and MPS support correctly.

Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that 'self.device' is set to 'mps' when MPS is available.

# Find all assignments to 'self.device' in the codebase
rg --type python 'self\.device\s*=\s*.*' stable_baselines3/

# Check where 'get_device' is called and how 'self.device' is initialized
rg --type python 'self\.device\s*=\s*get_device\(' -A 2 stable_baselines3/

Length of output: 200


Script:

#!/bin/bash
# Description: Verify 'self.device' assignments and MPS-related configurations

# Find all assignments to 'self.device' in the codebase
echo "Searching for 'self.device' assignments:"
grep -R "self\.device\s*=" stable_baselines3/

echo -e "\nSearching for 'get_device' function calls:"
grep -R "self\.device\s*=\s*get_device(" stable_baselines3/

echo -e "\nSearching for MPS-related device configurations:"
grep -R "mps" stable_baselines3/

Length of output: 1892


Script:

#!/bin/bash
# Description: Verify the implementation of 'get_device' function and MPS handling

echo "Searching for 'get_device' function implementation:"
grep -n -A 20 "def get_device" stable_baselines3/common/utils.py

echo -e "\nSearching for MPS-related logic in utils.py:"
grep -n -C 5 "mps" stable_baselines3/common/utils.py

echo -e "\nChecking 'device' parameter usage in base_class.py:"
grep -n -C 3 "self\.device = get_device" stable_baselines3/common/base_class.py

echo -e "\nChecking 'device' parameter usage in buffers.py:"
grep -n -C 3 "self\.device = get_device" stable_baselines3/common/buffers.py

Length of output: 3421

Comment on lines +32 to +60
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu

# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz

pip install .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
pip install .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review shell commands for compatibility issues flagged by ShellCheck.

ShellCheck reported: SC2102: Ranges can only match single chars. This suggests that a character range in your shell script is incorrect.

Please inspect the shell commands for any misuse of character ranges, especially in loops or pattern matches. Ensure that ranges in expressions like [a-z] are correctly used.

Tools
actionlint

32-32: shellcheck reported issue in this script: SC2102:info:11:14: Ranges can only match single chars (mentioned due to duplicates)

(shellcheck)

# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct the redirection operator in the base64 decode command.

The use of &> is not POSIX-compliant and may cause issues in some shells. To redirect the output properly, use >:

- base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
+ base64 Roms.tar.gz.b64 --decode > Roms.tar.gz
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
base64 Roms.tar.gz.b64 --decode > Roms.tar.gz

Comment on lines +38 to +41
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify the installation of Atari ROMs using AutoROM.

Manually downloading and decoding the ROMs is unnecessary. AutoROM can handle the download and installation process:

 pip install autorom
-wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
-base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
-AutoROM --accept-license --source-file Roms.tar.gz
+AutoROM --accept-license

This approach is cleaner and reduces potential points of failure.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
pip install autorom
AutoROM --accept-license

run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use --extra-index-url instead of --index-url when installing PyTorch.

Using --index-url replaces the default PyPI index, potentially causing dependency resolution issues. To add the PyTorch CPU repository without overriding PyPI, use --extra-index-url:

- pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
+ pip install torch==2.1.0 --extra-index-url https://download.pytorch.org/whl/cpu

This ensures all dependencies are correctly installed from both indexes.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
pip install torch==2.1.0 --extra-index-url https://download.pytorch.org/whl/cpu

Comment on lines 155 to 156
if hasattr(th, "backends") and th.backends.mps.is_built():
pytest.skip("MPS framework doesn't support float64")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use is_available() instead of is_built() to check MPS backend support

The method th.backends.mps.is_built() checks if PyTorch was compiled with MPS support but does not verify if the MPS device is actually available on the system. To ensure the test skips correctly on machines where MPS is available and being used, it's better to use th.backends.mps.is_available(), which confirms both the build and the availability of the MPS device.

Apply this diff to fix the issue:

-if hasattr(th, "backends") and th.backends.mps.is_built():
+if hasattr(th, "backends") and th.backends.mps.is_available():
     pytest.skip("MPS framework doesn't support float64")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(th, "backends") and th.backends.mps.is_built():
pytest.skip("MPS framework doesn't support float64")
if hasattr(th, "backends") and th.backends.mps.is_available():
pytest.skip("MPS framework doesn't support float64")

Comment on lines 526 to 529
if hasattr(th, "backends") and th.backends.mps.is_built():
# MacOS Metal GPU
th.set_default_dtype(th.float32)
return "mps"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use th.backends.mps.is_available() and avoid setting global default dtype

Similar to the previous comment, use th.backends.mps.is_available() to accurately check for MPS availability at runtime. Additionally, setting the global default dtype with th.set_default_dtype(th.float32) can have unintended side effects throughout the codebase. It's advisable to specify the dtype explicitly when creating tensors instead of modifying the global setting.

Apply this diff to update the check and remove the global dtype change:

-if hasattr(th, "backends") and th.backends.mps.is_built():
-    # MacOS Metal GPU
-    th.set_default_dtype(th.float32)
+if hasattr(th.backends, "mps") and th.backends.mps.is_available():
     # MacOS Metal GPU
     return "mps"

If changing the default dtype is necessary for MPS devices, consider isolating this change or documenting its implications to prevent unexpected behavior in other parts of the code.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(th, "backends") and th.backends.mps.is_built():
# MacOS Metal GPU
th.set_default_dtype(th.float32)
return "mps"
if hasattr(th.backends, "mps") and th.backends.mps.is_available():
# MacOS Metal GPU
return "mps"

Comment on lines 154 to 156
# CUDA not available
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add handling for unavailable MPS devices

Currently, if the device is set to cuda but CUDA is not available, the function defaults to CPU. Similar handling should be added for MPS devices to ensure consistent behavior and prevent potential runtime errors.

Apply this diff to add the MPS availability check:

+if device.type == th.device("mps").type and not (hasattr(th.backends, "mps") and th.backends.mps.is_available()):
+    return th.device("cpu")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# CUDA not available
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")
# CUDA not available
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")
if device.type == th.device("mps").type and not (hasattr(th.backends, "mps") and th.backends.mps.is_available()):
return th.device("cpu")

Comment on lines 138 to 139
if hasattr(th, "backends") and th.backends.mps.is_built():
return th.tensor(array, dtype=th.float32, device=self.device)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the conditional check to use is_available() for MPS

Currently, the code checks th.backends.mps.is_built() to determine if the MPS backend is available. However, is_built() only indicates if PyTorch was built with MPS support, not if the MPS device is available at runtime. To accurately determine MPS availability on the current system, use th.backends.mps.is_available() instead.

Apply this diff to correct the check:

-if hasattr(th, "backends") and th.backends.mps.is_built():
+if hasattr(th, "backends") and th.backends.mps.is_available():
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(th, "backends") and th.backends.mps.is_built():
return th.tensor(array, dtype=th.float32, device=self.device)
if hasattr(th, "backends") and th.backends.mps.is_available():
return th.tensor(array, dtype=th.float32, device=self.device)

@@ -558,6 +558,7 @@ New Features:
- Added checkpoints for replay buffer and ``VecNormalize`` statistics (@anand-bala)
- Added option for ``Monitor`` to append to existing file instead of overriding (@sidney-tio)
- The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys
- Use MacOS Metal "mps" device when available
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct capitalization of 'macOS'

The term "macOS" should be written with a lowercase 'm' to match Apple's official branding.

Apply this diff to fix the capitalization:

- - Use MacOS Metal "mps" device when available
+ - Use macOS Metal "mps" device when available
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
- Use MacOS Metal "mps" device when available
- Use macOS Metal "mps" device when available

@@ -615,6 +616,7 @@ Breaking Changes:

New Features:
^^^^^^^^^^^^^
- Save cloudpickle version
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enclose 'cloudpickle' in code formatting for consistency

To maintain consistency in the documentation, module names like 'cloudpickle' should be enclosed with double backticks.

Apply this diff to enclose 'cloudpickle' in code formatting:

- - Save cloudpickle version
+ - Save ``cloudpickle`` version
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
- Save cloudpickle version
- Save ``cloudpickle`` version

@deathcoder deathcoder changed the base branch from master to feat/mps-support September 17, 2024 16:09
@deathcoder deathcoder merged commit 9970f51 into feat/mps-support Sep 17, 2024
1 check passed
@coderabbitai coderabbitai bot mentioned this pull request Oct 29, 2024
61 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant