diff --git a/tests/unit/forklift/test_legacy.py b/tests/unit/forklift/test_legacy.py index b2f9ab3479bf..c0d79f9f333b 100644 --- a/tests/unit/forklift/test_legacy.py +++ b/tests/unit/forklift/test_legacy.py @@ -20,7 +20,6 @@ from cgi import FieldStorage from unittest import mock -import pkg_resources import pretend import pytest import requests @@ -2354,7 +2353,7 @@ def test_upload_fails_with_diff_filename_same_blake2( "filetype": "sdist", "md5_digest": hashlib.md5(file_content.getvalue()).hexdigest(), "content": pretend.stub( - filename="{}-fake.tar.gz".format(project.name), + filename="{}-0.1.tar.gz".format(project.name), file=file_content, type="application/tar", ), @@ -2388,7 +2387,18 @@ def test_upload_fails_with_diff_filename_same_blake2( "400 File already exists. See /the/help/url/ for more information." ) - def test_upload_fails_with_wrong_filename(self, pyramid_config, db_request): + @pytest.mark.parametrize( + "filename", + [ + "nope-{version}.tar.gz", + "nope-{version}-py3-none-any.whl", + "nope-notaversion.tar.gz", + "nope-notaversion-py3-none-any.whl", + ], + ) + def test_upload_fails_with_wrong_filename( + self, pyramid_config, db_request, filename + ): pyramid_config.testing_securitypolicy(userid=1) user = UserFactory.create() @@ -2398,7 +2408,7 @@ def test_upload_fails_with_wrong_filename(self, pyramid_config, db_request): release = ReleaseFactory.create(project=project, version="1.0") RoleFactory.create(user=user, project=project) - filename = "nope-{}.tar.gz".format(release.version) + filename = filename.format(version=release.version) db_request.POST = MultiDict( { @@ -2422,8 +2432,8 @@ def test_upload_fails_with_wrong_filename(self, pyramid_config, db_request): assert resp.status_code == 400 assert resp.status == ( - "400 Start filename for {!r} with {!r}.".format( - project.name, pkg_resources.safe_name(project.name).lower() + "400 Filename {!r} must match project {!r}.".format( + filename, project.normalized_name ) ) diff --git a/warehouse/forklift/legacy.py b/warehouse/forklift/legacy.py index 5542d09090c2..35e825668792 100644 --- a/warehouse/forklift/legacy.py +++ b/warehouse/forklift/legacy.py @@ -26,7 +26,6 @@ import packaging.specifiers import packaging.utils import packaging.version -import pkg_resources import requests import stdlib_list import wtforms @@ -628,6 +627,22 @@ def full_validate(self): ) +def _is_valid_filename(filename, specified_normalized_name): + if filename.endswith(".whl"): + parse_func = packaging.utils.parse_wheel_filename + else: + parse_func = packaging.utils.parse_sdist_filename + try: + parsed_parts = parse_func(filename) + except ( + packaging.utils.InvalidSdistFilename, + packaging.utils.InvalidWheelFilename, + packaging.version.InvalidVersion, + ): + return False + return parsed_parts[0] == specified_normalized_name + + _safe_zipnames = re.compile(r"(purelib|platlib|headers|scripts|data).+", re.I) # .tar uncompressed, .tar.gz .tgz, .tar.bz2 .tbz2 _tar_filenames_re = re.compile(r"\.(?:tar$|t(?:ar\.)?(?Pgz|bz2)$)") @@ -1194,11 +1209,11 @@ def file_upload(request): # Make sure that our filename matches the project that it is being uploaded # to. - prefix = pkg_resources.safe_name(project.name).lower() - if not pkg_resources.safe_name(filename).lower().startswith(prefix): + normalized_name = project.normalized_name + if not _is_valid_filename(filename, normalized_name): raise _exc_with_message( HTTPBadRequest, - "Start filename for {!r} with {!r}.".format(project.name, prefix), + "Filename {!r} must match project {!r}.".format(filename, normalized_name), ) # Check the content type of what is being uploaded