Skip to content

Commit

Permalink
feat: Add max_crawl_depth option to BasicCrawler (#637)
Browse files Browse the repository at this point in the history
### Description

<!-- The purpose of the PR, list of the changes, ... -->

- Implements "max crawl depth"

### Issues

<!-- If applicable, reference any related GitHub issues -->

- Closes: #460 

### Testing

<!-- Describe the testing process for these changes -->

- Added tests

### Checklist

- [ ] CI passed
  • Loading branch information
Prathamesh010 authored Nov 4, 2024
1 parent fcf7f5e commit 77deaa9
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 8 deletions.
12 changes: 12 additions & 0 deletions src/crawlee/_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class CrawleeRequestData(BaseModel):
forefront: Annotated[bool, Field()] = False
"""Indicate whether the request should be enqueued at the front of the queue."""

crawl_depth: Annotated[int, Field(alias='crawlDepth')] = 0
"""The depth of the request in the crawl tree."""


class UserData(BaseModel, MutableMapping[str, JsonSerializable]):
"""Represents the `user_data` part of a Request.
Expand Down Expand Up @@ -360,6 +363,15 @@ def crawlee_data(self) -> CrawleeRequestData:

return user_data.crawlee_data

@property
def crawl_depth(self) -> int:
"""The depth of the request in the crawl tree."""
return self.crawlee_data.crawl_depth

@crawl_depth.setter
def crawl_depth(self, new_value: int) -> None:
self.crawlee_data.crawl_depth = new_value

@property
def state(self) -> RequestState | None:
"""Crawlee-specific request handling state."""
Expand Down
30 changes: 22 additions & 8 deletions src/crawlee/basic_crawler/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ class BasicCrawlerOptions(TypedDict, Generic[TCrawlingContext]):
configure_logging: NotRequired[bool]
"""If True, the crawler will set up logging infrastructure automatically."""

max_crawl_depth: NotRequired[int | None]
"""Limits crawl depth from 0 (initial requests) up to the specified `max_crawl_depth`.
Requests at the maximum depth are processed, but no further links are enqueued."""

_context_pipeline: NotRequired[ContextPipeline[TCrawlingContext]]
"""Enables extending the request lifecycle and modifying the crawling context. Intended for use by
subclasses rather than direct instantiation of `BasicCrawler`."""
Expand Down Expand Up @@ -174,6 +178,7 @@ def __init__(
statistics: Statistics | None = None,
event_manager: EventManager | None = None,
configure_logging: bool = True,
max_crawl_depth: int | None = None,
_context_pipeline: ContextPipeline[TCrawlingContext] | None = None,
_additional_context_managers: Sequence[AsyncContextManager] | None = None,
_logger: logging.Logger | None = None,
Expand Down Expand Up @@ -201,6 +206,7 @@ def __init__(
statistics: A custom `Statistics` instance, allowing the use of non-default configuration.
event_manager: A custom `EventManager` instance, allowing the use of non-default configuration.
configure_logging: If True, the crawler will set up logging infrastructure automatically.
max_crawl_depth: Maximum crawl depth. If set, the crawler will stop crawling after reaching this depth.
_context_pipeline: Enables extending the request lifecycle and modifying the crawling context.
Intended for use by subclasses rather than direct instantiation of `BasicCrawler`.
_additional_context_managers: Additional context managers used throughout the crawler lifecycle.
Expand Down Expand Up @@ -283,6 +289,7 @@ def __init__(

self._running = False
self._has_finished_before = False
self._max_crawl_depth = max_crawl_depth

@property
def log(self) -> logging.Logger:
Expand Down Expand Up @@ -841,14 +848,21 @@ async def _commit_request_handler_result(
else:
dst_request = Request.from_base_request_data(request)

if self._check_enqueue_strategy(
add_requests_call.get('strategy', EnqueueStrategy.ALL),
target_url=urlparse(dst_request.url),
origin_url=urlparse(origin),
) and self._check_url_patterns(
dst_request.url,
add_requests_call.get('include', None),
add_requests_call.get('exclude', None),
# Update the crawl depth of the request.
dst_request.crawl_depth = context.request.crawl_depth + 1

if (
(self._max_crawl_depth is None or dst_request.crawl_depth <= self._max_crawl_depth)
and self._check_enqueue_strategy(
add_requests_call.get('strategy', EnqueueStrategy.ALL),
target_url=urlparse(dst_request.url),
origin_url=urlparse(origin),
)
and self._check_url_patterns(
dst_request.url,
add_requests_call.get('include', None),
add_requests_call.get('exclude', None),
)
):
requests.append(dst_request)

Expand Down
29 changes: 29 additions & 0 deletions tests/unit/basic_crawler/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,35 @@ async def handler(context: BasicCrawlingContext) -> None:
assert stats.requests_finished == 3


async def test_max_crawl_depth(httpbin: str) -> None:
processed_urls = []

start_request = Request.from_url('https://someplace.com/', label='start')
start_request.crawl_depth = 2

# Set max_concurrency to 1 to ensure testing max_requests_per_crawl accurately
crawler = BasicCrawler(
concurrency_settings=ConcurrencySettings(max_concurrency=1),
max_crawl_depth=2,
request_provider=RequestList([start_request]),
)

@crawler.router.handler('start')
async def start_handler(context: BasicCrawlingContext) -> None:
processed_urls.append(context.request.url)
await context.add_requests(['https://someplace.com/too-deep'])

@crawler.router.default_handler
async def handler(context: BasicCrawlingContext) -> None:
processed_urls.append(context.request.url)

stats = await crawler.run()

assert len(processed_urls) == 1
assert stats.requests_total == 1
assert stats.requests_finished == 1


def test_crawler_log() -> None:
crawler = BasicCrawler()
assert isinstance(crawler.log, logging.Logger)
Expand Down

0 comments on commit 77deaa9

Please sign in to comment.