From 3674e2fbd9e3a814fbf49f08be968a60bbabce57 Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Sun, 9 Feb 2025 23:06:16 +0000 Subject: [PATCH 1/7] Fix pytorch weights check --- bandit/plugins/pytorch_load_save.py | 13 +++++++++---- examples/pytorch_load_save.py | 23 +++++++++++++++-------- tests/functional/test_functional.py | 4 ++-- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/bandit/plugins/pytorch_load_save.py b/bandit/plugins/pytorch_load_save.py index 77522da22..a1e876db5 100644 --- a/bandit/plugins/pytorch_load_save.py +++ b/bandit/plugins/pytorch_load_save.py @@ -10,8 +10,9 @@ `torch.load` with untrusted data can lead to arbitrary code execution, and improper use of `torch.save` might expose sensitive data or lead to data corruption. A safe alternative is to use `torch.load` with the `safetensors` -library from hugingface, which provides a safe deserialization mechanism. - +library from hugingface, which provides a safe deserialization mechanism. Or +use the `weights_only` argument for `torch.load` to load only the model weights +and avoid deserializing the entire model state. :Example: .. code-block:: none @@ -59,10 +60,14 @@ def pytorch_load_save(context): if all( [ "torch" in qualname_list, - func in ["load", "save"], - not context.check_call_arg_value("map_location", "cpu"), + func == "load", ] ): + # For torch.load, check if weights_only=True is specified + weights_only = context.get_call_arg_value("weights_only") + if weights_only == "True" or weights_only is True: + return + return bandit.Issue( severity=bandit.MEDIUM, confidence=bandit.HIGH, diff --git a/examples/pytorch_load_save.py b/examples/pytorch_load_save.py index e1f912022..a6d380889 100644 --- a/examples/pytorch_load_save.py +++ b/examples/pytorch_load_save.py @@ -5,17 +5,24 @@ model = models.resnet18(pretrained=True) torch.save(model.state_dict(), 'model_weights.pth') -# Example of loading the model weights in an insecure way +# Example of loading the model weights in an insecure way (should trigger B614) loaded_model = models.resnet18() loaded_model.load_state_dict(torch.load('model_weights.pth')) -# Save the model -torch.save(loaded_model.state_dict(), 'model_weights.pth') +# Example of loading with weights_only=True (should NOT trigger B614) +safe_model = models.resnet18() +safe_model.load_state_dict(torch.load('model_weights.pth', weights_only=True)) -# Another example using torch.load with more parameters -another_model = models.resnet18() -another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu')) +# Example of loading with weights_only=False (should trigger B614) +unsafe_model = models.resnet18() +unsafe_model.load_state_dict(torch.load('model_weights.pth', weights_only=False)) + +# Example of loading with map_location but no weights_only (should trigger B614) +cpu_model = models.resnet18() +cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu')) + +# Example of loading with both map_location and weights_only=True (should NOT trigger B614) +safe_cpu_model = models.resnet18() +safe_cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu', weights_only=True)) -# Save the model -torch.save(another_model.state_dict(), 'model_weights.pth') diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index f9fe6956b..ba5df2647 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -875,8 +875,8 @@ def test_tarfile_unsafe_members(self): def test_pytorch_load_save(self): """Test insecure usage of torch.load and torch.save.""" expect = { - "SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 4, "HIGH": 0}, - "CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 4}, + "SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 3, "HIGH": 0}, + "CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 3}, } self.check_example("pytorch_load_save.py", expect) From c335d37faca4218b8fa8f19c756f21b895f1c851 Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Mon, 10 Feb 2025 09:43:50 +0000 Subject: [PATCH 2/7] B614: Fix PyTorch plugin to handle weights_only parameter correctly The PyTorch plugin (B614) has been updated to properly handle the weights_only parameter in torch.load calls. When weights_only=True is specified, PyTorch will only deserialize known safe types, making the operation more secure. I also removed torch.save as there is no certain insecure element as such, saving any file or artifact requires consideration of what it is you are saving. Changes: - Update plugin to only check torch.load calls (not torch.save) - Fix weights_only check to handle both string and boolean True values - Remove map_location check as it doesn't affect security - Update example file to demonstrate both safe and unsafe cases - Update plugin documentation to mention weights_only as a safe alternative The plugin now correctly identifies unsafe torch.load calls while allowing safe usage with weights_only=True to pass without warning. Fixes: #1224 --- bandit/plugins/pytorch_load_save.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bandit/plugins/pytorch_load_save.py b/bandit/plugins/pytorch_load_save.py index a1e876db5..fd71061a6 100644 --- a/bandit/plugins/pytorch_load_save.py +++ b/bandit/plugins/pytorch_load_save.py @@ -10,9 +10,11 @@ `torch.load` with untrusted data can lead to arbitrary code execution, and improper use of `torch.save` might expose sensitive data or lead to data corruption. A safe alternative is to use `torch.load` with the `safetensors` -library from hugingface, which provides a safe deserialization mechanism. Or -use the `weights_only` argument for `torch.load` to load only the model weights -and avoid deserializing the entire model state. +library from huggingface, which provides a safe deserialization mechanism. A +second option is to use the `weights_only` argument for `torch.load` where +only tensor data is extracted, and no arbitrary Python objects (like custom +layers, optimizers or hooks) are deserialized. With `weights_only=True`, PyTorch +enforces a strict type check, ensuring that only torch.Tensor objects are loaded. :Example: .. code-block:: none From 2818b55092a49034527f61122380d269f6b575f6 Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Mon, 10 Feb 2025 11:16:52 +0000 Subject: [PATCH 3/7] Fix E501 line too long --- bandit/plugins/pytorch_load_save.py | 6 ++++-- examples/pytorch_load_save.py | 2 -- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bandit/plugins/pytorch_load_save.py b/bandit/plugins/pytorch_load_save.py index fd71061a6..921fd2576 100644 --- a/bandit/plugins/pytorch_load_save.py +++ b/bandit/plugins/pytorch_load_save.py @@ -13,8 +13,10 @@ library from huggingface, which provides a safe deserialization mechanism. A second option is to use the `weights_only` argument for `torch.load` where only tensor data is extracted, and no arbitrary Python objects (like custom -layers, optimizers or hooks) are deserialized. With `weights_only=True`, PyTorch -enforces a strict type check, ensuring that only torch.Tensor objects are loaded. +layers, optimizers) are deserialized. With `weights_only=True`, PyTorch +enforces a strict type check, ensuring that only torch.Tensor objects are +loaded. + :Example: .. code-block:: none diff --git a/examples/pytorch_load_save.py b/examples/pytorch_load_save.py index a6d380889..c5129a035 100644 --- a/examples/pytorch_load_save.py +++ b/examples/pytorch_load_save.py @@ -24,5 +24,3 @@ # Example of loading with both map_location and weights_only=True (should NOT trigger B614) safe_cpu_model = models.resnet18() safe_cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu', weights_only=True)) - - From 6e334271da02e39f781f0ee6a7357e48d55437d6 Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Fri, 14 Feb 2025 22:54:01 +0000 Subject: [PATCH 4/7] Rename files to new test scope --- .../{pytorch_load_save.py => pytorch_load.py} | 37 +++++++++---------- doc/source/plugins/b614_pytorch_load.rst | 5 +++ doc/source/plugins/b614_pytorch_load_save.rst | 5 --- .../{pytorch_load_save.py => pytorch_load.py} | 0 setup.cfg | 4 +- tests/functional/test_functional.py | 6 +-- 6 files changed, 28 insertions(+), 29 deletions(-) rename bandit/plugins/{pytorch_load_save.py => pytorch_load.py} (61%) create mode 100644 doc/source/plugins/b614_pytorch_load.rst delete mode 100644 doc/source/plugins/b614_pytorch_load_save.rst rename examples/{pytorch_load_save.py => pytorch_load.py} (100%) diff --git a/bandit/plugins/pytorch_load_save.py b/bandit/plugins/pytorch_load.py similarity index 61% rename from bandit/plugins/pytorch_load_save.py rename to bandit/plugins/pytorch_load.py index 921fd2576..6015136c4 100644 --- a/bandit/plugins/pytorch_load_save.py +++ b/bandit/plugins/pytorch_load.py @@ -3,25 +3,25 @@ # SPDX-License-Identifier: Apache-2.0 r""" ========================================== -B614: Test for unsafe PyTorch load or save +B614: Test for unsafe PyTorch load ========================================== -This plugin checks for the use of `torch.load` and `torch.save`. Using -`torch.load` with untrusted data can lead to arbitrary code execution, and -improper use of `torch.save` might expose sensitive data or lead to data -corruption. A safe alternative is to use `torch.load` with the `safetensors` -library from huggingface, which provides a safe deserialization mechanism. A -second option is to use the `weights_only` argument for `torch.load` where -only tensor data is extracted, and no arbitrary Python objects (like custom -layers, optimizers) are deserialized. With `weights_only=True`, PyTorch -enforces a strict type check, ensuring that only torch.Tensor objects are -loaded. +This plugin checks for unsafe use of `torch.load`. Using `torch.load` with +untrusted data can lead to arbitrary code execution. There are two safe +alternatives: +1. Use `torch.load` with `weights_only=True` where only tensor data is + extracted, and no arbitrary Python objects are deserialized +2. Use the `safetensors` library from huggingface, which provides a safe + deserialization mechanism + +With `weights_only=True`, PyTorch enforces a strict type check, ensuring +that only torch.Tensor objects are loaded. :Example: .. code-block:: none - >> Issue: Use of unsafe PyTorch load or save + >> Issue: Use of unsafe PyTorch load Severity: Medium Confidence: High CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html) Location: examples/pytorch_load_save.py:8 @@ -47,12 +47,11 @@ @test.checks("Call") @test.test_id("B614") -def pytorch_load_save(context): +def pytorch_load(context): """ - This plugin checks for the use of `torch.load` and `torch.save`. Using - `torch.load` with untrusted data can lead to arbitrary code execution, - and improper use of `torch.save` might expose sensitive data or lead - to data corruption. + This plugin checks for unsafe use of `torch.load`. Using `torch.load` + with untrusted data can lead to arbitrary code execution. The safe + alternative is to use `weights_only=True` or the safetensors library. """ imported = context.is_module_imported_exact("torch") qualname = context.call_function_name_qual @@ -71,11 +70,11 @@ def pytorch_load_save(context): weights_only = context.get_call_arg_value("weights_only") if weights_only == "True" or weights_only is True: return - + return bandit.Issue( severity=bandit.MEDIUM, confidence=bandit.HIGH, - text="Use of unsafe PyTorch load or save", + text="Use of unsafe PyTorch load", cwe=issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA, lineno=context.get_lineno_for_call_arg("load"), ) diff --git a/doc/source/plugins/b614_pytorch_load.rst b/doc/source/plugins/b614_pytorch_load.rst new file mode 100644 index 000000000..74f549fb9 --- /dev/null +++ b/doc/source/plugins/b614_pytorch_load.rst @@ -0,0 +1,5 @@ +----------------------- +B614: pytorch_load +----------------------- + +.. automodule:: bandit.plugins.pytorch_load diff --git a/doc/source/plugins/b614_pytorch_load_save.rst b/doc/source/plugins/b614_pytorch_load_save.rst deleted file mode 100644 index dcc1ae3a0..000000000 --- a/doc/source/plugins/b614_pytorch_load_save.rst +++ /dev/null @@ -1,5 +0,0 @@ ------------------------ -B614: pytorch_load_save ------------------------ - -.. automodule:: bandit.plugins.pytorch_load_save diff --git a/examples/pytorch_load_save.py b/examples/pytorch_load.py similarity index 100% rename from examples/pytorch_load_save.py rename to examples/pytorch_load.py diff --git a/setup.cfg b/setup.cfg index 83d57d1ea..e0288e600 100644 --- a/setup.cfg +++ b/setup.cfg @@ -155,8 +155,8 @@ bandit.plugins = #bandit/plugins/tarfile_unsafe_members.py tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members - #bandit/plugins/pytorch_load_save.py - pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save + #bandit/plugins/pytorch_load.py + pytorch_load = bandit.plugins.pytorch_load:pytorch_load # bandit/plugins/trojansource.py trojansource = bandit.plugins.trojansource:trojansource diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index ba5df2647..660b65f94 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -872,13 +872,13 @@ def test_tarfile_unsafe_members(self): } self.check_example("tarfile_extractall.py", expect) - def test_pytorch_load_save(self): - """Test insecure usage of torch.load and torch.save.""" + def test_pytorch_load(self): + """Test insecure usage of torch.load.""" expect = { "SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 3, "HIGH": 0}, "CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 3}, } - self.check_example("pytorch_load_save.py", expect) + self.check_example("pytorch_load.py", expect) def test_trojansource(self): expect = { From b8c00d8d88a9d0b9d0a06d56fb1bf2f3cc212793 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Feb 2025 22:55:15 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- bandit/plugins/pytorch_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bandit/plugins/pytorch_load.py b/bandit/plugins/pytorch_load.py index 6015136c4..2e7046b29 100644 --- a/bandit/plugins/pytorch_load.py +++ b/bandit/plugins/pytorch_load.py @@ -70,7 +70,7 @@ def pytorch_load(context): weights_only = context.get_call_arg_value("weights_only") if weights_only == "True" or weights_only is True: return - + return bandit.Issue( severity=bandit.MEDIUM, confidence=bandit.HIGH, From b7ab6f344ab84093f50635df08c8da93412126f4 Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Sun, 16 Feb 2025 01:07:04 +0000 Subject: [PATCH 6/7] Update doc/source/plugins/b614_pytorch_load.rst Co-authored-by: Eric Brown --- doc/source/plugins/b614_pytorch_load.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/source/plugins/b614_pytorch_load.rst b/doc/source/plugins/b614_pytorch_load.rst index 74f549fb9..808383e6a 100644 --- a/doc/source/plugins/b614_pytorch_load.rst +++ b/doc/source/plugins/b614_pytorch_load.rst @@ -1,5 +1,5 @@ ------------------------ +------------------ B614: pytorch_load ------------------------ +------------------ .. automodule:: bandit.plugins.pytorch_load From e9b74ec8ef39805fb05173b6adf9ab2de4d19d73 Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Sun, 16 Feb 2025 01:07:41 +0000 Subject: [PATCH 7/7] Update pytorch_load.py --- bandit/plugins/pytorch_load.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bandit/plugins/pytorch_load.py b/bandit/plugins/pytorch_load.py index 2e7046b29..8be5e3451 100644 --- a/bandit/plugins/pytorch_load.py +++ b/bandit/plugins/pytorch_load.py @@ -2,9 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 r""" -========================================== +================================== B614: Test for unsafe PyTorch load -========================================== +================================== This plugin checks for unsafe use of `torch.load`. Using `torch.load` with untrusted data can lead to arbitrary code execution. There are two safe