Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,12 +648,35 @@ def get_serializable_array_node_map_task(
# TODO Add support for other flyte entities
entity = node.flyte_entity
task_spec = get_serializable(entity_mapping, settings, entity, options)

override_pod_spec = {}
if node._pod_template is not None:
# get_container (not _get_container) goes through prepare_target() so the
# container args carry the map-task command rather than pyflyte-execute.
container = entity.get_container(settings)
if settings.should_fast_serialize() and container is not None:
# Mirror get_serializable_task: prefix args post-build rather than
# swapping command_fn, since _fast_serialize_command_fn would wrap
# the inherited pyflyte-execute default (wrong for map_task).
container._args = prefix_with_fast_execute(settings, container.args)
override_pod_spec = _serialize_pod_spec(node._pod_template, container, settings)

task_node = workflow_model.TaskNode(
reference_id=task_spec.template.id,
overrides=TaskNodeOverrides(
resources=node._resources,
extended_resources=node._extended_resources,
container_image=node._container_image,
pod_template=PodTemplate(
pod_spec=override_pod_spec,
labels=node._pod_template.labels if node._pod_template.labels else None,
annotations=node._pod_template.annotations if node._pod_template.annotations else None,
primary_container_name=node._pod_template.primary_container_name
if node._pod_template.primary_container_name
else None,
)
if node._pod_template
else None,
),
)
node = workflow_model.Node(
Expand Down
65 changes: 61 additions & 4 deletions tests/flytekit/unit/test_translator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import typing
from collections import OrderedDict

import pytest
from kubernetes import client
from kubernetes.client import V1Container, V1PodSpec

import flytekit.configuration
from flytekit import ContainerTask, Resources, PodTemplate
from flytekit import ContainerTask, PodTemplate, Resources, map_task
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig
from flytekit.core.base_task import kwtypes
from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan
Expand All @@ -13,9 +17,6 @@
from flytekit.models.core import identifier as identifier_models
from flytekit.models.task import Resources as resource_model
from flytekit.tools.translator import get_serializable
from kubernetes import client
from kubernetes.client import V1PodSpec, V1Container
import pytest

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
Expand Down Expand Up @@ -234,3 +235,59 @@ def wf():
assert len(pod_template_override.pod_spec['containers']) == 1
container = pod_template_override.pod_spec['containers'][0]
assert container['env'] == [{'name': 'MY_KEY', 'value': 'MY_VALUE'}]


@pytest.mark.parametrize(
"fast_registration_enabled",
[
pytest.param(True, id="fast registration enabled"),
pytest.param(False, id="fast registration disabled"),
],
)
def test_map_task_with_pod_template_override(fast_registration_enabled: bool):
# Regression test for https://github.com/flyteorg/flyte/issues/7076
# map_task(...).with_overrides(pod_template=...) was silently dropped at serialization.
custom_pod_template = PodTemplate(
primary_container_name="primary",
labels={"lKeyA": "lValA"},
annotations={"aKeyA": "aValA"},
pod_spec=V1PodSpec(
containers=[
V1Container(
name="primary",
env=[client.V1EnvVar(name="MY_KEY", value="MY_VALUE")],
)
]
),
)

@task
def t(a: int) -> str:
return str(a)

@workflow
def wf(xs: typing.List[int]):
map_task(t)(a=xs).with_overrides(pod_template=custom_pod_template)

settings = (
serialization_settings.new_builder()
.with_fast_serialization_settings(FastSerializationSettings(enabled=fast_registration_enabled))
.build()
)

wf_spec = get_serializable(OrderedDict(), settings, wf)
assert len(wf_spec.template.nodes) == 1
node = wf_spec.template.nodes[0]
# map_task is serialized as an array_node wrapping an inner task node
assert node.array_node is not None
inner_task_node = node.array_node.node.task_node
assert inner_task_node is not None
assert inner_task_node.overrides.pod_template is not None
pod_template_override = inner_task_node.overrides.pod_template
assert pod_template_override.primary_container_name == "primary"
assert pod_template_override.labels == {"lKeyA": "lValA"}
assert pod_template_override.annotations == {"aKeyA": "aValA"}
assert pod_template_override.pod_spec # validate not empty
assert len(pod_template_override.pod_spec["containers"]) == 1
container = pod_template_override.pod_spec["containers"][0]
assert {"name": "MY_KEY", "value": "MY_VALUE"} in container["env"]