Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch fix #1231

Merged
merged 8 commits into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,26 @@
#
# SPDX-License-Identifier: Apache-2.0
r"""
==========================================
B614: Test for unsafe PyTorch load or save
==========================================
==================================
B614: Test for unsafe PyTorch load
==================================

This plugin checks for the use of `torch.load` and `torch.save`. Using
`torch.load` with untrusted data can lead to arbitrary code execution, and
improper use of `torch.save` might expose sensitive data or lead to data
corruption. A safe alternative is to use `torch.load` with the `safetensors`
library from hugingface, which provides a safe deserialization mechanism.
This plugin checks for unsafe use of `torch.load`. Using `torch.load` with
untrusted data can lead to arbitrary code execution. There are two safe
alternatives:
1. Use `torch.load` with `weights_only=True` where only tensor data is
extracted, and no arbitrary Python objects are deserialized
2. Use the `safetensors` library from huggingface, which provides a safe
deserialization mechanism

With `weights_only=True`, PyTorch enforces a strict type check, ensuring
that only torch.Tensor objects are loaded.

:Example:

.. code-block:: none

>> Issue: Use of unsafe PyTorch load or save
>> Issue: Use of unsafe PyTorch load
Severity: Medium Confidence: High
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
Location: examples/pytorch_load_save.py:8
Expand All @@ -42,12 +47,11 @@

@test.checks("Call")
@test.test_id("B614")
def pytorch_load_save(context):
def pytorch_load(context):
"""
This plugin checks for the use of `torch.load` and `torch.save`. Using
`torch.load` with untrusted data can lead to arbitrary code execution,
and improper use of `torch.save` might expose sensitive data or lead
to data corruption.
This plugin checks for unsafe use of `torch.load`. Using `torch.load`
with untrusted data can lead to arbitrary code execution. The safe
alternative is to use `weights_only=True` or the safetensors library.
"""
imported = context.is_module_imported_exact("torch")
qualname = context.call_function_name_qual
Expand All @@ -59,14 +63,18 @@ def pytorch_load_save(context):
if all(
[
"torch" in qualname_list,
func in ["load", "save"],
not context.check_call_arg_value("map_location", "cpu"),
func == "load",
]
):
# For torch.load, check if weights_only=True is specified
weights_only = context.get_call_arg_value("weights_only")
if weights_only == "True" or weights_only is True:
return

return bandit.Issue(
severity=bandit.MEDIUM,
confidence=bandit.HIGH,
text="Use of unsafe PyTorch load or save",
text="Use of unsafe PyTorch load",
cwe=issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA,
lineno=context.get_lineno_for_call_arg("load"),
)
5 changes: 5 additions & 0 deletions doc/source/plugins/b614_pytorch_load.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
------------------
B614: pytorch_load
------------------

.. automodule:: bandit.plugins.pytorch_load
5 changes: 0 additions & 5 deletions doc/source/plugins/b614_pytorch_load_save.rst

This file was deleted.

26 changes: 26 additions & 0 deletions examples/pytorch_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import torchvision.models as models

# Example of saving a model
model = models.resnet18(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

# Example of loading the model weights in an insecure way (should trigger B614)
loaded_model = models.resnet18()
loaded_model.load_state_dict(torch.load('model_weights.pth'))

# Example of loading with weights_only=True (should NOT trigger B614)
safe_model = models.resnet18()
safe_model.load_state_dict(torch.load('model_weights.pth', weights_only=True))

# Example of loading with weights_only=False (should trigger B614)
unsafe_model = models.resnet18()
unsafe_model.load_state_dict(torch.load('model_weights.pth', weights_only=False))

# Example of loading with map_location but no weights_only (should trigger B614)
cpu_model = models.resnet18()
cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))

# Example of loading with both map_location and weights_only=True (should NOT trigger B614)
safe_cpu_model = models.resnet18()
safe_cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu', weights_only=True))
21 changes: 0 additions & 21 deletions examples/pytorch_load_save.py

This file was deleted.

4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ bandit.plugins =
#bandit/plugins/tarfile_unsafe_members.py
tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members

#bandit/plugins/pytorch_load_save.py
pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save
#bandit/plugins/pytorch_load.py
pytorch_load = bandit.plugins.pytorch_load:pytorch_load

# bandit/plugins/trojansource.py
trojansource = bandit.plugins.trojansource:trojansource
Expand Down
10 changes: 5 additions & 5 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,13 +872,13 @@ def test_tarfile_unsafe_members(self):
}
self.check_example("tarfile_extractall.py", expect)

def test_pytorch_load_save(self):
"""Test insecure usage of torch.load and torch.save."""
def test_pytorch_load(self):
"""Test insecure usage of torch.load."""
expect = {
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 4, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 4},
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 3, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 3},
}
self.check_example("pytorch_load_save.py", expect)
self.check_example("pytorch_load.py", expect)

def test_trojansource(self):
expect = {
Expand Down