From 83ecd4403c74effe5b427deb8799f465a9d16ab6 Mon Sep 17 00:00:00 2001 From: Rohit Sharma Date: Sun, 26 Apr 2026 21:28:04 +0100 Subject: [PATCH 1/3] Add AWS EMR Serverless connector plugin (flytekitplugins-awsemrserverless) Adds a new plugin that lets Flyte tasks run Spark and Hive jobs on AWS EMR Serverless via an async connector. Three task modes are supported: - Pythonic Spark: regular @task body, packaged and uploaded to S3 by the connector and executed on EMR Serverless workers. - Script Spark: point at an existing s3://.../main.py (or JAR) and submit it directly. - Hive: submit a Hive query against a Hive-configured EMR Serverless application. The connector implements the standard create / get / delete lifecycle on top of EMR Serverless StartJobRun / GetJobRun / CancelJobRun, maps EMR Serverless states to Flyte task phases, and surfaces log URIs to the Flyte UI when available. It is intended to be deployed on the existing upstream flyteconnector pod with this package installed and an IAM identity allowed to call EMR Serverless. Layout follows the existing flytekit-aws-batch / flytekit-aws-sagemaker plugins (setup.py, flytekitplugins/, tests/, README.md). 138 unit tests cover the connector lifecycle, task config, boto handler, and worker entrypoint. Tracking issue: flyteorg/flyte#7286 Signed-off-by: Rohit Sharma Made-with: Cursor Signed-off-by: Rohit Sharma --- plugins/flytekit-aws-emr-serverless/README.md | 109 ++ .../awsemrserverless/__init__.py | 29 + .../awsemrserverless/_entrypoint.py | 203 +++ .../awsemrserverless/boto_handler.py | 398 ++++++ .../awsemrserverless/connector.py | 825 +++++++++++ .../flytekitplugins/awsemrserverless/task.py | 421 ++++++ plugins/flytekit-aws-emr-serverless/setup.py | 49 + .../tests/__init__.py | 1 + .../tests/conftest.py | 157 ++ .../tests/test_boto_handler.py | 336 +++++ .../tests/test_connector.py | 1266 +++++++++++++++++ .../tests/test_entrypoint.py | 338 +++++ .../tests/test_task.py | 346 +++++ 13 files changed, 4478 insertions(+) create mode 100644 plugins/flytekit-aws-emr-serverless/README.md create mode 100644 plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/__init__.py create mode 100644 plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/_entrypoint.py create mode 100644 plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/boto_handler.py create mode 100644 plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/connector.py create mode 100644 plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/task.py create mode 100644 plugins/flytekit-aws-emr-serverless/setup.py create mode 100644 plugins/flytekit-aws-emr-serverless/tests/__init__.py create mode 100644 plugins/flytekit-aws-emr-serverless/tests/conftest.py create mode 100644 plugins/flytekit-aws-emr-serverless/tests/test_boto_handler.py create mode 100644 plugins/flytekit-aws-emr-serverless/tests/test_connector.py create mode 100644 plugins/flytekit-aws-emr-serverless/tests/test_entrypoint.py create mode 100644 plugins/flytekit-aws-emr-serverless/tests/test_task.py diff --git a/plugins/flytekit-aws-emr-serverless/README.md b/plugins/flytekit-aws-emr-serverless/README.md new file mode 100644 index 0000000000..3ecda1ad88 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/README.md @@ -0,0 +1,109 @@ +# Flytekit AWS EMR Serverless Plugin + +A Flyte connector for [AWS EMR Serverless](https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/emr-serverless.html) that submits Spark and Hive jobs to an EMR Serverless application and tracks them through to completion. + +## Features + +- **Pythonic Spark mode**: write a Flyte `@task` whose body is regular PySpark; the plugin packages the user code, uploads it to S3, and runs it on EMR Serverless. No long-lived cluster to manage. +- **Script Spark mode**: point at an existing `main.py` (or JAR) already in S3 and submit it directly. +- **Hive mode**: submit a Hive query (inline or from S3) against an EMR Serverless application configured for Hive. +- Async connector lifecycle (`create` / `get` / `delete`) so the connector pod stays light and many jobs can be tracked concurrently. +- Honours Flyte task retries, timeouts, and cancellation, and surfaces EMR Serverless logs through the Flyte UI when log URIs are available. + +## Installation + +```bash +pip install flytekitplugins-awsemrserverless +``` + +The connector is registered automatically with `flytekit` via the plugin entry point. Deploy it on a [`flyteconnector`](https://github.com/flyteorg/flyte/tree/master/charts/flyteconnector) pod that has this package installed and an IAM identity allowed to call EMR Serverless `StartJobRun` / `GetJobRun` / `CancelJobRun` and to read/write the script-staging S3 prefix. + +## Usage + +### Pythonic Spark task + +```python +from flytekit import task, workflow +from flytekitplugins.awsemrserverless import EMRServerless, EMRServerlessSparkJobDriver + + +@task( + task_config=EMRServerless( + application_id="00fhabc12345", + execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole", + region="us-east-1", + job_driver=EMRServerlessSparkJobDriver( + spark_submit_parameters="--conf spark.executor.cores=2 --conf spark.executor.memory=4g", + ), + ), +) +def spark_count() -> int: + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + return spark.range(1_000_000).count() + + +@workflow +def wf() -> int: + return spark_count() +``` + +The plugin serializes the task body, uploads it to S3, and EMR Serverless runs it inside the worker image you have associated with the application. + +### Script Spark task + +```python +@task( + task_config=EMRServerless( + application_id="00fhabc12345", + execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole", + region="us-east-1", + job_driver=EMRServerlessSparkJobDriver( + entry_point="s3://my-bucket/scripts/main.py", + entry_point_arguments=["--date", "2025-01-01"], + spark_submit_parameters="--conf spark.executor.memory=4g", + ), + ), +) +def submit_script(): + ... +``` + +### Hive task + +```python +from flytekitplugins.awsemrserverless import EMRServerless, EMRServerlessHiveJobDriver + + +@task( + task_config=EMRServerless( + application_id="00fhabc12345", + execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole", + region="us-east-1", + job_driver=EMRServerlessHiveJobDriver( + query="SELECT COUNT(*) FROM my_table", + ), + ), +) +def hive_query(): + ... +``` + +## Worker image + +For Pythonic Spark tasks the worker image must contain `flytekit` and this plugin so the executor can rehydrate the task object on the EMR Serverless side. Build it from an EMR Serverless Spark base image and `pip install flytekitplugins-awsemrserverless`. Script Spark and Hive jobs do not require flytekit on the worker. + +## IAM + +The connector pod's IAM principal needs: + +- `emr-serverless:StartJobRun`, `emr-serverless:GetJobRun`, `emr-serverless:CancelJobRun` on the target application +- `s3:GetObject` / `s3:PutObject` on the script-staging prefix (Pythonic mode) +- `iam:PassRole` for the EMR Serverless execution role + +The execution role attached to the EMR Serverless application is the role the workers run as and needs whatever data-access permissions your jobs require. + +## Discussion + +Tracking issue: [flyteorg/flyte#7286](https://github.com/flyteorg/flyte/issues/7286). diff --git a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/__init__.py b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/__init__.py new file mode 100644 index 0000000000..abc6be9ffd --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/__init__.py @@ -0,0 +1,29 @@ +""" +.. currentmodule:: flytekitplugins.awsemrserverless + +This plugin enables running Spark and Hive jobs on AWS EMR Serverless from +Flyte workflows. It exposes an async connector that handles the EMR +Serverless job lifecycle (submit, poll, cancel) and a task config type. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + EMRServerless + EMRServerlessSparkJobDriver + EMRServerlessHiveJobDriver + EMRServerlessTask + EMRServerlessConnector + EMRServerlessJobMetadata +""" + +from flytekitplugins.awsemrserverless.connector import ( + EMRServerlessConnector, + EMRServerlessJobMetadata, +) +from flytekitplugins.awsemrserverless.task import ( + EMRServerless, + EMRServerlessHiveJobDriver, + EMRServerlessSparkJobDriver, + EMRServerlessTask, +) diff --git a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/_entrypoint.py b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/_entrypoint.py new file mode 100644 index 0000000000..c93288fcd3 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/_entrypoint.py @@ -0,0 +1,203 @@ +""" +EMR Serverless Pythonic-mode entrypoint script. + +This file is the canonical source of the bootstrap script that EMR +Serverless workers execute as ``sparkSubmit.entryPoint`` for Pythonic +tasks (i.e. tasks that do not provide an explicit ``spark_job_driver``). + +How it is delivered to EMR +-------------------------- + +The connector pod: + +1. reads this file from its own ``site-packages`` install; +2. computes ``hashlib.sha256(content)[:12]`` and uploads (idempotently) + to ``s3:///flyte/emr-serverless/entrypoint-.py``; +3. passes that S3 URI as ``StartJobRun.jobDriver.sparkSubmit.entryPoint``. + +EMR Serverless then downloads this script onto the Spark driver and +runs it with ``spark-submit``. ``sys.argv[1:]`` carries the actual +``pyflyte-fast-execute`` (or ``pyflyte-execute``) invocation that +should run inside the worker, plus the fast-registration distribution +arguments. + +Why a custom entrypoint at all +------------------------------ + +EMR Serverless's API requires ``sparkSubmit.entryPoint`` to be a +single Python file URI -- there is no "container as entrypoint" +escape hatch (cf. SageMaker / Batch / ECS, which run the container +itself). We need a thin shim that: + +* downloads the fast-registration tarball from Flyte's blob store, +* invokes ``pyflyte-fast-execute`` with the right resolver arguments, +* converts Flytekit's "exit-0-on-user-error" semantics into a non-zero + exit so EMR reports ``FAILED`` instead of ``SUCCESS``. + +This file deliberately has only ``flytekit`` as a runtime dependency +(specifically ``flytekit.tools.fast_registration.download_distribution``) +because it runs *inside the EMR worker*, not the connector pod. + +Editing this file +----------------- + +Treat this file as part of the connector's *runtime contract* with +EMR workers, not as plugin internals: + +* changes here propagate to every Pythonic-mode job on the next + connector deploy via the content hash in the S3 key; +* the corresponding unit tests live in ``tests/test_entrypoint.py`` + and exercise this module both as imported Python and as the + spawned subprocess EMR sees; +* upstream alignment: this is the EMR analogue of + ``flytetools/flytekitplugins/databricks/entrypoint.py`` -- + same shape, different transport (S3 instead of GitHub). +""" + +import os +import signal +import subprocess +import sys + +from flytekit.tools.fast_registration import download_distribution + + +def _run_subprocess(cmd, env=None): + """Run ``cmd`` and forward SIGTERM, returning ``(returncode, stderr_text)``. + + stdout streams through to the parent (Spark driver stdout); stderr is + captured so the caller can inspect it for Flytekit's user-error banner. + """ + p = subprocess.Popen(cmd, env=env, stderr=subprocess.PIPE, stdout=None) + signal.signal(signal.SIGTERM, lambda s, f: p.send_signal(s)) + _, stderr_bytes = p.communicate() + stderr_text = stderr_bytes.decode("utf-8", errors="replace") if stderr_bytes else "" + if stderr_text: + sys.stderr.write(stderr_text) + return p.returncode, stderr_text + + +def _exit_with_code(rc, stderr_text=""): + """Translate Flytekit subprocess exit semantics into EMR-correct exits. + + Flytekit's ``pyflyte-execute`` catches user exceptions, writes the + error to the Flyte output blob, and exits ``0`` -- by design for + K8s-based agents where FlytePropeller reads the output. + + In EMR Serverless the connector only polls EMR job state + (``SUCCESS`` / ``FAILED``) and cannot read the output blobs. If + ``pyflyte-execute`` exits ``0`` but the user function failed, EMR + reports ``SUCCESS`` and the connector wrongly reports ``SUCCEEDED``. + + Detect this by scanning stderr for Flytekit's error banner. When + found, force a non-zero exit so Spark fails the driver and EMR + reports ``FAILED``. + """ + if rc == 0 and "User Error Captured by Flyte" in stderr_text: + print( + "[flyte-entrypoint] pyflyte-execute exited 0 but stderr contains " + "a user error -- forcing non-zero exit so EMR reports FAILED", + file=sys.stderr, + ) + sys.exit(1) + if rc != 0: + print(f"[flyte-entrypoint] Task process exited with code {rc}", file=sys.stderr) + sys.exit(rc) + + +def _parse_fast_execute_args(args): + """Split a ``pyflyte-fast-execute ...`` argv into its three pieces. + + Returns ``(additional_distribution, dest_dir, task_cmd_start)`` + where ``task_cmd_start`` is the index in ``args`` where the + underlying ``pyflyte-execute ...`` command begins. + + Recognises the two-arg flag forms emitted by Flytekit: + ``--additional-distribution `` + ``--dest-dir `` + and the optional ``--`` separator before the inner command. + """ + additional_distribution = None + dest_dir = None + task_cmd_start = 0 + + i = 1 + while i < len(args): + if args[i] == "--additional-distribution" and i + 1 < len(args): + additional_distribution = args[i + 1] + i += 2 + elif args[i] == "--dest-dir" and i + 1 < len(args): + dest_dir = args[i + 1] + i += 2 + elif args[i] == "--": + task_cmd_start = i + 1 + break + else: + task_cmd_start = i + break + + return additional_distribution, dest_dir, task_cmd_start + + +def _build_resolver_command(task_execute_cmd, additional_distribution, dest_dir): + """Inject the fast-registration distribution args before ``--resolver``. + + ``pyflyte-execute`` resolves task callables via a resolver plugin + (default ``flytekit.core.python_auto_container.default_task_resolver``). + For fast-registered code, the resolver needs to know where the + extracted source tree lives, which we inject as + ``--dynamic-addl-distro`` / ``--dynamic-dest-dir`` immediately + before ``--resolver``. + """ + cmd = [] + for arg in task_execute_cmd: + if arg == "--resolver": + cmd.extend( + [ + "--dynamic-addl-distro", + additional_distribution or "", + "--dynamic-dest-dir", + dest_dir or "", + ] + ) + cmd.append(arg) + return cmd + + +def main(): + args = sys.argv[1:] + if not args: + print("Usage: entrypoint.py pyflyte-fast-execute|pyflyte-execute ...", file=sys.stderr) + sys.exit(1) + + if args[0] == "pyflyte-fast-execute": + additional_distribution, dest_dir, task_cmd_start = _parse_fast_execute_args(args) + task_execute_cmd = list(args[task_cmd_start:]) + + if additional_distribution: + if not dest_dir: + dest_dir = os.getcwd() + download_distribution(additional_distribution, dest_dir) + + cmd = _build_resolver_command(task_execute_cmd, additional_distribution, dest_dir) + + env = os.environ.copy() + if dest_dir: + resolved = os.path.realpath(os.path.expanduser(dest_dir)) + env["PYTHONPATH"] = resolved + os.pathsep + env.get("PYTHONPATH", "") + rc, stderr_text = _run_subprocess(cmd, env) + _exit_with_code(rc, stderr_text) + + elif args[0] == "pyflyte-execute": + env = os.environ.copy() + env.setdefault("PYTHONPATH", os.getcwd()) + rc, stderr_text = _run_subprocess(args, env) + _exit_with_code(rc, stderr_text) + + else: + print(f"Unrecognized command: {args}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/boto_handler.py b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/boto_handler.py new file mode 100644 index 0000000000..c686667ce1 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/boto_handler.py @@ -0,0 +1,398 @@ +""" +Boto3 helper for EMR Serverless API operations. + +Provides a thin async wrapper around the boto3 EMR Serverless client, with +application-lifecycle management used by the connector. + +All outbound boto3 traffic is funneled through two private helpers so that +tests can mock the entire handler at a single boundary (matches the +flytekit-aws-sagemaker ``Boto3ConnectorMixin._call`` pattern): + +* :py:meth:`EMRServerlessHandler._call` -- single-shot API calls +* :py:meth:`EMRServerlessHandler._paginate` -- paginated list operations +""" + +import asyncio +import logging +import time +from typing import Any, Dict, List, Optional + +import boto3 +from botocore.config import Config as BotoConfig +from botocore.exceptions import ClientError + +logger = logging.getLogger(__name__) + +_BOTO_RETRY_CONFIG = BotoConfig(retries={"max_attempts": 3, "mode": "adaptive"}) + +_APP_STARTED = "STARTED" +_APP_TERMINAL = {"TERMINATED"} +_APP_NEEDS_START = {"CREATED", "STOPPED"} +_APP_TRANSITIONING = {"CREATING", "STARTING"} + +_JOB_TERMINAL = {"SUCCESS", "FAILED", "CANCELLED"} + + +class EMRServerlessHandler: + """ + Async-friendly wrapper around the boto3 EMR Serverless client. + + Retries are handled by botocore's built-in adaptive retry mode. + Blocking boto3 calls are offloaded to a thread-pool executor so they + do not block the asyncio event loop. + """ + + def __init__(self, region: Optional[str] = None): + self.region = region + self._client: Optional[Any] = None + logger.debug("EMRServerlessHandler initialized: region=%s", region or "(default)") + + @property + def client(self) -> Any: + if self._client is None: + kwargs: Dict[str, Any] = { + "service_name": "emr-serverless", + "config": _BOTO_RETRY_CONFIG, + } + if self.region: + kwargs["region_name"] = self.region + self._client = boto3.client(**kwargs) + logger.debug( + "Boto3 EMR Serverless client created: region=%s", + self._client.meta.region_name, + ) + return self._client + + async def _call(self, method: str, **params: Any) -> Dict[str, Any]: + """Invoke a boto3 EMR Serverless client method in the thread-pool executor. + + This is the single chokepoint for all non-paginated boto3 API calls. + Tests can mock the entire handler surface by patching just this method. + + Args: + method: The boto3 client method name (e.g. ``"get_application"``). + **params: Keyword arguments forwarded to the boto3 call in the + camelCase form the AWS SDK expects. + + Returns: + The raw response dict from the boto3 call. + """ + func = getattr(self.client, method) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, lambda: func(**params)) + + async def _paginate(self, method: str, result_key: str, **params: Any) -> List[Dict[str, Any]]: + """Exhaust a boto3 paginator and return the concatenated items. + + Used for list operations (e.g. ``list_applications``) that may span + multiple pages. Tests can mock this independently from ``_call``. + + Args: + method: The boto3 client method to paginate (e.g. ``"list_applications"``). + result_key: The key under which each page stores items + (e.g. ``"applications"`` for ``list_applications``). + **params: Keyword arguments forwarded to ``paginator.paginate()``. + + Returns: + A flat list of items across all pages. + """ + + def _collect() -> List[Dict[str, Any]]: + paginator = self.client.get_paginator(method) + items: List[Dict[str, Any]] = [] + for page in paginator.paginate(**params): + items.extend(page.get(result_key, [])) + return items + + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, _collect) + + # ------------------------------------------------------------------ + # Application management + # ------------------------------------------------------------------ + + async def get_application(self, application_id: str) -> Dict[str, Any]: + logger.debug("GetApplication: applicationId=%s", application_id) + resp = await self._call("get_application", applicationId=application_id) + app = resp.get("application", {}) + logger.debug( + "GetApplication response: id=%s, name=%s, state=%s", + app.get("applicationId"), + app.get("name"), + app.get("state"), + ) + return app + + async def find_application_by_name(self, name: str) -> Optional[str]: + """Find an active application ID by name. Returns None if not found.""" + logger.info("Searching for application by name: '%s'", name) + apps = await self._paginate( + "list_applications", + result_key="applications", + states=["CREATED", "STARTED", "STOPPED"], + ) + for app in apps: + if app.get("name") == name: + logger.info("Found application '%s' with ID: %s", name, app["id"]) + return app["id"] + logger.info("Application '%s' not found after searching %d app(s)", name, len(apps)) + return None + + async def create_application( + self, + name: str, + release_label: str, + application_type: str, + initial_capacity: Optional[Dict[str, Any]] = None, + maximum_capacity: Optional[Dict[str, Any]] = None, + network_configuration: Optional[Dict[str, Any]] = None, + image_configuration: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, str]] = None, + architecture: Optional[str] = None, + runtime_configuration: Optional[list] = None, + scheduler_configuration: Optional[Dict[str, Any]] = None, + auto_stop_config: Optional[Dict[str, Any]] = None, + ) -> str: + logger.info( + "CreateApplication: name=%s, release_label=%s, type=%s, architecture=%s", + name, + release_label, + application_type, + architecture, + ) + params: Dict[str, Any] = { + "name": name, + "releaseLabel": release_label, + "type": application_type, + } + if initial_capacity: + params["initialCapacity"] = initial_capacity + logger.debug("CreateApplication: initialCapacity provided") + if maximum_capacity: + params["maximumCapacity"] = maximum_capacity + logger.debug("CreateApplication: maximumCapacity provided") + if network_configuration: + params["networkConfiguration"] = network_configuration + logger.debug("CreateApplication: networkConfiguration provided") + if image_configuration: + params["imageConfiguration"] = image_configuration + logger.debug("CreateApplication: imageConfiguration=%s", image_configuration) + if tags: + params["tags"] = tags + logger.debug("CreateApplication: tags=%s", tags) + if architecture: + params["architecture"] = architecture + if runtime_configuration: + params["runtimeConfiguration"] = runtime_configuration + logger.debug("CreateApplication: runtimeConfiguration provided") + if scheduler_configuration: + params["schedulerConfiguration"] = scheduler_configuration + logger.debug("CreateApplication: schedulerConfiguration=%s", scheduler_configuration) + if auto_stop_config: + params["autoStopConfiguration"] = auto_stop_config + logger.debug("CreateApplication: autoStopConfiguration=%s", auto_stop_config) + + resp = await self._call("create_application", **params) + app_id = resp.get("applicationId", "") + logger.info("CreateApplication succeeded: applicationId=%s, arn=%s", app_id, resp.get("arn", "")) + return app_id + + async def update_application( + self, + application_id: str, + image_configuration: Optional[Dict[str, Any]] = None, + maximum_capacity: Optional[Dict[str, Any]] = None, + auto_stop_config: Optional[Dict[str, Any]] = None, + runtime_configuration: Optional[list] = None, + scheduler_configuration: Optional[Dict[str, Any]] = None, + ) -> None: + """Update a mutable subset of application properties.""" + params: Dict[str, Any] = {"applicationId": application_id} + update_fields = [] + + if image_configuration: + params["imageConfiguration"] = image_configuration + update_fields.append("imageConfiguration") + if maximum_capacity: + params["maximumCapacity"] = maximum_capacity + update_fields.append("maximumCapacity") + if auto_stop_config: + params["autoStopConfiguration"] = auto_stop_config + update_fields.append("autoStopConfiguration") + if runtime_configuration: + params["runtimeConfiguration"] = runtime_configuration + update_fields.append("runtimeConfiguration") + if scheduler_configuration: + params["schedulerConfiguration"] = scheduler_configuration + update_fields.append("schedulerConfiguration") + + if len(params) <= 1: + logger.debug("UpdateApplication: no fields to update for %s, skipping", application_id) + return + + logger.info("UpdateApplication: applicationId=%s, fields=%s", application_id, update_fields) + await self._call("update_application", **params) + logger.info("UpdateApplication succeeded for %s", application_id) + + async def ensure_application_started( + self, + application_id: str, + timeout_seconds: int = 300, + poll_interval_seconds: int = 10, + ) -> None: + """Ensure the application reaches STARTED state, starting it if needed.""" + logger.info( + "EnsureApplicationStarted: applicationId=%s, timeout=%ds", + application_id, + timeout_seconds, + ) + start_time = time.monotonic() + + while True: + app = await self.get_application(application_id) + state = app.get("state", "") + + if state == _APP_STARTED: + elapsed = time.monotonic() - start_time + logger.info( + "Application %s is in STARTED state (waited %.1fs)", + application_id, + elapsed, + ) + return + + if state in _APP_TERMINAL: + logger.error( + "Application %s is in terminal state '%s', cannot be started", + application_id, + state, + ) + raise RuntimeError(f"Application {application_id} is in terminal state '{state}' and cannot be started") + + if state in _APP_NEEDS_START: + logger.info("Application %s is in state '%s', sending StartApplication request", application_id, state) + await self._call("start_application", applicationId=application_id) + + elapsed = time.monotonic() - start_time + if elapsed >= timeout_seconds: + logger.error( + "Timed out after %.1fs waiting for application %s to start (last state: %s)", + elapsed, + application_id, + state, + ) + raise TimeoutError( + f"Timed out after {timeout_seconds}s waiting for application {application_id} to start " + f"(last state: {state})" + ) + + logger.debug( + "Application %s state: %s, waiting %ds (elapsed: %.1fs / %ds)", + application_id, + state, + poll_interval_seconds, + elapsed, + timeout_seconds, + ) + await asyncio.sleep(poll_interval_seconds) + + # ------------------------------------------------------------------ + # Job management + # ------------------------------------------------------------------ + + async def start_job_run( + self, + application_id: str, + execution_role_arn: str, + job_driver: Dict[str, Any], + configuration_overrides: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, str]] = None, + execution_timeout_minutes: int = 60, + name: Optional[str] = None, + retry_policy: Optional[Dict[str, Any]] = None, + ) -> str: + logger.info( + "StartJobRun: applicationId=%s, name=%s, timeout=%dm", + application_id, + name, + execution_timeout_minutes, + ) + params: Dict[str, Any] = { + "applicationId": application_id, + "executionRoleArn": execution_role_arn, + "jobDriver": job_driver, + "executionTimeoutMinutes": execution_timeout_minutes, + } + if configuration_overrides: + params["configurationOverrides"] = configuration_overrides + logger.debug("StartJobRun: configurationOverrides provided") + if tags: + params["tags"] = tags + if name: + params["name"] = name + if retry_policy: + params["retryPolicy"] = retry_policy + logger.debug("StartJobRun: retryPolicy=%s", retry_policy) + + logger.debug("StartJobRun: jobDriver type=%s", list(job_driver.keys())) + resp = await self._call("start_job_run", **params) + job_run_id = resp.get("jobRunId", "") + logger.info( + "StartJobRun succeeded: jobRunId=%s, applicationId=%s, arn=%s", + job_run_id, + application_id, + resp.get("arn", ""), + ) + return job_run_id + + async def get_job_run(self, application_id: str, job_run_id: str) -> Dict[str, Any]: + logger.debug("GetJobRun: applicationId=%s, jobRunId=%s", application_id, job_run_id) + resp = await self._call( + "get_job_run", + applicationId=application_id, + jobRunId=job_run_id, + ) + job = resp.get("jobRun", {}) + logger.debug( + "GetJobRun response: jobRunId=%s, state=%s", + job.get("jobRunId"), + job.get("state"), + ) + return job + + async def cancel_job_run(self, application_id: str, job_run_id: str) -> None: + """Cancel a running job. Idempotent -- safe to call on completed jobs.""" + logger.info("CancelJobRun: applicationId=%s, jobRunId=%s", application_id, job_run_id) + try: + job = await self.get_job_run(application_id, job_run_id) + current_state = job.get("state", "") + if current_state in _JOB_TERMINAL: + logger.info( + "CancelJobRun: job %s already in terminal state '%s', no action needed", + job_run_id, + current_state, + ) + return + logger.info("CancelJobRun: job %s is in state '%s', sending cancel request", job_run_id, current_state) + await self._call( + "cancel_job_run", + applicationId=application_id, + jobRunId=job_run_id, + ) + logger.info("CancelJobRun succeeded: jobRunId=%s", job_run_id) + except ClientError as e: + code = e.response.get("Error", {}).get("Code", "") + if code == "ResourceNotFoundException": + logger.warning( + "CancelJobRun: job %s not found (ResourceNotFoundException), may already be cleaned up", + job_run_id, + ) + return + if code == "ValidationException" and "cannot be cancelled" in str(e).lower(): + logger.info( + "CancelJobRun: job %s cannot be cancelled (ValidationException), likely already completed", + job_run_id, + ) + return + logger.error("CancelJobRun failed for job %s: %s (code: %s)", job_run_id, e, code) + raise diff --git a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/connector.py b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/connector.py new file mode 100644 index 0000000000..11cefbb026 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/connector.py @@ -0,0 +1,825 @@ +""" +Flyte Async Connector for AWS EMR Serverless. + +Handles the complete lifecycle of EMR Serverless jobs: +create (submit), get (poll status), and delete (cancel). +""" + +import hashlib +import logging +import os +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional + +from flyteidl.core.execution_pb2 import TaskExecution + +try: + from flytekit.extend.backend.base_connector import ( + AsyncConnectorBase, + ConnectorRegistry, + Resource, + ResourceMeta, + ) +except ModuleNotFoundError: + from flytekit.extend.backend.base_agent import ( + AgentRegistry as ConnectorRegistry, + ) + from flytekit.extend.backend.base_agent import ( + AsyncAgentBase as AsyncConnectorBase, + ) + from flytekit.extend.backend.base_agent import ( + Resource, + ResourceMeta, + ) + +from flytekit.extend.backend.utils import convert_to_flyte_phase +from flytekit.models.core.execution import TaskLog +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +from flytekitplugins.awsemrserverless.boto_handler import EMRServerlessHandler +from flytekitplugins.awsemrserverless.task import EMR_SERVERLESS_BASE_IMAGE, EMRServerless + +logger = logging.getLogger(__name__) + +# Maps EMR Serverless job run states to the canonical flytekit phase strings +# accepted by ``convert_to_flyte_phase`` ("running", "success", "failed"). +# This mirrors the upstream flytekit plugin convention (see +# ``plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/connector.py``). +EMR_SERVERLESS_STATES = { + "PENDING": "Running", + "SCHEDULED": "Running", + "SUBMITTED": "Running", + "RUNNING": "Running", + "SUCCESS": "Success", + "FAILED": "Failed", + "CANCELLING": "Running", + "CANCELLED": "Failed", +} + +FLYTE_TAGS = {"Application": "flyte", "ManagedBy": "flyte-connector"} + +_APPLICATION_ID_RE = re.compile(r"^[0-9a-z]+$") + +_ENV_ALLOW_CREATE_APPLICATION = "FLYTE_EMR_ALLOW_CREATE_APPLICATION" +_ENV_APPLICATION_NAME_PREFIX = "FLYTE_EMR_APPLICATION_NAME_PREFIX" + + +@dataclass +class EMRServerlessJobMetadata(ResourceMeta): + """Metadata persisted by FlytePropeller between connector calls.""" + + application_id: str + job_run_id: str + region: str + created_application: bool = False + + +class EMRServerlessConnector(AsyncConnectorBase): + """ + Flyte Connector for AWS EMR Serverless. + + Supports two execution modes: + + * **Script mode** -- the user provides an explicit ``spark_job_driver`` + or ``hive_job_driver`` in the task config. The connector submits the + job as-is. + * **Pythonic mode** -- no job driver is provided. The connector reads + ``task_template.container`` (image + args) and constructs a + ``sparkSubmit`` job that runs the Flytekit entrypoint inside the + user's container image on EMR Serverless. + """ + + name = "EMR Serverless Connector" + + def __init__(self): + super().__init__( + task_type_name="emr_serverless", + metadata_type=EMRServerlessJobMetadata, + ) + + def _get_handler(self, region: Optional[str] = None) -> EMRServerlessHandler: + return EMRServerlessHandler(region=region) + + @staticmethod + def _extract_config(task_template: TaskTemplate) -> EMRServerless: + custom = task_template.custom + if not custom: + raise ValueError("Task template has no custom configuration") + + if hasattr(custom, "fields"): + from google.protobuf.json_format import MessageToDict + + config_dict = MessageToDict(custom) + else: + config_dict = dict(custom) + + config = EMRServerless.from_dict(config_dict) + logger.debug( + "Extracted task config: application_id=%s, application_name=%s, " + "application_type=%s, region=%s, script_mode=%s, sync_image=%s", + config.application_id, + config.application_name, + config.application_type, + config.region, + config.is_script_mode, + config.sync_image, + ) + return config + + # The Pythonic-mode entrypoint script that EMR workers execute as + # ``sparkSubmit.entryPoint``. Lives in its own module so that: + # * it is reviewable / blame-able in version control, + # * it can be unit-tested as Python (see ``tests/test_entrypoint.py``), + # * the connector ships it via the same wheel it ships in, + # mirroring the upstream Databricks pattern of treating the + # entrypoint as a first-class plugin asset. + _ENTRYPOINT_PATH: Path = (Path(__file__).parent / "_entrypoint.py").resolve() + + @classmethod + def _read_entrypoint_bytes(cls) -> bytes: + """Read the entrypoint file as bytes. Cached on the class. + + Reading from disk (instead of importing the module and using + ``inspect.getsource``) keeps the byte-stream byte-for-byte + identical to what EMR will execute, which is the input to the + content hash. + """ + cached = getattr(cls, "_entrypoint_bytes_cache", None) + if cached is None: + cached = cls._ENTRYPOINT_PATH.read_bytes() + cls._entrypoint_bytes_cache = cached + return cached + + @staticmethod + def _resolve_entrypoint_bucket(config: EMRServerless) -> str: + """Determine the S3 bucket for the entrypoint script. + + Resolution order: + 1. FLYTE_EMR_ENTRYPOINT_S3_BUCKET env var + 2. The S3 monitoring log URI from configuration_overrides (reuses the same bucket) + 3. FLYTE_AWS_S3_BUCKET env var (common Flyte storage bucket) + """ + bucket = os.environ.get("FLYTE_EMR_ENTRYPOINT_S3_BUCKET") + if bucket: + resolved = bucket.removeprefix("s3://").split("/")[0] + logger.debug("Resolved entrypoint bucket from FLYTE_EMR_ENTRYPOINT_S3_BUCKET: %s", resolved) + return resolved + + if config.configuration_overrides: + log_uri = ( + config.configuration_overrides.get("monitoringConfiguration", {}) + .get("s3MonitoringConfiguration", {}) + .get("logUri", "") + ) + if log_uri.startswith("s3://"): + resolved = log_uri.removeprefix("s3://").split("/")[0] + logger.debug("Resolved entrypoint bucket from monitoringConfiguration logUri: %s", resolved) + return resolved + + bucket = os.environ.get("FLYTE_AWS_S3_BUCKET") + if bucket: + resolved = bucket.removeprefix("s3://").split("/")[0] + logger.debug("Resolved entrypoint bucket from FLYTE_AWS_S3_BUCKET: %s", resolved) + return resolved + + raise ValueError( + "Pythonic mode needs an S3 bucket to upload the Flytekit entrypoint. " + "Set FLYTE_EMR_ENTRYPOINT_S3_BUCKET env var, or add " + "monitoringConfiguration.s3MonitoringConfiguration.logUri to " + "configuration_overrides in your task config." + ) + + def _ensure_entrypoint_on_s3(self, handler: EMRServerlessHandler, config: EMRServerless) -> str: + """Upload the Flytekit entrypoint to S3 if not present, return the s3:// URI. + + The S3 key is content-addressed (``entrypoint-.py``) + so each plugin version produces a unique, immutable artifact. + Old hashes remain in the bucket indefinitely so in-flight EMR + jobs that captured an older URI in their job spec keep working + across connector upgrades. + """ + content = self._read_entrypoint_bytes() + content_hash = hashlib.sha256(content).hexdigest()[:12] + + bucket = self._resolve_entrypoint_bucket(config) + + key = f"flyte/emr-serverless/entrypoint-{content_hash}.py" + s3_uri = f"s3://{bucket}/{key}" + + import boto3 + + s3 = boto3.client("s3", region_name=handler.region) + try: + s3.head_object(Bucket=bucket, Key=key) + logger.debug("Entrypoint already exists at %s", s3_uri) + except Exception: + logger.info("Uploading Flytekit entrypoint to %s", s3_uri) + s3.put_object(Bucket=bucket, Key=key, Body=content, ContentType="text/x-python") + + return s3_uri + + _PYTHONIC_SPARK_DEFAULTS = ( + "--conf spark.emr-serverless.driverEnv.PYSPARK_DRIVER_PYTHON=/usr/bin/python3 " + "--conf spark.emr-serverless.driverEnv.PYSPARK_PYTHON=/usr/bin/python3 " + "--conf spark.executorEnv.PYSPARK_PYTHON=/usr/bin/python3" + ) + + # Characters that would break ``--conf k=v`` tokenisation if embedded in + # an env value. EMR Serverless does not support quoting values here, + # so we drop the variable rather than risk a malformed sparkSubmit line. + _UNSAFE_ENV_VALUE_CHARS = re.compile(r"[\s='\"]") + + @staticmethod + def _build_flyte_env_vars( + task_execution_metadata: Optional[Any], + ) -> Dict[str, str]: + """Return a dict of FLYTE_INTERNAL_* env vars derived from the + task execution metadata, plus any user-supplied env vars from + ``task_execution_metadata.environment_variables``. + + Returns an empty dict when ``task_execution_metadata`` is ``None`` + (e.g. local / unit-test execution). + """ + if task_execution_metadata is None: + return {} + + env: Dict[str, str] = {} + + task_exec_id = getattr(task_execution_metadata, "task_execution_id", None) + if task_exec_id is None: + return env + + task_id = getattr(task_exec_id, "task_id", None) + if task_id is not None: + env["FLYTE_INTERNAL_TASK_PROJECT"] = getattr(task_id, "project", "") or "" + env["FLYTE_INTERNAL_TASK_DOMAIN"] = getattr(task_id, "domain", "") or "" + env["FLYTE_INTERNAL_TASK_NAME"] = getattr(task_id, "name", "") or "" + env["FLYTE_INTERNAL_TASK_VERSION"] = getattr(task_id, "version", "") or "" + + node_exec_id = getattr(task_exec_id, "node_execution_id", None) + if node_exec_id is not None: + env["FLYTE_INTERNAL_NODE_ID"] = getattr(node_exec_id, "node_id", "") or "" + wf_exec_id = getattr(node_exec_id, "execution_id", None) + if wf_exec_id is not None: + env["FLYTE_INTERNAL_EXECUTION_ID"] = getattr(wf_exec_id, "name", "") or "" + env["FLYTE_INTERNAL_EXECUTION_PROJECT"] = getattr(wf_exec_id, "project", "") or "" + env["FLYTE_INTERNAL_EXECUTION_DOMAIN"] = getattr(wf_exec_id, "domain", "") or "" + + retry = getattr(task_exec_id, "retry_attempt", None) + if retry is not None: + env["FLYTE_INTERNAL_TASK_RETRY_ATTEMPT"] = str(retry) + + user_env = getattr(task_execution_metadata, "environment_variables", None) or {} + for k, v in user_env.items(): + if k and v is not None: + env[str(k)] = str(v) + + return {k: v for k, v in env.items() if v != ""} + + @classmethod + def _format_env_as_spark_conf(cls, env_vars: Dict[str, str]) -> str: + """Convert a dict of env vars into Spark driver+executor ``--conf`` flags. + + Produces entries of the form:: + + --conf spark.emr-serverless.driverEnv.KEY=VALUE + --conf spark.executorEnv.KEY=VALUE + + Values containing whitespace or quote characters are skipped with + a warning (EMR Serverless does not support escaping here). + """ + if not env_vars: + return "" + + parts = [] + for key, value in env_vars.items(): + if cls._UNSAFE_ENV_VALUE_CHARS.search(value): + logger.warning( + "Skipping env var %s: value contains characters that would " + "break sparkSubmitParameters tokenisation", + key, + ) + continue + parts.append(f"--conf spark.emr-serverless.driverEnv.{key}={value}") + parts.append(f"--conf spark.executorEnv.{key}={value}") + return " ".join(parts) + + @classmethod + def _append_flyte_env_to_spark_params( + cls, + existing: Optional[str], + task_execution_metadata: Optional[Any], + *, + enabled: bool, + ) -> Optional[str]: + """Append Flyte context env vars to an existing sparkSubmitParameters. + + Returns ``existing`` unchanged when ``enabled`` is False, when no + metadata is available, or when there are no env vars to inject. + """ + if not enabled: + return existing + env_vars = cls._build_flyte_env_vars(task_execution_metadata) + if not env_vars: + return existing + conf = cls._format_env_as_spark_conf(env_vars) + if not conf: + return existing + logger.info("Injecting %d Flyte env var(s) into sparkSubmitParameters", len(env_vars)) + logger.debug("Injected Flyte env vars: %s", sorted(env_vars.keys())) + return f"{existing} {conf}".strip() if existing else conf + + def _build_pythonic_job_driver( + self, task_template: TaskTemplate, config: EMRServerless, handler: EMRServerlessHandler + ) -> Dict[str, Any]: + """ + Build a sparkSubmit job driver from the Flyte container entrypoint. + + The Databricks connector fetches its entrypoint from a Git repo. + EMR Serverless doesn't support Git sources, so we upload the + Flytekit entrypoint script to S3 and reference it from there. + ``container.args`` are passed as entrypoint arguments so that + the task function is executed inside the EMR Spark runtime. + + Pythonic mode requires a custom Docker image on the EMR Serverless + application that includes flytekit and its dependencies. The driver + and executor Python paths are set explicitly so the container's + Python (which has flytekit) is used instead of the EMR default. + """ + container = task_template.container + if container is None: + raise ValueError( + "Pythonic mode requires a container image. Either provide a " + "spark_job_driver for script mode, or ensure the task has a container image." + ) + + logger.info( + "Building Pythonic mode job driver: container_image=%s, args_count=%d", + container.image, + len(container.args) if container.args else 0, + ) + logger.debug("Pythonic mode container.args: %s", list(container.args) if container.args else []) + + entrypoint_s3_uri = self._ensure_entrypoint_on_s3(handler, config) + logger.info("Entrypoint S3 URI: %s", entrypoint_s3_uri) + + entry_point_args = list(container.args) if container.args else [] + + user_params = config.spark_submit_parameters + spark_submit_params = self._merge_spark_submit_params(user_params) + if user_params: + logger.info("Merged user spark_submit_parameters with Pythonic defaults") + logger.debug("Final sparkSubmitParameters: %s", spark_submit_params) + + spark_submit: Dict[str, Any] = { + "entryPoint": entrypoint_s3_uri, + "entryPointArguments": entry_point_args, + } + if spark_submit_params: + spark_submit["sparkSubmitParameters"] = spark_submit_params + + return {"sparkSubmit": spark_submit} + + def _merge_spark_submit_params(self, user_params: Optional[str]) -> str: + """Merge user-provided spark submit params with Pythonic mode defaults. + + User-provided values take precedence -- if the user already sets + e.g. ``spark.executorEnv.PYSPARK_PYTHON``, the default is skipped. + """ + defaults = self._PYTHONIC_SPARK_DEFAULTS + if not user_params: + return defaults + + merged_parts = [] + for token in defaults.split("--conf "): + token = token.strip() + if not token: + continue + key = token.split("=", 1)[0] + if key not in user_params: + merged_parts.append(f"--conf {token}") + + return f"{user_params} {' '.join(merged_parts)}".strip() + + @staticmethod + def _resolve_image_configuration( + task_template: TaskTemplate, + config: EMRServerless, + *, + for_create: bool = False, + ) -> Optional[Dict[str, Any]]: + """Derive EMR Serverless ``imageConfiguration`` following community patterns. + + Resolution order: + 1. Explicit ``image_configuration`` in the task config -- always wins. + 2. ``container_image`` on the task (``task_template.container.image``). + 3. For script mode, ``None`` is fine (default EMR image works). + 4. For Pythonic mode when creating a new app, fall back to the base image. + 5. For Pythonic mode on an existing app, raise with a clear message + (the app should already have a custom image). + """ + logger.debug("Resolving image configuration: for_create=%s, script_mode=%s", for_create, config.is_script_mode) + + if config.image_configuration: + logger.info("Using explicit image_configuration from task config: %s", config.image_configuration) + return config.image_configuration + + container = task_template.container + if container and getattr(container, "image", None): + logger.info( + "Deriving imageConfiguration from container_image: %s", + container.image, + ) + return {"imageUri": container.image} + + if config.is_script_mode: + logger.debug("Script mode with no custom image; using default EMR image") + return None + + if for_create: + logger.info( + "No image specified for new application; using default base image: %s", + EMR_SERVERLESS_BASE_IMAGE, + ) + return {"imageUri": EMR_SERVERLESS_BASE_IMAGE} + + raise ValueError( + "Pythonic mode requires a custom Docker image with flytekit installed " + "on the EMR Serverless workers. Provide a container_image on the task:\n\n" + ' @task(task_config=EMRServerless(...), container_image="")\n\n' + "Or use an existing application_id whose application already has a " + "custom image configured." + ) + + def _merge_tags(self, user_tags: Optional[Dict[str, str]]) -> Dict[str, str]: + tags = dict(FLYTE_TAGS) + if user_tags: + tags.update(user_tags) + return tags + + @staticmethod + def _get_container_image(task_template: TaskTemplate) -> Optional[str]: + """Extract the container image URI from the task template.""" + container = task_template.container + if container and getattr(container, "image", None): + return container.image + return None + + @staticmethod + def _is_create_application_allowed() -> bool: + """Check the connector-level policy for application creation. + + Reads ``FLYTE_EMR_ALLOW_CREATE_APPLICATION`` (default ``"false"``). + Only ``"true"`` (case-insensitive) permits dynamic creation. + """ + raw_value = os.environ.get(_ENV_ALLOW_CREATE_APPLICATION, "false") + allowed = raw_value.strip().lower() == "true" + logger.debug( + "Application creation policy check: %s=%r, allowed=%s", + _ENV_ALLOW_CREATE_APPLICATION, + raw_value, + allowed, + ) + return allowed + + @staticmethod + def _apply_application_name_prefix(name: str) -> str: + """Prepend the connector-level prefix to the application name. + + Reads ``FLYTE_EMR_APPLICATION_NAME_PREFIX``. When set (e.g. + ``flyte-prod-``), the final name becomes ``flyte-prod-my-etl-app``. + """ + prefix = os.environ.get(_ENV_APPLICATION_NAME_PREFIX, "").strip() + if prefix and not name.startswith(prefix): + prefixed = f"{prefix}{name}" + logger.info("Applied application name prefix: '%s' -> '%s'", name, prefixed) + return prefixed + if prefix: + logger.debug("Application name '%s' already starts with prefix '%s', skipping", name, prefix) + return name + + async def _sync_application_image( + self, + handler: EMRServerlessHandler, + application_id: str, + task_template: TaskTemplate, + config: EMRServerless, + ) -> None: + """Update the application image if ``container_image`` differs from the app. + + The desired image is derived from (in order): + 1. ``task_template.container.image`` (the ``container_image`` decorator arg) + 2. ``config.image_configuration`` (explicit ``image_configuration`` in task_config) + + If neither is set, this is a no-op (script mode with default EMR image). + """ + logger.debug("Image sync: checking application %s", application_id) + + if config.is_script_mode and not config.image_configuration: + logger.debug( + "Image sync: script mode with no explicit image_configuration, skipping for application %s", + application_id, + ) + return + + desired_uri = self._get_container_image(task_template) + if not desired_uri and config.image_configuration: + desired_uri = config.image_configuration.get("imageUri", "") + logger.debug("Image sync: using image_configuration as fallback: %s", desired_uri) + if not desired_uri: + logger.debug("Image sync: no desired image URI found, skipping for application %s", application_id) + return + + logger.debug("Image sync: desired image for application %s: %s", application_id, desired_uri) + app = await handler.get_application(application_id) + current_image = app.get("imageConfiguration", {}).get("imageUri", "") + logger.debug("Image sync: current image on application %s: %s", application_id, current_image or "(none)") + + if current_image == desired_uri: + logger.info( + "Image sync: application %s already has image %s, no update needed", application_id, desired_uri + ) + return + + logger.info( + "Image sync: updating application %s image: %s -> %s", + application_id, + current_image or "(none)", + desired_uri, + ) + await handler.update_application( + application_id=application_id, + image_configuration={"imageUri": desired_uri}, + ) + logger.info("Image sync: successfully updated application %s image", application_id) + + async def create( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + task_execution_metadata: Optional[Any] = None, + **kwargs: Any, + ) -> EMRServerlessJobMetadata: + task_name = task_template.id.name if task_template.id else "(unknown)" + logger.info("create() called for task: %s", task_name) + + if task_execution_metadata is None: + task_execution_metadata = kwargs.get("task_execution_metadata") + + config = self._extract_config(task_template) + handler = self._get_handler(config.region) + + application_id = config.application_id + created_application = False + + # --- Resolve application by name if the ID looks like a name --- + if application_id and not _APPLICATION_ID_RE.match(application_id): + app_name = application_id + logger.info( + "application_id '%s' does not look like an AWS ID, treating as application name", + app_name, + ) + resolved_id = await handler.find_application_by_name(app_name) + if resolved_id: + logger.info("Resolved application name '%s' to ID: %s", app_name, resolved_id) + application_id = resolved_id + else: + raise ValueError( + f"No active EMR Serverless application found with name '{app_name}'. " + f"application_id must be a valid AWS application ID (e.g. '00fm0lpr3kbcq60p') " + f"or the name of an existing application." + ) + + # --- Create a new application if needed --- + if not application_id: + logger.info("No application_id provided, checking if dynamic creation is possible") + + if not config.application_name: + logger.error("No application_id and no application_name specified, cannot proceed") + raise ValueError( + "No application_id provided and no application_name specified. " + "Either set application_id to an existing EMR Serverless application, " + "or set application_name to create a new one " + f"(requires {_ENV_ALLOW_CREATE_APPLICATION}=true on the connector)." + ) + + if not self._is_create_application_allowed(): + logger.warning( + "Dynamic application creation requested for '%s' but blocked by connector policy", + config.application_name, + ) + raise ValueError( + "Dynamic application creation is disabled by the connector " + f"({_ENV_ALLOW_CREATE_APPLICATION} is not 'true'). " + "Contact your platform team to enable it, or set application_id " + "to an existing EMR Serverless application." + ) + + app_name = self._apply_application_name_prefix(config.application_name) + + image_config = self._resolve_image_configuration(task_template, config, for_create=True) + + logger.info( + "Creating new EMR Serverless application: name=%s, release_label=%s, type=%s, architecture=%s", + app_name, + config.release_label, + config.application_type, + config.architecture, + ) + application_id = await handler.create_application( + name=app_name, + release_label=config.release_label, + application_type=config.application_type, + initial_capacity=config.initial_capacity, + maximum_capacity=config.maximum_capacity, + network_configuration=config.network_configuration, + image_configuration=image_config, + tags=self._merge_tags(config.tags), + architecture=config.architecture, + runtime_configuration=config.runtime_configuration, + scheduler_configuration=config.scheduler_configuration, + auto_stop_config=config.auto_stop_config, + ) + created_application = True + logger.info("Created application %s successfully", application_id) + else: + logger.info("Using existing application: %s", application_id) + + # --- Sync image on existing apps if sync_image is enabled --- + if not created_application and config.sync_image: + logger.info("sync_image is enabled, checking application image for %s", application_id) + await self._sync_application_image(handler, application_id, task_template, config) + elif not created_application: + logger.debug("sync_image is disabled, skipping image sync for %s", application_id) + + logger.info("Ensuring application %s is in STARTED state", application_id) + await handler.ensure_application_started(application_id) + + # --- Build job driver --- + if config.is_script_mode: + logger.info("Building job driver in script mode") + job_driver = config.get_job_driver() + if config.spark_submit_parameters and "sparkSubmit" in job_driver: + existing = job_driver["sparkSubmit"].get("sparkSubmitParameters", "") + merged = f"{existing} {config.spark_submit_parameters}".strip() + job_driver["sparkSubmit"]["sparkSubmitParameters"] = merged + logger.info("Merged top-level spark_submit_parameters into script mode job driver") + logger.debug("Final sparkSubmitParameters: %s", merged) + if "sparkSubmit" in job_driver: + current = job_driver["sparkSubmit"].get("sparkSubmitParameters") + updated = self._append_flyte_env_to_spark_params( + current, + task_execution_metadata, + enabled=config.inject_flyte_env, + ) + if updated != current: + job_driver["sparkSubmit"]["sparkSubmitParameters"] = updated + elif config.inject_flyte_env: + logger.debug( + "Skipping Flyte env injection: non-Spark job driver (%s)", + next(iter(job_driver.keys()), "unknown"), + ) + else: + logger.info("Building job driver in Pythonic mode") + job_driver = self._build_pythonic_job_driver(task_template, config, handler) + if "sparkSubmit" in job_driver: + current = job_driver["sparkSubmit"].get("sparkSubmitParameters") + updated = self._append_flyte_env_to_spark_params( + current, + task_execution_metadata, + enabled=config.inject_flyte_env, + ) + if updated != current: + job_driver["sparkSubmit"]["sparkSubmitParameters"] = updated + + job_name = task_template.id.name if task_template.id else None + + effective_config_overrides = config.get_effective_configuration_overrides() + if effective_config_overrides: + logger.debug( + "Effective configuration overrides keys: %s", + list(effective_config_overrides.keys()), + ) + + logger.info( + "Submitting job run: application=%s, job_name=%s, execution_role=%s, timeout=%dm", + application_id, + job_name, + config.execution_role_arn, + config.execution_timeout_minutes, + ) + job_run_id = await handler.start_job_run( + application_id=application_id, + execution_role_arn=config.execution_role_arn, + job_driver=job_driver, + configuration_overrides=effective_config_overrides, + tags=self._merge_tags(config.tags), + execution_timeout_minutes=config.execution_timeout_minutes, + name=job_name, + retry_policy=config.retry_policy, + ) + + region = config.region or handler.client.meta.region_name + logger.info( + "Job submitted successfully: application=%s, job_run_id=%s, region=%s, created_application=%s", + application_id, + job_run_id, + region, + created_application, + ) + + return EMRServerlessJobMetadata( + application_id=application_id, + job_run_id=job_run_id, + region=region, + created_application=created_application, + ) + + async def get( + self, + resource_meta: EMRServerlessJobMetadata, + **kwargs: Any, + ) -> Resource: + logger.debug( + "get() called: application=%s, job_run_id=%s, region=%s", + resource_meta.application_id, + resource_meta.job_run_id, + resource_meta.region, + ) + handler = self._get_handler(resource_meta.region) + + try: + job = await handler.get_job_run( + application_id=resource_meta.application_id, + job_run_id=resource_meta.job_run_id, + ) + except Exception as e: + logger.warning( + "Failed to retrieve job %s on application %s: %s", + resource_meta.job_run_id, + resource_meta.application_id, + e, + ) + return Resource( + phase=TaskExecution.FAILED, + message=f"Job not found: {resource_meta.job_run_id}", + ) + + state = job.get("state", "UNKNOWN") + state_details = job.get("stateDetails", "") + phase = convert_to_flyte_phase(EMR_SERVERLESS_STATES.get(state, "Running")) + + message = f"EMR Serverless job state: {state}" + if state_details: + message = f"{message} - {state_details}" + + logger.info( + "Job %s status: state=%s, phase=%s", + resource_meta.job_run_id, + state, + phase, + ) + + log_links = self._get_log_links(resource_meta) + + return Resource(phase=phase, message=message, log_links=log_links) + + def _get_log_links(self, resource_meta: EMRServerlessJobMetadata) -> list: + region = resource_meta.region or "us-east-1" + console_url = ( + f"https://{region}.console.aws.amazon.com/emr/home?region={region}" + f"#/serverless/{resource_meta.application_id}/jobs/{resource_meta.job_run_id}" + ) + return [TaskLog(uri=console_url, name="EMR Serverless Console").to_flyte_idl()] + + async def delete( + self, + resource_meta: EMRServerlessJobMetadata, + **kwargs: Any, + ) -> None: + logger.info( + "delete() called: application=%s, job_run_id=%s, region=%s", + resource_meta.application_id, + resource_meta.job_run_id, + resource_meta.region, + ) + handler = self._get_handler(resource_meta.region) + try: + await handler.cancel_job_run( + application_id=resource_meta.application_id, + job_run_id=resource_meta.job_run_id, + ) + logger.info("Delete completed for job %s", resource_meta.job_run_id) + except Exception as e: + logger.warning( + "Failed to cancel job %s on application %s: %s", + resource_meta.job_run_id, + resource_meta.application_id, + e, + ) + + +ConnectorRegistry.register(EMRServerlessConnector()) diff --git a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/task.py b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/task.py new file mode 100644 index 0000000000..3a92a3cab0 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/task.py @@ -0,0 +1,421 @@ +""" +Flyte Task definition for EMR Serverless. + +Provides the EMRServerless task configuration dataclass, job driver configs, +and the EMRServerlessTask class for defining EMR Serverless jobs in Flyte +workflows. Supports two execution modes: + +1. Script mode: User provides S3 paths to Spark/Hive scripts. +2. Pythonic mode: User writes a @task function and the connector uses the + Flyte container entrypoint to execute it on EMR Serverless. +""" + +import dataclasses +import logging +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +from flytekit import PythonFunctionTask +from flytekit.configuration import SerializationSettings +from flytekit.core.context_manager import FlyteContextManager +from flytekit.extend import TaskPlugins +from flytekit.image_spec import ImageSpec + +try: + from flytekit.extend.backend.base_connector import AsyncConnectorExecutorMixin +except ModuleNotFoundError: + from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin as AsyncConnectorExecutorMixin + +logger = logging.getLogger(__name__) + + +@dataclass +class EMRServerlessSparkJobDriver: + """ + Spark job driver configuration for EMR Serverless. + + Attributes: + entry_point: S3 path to the main application file + (e.g., ``s3://bucket/scripts/main.py``). + entry_point_arguments: Arguments to pass to the main application. + spark_submit_parameters: Spark submit parameters string + (e.g., ``--conf spark.executor.memory=4g``). + """ + + entry_point: str + entry_point_arguments: Optional[List[str]] = None + spark_submit_parameters: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to EMR Serverless API format.""" + result: Dict[str, Any] = {"entryPoint": self.entry_point} + if self.entry_point_arguments: + result["entryPointArguments"] = self.entry_point_arguments + if self.spark_submit_parameters: + result["sparkSubmitParameters"] = self.spark_submit_parameters + return result + + +@dataclass +class EMRServerlessHiveJobDriver: + """ + Hive job driver configuration for EMR Serverless. + + Attributes: + query: The Hive query to execute, or S3 path to a query file. + init_query_file: S3 path to an initialization query file. + parameters: Parameters string for the query. + """ + + query: str + init_query_file: Optional[str] = None + parameters: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to EMR Serverless API format.""" + result: Dict[str, Any] = {"query": self.query} + if self.init_query_file: + result["initQueryFile"] = self.init_query_file + if self.parameters: + result["parameters"] = self.parameters + return result + + +_VALID_ARCHITECTURES = ("X86_64", "ARM64") + + +@dataclass +class EMRServerless: + """ + EMR Serverless task configuration. + + Use this to configure an EMR Serverless task. Tasks marked with this + configuration will execute on AWS EMR Serverless as Spark or Hive jobs. + + Supports two execution modes: + + **Script mode** -- provide a ``spark_job_driver`` or ``hive_job_driver`` + pointing to scripts on S3:: + + @task( + task_config=EMRServerless( + application_id="00f5abc123def456", + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=EMRServerlessSparkJobDriver( + entry_point="s3://bucket/scripts/main.py", + ), + region="us-east-1", + ) + ) + def my_spark_job() -> str: + return "Job submitted" + + **Pythonic mode** -- omit the job driver and write Spark code directly + in the task function. The connector will use the Flyte container + entrypoint to execute it on EMR Serverless:: + + @task( + task_config=EMRServerless( + application_id="00f5abc123def456", + execution_role_arn="arn:aws:iam::123456789012:role/Role", + region="us-east-1", + ) + ) + def my_spark_job(): + from pyspark.sql import SparkSession + spark = SparkSession.builder.getOrCreate() + spark.range(100).write.parquet("s3://bucket/output") + + Attributes: + execution_role_arn: IAM role ARN for job execution (required). + application_id: Existing EMR Serverless application ID. If not + provided and ``application_name`` is set, the connector will + attempt to create a new application (subject to the + connector-level ``FLYTE_EMR_ALLOW_CREATE_APPLICATION`` policy). + release_label: EMR release label (default: ``emr-7.0.0``). + application_type: ``"SPARK"`` or ``"HIVE"`` (default: ``"SPARK"``). + application_name: Name for a new application. When + ``application_id`` is not set and ``application_name`` is + provided, the connector will create a new application with + this name (prefixed by the connector's + ``FLYTE_EMR_APPLICATION_NAME_PREFIX`` if configured). + sync_image: When ``True`` (the default), the connector will update + the application's image if the task's ``container_image`` or + ``image_configuration`` differs from what the app currently has. + spark_job_driver: Spark job configuration. When omitted for a SPARK + application, Pythonic mode is used. + hive_job_driver: Hive job configuration. + spark_submit_parameters: Extra ``--conf`` flags to pass to + ``sparkSubmitParameters`` for *both* script and Pythonic modes. + In Pythonic mode these are merged with the connector defaults. + application_configuration: ``applicationConfiguration`` list for + ``configurationOverrides``. Merged with any entries already + in ``configuration_overrides``. + runtime_configuration: ``runtimeConfiguration`` passed to + ``CreateApplication`` or ``UpdateApplication``. + scheduler_configuration: ``SchedulerConfiguration`` for job + concurrency and queuing (e.g. + ``{"maxConcurrentRuns": 5, "queueTimeoutMinutes": 30}``). + auto_stop_config: ``AutoStopConfiguration`` for the application + (e.g. ``{"enabled": True, "idleTimeoutMinutes": 30}``). + architecture: Processor architecture for the application. + ``"X86_64"`` (default) or ``"ARM64"`` (Graviton). + retry_policy: Job-level retry policy for resiliency (e.g. + ``{"maxAttempts": 3}``). Requires EMR 7.1+. + configuration_overrides: Application and monitoring configuration. + tags: Resource tags. + execution_timeout_minutes: Maximum job execution time (default: 60). + initial_capacity: Pre-initialized worker capacity. + maximum_capacity: Auto-scaling limits. + network_configuration: VPC, subnet, and security group settings. + image_configuration: Custom container image settings. + region: AWS region (uses boto3 default if not specified). + inject_flyte_env: When ``True`` (default), the connector appends + ``--conf spark.emr-serverless.driverEnv.*`` and + ``--conf spark.executorEnv.*`` entries to ``sparkSubmitParameters`` + so the Spark driver and executors see Flyte runtime context as + environment variables: ``FLYTE_INTERNAL_EXECUTION_ID``, + ``FLYTE_INTERNAL_EXECUTION_PROJECT``, + ``FLYTE_INTERNAL_EXECUTION_DOMAIN``, + ``FLYTE_INTERNAL_TASK_{PROJECT,DOMAIN,NAME,VERSION}``, + ``FLYTE_INTERNAL_NODE_ID``, and + ``FLYTE_INTERNAL_TASK_RETRY_ATTEMPT``. Any + ``environment_variables`` set on ``task_execution_metadata`` + (via Flyte ``Environment`` policies) are also forwarded. Only + applies to Spark jobs. + """ + + execution_role_arn: str + application_id: Optional[str] = None + release_label: str = "emr-7.0.0" + application_type: str = "SPARK" + application_name: Optional[str] = None + + sync_image: bool = True + inject_flyte_env: bool = True + + spark_job_driver: Optional[EMRServerlessSparkJobDriver] = None + hive_job_driver: Optional[EMRServerlessHiveJobDriver] = None + + spark_submit_parameters: Optional[str] = None + application_configuration: Optional[List[Dict[str, Any]]] = None + runtime_configuration: Optional[List[Dict[str, Any]]] = None + scheduler_configuration: Optional[Dict[str, Any]] = None + auto_stop_config: Optional[Dict[str, Any]] = None + architecture: Optional[str] = None + retry_policy: Optional[Dict[str, Any]] = None + + configuration_overrides: Optional[Dict[str, Any]] = None + tags: Optional[Dict[str, str]] = None + execution_timeout_minutes: int = 60 + + initial_capacity: Optional[Dict[str, Any]] = None + maximum_capacity: Optional[Dict[str, Any]] = None + network_configuration: Optional[Dict[str, Any]] = None + image_configuration: Optional[Dict[str, Any]] = None + + region: Optional[str] = None + + def __post_init__(self) -> None: + if not self.execution_role_arn: + raise ValueError("execution_role_arn is required") + if self.application_type not in ("SPARK", "HIVE"): + raise ValueError(f"application_type must be 'SPARK' or 'HIVE', got '{self.application_type}'") + if self.application_type == "HIVE" and self.hive_job_driver is None: + raise ValueError("hive_job_driver is required when application_type is 'HIVE'") + if self.spark_job_driver and self.hive_job_driver: + raise ValueError("Only one of spark_job_driver or hive_job_driver can be specified") + if self.execution_timeout_minutes < 1: + raise ValueError("execution_timeout_minutes must be at least 1") + if self.architecture and self.architecture not in _VALID_ARCHITECTURES: + raise ValueError(f"architecture must be one of {_VALID_ARCHITECTURES}, got '{self.architecture}'") + + @property + def is_script_mode(self) -> bool: + """True when user provided an explicit job driver (script mode).""" + return self.spark_job_driver is not None or self.hive_job_driver is not None + + def get_job_driver(self) -> Dict[str, Any]: + """Get the job driver dict in EMR Serverless API format.""" + if self.spark_job_driver: + return {"sparkSubmit": self.spark_job_driver.to_dict()} + if self.hive_job_driver: + return {"hive": self.hive_job_driver.to_dict()} + raise ValueError("No job driver configured -- use Pythonic mode instead") + + def get_effective_configuration_overrides(self) -> Optional[Dict[str, Any]]: + """Build the merged ``configurationOverrides`` dict. + + Combines ``configuration_overrides`` with ``application_configuration`` + so users can set monitoring config via ``configuration_overrides`` and + Spark/Hive config via ``application_configuration`` independently. + """ + if not self.application_configuration and not self.configuration_overrides: + return None + + result = dict(self.configuration_overrides) if self.configuration_overrides else {} + + if self.application_configuration: + existing = result.get("applicationConfiguration", []) + merged = list(existing) + list(self.application_configuration) + result["applicationConfiguration"] = merged + + return result or None + + def to_dict(self) -> Dict[str, Any]: + """Serialize to a dict for ``task_template.custom``.""" + result: Dict[str, Any] = { + "execution_role_arn": self.execution_role_arn, + "release_label": self.release_label, + "application_type": self.application_type, + "execution_timeout_minutes": self.execution_timeout_minutes, + "sync_image": self.sync_image, + "inject_flyte_env": self.inject_flyte_env, + } + if self.application_id: + result["application_id"] = self.application_id + if self.application_name: + result["application_name"] = self.application_name + if self.spark_job_driver: + result["spark_job_driver"] = self.spark_job_driver.to_dict() + if self.hive_job_driver: + result["hive_job_driver"] = self.hive_job_driver.to_dict() + if self.spark_submit_parameters: + result["spark_submit_parameters"] = self.spark_submit_parameters + if self.application_configuration: + result["application_configuration"] = self.application_configuration + if self.runtime_configuration: + result["runtime_configuration"] = self.runtime_configuration + if self.scheduler_configuration: + result["scheduler_configuration"] = self.scheduler_configuration + if self.auto_stop_config: + result["auto_stop_config"] = self.auto_stop_config + if self.architecture: + result["architecture"] = self.architecture + if self.retry_policy: + result["retry_policy"] = self.retry_policy + if self.configuration_overrides: + result["configuration_overrides"] = self.configuration_overrides + if self.tags: + result["tags"] = self.tags + if self.initial_capacity: + result["initial_capacity"] = self.initial_capacity + if self.maximum_capacity: + result["maximum_capacity"] = self.maximum_capacity + if self.network_configuration: + result["network_configuration"] = self.network_configuration + if self.image_configuration: + result["image_configuration"] = self.image_configuration + if self.region: + result["region"] = self.region + return result + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "EMRServerless": + """Deserialize from a dict (inverse of ``to_dict``).""" + spark_driver = None + hive_driver = None + + if "spark_job_driver" in data: + d = data["spark_job_driver"] + spark_driver = EMRServerlessSparkJobDriver( + entry_point=d["entryPoint"], + entry_point_arguments=d.get("entryPointArguments"), + spark_submit_parameters=d.get("sparkSubmitParameters"), + ) + if "hive_job_driver" in data: + d = data["hive_job_driver"] + hive_driver = EMRServerlessHiveJobDriver( + query=d["query"], + init_query_file=d.get("initQueryFile"), + parameters=d.get("parameters"), + ) + + return cls( + execution_role_arn=data["execution_role_arn"], + application_id=data.get("application_id"), + release_label=data.get("release_label", "emr-7.0.0"), + application_type=data.get("application_type", "SPARK"), + application_name=data.get("application_name"), + sync_image=data.get("sync_image", True), + inject_flyte_env=data.get("inject_flyte_env", True), + spark_job_driver=spark_driver, + hive_job_driver=hive_driver, + spark_submit_parameters=data.get("spark_submit_parameters"), + application_configuration=data.get("application_configuration"), + runtime_configuration=data.get("runtime_configuration"), + scheduler_configuration=data.get("scheduler_configuration"), + auto_stop_config=data.get("auto_stop_config"), + architecture=data.get("architecture"), + retry_policy=data.get("retry_policy"), + configuration_overrides=data.get("configuration_overrides"), + tags=data.get("tags"), + execution_timeout_minutes=int(data.get("execution_timeout_minutes", 60)), + initial_capacity=data.get("initial_capacity"), + maximum_capacity=data.get("maximum_capacity"), + network_configuration=data.get("network_configuration"), + image_configuration=data.get("image_configuration"), + region=data.get("region"), + ) + + +EMR_SERVERLESS_BASE_IMAGE = "public.ecr.aws/emr-serverless/spark/emr-7.0.0:latest" + + +class EMRServerlessTask(AsyncConnectorExecutorMixin, PythonFunctionTask[EMRServerless]): + """ + EMR Serverless Task implementation. + + Extends ``PythonFunctionTask`` with ``AsyncConnectorExecutorMixin`` to + enable remote execution via the EMR Serverless connector and local + testing by mimicking FlytePropeller's connector calls. + + For Pythonic mode the task image must include ``flytekit``. When an + ``ImageSpec`` is provided without a ``base_image``, the EMR Serverless + Spark base image is used automatically (same pattern as + ``flytekit-spark``). Users can also pass a plain ECR URI string as + ``container_image``. + """ + + _TASK_TYPE = "emr_serverless" + + def __init__( + self, + task_config: EMRServerless, + task_function: Callable, + container_image: Optional[Union[str, ImageSpec]] = None, + **kwargs: Any, + ): + if isinstance(container_image, ImageSpec) and container_image.base_image is None: + container_image = dataclasses.replace(container_image, base_image=EMR_SERVERLESS_BASE_IMAGE) + + super().__init__( + task_config=task_config, + task_function=task_function, + task_type=self._TASK_TYPE, + container_image=container_image, + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + """Serialize task configuration into ``task_template.custom``.""" + return self.task_config.to_dict() + + def execute(self, **kwargs: Any) -> Any: + """ + Execute the task. + + When running locally (e.g. ``pyflyte run`` without ``--remote``), + delegate to the connector mixin so it mimics the backend flow. + + When running on an EMR Serverless worker (dispatched by the + entrypoint), execute the user function directly -- the connector + already submitted the job; the worker just needs to run the code. + """ + ctx = FlyteContextManager.current_context() + if ctx.execution_state and ctx.execution_state.is_local_execution(): + return AsyncConnectorExecutorMixin.execute(self, **kwargs) + return PythonFunctionTask.execute(self, **kwargs) + + +TaskPlugins.register_pythontask_plugin(EMRServerless, EMRServerlessTask) diff --git a/plugins/flytekit-aws-emr-serverless/setup.py b/plugins/flytekit-aws-emr-serverless/setup.py new file mode 100644 index 0000000000..8ba9163b94 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/setup.py @@ -0,0 +1,49 @@ +from setuptools import setup + +PLUGIN_NAME = "awsemrserverless" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = [ + "flytekit>=1.14.6", + "aioboto3>=12.3.0", + "boto3>=1.28.0", +] + +__version__ = "0.0.0+develop" + +setup( + title="AWS EMR Serverless", + title_expanded="Flytekit AWS EMR Serverless Plugin", + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="Flytekit AWS EMR Serverless Plugin: run Spark and Hive jobs from Flyte tasks", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.9", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={ + "flytekit.plugins": [ + f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}", + ], + }, +) diff --git a/plugins/flytekit-aws-emr-serverless/tests/__init__.py b/plugins/flytekit-aws-emr-serverless/tests/__init__.py new file mode 100644 index 0000000000..199377458c --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the Flyte EMR Serverless connector.""" diff --git a/plugins/flytekit-aws-emr-serverless/tests/conftest.py b/plugins/flytekit-aws-emr-serverless/tests/conftest.py new file mode 100644 index 0000000000..67142fc5f9 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/tests/conftest.py @@ -0,0 +1,157 @@ +""" +Pytest configuration and shared fixtures for EMR Serverless tests. + +The fixtures here follow the flytekit-aws-sagemaker testing pattern: boto3 +traffic is mocked at the single ``EMRServerlessHandler._call`` (and +``_paginate``) chokepoint, not at the boto3 client factory. This keeps +test setup dependency-free -- no ``moto``, no ``botocore.stub`` wiring. +""" + +from unittest.mock import AsyncMock, patch + +import pytest + + +@pytest.fixture +def mock_call(): + """Patch :py:meth:`EMRServerlessHandler._call` for a test. + + Yields an :class:`~unittest.mock.AsyncMock` so tests can set + ``return_value`` or ``side_effect`` and assert on ``call_args``. + """ + with patch( + "flytekitplugins.awsemrserverless.boto_handler.EMRServerlessHandler._call", + new_callable=AsyncMock, + ) as m: + yield m + + +@pytest.fixture +def mock_paginate(): + """Patch :py:meth:`EMRServerlessHandler._paginate` for a test.""" + with patch( + "flytekitplugins.awsemrserverless.boto_handler.EMRServerlessHandler._paginate", + new_callable=AsyncMock, + ) as m: + yield m + + +@pytest.fixture +def sample_spark_job_driver(): + from flytekitplugins.awsemrserverless import EMRServerlessSparkJobDriver + + return EMRServerlessSparkJobDriver( + entry_point="s3://my-bucket/scripts/main.py", + entry_point_arguments=[ + "--input", + "s3://data/input", + "--output", + "s3://data/output", + ], + spark_submit_parameters="--conf spark.executor.memory=4g --conf spark.executor.cores=2", + ) + + +@pytest.fixture +def sample_hive_job_driver(): + from flytekitplugins.awsemrserverless import EMRServerlessHiveJobDriver + + return EMRServerlessHiveJobDriver( + query="s3://my-bucket/queries/analytics.sql", + init_query_file="s3://my-bucket/queries/init.sql", + parameters="--hiveconf hive.exec.parallel=true", + ) + + +@pytest.fixture +def sample_config(sample_spark_job_driver): + from flytekitplugins.awsemrserverless import EMRServerless + + return EMRServerless( + application_id="00f5abc123def456", + execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole", + spark_job_driver=sample_spark_job_driver, + region="us-east-1", + tags={"Environment": "test", "Team": "data-engineering"}, + execution_timeout_minutes=120, + ) + + +@pytest.fixture +def sample_task_config(sample_spark_job_driver): + from flytekitplugins.awsemrserverless import EMRServerless + + return EMRServerless( + application_id="00f5abc123def456", + execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole", + spark_job_driver=sample_spark_job_driver, + region="us-east-1", + ) + + +@pytest.fixture +def sample_job_metadata(): + from flytekitplugins.awsemrserverless import EMRServerlessJobMetadata + + return EMRServerlessJobMetadata( + application_id="00f5abc123def456", + job_run_id="00f5xyz789ghi012", + region="us-east-1", + created_application=False, + ) + + +# ---------------------------------------------------------------------- +# Boto response shape fixtures +# +# These match the shape returned by the boto3 client methods (what gets +# funneled through ``_call``). Tests feed them directly as the ``_call`` +# mock's return value. +# ---------------------------------------------------------------------- + + +@pytest.fixture +def mock_job_run_success(): + return { + "jobRun": { + "applicationId": "00f5abc123def456", + "jobRunId": "00f5xyz789ghi012", + "state": "SUCCESS", + "stateDetails": "Job completed successfully", + } + } + + +@pytest.fixture +def mock_job_run_running(): + return { + "jobRun": { + "applicationId": "00f5abc123def456", + "jobRunId": "00f5xyz789ghi012", + "state": "RUNNING", + "stateDetails": "Job is running", + } + } + + +@pytest.fixture +def mock_job_run_failed(): + return { + "jobRun": { + "applicationId": "00f5abc123def456", + "jobRunId": "00f5xyz789ghi012", + "state": "FAILED", + "stateDetails": "Job failed due to OOM", + } + } + + +@pytest.fixture +def mock_application_started(): + return { + "application": { + "applicationId": "00f5abc123def456", + "state": "STARTED", + "name": "test-app", + } + } diff --git a/plugins/flytekit-aws-emr-serverless/tests/test_boto_handler.py b/plugins/flytekit-aws-emr-serverless/tests/test_boto_handler.py new file mode 100644 index 0000000000..cf2c75636e --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/tests/test_boto_handler.py @@ -0,0 +1,336 @@ +""" +Unit tests for the EMR Serverless boto3 handler. + +Tests the thin typed methods on :class:`EMRServerlessHandler` by patching +the single ``_call`` (and ``_paginate``) chokepoint that funnels all boto3 +traffic. This mirrors the flytekit-aws-sagemaker testing pattern +(``mock.patch("...Boto3ConnectorMixin._call")``). +""" + +from unittest.mock import AsyncMock, patch + +import pytest +from botocore.exceptions import ClientError + +from flytekitplugins.awsemrserverless.boto_handler import EMRServerlessHandler + + +class TestEMRServerlessHandlerInit: + def test_init_default(self): + handler = EMRServerlessHandler() + assert handler.region is None + + def test_init_with_region(self): + handler = EMRServerlessHandler(region="us-west-2") + assert handler.region == "us-west-2" + + +class TestEMRServerlessHandlerGetApplication: + @pytest.mark.asyncio + async def test_get_application_success(self, mock_call, mock_application_started): + mock_call.return_value = mock_application_started + + handler = EMRServerlessHandler() + result = await handler.get_application("00f5abc123def456") + + assert result["applicationId"] == "00f5abc123def456" + assert result["state"] == "STARTED" + mock_call.assert_called_once_with("get_application", applicationId="00f5abc123def456") + + @pytest.mark.asyncio + async def test_get_application_not_found(self, mock_call): + mock_call.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "Not found"}}, + "GetApplication", + ) + + handler = EMRServerlessHandler() + + with pytest.raises(ClientError): + await handler.get_application("nonexistent-app") + + +class TestEMRServerlessHandlerStartJobRun: + @pytest.mark.asyncio + async def test_start_job_run_success(self, mock_call): + mock_call.return_value = { + "applicationId": "app-123", + "jobRunId": "job-456", + } + + handler = EMRServerlessHandler() + job_driver = {"sparkSubmit": {"entryPoint": "s3://bucket/main.py"}} + + job_run_id = await handler.start_job_run( + application_id="app-123", + execution_role_arn="arn:aws:iam::123456789012:role/Role", + job_driver=job_driver, + ) + + assert job_run_id == "job-456" + mock_call.assert_called_once() + assert mock_call.call_args.args == ("start_job_run",) + + @pytest.mark.asyncio + async def test_start_job_run_with_all_options(self, mock_call): + mock_call.return_value = { + "applicationId": "app-123", + "jobRunId": "job-456", + } + + handler = EMRServerlessHandler() + job_driver = {"sparkSubmit": {"entryPoint": "s3://bucket/main.py"}} + config_overrides = {"monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": "s3://logs/"}}} + tags = {"Environment": "test"} + + await handler.start_job_run( + application_id="app-123", + execution_role_arn="arn:aws:iam::123456789012:role/Role", + job_driver=job_driver, + configuration_overrides=config_overrides, + tags=tags, + execution_timeout_minutes=120, + name="test-job", + ) + + call_kwargs = mock_call.call_args.kwargs + assert call_kwargs["configurationOverrides"] == config_overrides + assert call_kwargs["tags"] == tags + assert call_kwargs["executionTimeoutMinutes"] == 120 + assert call_kwargs["name"] == "test-job" + + @pytest.mark.asyncio + async def test_start_job_run_with_retry_policy(self, mock_call): + mock_call.return_value = {"applicationId": "app-123", "jobRunId": "job-456"} + + handler = EMRServerlessHandler() + await handler.start_job_run( + application_id="app-123", + execution_role_arn="arn:aws:iam::123456789012:role/Role", + job_driver={"sparkSubmit": {"entryPoint": "s3://bucket/main.py"}}, + retry_policy={"maxAttempts": 3}, + ) + + call_kwargs = mock_call.call_args.kwargs + assert call_kwargs["retryPolicy"]["maxAttempts"] == 3 + + @pytest.mark.asyncio + async def test_start_job_run_without_retry_policy(self, mock_call): + mock_call.return_value = {"applicationId": "app-123", "jobRunId": "job-456"} + + handler = EMRServerlessHandler() + await handler.start_job_run( + application_id="app-123", + execution_role_arn="arn:aws:iam::123456789012:role/Role", + job_driver={"sparkSubmit": {"entryPoint": "s3://bucket/main.py"}}, + ) + + call_kwargs = mock_call.call_args.kwargs + assert "retryPolicy" not in call_kwargs + + +class TestEMRServerlessHandlerGetJobRun: + @pytest.mark.asyncio + async def test_get_job_run_success(self, mock_call, mock_job_run_running): + mock_call.return_value = mock_job_run_running + + handler = EMRServerlessHandler() + result = await handler.get_job_run("app-123", "job-456") + + assert result["state"] == "RUNNING" + mock_call.assert_called_once_with( + "get_job_run", applicationId="app-123", jobRunId="job-456" + ) + + @pytest.mark.asyncio + async def test_get_job_run_not_found(self, mock_call): + mock_call.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "Not found"}}, + "GetJobRun", + ) + + handler = EMRServerlessHandler() + + with pytest.raises(ClientError): + await handler.get_job_run("app-123", "nonexistent-job") + + +class TestEMRServerlessHandlerCancelJobRun: + @pytest.mark.asyncio + async def test_cancel_job_run_success(self, mock_call, mock_job_run_running): + # First _call -> get_job_run returns running, second -> cancel_job_run returns {} + mock_call.side_effect = [mock_job_run_running, {}] + + handler = EMRServerlessHandler() + await handler.cancel_job_run("app-123", "job-456") + + assert mock_call.call_count == 2 + assert mock_call.call_args_list[0].args == ("get_job_run",) + assert mock_call.call_args_list[1].args == ("cancel_job_run",) + + @pytest.mark.asyncio + async def test_cancel_job_run_already_completed(self, mock_call, mock_job_run_success): + mock_call.return_value = mock_job_run_success + + handler = EMRServerlessHandler() + await handler.cancel_job_run("app-123", "job-456") + + # Only the get_job_run call; cancel_job_run is skipped because it's terminal. + mock_call.assert_called_once() + assert mock_call.call_args.args == ("get_job_run",) + + @pytest.mark.asyncio + async def test_cancel_job_run_not_found_is_swallowed(self, mock_call): + mock_call.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "Not found"}}, + "GetJobRun", + ) + + handler = EMRServerlessHandler() + # Does not raise. + await handler.cancel_job_run("app-123", "nonexistent-job") + + +class TestEMRServerlessHandlerCreateApplication: + @pytest.mark.asyncio + async def test_create_application_with_new_params(self, mock_call): + mock_call.return_value = {"applicationId": "app-new"} + + handler = EMRServerlessHandler() + app_id = await handler.create_application( + name="test-app", + release_label="emr-7.0.0", + application_type="SPARK", + architecture="ARM64", + runtime_configuration=[{"classification": "spark-defaults", "properties": {}}], + scheduler_configuration={"maxConcurrentRuns": 5}, + auto_stop_config={"enabled": True, "idleTimeoutMinutes": 30}, + ) + + assert app_id == "app-new" + call_kwargs = mock_call.call_args.kwargs + assert call_kwargs["architecture"] == "ARM64" + assert call_kwargs["runtimeConfiguration"][0]["classification"] == "spark-defaults" + assert call_kwargs["schedulerConfiguration"]["maxConcurrentRuns"] == 5 + assert call_kwargs["autoStopConfiguration"]["enabled"] is True + + +class TestEMRServerlessHandlerUpdateApplication: + @pytest.mark.asyncio + async def test_update_application_image(self, mock_call): + mock_call.return_value = {} + + handler = EMRServerlessHandler() + await handler.update_application( + application_id="app-123", + image_configuration={"imageUri": "new-image:v2"}, + ) + + mock_call.assert_called_once() + call_kwargs = mock_call.call_args.kwargs + assert call_kwargs["applicationId"] == "app-123" + assert call_kwargs["imageConfiguration"]["imageUri"] == "new-image:v2" + + @pytest.mark.asyncio + async def test_update_application_noop(self, mock_call): + """When no mutable fields are provided, no API call is made.""" + handler = EMRServerlessHandler() + await handler.update_application(application_id="app-123") + + mock_call.assert_not_called() + + @pytest.mark.asyncio + async def test_update_application_multiple_fields(self, mock_call): + mock_call.return_value = {} + + handler = EMRServerlessHandler() + await handler.update_application( + application_id="app-123", + image_configuration={"imageUri": "image:v3"}, + maximum_capacity={"cpu": "200vCPU"}, + scheduler_configuration={"maxConcurrentRuns": 10}, + ) + + call_kwargs = mock_call.call_args.kwargs + assert call_kwargs["imageConfiguration"]["imageUri"] == "image:v3" + assert call_kwargs["maximumCapacity"]["cpu"] == "200vCPU" + assert call_kwargs["schedulerConfiguration"]["maxConcurrentRuns"] == 10 + + +class TestEMRServerlessHandlerFindApplicationByName: + @pytest.mark.asyncio + async def test_find_application_returns_id_when_present(self, mock_paginate): + mock_paginate.return_value = [ + {"id": "00f-other", "name": "other-app", "state": "STARTED"}, + {"id": "00f-target", "name": "my-app", "state": "STARTED"}, + ] + + handler = EMRServerlessHandler() + result = await handler.find_application_by_name("my-app") + + assert result == "00f-target" + mock_paginate.assert_called_once_with( + "list_applications", + result_key="applications", + states=["CREATED", "STARTED", "STOPPED"], + ) + + @pytest.mark.asyncio + async def test_find_application_returns_none_when_absent(self, mock_paginate): + mock_paginate.return_value = [ + {"id": "00f-other", "name": "other-app", "state": "STARTED"}, + ] + + handler = EMRServerlessHandler() + result = await handler.find_application_by_name("nonexistent") + + assert result is None + + @pytest.mark.asyncio + async def test_find_application_empty_list(self, mock_paginate): + mock_paginate.return_value = [] + + handler = EMRServerlessHandler() + result = await handler.find_application_by_name("anything") + + assert result is None + + +class TestEMRServerlessHandlerEnsureApplicationStarted: + @pytest.mark.asyncio + async def test_returns_immediately_when_started(self, mock_call): + mock_call.return_value = {"application": {"applicationId": "app-1", "state": "STARTED"}} + + handler = EMRServerlessHandler() + await handler.ensure_application_started("app-1") + + # Only GetApplication, no StartApplication. + mock_call.assert_called_once() + assert mock_call.call_args.args == ("get_application",) + + @pytest.mark.asyncio + async def test_starts_a_stopped_application(self, mock_call): + # First GetApplication -> STOPPED, StartApplication -> {}, then GetApplication -> STARTED. + mock_call.side_effect = [ + {"application": {"applicationId": "app-1", "state": "STOPPED"}}, + {}, + {"application": {"applicationId": "app-1", "state": "STARTED"}}, + ] + + handler = EMRServerlessHandler() + # Use a very small poll interval to keep the test fast. + with patch("flytekitplugins.awsemrserverless.boto_handler.asyncio.sleep", new_callable=AsyncMock): + await handler.ensure_application_started("app-1", poll_interval_seconds=0) + + methods_called = [c.args[0] for c in mock_call.call_args_list] + assert methods_called[0] == "get_application" + assert methods_called[1] == "start_application" + assert methods_called[2] == "get_application" + + @pytest.mark.asyncio + async def test_raises_on_terminal_state(self, mock_call): + mock_call.return_value = {"application": {"applicationId": "app-1", "state": "TERMINATED"}} + + handler = EMRServerlessHandler() + with pytest.raises(RuntimeError, match="terminal state"): + await handler.ensure_application_started("app-1") diff --git a/plugins/flytekit-aws-emr-serverless/tests/test_connector.py b/plugins/flytekit-aws-emr-serverless/tests/test_connector.py new file mode 100644 index 0000000000..fb0ba4d13e --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/tests/test_connector.py @@ -0,0 +1,1266 @@ +""" +Unit tests for EMR Serverless Connector. +""" + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from flyteidl.core.execution_pb2 import TaskExecution + +from flytekitplugins.awsemrserverless import ( + EMRServerlessConnector, + EMRServerlessJobMetadata, +) +from flytekit.extend.backend.utils import convert_to_flyte_phase + +from flytekitplugins.awsemrserverless.connector import ( + EMR_SERVERLESS_STATES, + _ENV_ALLOW_CREATE_APPLICATION, + _ENV_APPLICATION_NAME_PREFIX, +) + + +def _make_handler(**overrides) -> AsyncMock: + """Build a mock EMRServerlessHandler with sensible defaults.""" + handler = AsyncMock() + handler.ensure_application_started = AsyncMock() + handler.start_job_run = AsyncMock(return_value="job-123") + handler.create_application = AsyncMock(return_value="new-app-123") + handler.get_application = AsyncMock(return_value={ + "applicationId": "00f5abc123def456", + "state": "STARTED", + "imageConfiguration": {}, + }) + handler.update_application = AsyncMock() + handler.client.meta.region_name = "us-east-1" + handler.region = "us-east-1" + for k, v in overrides.items(): + setattr(handler, k, v) + return handler + + +class TestEMRServerlessJobMetadata: + def test_metadata_creation(self): + metadata = EMRServerlessJobMetadata( + application_id="app-123", + job_run_id="job-456", + region="us-east-1", + ) + assert metadata.application_id == "app-123" + assert metadata.job_run_id == "job-456" + assert metadata.region == "us-east-1" + assert metadata.created_application is False + + def test_metadata_with_created_app(self): + metadata = EMRServerlessJobMetadata( + application_id="app-123", + job_run_id="job-456", + region="us-west-2", + created_application=True, + ) + assert metadata.created_application is True + + +class TestEMRServerlessConnector: + def test_connector_initialization(self): + connector = EMRServerlessConnector() + assert connector.name == "EMR Serverless Connector" + + @pytest.mark.asyncio + async def test_create_with_existing_application(self, sample_config): + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = sample_config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=sample_config): + metadata = await connector.create(mock_template) + + assert isinstance(metadata, EMRServerlessJobMetadata) + assert metadata.application_id == sample_config.application_id + assert metadata.job_run_id == "job-123" + mock_handler.ensure_application_started.assert_called_once() + mock_handler.start_job_run.assert_called_once() + + @pytest.mark.asyncio + async def test_create_with_new_application(self, sample_spark_job_driver): + """application_name + env var true => create succeeds.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="test-app", + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.dict(os.environ, {_ENV_ALLOW_CREATE_APPLICATION: "true"}): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + metadata = await connector.create(mock_template) + + assert metadata.application_id == "new-app-123" + assert metadata.created_application is True + mock_handler.create_application.assert_called_once() + assert mock_handler.create_application.call_args.kwargs["name"] == "test-app" + + @pytest.mark.asyncio + async def test_create_fails_no_app_id_no_app_name(self, sample_spark_job_driver): + """Without application_id or application_name, create must fail.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.dict(os.environ, {_ENV_ALLOW_CREATE_APPLICATION: "true"}): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with pytest.raises(ValueError, match="No application_id provided and no application_name"): + await connector.create(mock_template) + + @pytest.mark.asyncio + async def test_create_blocked_by_connector_env_var(self, sample_spark_job_driver): + """When connector env var is false (default), creation is blocked even with application_name.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="my-app", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.dict(os.environ, {_ENV_ALLOW_CREATE_APPLICATION: "false"}): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with pytest.raises(ValueError, match="disabled by the connector"): + await connector.create(mock_template) + + @pytest.mark.asyncio + async def test_create_blocked_when_env_var_unset(self, sample_spark_job_driver): + """When the env var is not set, default is false => creation blocked.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="my-app", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + env = os.environ.copy() + env.pop(_ENV_ALLOW_CREATE_APPLICATION, None) + with patch.dict(os.environ, env, clear=True): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with pytest.raises(ValueError, match="disabled by the connector"): + await connector.create(mock_template) + + @pytest.mark.asyncio + async def test_create_applies_application_name_prefix(self, sample_spark_job_driver): + """FLYTE_EMR_APPLICATION_NAME_PREFIX should be prepended to the app name.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="my-etl-app", + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.dict(os.environ, { + _ENV_ALLOW_CREATE_APPLICATION: "true", + _ENV_APPLICATION_NAME_PREFIX: "flyte-prod-", + }): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.create_application.call_args.kwargs + assert call_kwargs["name"] == "flyte-prod-my-etl-app" + + @pytest.mark.asyncio + async def test_create_no_double_prefix(self, sample_spark_job_driver): + """If application_name already starts with the prefix, don't double-prefix.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="flyte-prod-my-app", + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.dict(os.environ, { + _ENV_ALLOW_CREATE_APPLICATION: "true", + _ENV_APPLICATION_NAME_PREFIX: "flyte-prod-", + }): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.create_application.call_args.kwargs + assert call_kwargs["name"] == "flyte-prod-my-app" + + @pytest.mark.asyncio + async def test_create_no_prefix_when_unset(self, sample_spark_job_driver): + """Without the prefix env var, the application name is used as-is.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="my-raw-name", + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + env = os.environ.copy() + env.pop(_ENV_APPLICATION_NAME_PREFIX, None) + env[_ENV_ALLOW_CREATE_APPLICATION] = "true" + with patch.dict(os.environ, env, clear=True): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.create_application.call_args.kwargs + assert call_kwargs["name"] == "my-raw-name" + + @pytest.mark.asyncio + async def test_create_passes_architecture_and_scheduler(self, sample_spark_job_driver): + """Verify new app-creation params are forwarded to the handler.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="test-app", + architecture="ARM64", + scheduler_configuration={"maxConcurrentRuns": 5}, + auto_stop_config={"enabled": True, "idleTimeoutMinutes": 30}, + runtime_configuration=[{"classification": "spark-defaults", "properties": {"spark.executor.cores": "8"}}], + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.dict(os.environ, {_ENV_ALLOW_CREATE_APPLICATION: "true"}): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.create_application.call_args.kwargs + assert call_kwargs["architecture"] == "ARM64" + assert call_kwargs["scheduler_configuration"] == {"maxConcurrentRuns": 5} + assert call_kwargs["auto_stop_config"] == {"enabled": True, "idleTimeoutMinutes": 30} + assert call_kwargs["runtime_configuration"][0]["classification"] == "spark-defaults" + + @pytest.mark.asyncio + async def test_create_with_retry_policy(self, sample_config): + """Verify retry_policy is forwarded to start_job_run.""" + from dataclasses import replace + + config = replace(sample_config, retry_policy={"maxAttempts": 3}) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + assert call_kwargs["retry_policy"] == {"maxAttempts": 3} + + @pytest.mark.asyncio + async def test_create_syncs_image_from_container_image(self): + """When sync_image=True and container_image differs from app, update_application is called. + + Script mode skips sync unless image_configuration is explicitly set, + so this test uses a script-mode config with image_configuration to + exercise the sync path. + """ + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + image_configuration={"imageUri": "123456.dkr.ecr.us-east-1.amazonaws.com/spark:v2"}, + sync_image=True, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + mock_handler.get_application = AsyncMock(return_value={ + "applicationId": "00abc123", + "state": "STARTED", + "imageConfiguration": {"imageUri": "123456.dkr.ecr.us-east-1.amazonaws.com/spark:v1"}, + }) + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = MagicMock() + mock_template.container.image = "123456.dkr.ecr.us-east-1.amazonaws.com/spark:v2" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + mock_handler.update_application.assert_called_once_with( + application_id="00abc123", + image_configuration={"imageUri": "123456.dkr.ecr.us-east-1.amazonaws.com/spark:v2"}, + ) + + @pytest.mark.asyncio + async def test_create_syncs_image_falls_back_to_image_configuration(self): + """When container_image is absent, sync_image falls back to image_configuration.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + image_configuration={"imageUri": "new-image:v2"}, + sync_image=True, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + mock_handler.get_application = AsyncMock(return_value={ + "applicationId": "00abc123", + "state": "STARTED", + "imageConfiguration": {"imageUri": "old-image:v1"}, + }) + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = None + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + mock_handler.update_application.assert_called_once_with( + application_id="00abc123", + image_configuration={"imageUri": "new-image:v2"}, + ) + + @pytest.mark.asyncio + async def test_create_skips_image_sync_when_same(self): + """When container_image matches the application, update is skipped.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + sync_image=True, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + mock_handler.get_application = AsyncMock(return_value={ + "applicationId": "00abc123", + "state": "STARTED", + "imageConfiguration": {"imageUri": "same-image:v1"}, + }) + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = MagicMock() + mock_template.container.image = "same-image:v1" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + mock_handler.update_application.assert_not_called() + + @pytest.mark.asyncio + async def test_create_skips_image_sync_when_disabled(self): + """When sync_image=False, update_application should not be called.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = MagicMock() + mock_template.container.image = "new-image:v2" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + mock_handler.update_application.assert_not_called() + mock_handler.get_application.assert_not_called() + + @pytest.mark.asyncio + async def test_create_image_sync_noop_when_no_image(self): + """When there's no container_image or image_configuration, sync is a no-op.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + sync_image=True, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = None + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + mock_handler.update_application.assert_not_called() + mock_handler.get_application.assert_not_called() + + @pytest.mark.asyncio + async def test_create_merges_spark_submit_params_in_script_mode(self): + """Top-level spark_submit_parameters should be appended to the job driver.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver( + entry_point="s3://bucket/main.py", + spark_submit_parameters="--conf spark.executor.memory=4g", + ), + spark_submit_parameters="--conf spark.driver.memory=8g", + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + submit_params = call_kwargs["job_driver"]["sparkSubmit"]["sparkSubmitParameters"] + assert "spark.executor.memory=4g" in submit_params + assert "spark.driver.memory=8g" in submit_params + + @pytest.mark.asyncio + async def test_create_merges_application_configuration(self): + """application_configuration should be merged into configuration_overrides.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + configuration_overrides={ + "monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": "s3://logs/"}} + }, + application_configuration=[ + {"classification": "spark-defaults", "properties": {"spark.executor.cores": "8"}} + ], + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = None + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + config_overrides = call_kwargs["configuration_overrides"] + assert "monitoringConfiguration" in config_overrides + assert len(config_overrides["applicationConfiguration"]) == 1 + assert config_overrides["applicationConfiguration"][0]["classification"] == "spark-defaults" + + @pytest.mark.asyncio + async def test_create_pythonic_mode(self): + """Test create with Pythonic mode (no explicit job driver).""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_container = MagicMock() + mock_container.image = "my-image:latest" + mock_container.args = [ + "pyflyte-execute", + "--task-module", + "my.module", + "--task-name", + "my_task", + ] + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = mock_container + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with patch.object( + connector, "_ensure_entrypoint_on_s3", return_value="s3://bucket/flyte/entrypoint.py" + ): + metadata = await connector.create(mock_template) + + assert metadata.job_run_id == "job-123" + call_kwargs = mock_handler.start_job_run.call_args.kwargs + assert "sparkSubmit" in call_kwargs["job_driver"] + assert call_kwargs["job_driver"]["sparkSubmit"]["entryPoint"] == "s3://bucket/flyte/entrypoint.py" + assert "pyflyte-execute" in call_kwargs["job_driver"]["sparkSubmit"]["entryPointArguments"] + + @pytest.mark.asyncio + async def test_create_pythonic_mode_with_spark_submit_params(self): + """spark_submit_parameters should be merged in Pythonic mode.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_submit_parameters="--conf spark.executor.memory=16g --conf spark.executor.cores=8", + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_container = MagicMock() + mock_container.image = "my-image:latest" + mock_container.args = ["pyflyte-execute", "--task-module", "my.module"] + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = mock_container + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with patch.object( + connector, "_ensure_entrypoint_on_s3", return_value="s3://bucket/flyte/entrypoint.py" + ): + await connector.create(mock_template) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + submit_params = call_kwargs["job_driver"]["sparkSubmit"]["sparkSubmitParameters"] + assert "spark.executor.memory=16g" in submit_params + assert "spark.executor.cores=8" in submit_params + assert "PYSPARK_PYTHON" in submit_params + + @pytest.mark.asyncio + async def test_create_new_app_gets_default_image_in_pythonic_mode(self): + """When creating an app in Pythonic mode with no image, a default is used.""" + from flytekitplugins.awsemrserverless import EMRServerless + from flytekitplugins.awsemrserverless.task import EMR_SERVERLESS_BASE_IMAGE + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_name="pythonic-app", + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = MagicMock() + mock_template.container.image = None + mock_template.container.args = ["pyflyte-execute"] + + with patch.dict(os.environ, {_ENV_ALLOW_CREATE_APPLICATION: "true"}): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with patch.object( + connector, "_ensure_entrypoint_on_s3", return_value="s3://bucket/entrypoint.py" + ): + await connector.create(mock_template) + + call_kwargs = mock_handler.create_application.call_args.kwargs + assert call_kwargs["image_configuration"]["imageUri"] == EMR_SERVERLESS_BASE_IMAGE + + @pytest.mark.asyncio + async def test_get_running_job(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.get_job_run = AsyncMock( + return_value={ + "state": "RUNNING", + "stateDetails": "Job is executing", + } + ) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + resource = await connector.get(sample_job_metadata) + + assert resource.phase == TaskExecution.RUNNING + assert "RUNNING" in resource.message + + @pytest.mark.asyncio + async def test_get_successful_job(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.get_job_run = AsyncMock( + return_value={ + "state": "SUCCESS", + "stateDetails": "Job completed successfully", + } + ) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + resource = await connector.get(sample_job_metadata) + + assert resource.phase == TaskExecution.SUCCEEDED + + @pytest.mark.asyncio + async def test_get_failed_job(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.get_job_run = AsyncMock( + return_value={ + "state": "FAILED", + "stateDetails": "OutOfMemoryError", + } + ) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + resource = await connector.get(sample_job_metadata) + + assert resource.phase == TaskExecution.FAILED + assert "OutOfMemoryError" in resource.message + + @pytest.mark.asyncio + async def test_get_cancelled_job(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.get_job_run = AsyncMock( + return_value={ + "state": "CANCELLED", + "stateDetails": "User cancelled", + } + ) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + resource = await connector.get(sample_job_metadata) + + assert resource.phase == TaskExecution.FAILED + + @pytest.mark.asyncio + async def test_get_job_not_found(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.get_job_run = AsyncMock(side_effect=RuntimeError("Job not found")) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + resource = await connector.get(sample_job_metadata) + + assert resource.phase == TaskExecution.FAILED + assert "not found" in resource.message.lower() + + @pytest.mark.asyncio + async def test_get_pending_states(self, sample_job_metadata): + connector = EMRServerlessConnector() + + for state in ["PENDING", "SCHEDULED", "SUBMITTED"]: + mock_handler = AsyncMock() + mock_handler.get_job_run = AsyncMock(return_value={"state": state, "stateDetails": ""}) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + resource = await connector.get(sample_job_metadata) + + assert resource.phase == TaskExecution.RUNNING, f"State {state} should map to RUNNING" + + @pytest.mark.asyncio + async def test_delete_running_job(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.cancel_job_run = AsyncMock() + + with patch.object(connector, "_get_handler", return_value=mock_handler): + await connector.delete(sample_job_metadata) + + mock_handler.cancel_job_run.assert_called_once_with( + application_id=sample_job_metadata.application_id, + job_run_id=sample_job_metadata.job_run_id, + ) + + @pytest.mark.asyncio + async def test_delete_handles_errors_gracefully(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.cancel_job_run = AsyncMock(side_effect=Exception("Network error")) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + await connector.delete(sample_job_metadata) + + +class TestEMRServerlessStateMapping: + def test_all_states_mapped(self): + expected_states = [ + "PENDING", + "SCHEDULED", + "SUBMITTED", + "RUNNING", + "SUCCESS", + "FAILED", + "CANCELLING", + "CANCELLED", + ] + for state in expected_states: + assert state in EMR_SERVERLESS_STATES, f"Missing mapping for {state}" + + def test_running_states_map_to_running(self): + running_states = ["PENDING", "SCHEDULED", "SUBMITTED", "RUNNING", "CANCELLING"] + for state in running_states: + assert EMR_SERVERLESS_STATES[state] == "Running" + assert convert_to_flyte_phase(EMR_SERVERLESS_STATES[state]) == TaskExecution.RUNNING + + def test_success_maps_to_succeeded(self): + assert EMR_SERVERLESS_STATES["SUCCESS"] == "Success" + assert convert_to_flyte_phase(EMR_SERVERLESS_STATES["SUCCESS"]) == TaskExecution.SUCCEEDED + + def test_failure_states_map_to_failed(self): + for state in ["FAILED", "CANCELLED"]: + assert EMR_SERVERLESS_STATES[state] == "Failed" + assert convert_to_flyte_phase(EMR_SERVERLESS_STATES[state]) == TaskExecution.FAILED + + +def _make_task_execution_metadata( + *, + exec_name: str = "abc123def456", + exec_project: str = "flytetest", + exec_domain: str = "development", + task_project: str = "flytetest", + task_domain: str = "development", + task_name: str = "examples.script_spark_job.spark_demo_task", + task_version: str = "v1", + node_id: str = "n0", + retry_attempt: int = 0, + environment_variables: dict = None, +) -> MagicMock: + """Build a mock TaskExecutionMetadata with the nested id graph populated.""" + meta = MagicMock() + meta.task_execution_id.task_id.project = task_project + meta.task_execution_id.task_id.domain = task_domain + meta.task_execution_id.task_id.name = task_name + meta.task_execution_id.task_id.version = task_version + meta.task_execution_id.node_execution_id.node_id = node_id + meta.task_execution_id.node_execution_id.execution_id.name = exec_name + meta.task_execution_id.node_execution_id.execution_id.project = exec_project + meta.task_execution_id.node_execution_id.execution_id.domain = exec_domain + meta.task_execution_id.retry_attempt = retry_attempt + meta.environment_variables = environment_variables or {} + return meta + + +class TestFlyteEnvInjection: + """Tests for the Flyte context env var injection feature.""" + + def test_build_flyte_env_vars_none_metadata(self): + assert EMRServerlessConnector._build_flyte_env_vars(None) == {} + + def test_build_flyte_env_vars_full(self): + meta = _make_task_execution_metadata() + env = EMRServerlessConnector._build_flyte_env_vars(meta) + assert env["FLYTE_INTERNAL_EXECUTION_ID"] == "abc123def456" + assert env["FLYTE_INTERNAL_EXECUTION_PROJECT"] == "flytetest" + assert env["FLYTE_INTERNAL_EXECUTION_DOMAIN"] == "development" + assert env["FLYTE_INTERNAL_TASK_PROJECT"] == "flytetest" + assert env["FLYTE_INTERNAL_TASK_DOMAIN"] == "development" + assert env["FLYTE_INTERNAL_TASK_NAME"] == "examples.script_spark_job.spark_demo_task" + assert env["FLYTE_INTERNAL_TASK_VERSION"] == "v1" + assert env["FLYTE_INTERNAL_NODE_ID"] == "n0" + assert env["FLYTE_INTERNAL_TASK_RETRY_ATTEMPT"] == "0" + + def test_build_flyte_env_vars_includes_user_env(self): + meta = _make_task_execution_metadata(environment_variables={"MY_VAR": "hello"}) + env = EMRServerlessConnector._build_flyte_env_vars(meta) + assert env["MY_VAR"] == "hello" + + def test_build_flyte_env_vars_drops_empty_values(self): + meta = _make_task_execution_metadata(task_version="") + env = EMRServerlessConnector._build_flyte_env_vars(meta) + assert "FLYTE_INTERNAL_TASK_VERSION" not in env + assert "FLYTE_INTERNAL_EXECUTION_ID" in env + + def test_format_env_as_spark_conf_empty(self): + assert EMRServerlessConnector._format_env_as_spark_conf({}) == "" + + def test_format_env_as_spark_conf_produces_driver_and_executor(self): + conf = EMRServerlessConnector._format_env_as_spark_conf( + {"FLYTE_INTERNAL_EXECUTION_ID": "abc123"} + ) + assert "spark.emr-serverless.driverEnv.FLYTE_INTERNAL_EXECUTION_ID=abc123" in conf + assert "spark.executorEnv.FLYTE_INTERNAL_EXECUTION_ID=abc123" in conf + + def test_format_env_as_spark_conf_skips_unsafe_values(self): + conf = EMRServerlessConnector._format_env_as_spark_conf( + {"SAFE": "ok", "BAD_SPACES": "has spaces", "BAD_QUOTE": 'has"quote'} + ) + assert "SAFE=ok" in conf + assert "BAD_SPACES" not in conf + assert "BAD_QUOTE" not in conf + + def test_append_flyte_env_disabled(self): + meta = _make_task_execution_metadata() + result = EMRServerlessConnector._append_flyte_env_to_spark_params( + "--conf spark.executor.memory=4g", meta, enabled=False, + ) + assert result == "--conf spark.executor.memory=4g" + + def test_append_flyte_env_none_metadata(self): + result = EMRServerlessConnector._append_flyte_env_to_spark_params( + "--conf spark.executor.memory=4g", None, enabled=True, + ) + assert result == "--conf spark.executor.memory=4g" + + def test_append_flyte_env_preserves_existing_params(self): + meta = _make_task_execution_metadata() + result = EMRServerlessConnector._append_flyte_env_to_spark_params( + "--conf spark.executor.memory=4g", meta, enabled=True, + ) + assert result.startswith("--conf spark.executor.memory=4g ") + assert "FLYTE_INTERNAL_EXECUTION_ID=abc123def456" in result + + def test_append_flyte_env_when_existing_is_none(self): + meta = _make_task_execution_metadata() + result = EMRServerlessConnector._append_flyte_env_to_spark_params( + None, meta, enabled=True, + ) + assert "FLYTE_INTERNAL_EXECUTION_ID=abc123def456" in result + assert not result.startswith(" ") + + @pytest.mark.asyncio + async def test_create_injects_flyte_env_in_script_mode(self): + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver( + entry_point="s3://bucket/main.py", + spark_submit_parameters="--conf spark.executor.memory=4g", + ), + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + meta = _make_task_execution_metadata(exec_name="exec-xyz-789") + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template, task_execution_metadata=meta) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + submit_params = call_kwargs["job_driver"]["sparkSubmit"]["sparkSubmitParameters"] + assert "spark.executor.memory=4g" in submit_params + assert "spark.emr-serverless.driverEnv.FLYTE_INTERNAL_EXECUTION_ID=exec-xyz-789" in submit_params + assert "spark.executorEnv.FLYTE_INTERNAL_EXECUTION_ID=exec-xyz-789" in submit_params + + @pytest.mark.asyncio + async def test_create_respects_inject_flyte_env_false(self): + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver( + entry_point="s3://bucket/main.py", + spark_submit_parameters="--conf spark.executor.memory=4g", + ), + sync_image=False, + region="us-east-1", + inject_flyte_env=False, + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + meta = _make_task_execution_metadata() + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template, task_execution_metadata=meta) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + submit_params = call_kwargs["job_driver"]["sparkSubmit"]["sparkSubmitParameters"] + assert "FLYTE_INTERNAL_EXECUTION_ID" not in submit_params + + @pytest.mark.asyncio + async def test_create_injects_flyte_env_in_pythonic_mode(self): + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_container = MagicMock() + mock_container.image = "my-image:latest" + mock_container.args = ["pyflyte-execute", "--task-module", "m"] + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = mock_container + + meta = _make_task_execution_metadata(exec_name="py-exec-001") + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with patch.object( + connector, "_ensure_entrypoint_on_s3", + return_value="s3://bucket/entrypoint.py", + ): + await connector.create(mock_template, task_execution_metadata=meta) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + submit_params = call_kwargs["job_driver"]["sparkSubmit"]["sparkSubmitParameters"] + assert "spark.emr-serverless.driverEnv.FLYTE_INTERNAL_EXECUTION_ID=py-exec-001" in submit_params + assert "spark.executorEnv.FLYTE_INTERNAL_EXECUTION_ID=py-exec-001" in submit_params + + @pytest.mark.asyncio + async def test_create_handles_missing_metadata_gracefully(self): + """No task_execution_metadata (e.g. local mode) should not raise.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://b/m.py"), + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + submit_params = call_kwargs["job_driver"]["sparkSubmit"].get("sparkSubmitParameters") + if submit_params is not None: + assert "FLYTE_INTERNAL_EXECUTION_ID" not in submit_params + + +class TestEntrypointBucketResolution: + """Resolution order for the Pythonic-mode entrypoint S3 bucket. + + Documented contract (see :py:meth:`EMRServerlessConnector._resolve_entrypoint_bucket`): + + 1. ``FLYTE_EMR_ENTRYPOINT_S3_BUCKET`` env var (preferred). + 2. ``configuration_overrides.monitoringConfiguration.s3MonitoringConfiguration.logUri`` + (reuses the same bucket the user already uses for monitoring logs). + 3. ``FLYTE_AWS_S3_BUCKET`` env var (Flyte's general-purpose bucket). + 4. Otherwise, raise -- Pythonic mode requires a bucket. + """ + + def test_env_var_takes_priority(self, sample_task_config): + with patch.dict( + os.environ, + {"FLYTE_EMR_ENTRYPOINT_S3_BUCKET": "explicit-bucket", "FLYTE_AWS_S3_BUCKET": "fallback"}, + clear=False, + ): + assert ( + EMRServerlessConnector._resolve_entrypoint_bucket(sample_task_config) == "explicit-bucket" + ) + + def test_env_var_strips_s3_prefix(self, sample_task_config): + with patch.dict( + os.environ, {"FLYTE_EMR_ENTRYPOINT_S3_BUCKET": "s3://my-bucket/some/prefix"}, clear=False + ): + assert ( + EMRServerlessConnector._resolve_entrypoint_bucket(sample_task_config) == "my-bucket" + ) + + def test_falls_back_to_monitoring_log_uri(self, sample_task_config): + sample_task_config.configuration_overrides = { + "monitoringConfiguration": { + "s3MonitoringConfiguration": {"logUri": "s3://monitoring-bucket/logs"} + } + } + with patch.dict( + os.environ, {"FLYTE_EMR_ENTRYPOINT_S3_BUCKET": "", "FLYTE_AWS_S3_BUCKET": ""}, clear=False + ): + os.environ.pop("FLYTE_EMR_ENTRYPOINT_S3_BUCKET", None) + os.environ.pop("FLYTE_AWS_S3_BUCKET", None) + assert ( + EMRServerlessConnector._resolve_entrypoint_bucket(sample_task_config) + == "monitoring-bucket" + ) + + def test_falls_back_to_flyte_aws_s3_bucket(self, sample_task_config): + with patch.dict( + os.environ, {"FLYTE_AWS_S3_BUCKET": "flyte-bucket"}, clear=False + ): + os.environ.pop("FLYTE_EMR_ENTRYPOINT_S3_BUCKET", None) + assert ( + EMRServerlessConnector._resolve_entrypoint_bucket(sample_task_config) == "flyte-bucket" + ) + + def test_raises_when_no_source_configured(self, sample_task_config): + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("FLYTE_EMR_ENTRYPOINT_S3_BUCKET", None) + os.environ.pop("FLYTE_AWS_S3_BUCKET", None) + with pytest.raises(ValueError, match="entrypoint"): + EMRServerlessConnector._resolve_entrypoint_bucket(sample_task_config) + + +class TestEnsureEntrypointOnS3: + """The connector uploads the entrypoint exactly once per content-hash, + keyed under ``flyte/emr-serverless/entrypoint-.py``.""" + + def _make_handler_with_region(self): + h = MagicMock() + h.region = "us-east-1" + return h + + def test_skips_upload_when_already_present(self, sample_task_config): + from flytekitplugins.awsemrserverless.connector import EMRServerlessConnector + + connector = EMRServerlessConnector() + handler = self._make_handler_with_region() + + s3_client = MagicMock() + s3_client.head_object.return_value = {"ContentLength": 1234} + + with ( + patch.dict(os.environ, {"FLYTE_EMR_ENTRYPOINT_S3_BUCKET": "test-bucket"}, clear=False), + patch("boto3.client", return_value=s3_client), + ): + uri = connector._ensure_entrypoint_on_s3(handler, sample_task_config) + + assert uri.startswith("s3://test-bucket/flyte/emr-serverless/entrypoint-") + assert uri.endswith(".py") + s3_client.head_object.assert_called_once() + s3_client.put_object.assert_not_called() + + def test_uploads_when_missing(self, sample_task_config): + from flytekitplugins.awsemrserverless.connector import EMRServerlessConnector + + connector = EMRServerlessConnector() + handler = self._make_handler_with_region() + + s3_client = MagicMock() + s3_client.head_object.side_effect = Exception("NotFound") + + with ( + patch.dict(os.environ, {"FLYTE_EMR_ENTRYPOINT_S3_BUCKET": "test-bucket"}, clear=False), + patch("boto3.client", return_value=s3_client), + ): + uri = connector._ensure_entrypoint_on_s3(handler, sample_task_config) + + s3_client.put_object.assert_called_once() + put_kwargs = s3_client.put_object.call_args.kwargs + assert put_kwargs["Bucket"] == "test-bucket" + assert put_kwargs["Key"].startswith("flyte/emr-serverless/entrypoint-") + assert put_kwargs["ContentType"] == "text/x-python" + # The body is the byte-identical content of _entrypoint.py + from flytekitplugins.awsemrserverless import _entrypoint + from pathlib import Path + + assert put_kwargs["Body"] == Path(_entrypoint.__file__).read_bytes() + assert uri == f"s3://test-bucket/{put_kwargs['Key']}" + + def test_uri_is_deterministic_across_calls(self, sample_task_config): + """Two consecutive calls produce the same URI -- this is what + makes EMR's job specs stable across connector restarts.""" + from flytekitplugins.awsemrserverless.connector import EMRServerlessConnector + + connector = EMRServerlessConnector() + handler = self._make_handler_with_region() + s3_client = MagicMock() + s3_client.head_object.return_value = {"ContentLength": 1} + + with ( + patch.dict(os.environ, {"FLYTE_EMR_ENTRYPOINT_S3_BUCKET": "b"}, clear=False), + patch("boto3.client", return_value=s3_client), + ): + uri1 = connector._ensure_entrypoint_on_s3(handler, sample_task_config) + uri2 = connector._ensure_entrypoint_on_s3(handler, sample_task_config) + + assert uri1 == uri2 diff --git a/plugins/flytekit-aws-emr-serverless/tests/test_entrypoint.py b/plugins/flytekit-aws-emr-serverless/tests/test_entrypoint.py new file mode 100644 index 0000000000..0e8b5f7e61 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/tests/test_entrypoint.py @@ -0,0 +1,338 @@ +""" +Unit tests for the EMR Serverless Pythonic-mode entrypoint script. + +These tests exercise the entrypoint module on two levels: + +* **Pure unit tests** of the helper functions (``_parse_fast_execute_args``, + ``_build_resolver_command``, ``_exit_with_code``) by importing them + directly. These are fast, deterministic, and require no subprocess. +* **Subprocess tests** that run ``python -m`` on the actual file the + connector uploads to S3, verifying the wire-level CLI behaviour EMR + will see. + +The shape of these tests mirrors how Databricks's ``flytetools`` +maintains test coverage for ``flytekitplugins/databricks/entrypoint.py``. +""" + +import hashlib +import subprocess +import sys +from pathlib import Path +from unittest.mock import patch + +import pytest + +from flytekitplugins.awsemrserverless import _entrypoint + + +ENTRYPOINT_PATH: Path = Path(_entrypoint.__file__).resolve() + + +# --------------------------------------------------------------------------- +# Module-level invariants +# --------------------------------------------------------------------------- + + +class TestEntrypointModuleInvariants: + """Sanity checks on the entrypoint module itself.""" + + def test_entrypoint_file_exists(self): + assert ENTRYPOINT_PATH.is_file(), ( + f"_entrypoint.py must exist at {ENTRYPOINT_PATH}; the connector uploads this file to S3 and EMR runs it." + ) + + def test_entrypoint_is_valid_python(self): + source = ENTRYPOINT_PATH.read_text() + compile(source, str(ENTRYPOINT_PATH), "exec") + + def test_entrypoint_has_main(self): + assert callable(_entrypoint.main) + + def test_entrypoint_minimal_runtime_imports(self): + """The entrypoint runs inside the EMR worker; its module-level + imports must be available there. ``flytekit`` is the only + non-stdlib module we should import (specifically + ``download_distribution``).""" + source = ENTRYPOINT_PATH.read_text() + assert "from flytekit.tools.fast_registration import download_distribution" in source + assert "import boto3" not in source, ( + "_entrypoint.py must not import boto3 -- it runs in the EMR " + "worker which may not have boto3 installed at the version " + "the connector uses." + ) + + +# --------------------------------------------------------------------------- +# _parse_fast_execute_args +# --------------------------------------------------------------------------- + + +class TestParseFastExecuteArgs: + """The parser splits a ``pyflyte-fast-execute ...`` argv into its parts.""" + + def test_full_argv(self): + args = [ + "pyflyte-fast-execute", + "--additional-distribution", + "s3://bucket/fast-abc.tar.gz", + "--dest-dir", + "/tmp/work", + "pyflyte-execute", + "--inputs", + "s3://in", + "--output-prefix", + "s3://out", + ] + addl, dest, start = _entrypoint._parse_fast_execute_args(args) + assert addl == "s3://bucket/fast-abc.tar.gz" + assert dest == "/tmp/work" + assert args[start] == "pyflyte-execute" + + def test_only_additional_distribution(self): + args = [ + "pyflyte-fast-execute", + "--additional-distribution", + "s3://bucket/fast.tar.gz", + "pyflyte-execute", + "--inputs", + "s3://in", + ] + addl, dest, start = _entrypoint._parse_fast_execute_args(args) + assert addl == "s3://bucket/fast.tar.gz" + assert dest is None + assert args[start] == "pyflyte-execute" + + def test_no_distribution_flags(self): + args = ["pyflyte-fast-execute", "pyflyte-execute", "--inputs", "s3://in"] + addl, dest, start = _entrypoint._parse_fast_execute_args(args) + assert addl is None + assert dest is None + assert args[start] == "pyflyte-execute" + + def test_explicit_double_dash_separator(self): + args = [ + "pyflyte-fast-execute", + "--additional-distribution", + "s3://b/a.tar.gz", + "--", + "pyflyte-execute", + "--inputs", + "s3://in", + ] + addl, dest, start = _entrypoint._parse_fast_execute_args(args) + assert addl == "s3://b/a.tar.gz" + assert dest is None + assert args[start] == "pyflyte-execute" + + def test_dangling_flag_does_not_crash(self): + """``--additional-distribution`` with no value following must not crash. + + With no value, the loop falls through to the catch-all ``else`` branch + and treats the flag itself as the start of the task command. This is + recoverable nonsense (the inner subprocess will fail cleanly) -- the + important property is that the parser doesn't index past the end. + """ + args = ["pyflyte-fast-execute", "--additional-distribution"] + addl, dest, start = _entrypoint._parse_fast_execute_args(args) + assert addl is None + assert dest is None + assert 0 <= start < len(args) + + +# --------------------------------------------------------------------------- +# _build_resolver_command +# --------------------------------------------------------------------------- + + +class TestBuildResolverCommand: + """The resolver builder must inject ``--dynamic-addl-distro`` / + ``--dynamic-dest-dir`` immediately before any ``--resolver`` flag.""" + + def test_injects_dynamic_args_before_resolver(self): + task_cmd = [ + "pyflyte-execute", + "--inputs", + "s3://in", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + ] + out = _entrypoint._build_resolver_command(task_cmd, "s3://b/fast.tar.gz", "/tmp/work") + assert "--dynamic-addl-distro" in out + assert "s3://b/fast.tar.gz" in out + assert "--dynamic-dest-dir" in out + assert "/tmp/work" in out + assert out.index("--dynamic-addl-distro") < out.index("--resolver") + assert out.index("--dynamic-dest-dir") < out.index("--resolver") + + def test_no_resolver_flag_means_no_injection(self): + task_cmd = ["pyflyte-execute", "--inputs", "s3://in"] + out = _entrypoint._build_resolver_command(task_cmd, "s3://x", "/tmp") + assert out == task_cmd + + def test_handles_none_distribution_args(self): + task_cmd = ["pyflyte-execute", "--resolver", "x"] + out = _entrypoint._build_resolver_command(task_cmd, None, None) + assert "--dynamic-addl-distro" in out + assert "--dynamic-dest-dir" in out + addl_idx = out.index("--dynamic-addl-distro") + dest_idx = out.index("--dynamic-dest-dir") + assert out[addl_idx + 1] == "" + assert out[dest_idx + 1] == "" + + +# --------------------------------------------------------------------------- +# _exit_with_code +# --------------------------------------------------------------------------- + + +class TestExitWithCode: + """The exit handler bridges Flytekit's exit semantics to EMR's + state-reporting model. + + The contract: + * non-zero rc → exit with rc (EMR sees FAILED) + * rc=0 with user-error banner in stderr → exit 1 (EMR sees FAILED) + * rc=0 without banner → exit 0 (EMR sees SUCCESS) + """ + + def test_nonzero_rc_propagates(self): + with pytest.raises(SystemExit) as exc: + _entrypoint._exit_with_code(42, "") + assert exc.value.code == 42 + + def test_zero_rc_clean_exits_zero(self): + with pytest.raises(SystemExit) as exc: + _entrypoint._exit_with_code(0, "all good") + assert exc.value.code == 0 + + def test_zero_rc_with_user_error_forces_failure(self): + """Flytekit catches user exceptions and exits 0; we must override + that to ensure EMR reports FAILED, not SUCCESS.""" + stderr = "Traceback ...\nUser Error Captured by Flyte: TypeError: x must be int\n" + with pytest.raises(SystemExit) as exc: + _entrypoint._exit_with_code(0, stderr) + assert exc.value.code == 1 + + +# --------------------------------------------------------------------------- +# main() dispatch +# --------------------------------------------------------------------------- + + +class TestMainDispatch: + """``main`` reads ``sys.argv`` and dispatches to the right branch.""" + + def test_no_args_exits_nonzero(self): + with patch.object(sys, "argv", ["entrypoint.py"]): + with pytest.raises(SystemExit) as exc: + _entrypoint.main() + assert exc.value.code == 1 + + def test_unknown_command_exits_nonzero(self): + with patch.object(sys, "argv", ["entrypoint.py", "unknown-cmd"]): + with pytest.raises(SystemExit) as exc: + _entrypoint.main() + assert exc.value.code == 1 + + def test_pyflyte_execute_branch_invokes_subprocess(self): + """When given ``pyflyte-execute ...``, main runs it as a subprocess + with PYTHONPATH defaulted to cwd and exits with rc=0.""" + argv = ["entrypoint.py", "pyflyte-execute", "--inputs", "s3://x"] + with ( + patch.object(sys, "argv", argv), + patch.object(_entrypoint, "_run_subprocess", return_value=(0, "")) as m_run, + ): + with pytest.raises(SystemExit) as exc: + _entrypoint.main() + assert exc.value.code == 0 + cmd_arg, env_arg = m_run.call_args[0] + assert cmd_arg == ["pyflyte-execute", "--inputs", "s3://x"] + assert "PYTHONPATH" in env_arg + + def test_fast_execute_branch_downloads_and_invokes(self): + """When given ``pyflyte-fast-execute ...``, main downloads the + distribution, builds the resolver-aware command, and exits.""" + argv = [ + "entrypoint.py", + "pyflyte-fast-execute", + "--additional-distribution", + "s3://b/fast.tar.gz", + "--dest-dir", + "/tmp/work", + "pyflyte-execute", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + ] + with ( + patch.object(sys, "argv", argv), + patch.object(_entrypoint, "download_distribution") as m_dl, + patch.object(_entrypoint, "_run_subprocess", return_value=(0, "")) as m_run, + ): + with pytest.raises(SystemExit) as exc: + _entrypoint.main() + assert exc.value.code == 0 + m_dl.assert_called_once_with("s3://b/fast.tar.gz", "/tmp/work") + cmd_arg, env_arg = m_run.call_args[0] + assert "--dynamic-addl-distro" in cmd_arg + assert "s3://b/fast.tar.gz" in cmd_arg + assert "/tmp/work" in env_arg["PYTHONPATH"] + + +# --------------------------------------------------------------------------- +# End-to-end CLI behaviour (run the file as EMR will) +# --------------------------------------------------------------------------- + + +class TestEntrypointAsScript: + """Run the entrypoint file as a subprocess to verify the wire-level + behaviour EMR's spark-submit will observe.""" + + def test_no_args_exits_one_with_usage(self): + proc = subprocess.run( + [sys.executable, str(ENTRYPOINT_PATH)], + capture_output=True, + text=True, + ) + assert proc.returncode == 1 + assert "Usage" in proc.stderr + + def test_unknown_command_exits_one(self): + proc = subprocess.run( + [sys.executable, str(ENTRYPOINT_PATH), "definitely-not-a-flyte-command"], + capture_output=True, + text=True, + ) + assert proc.returncode == 1 + assert "Unrecognized command" in proc.stderr + + +# --------------------------------------------------------------------------- +# Connector ↔ entrypoint coupling +# --------------------------------------------------------------------------- + + +class TestConnectorEntrypointWiring: + """Ensure the connector reads exactly this file and produces a stable, + content-addressed hash.""" + + def test_connector_path_points_at_this_module(self): + from flytekitplugins.awsemrserverless.connector import EMRServerlessConnector + + assert EMRServerlessConnector._ENTRYPOINT_PATH == ENTRYPOINT_PATH + + def test_connector_reads_byte_identical_content(self): + from flytekitplugins.awsemrserverless.connector import EMRServerlessConnector + + EMRServerlessConnector._entrypoint_bytes_cache = None # bust cache + assert EMRServerlessConnector._read_entrypoint_bytes() == ENTRYPOINT_PATH.read_bytes() + + def test_hash_is_stable_for_unchanged_content(self): + """The S3 key contains the first 12 hex chars of sha256(content); + as long as the file content is unchanged, the key is unchanged.""" + from flytekitplugins.awsemrserverless.connector import EMRServerlessConnector + + EMRServerlessConnector._entrypoint_bytes_cache = None + content = EMRServerlessConnector._read_entrypoint_bytes() + h1 = hashlib.sha256(content).hexdigest()[:12] + h2 = hashlib.sha256(ENTRYPOINT_PATH.read_bytes()).hexdigest()[:12] + assert h1 == h2 diff --git a/plugins/flytekit-aws-emr-serverless/tests/test_task.py b/plugins/flytekit-aws-emr-serverless/tests/test_task.py new file mode 100644 index 0000000000..226997c74a --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/tests/test_task.py @@ -0,0 +1,346 @@ +""" +Unit tests for EMR Serverless Task and configuration. +""" + +from unittest.mock import MagicMock + +import pytest +from flytekit import task + +from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessHiveJobDriver, + EMRServerlessSparkJobDriver, + EMRServerlessTask, +) + + +class TestEMRServerlessSparkJobDriver: + def test_basic_creation(self): + driver = EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py") + assert driver.entry_point == "s3://bucket/main.py" + assert driver.entry_point_arguments is None + assert driver.spark_submit_parameters is None + + def test_full_creation(self): + driver = EMRServerlessSparkJobDriver( + entry_point="s3://bucket/main.py", + entry_point_arguments=["--arg1", "value1"], + spark_submit_parameters="--conf spark.executor.memory=4g", + ) + assert driver.entry_point_arguments == ["--arg1", "value1"] + + def test_to_dict_basic(self): + driver = EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py") + assert driver.to_dict() == {"entryPoint": "s3://bucket/main.py"} + + def test_to_dict_full(self): + driver = EMRServerlessSparkJobDriver( + entry_point="s3://bucket/main.py", + entry_point_arguments=["--arg1", "value1"], + spark_submit_parameters="--conf spark.executor.memory=4g", + ) + result = driver.to_dict() + assert result["entryPoint"] == "s3://bucket/main.py" + assert result["entryPointArguments"] == ["--arg1", "value1"] + assert result["sparkSubmitParameters"] == "--conf spark.executor.memory=4g" + + +class TestEMRServerlessHiveJobDriver: + def test_basic_creation(self): + driver = EMRServerlessHiveJobDriver(query="SELECT * FROM table") + assert driver.query == "SELECT * FROM table" + assert driver.init_query_file is None + + def test_to_dict(self): + driver = EMRServerlessHiveJobDriver( + query="s3://bucket/query.sql", + init_query_file="s3://bucket/init.sql", + ) + result = driver.to_dict() + assert result == { + "query": "s3://bucket/query.sql", + "initQueryFile": "s3://bucket/init.sql", + } + + +class TestEMRServerless: + def test_basic_spark_config(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + ) + assert config.application_type == "SPARK" + assert config.is_script_mode is True + + def test_pythonic_mode_config(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + ) + assert config.is_script_mode is False + + def test_hive_config(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + application_type="HIVE", + hive_job_driver=EMRServerlessHiveJobDriver(query="s3://bucket/query.sql"), + ) + assert config.application_type == "HIVE" + assert config.is_script_mode is True + + def test_missing_execution_role(self): + with pytest.raises(ValueError, match="execution_role_arn is required"): + EMRServerless(execution_role_arn="") + + def test_invalid_application_type(self): + with pytest.raises(ValueError, match="application_type must be"): + EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_type="INVALID", + ) + + def test_missing_hive_driver(self): + with pytest.raises(ValueError, match="hive_job_driver is required"): + EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_type="HIVE", + ) + + def test_both_drivers_specified(self): + with pytest.raises(ValueError, match="Only one of"): + EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + hive_job_driver=EMRServerlessHiveJobDriver(query="SELECT 1"), + ) + + def test_invalid_timeout(self): + with pytest.raises(ValueError, match="execution_timeout_minutes"): + EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + execution_timeout_minutes=0, + ) + + def test_get_job_driver_spark(self, sample_spark_job_driver): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + ) + result = config.get_job_driver() + assert "sparkSubmit" in result + + def test_get_job_driver_hive(self, sample_hive_job_driver): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_type="HIVE", + hive_job_driver=sample_hive_job_driver, + ) + result = config.get_job_driver() + assert "hive" in result + + def test_to_dict(self, sample_config): + result = sample_config.to_dict() + assert result["execution_role_arn"] == "arn:aws:iam::123456789012:role/EMRServerlessRole" + assert result["application_id"] == "00f5abc123def456" + assert result["region"] == "us-east-1" + assert "spark_job_driver" in result + + def test_from_dict_roundtrip(self, sample_config): + config_dict = sample_config.to_dict() + restored = EMRServerless.from_dict(config_dict) + assert restored.execution_role_arn == sample_config.execution_role_arn + assert restored.application_id == sample_config.application_id + assert restored.region == sample_config.region + + def test_full_config_with_all_options(self, sample_spark_job_driver): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + application_id="00abc123", + release_label="emr-6.15.0", + application_name="my-app", + spark_job_driver=sample_spark_job_driver, + configuration_overrides={ + "monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": "s3://logs/"}} + }, + tags={"Env": "prod"}, + execution_timeout_minutes=240, + initial_capacity={"DRIVER": {"workerCount": 1}}, + maximum_capacity={"cpu": "100vCPU"}, + network_configuration={"subnetIds": ["subnet-123"]}, + image_configuration={"imageUri": "123456789012.dkr.ecr.us-east-1.amazonaws.com/spark:latest"}, + region="us-west-2", + ) + d = config.to_dict() + assert d["release_label"] == "emr-6.15.0" + assert d["application_name"] == "my-app" + assert d["tags"]["Env"] == "prod" + assert d["network_configuration"]["subnetIds"] == ["subnet-123"] + + def test_sync_image_default(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + application_id="00abc123", + ) + assert config.sync_image is True + + def test_architecture_validation(self): + with pytest.raises(ValueError, match="architecture must be one of"): + EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + architecture="INVALID", + ) + + def test_valid_architectures(self): + for arch in ("X86_64", "ARM64"): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + architecture=arch, + ) + assert config.architecture == arch + + def test_new_fields_to_dict_roundtrip(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + sync_image=False, + spark_submit_parameters="--conf spark.executor.memory=16g", + application_configuration=[ + {"classification": "spark-defaults", "properties": {"spark.executor.cores": "8"}} + ], + runtime_configuration=[ + {"classification": "spark-defaults", "properties": {"spark.dynamicAllocation.enabled": "true"}} + ], + scheduler_configuration={"maxConcurrentRuns": 5, "queueTimeoutMinutes": 30}, + auto_stop_config={"enabled": True, "idleTimeoutMinutes": 30}, + architecture="ARM64", + retry_policy={"maxAttempts": 3}, + ) + d = config.to_dict() + restored = EMRServerless.from_dict(d) + + assert restored.sync_image is False + assert restored.spark_submit_parameters == "--conf spark.executor.memory=16g" + assert len(restored.application_configuration) == 1 + assert restored.application_configuration[0]["classification"] == "spark-defaults" + assert len(restored.runtime_configuration) == 1 + assert restored.scheduler_configuration == {"maxConcurrentRuns": 5, "queueTimeoutMinutes": 30} + assert restored.auto_stop_config == {"enabled": True, "idleTimeoutMinutes": 30} + assert restored.architecture == "ARM64" + assert restored.retry_policy == {"maxAttempts": 3} + + def test_effective_configuration_overrides_merge(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + configuration_overrides={ + "monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": "s3://logs/"}} + }, + application_configuration=[ + {"classification": "spark-defaults", "properties": {"spark.executor.cores": "8"}} + ], + ) + result = config.get_effective_configuration_overrides() + assert "monitoringConfiguration" in result + assert len(result["applicationConfiguration"]) == 1 + + def test_effective_configuration_overrides_none(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + ) + assert config.get_effective_configuration_overrides() is None + + def test_effective_configuration_overrides_only_app_config(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_configuration=[ + {"classification": "spark", "properties": {"dynamicAllocationOptimization": "true"}} + ], + ) + result = config.get_effective_configuration_overrides() + assert len(result["applicationConfiguration"]) == 1 + + +class TestEMRServerlessTask: + def test_task_type(self): + assert EMRServerlessTask._TASK_TYPE == "emr_serverless" + + def test_task_creation_with_decorator(self, sample_task_config): + @task(task_config=sample_task_config) + def my_spark_job() -> str: + return "done" + + assert isinstance(my_spark_job, EMRServerlessTask) + assert my_spark_job.task_config == sample_task_config + + def test_get_custom(self, sample_task_config): + @task(task_config=sample_task_config) + def my_spark_job() -> str: + return "done" + + mock_settings = MagicMock() + custom = my_spark_job.get_custom(mock_settings) + + assert isinstance(custom, dict) + assert custom["execution_role_arn"] == sample_task_config.execution_role_arn + assert custom["application_id"] == sample_task_config.application_id + assert "spark_job_driver" in custom + + def test_get_custom_with_hive(self): + task_config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + application_type="HIVE", + hive_job_driver=EMRServerlessHiveJobDriver( + query="s3://bucket/query.sql", + init_query_file="s3://bucket/init.sql", + ), + ) + + @task(task_config=task_config) + def hive_task() -> str: + return "done" + + custom = hive_task.get_custom(MagicMock()) + assert custom["application_type"] == "HIVE" + assert custom["hive_job_driver"]["query"] == "s3://bucket/query.sql" + + def test_task_inherits_from_correct_classes(self): + from flytekit.core.python_function_task import PythonFunctionTask + + try: + from flytekit.extend.backend.base_connector import AsyncConnectorExecutorMixin + except ModuleNotFoundError: + from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin as AsyncConnectorExecutorMixin + + assert issubclass(EMRServerlessTask, PythonFunctionTask) + assert issubclass(EMRServerlessTask, AsyncConnectorExecutorMixin) + + def test_task_decorator_produces_emr_serverless_task(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + application_id="app-123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + ) + + @task(task_config=config) + def my_emr_task() -> str: + return "done" + + assert isinstance(my_emr_task, EMRServerlessTask) + assert my_emr_task._task_type == "emr_serverless" + + def test_pythonic_mode_task(self): + """Test that a task with no job driver creates successfully (Pythonic mode).""" + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + application_id="app-123", + region="us-east-1", + ) + + @task(task_config=config) + def my_pythonic_spark_job() -> str: + return "done" + + assert isinstance(my_pythonic_spark_job, EMRServerlessTask) + + custom = my_pythonic_spark_job.get_custom(MagicMock()) + assert "spark_job_driver" not in custom + assert custom["application_type"] == "SPARK" From d808707811ab36764db661c79e7e12dc2ead1dad Mon Sep 17 00:00:00 2001 From: Rohit Sharma Date: Sun, 26 Apr 2026 21:38:43 +0100 Subject: [PATCH 2/3] Fix lint: simplify pydoclint-validated docstring sections The initial commit failed CI lint because pydoclint (configured with --style=google and --arg-type-hints-in-docstring=True via the upstream pre-commit hook) requires Args / Returns / Attributes sections to carry type hints that match the signatures and dataclass fields exactly. Convert the structured Args/Returns sections in EMRServerlessHandler helpers and the Attributes lists on the EMRServerless* dataclasses into prose so pydoclint has nothing structured to validate against. The type information remains available on the function signatures and dataclass fields. Also re-order an import block flagged by ruff I001. 138 tests still pass. Made-with: Cursor Signed-off-by: Rohit Sharma --- .../awsemrserverless/boto_handler.py | 31 ++--- .../awsemrserverless/connector.py | 6 +- .../flytekitplugins/awsemrserverless/task.py | 125 ++++++++---------- 3 files changed, 67 insertions(+), 95 deletions(-) diff --git a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/boto_handler.py b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/boto_handler.py index c686667ce1..0cfb82e2e0 100644 --- a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/boto_handler.py +++ b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/boto_handler.py @@ -66,16 +66,11 @@ def client(self) -> Any: async def _call(self, method: str, **params: Any) -> Dict[str, Any]: """Invoke a boto3 EMR Serverless client method in the thread-pool executor. - This is the single chokepoint for all non-paginated boto3 API calls. - Tests can mock the entire handler surface by patching just this method. - - Args: - method: The boto3 client method name (e.g. ``"get_application"``). - **params: Keyword arguments forwarded to the boto3 call in the - camelCase form the AWS SDK expects. - - Returns: - The raw response dict from the boto3 call. + This is the single chokepoint for all non-paginated boto3 API calls + (``method`` is the boto3 client method name and ``params`` are keyword + arguments forwarded to the call in the camelCase form the AWS SDK + expects). Tests can mock the entire handler surface by patching just + this method. Returns the raw response dict from the boto3 call. """ func = getattr(self.client, method) loop = asyncio.get_event_loop() @@ -85,16 +80,12 @@ async def _paginate(self, method: str, result_key: str, **params: Any) -> List[D """Exhaust a boto3 paginator and return the concatenated items. Used for list operations (e.g. ``list_applications``) that may span - multiple pages. Tests can mock this independently from ``_call``. - - Args: - method: The boto3 client method to paginate (e.g. ``"list_applications"``). - result_key: The key under which each page stores items - (e.g. ``"applications"`` for ``list_applications``). - **params: Keyword arguments forwarded to ``paginator.paginate()``. - - Returns: - A flat list of items across all pages. + multiple pages. ``method`` is the boto3 client method to paginate, + ``result_key`` is the key under which each page stores items + (e.g. ``"applications"`` for ``list_applications``), and ``params`` + are keyword arguments forwarded to ``paginator.paginate()``. Tests + can mock this independently from ``_call``. Returns a flat list of + items across all pages. """ def _collect() -> List[Dict[str, Any]]: diff --git a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/connector.py b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/connector.py index 11cefbb026..52c41a9f8b 100644 --- a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/connector.py +++ b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/connector.py @@ -34,14 +34,14 @@ ResourceMeta, ) +from flytekitplugins.awsemrserverless.boto_handler import EMRServerlessHandler +from flytekitplugins.awsemrserverless.task import EMR_SERVERLESS_BASE_IMAGE, EMRServerless + from flytekit.extend.backend.utils import convert_to_flyte_phase from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flytekitplugins.awsemrserverless.boto_handler import EMRServerlessHandler -from flytekitplugins.awsemrserverless.task import EMR_SERVERLESS_BASE_IMAGE, EMRServerless - logger = logging.getLogger(__name__) # Maps EMR Serverless job run states to the canonical flytekit phase strings diff --git a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/task.py b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/task.py index 3a92a3cab0..4b644563a2 100644 --- a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/task.py +++ b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/task.py @@ -31,15 +31,13 @@ @dataclass class EMRServerlessSparkJobDriver: - """ - Spark job driver configuration for EMR Serverless. - - Attributes: - entry_point: S3 path to the main application file - (e.g., ``s3://bucket/scripts/main.py``). - entry_point_arguments: Arguments to pass to the main application. - spark_submit_parameters: Spark submit parameters string - (e.g., ``--conf spark.executor.memory=4g``). + """Spark job driver configuration for EMR Serverless. + + ``entry_point`` is the S3 path to the main application file (for example + ``s3://bucket/scripts/main.py``). ``entry_point_arguments`` is an optional + list of arguments forwarded to the application, and + ``spark_submit_parameters`` is the optional spark-submit parameter string + (for example ``--conf spark.executor.memory=4g``). """ entry_point: str @@ -58,13 +56,11 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class EMRServerlessHiveJobDriver: - """ - Hive job driver configuration for EMR Serverless. + """Hive job driver configuration for EMR Serverless. - Attributes: - query: The Hive query to execute, or S3 path to a query file. - init_query_file: S3 path to an initialization query file. - parameters: Parameters string for the query. + ``query`` is the Hive query to execute, or an S3 path to a query file. + ``init_query_file`` is an optional S3 path to an initialization query + file, and ``parameters`` is an optional parameters string for the query. """ query: str @@ -126,63 +122,48 @@ def my_spark_job(): spark = SparkSession.builder.getOrCreate() spark.range(100).write.parquet("s3://bucket/output") - Attributes: - execution_role_arn: IAM role ARN for job execution (required). - application_id: Existing EMR Serverless application ID. If not - provided and ``application_name`` is set, the connector will - attempt to create a new application (subject to the - connector-level ``FLYTE_EMR_ALLOW_CREATE_APPLICATION`` policy). - release_label: EMR release label (default: ``emr-7.0.0``). - application_type: ``"SPARK"`` or ``"HIVE"`` (default: ``"SPARK"``). - application_name: Name for a new application. When - ``application_id`` is not set and ``application_name`` is - provided, the connector will create a new application with - this name (prefixed by the connector's - ``FLYTE_EMR_APPLICATION_NAME_PREFIX`` if configured). - sync_image: When ``True`` (the default), the connector will update - the application's image if the task's ``container_image`` or - ``image_configuration`` differs from what the app currently has. - spark_job_driver: Spark job configuration. When omitted for a SPARK - application, Pythonic mode is used. - hive_job_driver: Hive job configuration. - spark_submit_parameters: Extra ``--conf`` flags to pass to - ``sparkSubmitParameters`` for *both* script and Pythonic modes. - In Pythonic mode these are merged with the connector defaults. - application_configuration: ``applicationConfiguration`` list for - ``configurationOverrides``. Merged with any entries already - in ``configuration_overrides``. - runtime_configuration: ``runtimeConfiguration`` passed to - ``CreateApplication`` or ``UpdateApplication``. - scheduler_configuration: ``SchedulerConfiguration`` for job - concurrency and queuing (e.g. - ``{"maxConcurrentRuns": 5, "queueTimeoutMinutes": 30}``). - auto_stop_config: ``AutoStopConfiguration`` for the application - (e.g. ``{"enabled": True, "idleTimeoutMinutes": 30}``). - architecture: Processor architecture for the application. - ``"X86_64"`` (default) or ``"ARM64"`` (Graviton). - retry_policy: Job-level retry policy for resiliency (e.g. - ``{"maxAttempts": 3}``). Requires EMR 7.1+. - configuration_overrides: Application and monitoring configuration. - tags: Resource tags. - execution_timeout_minutes: Maximum job execution time (default: 60). - initial_capacity: Pre-initialized worker capacity. - maximum_capacity: Auto-scaling limits. - network_configuration: VPC, subnet, and security group settings. - image_configuration: Custom container image settings. - region: AWS region (uses boto3 default if not specified). - inject_flyte_env: When ``True`` (default), the connector appends - ``--conf spark.emr-serverless.driverEnv.*`` and - ``--conf spark.executorEnv.*`` entries to ``sparkSubmitParameters`` - so the Spark driver and executors see Flyte runtime context as - environment variables: ``FLYTE_INTERNAL_EXECUTION_ID``, - ``FLYTE_INTERNAL_EXECUTION_PROJECT``, - ``FLYTE_INTERNAL_EXECUTION_DOMAIN``, - ``FLYTE_INTERNAL_TASK_{PROJECT,DOMAIN,NAME,VERSION}``, - ``FLYTE_INTERNAL_NODE_ID``, and - ``FLYTE_INTERNAL_TASK_RETRY_ATTEMPT``. Any - ``environment_variables`` set on ``task_execution_metadata`` - (via Flyte ``Environment`` policies) are also forwarded. Only - applies to Spark jobs. + Field summary: + + * ``execution_role_arn`` (required): IAM role ARN used to run the job. + * ``application_id``: existing EMR Serverless application ID. When unset + and ``application_name`` is provided, the connector will attempt to + create a new application (subject to the connector-level + ``FLYTE_EMR_ALLOW_CREATE_APPLICATION`` policy). + * ``release_label``: EMR release label, default ``emr-7.0.0``. + * ``application_type``: ``"SPARK"`` or ``"HIVE"`` (default ``"SPARK"``). + * ``application_name``: name for a new application; when set without + ``application_id`` the connector creates one (prefixed by + ``FLYTE_EMR_APPLICATION_NAME_PREFIX`` if configured). + * ``sync_image`` (default ``True``): the connector will update the + application image if the task's ``container_image`` or + ``image_configuration`` differs from what the app currently has. + * ``spark_job_driver`` / ``hive_job_driver``: explicit job driver for + script mode; omit ``spark_job_driver`` to use Pythonic Spark mode. + * ``spark_submit_parameters``: extra ``--conf`` flags applied in both + script and Pythonic modes (merged with the connector defaults in + Pythonic mode). + * ``application_configuration`` / ``runtime_configuration`` / + ``scheduler_configuration`` / ``auto_stop_config``: forwarded to + ``CreateApplication`` / ``UpdateApplication``. + * ``architecture``: ``"X86_64"`` (default) or ``"ARM64"`` (Graviton). + * ``retry_policy``: job-level retry policy (requires EMR 7.1+). + * ``configuration_overrides``: application and monitoring overrides. + * ``tags``: resource tags applied to the application/jobs. + * ``execution_timeout_minutes``: max job execution time, default 60. + * ``initial_capacity`` / ``maximum_capacity`` / ``network_configuration`` + / ``image_configuration``: forwarded to the application. + * ``region``: AWS region (uses the boto3 default if not specified). + * ``inject_flyte_env`` (default ``True``): when set, the connector + appends ``spark.emr-serverless.driverEnv.*`` and ``spark.executorEnv.*`` + entries to ``sparkSubmitParameters`` so the Spark driver and executors + see Flyte runtime context as environment variables + (``FLYTE_INTERNAL_EXECUTION_ID``, ``FLYTE_INTERNAL_EXECUTION_PROJECT``, + ``FLYTE_INTERNAL_EXECUTION_DOMAIN``, + ``FLYTE_INTERNAL_TASK_{PROJECT,DOMAIN,NAME,VERSION}``, + ``FLYTE_INTERNAL_NODE_ID``, ``FLYTE_INTERNAL_TASK_RETRY_ATTEMPT``). + Any ``environment_variables`` set on ``task_execution_metadata`` (via + Flyte ``Environment`` policies) are also forwarded. Only applies to + Spark jobs. """ execution_role_arn: str From d34c10267568b3ec4ecadac3db28fbe8410e94bb Mon Sep 17 00:00:00 2001 From: Rohit Sharma Date: Sun, 26 Apr 2026 22:10:21 +0100 Subject: [PATCH 3/3] Add reference worker Dockerfile for Pythonic Spark mode Pythonic Spark tasks need flytekit and this plugin on the EMR Serverless worker so pyflyte-execute can rehydrate the task module. Add a reference Dockerfile alongside the plugin that builds on the public EMR Serverless Spark base image and installs the matching flytekit / flytekitplugins-awsemrserverless versions, mirroring the Dockerfile that ships with flytekit-spark. Document the build command in the plugin README. Adopters who only run Script Spark or Hive jobs can keep using the upstream EMR Serverless Spark image and ignore this file. Signed-off-by: Rohit Sharma Made-with: Cursor --- .../flytekit-aws-emr-serverless/Dockerfile | 20 +++++++++++++++++++ plugins/flytekit-aws-emr-serverless/README.md | 13 +++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 plugins/flytekit-aws-emr-serverless/Dockerfile diff --git a/plugins/flytekit-aws-emr-serverless/Dockerfile b/plugins/flytekit-aws-emr-serverless/Dockerfile new file mode 100644 index 0000000000..4cc93b6a7b --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/Dockerfile @@ -0,0 +1,20 @@ +# Reference worker image for AWS EMR Serverless tasks running in Pythonic +# Spark mode. Script Spark and Hive jobs do not need flytekit on the worker +# and can use the upstream EMR Serverless Spark image directly. +# +# Build: +# docker build --build-arg VERSION= \ +# -t /emr-serverless-flytekit: . +ARG EMR_BASE_IMAGE=public.ecr.aws/emr-serverless/spark/emr-7.0.0:latest +FROM ${EMR_BASE_IMAGE} +LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit + +ARG VERSION + +USER root + +RUN pip3 install --no-cache-dir --ignore-installed \ + "flytekit==${VERSION}" \ + "flytekitplugins-awsemrserverless==${VERSION}" + +USER hadoop:hadoop diff --git a/plugins/flytekit-aws-emr-serverless/README.md b/plugins/flytekit-aws-emr-serverless/README.md index 3ecda1ad88..496de851c6 100644 --- a/plugins/flytekit-aws-emr-serverless/README.md +++ b/plugins/flytekit-aws-emr-serverless/README.md @@ -92,7 +92,18 @@ def hive_query(): ## Worker image -For Pythonic Spark tasks the worker image must contain `flytekit` and this plugin so the executor can rehydrate the task object on the EMR Serverless side. Build it from an EMR Serverless Spark base image and `pip install flytekitplugins-awsemrserverless`. Script Spark and Hive jobs do not require flytekit on the worker. +For Pythonic Spark tasks the worker image must contain `flytekit` and this plugin so the executor can rehydrate the task object on the EMR Serverless side. Script Spark and Hive jobs do not require flytekit on the worker. + +A reference `Dockerfile` is shipped alongside this plugin; it builds on the public EMR Serverless Spark base image and installs the matching `flytekit` and `flytekitplugins-awsemrserverless` versions: + +```bash +docker build \ + --build-arg VERSION= \ + -t /emr-serverless-flytekit: \ + plugins/flytekit-aws-emr-serverless +``` + +Override the base with `--build-arg EMR_BASE_IMAGE=...` to track a different EMR release. ## IAM