diff --git a/tests/conftest.py b/tests/conftest.py index 4ad8ae227e7a..443f22accf9f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,7 @@ import os.path import xmlrpc.client +from collections import defaultdict from contextlib import contextmanager from unittest import mock @@ -75,13 +76,13 @@ def metrics(): class _Services: def __init__(self): - self._services = {} + self._services = defaultdict(lambda: defaultdict(dict)) - def register_service(self, iface, context, service_obj): - self._services[(iface, context)] = service_obj + def register_service(self, service_obj, iface=None, context=None, name=""): + self._services[iface][context][name] = service_obj - def find_service(self, iface, context): - return self._services[(iface, context)] + def find_service(self, iface=None, context=None, name=""): + return self._services[iface][context][name] @pytest.fixture @@ -89,7 +90,7 @@ def pyramid_services(metrics): services = _Services() # Register our global services. - services.register_service(IMetricsService, None, metrics) + services.register_service(metrics, IMetricsService, None, name="") return services diff --git a/tests/functional/manage/test_views.py b/tests/functional/manage/test_views.py index 4da267ab4167..dec2e80d2ba1 100644 --- a/tests/functional/manage/test_views.py +++ b/tests/functional/manage/test_views.py @@ -23,9 +23,9 @@ class TestManageAccount: def test_save_account(self, pyramid_services, user_service, db_request): breach_service = pretend.stub() - pyramid_services.register_service(IUserService, None, user_service) + pyramid_services.register_service(user_service, IUserService, None) pyramid_services.register_service( - IPasswordBreachedService, None, breach_service + breach_service, IPasswordBreachedService, None ) user = UserFactory.create(name="old name") EmailFactory.create(primary=True, verified=True, public=True, user=user) diff --git a/tests/unit/accounts/test_core.py b/tests/unit/accounts/test_core.py index f3f369c31bbf..cc26a917b849 100644 --- a/tests/unit/accounts/test_core.py +++ b/tests/unit/accounts/test_core.py @@ -35,9 +35,9 @@ class TestLogin: def test_invalid_route(self, pyramid_request, pyramid_services): service = pretend.stub(find_userid=pretend.call_recorder(lambda username: None)) - pyramid_services.register_service(IUserService, None, service) + pyramid_services.register_service(service, IUserService, None) pyramid_services.register_service( - IPasswordBreachedService, None, pretend.stub() + pretend.stub(), IPasswordBreachedService, None ) pyramid_request.matched_route = pretend.stub(name="route_name") assert accounts._basic_auth_login("myuser", "mypass", pyramid_request) is None @@ -45,9 +45,9 @@ def test_invalid_route(self, pyramid_request, pyramid_services): def test_with_no_user(self, pyramid_request, pyramid_services): service = pretend.stub(find_userid=pretend.call_recorder(lambda username: None)) - pyramid_services.register_service(IUserService, None, service) + pyramid_services.register_service(service, IUserService, None) pyramid_services.register_service( - IPasswordBreachedService, None, pretend.stub() + pretend.stub(), IPasswordBreachedService, None ) pyramid_request.matched_route = pretend.stub(name="forklift.legacy.file_upload") assert accounts._basic_auth_login("myuser", "mypass", pyramid_request) is None @@ -63,9 +63,9 @@ def test_with_invalid_password(self, pyramid_request, pyramid_services): ), is_disabled=pretend.call_recorder(lambda user_id: (False, None)), ) - pyramid_services.register_service(IUserService, None, service) + pyramid_services.register_service(service, IUserService, None) pyramid_services.register_service( - IPasswordBreachedService, None, pretend.stub() + pretend.stub(), IPasswordBreachedService, None ) pyramid_request.matched_route = pretend.stub(name="forklift.legacy.file_upload") assert accounts._basic_auth_login("myuser", "mypass", pyramid_request) is None @@ -86,9 +86,9 @@ def test_with_disabled_user_no_reason(self, pyramid_request, pyramid_services): ), is_disabled=pretend.call_recorder(lambda user_id: (True, None)), ) - pyramid_services.register_service(IUserService, None, service) + pyramid_services.register_service(service, IUserService, None) pyramid_services.register_service( - IPasswordBreachedService, None, pretend.stub() + pretend.stub(), IPasswordBreachedService, None ) pyramid_request.matched_route = pretend.stub(name="forklift.legacy.file_upload") assert accounts._basic_auth_login("myuser", "mypass", pyramid_request) is None @@ -111,11 +111,11 @@ def test_with_disabled_user_compromised_pw(self, pyramid_request, pyramid_servic lambda user_id: (True, DisableReason.CompromisedPassword) ), ) - pyramid_services.register_service(IUserService, None, service) + pyramid_services.register_service(service, IUserService, None) pyramid_services.register_service( + pretend.stub(failure_message_plain="Bad Password!"), IPasswordBreachedService, None, - pretend.stub(failure_message_plain="Bad Password!"), ) pyramid_request.matched_route = pretend.stub(name="forklift.legacy.file_upload") @@ -149,9 +149,9 @@ def test_with_valid_password(self, monkeypatch, pyramid_request, pyramid_service check_password=pretend.call_recorder(lambda pw, tags=None: False) ) - pyramid_services.register_service(IUserService, None, service) + pyramid_services.register_service(service, IUserService, None) pyramid_services.register_service( - IPasswordBreachedService, None, breach_service + breach_service, IPasswordBreachedService, None ) pyramid_request.matched_route = pretend.stub(name="forklift.legacy.file_upload") @@ -199,9 +199,9 @@ def test_via_basic_auth_compromised( failure_message_plain="Bad Password!", ) - pyramid_services.register_service(IUserService, None, service) + pyramid_services.register_service(service, IUserService, None) pyramid_services.register_service( - IPasswordBreachedService, None, breach_service + breach_service, IPasswordBreachedService, None ) pyramid_request.matched_route = pretend.stub(name="forklift.legacy.file_upload") diff --git a/tests/unit/accounts/test_views.py b/tests/unit/accounts/test_views.py index fdb51c71d627..aa99c8ca2a41 100644 --- a/tests/unit/accounts/test_views.py +++ b/tests/unit/accounts/test_views.py @@ -98,9 +98,9 @@ def test_get_returns_form(self, pyramid_request, pyramid_services, next_url): user_service = pretend.stub() breach_service = pretend.stub() - pyramid_services.register_service(IUserService, None, user_service) + pyramid_services.register_service(user_service, IUserService, None) pyramid_services.register_service( - IPasswordBreachedService, None, breach_service + breach_service, IPasswordBreachedService, None ) form_obj = pretend.stub() @@ -132,9 +132,9 @@ def test_post_invalid_returns_form( user_service = pretend.stub() breach_service = pretend.stub() - pyramid_services.register_service(IUserService, None, user_service) + pyramid_services.register_service(user_service, IUserService, None) pyramid_services.register_service( - IPasswordBreachedService, None, breach_service + breach_service, IPasswordBreachedService, None ) pyramid_request.method = "POST" @@ -179,9 +179,9 @@ def test_post_validate_redirects( ) breach_service = pretend.stub(check_password=lambda password, tags=None: False) - pyramid_services.register_service(IUserService, None, user_service) + pyramid_services.register_service(user_service, IUserService, None) pyramid_services.register_service( - IPasswordBreachedService, None, breach_service + breach_service, IPasswordBreachedService, None ) pyramid_request.method = "POST" @@ -272,9 +272,9 @@ def test_post_validate_no_redirects( ) breach_service = pretend.stub(check_password=lambda password, tags=None: False) - pyramid_services.register_service(IUserService, None, user_service) + pyramid_services.register_service(user_service, IUserService, None) pyramid_services.register_service( - IPasswordBreachedService, None, breach_service + breach_service, IPasswordBreachedService, None ) pyramid_request.method = "POST" @@ -1919,7 +1919,7 @@ def test_reauth(self, monkeypatch, pyramid_request, pyramid_services, next_route monkeypatch.setattr(views, "HTTPSeeOther", lambda url: response) - pyramid_services.register_service(IUserService, None, user_service) + pyramid_services.register_service(user_service, IUserService, None) pyramid_request.route_path = lambda *args, **kwargs: pretend.stub() pyramid_request.session.record_auth_timestamp = pretend.call_recorder( diff --git a/tests/unit/legacy/api/xmlrpc/test_xmlrpc.py b/tests/unit/legacy/api/xmlrpc/test_xmlrpc.py index b693a339fc6c..4e9448b02045 100644 --- a/tests/unit/legacy/api/xmlrpc/test_xmlrpc.py +++ b/tests/unit/legacy/api/xmlrpc/test_xmlrpc.py @@ -18,6 +18,7 @@ from warehouse.legacy.api.xmlrpc import views as xmlrpc from warehouse.packaging.models import Classifier +from warehouse.rate_limiting.interfaces import IRateLimiter from .....common.db.accounts import UserFactory from .....common.db.packaging import ( @@ -29,6 +30,81 @@ ) +class TestRateLimiting: + def test_ratelimiting_pass(self, pyramid_services, pyramid_request, metrics): + def view(context, request): + return None + + ratelimited_view = xmlrpc.ratelimit()(view) + context = pretend.stub() + pyramid_request.remote_addr = "127.0.0.1" + fake_rate_limiter = pretend.stub( + test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None + ) + pyramid_services.register_service( + fake_rate_limiter, IRateLimiter, None, name="xmlrpc.client" + ) + ratelimited_view(context, pyramid_request) + + assert metrics.increment.calls == [ + pretend.call("warehouse.xmlrpc.ratelimiter.hit", tags=[]) + ] + + def test_ratelimiting_block(self, pyramid_services, pyramid_request, metrics): + def view(context, request): + return None + + ratelimited_view = xmlrpc.ratelimit()(view) + context = pretend.stub() + pyramid_request.remote_addr = "127.0.0.1" + fake_rate_limiter = pretend.stub( + test=lambda *a: False, hit=lambda *a: True, resets_in=lambda *a: None + ) + pyramid_services.register_service( + fake_rate_limiter, IRateLimiter, None, name="xmlrpc.client" + ) + with pytest.raises(xmlrpc.XMLRPCWrappedError) as exc: + ratelimited_view(context, pyramid_request) + + assert exc.value.faultString == ( + "HTTPTooManyRequests: The action could not be performed because there " + "were too many requests by the client." + ) + + assert metrics.increment.calls == [ + pretend.call("warehouse.xmlrpc.ratelimiter.exceeded", tags=[]) + ] + + def test_ratelimiting_block_with_hint( + self, pyramid_services, pyramid_request, metrics + ): + def view(context, request): + return None + + ratelimited_view = xmlrpc.ratelimit()(view) + context = pretend.stub() + pyramid_request.remote_addr = "127.0.0.1" + fake_rate_limiter = pretend.stub( + test=lambda *a: False, + hit=lambda *a: True, + resets_in=lambda *a: datetime.timedelta(minutes=11, seconds=6.9), + ) + pyramid_services.register_service( + fake_rate_limiter, IRateLimiter, None, name="xmlrpc.client" + ) + with pytest.raises(xmlrpc.XMLRPCWrappedError) as exc: + ratelimited_view(context, pyramid_request) + + assert exc.value.faultString == ( + "HTTPTooManyRequests: The action could not be performed because there " + "were too many requests by the client. Limit may reset in 666 seconds." + ) + + assert metrics.increment.calls == [ + pretend.call("warehouse.xmlrpc.ratelimiter.exceeded", tags=[]) + ] + + class TestSearch: def test_fails_with_invalid_operator(self, pyramid_request, metrics): with pytest.raises(xmlrpc.XMLRPCWrappedError) as exc: diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 77ba844b1cb9..6828090df8c4 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -183,6 +183,7 @@ def __init__(self): "camo.url": "http://camo.example.com/", "pyramid.reload_assets": False, "dirs.packages": "/srv/data/pypi/packages/", + "warehouse.xmlrpc.client.ratelimit_string": "3600 per hour", } configurator_settings = other_settings.copy() @@ -228,6 +229,7 @@ def __init__(self): "site.name": "Warehouse", "token.two_factor.max_age": 300, "token.default.max_age": 21600, + "warehouse.xmlrpc.client.ratelimit_string": "3600 per hour", } if environment == config.Environment.development: @@ -298,6 +300,7 @@ def __init__(self): pretend.call("pyramid_mailer"), pretend.call("pyramid_retry"), pretend.call("pyramid_tm"), + pretend.call(".legacy.api.xmlrpc"), pretend.call(".legacy.api.xmlrpc.cache"), pretend.call("pyramid_rpc.xmlrpc"), pretend.call(".legacy.action_routing"), diff --git a/warehouse/config.py b/warehouse/config.py index a2d9f0b9225a..61397e4be3da 100644 --- a/warehouse/config.py +++ b/warehouse/config.py @@ -184,6 +184,12 @@ def configure(settings=None): maybe_set(settings, "token.email.secret", "TOKEN_EMAIL_SECRET") maybe_set(settings, "token.two_factor.secret", "TOKEN_TWO_FACTOR_SECRET") maybe_set(settings, "warehouse.xmlrpc.cache.url", "REDIS_URL") + maybe_set( + settings, + "warehouse.xmlrpc.client.ratelimit_string", + "XMLRPC_RATELIMIT_STRING", + default="3600 per hour", + ) maybe_set(settings, "token.password.max_age", "TOKEN_PASSWORD_MAX_AGE", coercer=int) maybe_set(settings, "token.email.max_age", "TOKEN_EMAIL_MAX_AGE", coercer=int) maybe_set( @@ -338,6 +344,9 @@ def configure(settings=None): ) config.include("pyramid_tm") + # Register our XMLRPC service + config.include(".legacy.api.xmlrpc") + # Register our XMLRPC cache config.include(".legacy.api.xmlrpc.cache") diff --git a/warehouse/legacy/api/xmlrpc/__init__.py b/warehouse/legacy/api/xmlrpc/__init__.py index 36ab810bdee3..4874d4b1f525 100644 --- a/warehouse/legacy/api/xmlrpc/__init__.py +++ b/warehouse/legacy/api/xmlrpc/__init__.py @@ -11,3 +11,14 @@ # limitations under the License. import warehouse.legacy.api.xmlrpc.views # noqa + +from warehouse.rate_limiting import IRateLimiter, RateLimit + + +def includeme(config): + ratelimit_string = config.registry.settings.get( + "warehouse.xmlrpc.client.ratelimit_string" + ) + config.register_service_factory( + RateLimit(ratelimit_string), IRateLimiter, name="xmlrpc.client" + ) diff --git a/warehouse/legacy/api/xmlrpc/views.py b/warehouse/legacy/api/xmlrpc/views.py index 3d51a6a7d63e..9a48fa25eb07 100644 --- a/warehouse/legacy/api/xmlrpc/views.py +++ b/warehouse/legacy/api/xmlrpc/views.py @@ -23,6 +23,7 @@ from elasticsearch_dsl import Q from packaging.utils import canonicalize_name +from pyramid.httpexceptions import HTTPTooManyRequests from pyramid.view import view_config from pyramid_rpc.mapper import MapplyViewMapper from pyramid_rpc.xmlrpc import ( @@ -45,6 +46,7 @@ Role, release_classifiers, ) +from warehouse.rate_limiting import IRateLimiter from warehouse.search.queries import SEARCH_BOOSTS # From https://stackoverflow.com/a/22273639 @@ -108,6 +110,33 @@ def wrapped(context, request): return decorator +def ratelimit(): + def decorator(f): + def wrapped(context, request): + ratelimiter = request.find_service( + IRateLimiter, name="xmlrpc.client", context=None + ) + metrics = request.find_service(IMetricsService, context=None) + ratelimiter.hit(request.remote_addr) + if not ratelimiter.test(request.remote_addr): + metrics.increment("warehouse.xmlrpc.ratelimiter.exceeded", tags=[]) + message = ( + "The action could not be performed because there were too " + "many requests by the client." + ) + _resets_in = ratelimiter.resets_in(request.remote_addr) + if _resets_in is not None: + _resets_in = int(_resets_in.total_seconds()) + message += f" Limit may reset in {_resets_in} seconds." + raise XMLRPCWrappedError(HTTPTooManyRequests(message)) + metrics.increment("warehouse.xmlrpc.ratelimiter.hit", tags=[]) + return f(context, request) + + return wrapped + + return decorator + + def xmlrpc_method(**kwargs): """ Support multiple endpoints serving the same views by chaining calls to @@ -117,7 +146,7 @@ def xmlrpc_method(**kwargs): kwargs.update( require_csrf=False, require_methods=["POST"], - decorator=(submit_xmlrpc_metrics(method=kwargs["method"]),), + decorator=(submit_xmlrpc_metrics(method=kwargs["method"]), ratelimit()), mapper=TypedMapplyViewMapper, )