diff --git a/bandit/plugins/pytorch_load_save.py b/bandit/plugins/pytorch_load.py similarity index 54% rename from bandit/plugins/pytorch_load_save.py rename to bandit/plugins/pytorch_load.py index 77522da22..8be5e3451 100644 --- a/bandit/plugins/pytorch_load_save.py +++ b/bandit/plugins/pytorch_load.py @@ -2,21 +2,26 @@ # # SPDX-License-Identifier: Apache-2.0 r""" -========================================== -B614: Test for unsafe PyTorch load or save -========================================== +================================== +B614: Test for unsafe PyTorch load +================================== -This plugin checks for the use of `torch.load` and `torch.save`. Using -`torch.load` with untrusted data can lead to arbitrary code execution, and -improper use of `torch.save` might expose sensitive data or lead to data -corruption. A safe alternative is to use `torch.load` with the `safetensors` -library from hugingface, which provides a safe deserialization mechanism. +This plugin checks for unsafe use of `torch.load`. Using `torch.load` with +untrusted data can lead to arbitrary code execution. There are two safe +alternatives: +1. Use `torch.load` with `weights_only=True` where only tensor data is + extracted, and no arbitrary Python objects are deserialized +2. Use the `safetensors` library from huggingface, which provides a safe + deserialization mechanism + +With `weights_only=True`, PyTorch enforces a strict type check, ensuring +that only torch.Tensor objects are loaded. :Example: .. code-block:: none - >> Issue: Use of unsafe PyTorch load or save + >> Issue: Use of unsafe PyTorch load Severity: Medium Confidence: High CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html) Location: examples/pytorch_load_save.py:8 @@ -42,12 +47,11 @@ @test.checks("Call") @test.test_id("B614") -def pytorch_load_save(context): +def pytorch_load(context): """ - This plugin checks for the use of `torch.load` and `torch.save`. Using - `torch.load` with untrusted data can lead to arbitrary code execution, - and improper use of `torch.save` might expose sensitive data or lead - to data corruption. + This plugin checks for unsafe use of `torch.load`. Using `torch.load` + with untrusted data can lead to arbitrary code execution. The safe + alternative is to use `weights_only=True` or the safetensors library. """ imported = context.is_module_imported_exact("torch") qualname = context.call_function_name_qual @@ -59,14 +63,18 @@ def pytorch_load_save(context): if all( [ "torch" in qualname_list, - func in ["load", "save"], - not context.check_call_arg_value("map_location", "cpu"), + func == "load", ] ): + # For torch.load, check if weights_only=True is specified + weights_only = context.get_call_arg_value("weights_only") + if weights_only == "True" or weights_only is True: + return + return bandit.Issue( severity=bandit.MEDIUM, confidence=bandit.HIGH, - text="Use of unsafe PyTorch load or save", + text="Use of unsafe PyTorch load", cwe=issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA, lineno=context.get_lineno_for_call_arg("load"), ) diff --git a/doc/source/plugins/b614_pytorch_load.rst b/doc/source/plugins/b614_pytorch_load.rst new file mode 100644 index 000000000..808383e6a --- /dev/null +++ b/doc/source/plugins/b614_pytorch_load.rst @@ -0,0 +1,5 @@ +------------------ +B614: pytorch_load +------------------ + +.. automodule:: bandit.plugins.pytorch_load diff --git a/doc/source/plugins/b614_pytorch_load_save.rst b/doc/source/plugins/b614_pytorch_load_save.rst deleted file mode 100644 index dcc1ae3a0..000000000 --- a/doc/source/plugins/b614_pytorch_load_save.rst +++ /dev/null @@ -1,5 +0,0 @@ ------------------------ -B614: pytorch_load_save ------------------------ - -.. automodule:: bandit.plugins.pytorch_load_save diff --git a/examples/pytorch_load.py b/examples/pytorch_load.py new file mode 100644 index 000000000..c5129a035 --- /dev/null +++ b/examples/pytorch_load.py @@ -0,0 +1,26 @@ +import torch +import torchvision.models as models + +# Example of saving a model +model = models.resnet18(pretrained=True) +torch.save(model.state_dict(), 'model_weights.pth') + +# Example of loading the model weights in an insecure way (should trigger B614) +loaded_model = models.resnet18() +loaded_model.load_state_dict(torch.load('model_weights.pth')) + +# Example of loading with weights_only=True (should NOT trigger B614) +safe_model = models.resnet18() +safe_model.load_state_dict(torch.load('model_weights.pth', weights_only=True)) + +# Example of loading with weights_only=False (should trigger B614) +unsafe_model = models.resnet18() +unsafe_model.load_state_dict(torch.load('model_weights.pth', weights_only=False)) + +# Example of loading with map_location but no weights_only (should trigger B614) +cpu_model = models.resnet18() +cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu')) + +# Example of loading with both map_location and weights_only=True (should NOT trigger B614) +safe_cpu_model = models.resnet18() +safe_cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu', weights_only=True)) diff --git a/examples/pytorch_load_save.py b/examples/pytorch_load_save.py deleted file mode 100644 index e1f912022..000000000 --- a/examples/pytorch_load_save.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -import torchvision.models as models - -# Example of saving a model -model = models.resnet18(pretrained=True) -torch.save(model.state_dict(), 'model_weights.pth') - -# Example of loading the model weights in an insecure way -loaded_model = models.resnet18() -loaded_model.load_state_dict(torch.load('model_weights.pth')) - -# Save the model -torch.save(loaded_model.state_dict(), 'model_weights.pth') - -# Another example using torch.load with more parameters -another_model = models.resnet18() -another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu')) - -# Save the model -torch.save(another_model.state_dict(), 'model_weights.pth') - diff --git a/setup.cfg b/setup.cfg index 83d57d1ea..e0288e600 100644 --- a/setup.cfg +++ b/setup.cfg @@ -155,8 +155,8 @@ bandit.plugins = #bandit/plugins/tarfile_unsafe_members.py tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members - #bandit/plugins/pytorch_load_save.py - pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save + #bandit/plugins/pytorch_load.py + pytorch_load = bandit.plugins.pytorch_load:pytorch_load # bandit/plugins/trojansource.py trojansource = bandit.plugins.trojansource:trojansource diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index f9fe6956b..660b65f94 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -872,13 +872,13 @@ def test_tarfile_unsafe_members(self): } self.check_example("tarfile_extractall.py", expect) - def test_pytorch_load_save(self): - """Test insecure usage of torch.load and torch.save.""" + def test_pytorch_load(self): + """Test insecure usage of torch.load.""" expect = { - "SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 4, "HIGH": 0}, - "CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 4}, + "SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 3, "HIGH": 0}, + "CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 3}, } - self.check_example("pytorch_load_save.py", expect) + self.check_example("pytorch_load.py", expect) def test_trojansource(self): expect = {