Skip to content

Commit

Permalink
Fix arbitrary file write during tarfile extraction
Browse files Browse the repository at this point in the history
Fixes #3302

Address arbitrary file write vulnerability during tarfile extraction in `luigi/contrib/lsf_runner.py`.

* Add a function `_is_within_directory(directory, target)` to check if a target path is within a directory.
* Add a function `_safe_extract(tar, path=".", members=None, *, numeric_owner=False)` to safely extract tar files.
* Replace the existing tar extraction code with a call to `_safe_extract` in the `extract_packages_archive` function.

* Add a test case `test_safe_extract` to verify the safe extraction of tar files in `test/contrib/lsf_test.py`.
* Add a test case `test_safe_extract_with_traversal` to verify that directory traversal is prevented in `test/contrib/lsf_test.py`.
  • Loading branch information
Ali-Razmjoo committed Sep 4, 2024
1 parent 74e6e63 commit 91ce61f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
18 changes: 16 additions & 2 deletions luigi/contrib/lsf_runner.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ def do_work_on_compute_node(work_dir):
job.work()


def _is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory


def _safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not _is_within_directory(path, member_path):
raise ValueError("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)


def extract_packages_archive(work_dir):
package_file = os.path.join(work_dir, "packages.tar")
if not os.path.exists(package_file):
Expand All @@ -53,8 +68,7 @@ def extract_packages_archive(work_dir):

os.chdir(work_dir)
tar = tarfile.open(package_file)
for tarinfo in tar:
tar.extract(tarinfo)
_safe_extract(tar)
tar.close()
if '' not in sys.path:
sys.path.insert(0, '')
Expand Down
43 changes: 43 additions & 0 deletions test/contrib/lsf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@

import luigi
from luigi.contrib.lsf import LSFJobTask
import tarfile
import tempfile
import shutil

import pytest

Expand Down Expand Up @@ -103,5 +106,45 @@ def tearDown(self):
pass


class TestSafeExtract(unittest.TestCase):

def setUp(self):
self.temp_dir = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.temp_dir)

def test_safe_extract(self):
tar_path = os.path.join(self.temp_dir, 'test.tar')
with tarfile.open(tar_path, 'w') as tar:
for i in range(3):
file_path = os.path.join(self.temp_dir, f'test_file_{i}.txt')
with open(file_path, 'w') as f:
f.write(f'This is test file {i}')
tar.add(file_path, arcname=f'test_file_{i}.txt')

with tarfile.open(tar_path, 'r') as tar:
_safe_extract(tar, self.temp_dir)

for i in range(3):
file_path = os.path.join(self.temp_dir, f'test_file_{i}.txt')
self.assertTrue(os.path.exists(file_path))
with open(file_path, 'r') as f:
content = f.read()
self.assertEqual(content, f'This is test file {i}')

def test_safe_extract_with_traversal(self):
tar_path = os.path.join(self.temp_dir, 'test.tar')
with tarfile.open(tar_path, 'w') as tar:
file_path = os.path.join(self.temp_dir, 'test_file.txt')
with open(file_path, 'w') as f:
f.write('This is a test file')
tar.add(file_path, arcname='../../test_file.txt')

with tarfile.open(tar_path, 'r') as tar:
with self.assertRaises(ValueError):
_safe_extract(tar, self.temp_dir)


if __name__ == '__main__':
unittest.main()

0 comments on commit 91ce61f

Please sign in to comment.