-
Notifications
You must be signed in to change notification settings - Fork 0
Attempt fix ci: only cast reward from float64 to float32 #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughThe pull request introduces several enhancements across multiple files, including the addition of a Changes
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 16
Outside diff range and nitpick comments (3)
.github/workflows/ci.yml (1)
8-11
: Confirm the branch names and consider updating to 'main' if applicable.The workflow is triggered on pushes and pull requests to the
master
branch. Ensure thatmaster
is the correct default branch for your repository. If you have moved tomain
as the default branch, please update thebranches
specification:branches: - [master] + [main]stable_baselines3/common/vec_env/vec_normalize.py (1)
131-132
: Fix grammatical error in the docstringThere's a minor grammatical issue in the docstring. The phrase "keep the others dtype unchanged" should be corrected to "keep the other dtypes unchanged."
Apply this diff:
Cast `np.float64` reward datatype to `np.float32`, - keep the others dtype unchanged. + keep the other dtypes unchanged.stable_baselines3/common/buffers.py (1)
138-139
: Ensure consistent data types when converting arrays to tensorsIn the MPS conditional block, the array is converted to a tensor with
dtype=th.float32
. This forces the data type tofloat32
, which might not be intended if the input array has a different data type. To preserve the original data type of the input array, consider usingarray.dtype
instead of explicitly setting it tofloat32
.Apply this diff to retain the input array's data type:
-return th.tensor(array, dtype=th.float32, device=self.device) +return th.tensor(array, dtype=array.dtype, device=self.device)Alternatively, if casting to
float32
is necessary for MPS compatibility, please ensure that this is intentional and documented.
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (8)
- .github/workflows/ci.yml (2 hunks)
- docs/misc/changelog.rst (2 hunks)
- stable_baselines3/common/buffers.py (1 hunks)
- stable_baselines3/common/envs/bit_flipping_env.py (1 hunks)
- stable_baselines3/common/utils.py (5 hunks)
- stable_baselines3/common/vec_env/vec_normalize.py (2 hunks)
- tests/test_spaces.py (2 hunks)
- tests/test_utils.py (1 hunks)
Additional context used
actionlint
.github/workflows/ci.yml
32-32: shellcheck reported issue in this script: SC2102:info:11:14: Ranges can only match single chars (mentioned due to duplicates)
(shellcheck)
Additional comments not posted (7)
.github/workflows/ci.yml (1)
43-45
: Ensure compatibility ofopencv-python-headless
with your dependencies.Installing
opencv-python-headless
separately may conflict withopencv-python
if it's included in your project's dependencies. Verify that only one version is installed to prevent unexpected behavior.You can run the following script to check for conflicting installations:
stable_baselines3/common/envs/bit_flipping_env.py (1)
85-85
: Confirmed: Correct Data Type Handling for Image RepresentationCasting
state
tonp.uint8
before multiplying by 255 ensures that the binary values are properly converted to image pixel intensities (0 or 255). The concatenation with zeros to match thesize
allows for accurate reshaping intoself.image_shape
. This change enhances type safety and prevents unintended type coercion during image construction.stable_baselines3/common/vec_env/vec_normalize.py (1)
271-272
: Ensure consistent reward data types after normalizationThe addition of
_maybe_cast_reward(reward)
innormalize_reward
ensures that rewards are properly cast tonp.float32
when needed, addressing the precision issues causing CI failures. Good job on implementing this solution.stable_baselines3/common/utils.py (1)
551-551
: Include accelerator details in system informationThe addition of
"Accelerator": get_available_accelerator(),
enhances the system information by specifying the available accelerator. This provides more detailed and helpful information for debugging and user awareness.tests/test_utils.py (2)
445-445
: Ensure 'Accelerator' is consistently present in system infoThe assertion
assert "Accelerator" in info_str
replaces the previous check for "GPU Enabled". Please verify that the 'Accelerator' keyword is included ininfo_str
across all environments, including those without GPU support, to prevent potential test failures on systems without accelerators.
448-448
: Confirm 'Cloudpickle' is always included in system infoThe addition of
assert "Cloudpickle" in info_str
assumes that 'Cloudpickle' is present in all environments. Ensure that 'Cloudpickle' is a mandatory dependency and that it's consistently included ininfo_str
to avoid assertion errors in environments where 'Cloudpickle' might not be installed.stable_baselines3/common/buffers.py (1)
138-139
: Verify that tensors are placed on the correct deviceThe function uses
self.device
when creating tensors. Ensure thatself.device
is correctly set to'mps'
when the MPS device is available and intended for use. Ifself.device
is not'mps'
, tensors will not be placed on the MPS device despite the conditional check.Run the following script to check the device configuration:
This script searches for instances where
self.device
is assigned and ensures that it's set appropriately based on device availability.Verification successful
Verification confirms correct device handling, including MPS support
The codebase correctly handles device selection, including MPS support. The
get_device
function in utils.py properly manages device selection, with MPS being prioritized when available. Theself.device
attribute is consistently set usingget_device(device)
in both base_class.py and buffers.py, ensuring that the correct device is used throughout the codebase.The additional check in buffers.py (lines 138-139) is a safeguard specific to tensor creation and doesn't conflict with the overall device selection logic. This check ensures that tensors are created correctly on the MPS device when it's available, regardless of the
self.device
setting.In conclusion, the concern about
self.device
potentially not being set to 'mps' when the MPS device is available is unfounded. The codebase handles device selection and MPS support correctly.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify that 'self.device' is set to 'mps' when MPS is available. # Find all assignments to 'self.device' in the codebase rg --type python 'self\.device\s*=\s*.*' stable_baselines3/ # Check where 'get_device' is called and how 'self.device' is initialized rg --type python 'self\.device\s*=\s*get_device\(' -A 2 stable_baselines3/Length of output: 200
Script:
#!/bin/bash # Description: Verify 'self.device' assignments and MPS-related configurations # Find all assignments to 'self.device' in the codebase echo "Searching for 'self.device' assignments:" grep -R "self\.device\s*=" stable_baselines3/ echo -e "\nSearching for 'get_device' function calls:" grep -R "self\.device\s*=\s*get_device(" stable_baselines3/ echo -e "\nSearching for MPS-related device configurations:" grep -R "mps" stable_baselines3/Length of output: 1892
Script:
#!/bin/bash # Description: Verify the implementation of 'get_device' function and MPS handling echo "Searching for 'get_device' function implementation:" grep -n -A 20 "def get_device" stable_baselines3/common/utils.py echo -e "\nSearching for MPS-related logic in utils.py:" grep -n -C 5 "mps" stable_baselines3/common/utils.py echo -e "\nChecking 'device' parameter usage in base_class.py:" grep -n -C 3 "self\.device = get_device" stable_baselines3/common/base_class.py echo -e "\nChecking 'device' parameter usage in buffers.py:" grep -n -C 3 "self\.device = get_device" stable_baselines3/common/buffers.pyLength of output: 3421
run: | | ||
python -m pip install --upgrade pip | ||
# cpu version of pytorch | ||
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu | ||
|
||
# Install Atari Roms | ||
pip install autorom | ||
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 | ||
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz | ||
AutoROM --accept-license --source-file Roms.tar.gz | ||
# Install Atari Roms | ||
pip install autorom | ||
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 | ||
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz | ||
AutoROM --accept-license --source-file Roms.tar.gz | ||
|
||
pip install .[extra_no_roms,tests,docs] | ||
# Use headless version | ||
pip install opencv-python-headless | ||
- name: Lint with ruff | ||
run: | | ||
make lint | ||
- name: Build the doc | ||
run: | | ||
make doc | ||
- name: Check codestyle | ||
run: | | ||
make check-codestyle | ||
- name: Type check | ||
run: | | ||
make type | ||
- name: Test with pytest | ||
run: | | ||
make pytest | ||
pip install .[extra_no_roms,tests,docs] | ||
# Use headless version | ||
pip install opencv-python-headless | ||
- name: Lint with ruff | ||
run: | | ||
make lint | ||
- name: Build the doc | ||
run: | | ||
make doc | ||
- name: Check codestyle | ||
run: | | ||
make check-codestyle | ||
- name: Type check | ||
run: | | ||
make type | ||
- name: Test with pytest | ||
run: | | ||
make pytest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review shell commands for compatibility issues flagged by ShellCheck.
ShellCheck reported: SC2102: Ranges can only match single chars
. This suggests that a character range in your shell script is incorrect.
Please inspect the shell commands for any misuse of character ranges, especially in loops or pattern matches. Ensure that ranges in expressions like [a-z]
are correctly used.
Tools
actionlint
32-32: shellcheck reported issue in this script: SC2102:info:11:14: Ranges can only match single chars (mentioned due to duplicates)
(shellcheck)
# Install Atari Roms | ||
pip install autorom | ||
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 | ||
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct the redirection operator in the base64
decode command.
The use of &>
is not POSIX-compliant and may cause issues in some shells. To redirect the output properly, use >
:
- base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
+ base64 Roms.tar.gz.b64 --decode > Roms.tar.gz
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz | |
base64 Roms.tar.gz.b64 --decode > Roms.tar.gz |
pip install autorom | ||
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 | ||
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz | ||
AutoROM --accept-license --source-file Roms.tar.gz |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplify the installation of Atari ROMs using AutoROM
.
Manually downloading and decoding the ROMs is unnecessary. AutoROM
can handle the download and installation process:
pip install autorom
-wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
-base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
-AutoROM --accept-license --source-file Roms.tar.gz
+AutoROM --accept-license
This approach is cleaner and reduces potential points of failure.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
pip install autorom | |
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 | |
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz | |
AutoROM --accept-license --source-file Roms.tar.gz | |
pip install autorom | |
AutoROM --accept-license |
run: | | ||
python -m pip install --upgrade pip | ||
# cpu version of pytorch | ||
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use --extra-index-url
instead of --index-url
when installing PyTorch.
Using --index-url
replaces the default PyPI index, potentially causing dependency resolution issues. To add the PyTorch CPU repository without overriding PyPI, use --extra-index-url
:
- pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
+ pip install torch==2.1.0 --extra-index-url https://download.pytorch.org/whl/cpu
This ensures all dependencies are correctly installed from both indexes.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu | |
pip install torch==2.1.0 --extra-index-url https://download.pytorch.org/whl/cpu |
tests/test_spaces.py
Outdated
if hasattr(th, "backends") and th.backends.mps.is_built(): | ||
pytest.skip("MPS framework doesn't support float64") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use is_available()
instead of is_built()
to check MPS backend support
The method th.backends.mps.is_built()
checks if PyTorch was compiled with MPS support but does not verify if the MPS device is actually available on the system. To ensure the test skips correctly on machines where MPS is available and being used, it's better to use th.backends.mps.is_available()
, which confirms both the build and the availability of the MPS device.
Apply this diff to fix the issue:
-if hasattr(th, "backends") and th.backends.mps.is_built():
+if hasattr(th, "backends") and th.backends.mps.is_available():
pytest.skip("MPS framework doesn't support float64")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if hasattr(th, "backends") and th.backends.mps.is_built(): | |
pytest.skip("MPS framework doesn't support float64") | |
if hasattr(th, "backends") and th.backends.mps.is_available(): | |
pytest.skip("MPS framework doesn't support float64") |
stable_baselines3/common/utils.py
Outdated
if hasattr(th, "backends") and th.backends.mps.is_built(): | ||
# MacOS Metal GPU | ||
th.set_default_dtype(th.float32) | ||
return "mps" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use th.backends.mps.is_available()
and avoid setting global default dtype
Similar to the previous comment, use th.backends.mps.is_available()
to accurately check for MPS availability at runtime. Additionally, setting the global default dtype with th.set_default_dtype(th.float32)
can have unintended side effects throughout the codebase. It's advisable to specify the dtype explicitly when creating tensors instead of modifying the global setting.
Apply this diff to update the check and remove the global dtype change:
-if hasattr(th, "backends") and th.backends.mps.is_built():
- # MacOS Metal GPU
- th.set_default_dtype(th.float32)
+if hasattr(th.backends, "mps") and th.backends.mps.is_available():
# MacOS Metal GPU
return "mps"
If changing the default dtype is necessary for MPS devices, consider isolating this change or documenting its implications to prevent unexpected behavior in other parts of the code.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if hasattr(th, "backends") and th.backends.mps.is_built(): | |
# MacOS Metal GPU | |
th.set_default_dtype(th.float32) | |
return "mps" | |
if hasattr(th.backends, "mps") and th.backends.mps.is_available(): | |
# MacOS Metal GPU | |
return "mps" |
stable_baselines3/common/utils.py
Outdated
# CUDA not available | ||
if device.type == th.device("cuda").type and not th.cuda.is_available(): | ||
return th.device("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add handling for unavailable MPS devices
Currently, if the device is set to cuda
but CUDA is not available, the function defaults to CPU. Similar handling should be added for MPS devices to ensure consistent behavior and prevent potential runtime errors.
Apply this diff to add the MPS availability check:
+if device.type == th.device("mps").type and not (hasattr(th.backends, "mps") and th.backends.mps.is_available()):
+ return th.device("cpu")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# CUDA not available | |
if device.type == th.device("cuda").type and not th.cuda.is_available(): | |
return th.device("cpu") | |
# CUDA not available | |
if device.type == th.device("cuda").type and not th.cuda.is_available(): | |
return th.device("cpu") | |
if device.type == th.device("mps").type and not (hasattr(th.backends, "mps") and th.backends.mps.is_available()): | |
return th.device("cpu") |
stable_baselines3/common/buffers.py
Outdated
if hasattr(th, "backends") and th.backends.mps.is_built(): | ||
return th.tensor(array, dtype=th.float32, device=self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the conditional check to use is_available()
for MPS
Currently, the code checks th.backends.mps.is_built()
to determine if the MPS backend is available. However, is_built()
only indicates if PyTorch was built with MPS support, not if the MPS device is available at runtime. To accurately determine MPS availability on the current system, use th.backends.mps.is_available()
instead.
Apply this diff to correct the check:
-if hasattr(th, "backends") and th.backends.mps.is_built():
+if hasattr(th, "backends") and th.backends.mps.is_available():
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if hasattr(th, "backends") and th.backends.mps.is_built(): | |
return th.tensor(array, dtype=th.float32, device=self.device) | |
if hasattr(th, "backends") and th.backends.mps.is_available(): | |
return th.tensor(array, dtype=th.float32, device=self.device) |
docs/misc/changelog.rst
Outdated
@@ -558,6 +558,7 @@ New Features: | |||
- Added checkpoints for replay buffer and ``VecNormalize`` statistics (@anand-bala) | |||
- Added option for ``Monitor`` to append to existing file instead of overriding (@sidney-tio) | |||
- The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys | |||
- Use MacOS Metal "mps" device when available |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct capitalization of 'macOS'
The term "macOS" should be written with a lowercase 'm' to match Apple's official branding.
Apply this diff to fix the capitalization:
- - Use MacOS Metal "mps" device when available
+ - Use macOS Metal "mps" device when available
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
- Use MacOS Metal "mps" device when available | |
- Use macOS Metal "mps" device when available |
docs/misc/changelog.rst
Outdated
@@ -615,6 +616,7 @@ Breaking Changes: | |||
|
|||
New Features: | |||
^^^^^^^^^^^^^ | |||
- Save cloudpickle version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enclose 'cloudpickle' in code formatting for consistency
To maintain consistency in the documentation, module names like 'cloudpickle' should be enclosed with double backticks.
Apply this diff to enclose 'cloudpickle' in code formatting:
- - Save cloudpickle version
+ - Save ``cloudpickle`` version
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
- Save cloudpickle version | |
- Save ``cloudpickle`` version |
Trying to fix https://github.com/deathcoder/stable-baselines3/actions/runs/10903046945/job/30256261083#step:9:947
In a recent change i always cast rewards to float32 in
normalize_reward
, with this change i only cast from float64 to float32Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests