diff --git a/src/crawlee/_types.py b/src/crawlee/_types.py index 51f9d357e7..684f64d02e 100644 --- a/src/crawlee/_types.py +++ b/src/crawlee/_types.py @@ -180,6 +180,17 @@ class AddRequestsKwargs(EnqueueLinksKwargs): requests: Sequence[str | Request] """Requests to be added to the `RequestManager`.""" + rq_id: str | None + """ID of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` can be provided.""" + + rq_name: str | None + """Name of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` can be provided. + """ + + rq_alias: str | None + """Alias of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` can be provided. + """ + class PushDataKwargs(TypedDict): """Keyword arguments for dataset's `push_data` method.""" @@ -261,10 +272,18 @@ def __init__(self, *, key_value_store_getter: GetKeyValueStoreFunction) -> None: async def add_requests( self, requests: Sequence[str | Request], + rq_id: str | None = None, + rq_name: str | None = None, + rq_alias: str | None = None, **kwargs: Unpack[EnqueueLinksKwargs], ) -> None: """Track a call to the `add_requests` context helper.""" - self.add_requests_calls.append(AddRequestsKwargs(requests=requests, **kwargs)) + specified_params = sum(1 for param in [rq_id, rq_name, rq_alias] if param is not None) + if specified_params > 1: + raise ValueError('Only one of `rq_id`, `rq_name` or `rq_alias` can be provided.') + self.add_requests_calls.append( + AddRequestsKwargs(requests=requests, rq_id=rq_id, rq_name=rq_name, rq_alias=rq_alias, **kwargs) + ) async def push_data( self, @@ -311,12 +330,21 @@ class AddRequestsFunction(Protocol): def __call__( self, requests: Sequence[str | Request], + rq_id: str | None = None, + rq_name: str | None = None, + rq_alias: str | None = None, **kwargs: Unpack[EnqueueLinksKwargs], ) -> Coroutine[None, None, None]: """Call dunder method. Args: requests: Requests to be added to the `RequestManager`. + rq_id: ID of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` can be + provided. + rq_name: Name of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` + can be provided. + rq_alias: Alias of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` + can be provided. **kwargs: Additional keyword arguments. """ @@ -344,12 +372,21 @@ def __call__( label: str | None = None, user_data: dict[str, Any] | None = None, transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction] | None = None, + rq_id: str | None = None, + rq_name: str | None = None, + rq_alias: str | None = None, **kwargs: Unpack[EnqueueLinksKwargs], ) -> Coroutine[None, None, None]: ... @overload def __call__( - self, *, requests: Sequence[str | Request] | None = None, **kwargs: Unpack[EnqueueLinksKwargs] + self, + *, + requests: Sequence[str | Request] | None = None, + rq_id: str | None = None, + rq_name: str | None = None, + rq_alias: str | None = None, + **kwargs: Unpack[EnqueueLinksKwargs], ) -> Coroutine[None, None, None]: ... def __call__( @@ -360,6 +397,9 @@ def __call__( user_data: dict[str, Any] | None = None, transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction] | None = None, requests: Sequence[str | Request] | None = None, + rq_id: str | None = None, + rq_name: str | None = None, + rq_alias: str | None = None, **kwargs: Unpack[EnqueueLinksKwargs], ) -> Coroutine[None, None, None]: """Call enqueue links function. @@ -377,6 +417,12 @@ def __call__( - `'skip'` to exclude the request from being enqueued, - `'unchanged'` to use the original request options without modification. requests: Requests to be added to the `RequestManager`. + rq_id: ID of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` can be + provided. + rq_name: Name of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` + can be provided. + rq_alias: Alias of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` + can be provided. **kwargs: Additional keyword arguments. """ diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 0a6d4eae87..1d384c0455 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -944,6 +944,9 @@ async def enqueue_links( transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction] | None = None, requests: Sequence[str | Request] | None = None, + rq_id: str | None = None, + rq_name: str | None = None, + rq_alias: str | None = None, **kwargs: Unpack[EnqueueLinksKwargs], ) -> None: kwargs.setdefault('strategy', 'same-hostname') @@ -955,7 +958,9 @@ async def enqueue_links( '`transform_request_function` arguments when `requests` is provided.' ) # Add directly passed requests. - await context.add_requests(requests or list[str | Request](), **kwargs) + await context.add_requests( + requests or list[str | Request](), rq_id=rq_id, rq_name=rq_name, rq_alias=rq_alias, **kwargs + ) else: # Add requests from extracted links. await context.add_requests( @@ -965,6 +970,9 @@ async def enqueue_links( user_data=user_data, transform_request_function=transform_request_function, ), + rq_id=rq_id, + rq_name=rq_name, + rq_alias=rq_alias, **kwargs, ) @@ -1241,10 +1249,28 @@ async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> """Commit request handler result for the input `context`. Result is taken from `_context_result_map`.""" result = self._context_result_map[context] - request_manager = await self.get_request_manager() + base_request_manager = await self.get_request_manager() + origin = context.request.loaded_url or context.request.url for add_requests_call in result.add_requests_calls: + rq_id = add_requests_call.get('rq_id') + rq_name = add_requests_call.get('rq_name') + rq_alias = add_requests_call.get('rq_alias') + specified_params = sum(1 for param in [rq_id, rq_name, rq_alias] if param is not None) + if specified_params > 1: + raise ValueError('You can only provide one of `rq_id`, `rq_name` or `rq_alias` arguments.') + if rq_id or rq_name or rq_alias: + request_manager: RequestManager | RequestQueue = await RequestQueue.open( + id=rq_id, + name=rq_name, + alias=rq_alias, + storage_client=self._service_locator.get_storage_client(), + configuration=self._service_locator.get_configuration(), + ) + else: + request_manager = base_request_manager + requests = list[Request]() base_url = url if (url := add_requests_call.get('base_url')) else origin diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index 62ede11e67..7f864afbd4 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -1549,3 +1549,71 @@ def listener(event_data: EventCrawlerStatusData) -> None: event_manager.off(event=Event.CRAWLER_STATUS, listener=listener) assert status_message_listener.called + + +@pytest.mark.parametrize( + ('queue_name', 'queue_alias', 'by_id'), + [ + pytest.param('named-queue', None, False, id='with rq_name'), + pytest.param(None, 'alias-queue', False, id='with rq_alias'), + pytest.param('id-queue', None, True, id='with rq_id'), + ], +) +async def test_add_requests_with_rq_param(queue_name: str | None, queue_alias: str | None, *, by_id: bool) -> None: + crawler = BasicCrawler() + rq = await RequestQueue.open(name=queue_name, alias=queue_alias) + if by_id: + queue_id = rq.id + queue_name = None + else: + queue_id = None + visit_urls = set() + + check_requests = [ + Request.from_url('https://a.placeholder.com'), + Request.from_url('https://b.placeholder.com'), + Request.from_url('https://c.placeholder.com'), + ] + + @crawler.router.default_handler + async def handler(context: BasicCrawlingContext) -> None: + visit_urls.add(context.request.url) + await context.add_requests(check_requests, rq_id=queue_id, rq_name=queue_name, rq_alias=queue_alias) + + await crawler.run(['https://start.placeholder.com']) + + requests_from_queue = [] + while request := await rq.fetch_next_request(): + requests_from_queue.append(request) + + assert requests_from_queue == check_requests + assert visit_urls == {'https://start.placeholder.com'} + + await rq.drop() + + +@pytest.mark.parametrize( + ('queue_name', 'queue_alias', 'queue_id'), + [ + pytest.param('named-queue', 'alias-queue', None, id='rq_name and rq_alias'), + pytest.param('named-queue', None, 'id-queue', id='rq_name and rq_id'), + pytest.param(None, 'alias-queue', 'id-queue', id='rq_alias and rq_id'), + pytest.param('named-queue', 'alias-queue', 'id-queue', id='rq_name and rq_alias and rq_id'), + ], +) +async def test_add_requests_error_with_multi_params( + queue_id: str | None, queue_name: str | None, queue_alias: str | None +) -> None: + crawler = BasicCrawler() + + @crawler.router.default_handler + async def handler(context: BasicCrawlingContext) -> None: + with pytest.raises(ValueError, match='Only one of `rq_id`, `rq_name` or `rq_alias` can be set'): + await context.add_requests( + [Request.from_url('https://a.placeholder.com')], + rq_id=queue_id, + rq_name=queue_name, + rq_alias=queue_alias, + ) + + await crawler.run(['https://start.placeholder.com']) diff --git a/tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py b/tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py index 37f2d1b8ed..efe58665dd 100644 --- a/tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py +++ b/tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py @@ -3,8 +3,11 @@ from typing import TYPE_CHECKING from unittest import mock +import pytest + from crawlee import ConcurrencySettings, Glob, HttpHeaders, RequestTransformAction, SkippedReason from crawlee.crawlers import BeautifulSoupCrawler, BeautifulSoupCrawlingContext +from crawlee.storages import RequestQueue if TYPE_CHECKING: from yarl import URL @@ -198,3 +201,107 @@ async def request_handler(context: BeautifulSoupCrawlingContext) -> None: assert len(extracted_links) == 1 assert extracted_links[0] == str(server_url / 'page_1') + + +@pytest.mark.parametrize( + ('queue_name', 'queue_alias', 'by_id'), + [ + pytest.param('named-queue', None, False, id='with rq_name'), + pytest.param(None, 'alias-queue', False, id='with rq_alias'), + pytest.param('id-queue', None, True, id='with rq_id'), + ], +) +async def test_enqueue_links_with_rq_param( + server_url: URL, http_client: HttpClient, queue_name: str | None, queue_alias: str | None, *, by_id: bool +) -> None: + crawler = BeautifulSoupCrawler(http_client=http_client) + rq = await RequestQueue.open(name=queue_name, alias=queue_alias) + if by_id: + queue_name = None + queue_id = rq.id + else: + queue_id = None + visit_urls: set[str] = set() + + @crawler.router.default_handler + async def handler(context: BeautifulSoupCrawlingContext) -> None: + visit_urls.add(context.request.url) + await context.enqueue_links(rq_id=queue_id, rq_name=queue_name, rq_alias=queue_alias) + + await crawler.run([str(server_url / 'start_enqueue')]) + + requests_from_queue: list[str] = [] + while request := await rq.fetch_next_request(): + requests_from_queue.append(request.url) + + assert set(requests_from_queue) == {str(server_url / 'page_1'), str(server_url / 'sub_index')} + assert visit_urls == {str(server_url / 'start_enqueue')} + + await rq.drop() + + +@pytest.mark.parametrize( + ('queue_name', 'queue_alias', 'by_id'), + [ + pytest.param('named-queue', None, False, id='with rq_name'), + pytest.param(None, 'alias-queue', False, id='with rq_alias'), + pytest.param('id-queue', None, True, id='with rq_id'), + ], +) +async def test_enqueue_links_requests_with_rq_param( + server_url: URL, http_client: HttpClient, queue_name: str | None, queue_alias: str | None, *, by_id: bool +) -> None: + crawler = BeautifulSoupCrawler(http_client=http_client) + rq = await RequestQueue.open(name=queue_name, alias=queue_alias) + if by_id: + queue_name = None + queue_id = rq.id + else: + queue_id = None + visit_urls: set[str] = set() + + check_requests: list[str] = [ + 'https://a.placeholder.com', + 'https://b.placeholder.com', + 'https://c.placeholder.com', + ] + + @crawler.router.default_handler + async def handler(context: BeautifulSoupCrawlingContext) -> None: + visit_urls.add(context.request.url) + await context.enqueue_links( + requests=check_requests, rq_name=queue_name, rq_alias=queue_alias, rq_id=queue_id, strategy='all' + ) + + await crawler.run([str(server_url / 'start_enqueue')]) + + requests_from_queue: list[str] = [] + while request := await rq.fetch_next_request(): + requests_from_queue.append(request.url) + + assert set(requests_from_queue) == set(check_requests) + assert visit_urls == {str(server_url / 'start_enqueue')} + + await rq.drop() + + +@pytest.mark.parametrize( + ('queue_id', 'queue_name', 'queue_alias'), + [ + pytest.param('named-queue', 'alias-queue', None, id='rq_name and rq_alias'), + pytest.param('named-queue', None, 'id-queue', id='rq_name and rq_id'), + pytest.param(None, 'alias-queue', 'id-queue', id='rq_alias and rq_id'), + pytest.param('named-queue', 'alias-queue', 'id-queue', id='rq_name and rq_alias and rq_id'), + ], +) +async def test_enqueue_links_error_with_multi_params( + server_url: URL, http_client: HttpClient, queue_id: str | None, queue_name: str | None, queue_alias: str | None +) -> None: + crawler = BeautifulSoupCrawler(http_client=http_client) + + @crawler.router.default_handler + async def handler(context: BeautifulSoupCrawlingContext) -> None: + with pytest.raises(ValueError, match='Cannot use both `rq_name` and `rq_alias`'): + await context.enqueue_links(rq_id=queue_id, rq_name=queue_name, rq_alias=queue_alias) + + await crawler.run([str(server_url / 'start_enqueue')]) diff --git a/tests/unit/crawlers/_parsel/test_parsel_crawler.py b/tests/unit/crawlers/_parsel/test_parsel_crawler.py index 909563d822..5f74b7b262 100644 --- a/tests/unit/crawlers/_parsel/test_parsel_crawler.py +++ b/tests/unit/crawlers/_parsel/test_parsel_crawler.py @@ -8,6 +8,7 @@ from crawlee import ConcurrencySettings, Glob, HttpHeaders, Request, RequestTransformAction, SkippedReason from crawlee.crawlers import ParselCrawler +from crawlee.storages import RequestQueue if TYPE_CHECKING: from yarl import URL @@ -294,3 +295,107 @@ async def request_handler(context: ParselCrawlingContext) -> None: assert len(extracted_links) == 1 assert extracted_links[0] == str(server_url / 'page_1') + + +@pytest.mark.parametrize( + ('queue_name', 'queue_alias', 'by_id'), + [ + pytest.param('named-queue', None, False, id='with rq_name'), + pytest.param(None, 'alias-queue', False, id='with rq_alias'), + pytest.param('id-queue', None, True, id='with rq_id'), + ], +) +async def test_enqueue_links_with_rq_param( + server_url: URL, http_client: HttpClient, queue_name: str | None, queue_alias: str | None, *, by_id: bool +) -> None: + crawler = ParselCrawler(http_client=http_client) + rq = await RequestQueue.open(name=queue_name, alias=queue_alias) + if by_id: + queue_name = None + queue_id = rq.id + else: + queue_id = None + visit_urls: set[str] = set() + + @crawler.router.default_handler + async def handler(context: ParselCrawlingContext) -> None: + visit_urls.add(context.request.url) + await context.enqueue_links(rq_id=queue_id, rq_name=queue_name, rq_alias=queue_alias) + + await crawler.run([str(server_url / 'start_enqueue')]) + + requests_from_queue: list[str] = [] + while request := await rq.fetch_next_request(): + requests_from_queue.append(request.url) + + assert set(requests_from_queue) == {str(server_url / 'page_1'), str(server_url / 'sub_index')} + assert visit_urls == {str(server_url / 'start_enqueue')} + + await rq.drop() + + +@pytest.mark.parametrize( + ('queue_name', 'queue_alias', 'by_id'), + [ + pytest.param('named-queue', None, False, id='with rq_name'), + pytest.param(None, 'alias-queue', False, id='with rq_alias'), + pytest.param('id-queue', None, True, id='with rq_id'), + ], +) +async def test_enqueue_links_requests_with_rq_param( + server_url: URL, http_client: HttpClient, queue_name: str | None, queue_alias: str | None, *, by_id: bool +) -> None: + crawler = ParselCrawler(http_client=http_client) + rq = await RequestQueue.open(name=queue_name, alias=queue_alias) + if by_id: + queue_name = None + queue_id = rq.id + else: + queue_id = None + visit_urls: set[str] = set() + + check_requests: list[str] = [ + 'https://a.placeholder.com', + 'https://b.placeholder.com', + 'https://c.placeholder.com', + ] + + @crawler.router.default_handler + async def handler(context: ParselCrawlingContext) -> None: + visit_urls.add(context.request.url) + await context.enqueue_links( + requests=check_requests, rq_id=queue_id, rq_name=queue_name, rq_alias=queue_alias, strategy='all' + ) + + await crawler.run([str(server_url / 'start_enqueue')]) + + requests_from_queue: list[str] = [] + while request := await rq.fetch_next_request(): + requests_from_queue.append(request.url) + + assert set(requests_from_queue) == set(check_requests) + assert visit_urls == {str(server_url / 'start_enqueue')} + + await rq.drop() + + +@pytest.mark.parametrize( + ('queue_id', 'queue_name', 'queue_alias'), + [ + pytest.param('named-queue', 'alias-queue', None, id='rq_name and rq_alias'), + pytest.param('named-queue', None, 'id-queue', id='rq_name and rq_id'), + pytest.param(None, 'alias-queue', 'id-queue', id='rq_alias and rq_id'), + pytest.param('named-queue', 'alias-queue', 'id-queue', id='rq_name and rq_alias and rq_id'), + ], +) +async def test_enqueue_links_error_with_multi_params( + server_url: URL, http_client: HttpClient, queue_id: str | None, queue_name: str | None, queue_alias: str | None +) -> None: + crawler = ParselCrawler(http_client=http_client) + + @crawler.router.default_handler + async def handler(context: ParselCrawlingContext) -> None: + with pytest.raises(ValueError, match='Cannot use both `rq_name` and `rq_alias`'): + await context.enqueue_links(rq_id=queue_id, rq_name=queue_name, rq_alias=queue_alias) + + await crawler.run([str(server_url / 'start_enqueue')]) diff --git a/tests/unit/crawlers/_playwright/test_playwright_crawler.py b/tests/unit/crawlers/_playwright/test_playwright_crawler.py index 2f52cac163..0bde7d55fd 100644 --- a/tests/unit/crawlers/_playwright/test_playwright_crawler.py +++ b/tests/unit/crawlers/_playwright/test_playwright_crawler.py @@ -37,6 +37,7 @@ from crawlee.sessions import Session, SessionPool from crawlee.statistics import Statistics from crawlee.statistics._error_snapshotter import ErrorSnapshotter +from crawlee.storages import RequestQueue from tests.unit.server_endpoints import GENERIC_RESPONSE, HELLO_WORLD if TYPE_CHECKING: @@ -784,3 +785,107 @@ async def test_reduced_logs_from_playwright_navigation_timeout(caplog: pytest.Lo break assert found_playwright_message, 'Expected log message about request handler error was not found.' + + +@pytest.mark.parametrize( + ('queue_name', 'queue_alias', 'by_id'), + [ + pytest.param('named-queue', None, False, id='with rq_name'), + pytest.param(None, 'alias-queue', False, id='with rq_alias'), + pytest.param('id-queue', None, True, id='with rq_id'), + ], +) +async def test_enqueue_links_with_rq_param( + server_url: URL, queue_name: str | None, queue_alias: str | None, *, by_id: bool +) -> None: + crawler = PlaywrightCrawler() + rq = await RequestQueue.open(name=queue_name, alias=queue_alias) + if by_id: + queue_name = None + queue_id = rq.id + else: + queue_id = None + visit_urls: set[str] = set() + + @crawler.router.default_handler + async def handler(context: PlaywrightCrawlingContext) -> None: + visit_urls.add(context.request.url) + await context.enqueue_links(rq_id=queue_id, rq_name=queue_name, rq_alias=queue_alias) + + await crawler.run([str(server_url / 'start_enqueue')]) + + requests_from_queue: list[str] = [] + while request := await rq.fetch_next_request(): + requests_from_queue.append(request.url) + + assert set(requests_from_queue) == {str(server_url / 'page_1'), str(server_url / 'sub_index')} + assert visit_urls == {str(server_url / 'start_enqueue')} + + await rq.drop() + + +@pytest.mark.parametrize( + ('queue_name', 'queue_alias', 'by_id'), + [ + pytest.param('named-queue', None, False, id='with rq_name'), + pytest.param(None, 'alias-queue', False, id='with rq_alias'), + pytest.param('id-queue', None, True, id='with rq_id'), + ], +) +async def test_enqueue_links_requests_with_rq_param( + server_url: URL, queue_name: str | None, queue_alias: str | None, *, by_id: bool +) -> None: + crawler = PlaywrightCrawler() + rq = await RequestQueue.open(name=queue_name, alias=queue_alias) + if by_id: + queue_name = None + queue_id = rq.id + else: + queue_id = None + visit_urls: set[str] = set() + + check_requests: list[str] = [ + 'https://a.placeholder.com', + 'https://b.placeholder.com', + 'https://c.placeholder.com', + ] + + @crawler.router.default_handler + async def handler(context: PlaywrightCrawlingContext) -> None: + visit_urls.add(context.request.url) + await context.enqueue_links( + requests=check_requests, rq_id=queue_id, rq_name=queue_name, rq_alias=queue_alias, strategy='all' + ) + + await crawler.run([str(server_url / 'start_enqueue')]) + + requests_from_queue: list[str] = [] + while request := await rq.fetch_next_request(): + requests_from_queue.append(request.url) + + assert set(requests_from_queue) == set(check_requests) + assert visit_urls == {str(server_url / 'start_enqueue')} + + await rq.drop() + + +@pytest.mark.parametrize( + ('queue_id', 'queue_name', 'queue_alias'), + [ + pytest.param('named-queue', 'alias-queue', None, id='rq_name and rq_alias'), + pytest.param('named-queue', None, 'id-queue', id='rq_name and rq_id'), + pytest.param(None, 'alias-queue', 'id-queue', id='rq_alias and rq_id'), + pytest.param('named-queue', 'alias-queue', 'id-queue', id='rq_name and rq_alias and rq_id'), + ], +) +async def test_enqueue_links_error_with_multi_params( + server_url: URL, queue_id: str | None, queue_name: str | None, queue_alias: str | None +) -> None: + crawler = PlaywrightCrawler() + + @crawler.router.default_handler + async def handler(context: PlaywrightCrawlingContext) -> None: + with pytest.raises(ValueError, match='Cannot use both `rq_name` and `rq_alias`'): + await context.enqueue_links(rq_id=queue_id, rq_name=queue_name, rq_alias=queue_alias) + + await crawler.run([str(server_url / 'start_enqueue')])