diff --git a/plugins/flytekit-aws-emr-serverless/Dockerfile b/plugins/flytekit-aws-emr-serverless/Dockerfile new file mode 100644 index 0000000000..4cc93b6a7b --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/Dockerfile @@ -0,0 +1,20 @@ +# Reference worker image for AWS EMR Serverless tasks running in Pythonic +# Spark mode. Script Spark and Hive jobs do not need flytekit on the worker +# and can use the upstream EMR Serverless Spark image directly. +# +# Build: +# docker build --build-arg VERSION= \ +# -t /emr-serverless-flytekit: . +ARG EMR_BASE_IMAGE=public.ecr.aws/emr-serverless/spark/emr-7.0.0:latest +FROM ${EMR_BASE_IMAGE} +LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit + +ARG VERSION + +USER root + +RUN pip3 install --no-cache-dir --ignore-installed \ + "flytekit==${VERSION}" \ + "flytekitplugins-awsemrserverless==${VERSION}" + +USER hadoop:hadoop diff --git a/plugins/flytekit-aws-emr-serverless/README.md b/plugins/flytekit-aws-emr-serverless/README.md new file mode 100644 index 0000000000..496de851c6 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/README.md @@ -0,0 +1,120 @@ +# Flytekit AWS EMR Serverless Plugin + +A Flyte connector for [AWS EMR Serverless](https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/emr-serverless.html) that submits Spark and Hive jobs to an EMR Serverless application and tracks them through to completion. + +## Features + +- **Pythonic Spark mode**: write a Flyte `@task` whose body is regular PySpark; the plugin packages the user code, uploads it to S3, and runs it on EMR Serverless. No long-lived cluster to manage. +- **Script Spark mode**: point at an existing `main.py` (or JAR) already in S3 and submit it directly. +- **Hive mode**: submit a Hive query (inline or from S3) against an EMR Serverless application configured for Hive. +- Async connector lifecycle (`create` / `get` / `delete`) so the connector pod stays light and many jobs can be tracked concurrently. +- Honours Flyte task retries, timeouts, and cancellation, and surfaces EMR Serverless logs through the Flyte UI when log URIs are available. + +## Installation + +```bash +pip install flytekitplugins-awsemrserverless +``` + +The connector is registered automatically with `flytekit` via the plugin entry point. Deploy it on a [`flyteconnector`](https://github.com/flyteorg/flyte/tree/master/charts/flyteconnector) pod that has this package installed and an IAM identity allowed to call EMR Serverless `StartJobRun` / `GetJobRun` / `CancelJobRun` and to read/write the script-staging S3 prefix. + +## Usage + +### Pythonic Spark task + +```python +from flytekit import task, workflow +from flytekitplugins.awsemrserverless import EMRServerless, EMRServerlessSparkJobDriver + + +@task( + task_config=EMRServerless( + application_id="00fhabc12345", + execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole", + region="us-east-1", + job_driver=EMRServerlessSparkJobDriver( + spark_submit_parameters="--conf spark.executor.cores=2 --conf spark.executor.memory=4g", + ), + ), +) +def spark_count() -> int: + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + return spark.range(1_000_000).count() + + +@workflow +def wf() -> int: + return spark_count() +``` + +The plugin serializes the task body, uploads it to S3, and EMR Serverless runs it inside the worker image you have associated with the application. + +### Script Spark task + +```python +@task( + task_config=EMRServerless( + application_id="00fhabc12345", + execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole", + region="us-east-1", + job_driver=EMRServerlessSparkJobDriver( + entry_point="s3://my-bucket/scripts/main.py", + entry_point_arguments=["--date", "2025-01-01"], + spark_submit_parameters="--conf spark.executor.memory=4g", + ), + ), +) +def submit_script(): + ... +``` + +### Hive task + +```python +from flytekitplugins.awsemrserverless import EMRServerless, EMRServerlessHiveJobDriver + + +@task( + task_config=EMRServerless( + application_id="00fhabc12345", + execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole", + region="us-east-1", + job_driver=EMRServerlessHiveJobDriver( + query="SELECT COUNT(*) FROM my_table", + ), + ), +) +def hive_query(): + ... +``` + +## Worker image + +For Pythonic Spark tasks the worker image must contain `flytekit` and this plugin so the executor can rehydrate the task object on the EMR Serverless side. Script Spark and Hive jobs do not require flytekit on the worker. + +A reference `Dockerfile` is shipped alongside this plugin; it builds on the public EMR Serverless Spark base image and installs the matching `flytekit` and `flytekitplugins-awsemrserverless` versions: + +```bash +docker build \ + --build-arg VERSION= \ + -t /emr-serverless-flytekit: \ + plugins/flytekit-aws-emr-serverless +``` + +Override the base with `--build-arg EMR_BASE_IMAGE=...` to track a different EMR release. + +## IAM + +The connector pod's IAM principal needs: + +- `emr-serverless:StartJobRun`, `emr-serverless:GetJobRun`, `emr-serverless:CancelJobRun` on the target application +- `s3:GetObject` / `s3:PutObject` on the script-staging prefix (Pythonic mode) +- `iam:PassRole` for the EMR Serverless execution role + +The execution role attached to the EMR Serverless application is the role the workers run as and needs whatever data-access permissions your jobs require. + +## Discussion + +Tracking issue: [flyteorg/flyte#7286](https://github.com/flyteorg/flyte/issues/7286). diff --git a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/__init__.py b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/__init__.py new file mode 100644 index 0000000000..abc6be9ffd --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/__init__.py @@ -0,0 +1,29 @@ +""" +.. currentmodule:: flytekitplugins.awsemrserverless + +This plugin enables running Spark and Hive jobs on AWS EMR Serverless from +Flyte workflows. It exposes an async connector that handles the EMR +Serverless job lifecycle (submit, poll, cancel) and a task config type. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + EMRServerless + EMRServerlessSparkJobDriver + EMRServerlessHiveJobDriver + EMRServerlessTask + EMRServerlessConnector + EMRServerlessJobMetadata +""" + +from flytekitplugins.awsemrserverless.connector import ( + EMRServerlessConnector, + EMRServerlessJobMetadata, +) +from flytekitplugins.awsemrserverless.task import ( + EMRServerless, + EMRServerlessHiveJobDriver, + EMRServerlessSparkJobDriver, + EMRServerlessTask, +) diff --git a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/_entrypoint.py b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/_entrypoint.py new file mode 100644 index 0000000000..c93288fcd3 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/_entrypoint.py @@ -0,0 +1,203 @@ +""" +EMR Serverless Pythonic-mode entrypoint script. + +This file is the canonical source of the bootstrap script that EMR +Serverless workers execute as ``sparkSubmit.entryPoint`` for Pythonic +tasks (i.e. tasks that do not provide an explicit ``spark_job_driver``). + +How it is delivered to EMR +-------------------------- + +The connector pod: + +1. reads this file from its own ``site-packages`` install; +2. computes ``hashlib.sha256(content)[:12]`` and uploads (idempotently) + to ``s3:///flyte/emr-serverless/entrypoint-.py``; +3. passes that S3 URI as ``StartJobRun.jobDriver.sparkSubmit.entryPoint``. + +EMR Serverless then downloads this script onto the Spark driver and +runs it with ``spark-submit``. ``sys.argv[1:]`` carries the actual +``pyflyte-fast-execute`` (or ``pyflyte-execute``) invocation that +should run inside the worker, plus the fast-registration distribution +arguments. + +Why a custom entrypoint at all +------------------------------ + +EMR Serverless's API requires ``sparkSubmit.entryPoint`` to be a +single Python file URI -- there is no "container as entrypoint" +escape hatch (cf. SageMaker / Batch / ECS, which run the container +itself). We need a thin shim that: + +* downloads the fast-registration tarball from Flyte's blob store, +* invokes ``pyflyte-fast-execute`` with the right resolver arguments, +* converts Flytekit's "exit-0-on-user-error" semantics into a non-zero + exit so EMR reports ``FAILED`` instead of ``SUCCESS``. + +This file deliberately has only ``flytekit`` as a runtime dependency +(specifically ``flytekit.tools.fast_registration.download_distribution``) +because it runs *inside the EMR worker*, not the connector pod. + +Editing this file +----------------- + +Treat this file as part of the connector's *runtime contract* with +EMR workers, not as plugin internals: + +* changes here propagate to every Pythonic-mode job on the next + connector deploy via the content hash in the S3 key; +* the corresponding unit tests live in ``tests/test_entrypoint.py`` + and exercise this module both as imported Python and as the + spawned subprocess EMR sees; +* upstream alignment: this is the EMR analogue of + ``flytetools/flytekitplugins/databricks/entrypoint.py`` -- + same shape, different transport (S3 instead of GitHub). +""" + +import os +import signal +import subprocess +import sys + +from flytekit.tools.fast_registration import download_distribution + + +def _run_subprocess(cmd, env=None): + """Run ``cmd`` and forward SIGTERM, returning ``(returncode, stderr_text)``. + + stdout streams through to the parent (Spark driver stdout); stderr is + captured so the caller can inspect it for Flytekit's user-error banner. + """ + p = subprocess.Popen(cmd, env=env, stderr=subprocess.PIPE, stdout=None) + signal.signal(signal.SIGTERM, lambda s, f: p.send_signal(s)) + _, stderr_bytes = p.communicate() + stderr_text = stderr_bytes.decode("utf-8", errors="replace") if stderr_bytes else "" + if stderr_text: + sys.stderr.write(stderr_text) + return p.returncode, stderr_text + + +def _exit_with_code(rc, stderr_text=""): + """Translate Flytekit subprocess exit semantics into EMR-correct exits. + + Flytekit's ``pyflyte-execute`` catches user exceptions, writes the + error to the Flyte output blob, and exits ``0`` -- by design for + K8s-based agents where FlytePropeller reads the output. + + In EMR Serverless the connector only polls EMR job state + (``SUCCESS`` / ``FAILED``) and cannot read the output blobs. If + ``pyflyte-execute`` exits ``0`` but the user function failed, EMR + reports ``SUCCESS`` and the connector wrongly reports ``SUCCEEDED``. + + Detect this by scanning stderr for Flytekit's error banner. When + found, force a non-zero exit so Spark fails the driver and EMR + reports ``FAILED``. + """ + if rc == 0 and "User Error Captured by Flyte" in stderr_text: + print( + "[flyte-entrypoint] pyflyte-execute exited 0 but stderr contains " + "a user error -- forcing non-zero exit so EMR reports FAILED", + file=sys.stderr, + ) + sys.exit(1) + if rc != 0: + print(f"[flyte-entrypoint] Task process exited with code {rc}", file=sys.stderr) + sys.exit(rc) + + +def _parse_fast_execute_args(args): + """Split a ``pyflyte-fast-execute ...`` argv into its three pieces. + + Returns ``(additional_distribution, dest_dir, task_cmd_start)`` + where ``task_cmd_start`` is the index in ``args`` where the + underlying ``pyflyte-execute ...`` command begins. + + Recognises the two-arg flag forms emitted by Flytekit: + ``--additional-distribution `` + ``--dest-dir `` + and the optional ``--`` separator before the inner command. + """ + additional_distribution = None + dest_dir = None + task_cmd_start = 0 + + i = 1 + while i < len(args): + if args[i] == "--additional-distribution" and i + 1 < len(args): + additional_distribution = args[i + 1] + i += 2 + elif args[i] == "--dest-dir" and i + 1 < len(args): + dest_dir = args[i + 1] + i += 2 + elif args[i] == "--": + task_cmd_start = i + 1 + break + else: + task_cmd_start = i + break + + return additional_distribution, dest_dir, task_cmd_start + + +def _build_resolver_command(task_execute_cmd, additional_distribution, dest_dir): + """Inject the fast-registration distribution args before ``--resolver``. + + ``pyflyte-execute`` resolves task callables via a resolver plugin + (default ``flytekit.core.python_auto_container.default_task_resolver``). + For fast-registered code, the resolver needs to know where the + extracted source tree lives, which we inject as + ``--dynamic-addl-distro`` / ``--dynamic-dest-dir`` immediately + before ``--resolver``. + """ + cmd = [] + for arg in task_execute_cmd: + if arg == "--resolver": + cmd.extend( + [ + "--dynamic-addl-distro", + additional_distribution or "", + "--dynamic-dest-dir", + dest_dir or "", + ] + ) + cmd.append(arg) + return cmd + + +def main(): + args = sys.argv[1:] + if not args: + print("Usage: entrypoint.py pyflyte-fast-execute|pyflyte-execute ...", file=sys.stderr) + sys.exit(1) + + if args[0] == "pyflyte-fast-execute": + additional_distribution, dest_dir, task_cmd_start = _parse_fast_execute_args(args) + task_execute_cmd = list(args[task_cmd_start:]) + + if additional_distribution: + if not dest_dir: + dest_dir = os.getcwd() + download_distribution(additional_distribution, dest_dir) + + cmd = _build_resolver_command(task_execute_cmd, additional_distribution, dest_dir) + + env = os.environ.copy() + if dest_dir: + resolved = os.path.realpath(os.path.expanduser(dest_dir)) + env["PYTHONPATH"] = resolved + os.pathsep + env.get("PYTHONPATH", "") + rc, stderr_text = _run_subprocess(cmd, env) + _exit_with_code(rc, stderr_text) + + elif args[0] == "pyflyte-execute": + env = os.environ.copy() + env.setdefault("PYTHONPATH", os.getcwd()) + rc, stderr_text = _run_subprocess(args, env) + _exit_with_code(rc, stderr_text) + + else: + print(f"Unrecognized command: {args}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/boto_handler.py b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/boto_handler.py new file mode 100644 index 0000000000..0cfb82e2e0 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/boto_handler.py @@ -0,0 +1,389 @@ +""" +Boto3 helper for EMR Serverless API operations. + +Provides a thin async wrapper around the boto3 EMR Serverless client, with +application-lifecycle management used by the connector. + +All outbound boto3 traffic is funneled through two private helpers so that +tests can mock the entire handler at a single boundary (matches the +flytekit-aws-sagemaker ``Boto3ConnectorMixin._call`` pattern): + +* :py:meth:`EMRServerlessHandler._call` -- single-shot API calls +* :py:meth:`EMRServerlessHandler._paginate` -- paginated list operations +""" + +import asyncio +import logging +import time +from typing import Any, Dict, List, Optional + +import boto3 +from botocore.config import Config as BotoConfig +from botocore.exceptions import ClientError + +logger = logging.getLogger(__name__) + +_BOTO_RETRY_CONFIG = BotoConfig(retries={"max_attempts": 3, "mode": "adaptive"}) + +_APP_STARTED = "STARTED" +_APP_TERMINAL = {"TERMINATED"} +_APP_NEEDS_START = {"CREATED", "STOPPED"} +_APP_TRANSITIONING = {"CREATING", "STARTING"} + +_JOB_TERMINAL = {"SUCCESS", "FAILED", "CANCELLED"} + + +class EMRServerlessHandler: + """ + Async-friendly wrapper around the boto3 EMR Serverless client. + + Retries are handled by botocore's built-in adaptive retry mode. + Blocking boto3 calls are offloaded to a thread-pool executor so they + do not block the asyncio event loop. + """ + + def __init__(self, region: Optional[str] = None): + self.region = region + self._client: Optional[Any] = None + logger.debug("EMRServerlessHandler initialized: region=%s", region or "(default)") + + @property + def client(self) -> Any: + if self._client is None: + kwargs: Dict[str, Any] = { + "service_name": "emr-serverless", + "config": _BOTO_RETRY_CONFIG, + } + if self.region: + kwargs["region_name"] = self.region + self._client = boto3.client(**kwargs) + logger.debug( + "Boto3 EMR Serverless client created: region=%s", + self._client.meta.region_name, + ) + return self._client + + async def _call(self, method: str, **params: Any) -> Dict[str, Any]: + """Invoke a boto3 EMR Serverless client method in the thread-pool executor. + + This is the single chokepoint for all non-paginated boto3 API calls + (``method`` is the boto3 client method name and ``params`` are keyword + arguments forwarded to the call in the camelCase form the AWS SDK + expects). Tests can mock the entire handler surface by patching just + this method. Returns the raw response dict from the boto3 call. + """ + func = getattr(self.client, method) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, lambda: func(**params)) + + async def _paginate(self, method: str, result_key: str, **params: Any) -> List[Dict[str, Any]]: + """Exhaust a boto3 paginator and return the concatenated items. + + Used for list operations (e.g. ``list_applications``) that may span + multiple pages. ``method`` is the boto3 client method to paginate, + ``result_key`` is the key under which each page stores items + (e.g. ``"applications"`` for ``list_applications``), and ``params`` + are keyword arguments forwarded to ``paginator.paginate()``. Tests + can mock this independently from ``_call``. Returns a flat list of + items across all pages. + """ + + def _collect() -> List[Dict[str, Any]]: + paginator = self.client.get_paginator(method) + items: List[Dict[str, Any]] = [] + for page in paginator.paginate(**params): + items.extend(page.get(result_key, [])) + return items + + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, _collect) + + # ------------------------------------------------------------------ + # Application management + # ------------------------------------------------------------------ + + async def get_application(self, application_id: str) -> Dict[str, Any]: + logger.debug("GetApplication: applicationId=%s", application_id) + resp = await self._call("get_application", applicationId=application_id) + app = resp.get("application", {}) + logger.debug( + "GetApplication response: id=%s, name=%s, state=%s", + app.get("applicationId"), + app.get("name"), + app.get("state"), + ) + return app + + async def find_application_by_name(self, name: str) -> Optional[str]: + """Find an active application ID by name. Returns None if not found.""" + logger.info("Searching for application by name: '%s'", name) + apps = await self._paginate( + "list_applications", + result_key="applications", + states=["CREATED", "STARTED", "STOPPED"], + ) + for app in apps: + if app.get("name") == name: + logger.info("Found application '%s' with ID: %s", name, app["id"]) + return app["id"] + logger.info("Application '%s' not found after searching %d app(s)", name, len(apps)) + return None + + async def create_application( + self, + name: str, + release_label: str, + application_type: str, + initial_capacity: Optional[Dict[str, Any]] = None, + maximum_capacity: Optional[Dict[str, Any]] = None, + network_configuration: Optional[Dict[str, Any]] = None, + image_configuration: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, str]] = None, + architecture: Optional[str] = None, + runtime_configuration: Optional[list] = None, + scheduler_configuration: Optional[Dict[str, Any]] = None, + auto_stop_config: Optional[Dict[str, Any]] = None, + ) -> str: + logger.info( + "CreateApplication: name=%s, release_label=%s, type=%s, architecture=%s", + name, + release_label, + application_type, + architecture, + ) + params: Dict[str, Any] = { + "name": name, + "releaseLabel": release_label, + "type": application_type, + } + if initial_capacity: + params["initialCapacity"] = initial_capacity + logger.debug("CreateApplication: initialCapacity provided") + if maximum_capacity: + params["maximumCapacity"] = maximum_capacity + logger.debug("CreateApplication: maximumCapacity provided") + if network_configuration: + params["networkConfiguration"] = network_configuration + logger.debug("CreateApplication: networkConfiguration provided") + if image_configuration: + params["imageConfiguration"] = image_configuration + logger.debug("CreateApplication: imageConfiguration=%s", image_configuration) + if tags: + params["tags"] = tags + logger.debug("CreateApplication: tags=%s", tags) + if architecture: + params["architecture"] = architecture + if runtime_configuration: + params["runtimeConfiguration"] = runtime_configuration + logger.debug("CreateApplication: runtimeConfiguration provided") + if scheduler_configuration: + params["schedulerConfiguration"] = scheduler_configuration + logger.debug("CreateApplication: schedulerConfiguration=%s", scheduler_configuration) + if auto_stop_config: + params["autoStopConfiguration"] = auto_stop_config + logger.debug("CreateApplication: autoStopConfiguration=%s", auto_stop_config) + + resp = await self._call("create_application", **params) + app_id = resp.get("applicationId", "") + logger.info("CreateApplication succeeded: applicationId=%s, arn=%s", app_id, resp.get("arn", "")) + return app_id + + async def update_application( + self, + application_id: str, + image_configuration: Optional[Dict[str, Any]] = None, + maximum_capacity: Optional[Dict[str, Any]] = None, + auto_stop_config: Optional[Dict[str, Any]] = None, + runtime_configuration: Optional[list] = None, + scheduler_configuration: Optional[Dict[str, Any]] = None, + ) -> None: + """Update a mutable subset of application properties.""" + params: Dict[str, Any] = {"applicationId": application_id} + update_fields = [] + + if image_configuration: + params["imageConfiguration"] = image_configuration + update_fields.append("imageConfiguration") + if maximum_capacity: + params["maximumCapacity"] = maximum_capacity + update_fields.append("maximumCapacity") + if auto_stop_config: + params["autoStopConfiguration"] = auto_stop_config + update_fields.append("autoStopConfiguration") + if runtime_configuration: + params["runtimeConfiguration"] = runtime_configuration + update_fields.append("runtimeConfiguration") + if scheduler_configuration: + params["schedulerConfiguration"] = scheduler_configuration + update_fields.append("schedulerConfiguration") + + if len(params) <= 1: + logger.debug("UpdateApplication: no fields to update for %s, skipping", application_id) + return + + logger.info("UpdateApplication: applicationId=%s, fields=%s", application_id, update_fields) + await self._call("update_application", **params) + logger.info("UpdateApplication succeeded for %s", application_id) + + async def ensure_application_started( + self, + application_id: str, + timeout_seconds: int = 300, + poll_interval_seconds: int = 10, + ) -> None: + """Ensure the application reaches STARTED state, starting it if needed.""" + logger.info( + "EnsureApplicationStarted: applicationId=%s, timeout=%ds", + application_id, + timeout_seconds, + ) + start_time = time.monotonic() + + while True: + app = await self.get_application(application_id) + state = app.get("state", "") + + if state == _APP_STARTED: + elapsed = time.monotonic() - start_time + logger.info( + "Application %s is in STARTED state (waited %.1fs)", + application_id, + elapsed, + ) + return + + if state in _APP_TERMINAL: + logger.error( + "Application %s is in terminal state '%s', cannot be started", + application_id, + state, + ) + raise RuntimeError(f"Application {application_id} is in terminal state '{state}' and cannot be started") + + if state in _APP_NEEDS_START: + logger.info("Application %s is in state '%s', sending StartApplication request", application_id, state) + await self._call("start_application", applicationId=application_id) + + elapsed = time.monotonic() - start_time + if elapsed >= timeout_seconds: + logger.error( + "Timed out after %.1fs waiting for application %s to start (last state: %s)", + elapsed, + application_id, + state, + ) + raise TimeoutError( + f"Timed out after {timeout_seconds}s waiting for application {application_id} to start " + f"(last state: {state})" + ) + + logger.debug( + "Application %s state: %s, waiting %ds (elapsed: %.1fs / %ds)", + application_id, + state, + poll_interval_seconds, + elapsed, + timeout_seconds, + ) + await asyncio.sleep(poll_interval_seconds) + + # ------------------------------------------------------------------ + # Job management + # ------------------------------------------------------------------ + + async def start_job_run( + self, + application_id: str, + execution_role_arn: str, + job_driver: Dict[str, Any], + configuration_overrides: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, str]] = None, + execution_timeout_minutes: int = 60, + name: Optional[str] = None, + retry_policy: Optional[Dict[str, Any]] = None, + ) -> str: + logger.info( + "StartJobRun: applicationId=%s, name=%s, timeout=%dm", + application_id, + name, + execution_timeout_minutes, + ) + params: Dict[str, Any] = { + "applicationId": application_id, + "executionRoleArn": execution_role_arn, + "jobDriver": job_driver, + "executionTimeoutMinutes": execution_timeout_minutes, + } + if configuration_overrides: + params["configurationOverrides"] = configuration_overrides + logger.debug("StartJobRun: configurationOverrides provided") + if tags: + params["tags"] = tags + if name: + params["name"] = name + if retry_policy: + params["retryPolicy"] = retry_policy + logger.debug("StartJobRun: retryPolicy=%s", retry_policy) + + logger.debug("StartJobRun: jobDriver type=%s", list(job_driver.keys())) + resp = await self._call("start_job_run", **params) + job_run_id = resp.get("jobRunId", "") + logger.info( + "StartJobRun succeeded: jobRunId=%s, applicationId=%s, arn=%s", + job_run_id, + application_id, + resp.get("arn", ""), + ) + return job_run_id + + async def get_job_run(self, application_id: str, job_run_id: str) -> Dict[str, Any]: + logger.debug("GetJobRun: applicationId=%s, jobRunId=%s", application_id, job_run_id) + resp = await self._call( + "get_job_run", + applicationId=application_id, + jobRunId=job_run_id, + ) + job = resp.get("jobRun", {}) + logger.debug( + "GetJobRun response: jobRunId=%s, state=%s", + job.get("jobRunId"), + job.get("state"), + ) + return job + + async def cancel_job_run(self, application_id: str, job_run_id: str) -> None: + """Cancel a running job. Idempotent -- safe to call on completed jobs.""" + logger.info("CancelJobRun: applicationId=%s, jobRunId=%s", application_id, job_run_id) + try: + job = await self.get_job_run(application_id, job_run_id) + current_state = job.get("state", "") + if current_state in _JOB_TERMINAL: + logger.info( + "CancelJobRun: job %s already in terminal state '%s', no action needed", + job_run_id, + current_state, + ) + return + logger.info("CancelJobRun: job %s is in state '%s', sending cancel request", job_run_id, current_state) + await self._call( + "cancel_job_run", + applicationId=application_id, + jobRunId=job_run_id, + ) + logger.info("CancelJobRun succeeded: jobRunId=%s", job_run_id) + except ClientError as e: + code = e.response.get("Error", {}).get("Code", "") + if code == "ResourceNotFoundException": + logger.warning( + "CancelJobRun: job %s not found (ResourceNotFoundException), may already be cleaned up", + job_run_id, + ) + return + if code == "ValidationException" and "cannot be cancelled" in str(e).lower(): + logger.info( + "CancelJobRun: job %s cannot be cancelled (ValidationException), likely already completed", + job_run_id, + ) + return + logger.error("CancelJobRun failed for job %s: %s (code: %s)", job_run_id, e, code) + raise diff --git a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/connector.py b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/connector.py new file mode 100644 index 0000000000..52c41a9f8b --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/connector.py @@ -0,0 +1,825 @@ +""" +Flyte Async Connector for AWS EMR Serverless. + +Handles the complete lifecycle of EMR Serverless jobs: +create (submit), get (poll status), and delete (cancel). +""" + +import hashlib +import logging +import os +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional + +from flyteidl.core.execution_pb2 import TaskExecution + +try: + from flytekit.extend.backend.base_connector import ( + AsyncConnectorBase, + ConnectorRegistry, + Resource, + ResourceMeta, + ) +except ModuleNotFoundError: + from flytekit.extend.backend.base_agent import ( + AgentRegistry as ConnectorRegistry, + ) + from flytekit.extend.backend.base_agent import ( + AsyncAgentBase as AsyncConnectorBase, + ) + from flytekit.extend.backend.base_agent import ( + Resource, + ResourceMeta, + ) + +from flytekitplugins.awsemrserverless.boto_handler import EMRServerlessHandler +from flytekitplugins.awsemrserverless.task import EMR_SERVERLESS_BASE_IMAGE, EMRServerless + +from flytekit.extend.backend.utils import convert_to_flyte_phase +from flytekit.models.core.execution import TaskLog +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +logger = logging.getLogger(__name__) + +# Maps EMR Serverless job run states to the canonical flytekit phase strings +# accepted by ``convert_to_flyte_phase`` ("running", "success", "failed"). +# This mirrors the upstream flytekit plugin convention (see +# ``plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/connector.py``). +EMR_SERVERLESS_STATES = { + "PENDING": "Running", + "SCHEDULED": "Running", + "SUBMITTED": "Running", + "RUNNING": "Running", + "SUCCESS": "Success", + "FAILED": "Failed", + "CANCELLING": "Running", + "CANCELLED": "Failed", +} + +FLYTE_TAGS = {"Application": "flyte", "ManagedBy": "flyte-connector"} + +_APPLICATION_ID_RE = re.compile(r"^[0-9a-z]+$") + +_ENV_ALLOW_CREATE_APPLICATION = "FLYTE_EMR_ALLOW_CREATE_APPLICATION" +_ENV_APPLICATION_NAME_PREFIX = "FLYTE_EMR_APPLICATION_NAME_PREFIX" + + +@dataclass +class EMRServerlessJobMetadata(ResourceMeta): + """Metadata persisted by FlytePropeller between connector calls.""" + + application_id: str + job_run_id: str + region: str + created_application: bool = False + + +class EMRServerlessConnector(AsyncConnectorBase): + """ + Flyte Connector for AWS EMR Serverless. + + Supports two execution modes: + + * **Script mode** -- the user provides an explicit ``spark_job_driver`` + or ``hive_job_driver`` in the task config. The connector submits the + job as-is. + * **Pythonic mode** -- no job driver is provided. The connector reads + ``task_template.container`` (image + args) and constructs a + ``sparkSubmit`` job that runs the Flytekit entrypoint inside the + user's container image on EMR Serverless. + """ + + name = "EMR Serverless Connector" + + def __init__(self): + super().__init__( + task_type_name="emr_serverless", + metadata_type=EMRServerlessJobMetadata, + ) + + def _get_handler(self, region: Optional[str] = None) -> EMRServerlessHandler: + return EMRServerlessHandler(region=region) + + @staticmethod + def _extract_config(task_template: TaskTemplate) -> EMRServerless: + custom = task_template.custom + if not custom: + raise ValueError("Task template has no custom configuration") + + if hasattr(custom, "fields"): + from google.protobuf.json_format import MessageToDict + + config_dict = MessageToDict(custom) + else: + config_dict = dict(custom) + + config = EMRServerless.from_dict(config_dict) + logger.debug( + "Extracted task config: application_id=%s, application_name=%s, " + "application_type=%s, region=%s, script_mode=%s, sync_image=%s", + config.application_id, + config.application_name, + config.application_type, + config.region, + config.is_script_mode, + config.sync_image, + ) + return config + + # The Pythonic-mode entrypoint script that EMR workers execute as + # ``sparkSubmit.entryPoint``. Lives in its own module so that: + # * it is reviewable / blame-able in version control, + # * it can be unit-tested as Python (see ``tests/test_entrypoint.py``), + # * the connector ships it via the same wheel it ships in, + # mirroring the upstream Databricks pattern of treating the + # entrypoint as a first-class plugin asset. + _ENTRYPOINT_PATH: Path = (Path(__file__).parent / "_entrypoint.py").resolve() + + @classmethod + def _read_entrypoint_bytes(cls) -> bytes: + """Read the entrypoint file as bytes. Cached on the class. + + Reading from disk (instead of importing the module and using + ``inspect.getsource``) keeps the byte-stream byte-for-byte + identical to what EMR will execute, which is the input to the + content hash. + """ + cached = getattr(cls, "_entrypoint_bytes_cache", None) + if cached is None: + cached = cls._ENTRYPOINT_PATH.read_bytes() + cls._entrypoint_bytes_cache = cached + return cached + + @staticmethod + def _resolve_entrypoint_bucket(config: EMRServerless) -> str: + """Determine the S3 bucket for the entrypoint script. + + Resolution order: + 1. FLYTE_EMR_ENTRYPOINT_S3_BUCKET env var + 2. The S3 monitoring log URI from configuration_overrides (reuses the same bucket) + 3. FLYTE_AWS_S3_BUCKET env var (common Flyte storage bucket) + """ + bucket = os.environ.get("FLYTE_EMR_ENTRYPOINT_S3_BUCKET") + if bucket: + resolved = bucket.removeprefix("s3://").split("/")[0] + logger.debug("Resolved entrypoint bucket from FLYTE_EMR_ENTRYPOINT_S3_BUCKET: %s", resolved) + return resolved + + if config.configuration_overrides: + log_uri = ( + config.configuration_overrides.get("monitoringConfiguration", {}) + .get("s3MonitoringConfiguration", {}) + .get("logUri", "") + ) + if log_uri.startswith("s3://"): + resolved = log_uri.removeprefix("s3://").split("/")[0] + logger.debug("Resolved entrypoint bucket from monitoringConfiguration logUri: %s", resolved) + return resolved + + bucket = os.environ.get("FLYTE_AWS_S3_BUCKET") + if bucket: + resolved = bucket.removeprefix("s3://").split("/")[0] + logger.debug("Resolved entrypoint bucket from FLYTE_AWS_S3_BUCKET: %s", resolved) + return resolved + + raise ValueError( + "Pythonic mode needs an S3 bucket to upload the Flytekit entrypoint. " + "Set FLYTE_EMR_ENTRYPOINT_S3_BUCKET env var, or add " + "monitoringConfiguration.s3MonitoringConfiguration.logUri to " + "configuration_overrides in your task config." + ) + + def _ensure_entrypoint_on_s3(self, handler: EMRServerlessHandler, config: EMRServerless) -> str: + """Upload the Flytekit entrypoint to S3 if not present, return the s3:// URI. + + The S3 key is content-addressed (``entrypoint-.py``) + so each plugin version produces a unique, immutable artifact. + Old hashes remain in the bucket indefinitely so in-flight EMR + jobs that captured an older URI in their job spec keep working + across connector upgrades. + """ + content = self._read_entrypoint_bytes() + content_hash = hashlib.sha256(content).hexdigest()[:12] + + bucket = self._resolve_entrypoint_bucket(config) + + key = f"flyte/emr-serverless/entrypoint-{content_hash}.py" + s3_uri = f"s3://{bucket}/{key}" + + import boto3 + + s3 = boto3.client("s3", region_name=handler.region) + try: + s3.head_object(Bucket=bucket, Key=key) + logger.debug("Entrypoint already exists at %s", s3_uri) + except Exception: + logger.info("Uploading Flytekit entrypoint to %s", s3_uri) + s3.put_object(Bucket=bucket, Key=key, Body=content, ContentType="text/x-python") + + return s3_uri + + _PYTHONIC_SPARK_DEFAULTS = ( + "--conf spark.emr-serverless.driverEnv.PYSPARK_DRIVER_PYTHON=/usr/bin/python3 " + "--conf spark.emr-serverless.driverEnv.PYSPARK_PYTHON=/usr/bin/python3 " + "--conf spark.executorEnv.PYSPARK_PYTHON=/usr/bin/python3" + ) + + # Characters that would break ``--conf k=v`` tokenisation if embedded in + # an env value. EMR Serverless does not support quoting values here, + # so we drop the variable rather than risk a malformed sparkSubmit line. + _UNSAFE_ENV_VALUE_CHARS = re.compile(r"[\s='\"]") + + @staticmethod + def _build_flyte_env_vars( + task_execution_metadata: Optional[Any], + ) -> Dict[str, str]: + """Return a dict of FLYTE_INTERNAL_* env vars derived from the + task execution metadata, plus any user-supplied env vars from + ``task_execution_metadata.environment_variables``. + + Returns an empty dict when ``task_execution_metadata`` is ``None`` + (e.g. local / unit-test execution). + """ + if task_execution_metadata is None: + return {} + + env: Dict[str, str] = {} + + task_exec_id = getattr(task_execution_metadata, "task_execution_id", None) + if task_exec_id is None: + return env + + task_id = getattr(task_exec_id, "task_id", None) + if task_id is not None: + env["FLYTE_INTERNAL_TASK_PROJECT"] = getattr(task_id, "project", "") or "" + env["FLYTE_INTERNAL_TASK_DOMAIN"] = getattr(task_id, "domain", "") or "" + env["FLYTE_INTERNAL_TASK_NAME"] = getattr(task_id, "name", "") or "" + env["FLYTE_INTERNAL_TASK_VERSION"] = getattr(task_id, "version", "") or "" + + node_exec_id = getattr(task_exec_id, "node_execution_id", None) + if node_exec_id is not None: + env["FLYTE_INTERNAL_NODE_ID"] = getattr(node_exec_id, "node_id", "") or "" + wf_exec_id = getattr(node_exec_id, "execution_id", None) + if wf_exec_id is not None: + env["FLYTE_INTERNAL_EXECUTION_ID"] = getattr(wf_exec_id, "name", "") or "" + env["FLYTE_INTERNAL_EXECUTION_PROJECT"] = getattr(wf_exec_id, "project", "") or "" + env["FLYTE_INTERNAL_EXECUTION_DOMAIN"] = getattr(wf_exec_id, "domain", "") or "" + + retry = getattr(task_exec_id, "retry_attempt", None) + if retry is not None: + env["FLYTE_INTERNAL_TASK_RETRY_ATTEMPT"] = str(retry) + + user_env = getattr(task_execution_metadata, "environment_variables", None) or {} + for k, v in user_env.items(): + if k and v is not None: + env[str(k)] = str(v) + + return {k: v for k, v in env.items() if v != ""} + + @classmethod + def _format_env_as_spark_conf(cls, env_vars: Dict[str, str]) -> str: + """Convert a dict of env vars into Spark driver+executor ``--conf`` flags. + + Produces entries of the form:: + + --conf spark.emr-serverless.driverEnv.KEY=VALUE + --conf spark.executorEnv.KEY=VALUE + + Values containing whitespace or quote characters are skipped with + a warning (EMR Serverless does not support escaping here). + """ + if not env_vars: + return "" + + parts = [] + for key, value in env_vars.items(): + if cls._UNSAFE_ENV_VALUE_CHARS.search(value): + logger.warning( + "Skipping env var %s: value contains characters that would " + "break sparkSubmitParameters tokenisation", + key, + ) + continue + parts.append(f"--conf spark.emr-serverless.driverEnv.{key}={value}") + parts.append(f"--conf spark.executorEnv.{key}={value}") + return " ".join(parts) + + @classmethod + def _append_flyte_env_to_spark_params( + cls, + existing: Optional[str], + task_execution_metadata: Optional[Any], + *, + enabled: bool, + ) -> Optional[str]: + """Append Flyte context env vars to an existing sparkSubmitParameters. + + Returns ``existing`` unchanged when ``enabled`` is False, when no + metadata is available, or when there are no env vars to inject. + """ + if not enabled: + return existing + env_vars = cls._build_flyte_env_vars(task_execution_metadata) + if not env_vars: + return existing + conf = cls._format_env_as_spark_conf(env_vars) + if not conf: + return existing + logger.info("Injecting %d Flyte env var(s) into sparkSubmitParameters", len(env_vars)) + logger.debug("Injected Flyte env vars: %s", sorted(env_vars.keys())) + return f"{existing} {conf}".strip() if existing else conf + + def _build_pythonic_job_driver( + self, task_template: TaskTemplate, config: EMRServerless, handler: EMRServerlessHandler + ) -> Dict[str, Any]: + """ + Build a sparkSubmit job driver from the Flyte container entrypoint. + + The Databricks connector fetches its entrypoint from a Git repo. + EMR Serverless doesn't support Git sources, so we upload the + Flytekit entrypoint script to S3 and reference it from there. + ``container.args`` are passed as entrypoint arguments so that + the task function is executed inside the EMR Spark runtime. + + Pythonic mode requires a custom Docker image on the EMR Serverless + application that includes flytekit and its dependencies. The driver + and executor Python paths are set explicitly so the container's + Python (which has flytekit) is used instead of the EMR default. + """ + container = task_template.container + if container is None: + raise ValueError( + "Pythonic mode requires a container image. Either provide a " + "spark_job_driver for script mode, or ensure the task has a container image." + ) + + logger.info( + "Building Pythonic mode job driver: container_image=%s, args_count=%d", + container.image, + len(container.args) if container.args else 0, + ) + logger.debug("Pythonic mode container.args: %s", list(container.args) if container.args else []) + + entrypoint_s3_uri = self._ensure_entrypoint_on_s3(handler, config) + logger.info("Entrypoint S3 URI: %s", entrypoint_s3_uri) + + entry_point_args = list(container.args) if container.args else [] + + user_params = config.spark_submit_parameters + spark_submit_params = self._merge_spark_submit_params(user_params) + if user_params: + logger.info("Merged user spark_submit_parameters with Pythonic defaults") + logger.debug("Final sparkSubmitParameters: %s", spark_submit_params) + + spark_submit: Dict[str, Any] = { + "entryPoint": entrypoint_s3_uri, + "entryPointArguments": entry_point_args, + } + if spark_submit_params: + spark_submit["sparkSubmitParameters"] = spark_submit_params + + return {"sparkSubmit": spark_submit} + + def _merge_spark_submit_params(self, user_params: Optional[str]) -> str: + """Merge user-provided spark submit params with Pythonic mode defaults. + + User-provided values take precedence -- if the user already sets + e.g. ``spark.executorEnv.PYSPARK_PYTHON``, the default is skipped. + """ + defaults = self._PYTHONIC_SPARK_DEFAULTS + if not user_params: + return defaults + + merged_parts = [] + for token in defaults.split("--conf "): + token = token.strip() + if not token: + continue + key = token.split("=", 1)[0] + if key not in user_params: + merged_parts.append(f"--conf {token}") + + return f"{user_params} {' '.join(merged_parts)}".strip() + + @staticmethod + def _resolve_image_configuration( + task_template: TaskTemplate, + config: EMRServerless, + *, + for_create: bool = False, + ) -> Optional[Dict[str, Any]]: + """Derive EMR Serverless ``imageConfiguration`` following community patterns. + + Resolution order: + 1. Explicit ``image_configuration`` in the task config -- always wins. + 2. ``container_image`` on the task (``task_template.container.image``). + 3. For script mode, ``None`` is fine (default EMR image works). + 4. For Pythonic mode when creating a new app, fall back to the base image. + 5. For Pythonic mode on an existing app, raise with a clear message + (the app should already have a custom image). + """ + logger.debug("Resolving image configuration: for_create=%s, script_mode=%s", for_create, config.is_script_mode) + + if config.image_configuration: + logger.info("Using explicit image_configuration from task config: %s", config.image_configuration) + return config.image_configuration + + container = task_template.container + if container and getattr(container, "image", None): + logger.info( + "Deriving imageConfiguration from container_image: %s", + container.image, + ) + return {"imageUri": container.image} + + if config.is_script_mode: + logger.debug("Script mode with no custom image; using default EMR image") + return None + + if for_create: + logger.info( + "No image specified for new application; using default base image: %s", + EMR_SERVERLESS_BASE_IMAGE, + ) + return {"imageUri": EMR_SERVERLESS_BASE_IMAGE} + + raise ValueError( + "Pythonic mode requires a custom Docker image with flytekit installed " + "on the EMR Serverless workers. Provide a container_image on the task:\n\n" + ' @task(task_config=EMRServerless(...), container_image="")\n\n' + "Or use an existing application_id whose application already has a " + "custom image configured." + ) + + def _merge_tags(self, user_tags: Optional[Dict[str, str]]) -> Dict[str, str]: + tags = dict(FLYTE_TAGS) + if user_tags: + tags.update(user_tags) + return tags + + @staticmethod + def _get_container_image(task_template: TaskTemplate) -> Optional[str]: + """Extract the container image URI from the task template.""" + container = task_template.container + if container and getattr(container, "image", None): + return container.image + return None + + @staticmethod + def _is_create_application_allowed() -> bool: + """Check the connector-level policy for application creation. + + Reads ``FLYTE_EMR_ALLOW_CREATE_APPLICATION`` (default ``"false"``). + Only ``"true"`` (case-insensitive) permits dynamic creation. + """ + raw_value = os.environ.get(_ENV_ALLOW_CREATE_APPLICATION, "false") + allowed = raw_value.strip().lower() == "true" + logger.debug( + "Application creation policy check: %s=%r, allowed=%s", + _ENV_ALLOW_CREATE_APPLICATION, + raw_value, + allowed, + ) + return allowed + + @staticmethod + def _apply_application_name_prefix(name: str) -> str: + """Prepend the connector-level prefix to the application name. + + Reads ``FLYTE_EMR_APPLICATION_NAME_PREFIX``. When set (e.g. + ``flyte-prod-``), the final name becomes ``flyte-prod-my-etl-app``. + """ + prefix = os.environ.get(_ENV_APPLICATION_NAME_PREFIX, "").strip() + if prefix and not name.startswith(prefix): + prefixed = f"{prefix}{name}" + logger.info("Applied application name prefix: '%s' -> '%s'", name, prefixed) + return prefixed + if prefix: + logger.debug("Application name '%s' already starts with prefix '%s', skipping", name, prefix) + return name + + async def _sync_application_image( + self, + handler: EMRServerlessHandler, + application_id: str, + task_template: TaskTemplate, + config: EMRServerless, + ) -> None: + """Update the application image if ``container_image`` differs from the app. + + The desired image is derived from (in order): + 1. ``task_template.container.image`` (the ``container_image`` decorator arg) + 2. ``config.image_configuration`` (explicit ``image_configuration`` in task_config) + + If neither is set, this is a no-op (script mode with default EMR image). + """ + logger.debug("Image sync: checking application %s", application_id) + + if config.is_script_mode and not config.image_configuration: + logger.debug( + "Image sync: script mode with no explicit image_configuration, skipping for application %s", + application_id, + ) + return + + desired_uri = self._get_container_image(task_template) + if not desired_uri and config.image_configuration: + desired_uri = config.image_configuration.get("imageUri", "") + logger.debug("Image sync: using image_configuration as fallback: %s", desired_uri) + if not desired_uri: + logger.debug("Image sync: no desired image URI found, skipping for application %s", application_id) + return + + logger.debug("Image sync: desired image for application %s: %s", application_id, desired_uri) + app = await handler.get_application(application_id) + current_image = app.get("imageConfiguration", {}).get("imageUri", "") + logger.debug("Image sync: current image on application %s: %s", application_id, current_image or "(none)") + + if current_image == desired_uri: + logger.info( + "Image sync: application %s already has image %s, no update needed", application_id, desired_uri + ) + return + + logger.info( + "Image sync: updating application %s image: %s -> %s", + application_id, + current_image or "(none)", + desired_uri, + ) + await handler.update_application( + application_id=application_id, + image_configuration={"imageUri": desired_uri}, + ) + logger.info("Image sync: successfully updated application %s image", application_id) + + async def create( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + task_execution_metadata: Optional[Any] = None, + **kwargs: Any, + ) -> EMRServerlessJobMetadata: + task_name = task_template.id.name if task_template.id else "(unknown)" + logger.info("create() called for task: %s", task_name) + + if task_execution_metadata is None: + task_execution_metadata = kwargs.get("task_execution_metadata") + + config = self._extract_config(task_template) + handler = self._get_handler(config.region) + + application_id = config.application_id + created_application = False + + # --- Resolve application by name if the ID looks like a name --- + if application_id and not _APPLICATION_ID_RE.match(application_id): + app_name = application_id + logger.info( + "application_id '%s' does not look like an AWS ID, treating as application name", + app_name, + ) + resolved_id = await handler.find_application_by_name(app_name) + if resolved_id: + logger.info("Resolved application name '%s' to ID: %s", app_name, resolved_id) + application_id = resolved_id + else: + raise ValueError( + f"No active EMR Serverless application found with name '{app_name}'. " + f"application_id must be a valid AWS application ID (e.g. '00fm0lpr3kbcq60p') " + f"or the name of an existing application." + ) + + # --- Create a new application if needed --- + if not application_id: + logger.info("No application_id provided, checking if dynamic creation is possible") + + if not config.application_name: + logger.error("No application_id and no application_name specified, cannot proceed") + raise ValueError( + "No application_id provided and no application_name specified. " + "Either set application_id to an existing EMR Serverless application, " + "or set application_name to create a new one " + f"(requires {_ENV_ALLOW_CREATE_APPLICATION}=true on the connector)." + ) + + if not self._is_create_application_allowed(): + logger.warning( + "Dynamic application creation requested for '%s' but blocked by connector policy", + config.application_name, + ) + raise ValueError( + "Dynamic application creation is disabled by the connector " + f"({_ENV_ALLOW_CREATE_APPLICATION} is not 'true'). " + "Contact your platform team to enable it, or set application_id " + "to an existing EMR Serverless application." + ) + + app_name = self._apply_application_name_prefix(config.application_name) + + image_config = self._resolve_image_configuration(task_template, config, for_create=True) + + logger.info( + "Creating new EMR Serverless application: name=%s, release_label=%s, type=%s, architecture=%s", + app_name, + config.release_label, + config.application_type, + config.architecture, + ) + application_id = await handler.create_application( + name=app_name, + release_label=config.release_label, + application_type=config.application_type, + initial_capacity=config.initial_capacity, + maximum_capacity=config.maximum_capacity, + network_configuration=config.network_configuration, + image_configuration=image_config, + tags=self._merge_tags(config.tags), + architecture=config.architecture, + runtime_configuration=config.runtime_configuration, + scheduler_configuration=config.scheduler_configuration, + auto_stop_config=config.auto_stop_config, + ) + created_application = True + logger.info("Created application %s successfully", application_id) + else: + logger.info("Using existing application: %s", application_id) + + # --- Sync image on existing apps if sync_image is enabled --- + if not created_application and config.sync_image: + logger.info("sync_image is enabled, checking application image for %s", application_id) + await self._sync_application_image(handler, application_id, task_template, config) + elif not created_application: + logger.debug("sync_image is disabled, skipping image sync for %s", application_id) + + logger.info("Ensuring application %s is in STARTED state", application_id) + await handler.ensure_application_started(application_id) + + # --- Build job driver --- + if config.is_script_mode: + logger.info("Building job driver in script mode") + job_driver = config.get_job_driver() + if config.spark_submit_parameters and "sparkSubmit" in job_driver: + existing = job_driver["sparkSubmit"].get("sparkSubmitParameters", "") + merged = f"{existing} {config.spark_submit_parameters}".strip() + job_driver["sparkSubmit"]["sparkSubmitParameters"] = merged + logger.info("Merged top-level spark_submit_parameters into script mode job driver") + logger.debug("Final sparkSubmitParameters: %s", merged) + if "sparkSubmit" in job_driver: + current = job_driver["sparkSubmit"].get("sparkSubmitParameters") + updated = self._append_flyte_env_to_spark_params( + current, + task_execution_metadata, + enabled=config.inject_flyte_env, + ) + if updated != current: + job_driver["sparkSubmit"]["sparkSubmitParameters"] = updated + elif config.inject_flyte_env: + logger.debug( + "Skipping Flyte env injection: non-Spark job driver (%s)", + next(iter(job_driver.keys()), "unknown"), + ) + else: + logger.info("Building job driver in Pythonic mode") + job_driver = self._build_pythonic_job_driver(task_template, config, handler) + if "sparkSubmit" in job_driver: + current = job_driver["sparkSubmit"].get("sparkSubmitParameters") + updated = self._append_flyte_env_to_spark_params( + current, + task_execution_metadata, + enabled=config.inject_flyte_env, + ) + if updated != current: + job_driver["sparkSubmit"]["sparkSubmitParameters"] = updated + + job_name = task_template.id.name if task_template.id else None + + effective_config_overrides = config.get_effective_configuration_overrides() + if effective_config_overrides: + logger.debug( + "Effective configuration overrides keys: %s", + list(effective_config_overrides.keys()), + ) + + logger.info( + "Submitting job run: application=%s, job_name=%s, execution_role=%s, timeout=%dm", + application_id, + job_name, + config.execution_role_arn, + config.execution_timeout_minutes, + ) + job_run_id = await handler.start_job_run( + application_id=application_id, + execution_role_arn=config.execution_role_arn, + job_driver=job_driver, + configuration_overrides=effective_config_overrides, + tags=self._merge_tags(config.tags), + execution_timeout_minutes=config.execution_timeout_minutes, + name=job_name, + retry_policy=config.retry_policy, + ) + + region = config.region or handler.client.meta.region_name + logger.info( + "Job submitted successfully: application=%s, job_run_id=%s, region=%s, created_application=%s", + application_id, + job_run_id, + region, + created_application, + ) + + return EMRServerlessJobMetadata( + application_id=application_id, + job_run_id=job_run_id, + region=region, + created_application=created_application, + ) + + async def get( + self, + resource_meta: EMRServerlessJobMetadata, + **kwargs: Any, + ) -> Resource: + logger.debug( + "get() called: application=%s, job_run_id=%s, region=%s", + resource_meta.application_id, + resource_meta.job_run_id, + resource_meta.region, + ) + handler = self._get_handler(resource_meta.region) + + try: + job = await handler.get_job_run( + application_id=resource_meta.application_id, + job_run_id=resource_meta.job_run_id, + ) + except Exception as e: + logger.warning( + "Failed to retrieve job %s on application %s: %s", + resource_meta.job_run_id, + resource_meta.application_id, + e, + ) + return Resource( + phase=TaskExecution.FAILED, + message=f"Job not found: {resource_meta.job_run_id}", + ) + + state = job.get("state", "UNKNOWN") + state_details = job.get("stateDetails", "") + phase = convert_to_flyte_phase(EMR_SERVERLESS_STATES.get(state, "Running")) + + message = f"EMR Serverless job state: {state}" + if state_details: + message = f"{message} - {state_details}" + + logger.info( + "Job %s status: state=%s, phase=%s", + resource_meta.job_run_id, + state, + phase, + ) + + log_links = self._get_log_links(resource_meta) + + return Resource(phase=phase, message=message, log_links=log_links) + + def _get_log_links(self, resource_meta: EMRServerlessJobMetadata) -> list: + region = resource_meta.region or "us-east-1" + console_url = ( + f"https://{region}.console.aws.amazon.com/emr/home?region={region}" + f"#/serverless/{resource_meta.application_id}/jobs/{resource_meta.job_run_id}" + ) + return [TaskLog(uri=console_url, name="EMR Serverless Console").to_flyte_idl()] + + async def delete( + self, + resource_meta: EMRServerlessJobMetadata, + **kwargs: Any, + ) -> None: + logger.info( + "delete() called: application=%s, job_run_id=%s, region=%s", + resource_meta.application_id, + resource_meta.job_run_id, + resource_meta.region, + ) + handler = self._get_handler(resource_meta.region) + try: + await handler.cancel_job_run( + application_id=resource_meta.application_id, + job_run_id=resource_meta.job_run_id, + ) + logger.info("Delete completed for job %s", resource_meta.job_run_id) + except Exception as e: + logger.warning( + "Failed to cancel job %s on application %s: %s", + resource_meta.job_run_id, + resource_meta.application_id, + e, + ) + + +ConnectorRegistry.register(EMRServerlessConnector()) diff --git a/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/task.py b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/task.py new file mode 100644 index 0000000000..4b644563a2 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/flytekitplugins/awsemrserverless/task.py @@ -0,0 +1,402 @@ +""" +Flyte Task definition for EMR Serverless. + +Provides the EMRServerless task configuration dataclass, job driver configs, +and the EMRServerlessTask class for defining EMR Serverless jobs in Flyte +workflows. Supports two execution modes: + +1. Script mode: User provides S3 paths to Spark/Hive scripts. +2. Pythonic mode: User writes a @task function and the connector uses the + Flyte container entrypoint to execute it on EMR Serverless. +""" + +import dataclasses +import logging +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +from flytekit import PythonFunctionTask +from flytekit.configuration import SerializationSettings +from flytekit.core.context_manager import FlyteContextManager +from flytekit.extend import TaskPlugins +from flytekit.image_spec import ImageSpec + +try: + from flytekit.extend.backend.base_connector import AsyncConnectorExecutorMixin +except ModuleNotFoundError: + from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin as AsyncConnectorExecutorMixin + +logger = logging.getLogger(__name__) + + +@dataclass +class EMRServerlessSparkJobDriver: + """Spark job driver configuration for EMR Serverless. + + ``entry_point`` is the S3 path to the main application file (for example + ``s3://bucket/scripts/main.py``). ``entry_point_arguments`` is an optional + list of arguments forwarded to the application, and + ``spark_submit_parameters`` is the optional spark-submit parameter string + (for example ``--conf spark.executor.memory=4g``). + """ + + entry_point: str + entry_point_arguments: Optional[List[str]] = None + spark_submit_parameters: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to EMR Serverless API format.""" + result: Dict[str, Any] = {"entryPoint": self.entry_point} + if self.entry_point_arguments: + result["entryPointArguments"] = self.entry_point_arguments + if self.spark_submit_parameters: + result["sparkSubmitParameters"] = self.spark_submit_parameters + return result + + +@dataclass +class EMRServerlessHiveJobDriver: + """Hive job driver configuration for EMR Serverless. + + ``query`` is the Hive query to execute, or an S3 path to a query file. + ``init_query_file`` is an optional S3 path to an initialization query + file, and ``parameters`` is an optional parameters string for the query. + """ + + query: str + init_query_file: Optional[str] = None + parameters: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to EMR Serverless API format.""" + result: Dict[str, Any] = {"query": self.query} + if self.init_query_file: + result["initQueryFile"] = self.init_query_file + if self.parameters: + result["parameters"] = self.parameters + return result + + +_VALID_ARCHITECTURES = ("X86_64", "ARM64") + + +@dataclass +class EMRServerless: + """ + EMR Serverless task configuration. + + Use this to configure an EMR Serverless task. Tasks marked with this + configuration will execute on AWS EMR Serverless as Spark or Hive jobs. + + Supports two execution modes: + + **Script mode** -- provide a ``spark_job_driver`` or ``hive_job_driver`` + pointing to scripts on S3:: + + @task( + task_config=EMRServerless( + application_id="00f5abc123def456", + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=EMRServerlessSparkJobDriver( + entry_point="s3://bucket/scripts/main.py", + ), + region="us-east-1", + ) + ) + def my_spark_job() -> str: + return "Job submitted" + + **Pythonic mode** -- omit the job driver and write Spark code directly + in the task function. The connector will use the Flyte container + entrypoint to execute it on EMR Serverless:: + + @task( + task_config=EMRServerless( + application_id="00f5abc123def456", + execution_role_arn="arn:aws:iam::123456789012:role/Role", + region="us-east-1", + ) + ) + def my_spark_job(): + from pyspark.sql import SparkSession + spark = SparkSession.builder.getOrCreate() + spark.range(100).write.parquet("s3://bucket/output") + + Field summary: + + * ``execution_role_arn`` (required): IAM role ARN used to run the job. + * ``application_id``: existing EMR Serverless application ID. When unset + and ``application_name`` is provided, the connector will attempt to + create a new application (subject to the connector-level + ``FLYTE_EMR_ALLOW_CREATE_APPLICATION`` policy). + * ``release_label``: EMR release label, default ``emr-7.0.0``. + * ``application_type``: ``"SPARK"`` or ``"HIVE"`` (default ``"SPARK"``). + * ``application_name``: name for a new application; when set without + ``application_id`` the connector creates one (prefixed by + ``FLYTE_EMR_APPLICATION_NAME_PREFIX`` if configured). + * ``sync_image`` (default ``True``): the connector will update the + application image if the task's ``container_image`` or + ``image_configuration`` differs from what the app currently has. + * ``spark_job_driver`` / ``hive_job_driver``: explicit job driver for + script mode; omit ``spark_job_driver`` to use Pythonic Spark mode. + * ``spark_submit_parameters``: extra ``--conf`` flags applied in both + script and Pythonic modes (merged with the connector defaults in + Pythonic mode). + * ``application_configuration`` / ``runtime_configuration`` / + ``scheduler_configuration`` / ``auto_stop_config``: forwarded to + ``CreateApplication`` / ``UpdateApplication``. + * ``architecture``: ``"X86_64"`` (default) or ``"ARM64"`` (Graviton). + * ``retry_policy``: job-level retry policy (requires EMR 7.1+). + * ``configuration_overrides``: application and monitoring overrides. + * ``tags``: resource tags applied to the application/jobs. + * ``execution_timeout_minutes``: max job execution time, default 60. + * ``initial_capacity`` / ``maximum_capacity`` / ``network_configuration`` + / ``image_configuration``: forwarded to the application. + * ``region``: AWS region (uses the boto3 default if not specified). + * ``inject_flyte_env`` (default ``True``): when set, the connector + appends ``spark.emr-serverless.driverEnv.*`` and ``spark.executorEnv.*`` + entries to ``sparkSubmitParameters`` so the Spark driver and executors + see Flyte runtime context as environment variables + (``FLYTE_INTERNAL_EXECUTION_ID``, ``FLYTE_INTERNAL_EXECUTION_PROJECT``, + ``FLYTE_INTERNAL_EXECUTION_DOMAIN``, + ``FLYTE_INTERNAL_TASK_{PROJECT,DOMAIN,NAME,VERSION}``, + ``FLYTE_INTERNAL_NODE_ID``, ``FLYTE_INTERNAL_TASK_RETRY_ATTEMPT``). + Any ``environment_variables`` set on ``task_execution_metadata`` (via + Flyte ``Environment`` policies) are also forwarded. Only applies to + Spark jobs. + """ + + execution_role_arn: str + application_id: Optional[str] = None + release_label: str = "emr-7.0.0" + application_type: str = "SPARK" + application_name: Optional[str] = None + + sync_image: bool = True + inject_flyte_env: bool = True + + spark_job_driver: Optional[EMRServerlessSparkJobDriver] = None + hive_job_driver: Optional[EMRServerlessHiveJobDriver] = None + + spark_submit_parameters: Optional[str] = None + application_configuration: Optional[List[Dict[str, Any]]] = None + runtime_configuration: Optional[List[Dict[str, Any]]] = None + scheduler_configuration: Optional[Dict[str, Any]] = None + auto_stop_config: Optional[Dict[str, Any]] = None + architecture: Optional[str] = None + retry_policy: Optional[Dict[str, Any]] = None + + configuration_overrides: Optional[Dict[str, Any]] = None + tags: Optional[Dict[str, str]] = None + execution_timeout_minutes: int = 60 + + initial_capacity: Optional[Dict[str, Any]] = None + maximum_capacity: Optional[Dict[str, Any]] = None + network_configuration: Optional[Dict[str, Any]] = None + image_configuration: Optional[Dict[str, Any]] = None + + region: Optional[str] = None + + def __post_init__(self) -> None: + if not self.execution_role_arn: + raise ValueError("execution_role_arn is required") + if self.application_type not in ("SPARK", "HIVE"): + raise ValueError(f"application_type must be 'SPARK' or 'HIVE', got '{self.application_type}'") + if self.application_type == "HIVE" and self.hive_job_driver is None: + raise ValueError("hive_job_driver is required when application_type is 'HIVE'") + if self.spark_job_driver and self.hive_job_driver: + raise ValueError("Only one of spark_job_driver or hive_job_driver can be specified") + if self.execution_timeout_minutes < 1: + raise ValueError("execution_timeout_minutes must be at least 1") + if self.architecture and self.architecture not in _VALID_ARCHITECTURES: + raise ValueError(f"architecture must be one of {_VALID_ARCHITECTURES}, got '{self.architecture}'") + + @property + def is_script_mode(self) -> bool: + """True when user provided an explicit job driver (script mode).""" + return self.spark_job_driver is not None or self.hive_job_driver is not None + + def get_job_driver(self) -> Dict[str, Any]: + """Get the job driver dict in EMR Serverless API format.""" + if self.spark_job_driver: + return {"sparkSubmit": self.spark_job_driver.to_dict()} + if self.hive_job_driver: + return {"hive": self.hive_job_driver.to_dict()} + raise ValueError("No job driver configured -- use Pythonic mode instead") + + def get_effective_configuration_overrides(self) -> Optional[Dict[str, Any]]: + """Build the merged ``configurationOverrides`` dict. + + Combines ``configuration_overrides`` with ``application_configuration`` + so users can set monitoring config via ``configuration_overrides`` and + Spark/Hive config via ``application_configuration`` independently. + """ + if not self.application_configuration and not self.configuration_overrides: + return None + + result = dict(self.configuration_overrides) if self.configuration_overrides else {} + + if self.application_configuration: + existing = result.get("applicationConfiguration", []) + merged = list(existing) + list(self.application_configuration) + result["applicationConfiguration"] = merged + + return result or None + + def to_dict(self) -> Dict[str, Any]: + """Serialize to a dict for ``task_template.custom``.""" + result: Dict[str, Any] = { + "execution_role_arn": self.execution_role_arn, + "release_label": self.release_label, + "application_type": self.application_type, + "execution_timeout_minutes": self.execution_timeout_minutes, + "sync_image": self.sync_image, + "inject_flyte_env": self.inject_flyte_env, + } + if self.application_id: + result["application_id"] = self.application_id + if self.application_name: + result["application_name"] = self.application_name + if self.spark_job_driver: + result["spark_job_driver"] = self.spark_job_driver.to_dict() + if self.hive_job_driver: + result["hive_job_driver"] = self.hive_job_driver.to_dict() + if self.spark_submit_parameters: + result["spark_submit_parameters"] = self.spark_submit_parameters + if self.application_configuration: + result["application_configuration"] = self.application_configuration + if self.runtime_configuration: + result["runtime_configuration"] = self.runtime_configuration + if self.scheduler_configuration: + result["scheduler_configuration"] = self.scheduler_configuration + if self.auto_stop_config: + result["auto_stop_config"] = self.auto_stop_config + if self.architecture: + result["architecture"] = self.architecture + if self.retry_policy: + result["retry_policy"] = self.retry_policy + if self.configuration_overrides: + result["configuration_overrides"] = self.configuration_overrides + if self.tags: + result["tags"] = self.tags + if self.initial_capacity: + result["initial_capacity"] = self.initial_capacity + if self.maximum_capacity: + result["maximum_capacity"] = self.maximum_capacity + if self.network_configuration: + result["network_configuration"] = self.network_configuration + if self.image_configuration: + result["image_configuration"] = self.image_configuration + if self.region: + result["region"] = self.region + return result + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "EMRServerless": + """Deserialize from a dict (inverse of ``to_dict``).""" + spark_driver = None + hive_driver = None + + if "spark_job_driver" in data: + d = data["spark_job_driver"] + spark_driver = EMRServerlessSparkJobDriver( + entry_point=d["entryPoint"], + entry_point_arguments=d.get("entryPointArguments"), + spark_submit_parameters=d.get("sparkSubmitParameters"), + ) + if "hive_job_driver" in data: + d = data["hive_job_driver"] + hive_driver = EMRServerlessHiveJobDriver( + query=d["query"], + init_query_file=d.get("initQueryFile"), + parameters=d.get("parameters"), + ) + + return cls( + execution_role_arn=data["execution_role_arn"], + application_id=data.get("application_id"), + release_label=data.get("release_label", "emr-7.0.0"), + application_type=data.get("application_type", "SPARK"), + application_name=data.get("application_name"), + sync_image=data.get("sync_image", True), + inject_flyte_env=data.get("inject_flyte_env", True), + spark_job_driver=spark_driver, + hive_job_driver=hive_driver, + spark_submit_parameters=data.get("spark_submit_parameters"), + application_configuration=data.get("application_configuration"), + runtime_configuration=data.get("runtime_configuration"), + scheduler_configuration=data.get("scheduler_configuration"), + auto_stop_config=data.get("auto_stop_config"), + architecture=data.get("architecture"), + retry_policy=data.get("retry_policy"), + configuration_overrides=data.get("configuration_overrides"), + tags=data.get("tags"), + execution_timeout_minutes=int(data.get("execution_timeout_minutes", 60)), + initial_capacity=data.get("initial_capacity"), + maximum_capacity=data.get("maximum_capacity"), + network_configuration=data.get("network_configuration"), + image_configuration=data.get("image_configuration"), + region=data.get("region"), + ) + + +EMR_SERVERLESS_BASE_IMAGE = "public.ecr.aws/emr-serverless/spark/emr-7.0.0:latest" + + +class EMRServerlessTask(AsyncConnectorExecutorMixin, PythonFunctionTask[EMRServerless]): + """ + EMR Serverless Task implementation. + + Extends ``PythonFunctionTask`` with ``AsyncConnectorExecutorMixin`` to + enable remote execution via the EMR Serverless connector and local + testing by mimicking FlytePropeller's connector calls. + + For Pythonic mode the task image must include ``flytekit``. When an + ``ImageSpec`` is provided without a ``base_image``, the EMR Serverless + Spark base image is used automatically (same pattern as + ``flytekit-spark``). Users can also pass a plain ECR URI string as + ``container_image``. + """ + + _TASK_TYPE = "emr_serverless" + + def __init__( + self, + task_config: EMRServerless, + task_function: Callable, + container_image: Optional[Union[str, ImageSpec]] = None, + **kwargs: Any, + ): + if isinstance(container_image, ImageSpec) and container_image.base_image is None: + container_image = dataclasses.replace(container_image, base_image=EMR_SERVERLESS_BASE_IMAGE) + + super().__init__( + task_config=task_config, + task_function=task_function, + task_type=self._TASK_TYPE, + container_image=container_image, + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + """Serialize task configuration into ``task_template.custom``.""" + return self.task_config.to_dict() + + def execute(self, **kwargs: Any) -> Any: + """ + Execute the task. + + When running locally (e.g. ``pyflyte run`` without ``--remote``), + delegate to the connector mixin so it mimics the backend flow. + + When running on an EMR Serverless worker (dispatched by the + entrypoint), execute the user function directly -- the connector + already submitted the job; the worker just needs to run the code. + """ + ctx = FlyteContextManager.current_context() + if ctx.execution_state and ctx.execution_state.is_local_execution(): + return AsyncConnectorExecutorMixin.execute(self, **kwargs) + return PythonFunctionTask.execute(self, **kwargs) + + +TaskPlugins.register_pythontask_plugin(EMRServerless, EMRServerlessTask) diff --git a/plugins/flytekit-aws-emr-serverless/setup.py b/plugins/flytekit-aws-emr-serverless/setup.py new file mode 100644 index 0000000000..8ba9163b94 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/setup.py @@ -0,0 +1,49 @@ +from setuptools import setup + +PLUGIN_NAME = "awsemrserverless" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = [ + "flytekit>=1.14.6", + "aioboto3>=12.3.0", + "boto3>=1.28.0", +] + +__version__ = "0.0.0+develop" + +setup( + title="AWS EMR Serverless", + title_expanded="Flytekit AWS EMR Serverless Plugin", + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="Flytekit AWS EMR Serverless Plugin: run Spark and Hive jobs from Flyte tasks", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.9", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={ + "flytekit.plugins": [ + f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}", + ], + }, +) diff --git a/plugins/flytekit-aws-emr-serverless/tests/__init__.py b/plugins/flytekit-aws-emr-serverless/tests/__init__.py new file mode 100644 index 0000000000..199377458c --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the Flyte EMR Serverless connector.""" diff --git a/plugins/flytekit-aws-emr-serverless/tests/conftest.py b/plugins/flytekit-aws-emr-serverless/tests/conftest.py new file mode 100644 index 0000000000..67142fc5f9 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/tests/conftest.py @@ -0,0 +1,157 @@ +""" +Pytest configuration and shared fixtures for EMR Serverless tests. + +The fixtures here follow the flytekit-aws-sagemaker testing pattern: boto3 +traffic is mocked at the single ``EMRServerlessHandler._call`` (and +``_paginate``) chokepoint, not at the boto3 client factory. This keeps +test setup dependency-free -- no ``moto``, no ``botocore.stub`` wiring. +""" + +from unittest.mock import AsyncMock, patch + +import pytest + + +@pytest.fixture +def mock_call(): + """Patch :py:meth:`EMRServerlessHandler._call` for a test. + + Yields an :class:`~unittest.mock.AsyncMock` so tests can set + ``return_value`` or ``side_effect`` and assert on ``call_args``. + """ + with patch( + "flytekitplugins.awsemrserverless.boto_handler.EMRServerlessHandler._call", + new_callable=AsyncMock, + ) as m: + yield m + + +@pytest.fixture +def mock_paginate(): + """Patch :py:meth:`EMRServerlessHandler._paginate` for a test.""" + with patch( + "flytekitplugins.awsemrserverless.boto_handler.EMRServerlessHandler._paginate", + new_callable=AsyncMock, + ) as m: + yield m + + +@pytest.fixture +def sample_spark_job_driver(): + from flytekitplugins.awsemrserverless import EMRServerlessSparkJobDriver + + return EMRServerlessSparkJobDriver( + entry_point="s3://my-bucket/scripts/main.py", + entry_point_arguments=[ + "--input", + "s3://data/input", + "--output", + "s3://data/output", + ], + spark_submit_parameters="--conf spark.executor.memory=4g --conf spark.executor.cores=2", + ) + + +@pytest.fixture +def sample_hive_job_driver(): + from flytekitplugins.awsemrserverless import EMRServerlessHiveJobDriver + + return EMRServerlessHiveJobDriver( + query="s3://my-bucket/queries/analytics.sql", + init_query_file="s3://my-bucket/queries/init.sql", + parameters="--hiveconf hive.exec.parallel=true", + ) + + +@pytest.fixture +def sample_config(sample_spark_job_driver): + from flytekitplugins.awsemrserverless import EMRServerless + + return EMRServerless( + application_id="00f5abc123def456", + execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole", + spark_job_driver=sample_spark_job_driver, + region="us-east-1", + tags={"Environment": "test", "Team": "data-engineering"}, + execution_timeout_minutes=120, + ) + + +@pytest.fixture +def sample_task_config(sample_spark_job_driver): + from flytekitplugins.awsemrserverless import EMRServerless + + return EMRServerless( + application_id="00f5abc123def456", + execution_role_arn="arn:aws:iam::123456789012:role/EMRServerlessRole", + spark_job_driver=sample_spark_job_driver, + region="us-east-1", + ) + + +@pytest.fixture +def sample_job_metadata(): + from flytekitplugins.awsemrserverless import EMRServerlessJobMetadata + + return EMRServerlessJobMetadata( + application_id="00f5abc123def456", + job_run_id="00f5xyz789ghi012", + region="us-east-1", + created_application=False, + ) + + +# ---------------------------------------------------------------------- +# Boto response shape fixtures +# +# These match the shape returned by the boto3 client methods (what gets +# funneled through ``_call``). Tests feed them directly as the ``_call`` +# mock's return value. +# ---------------------------------------------------------------------- + + +@pytest.fixture +def mock_job_run_success(): + return { + "jobRun": { + "applicationId": "00f5abc123def456", + "jobRunId": "00f5xyz789ghi012", + "state": "SUCCESS", + "stateDetails": "Job completed successfully", + } + } + + +@pytest.fixture +def mock_job_run_running(): + return { + "jobRun": { + "applicationId": "00f5abc123def456", + "jobRunId": "00f5xyz789ghi012", + "state": "RUNNING", + "stateDetails": "Job is running", + } + } + + +@pytest.fixture +def mock_job_run_failed(): + return { + "jobRun": { + "applicationId": "00f5abc123def456", + "jobRunId": "00f5xyz789ghi012", + "state": "FAILED", + "stateDetails": "Job failed due to OOM", + } + } + + +@pytest.fixture +def mock_application_started(): + return { + "application": { + "applicationId": "00f5abc123def456", + "state": "STARTED", + "name": "test-app", + } + } diff --git a/plugins/flytekit-aws-emr-serverless/tests/test_boto_handler.py b/plugins/flytekit-aws-emr-serverless/tests/test_boto_handler.py new file mode 100644 index 0000000000..cf2c75636e --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/tests/test_boto_handler.py @@ -0,0 +1,336 @@ +""" +Unit tests for the EMR Serverless boto3 handler. + +Tests the thin typed methods on :class:`EMRServerlessHandler` by patching +the single ``_call`` (and ``_paginate``) chokepoint that funnels all boto3 +traffic. This mirrors the flytekit-aws-sagemaker testing pattern +(``mock.patch("...Boto3ConnectorMixin._call")``). +""" + +from unittest.mock import AsyncMock, patch + +import pytest +from botocore.exceptions import ClientError + +from flytekitplugins.awsemrserverless.boto_handler import EMRServerlessHandler + + +class TestEMRServerlessHandlerInit: + def test_init_default(self): + handler = EMRServerlessHandler() + assert handler.region is None + + def test_init_with_region(self): + handler = EMRServerlessHandler(region="us-west-2") + assert handler.region == "us-west-2" + + +class TestEMRServerlessHandlerGetApplication: + @pytest.mark.asyncio + async def test_get_application_success(self, mock_call, mock_application_started): + mock_call.return_value = mock_application_started + + handler = EMRServerlessHandler() + result = await handler.get_application("00f5abc123def456") + + assert result["applicationId"] == "00f5abc123def456" + assert result["state"] == "STARTED" + mock_call.assert_called_once_with("get_application", applicationId="00f5abc123def456") + + @pytest.mark.asyncio + async def test_get_application_not_found(self, mock_call): + mock_call.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "Not found"}}, + "GetApplication", + ) + + handler = EMRServerlessHandler() + + with pytest.raises(ClientError): + await handler.get_application("nonexistent-app") + + +class TestEMRServerlessHandlerStartJobRun: + @pytest.mark.asyncio + async def test_start_job_run_success(self, mock_call): + mock_call.return_value = { + "applicationId": "app-123", + "jobRunId": "job-456", + } + + handler = EMRServerlessHandler() + job_driver = {"sparkSubmit": {"entryPoint": "s3://bucket/main.py"}} + + job_run_id = await handler.start_job_run( + application_id="app-123", + execution_role_arn="arn:aws:iam::123456789012:role/Role", + job_driver=job_driver, + ) + + assert job_run_id == "job-456" + mock_call.assert_called_once() + assert mock_call.call_args.args == ("start_job_run",) + + @pytest.mark.asyncio + async def test_start_job_run_with_all_options(self, mock_call): + mock_call.return_value = { + "applicationId": "app-123", + "jobRunId": "job-456", + } + + handler = EMRServerlessHandler() + job_driver = {"sparkSubmit": {"entryPoint": "s3://bucket/main.py"}} + config_overrides = {"monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": "s3://logs/"}}} + tags = {"Environment": "test"} + + await handler.start_job_run( + application_id="app-123", + execution_role_arn="arn:aws:iam::123456789012:role/Role", + job_driver=job_driver, + configuration_overrides=config_overrides, + tags=tags, + execution_timeout_minutes=120, + name="test-job", + ) + + call_kwargs = mock_call.call_args.kwargs + assert call_kwargs["configurationOverrides"] == config_overrides + assert call_kwargs["tags"] == tags + assert call_kwargs["executionTimeoutMinutes"] == 120 + assert call_kwargs["name"] == "test-job" + + @pytest.mark.asyncio + async def test_start_job_run_with_retry_policy(self, mock_call): + mock_call.return_value = {"applicationId": "app-123", "jobRunId": "job-456"} + + handler = EMRServerlessHandler() + await handler.start_job_run( + application_id="app-123", + execution_role_arn="arn:aws:iam::123456789012:role/Role", + job_driver={"sparkSubmit": {"entryPoint": "s3://bucket/main.py"}}, + retry_policy={"maxAttempts": 3}, + ) + + call_kwargs = mock_call.call_args.kwargs + assert call_kwargs["retryPolicy"]["maxAttempts"] == 3 + + @pytest.mark.asyncio + async def test_start_job_run_without_retry_policy(self, mock_call): + mock_call.return_value = {"applicationId": "app-123", "jobRunId": "job-456"} + + handler = EMRServerlessHandler() + await handler.start_job_run( + application_id="app-123", + execution_role_arn="arn:aws:iam::123456789012:role/Role", + job_driver={"sparkSubmit": {"entryPoint": "s3://bucket/main.py"}}, + ) + + call_kwargs = mock_call.call_args.kwargs + assert "retryPolicy" not in call_kwargs + + +class TestEMRServerlessHandlerGetJobRun: + @pytest.mark.asyncio + async def test_get_job_run_success(self, mock_call, mock_job_run_running): + mock_call.return_value = mock_job_run_running + + handler = EMRServerlessHandler() + result = await handler.get_job_run("app-123", "job-456") + + assert result["state"] == "RUNNING" + mock_call.assert_called_once_with( + "get_job_run", applicationId="app-123", jobRunId="job-456" + ) + + @pytest.mark.asyncio + async def test_get_job_run_not_found(self, mock_call): + mock_call.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "Not found"}}, + "GetJobRun", + ) + + handler = EMRServerlessHandler() + + with pytest.raises(ClientError): + await handler.get_job_run("app-123", "nonexistent-job") + + +class TestEMRServerlessHandlerCancelJobRun: + @pytest.mark.asyncio + async def test_cancel_job_run_success(self, mock_call, mock_job_run_running): + # First _call -> get_job_run returns running, second -> cancel_job_run returns {} + mock_call.side_effect = [mock_job_run_running, {}] + + handler = EMRServerlessHandler() + await handler.cancel_job_run("app-123", "job-456") + + assert mock_call.call_count == 2 + assert mock_call.call_args_list[0].args == ("get_job_run",) + assert mock_call.call_args_list[1].args == ("cancel_job_run",) + + @pytest.mark.asyncio + async def test_cancel_job_run_already_completed(self, mock_call, mock_job_run_success): + mock_call.return_value = mock_job_run_success + + handler = EMRServerlessHandler() + await handler.cancel_job_run("app-123", "job-456") + + # Only the get_job_run call; cancel_job_run is skipped because it's terminal. + mock_call.assert_called_once() + assert mock_call.call_args.args == ("get_job_run",) + + @pytest.mark.asyncio + async def test_cancel_job_run_not_found_is_swallowed(self, mock_call): + mock_call.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "Not found"}}, + "GetJobRun", + ) + + handler = EMRServerlessHandler() + # Does not raise. + await handler.cancel_job_run("app-123", "nonexistent-job") + + +class TestEMRServerlessHandlerCreateApplication: + @pytest.mark.asyncio + async def test_create_application_with_new_params(self, mock_call): + mock_call.return_value = {"applicationId": "app-new"} + + handler = EMRServerlessHandler() + app_id = await handler.create_application( + name="test-app", + release_label="emr-7.0.0", + application_type="SPARK", + architecture="ARM64", + runtime_configuration=[{"classification": "spark-defaults", "properties": {}}], + scheduler_configuration={"maxConcurrentRuns": 5}, + auto_stop_config={"enabled": True, "idleTimeoutMinutes": 30}, + ) + + assert app_id == "app-new" + call_kwargs = mock_call.call_args.kwargs + assert call_kwargs["architecture"] == "ARM64" + assert call_kwargs["runtimeConfiguration"][0]["classification"] == "spark-defaults" + assert call_kwargs["schedulerConfiguration"]["maxConcurrentRuns"] == 5 + assert call_kwargs["autoStopConfiguration"]["enabled"] is True + + +class TestEMRServerlessHandlerUpdateApplication: + @pytest.mark.asyncio + async def test_update_application_image(self, mock_call): + mock_call.return_value = {} + + handler = EMRServerlessHandler() + await handler.update_application( + application_id="app-123", + image_configuration={"imageUri": "new-image:v2"}, + ) + + mock_call.assert_called_once() + call_kwargs = mock_call.call_args.kwargs + assert call_kwargs["applicationId"] == "app-123" + assert call_kwargs["imageConfiguration"]["imageUri"] == "new-image:v2" + + @pytest.mark.asyncio + async def test_update_application_noop(self, mock_call): + """When no mutable fields are provided, no API call is made.""" + handler = EMRServerlessHandler() + await handler.update_application(application_id="app-123") + + mock_call.assert_not_called() + + @pytest.mark.asyncio + async def test_update_application_multiple_fields(self, mock_call): + mock_call.return_value = {} + + handler = EMRServerlessHandler() + await handler.update_application( + application_id="app-123", + image_configuration={"imageUri": "image:v3"}, + maximum_capacity={"cpu": "200vCPU"}, + scheduler_configuration={"maxConcurrentRuns": 10}, + ) + + call_kwargs = mock_call.call_args.kwargs + assert call_kwargs["imageConfiguration"]["imageUri"] == "image:v3" + assert call_kwargs["maximumCapacity"]["cpu"] == "200vCPU" + assert call_kwargs["schedulerConfiguration"]["maxConcurrentRuns"] == 10 + + +class TestEMRServerlessHandlerFindApplicationByName: + @pytest.mark.asyncio + async def test_find_application_returns_id_when_present(self, mock_paginate): + mock_paginate.return_value = [ + {"id": "00f-other", "name": "other-app", "state": "STARTED"}, + {"id": "00f-target", "name": "my-app", "state": "STARTED"}, + ] + + handler = EMRServerlessHandler() + result = await handler.find_application_by_name("my-app") + + assert result == "00f-target" + mock_paginate.assert_called_once_with( + "list_applications", + result_key="applications", + states=["CREATED", "STARTED", "STOPPED"], + ) + + @pytest.mark.asyncio + async def test_find_application_returns_none_when_absent(self, mock_paginate): + mock_paginate.return_value = [ + {"id": "00f-other", "name": "other-app", "state": "STARTED"}, + ] + + handler = EMRServerlessHandler() + result = await handler.find_application_by_name("nonexistent") + + assert result is None + + @pytest.mark.asyncio + async def test_find_application_empty_list(self, mock_paginate): + mock_paginate.return_value = [] + + handler = EMRServerlessHandler() + result = await handler.find_application_by_name("anything") + + assert result is None + + +class TestEMRServerlessHandlerEnsureApplicationStarted: + @pytest.mark.asyncio + async def test_returns_immediately_when_started(self, mock_call): + mock_call.return_value = {"application": {"applicationId": "app-1", "state": "STARTED"}} + + handler = EMRServerlessHandler() + await handler.ensure_application_started("app-1") + + # Only GetApplication, no StartApplication. + mock_call.assert_called_once() + assert mock_call.call_args.args == ("get_application",) + + @pytest.mark.asyncio + async def test_starts_a_stopped_application(self, mock_call): + # First GetApplication -> STOPPED, StartApplication -> {}, then GetApplication -> STARTED. + mock_call.side_effect = [ + {"application": {"applicationId": "app-1", "state": "STOPPED"}}, + {}, + {"application": {"applicationId": "app-1", "state": "STARTED"}}, + ] + + handler = EMRServerlessHandler() + # Use a very small poll interval to keep the test fast. + with patch("flytekitplugins.awsemrserverless.boto_handler.asyncio.sleep", new_callable=AsyncMock): + await handler.ensure_application_started("app-1", poll_interval_seconds=0) + + methods_called = [c.args[0] for c in mock_call.call_args_list] + assert methods_called[0] == "get_application" + assert methods_called[1] == "start_application" + assert methods_called[2] == "get_application" + + @pytest.mark.asyncio + async def test_raises_on_terminal_state(self, mock_call): + mock_call.return_value = {"application": {"applicationId": "app-1", "state": "TERMINATED"}} + + handler = EMRServerlessHandler() + with pytest.raises(RuntimeError, match="terminal state"): + await handler.ensure_application_started("app-1") diff --git a/plugins/flytekit-aws-emr-serverless/tests/test_connector.py b/plugins/flytekit-aws-emr-serverless/tests/test_connector.py new file mode 100644 index 0000000000..fb0ba4d13e --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/tests/test_connector.py @@ -0,0 +1,1266 @@ +""" +Unit tests for EMR Serverless Connector. +""" + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from flyteidl.core.execution_pb2 import TaskExecution + +from flytekitplugins.awsemrserverless import ( + EMRServerlessConnector, + EMRServerlessJobMetadata, +) +from flytekit.extend.backend.utils import convert_to_flyte_phase + +from flytekitplugins.awsemrserverless.connector import ( + EMR_SERVERLESS_STATES, + _ENV_ALLOW_CREATE_APPLICATION, + _ENV_APPLICATION_NAME_PREFIX, +) + + +def _make_handler(**overrides) -> AsyncMock: + """Build a mock EMRServerlessHandler with sensible defaults.""" + handler = AsyncMock() + handler.ensure_application_started = AsyncMock() + handler.start_job_run = AsyncMock(return_value="job-123") + handler.create_application = AsyncMock(return_value="new-app-123") + handler.get_application = AsyncMock(return_value={ + "applicationId": "00f5abc123def456", + "state": "STARTED", + "imageConfiguration": {}, + }) + handler.update_application = AsyncMock() + handler.client.meta.region_name = "us-east-1" + handler.region = "us-east-1" + for k, v in overrides.items(): + setattr(handler, k, v) + return handler + + +class TestEMRServerlessJobMetadata: + def test_metadata_creation(self): + metadata = EMRServerlessJobMetadata( + application_id="app-123", + job_run_id="job-456", + region="us-east-1", + ) + assert metadata.application_id == "app-123" + assert metadata.job_run_id == "job-456" + assert metadata.region == "us-east-1" + assert metadata.created_application is False + + def test_metadata_with_created_app(self): + metadata = EMRServerlessJobMetadata( + application_id="app-123", + job_run_id="job-456", + region="us-west-2", + created_application=True, + ) + assert metadata.created_application is True + + +class TestEMRServerlessConnector: + def test_connector_initialization(self): + connector = EMRServerlessConnector() + assert connector.name == "EMR Serverless Connector" + + @pytest.mark.asyncio + async def test_create_with_existing_application(self, sample_config): + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = sample_config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=sample_config): + metadata = await connector.create(mock_template) + + assert isinstance(metadata, EMRServerlessJobMetadata) + assert metadata.application_id == sample_config.application_id + assert metadata.job_run_id == "job-123" + mock_handler.ensure_application_started.assert_called_once() + mock_handler.start_job_run.assert_called_once() + + @pytest.mark.asyncio + async def test_create_with_new_application(self, sample_spark_job_driver): + """application_name + env var true => create succeeds.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="test-app", + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.dict(os.environ, {_ENV_ALLOW_CREATE_APPLICATION: "true"}): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + metadata = await connector.create(mock_template) + + assert metadata.application_id == "new-app-123" + assert metadata.created_application is True + mock_handler.create_application.assert_called_once() + assert mock_handler.create_application.call_args.kwargs["name"] == "test-app" + + @pytest.mark.asyncio + async def test_create_fails_no_app_id_no_app_name(self, sample_spark_job_driver): + """Without application_id or application_name, create must fail.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.dict(os.environ, {_ENV_ALLOW_CREATE_APPLICATION: "true"}): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with pytest.raises(ValueError, match="No application_id provided and no application_name"): + await connector.create(mock_template) + + @pytest.mark.asyncio + async def test_create_blocked_by_connector_env_var(self, sample_spark_job_driver): + """When connector env var is false (default), creation is blocked even with application_name.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="my-app", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.dict(os.environ, {_ENV_ALLOW_CREATE_APPLICATION: "false"}): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with pytest.raises(ValueError, match="disabled by the connector"): + await connector.create(mock_template) + + @pytest.mark.asyncio + async def test_create_blocked_when_env_var_unset(self, sample_spark_job_driver): + """When the env var is not set, default is false => creation blocked.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="my-app", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + env = os.environ.copy() + env.pop(_ENV_ALLOW_CREATE_APPLICATION, None) + with patch.dict(os.environ, env, clear=True): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with pytest.raises(ValueError, match="disabled by the connector"): + await connector.create(mock_template) + + @pytest.mark.asyncio + async def test_create_applies_application_name_prefix(self, sample_spark_job_driver): + """FLYTE_EMR_APPLICATION_NAME_PREFIX should be prepended to the app name.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="my-etl-app", + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.dict(os.environ, { + _ENV_ALLOW_CREATE_APPLICATION: "true", + _ENV_APPLICATION_NAME_PREFIX: "flyte-prod-", + }): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.create_application.call_args.kwargs + assert call_kwargs["name"] == "flyte-prod-my-etl-app" + + @pytest.mark.asyncio + async def test_create_no_double_prefix(self, sample_spark_job_driver): + """If application_name already starts with the prefix, don't double-prefix.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="flyte-prod-my-app", + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.dict(os.environ, { + _ENV_ALLOW_CREATE_APPLICATION: "true", + _ENV_APPLICATION_NAME_PREFIX: "flyte-prod-", + }): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.create_application.call_args.kwargs + assert call_kwargs["name"] == "flyte-prod-my-app" + + @pytest.mark.asyncio + async def test_create_no_prefix_when_unset(self, sample_spark_job_driver): + """Without the prefix env var, the application name is used as-is.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="my-raw-name", + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + env = os.environ.copy() + env.pop(_ENV_APPLICATION_NAME_PREFIX, None) + env[_ENV_ALLOW_CREATE_APPLICATION] = "true" + with patch.dict(os.environ, env, clear=True): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.create_application.call_args.kwargs + assert call_kwargs["name"] == "my-raw-name" + + @pytest.mark.asyncio + async def test_create_passes_architecture_and_scheduler(self, sample_spark_job_driver): + """Verify new app-creation params are forwarded to the handler.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + application_name="test-app", + architecture="ARM64", + scheduler_configuration={"maxConcurrentRuns": 5}, + auto_stop_config={"enabled": True, "idleTimeoutMinutes": 30}, + runtime_configuration=[{"classification": "spark-defaults", "properties": {"spark.executor.cores": "8"}}], + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.dict(os.environ, {_ENV_ALLOW_CREATE_APPLICATION: "true"}): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.create_application.call_args.kwargs + assert call_kwargs["architecture"] == "ARM64" + assert call_kwargs["scheduler_configuration"] == {"maxConcurrentRuns": 5} + assert call_kwargs["auto_stop_config"] == {"enabled": True, "idleTimeoutMinutes": 30} + assert call_kwargs["runtime_configuration"][0]["classification"] == "spark-defaults" + + @pytest.mark.asyncio + async def test_create_with_retry_policy(self, sample_config): + """Verify retry_policy is forwarded to start_job_run.""" + from dataclasses import replace + + config = replace(sample_config, retry_policy={"maxAttempts": 3}) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + assert call_kwargs["retry_policy"] == {"maxAttempts": 3} + + @pytest.mark.asyncio + async def test_create_syncs_image_from_container_image(self): + """When sync_image=True and container_image differs from app, update_application is called. + + Script mode skips sync unless image_configuration is explicitly set, + so this test uses a script-mode config with image_configuration to + exercise the sync path. + """ + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + image_configuration={"imageUri": "123456.dkr.ecr.us-east-1.amazonaws.com/spark:v2"}, + sync_image=True, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + mock_handler.get_application = AsyncMock(return_value={ + "applicationId": "00abc123", + "state": "STARTED", + "imageConfiguration": {"imageUri": "123456.dkr.ecr.us-east-1.amazonaws.com/spark:v1"}, + }) + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = MagicMock() + mock_template.container.image = "123456.dkr.ecr.us-east-1.amazonaws.com/spark:v2" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + mock_handler.update_application.assert_called_once_with( + application_id="00abc123", + image_configuration={"imageUri": "123456.dkr.ecr.us-east-1.amazonaws.com/spark:v2"}, + ) + + @pytest.mark.asyncio + async def test_create_syncs_image_falls_back_to_image_configuration(self): + """When container_image is absent, sync_image falls back to image_configuration.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + image_configuration={"imageUri": "new-image:v2"}, + sync_image=True, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + mock_handler.get_application = AsyncMock(return_value={ + "applicationId": "00abc123", + "state": "STARTED", + "imageConfiguration": {"imageUri": "old-image:v1"}, + }) + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = None + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + mock_handler.update_application.assert_called_once_with( + application_id="00abc123", + image_configuration={"imageUri": "new-image:v2"}, + ) + + @pytest.mark.asyncio + async def test_create_skips_image_sync_when_same(self): + """When container_image matches the application, update is skipped.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + sync_image=True, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + mock_handler.get_application = AsyncMock(return_value={ + "applicationId": "00abc123", + "state": "STARTED", + "imageConfiguration": {"imageUri": "same-image:v1"}, + }) + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = MagicMock() + mock_template.container.image = "same-image:v1" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + mock_handler.update_application.assert_not_called() + + @pytest.mark.asyncio + async def test_create_skips_image_sync_when_disabled(self): + """When sync_image=False, update_application should not be called.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = MagicMock() + mock_template.container.image = "new-image:v2" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + mock_handler.update_application.assert_not_called() + mock_handler.get_application.assert_not_called() + + @pytest.mark.asyncio + async def test_create_image_sync_noop_when_no_image(self): + """When there's no container_image or image_configuration, sync is a no-op.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + sync_image=True, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = None + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + mock_handler.update_application.assert_not_called() + mock_handler.get_application.assert_not_called() + + @pytest.mark.asyncio + async def test_create_merges_spark_submit_params_in_script_mode(self): + """Top-level spark_submit_parameters should be appended to the job driver.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver( + entry_point="s3://bucket/main.py", + spark_submit_parameters="--conf spark.executor.memory=4g", + ), + spark_submit_parameters="--conf spark.driver.memory=8g", + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + submit_params = call_kwargs["job_driver"]["sparkSubmit"]["sparkSubmitParameters"] + assert "spark.executor.memory=4g" in submit_params + assert "spark.driver.memory=8g" in submit_params + + @pytest.mark.asyncio + async def test_create_merges_application_configuration(self): + """application_configuration should be merged into configuration_overrides.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + configuration_overrides={ + "monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": "s3://logs/"}} + }, + application_configuration=[ + {"classification": "spark-defaults", "properties": {"spark.executor.cores": "8"}} + ], + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = None + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + config_overrides = call_kwargs["configuration_overrides"] + assert "monitoringConfiguration" in config_overrides + assert len(config_overrides["applicationConfiguration"]) == 1 + assert config_overrides["applicationConfiguration"][0]["classification"] == "spark-defaults" + + @pytest.mark.asyncio + async def test_create_pythonic_mode(self): + """Test create with Pythonic mode (no explicit job driver).""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_container = MagicMock() + mock_container.image = "my-image:latest" + mock_container.args = [ + "pyflyte-execute", + "--task-module", + "my.module", + "--task-name", + "my_task", + ] + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = mock_container + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with patch.object( + connector, "_ensure_entrypoint_on_s3", return_value="s3://bucket/flyte/entrypoint.py" + ): + metadata = await connector.create(mock_template) + + assert metadata.job_run_id == "job-123" + call_kwargs = mock_handler.start_job_run.call_args.kwargs + assert "sparkSubmit" in call_kwargs["job_driver"] + assert call_kwargs["job_driver"]["sparkSubmit"]["entryPoint"] == "s3://bucket/flyte/entrypoint.py" + assert "pyflyte-execute" in call_kwargs["job_driver"]["sparkSubmit"]["entryPointArguments"] + + @pytest.mark.asyncio + async def test_create_pythonic_mode_with_spark_submit_params(self): + """spark_submit_parameters should be merged in Pythonic mode.""" + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_submit_parameters="--conf spark.executor.memory=16g --conf spark.executor.cores=8", + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_container = MagicMock() + mock_container.image = "my-image:latest" + mock_container.args = ["pyflyte-execute", "--task-module", "my.module"] + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = mock_container + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with patch.object( + connector, "_ensure_entrypoint_on_s3", return_value="s3://bucket/flyte/entrypoint.py" + ): + await connector.create(mock_template) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + submit_params = call_kwargs["job_driver"]["sparkSubmit"]["sparkSubmitParameters"] + assert "spark.executor.memory=16g" in submit_params + assert "spark.executor.cores=8" in submit_params + assert "PYSPARK_PYTHON" in submit_params + + @pytest.mark.asyncio + async def test_create_new_app_gets_default_image_in_pythonic_mode(self): + """When creating an app in Pythonic mode with no image, a default is used.""" + from flytekitplugins.awsemrserverless import EMRServerless + from flytekitplugins.awsemrserverless.task import EMR_SERVERLESS_BASE_IMAGE + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_name="pythonic-app", + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = MagicMock() + mock_template.container.image = None + mock_template.container.args = ["pyflyte-execute"] + + with patch.dict(os.environ, {_ENV_ALLOW_CREATE_APPLICATION: "true"}): + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with patch.object( + connector, "_ensure_entrypoint_on_s3", return_value="s3://bucket/entrypoint.py" + ): + await connector.create(mock_template) + + call_kwargs = mock_handler.create_application.call_args.kwargs + assert call_kwargs["image_configuration"]["imageUri"] == EMR_SERVERLESS_BASE_IMAGE + + @pytest.mark.asyncio + async def test_get_running_job(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.get_job_run = AsyncMock( + return_value={ + "state": "RUNNING", + "stateDetails": "Job is executing", + } + ) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + resource = await connector.get(sample_job_metadata) + + assert resource.phase == TaskExecution.RUNNING + assert "RUNNING" in resource.message + + @pytest.mark.asyncio + async def test_get_successful_job(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.get_job_run = AsyncMock( + return_value={ + "state": "SUCCESS", + "stateDetails": "Job completed successfully", + } + ) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + resource = await connector.get(sample_job_metadata) + + assert resource.phase == TaskExecution.SUCCEEDED + + @pytest.mark.asyncio + async def test_get_failed_job(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.get_job_run = AsyncMock( + return_value={ + "state": "FAILED", + "stateDetails": "OutOfMemoryError", + } + ) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + resource = await connector.get(sample_job_metadata) + + assert resource.phase == TaskExecution.FAILED + assert "OutOfMemoryError" in resource.message + + @pytest.mark.asyncio + async def test_get_cancelled_job(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.get_job_run = AsyncMock( + return_value={ + "state": "CANCELLED", + "stateDetails": "User cancelled", + } + ) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + resource = await connector.get(sample_job_metadata) + + assert resource.phase == TaskExecution.FAILED + + @pytest.mark.asyncio + async def test_get_job_not_found(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.get_job_run = AsyncMock(side_effect=RuntimeError("Job not found")) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + resource = await connector.get(sample_job_metadata) + + assert resource.phase == TaskExecution.FAILED + assert "not found" in resource.message.lower() + + @pytest.mark.asyncio + async def test_get_pending_states(self, sample_job_metadata): + connector = EMRServerlessConnector() + + for state in ["PENDING", "SCHEDULED", "SUBMITTED"]: + mock_handler = AsyncMock() + mock_handler.get_job_run = AsyncMock(return_value={"state": state, "stateDetails": ""}) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + resource = await connector.get(sample_job_metadata) + + assert resource.phase == TaskExecution.RUNNING, f"State {state} should map to RUNNING" + + @pytest.mark.asyncio + async def test_delete_running_job(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.cancel_job_run = AsyncMock() + + with patch.object(connector, "_get_handler", return_value=mock_handler): + await connector.delete(sample_job_metadata) + + mock_handler.cancel_job_run.assert_called_once_with( + application_id=sample_job_metadata.application_id, + job_run_id=sample_job_metadata.job_run_id, + ) + + @pytest.mark.asyncio + async def test_delete_handles_errors_gracefully(self, sample_job_metadata): + connector = EMRServerlessConnector() + + mock_handler = AsyncMock() + mock_handler.cancel_job_run = AsyncMock(side_effect=Exception("Network error")) + + with patch.object(connector, "_get_handler", return_value=mock_handler): + await connector.delete(sample_job_metadata) + + +class TestEMRServerlessStateMapping: + def test_all_states_mapped(self): + expected_states = [ + "PENDING", + "SCHEDULED", + "SUBMITTED", + "RUNNING", + "SUCCESS", + "FAILED", + "CANCELLING", + "CANCELLED", + ] + for state in expected_states: + assert state in EMR_SERVERLESS_STATES, f"Missing mapping for {state}" + + def test_running_states_map_to_running(self): + running_states = ["PENDING", "SCHEDULED", "SUBMITTED", "RUNNING", "CANCELLING"] + for state in running_states: + assert EMR_SERVERLESS_STATES[state] == "Running" + assert convert_to_flyte_phase(EMR_SERVERLESS_STATES[state]) == TaskExecution.RUNNING + + def test_success_maps_to_succeeded(self): + assert EMR_SERVERLESS_STATES["SUCCESS"] == "Success" + assert convert_to_flyte_phase(EMR_SERVERLESS_STATES["SUCCESS"]) == TaskExecution.SUCCEEDED + + def test_failure_states_map_to_failed(self): + for state in ["FAILED", "CANCELLED"]: + assert EMR_SERVERLESS_STATES[state] == "Failed" + assert convert_to_flyte_phase(EMR_SERVERLESS_STATES[state]) == TaskExecution.FAILED + + +def _make_task_execution_metadata( + *, + exec_name: str = "abc123def456", + exec_project: str = "flytetest", + exec_domain: str = "development", + task_project: str = "flytetest", + task_domain: str = "development", + task_name: str = "examples.script_spark_job.spark_demo_task", + task_version: str = "v1", + node_id: str = "n0", + retry_attempt: int = 0, + environment_variables: dict = None, +) -> MagicMock: + """Build a mock TaskExecutionMetadata with the nested id graph populated.""" + meta = MagicMock() + meta.task_execution_id.task_id.project = task_project + meta.task_execution_id.task_id.domain = task_domain + meta.task_execution_id.task_id.name = task_name + meta.task_execution_id.task_id.version = task_version + meta.task_execution_id.node_execution_id.node_id = node_id + meta.task_execution_id.node_execution_id.execution_id.name = exec_name + meta.task_execution_id.node_execution_id.execution_id.project = exec_project + meta.task_execution_id.node_execution_id.execution_id.domain = exec_domain + meta.task_execution_id.retry_attempt = retry_attempt + meta.environment_variables = environment_variables or {} + return meta + + +class TestFlyteEnvInjection: + """Tests for the Flyte context env var injection feature.""" + + def test_build_flyte_env_vars_none_metadata(self): + assert EMRServerlessConnector._build_flyte_env_vars(None) == {} + + def test_build_flyte_env_vars_full(self): + meta = _make_task_execution_metadata() + env = EMRServerlessConnector._build_flyte_env_vars(meta) + assert env["FLYTE_INTERNAL_EXECUTION_ID"] == "abc123def456" + assert env["FLYTE_INTERNAL_EXECUTION_PROJECT"] == "flytetest" + assert env["FLYTE_INTERNAL_EXECUTION_DOMAIN"] == "development" + assert env["FLYTE_INTERNAL_TASK_PROJECT"] == "flytetest" + assert env["FLYTE_INTERNAL_TASK_DOMAIN"] == "development" + assert env["FLYTE_INTERNAL_TASK_NAME"] == "examples.script_spark_job.spark_demo_task" + assert env["FLYTE_INTERNAL_TASK_VERSION"] == "v1" + assert env["FLYTE_INTERNAL_NODE_ID"] == "n0" + assert env["FLYTE_INTERNAL_TASK_RETRY_ATTEMPT"] == "0" + + def test_build_flyte_env_vars_includes_user_env(self): + meta = _make_task_execution_metadata(environment_variables={"MY_VAR": "hello"}) + env = EMRServerlessConnector._build_flyte_env_vars(meta) + assert env["MY_VAR"] == "hello" + + def test_build_flyte_env_vars_drops_empty_values(self): + meta = _make_task_execution_metadata(task_version="") + env = EMRServerlessConnector._build_flyte_env_vars(meta) + assert "FLYTE_INTERNAL_TASK_VERSION" not in env + assert "FLYTE_INTERNAL_EXECUTION_ID" in env + + def test_format_env_as_spark_conf_empty(self): + assert EMRServerlessConnector._format_env_as_spark_conf({}) == "" + + def test_format_env_as_spark_conf_produces_driver_and_executor(self): + conf = EMRServerlessConnector._format_env_as_spark_conf( + {"FLYTE_INTERNAL_EXECUTION_ID": "abc123"} + ) + assert "spark.emr-serverless.driverEnv.FLYTE_INTERNAL_EXECUTION_ID=abc123" in conf + assert "spark.executorEnv.FLYTE_INTERNAL_EXECUTION_ID=abc123" in conf + + def test_format_env_as_spark_conf_skips_unsafe_values(self): + conf = EMRServerlessConnector._format_env_as_spark_conf( + {"SAFE": "ok", "BAD_SPACES": "has spaces", "BAD_QUOTE": 'has"quote'} + ) + assert "SAFE=ok" in conf + assert "BAD_SPACES" not in conf + assert "BAD_QUOTE" not in conf + + def test_append_flyte_env_disabled(self): + meta = _make_task_execution_metadata() + result = EMRServerlessConnector._append_flyte_env_to_spark_params( + "--conf spark.executor.memory=4g", meta, enabled=False, + ) + assert result == "--conf spark.executor.memory=4g" + + def test_append_flyte_env_none_metadata(self): + result = EMRServerlessConnector._append_flyte_env_to_spark_params( + "--conf spark.executor.memory=4g", None, enabled=True, + ) + assert result == "--conf spark.executor.memory=4g" + + def test_append_flyte_env_preserves_existing_params(self): + meta = _make_task_execution_metadata() + result = EMRServerlessConnector._append_flyte_env_to_spark_params( + "--conf spark.executor.memory=4g", meta, enabled=True, + ) + assert result.startswith("--conf spark.executor.memory=4g ") + assert "FLYTE_INTERNAL_EXECUTION_ID=abc123def456" in result + + def test_append_flyte_env_when_existing_is_none(self): + meta = _make_task_execution_metadata() + result = EMRServerlessConnector._append_flyte_env_to_spark_params( + None, meta, enabled=True, + ) + assert "FLYTE_INTERNAL_EXECUTION_ID=abc123def456" in result + assert not result.startswith(" ") + + @pytest.mark.asyncio + async def test_create_injects_flyte_env_in_script_mode(self): + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver( + entry_point="s3://bucket/main.py", + spark_submit_parameters="--conf spark.executor.memory=4g", + ), + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + meta = _make_task_execution_metadata(exec_name="exec-xyz-789") + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template, task_execution_metadata=meta) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + submit_params = call_kwargs["job_driver"]["sparkSubmit"]["sparkSubmitParameters"] + assert "spark.executor.memory=4g" in submit_params + assert "spark.emr-serverless.driverEnv.FLYTE_INTERNAL_EXECUTION_ID=exec-xyz-789" in submit_params + assert "spark.executorEnv.FLYTE_INTERNAL_EXECUTION_ID=exec-xyz-789" in submit_params + + @pytest.mark.asyncio + async def test_create_respects_inject_flyte_env_false(self): + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver( + entry_point="s3://bucket/main.py", + spark_submit_parameters="--conf spark.executor.memory=4g", + ), + sync_image=False, + region="us-east-1", + inject_flyte_env=False, + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + meta = _make_task_execution_metadata() + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template, task_execution_metadata=meta) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + submit_params = call_kwargs["job_driver"]["sparkSubmit"]["sparkSubmitParameters"] + assert "FLYTE_INTERNAL_EXECUTION_ID" not in submit_params + + @pytest.mark.asyncio + async def test_create_injects_flyte_env_in_pythonic_mode(self): + from flytekitplugins.awsemrserverless import EMRServerless + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + + mock_container = MagicMock() + mock_container.image = "my-image:latest" + mock_container.args = ["pyflyte-execute", "--task-module", "m"] + + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + mock_template.container = mock_container + + meta = _make_task_execution_metadata(exec_name="py-exec-001") + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + with patch.object( + connector, "_ensure_entrypoint_on_s3", + return_value="s3://bucket/entrypoint.py", + ): + await connector.create(mock_template, task_execution_metadata=meta) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + submit_params = call_kwargs["job_driver"]["sparkSubmit"]["sparkSubmitParameters"] + assert "spark.emr-serverless.driverEnv.FLYTE_INTERNAL_EXECUTION_ID=py-exec-001" in submit_params + assert "spark.executorEnv.FLYTE_INTERNAL_EXECUTION_ID=py-exec-001" in submit_params + + @pytest.mark.asyncio + async def test_create_handles_missing_metadata_gracefully(self): + """No task_execution_metadata (e.g. local mode) should not raise.""" + from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessSparkJobDriver, + ) + + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_id="00abc123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://b/m.py"), + sync_image=False, + region="us-east-1", + ) + + connector = EMRServerlessConnector() + mock_handler = _make_handler() + mock_template = MagicMock() + mock_template.custom = config.to_dict() + mock_template.id = MagicMock() + mock_template.id.name = "test-task" + + with patch.object(connector, "_get_handler", return_value=mock_handler): + with patch.object(connector, "_extract_config", return_value=config): + await connector.create(mock_template) + + call_kwargs = mock_handler.start_job_run.call_args.kwargs + submit_params = call_kwargs["job_driver"]["sparkSubmit"].get("sparkSubmitParameters") + if submit_params is not None: + assert "FLYTE_INTERNAL_EXECUTION_ID" not in submit_params + + +class TestEntrypointBucketResolution: + """Resolution order for the Pythonic-mode entrypoint S3 bucket. + + Documented contract (see :py:meth:`EMRServerlessConnector._resolve_entrypoint_bucket`): + + 1. ``FLYTE_EMR_ENTRYPOINT_S3_BUCKET`` env var (preferred). + 2. ``configuration_overrides.monitoringConfiguration.s3MonitoringConfiguration.logUri`` + (reuses the same bucket the user already uses for monitoring logs). + 3. ``FLYTE_AWS_S3_BUCKET`` env var (Flyte's general-purpose bucket). + 4. Otherwise, raise -- Pythonic mode requires a bucket. + """ + + def test_env_var_takes_priority(self, sample_task_config): + with patch.dict( + os.environ, + {"FLYTE_EMR_ENTRYPOINT_S3_BUCKET": "explicit-bucket", "FLYTE_AWS_S3_BUCKET": "fallback"}, + clear=False, + ): + assert ( + EMRServerlessConnector._resolve_entrypoint_bucket(sample_task_config) == "explicit-bucket" + ) + + def test_env_var_strips_s3_prefix(self, sample_task_config): + with patch.dict( + os.environ, {"FLYTE_EMR_ENTRYPOINT_S3_BUCKET": "s3://my-bucket/some/prefix"}, clear=False + ): + assert ( + EMRServerlessConnector._resolve_entrypoint_bucket(sample_task_config) == "my-bucket" + ) + + def test_falls_back_to_monitoring_log_uri(self, sample_task_config): + sample_task_config.configuration_overrides = { + "monitoringConfiguration": { + "s3MonitoringConfiguration": {"logUri": "s3://monitoring-bucket/logs"} + } + } + with patch.dict( + os.environ, {"FLYTE_EMR_ENTRYPOINT_S3_BUCKET": "", "FLYTE_AWS_S3_BUCKET": ""}, clear=False + ): + os.environ.pop("FLYTE_EMR_ENTRYPOINT_S3_BUCKET", None) + os.environ.pop("FLYTE_AWS_S3_BUCKET", None) + assert ( + EMRServerlessConnector._resolve_entrypoint_bucket(sample_task_config) + == "monitoring-bucket" + ) + + def test_falls_back_to_flyte_aws_s3_bucket(self, sample_task_config): + with patch.dict( + os.environ, {"FLYTE_AWS_S3_BUCKET": "flyte-bucket"}, clear=False + ): + os.environ.pop("FLYTE_EMR_ENTRYPOINT_S3_BUCKET", None) + assert ( + EMRServerlessConnector._resolve_entrypoint_bucket(sample_task_config) == "flyte-bucket" + ) + + def test_raises_when_no_source_configured(self, sample_task_config): + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("FLYTE_EMR_ENTRYPOINT_S3_BUCKET", None) + os.environ.pop("FLYTE_AWS_S3_BUCKET", None) + with pytest.raises(ValueError, match="entrypoint"): + EMRServerlessConnector._resolve_entrypoint_bucket(sample_task_config) + + +class TestEnsureEntrypointOnS3: + """The connector uploads the entrypoint exactly once per content-hash, + keyed under ``flyte/emr-serverless/entrypoint-.py``.""" + + def _make_handler_with_region(self): + h = MagicMock() + h.region = "us-east-1" + return h + + def test_skips_upload_when_already_present(self, sample_task_config): + from flytekitplugins.awsemrserverless.connector import EMRServerlessConnector + + connector = EMRServerlessConnector() + handler = self._make_handler_with_region() + + s3_client = MagicMock() + s3_client.head_object.return_value = {"ContentLength": 1234} + + with ( + patch.dict(os.environ, {"FLYTE_EMR_ENTRYPOINT_S3_BUCKET": "test-bucket"}, clear=False), + patch("boto3.client", return_value=s3_client), + ): + uri = connector._ensure_entrypoint_on_s3(handler, sample_task_config) + + assert uri.startswith("s3://test-bucket/flyte/emr-serverless/entrypoint-") + assert uri.endswith(".py") + s3_client.head_object.assert_called_once() + s3_client.put_object.assert_not_called() + + def test_uploads_when_missing(self, sample_task_config): + from flytekitplugins.awsemrserverless.connector import EMRServerlessConnector + + connector = EMRServerlessConnector() + handler = self._make_handler_with_region() + + s3_client = MagicMock() + s3_client.head_object.side_effect = Exception("NotFound") + + with ( + patch.dict(os.environ, {"FLYTE_EMR_ENTRYPOINT_S3_BUCKET": "test-bucket"}, clear=False), + patch("boto3.client", return_value=s3_client), + ): + uri = connector._ensure_entrypoint_on_s3(handler, sample_task_config) + + s3_client.put_object.assert_called_once() + put_kwargs = s3_client.put_object.call_args.kwargs + assert put_kwargs["Bucket"] == "test-bucket" + assert put_kwargs["Key"].startswith("flyte/emr-serverless/entrypoint-") + assert put_kwargs["ContentType"] == "text/x-python" + # The body is the byte-identical content of _entrypoint.py + from flytekitplugins.awsemrserverless import _entrypoint + from pathlib import Path + + assert put_kwargs["Body"] == Path(_entrypoint.__file__).read_bytes() + assert uri == f"s3://test-bucket/{put_kwargs['Key']}" + + def test_uri_is_deterministic_across_calls(self, sample_task_config): + """Two consecutive calls produce the same URI -- this is what + makes EMR's job specs stable across connector restarts.""" + from flytekitplugins.awsemrserverless.connector import EMRServerlessConnector + + connector = EMRServerlessConnector() + handler = self._make_handler_with_region() + s3_client = MagicMock() + s3_client.head_object.return_value = {"ContentLength": 1} + + with ( + patch.dict(os.environ, {"FLYTE_EMR_ENTRYPOINT_S3_BUCKET": "b"}, clear=False), + patch("boto3.client", return_value=s3_client), + ): + uri1 = connector._ensure_entrypoint_on_s3(handler, sample_task_config) + uri2 = connector._ensure_entrypoint_on_s3(handler, sample_task_config) + + assert uri1 == uri2 diff --git a/plugins/flytekit-aws-emr-serverless/tests/test_entrypoint.py b/plugins/flytekit-aws-emr-serverless/tests/test_entrypoint.py new file mode 100644 index 0000000000..0e8b5f7e61 --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/tests/test_entrypoint.py @@ -0,0 +1,338 @@ +""" +Unit tests for the EMR Serverless Pythonic-mode entrypoint script. + +These tests exercise the entrypoint module on two levels: + +* **Pure unit tests** of the helper functions (``_parse_fast_execute_args``, + ``_build_resolver_command``, ``_exit_with_code``) by importing them + directly. These are fast, deterministic, and require no subprocess. +* **Subprocess tests** that run ``python -m`` on the actual file the + connector uploads to S3, verifying the wire-level CLI behaviour EMR + will see. + +The shape of these tests mirrors how Databricks's ``flytetools`` +maintains test coverage for ``flytekitplugins/databricks/entrypoint.py``. +""" + +import hashlib +import subprocess +import sys +from pathlib import Path +from unittest.mock import patch + +import pytest + +from flytekitplugins.awsemrserverless import _entrypoint + + +ENTRYPOINT_PATH: Path = Path(_entrypoint.__file__).resolve() + + +# --------------------------------------------------------------------------- +# Module-level invariants +# --------------------------------------------------------------------------- + + +class TestEntrypointModuleInvariants: + """Sanity checks on the entrypoint module itself.""" + + def test_entrypoint_file_exists(self): + assert ENTRYPOINT_PATH.is_file(), ( + f"_entrypoint.py must exist at {ENTRYPOINT_PATH}; the connector uploads this file to S3 and EMR runs it." + ) + + def test_entrypoint_is_valid_python(self): + source = ENTRYPOINT_PATH.read_text() + compile(source, str(ENTRYPOINT_PATH), "exec") + + def test_entrypoint_has_main(self): + assert callable(_entrypoint.main) + + def test_entrypoint_minimal_runtime_imports(self): + """The entrypoint runs inside the EMR worker; its module-level + imports must be available there. ``flytekit`` is the only + non-stdlib module we should import (specifically + ``download_distribution``).""" + source = ENTRYPOINT_PATH.read_text() + assert "from flytekit.tools.fast_registration import download_distribution" in source + assert "import boto3" not in source, ( + "_entrypoint.py must not import boto3 -- it runs in the EMR " + "worker which may not have boto3 installed at the version " + "the connector uses." + ) + + +# --------------------------------------------------------------------------- +# _parse_fast_execute_args +# --------------------------------------------------------------------------- + + +class TestParseFastExecuteArgs: + """The parser splits a ``pyflyte-fast-execute ...`` argv into its parts.""" + + def test_full_argv(self): + args = [ + "pyflyte-fast-execute", + "--additional-distribution", + "s3://bucket/fast-abc.tar.gz", + "--dest-dir", + "/tmp/work", + "pyflyte-execute", + "--inputs", + "s3://in", + "--output-prefix", + "s3://out", + ] + addl, dest, start = _entrypoint._parse_fast_execute_args(args) + assert addl == "s3://bucket/fast-abc.tar.gz" + assert dest == "/tmp/work" + assert args[start] == "pyflyte-execute" + + def test_only_additional_distribution(self): + args = [ + "pyflyte-fast-execute", + "--additional-distribution", + "s3://bucket/fast.tar.gz", + "pyflyte-execute", + "--inputs", + "s3://in", + ] + addl, dest, start = _entrypoint._parse_fast_execute_args(args) + assert addl == "s3://bucket/fast.tar.gz" + assert dest is None + assert args[start] == "pyflyte-execute" + + def test_no_distribution_flags(self): + args = ["pyflyte-fast-execute", "pyflyte-execute", "--inputs", "s3://in"] + addl, dest, start = _entrypoint._parse_fast_execute_args(args) + assert addl is None + assert dest is None + assert args[start] == "pyflyte-execute" + + def test_explicit_double_dash_separator(self): + args = [ + "pyflyte-fast-execute", + "--additional-distribution", + "s3://b/a.tar.gz", + "--", + "pyflyte-execute", + "--inputs", + "s3://in", + ] + addl, dest, start = _entrypoint._parse_fast_execute_args(args) + assert addl == "s3://b/a.tar.gz" + assert dest is None + assert args[start] == "pyflyte-execute" + + def test_dangling_flag_does_not_crash(self): + """``--additional-distribution`` with no value following must not crash. + + With no value, the loop falls through to the catch-all ``else`` branch + and treats the flag itself as the start of the task command. This is + recoverable nonsense (the inner subprocess will fail cleanly) -- the + important property is that the parser doesn't index past the end. + """ + args = ["pyflyte-fast-execute", "--additional-distribution"] + addl, dest, start = _entrypoint._parse_fast_execute_args(args) + assert addl is None + assert dest is None + assert 0 <= start < len(args) + + +# --------------------------------------------------------------------------- +# _build_resolver_command +# --------------------------------------------------------------------------- + + +class TestBuildResolverCommand: + """The resolver builder must inject ``--dynamic-addl-distro`` / + ``--dynamic-dest-dir`` immediately before any ``--resolver`` flag.""" + + def test_injects_dynamic_args_before_resolver(self): + task_cmd = [ + "pyflyte-execute", + "--inputs", + "s3://in", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + ] + out = _entrypoint._build_resolver_command(task_cmd, "s3://b/fast.tar.gz", "/tmp/work") + assert "--dynamic-addl-distro" in out + assert "s3://b/fast.tar.gz" in out + assert "--dynamic-dest-dir" in out + assert "/tmp/work" in out + assert out.index("--dynamic-addl-distro") < out.index("--resolver") + assert out.index("--dynamic-dest-dir") < out.index("--resolver") + + def test_no_resolver_flag_means_no_injection(self): + task_cmd = ["pyflyte-execute", "--inputs", "s3://in"] + out = _entrypoint._build_resolver_command(task_cmd, "s3://x", "/tmp") + assert out == task_cmd + + def test_handles_none_distribution_args(self): + task_cmd = ["pyflyte-execute", "--resolver", "x"] + out = _entrypoint._build_resolver_command(task_cmd, None, None) + assert "--dynamic-addl-distro" in out + assert "--dynamic-dest-dir" in out + addl_idx = out.index("--dynamic-addl-distro") + dest_idx = out.index("--dynamic-dest-dir") + assert out[addl_idx + 1] == "" + assert out[dest_idx + 1] == "" + + +# --------------------------------------------------------------------------- +# _exit_with_code +# --------------------------------------------------------------------------- + + +class TestExitWithCode: + """The exit handler bridges Flytekit's exit semantics to EMR's + state-reporting model. + + The contract: + * non-zero rc → exit with rc (EMR sees FAILED) + * rc=0 with user-error banner in stderr → exit 1 (EMR sees FAILED) + * rc=0 without banner → exit 0 (EMR sees SUCCESS) + """ + + def test_nonzero_rc_propagates(self): + with pytest.raises(SystemExit) as exc: + _entrypoint._exit_with_code(42, "") + assert exc.value.code == 42 + + def test_zero_rc_clean_exits_zero(self): + with pytest.raises(SystemExit) as exc: + _entrypoint._exit_with_code(0, "all good") + assert exc.value.code == 0 + + def test_zero_rc_with_user_error_forces_failure(self): + """Flytekit catches user exceptions and exits 0; we must override + that to ensure EMR reports FAILED, not SUCCESS.""" + stderr = "Traceback ...\nUser Error Captured by Flyte: TypeError: x must be int\n" + with pytest.raises(SystemExit) as exc: + _entrypoint._exit_with_code(0, stderr) + assert exc.value.code == 1 + + +# --------------------------------------------------------------------------- +# main() dispatch +# --------------------------------------------------------------------------- + + +class TestMainDispatch: + """``main`` reads ``sys.argv`` and dispatches to the right branch.""" + + def test_no_args_exits_nonzero(self): + with patch.object(sys, "argv", ["entrypoint.py"]): + with pytest.raises(SystemExit) as exc: + _entrypoint.main() + assert exc.value.code == 1 + + def test_unknown_command_exits_nonzero(self): + with patch.object(sys, "argv", ["entrypoint.py", "unknown-cmd"]): + with pytest.raises(SystemExit) as exc: + _entrypoint.main() + assert exc.value.code == 1 + + def test_pyflyte_execute_branch_invokes_subprocess(self): + """When given ``pyflyte-execute ...``, main runs it as a subprocess + with PYTHONPATH defaulted to cwd and exits with rc=0.""" + argv = ["entrypoint.py", "pyflyte-execute", "--inputs", "s3://x"] + with ( + patch.object(sys, "argv", argv), + patch.object(_entrypoint, "_run_subprocess", return_value=(0, "")) as m_run, + ): + with pytest.raises(SystemExit) as exc: + _entrypoint.main() + assert exc.value.code == 0 + cmd_arg, env_arg = m_run.call_args[0] + assert cmd_arg == ["pyflyte-execute", "--inputs", "s3://x"] + assert "PYTHONPATH" in env_arg + + def test_fast_execute_branch_downloads_and_invokes(self): + """When given ``pyflyte-fast-execute ...``, main downloads the + distribution, builds the resolver-aware command, and exits.""" + argv = [ + "entrypoint.py", + "pyflyte-fast-execute", + "--additional-distribution", + "s3://b/fast.tar.gz", + "--dest-dir", + "/tmp/work", + "pyflyte-execute", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + ] + with ( + patch.object(sys, "argv", argv), + patch.object(_entrypoint, "download_distribution") as m_dl, + patch.object(_entrypoint, "_run_subprocess", return_value=(0, "")) as m_run, + ): + with pytest.raises(SystemExit) as exc: + _entrypoint.main() + assert exc.value.code == 0 + m_dl.assert_called_once_with("s3://b/fast.tar.gz", "/tmp/work") + cmd_arg, env_arg = m_run.call_args[0] + assert "--dynamic-addl-distro" in cmd_arg + assert "s3://b/fast.tar.gz" in cmd_arg + assert "/tmp/work" in env_arg["PYTHONPATH"] + + +# --------------------------------------------------------------------------- +# End-to-end CLI behaviour (run the file as EMR will) +# --------------------------------------------------------------------------- + + +class TestEntrypointAsScript: + """Run the entrypoint file as a subprocess to verify the wire-level + behaviour EMR's spark-submit will observe.""" + + def test_no_args_exits_one_with_usage(self): + proc = subprocess.run( + [sys.executable, str(ENTRYPOINT_PATH)], + capture_output=True, + text=True, + ) + assert proc.returncode == 1 + assert "Usage" in proc.stderr + + def test_unknown_command_exits_one(self): + proc = subprocess.run( + [sys.executable, str(ENTRYPOINT_PATH), "definitely-not-a-flyte-command"], + capture_output=True, + text=True, + ) + assert proc.returncode == 1 + assert "Unrecognized command" in proc.stderr + + +# --------------------------------------------------------------------------- +# Connector ↔ entrypoint coupling +# --------------------------------------------------------------------------- + + +class TestConnectorEntrypointWiring: + """Ensure the connector reads exactly this file and produces a stable, + content-addressed hash.""" + + def test_connector_path_points_at_this_module(self): + from flytekitplugins.awsemrserverless.connector import EMRServerlessConnector + + assert EMRServerlessConnector._ENTRYPOINT_PATH == ENTRYPOINT_PATH + + def test_connector_reads_byte_identical_content(self): + from flytekitplugins.awsemrserverless.connector import EMRServerlessConnector + + EMRServerlessConnector._entrypoint_bytes_cache = None # bust cache + assert EMRServerlessConnector._read_entrypoint_bytes() == ENTRYPOINT_PATH.read_bytes() + + def test_hash_is_stable_for_unchanged_content(self): + """The S3 key contains the first 12 hex chars of sha256(content); + as long as the file content is unchanged, the key is unchanged.""" + from flytekitplugins.awsemrserverless.connector import EMRServerlessConnector + + EMRServerlessConnector._entrypoint_bytes_cache = None + content = EMRServerlessConnector._read_entrypoint_bytes() + h1 = hashlib.sha256(content).hexdigest()[:12] + h2 = hashlib.sha256(ENTRYPOINT_PATH.read_bytes()).hexdigest()[:12] + assert h1 == h2 diff --git a/plugins/flytekit-aws-emr-serverless/tests/test_task.py b/plugins/flytekit-aws-emr-serverless/tests/test_task.py new file mode 100644 index 0000000000..226997c74a --- /dev/null +++ b/plugins/flytekit-aws-emr-serverless/tests/test_task.py @@ -0,0 +1,346 @@ +""" +Unit tests for EMR Serverless Task and configuration. +""" + +from unittest.mock import MagicMock + +import pytest +from flytekit import task + +from flytekitplugins.awsemrserverless import ( + EMRServerless, + EMRServerlessHiveJobDriver, + EMRServerlessSparkJobDriver, + EMRServerlessTask, +) + + +class TestEMRServerlessSparkJobDriver: + def test_basic_creation(self): + driver = EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py") + assert driver.entry_point == "s3://bucket/main.py" + assert driver.entry_point_arguments is None + assert driver.spark_submit_parameters is None + + def test_full_creation(self): + driver = EMRServerlessSparkJobDriver( + entry_point="s3://bucket/main.py", + entry_point_arguments=["--arg1", "value1"], + spark_submit_parameters="--conf spark.executor.memory=4g", + ) + assert driver.entry_point_arguments == ["--arg1", "value1"] + + def test_to_dict_basic(self): + driver = EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py") + assert driver.to_dict() == {"entryPoint": "s3://bucket/main.py"} + + def test_to_dict_full(self): + driver = EMRServerlessSparkJobDriver( + entry_point="s3://bucket/main.py", + entry_point_arguments=["--arg1", "value1"], + spark_submit_parameters="--conf spark.executor.memory=4g", + ) + result = driver.to_dict() + assert result["entryPoint"] == "s3://bucket/main.py" + assert result["entryPointArguments"] == ["--arg1", "value1"] + assert result["sparkSubmitParameters"] == "--conf spark.executor.memory=4g" + + +class TestEMRServerlessHiveJobDriver: + def test_basic_creation(self): + driver = EMRServerlessHiveJobDriver(query="SELECT * FROM table") + assert driver.query == "SELECT * FROM table" + assert driver.init_query_file is None + + def test_to_dict(self): + driver = EMRServerlessHiveJobDriver( + query="s3://bucket/query.sql", + init_query_file="s3://bucket/init.sql", + ) + result = driver.to_dict() + assert result == { + "query": "s3://bucket/query.sql", + "initQueryFile": "s3://bucket/init.sql", + } + + +class TestEMRServerless: + def test_basic_spark_config(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + ) + assert config.application_type == "SPARK" + assert config.is_script_mode is True + + def test_pythonic_mode_config(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + ) + assert config.is_script_mode is False + + def test_hive_config(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + application_type="HIVE", + hive_job_driver=EMRServerlessHiveJobDriver(query="s3://bucket/query.sql"), + ) + assert config.application_type == "HIVE" + assert config.is_script_mode is True + + def test_missing_execution_role(self): + with pytest.raises(ValueError, match="execution_role_arn is required"): + EMRServerless(execution_role_arn="") + + def test_invalid_application_type(self): + with pytest.raises(ValueError, match="application_type must be"): + EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_type="INVALID", + ) + + def test_missing_hive_driver(self): + with pytest.raises(ValueError, match="hive_job_driver is required"): + EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_type="HIVE", + ) + + def test_both_drivers_specified(self): + with pytest.raises(ValueError, match="Only one of"): + EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + hive_job_driver=EMRServerlessHiveJobDriver(query="SELECT 1"), + ) + + def test_invalid_timeout(self): + with pytest.raises(ValueError, match="execution_timeout_minutes"): + EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + execution_timeout_minutes=0, + ) + + def test_get_job_driver_spark(self, sample_spark_job_driver): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + spark_job_driver=sample_spark_job_driver, + ) + result = config.get_job_driver() + assert "sparkSubmit" in result + + def test_get_job_driver_hive(self, sample_hive_job_driver): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_type="HIVE", + hive_job_driver=sample_hive_job_driver, + ) + result = config.get_job_driver() + assert "hive" in result + + def test_to_dict(self, sample_config): + result = sample_config.to_dict() + assert result["execution_role_arn"] == "arn:aws:iam::123456789012:role/EMRServerlessRole" + assert result["application_id"] == "00f5abc123def456" + assert result["region"] == "us-east-1" + assert "spark_job_driver" in result + + def test_from_dict_roundtrip(self, sample_config): + config_dict = sample_config.to_dict() + restored = EMRServerless.from_dict(config_dict) + assert restored.execution_role_arn == sample_config.execution_role_arn + assert restored.application_id == sample_config.application_id + assert restored.region == sample_config.region + + def test_full_config_with_all_options(self, sample_spark_job_driver): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + application_id="00abc123", + release_label="emr-6.15.0", + application_name="my-app", + spark_job_driver=sample_spark_job_driver, + configuration_overrides={ + "monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": "s3://logs/"}} + }, + tags={"Env": "prod"}, + execution_timeout_minutes=240, + initial_capacity={"DRIVER": {"workerCount": 1}}, + maximum_capacity={"cpu": "100vCPU"}, + network_configuration={"subnetIds": ["subnet-123"]}, + image_configuration={"imageUri": "123456789012.dkr.ecr.us-east-1.amazonaws.com/spark:latest"}, + region="us-west-2", + ) + d = config.to_dict() + assert d["release_label"] == "emr-6.15.0" + assert d["application_name"] == "my-app" + assert d["tags"]["Env"] == "prod" + assert d["network_configuration"]["subnetIds"] == ["subnet-123"] + + def test_sync_image_default(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + application_id="00abc123", + ) + assert config.sync_image is True + + def test_architecture_validation(self): + with pytest.raises(ValueError, match="architecture must be one of"): + EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + architecture="INVALID", + ) + + def test_valid_architectures(self): + for arch in ("X86_64", "ARM64"): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + architecture=arch, + ) + assert config.architecture == arch + + def test_new_fields_to_dict_roundtrip(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + sync_image=False, + spark_submit_parameters="--conf spark.executor.memory=16g", + application_configuration=[ + {"classification": "spark-defaults", "properties": {"spark.executor.cores": "8"}} + ], + runtime_configuration=[ + {"classification": "spark-defaults", "properties": {"spark.dynamicAllocation.enabled": "true"}} + ], + scheduler_configuration={"maxConcurrentRuns": 5, "queueTimeoutMinutes": 30}, + auto_stop_config={"enabled": True, "idleTimeoutMinutes": 30}, + architecture="ARM64", + retry_policy={"maxAttempts": 3}, + ) + d = config.to_dict() + restored = EMRServerless.from_dict(d) + + assert restored.sync_image is False + assert restored.spark_submit_parameters == "--conf spark.executor.memory=16g" + assert len(restored.application_configuration) == 1 + assert restored.application_configuration[0]["classification"] == "spark-defaults" + assert len(restored.runtime_configuration) == 1 + assert restored.scheduler_configuration == {"maxConcurrentRuns": 5, "queueTimeoutMinutes": 30} + assert restored.auto_stop_config == {"enabled": True, "idleTimeoutMinutes": 30} + assert restored.architecture == "ARM64" + assert restored.retry_policy == {"maxAttempts": 3} + + def test_effective_configuration_overrides_merge(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + configuration_overrides={ + "monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": "s3://logs/"}} + }, + application_configuration=[ + {"classification": "spark-defaults", "properties": {"spark.executor.cores": "8"}} + ], + ) + result = config.get_effective_configuration_overrides() + assert "monitoringConfiguration" in result + assert len(result["applicationConfiguration"]) == 1 + + def test_effective_configuration_overrides_none(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + ) + assert config.get_effective_configuration_overrides() is None + + def test_effective_configuration_overrides_only_app_config(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/Role", + application_configuration=[ + {"classification": "spark", "properties": {"dynamicAllocationOptimization": "true"}} + ], + ) + result = config.get_effective_configuration_overrides() + assert len(result["applicationConfiguration"]) == 1 + + +class TestEMRServerlessTask: + def test_task_type(self): + assert EMRServerlessTask._TASK_TYPE == "emr_serverless" + + def test_task_creation_with_decorator(self, sample_task_config): + @task(task_config=sample_task_config) + def my_spark_job() -> str: + return "done" + + assert isinstance(my_spark_job, EMRServerlessTask) + assert my_spark_job.task_config == sample_task_config + + def test_get_custom(self, sample_task_config): + @task(task_config=sample_task_config) + def my_spark_job() -> str: + return "done" + + mock_settings = MagicMock() + custom = my_spark_job.get_custom(mock_settings) + + assert isinstance(custom, dict) + assert custom["execution_role_arn"] == sample_task_config.execution_role_arn + assert custom["application_id"] == sample_task_config.application_id + assert "spark_job_driver" in custom + + def test_get_custom_with_hive(self): + task_config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + application_type="HIVE", + hive_job_driver=EMRServerlessHiveJobDriver( + query="s3://bucket/query.sql", + init_query_file="s3://bucket/init.sql", + ), + ) + + @task(task_config=task_config) + def hive_task() -> str: + return "done" + + custom = hive_task.get_custom(MagicMock()) + assert custom["application_type"] == "HIVE" + assert custom["hive_job_driver"]["query"] == "s3://bucket/query.sql" + + def test_task_inherits_from_correct_classes(self): + from flytekit.core.python_function_task import PythonFunctionTask + + try: + from flytekit.extend.backend.base_connector import AsyncConnectorExecutorMixin + except ModuleNotFoundError: + from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin as AsyncConnectorExecutorMixin + + assert issubclass(EMRServerlessTask, PythonFunctionTask) + assert issubclass(EMRServerlessTask, AsyncConnectorExecutorMixin) + + def test_task_decorator_produces_emr_serverless_task(self): + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + application_id="app-123", + spark_job_driver=EMRServerlessSparkJobDriver(entry_point="s3://bucket/main.py"), + ) + + @task(task_config=config) + def my_emr_task() -> str: + return "done" + + assert isinstance(my_emr_task, EMRServerlessTask) + assert my_emr_task._task_type == "emr_serverless" + + def test_pythonic_mode_task(self): + """Test that a task with no job driver creates successfully (Pythonic mode).""" + config = EMRServerless( + execution_role_arn="arn:aws:iam::123456789012:role/EMRRole", + application_id="app-123", + region="us-east-1", + ) + + @task(task_config=config) + def my_pythonic_spark_job() -> str: + return "done" + + assert isinstance(my_pythonic_spark_job, EMRServerlessTask) + + custom = my_pythonic_spark_job.get_custom(MagicMock()) + assert "spark_job_driver" not in custom + assert custom["application_type"] == "SPARK"