From d2edd5de07d4f9214fb7aea1732e75105d077865 Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Mon, 27 Apr 2026 02:02:19 -0700 Subject: [PATCH] Carry pod_template override through map_task serialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit map_task(...).with_overrides(pod_template=PodTemplate(...)) was being silently dropped at serialization. The override is accepted at workflow build time and stored on the Node, but get_serializable_array_node_map_task never read node._pod_template — TaskNodeOverrides was emitted with only resources/extended_resources/container_image, so the registered task spec had an empty overrides {} block. Same class of bug as #6463 / PR #3270 fixed for regular tasks, but in the array-node serialization path. The map-task path needs two implementation tweaks vs. the regular-task mirror: * Use entity.get_container() (not _get_container) so prepare_target() substitutes the map-task command (pyflyte-map-execute) into the container args before the override pod spec is built. * For fast-serialization, prefix container args post-build (matching get_serializable_task) instead of swapping command_fn via _fast_serialize_command_fn — the latter wraps the inherited pyflyte-execute default, which is wrong for map_task. Adds a parametrized regression test covering both fast-registration enabled and disabled. Fixes https://github.com/flyteorg/flyte/issues/7076 Signed-off-by: 1fanwang <1fannnw@gmail.com> --- flytekit/tools/translator.py | 23 +++++++++ tests/flytekit/unit/test_translator.py | 65 ++++++++++++++++++++++++-- 2 files changed, 84 insertions(+), 4 deletions(-) diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index e75ea30b21..6a50bfbd0e 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -648,12 +648,35 @@ def get_serializable_array_node_map_task( # TODO Add support for other flyte entities entity = node.flyte_entity task_spec = get_serializable(entity_mapping, settings, entity, options) + + override_pod_spec = {} + if node._pod_template is not None: + # get_container (not _get_container) goes through prepare_target() so the + # container args carry the map-task command rather than pyflyte-execute. + container = entity.get_container(settings) + if settings.should_fast_serialize() and container is not None: + # Mirror get_serializable_task: prefix args post-build rather than + # swapping command_fn, since _fast_serialize_command_fn would wrap + # the inherited pyflyte-execute default (wrong for map_task). + container._args = prefix_with_fast_execute(settings, container.args) + override_pod_spec = _serialize_pod_spec(node._pod_template, container, settings) + task_node = workflow_model.TaskNode( reference_id=task_spec.template.id, overrides=TaskNodeOverrides( resources=node._resources, extended_resources=node._extended_resources, container_image=node._container_image, + pod_template=PodTemplate( + pod_spec=override_pod_spec, + labels=node._pod_template.labels if node._pod_template.labels else None, + annotations=node._pod_template.annotations if node._pod_template.annotations else None, + primary_container_name=node._pod_template.primary_container_name + if node._pod_template.primary_container_name + else None, + ) + if node._pod_template + else None, ), ) node = workflow_model.Node( diff --git a/tests/flytekit/unit/test_translator.py b/tests/flytekit/unit/test_translator.py index f446db0013..4e074122dd 100644 --- a/tests/flytekit/unit/test_translator.py +++ b/tests/flytekit/unit/test_translator.py @@ -1,8 +1,12 @@ import typing from collections import OrderedDict +import pytest +from kubernetes import client +from kubernetes.client import V1Container, V1PodSpec + import flytekit.configuration -from flytekit import ContainerTask, Resources, PodTemplate +from flytekit import ContainerTask, PodTemplate, Resources, map_task from flytekit.configuration import FastSerializationSettings, Image, ImageConfig from flytekit.core.base_task import kwtypes from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan @@ -13,9 +17,6 @@ from flytekit.models.core import identifier as identifier_models from flytekit.models.task import Resources as resource_model from flytekit.tools.translator import get_serializable -from kubernetes import client -from kubernetes.client import V1PodSpec, V1Container -import pytest default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -234,3 +235,59 @@ def wf(): assert len(pod_template_override.pod_spec['containers']) == 1 container = pod_template_override.pod_spec['containers'][0] assert container['env'] == [{'name': 'MY_KEY', 'value': 'MY_VALUE'}] + + +@pytest.mark.parametrize( + "fast_registration_enabled", + [ + pytest.param(True, id="fast registration enabled"), + pytest.param(False, id="fast registration disabled"), + ], +) +def test_map_task_with_pod_template_override(fast_registration_enabled: bool): + # Regression test for https://github.com/flyteorg/flyte/issues/7076 + # map_task(...).with_overrides(pod_template=...) was silently dropped at serialization. + custom_pod_template = PodTemplate( + primary_container_name="primary", + labels={"lKeyA": "lValA"}, + annotations={"aKeyA": "aValA"}, + pod_spec=V1PodSpec( + containers=[ + V1Container( + name="primary", + env=[client.V1EnvVar(name="MY_KEY", value="MY_VALUE")], + ) + ] + ), + ) + + @task + def t(a: int) -> str: + return str(a) + + @workflow + def wf(xs: typing.List[int]): + map_task(t)(a=xs).with_overrides(pod_template=custom_pod_template) + + settings = ( + serialization_settings.new_builder() + .with_fast_serialization_settings(FastSerializationSettings(enabled=fast_registration_enabled)) + .build() + ) + + wf_spec = get_serializable(OrderedDict(), settings, wf) + assert len(wf_spec.template.nodes) == 1 + node = wf_spec.template.nodes[0] + # map_task is serialized as an array_node wrapping an inner task node + assert node.array_node is not None + inner_task_node = node.array_node.node.task_node + assert inner_task_node is not None + assert inner_task_node.overrides.pod_template is not None + pod_template_override = inner_task_node.overrides.pod_template + assert pod_template_override.primary_container_name == "primary" + assert pod_template_override.labels == {"lKeyA": "lValA"} + assert pod_template_override.annotations == {"aKeyA": "aValA"} + assert pod_template_override.pod_spec # validate not empty + assert len(pod_template_override.pod_spec["containers"]) == 1 + container = pod_template_override.pod_spec["containers"][0] + assert {"name": "MY_KEY", "value": "MY_VALUE"} in container["env"]