diff --git a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py index 5d05098886..c110b30104 100644 --- a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py +++ b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py @@ -4,6 +4,7 @@ import logging from abc import ABC from typing import TYPE_CHECKING, Any, Callable, Generic, Union +from urllib.parse import urlparse from pydantic import ValidationError from typing_extensions import TypeVar @@ -155,15 +156,21 @@ async def extract_links( | None = None, **kwargs: Unpack[EnqueueLinksKwargs], ) -> list[Request]: - kwargs.setdefault('strategy', 'same-hostname') - requests = list[Request]() skipped = list[str]() base_user_data = user_data or {} robots_txt_file = await self._get_robots_txt_file_for_url(context.request.url) + strategy = kwargs.get('strategy', 'same-hostname') + include_blobs = kwargs.get('include') + exclude_blobs = kwargs.get('exclude') + limit_requests = kwargs.get('limit') + for link in self._parser.find_links(parsed_content, selector=selector): + if limit_requests and len(requests) >= limit_requests: + break + url = link if not is_url_absolute(url): base_url = context.request.loaded_url or context.request.url @@ -173,26 +180,31 @@ async def extract_links( skipped.append(url) continue - request_options = RequestOptions(url=url, user_data={**base_user_data}, label=label) - - if transform_request_function: - transform_request_options = transform_request_function(request_options) - if transform_request_options == 'skip': + if self._check_enqueue_strategy( + strategy, + target_url=urlparse(url), + origin_url=urlparse(context.request.url), + ) and self._check_url_patterns(url, include_blobs, exclude_blobs): + request_options = RequestOptions(url=url, user_data={**base_user_data}, label=label) + + if transform_request_function: + transform_request_options = transform_request_function(request_options) + if transform_request_options == 'skip': + continue + if transform_request_options != 'unchanged': + request_options = transform_request_options + + try: + request = Request.from_url(**request_options) + except ValidationError as exc: + context.log.debug( + f'Skipping URL "{url}" due to invalid format: {exc}. ' + 'This may be caused by a malformed URL or unsupported URL scheme. ' + 'Please ensure the URL is correct and retry.' + ) continue - if transform_request_options != 'unchanged': - request_options = transform_request_options - - try: - request = Request.from_url(**request_options) - except ValidationError as exc: - context.log.debug( - f'Skipping URL "{url}" due to invalid format: {exc}. ' - 'This may be caused by a malformed URL or unsupported URL scheme. ' - 'Please ensure the URL is correct and retry.' - ) - continue - requests.append(request) + requests.append(request) if skipped: skipped_tasks = [ diff --git a/src/crawlee/crawlers/_playwright/_playwright_crawler.py b/src/crawlee/crawlers/_playwright/_playwright_crawler.py index f07206e743..24d8e3537a 100644 --- a/src/crawlee/crawlers/_playwright/_playwright_crawler.py +++ b/src/crawlee/crawlers/_playwright/_playwright_crawler.py @@ -4,6 +4,7 @@ import logging from functools import partial from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union +from urllib.parse import urlparse from pydantic import ValidationError from typing_extensions import NotRequired, TypedDict, TypeVar @@ -295,8 +296,6 @@ async def extract_links( The `PlaywrightCrawler` implementation of the `ExtractLinksFunction` function. """ - kwargs.setdefault('strategy', 'same-hostname') - requests = list[Request]() skipped = list[str]() base_user_data = user_data or {} @@ -305,7 +304,15 @@ async def extract_links( robots_txt_file = await self._get_robots_txt_file_for_url(context.request.url) + strategy = kwargs.get('strategy', 'same-hostname') + include_blobs = kwargs.get('include') + exclude_blobs = kwargs.get('exclude') + limit_requests = kwargs.get('limit') + for element in elements: + if limit_requests and len(requests) >= limit_requests: + break + url = await element.get_attribute('href') if url: @@ -319,26 +326,31 @@ async def extract_links( skipped.append(url) continue - request_option = RequestOptions({'url': url, 'user_data': {**base_user_data}, 'label': label}) - - if transform_request_function: - transform_request_option = transform_request_function(request_option) - if transform_request_option == 'skip': + if self._check_enqueue_strategy( + strategy, + target_url=urlparse(url), + origin_url=urlparse(context.request.url), + ) and self._check_url_patterns(url, include_blobs, exclude_blobs): + request_option = RequestOptions({'url': url, 'user_data': {**base_user_data}, 'label': label}) + + if transform_request_function: + transform_request_option = transform_request_function(request_option) + if transform_request_option == 'skip': + continue + if transform_request_option != 'unchanged': + request_option = transform_request_option + + try: + request = Request.from_url(**request_option) + except ValidationError as exc: + context.log.debug( + f'Skipping URL "{url}" due to invalid format: {exc}. ' + 'This may be caused by a malformed URL or unsupported URL scheme. ' + 'Please ensure the URL is correct and retry.' + ) continue - if transform_request_option != 'unchanged': - request_option = transform_request_option - - try: - request = Request.from_url(**request_option) - except ValidationError as exc: - context.log.debug( - f'Skipping URL "{url}" due to invalid format: {exc}. ' - 'This may be caused by a malformed URL or unsupported URL scheme. ' - 'Please ensure the URL is correct and retry.' - ) - continue - requests.append(request) + requests.append(request) if skipped: skipped_tasks = [ diff --git a/tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py b/tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py index 4a0949b831..37f2d1b8ed 100644 --- a/tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py +++ b/tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from unittest import mock -from crawlee import ConcurrencySettings, HttpHeaders, RequestTransformAction, SkippedReason +from crawlee import ConcurrencySettings, Glob, HttpHeaders, RequestTransformAction, SkippedReason from crawlee.crawlers import BeautifulSoupCrawler, BeautifulSoupCrawlingContext if TYPE_CHECKING: @@ -183,3 +183,18 @@ async def skipped_hook(url: str, _reason: SkippedReason) -> None: str(server_url / 'page_2'), str(server_url / 'page_3'), } + + +async def test_extract_links(server_url: URL, http_client: HttpClient) -> None: + crawler = BeautifulSoupCrawler(http_client=http_client) + extracted_links: list[str] = [] + + @crawler.router.default_handler + async def request_handler(context: BeautifulSoupCrawlingContext) -> None: + links = await context.extract_links(exclude=[Glob(f'{server_url}sub_index')]) + extracted_links.extend(request.url for request in links) + + await crawler.run([str(server_url / 'start_enqueue')]) + + assert len(extracted_links) == 1 + assert extracted_links[0] == str(server_url / 'page_1') diff --git a/tests/unit/crawlers/_parsel/test_parsel_crawler.py b/tests/unit/crawlers/_parsel/test_parsel_crawler.py index 5b66564920..0c36e6dd9b 100644 --- a/tests/unit/crawlers/_parsel/test_parsel_crawler.py +++ b/tests/unit/crawlers/_parsel/test_parsel_crawler.py @@ -6,7 +6,7 @@ import pytest -from crawlee import ConcurrencySettings, HttpHeaders, Request, RequestTransformAction, SkippedReason +from crawlee import ConcurrencySettings, Glob, HttpHeaders, Request, RequestTransformAction, SkippedReason from crawlee.crawlers import ParselCrawler if TYPE_CHECKING: @@ -279,3 +279,18 @@ async def skipped_hook(url: str, _reason: SkippedReason) -> None: str(server_url / 'page_2'), str(server_url / 'page_3'), } + + +async def test_extract_links(server_url: URL, http_client: HttpClient) -> None: + crawler = ParselCrawler(http_client=http_client) + extracted_links: list[str] = [] + + @crawler.router.default_handler + async def request_handler(context: ParselCrawlingContext) -> None: + links = await context.extract_links(exclude=[Glob(f'{server_url}sub_index')]) + extracted_links.extend(request.url for request in links) + + await crawler.run([str(server_url / 'start_enqueue')]) + + assert len(extracted_links) == 1 + assert extracted_links[0] == str(server_url / 'page_1') diff --git a/tests/unit/crawlers/_playwright/test_playwright_crawler.py b/tests/unit/crawlers/_playwright/test_playwright_crawler.py index 5b7af71bf4..bce64f4f02 100644 --- a/tests/unit/crawlers/_playwright/test_playwright_crawler.py +++ b/tests/unit/crawlers/_playwright/test_playwright_crawler.py @@ -11,7 +11,15 @@ import pytest -from crawlee import ConcurrencySettings, HttpHeaders, Request, RequestTransformAction, SkippedReason, service_locator +from crawlee import ( + ConcurrencySettings, + Glob, + HttpHeaders, + Request, + RequestTransformAction, + SkippedReason, + service_locator, +) from crawlee.configuration import Configuration from crawlee.crawlers import PlaywrightCrawler from crawlee.fingerprint_suite import ( @@ -698,3 +706,18 @@ async def test_overwrite_configuration() -> None: PlaywrightCrawler(configuration=configuration) used_configuration = service_locator.get_configuration() assert used_configuration is configuration + + +async def test_extract_links(server_url: URL) -> None: + crawler = PlaywrightCrawler() + extracted_links: list[str] = [] + + @crawler.router.default_handler + async def request_handler(context: PlaywrightCrawlingContext) -> None: + links = await context.extract_links(exclude=[Glob(f'{server_url}sub_index')]) + extracted_links.extend(request.url for request in links) + + await crawler.run([str(server_url / 'start_enqueue')]) + + assert len(extracted_links) == 1 + assert extracted_links[0] == str(server_url / 'page_1')