diff --git a/src/crawlee/__init__.py b/src/crawlee/__init__.py index d8433a5e96..508835e008 100644 --- a/src/crawlee/__init__.py +++ b/src/crawlee/__init__.py @@ -1,6 +1,6 @@ from importlib import metadata -from ._request import Request, RequestOptions +from ._request import Request, RequestOptions, RequestState from ._service_locator import service_locator from ._types import ConcurrencySettings, EnqueueStrategy, HttpHeaders, RequestTransformAction, SkippedReason from ._utils.globs import Glob @@ -14,6 +14,7 @@ 'HttpHeaders', 'Request', 'RequestOptions', + 'RequestState', 'RequestTransformAction', 'SkippedReason', 'service_locator', diff --git a/src/crawlee/_request.py b/src/crawlee/_request.py index 1e6194dde0..c8c383915d 100644 --- a/src/crawlee/_request.py +++ b/src/crawlee/_request.py @@ -41,7 +41,7 @@ class CrawleeRequestData(BaseModel): enqueue_strategy: Annotated[EnqueueStrategy | None, Field(alias='enqueueStrategy')] = None """The strategy that was used for enqueuing the request.""" - state: RequestState | None = None + state: RequestState = RequestState.UNPROCESSED """Describes the request's current lifecycle state.""" session_rotation_count: Annotated[int | None, Field(alias='sessionRotationCount')] = None @@ -352,7 +352,7 @@ def crawl_depth(self, new_value: int) -> None: self.crawlee_data.crawl_depth = new_value @property - def state(self) -> RequestState | None: + def state(self) -> RequestState: """Crawlee-specific request handling state.""" return self.crawlee_data.state diff --git a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py index 8f3b1451ba..a3e303f088 100644 --- a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py +++ b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py @@ -10,7 +10,7 @@ from pydantic import ValidationError from typing_extensions import NotRequired, TypeVar -from crawlee._request import Request, RequestOptions +from crawlee._request import Request, RequestOptions, RequestState from crawlee._utils.docs import docs_group from crawlee._utils.time import SharedTimeout from crawlee._utils.urls import to_absolute_url_iterator @@ -257,6 +257,7 @@ async def _make_http_request(self, context: BasicCrawlingContext) -> AsyncGenera timeout=remaining_timeout, ) + context.request.state = RequestState.AFTER_NAV yield HttpCrawlingContext.from_basic_crawling_context(context=context, http_response=result.http_response) async def _handle_status_code_response( diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 5b3e797d41..2879685fda 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -1152,6 +1152,7 @@ async def _handle_request_retries( await request_manager.reclaim_request(request) else: + request.state = RequestState.ERROR await self._mark_request_as_handled(request) await self._handle_failed_request(context, error) self._statistics.record_request_processing_failure(request.unique_key) @@ -1167,8 +1168,6 @@ async def _handle_request_error(self, context: TCrawlingContext | BasicCrawlingC f'{self._internal_timeout.total_seconds()} seconds', logger=self._logger, ) - - context.request.state = RequestState.DONE except UserDefinedErrorHandlerError: context.request.state = RequestState.ERROR raise @@ -1201,8 +1200,8 @@ async def _handle_skipped_request( self, request: Request | str, reason: SkippedReason, *, need_mark: bool = False ) -> None: if need_mark and isinstance(request, Request): - await self._mark_request_as_handled(request) request.state = RequestState.SKIPPED + await self._mark_request_as_handled(request) url = request.url if isinstance(request, Request) else request @@ -1403,8 +1402,6 @@ async def __run_task_function(self) -> None: self._statistics.record_request_processing_start(request.unique_key) try: - request.state = RequestState.REQUEST_HANDLER - self._check_request_collision(context.request, context.session) try: @@ -1414,10 +1411,10 @@ async def __run_task_function(self) -> None: await self._commit_request_handler_result(context) - await self._mark_request_as_handled(request) - request.state = RequestState.DONE + await self._mark_request_as_handled(request) + if context.session and context.session.is_usable: context.session.mark_good() @@ -1483,6 +1480,7 @@ async def __run_task_function(self) -> None: raise async def _run_request_handler(self, context: BasicCrawlingContext) -> None: + context.request.state = RequestState.BEFORE_NAV await self._context_pipeline( context, lambda final_context: wait_for( diff --git a/src/crawlee/crawlers/_playwright/_playwright_crawler.py b/src/crawlee/crawlers/_playwright/_playwright_crawler.py index dd8fa72949..2b16261ebe 100644 --- a/src/crawlee/crawlers/_playwright/_playwright_crawler.py +++ b/src/crawlee/crawlers/_playwright/_playwright_crawler.py @@ -13,7 +13,7 @@ from typing_extensions import NotRequired, TypedDict, TypeVar from crawlee import service_locator -from crawlee._request import Request, RequestOptions +from crawlee._request import Request, RequestOptions, RequestState from crawlee._types import ( BasicCrawlingContext, ConcurrencySettings, @@ -323,6 +323,7 @@ async def _navigate( response = await context.page.goto( context.request.url, timeout=remaining_timeout.total_seconds() * 1000 ) + context.request.state = RequestState.AFTER_NAV except playwright.async_api.TimeoutError as exc: raise asyncio.TimeoutError from exc diff --git a/src/crawlee/router.py b/src/crawlee/router.py index a6278d98e1..d8d43528e4 100644 --- a/src/crawlee/router.py +++ b/src/crawlee/router.py @@ -3,6 +3,7 @@ from collections.abc import Awaitable, Callable from typing import Generic, TypeVar +from crawlee._request import RequestState from crawlee._types import BasicCrawlingContext from crawlee._utils.docs import docs_group @@ -89,6 +90,7 @@ def wrapper(handler: Callable[[TCrawlingContext], Awaitable]) -> Callable[[TCraw async def __call__(self, context: TCrawlingContext) -> None: """Invoke a request handler that matches the request label (or the default).""" + context.request.state = RequestState.REQUEST_HANDLER if context.request.label is None or context.request.label not in self._handlers_by_label: if self._default_handler is None: raise RuntimeError( diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index e9606fb10c..87ee9d14fa 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -1829,5 +1829,5 @@ async def error_handler(context: BasicCrawlingContext, error: Exception) -> Requ assert original_request.was_already_handled assert error_request is not None - assert error_request.state == RequestState.REQUEST_HANDLER + assert error_request.state == RequestState.DONE assert error_request.was_already_handled diff --git a/tests/unit/crawlers/_http/test_http_crawler.py b/tests/unit/crawlers/_http/test_http_crawler.py index a88ae074be..21bfde2eaf 100644 --- a/tests/unit/crawlers/_http/test_http_crawler.py +++ b/tests/unit/crawlers/_http/test_http_crawler.py @@ -7,10 +7,11 @@ import pytest -from crawlee import ConcurrencySettings, Request +from crawlee import ConcurrencySettings, Request, RequestState from crawlee.crawlers import HttpCrawler from crawlee.sessions import SessionPool from crawlee.statistics import Statistics +from crawlee.storages import RequestQueue from tests.unit.server_endpoints import HELLO_WORLD if TYPE_CHECKING: @@ -577,3 +578,57 @@ async def request_handler(context: HttpCrawlingContext) -> None: assert len(kvs_content) == 1 assert content_key.endswith('.html') assert kvs_content[content_key] == HELLO_WORLD.decode('utf8') + + +async def test_request_state(server_url: URL) -> None: + queue = await RequestQueue.open(alias='http_request_state') + crawler = HttpCrawler(request_manager=queue) + + success_request = Request.from_url(str(server_url)) + assert success_request.state == RequestState.UNPROCESSED + + error_request = Request.from_url(str(server_url / 'error'), user_data={'cause_error': True}) + + requests_states: dict[str, dict[str, RequestState]] = {success_request.unique_key: {}, error_request.unique_key: {}} + + @crawler.pre_navigation_hook + async def pre_navigation_hook(context: BasicCrawlingContext) -> None: + requests_states[context.request.unique_key]['pre_navigation'] = context.request.state + + @crawler.router.default_handler + async def request_handler(context: HttpCrawlingContext) -> None: + if context.request.user_data.get('cause_error'): + raise ValueError('Caused error as requested') + requests_states[context.request.unique_key]['request_handler'] = context.request.state + + @crawler.error_handler + async def error_handler(context: BasicCrawlingContext, _error: Exception) -> None: + requests_states[context.request.unique_key]['error_handler'] = context.request.state + + @crawler.failed_request_handler + async def failed_request_handler(context: BasicCrawlingContext, _error: Exception) -> None: + requests_states[context.request.unique_key]['failed_request_handler'] = context.request.state + + await crawler.run([success_request, error_request]) + + handled_success_request = await queue.get_request(success_request.unique_key) + + assert handled_success_request is not None + assert handled_success_request.state == RequestState.DONE + + assert requests_states[success_request.unique_key] == { + 'pre_navigation': RequestState.BEFORE_NAV, + 'request_handler': RequestState.REQUEST_HANDLER, + } + + handled_error_request = await queue.get_request(error_request.unique_key) + assert handled_error_request is not None + assert handled_error_request.state == RequestState.ERROR + + assert requests_states[error_request.unique_key] == { + 'pre_navigation': RequestState.BEFORE_NAV, + 'error_handler': RequestState.ERROR_HANDLER, + 'failed_request_handler': RequestState.ERROR, + } + + await queue.drop() diff --git a/tests/unit/crawlers/_playwright/test_playwright_crawler.py b/tests/unit/crawlers/_playwright/test_playwright_crawler.py index 96c799eac4..134f699161 100644 --- a/tests/unit/crawlers/_playwright/test_playwright_crawler.py +++ b/tests/unit/crawlers/_playwright/test_playwright_crawler.py @@ -19,6 +19,7 @@ Glob, HttpHeaders, Request, + RequestState, RequestTransformAction, SkippedReason, service_locator, @@ -991,3 +992,57 @@ async def test_slow_navigation_does_not_count_toward_handler_timeout(server_url: assert result.requests_failed == 0 assert result.requests_finished == 1 assert request_handler.call_count == 1 + + +async def test_request_state(server_url: URL) -> None: + queue = await RequestQueue.open(alias='playwright_request_state') + crawler = PlaywrightCrawler(request_manager=queue) + + success_request = Request.from_url(str(server_url)) + assert success_request.state == RequestState.UNPROCESSED + + error_request = Request.from_url(str(server_url / 'error'), user_data={'cause_error': True}) + + requests_states: dict[str, dict[str, RequestState]] = {success_request.unique_key: {}, error_request.unique_key: {}} + + @crawler.pre_navigation_hook + async def pre_navigation_hook(context: PlaywrightPreNavCrawlingContext) -> None: + requests_states[context.request.unique_key]['pre_navigation'] = context.request.state + + @crawler.router.default_handler + async def request_handler(context: PlaywrightCrawlingContext) -> None: + if context.request.user_data.get('cause_error'): + raise ValueError('Caused error as requested') + requests_states[context.request.unique_key]['request_handler'] = context.request.state + + @crawler.error_handler + async def error_handler(context: BasicCrawlingContext, _error: Exception) -> None: + requests_states[context.request.unique_key]['error_handler'] = context.request.state + + @crawler.failed_request_handler + async def failed_request_handler(context: BasicCrawlingContext, _error: Exception) -> None: + requests_states[context.request.unique_key]['failed_request_handler'] = context.request.state + + await crawler.run([success_request, error_request]) + + handled_success_request = await queue.get_request(success_request.unique_key) + + assert handled_success_request is not None + assert handled_success_request.state == RequestState.DONE + + assert requests_states[success_request.unique_key] == { + 'pre_navigation': RequestState.BEFORE_NAV, + 'request_handler': RequestState.REQUEST_HANDLER, + } + + handled_error_request = await queue.get_request(error_request.unique_key) + assert handled_error_request is not None + assert handled_error_request.state == RequestState.ERROR + + assert requests_states[error_request.unique_key] == { + 'pre_navigation': RequestState.BEFORE_NAV, + 'error_handler': RequestState.ERROR_HANDLER, + 'failed_request_handler': RequestState.ERROR, + } + + await queue.drop()