diff --git a/tests/unit/forklift/test_legacy.py b/tests/unit/forklift/test_legacy.py index b99a4113925b..0b9817abae91 100644 --- a/tests/unit/forklift/test_legacy.py +++ b/tests/unit/forklift/test_legacy.py @@ -2310,16 +2310,37 @@ def test_upload_fails_with_diff_filename_same_blake2( "400 File already exists. See /the/help/url/ for more information." ) - def test_upload_fails_with_wrong_filename(self, pyramid_config, db_request): + @pytest.mark.parametrize( + "filename, project_name", + [ + # completely different + ("nope-{version}.tar.gz", "something_else"), + ("nope-{version}-py3-none-any.whl", "something_else"), + # starts with same prefix + ("nope-{version}.tar.gz", "no"), + ("nope-{version}-py3-none-any.whl", "no"), + # starts with same prefix with hyphen + ("no-way-{version}.tar.gz", "no"), + ("no_way-{version}-py3-none-any.whl", "no"), + ], + ) + def test_upload_fails_with_wrong_filename( + self, pyramid_config, db_request, metrics, filename, project_name + ): user = UserFactory.create() pyramid_config.testing_securitypolicy(identity=user) db_request.user = user + db_request.user_agent = "warehouse-tests/6.6.6" EmailFactory.create(user=user) - project = ProjectFactory.create() + project = ProjectFactory.create(name=project_name) release = ReleaseFactory.create(project=project, version="1.0") RoleFactory.create(user=user, project=project) - filename = f"nope-{release.version}.tar.gz" + storage_service = pretend.stub(store=lambda path, filepath, meta: None) + db_request.find_service = lambda svc, name=None, context=None: { + IFileStorage: storage_service, + IMetricsService: metrics, + }.get(svc) db_request.POST = MultiDict( { @@ -2327,14 +2348,15 @@ def test_upload_fails_with_wrong_filename(self, pyramid_config, db_request): "name": project.name, "version": release.version, "filetype": "sdist", - "md5_digest": "nope!", + "md5_digest": _TAR_GZ_PKG_MD5, "content": pretend.stub( - filename=filename, - file=io.BytesIO(b"a" * (legacy.MAX_FILESIZE + 1)), + filename=filename.format(version=release.version), + file=io.BytesIO(_TAR_GZ_PKG_TESTDATA), type="application/tar", ), } ) + db_request.help_url = lambda **kw: "/the/help/url/" with pytest.raises(HTTPBadRequest) as excinfo: legacy.file_upload(db_request) diff --git a/warehouse/forklift/legacy.py b/warehouse/forklift/legacy.py index f07515f055cd..3e830828c950 100644 --- a/warehouse/forklift/legacy.py +++ b/warehouse/forklift/legacy.py @@ -1208,10 +1208,20 @@ def file_upload(request): # Ensure the filename doesn't contain any characters that are too 🌶️spicy🥵 _validate_filename(filename) + # Extract the project name from the filename and normalize it. + filename_prefix = pkg_resources.safe_name( + # For wheels, the project name is normalized and won't contain hyphens, so + # we can split on the first hyphen. + filename.partition("-")[0] + if filename.endswith(".whl") + # For source releases, we know that the version should not contain any + # hypens, so we can split on the last hypen to get the project name. + else filename.rpartition("-")[0] + ).lower() + # Make sure that our filename matches the project that it is being uploaded # to. - prefix = pkg_resources.safe_name(project.name).lower() - if not pkg_resources.safe_name(filename).lower().startswith(prefix): + if (prefix := pkg_resources.safe_name(project.name).lower()) != filename_prefix: raise _exc_with_message( HTTPBadRequest, f"Start filename for {project.name!r} with {prefix!r}.",