diff --git a/config/tests/deep-segments-config.data.xml b/config/tests/deep-segments-config.data.xml index f8833549a..796e32b0f 100644 --- a/config/tests/deep-segments-config.data.xml +++ b/config/tests/deep-segments-config.data.xml @@ -218,6 +218,7 @@ + diff --git a/config/tests/nestedConfig.data.xml b/config/tests/nestedConfig.data.xml index 4761605e9..23592ceb8 100644 --- a/config/tests/nestedConfig.data.xml +++ b/config/tests/nestedConfig.data.xml @@ -209,6 +209,7 @@ + diff --git a/config/tests/one-controller-config.data.xml b/config/tests/one-controller-config.data.xml index 6193886c5..7b3ee77c0 100644 --- a/config/tests/one-controller-config.data.xml +++ b/config/tests/one-controller-config.data.xml @@ -116,6 +116,7 @@ + diff --git a/src/drunc/controller/controller_driver.py b/src/drunc/controller/controller_driver.py index e31b0a544..187783399 100644 --- a/src/drunc/controller/controller_driver.py +++ b/src/drunc/controller/controller_driver.py @@ -36,7 +36,9 @@ def __init__(self, address: str, token: Token): options = [ ("grpc.keepalive_time_ms", 60000) # pings the server every 60 seconds ] - self.channel = grpc.insecure_channel(self.address, options=options) + # The 'ipv4:' prefix forces IPv4 resolution, which helps avoid Kubernetes hairpinning issues + target_address = f"ipv4:{self.address}" + self.channel = grpc.insecure_channel(target_address, options=options) self.stub = ControllerStub(self.channel) self.token = Token() self.token.CopyFrom(token) diff --git a/src/drunc/controller/interface/shell_utils.py b/src/drunc/controller/interface/shell_utils.py index 4a453812e..dc7e2fd81 100644 --- a/src/drunc/controller/interface/shell_utils.py +++ b/src/drunc/controller/interface/shell_utils.py @@ -1,4 +1,5 @@ import datetime +import ipaddress import logging import os import socket @@ -7,7 +8,7 @@ from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from functools import partial +from functools import lru_cache, partial from urllib.parse import urlparse import click @@ -113,8 +114,8 @@ def update_endpoint(endpoint: str) -> str: ip_address = urlparse(endpoint).hostname if not ip_address: return "" - hostname, _, _ = socket.gethostbyaddr(ip_address) - return endpoint.replace(ip_address, hostname) + resolved_host = get_hostname_smart(ip_address) + return endpoint.replace(ip_address, resolved_host) table.add_row( prefix + status_response.name, @@ -709,3 +710,50 @@ def grab_default_value_from_env(argument_name): )(cmd) return cmd, cmd_name + + +@lru_cache(maxsize=1024) +def is_private_ip(ip_str: str) -> bool: + """ + Checks if an IP address is private (RFC 1918), loopback, or link-local. + These IPs will almost never have a public reverse DNS record. + """ + if not ip_str: + return True + try: + ip_obj = ipaddress.ip_address(ip_str) + # .is_private = 10.x, 172.16-31.x, 192.168.x + # .is_loopback = 127.x.x.x + # .is_link_local = 169.254.x.x + return ip_obj.is_private or ip_obj.is_loopback or ip_obj.is_link_local + except ValueError: + # Not 'valid' IP address -> treat as private + return True + + +@lru_cache(maxsize=4096) +def get_hostname_smart(ip_address: str, timeout_seconds: float = 0.2) -> str: + """ + Resolves an IP to a hostname, with optimizations: + 1. Caches all results. + 2. Immediately skips private/internal IPs (like K8s). + 3. Uses a short timeout for public IPs. + """ + + # If private IP (k8s), don't try to resolve it + if is_private_ip(ip_address): + return ip_address + + # If public IP, try to resolve it. + original_timeout = socket.getdefaulttimeout() + try: + socket.setdefaulttimeout(timeout_seconds) + + hostname, _, _ = socket.gethostbyaddr(ip_address) + return hostname + + except (socket.herror, socket.gaierror, socket.timeout): + return ip_address + + finally: + socket.setdefaulttimeout(original_timeout) diff --git a/src/drunc/data/process_manager/k8s.json b/src/drunc/data/process_manager/k8s.json index d37750c5d..2fedaf06c 100644 --- a/src/drunc/data/process_manager/k8s.json +++ b/src/drunc/data/process_manager/k8s.json @@ -42,7 +42,9 @@ "checking": { "watcher_retry_sleep": 5, "pod_status_check_sleep": 1, - "host_cache_expiry": 300 + "host_cache_expiry": 300, + "grpc_startup_timeout": 30, + "socket_retry_timeout": 1.0 } } } diff --git a/src/drunc/process_manager/k8s_process_manager.py b/src/drunc/process_manager/k8s_process_manager.py index 584796beb..ee1feb1bb 100644 --- a/src/drunc/process_manager/k8s_process_manager.py +++ b/src/drunc/process_manager/k8s_process_manager.py @@ -3,7 +3,10 @@ import os import re import signal +import socket import threading +import urllib.error +import urllib.request import uuid from time import sleep, time @@ -189,6 +192,8 @@ def __init__(self, configuration, **kwargs) -> None: self.watcher_retry_sleep = checking.get("watcher_retry_sleep", 5) self.pod_status_check_sleep = checking.get("pod_status_check_sleep", 1) self._host_cache_expiry = checking.get("host_cache_expiry", 300) + self.grpc_startup_timeout = checking.get("grpc_startup_timeout", 30) + self.socket_retry_timeout = checking.get("socket_retry_timeout", 1.0) self.log.debug(f"Using kill_timeout of {self.kill_timeout} seconds.") @@ -461,6 +466,7 @@ def _create_nodeport_service(self, podname, session, pod_uid) -> None: ), spec=client.V1ServiceSpec( type="NodePort", + external_traffic_policy="Local", selector={"app": podname}, ports=[ client.V1ServicePort( @@ -480,12 +486,37 @@ def _create_nodeport_service(self, podname, session, pod_uid) -> None: f'Created NodePort service "{session}.{podname}" on port {self.connection_server_port} ' f"(NodePort: {self.connection_server_node_port} for external access)" ) + except self._api_error_v1_api as e: - if e.status != 409: - self.log.error(f"Failed to create NodePort service for {podname}: {e}") + is_port_conflict = False + + # Check for 422="Unprocessable Entity" or 409="Conflict" status + if e.status == 422 or e.status == 409: + if e.body and ( + "provided nodeport is already allocated" in e.body.lower() + or "port is already in use" in e.body.lower() + ): + is_port_conflict = True + + if is_port_conflict: + port = self.connection_server_node_port + error_message = ( + f"NodePort {port} is already in use by another service. " + f"Cannot start '{podname}'." + ) + self.log.error(error_message) + raise DruncK8sException(error_message) from e + else: + # other K8s API error + error_message = f"Failed to create NodePort service for {podname}: {e.reason} ({e.status})" + self.log.error(error_message) + raise DruncK8sException(error_message) from e + + def _build_pod_main_container( + self, podname: str, boot_request: BootRequest, lcs_port: int | None + ) -> client.V1Container: + """Builds the primary V1Container manifest, including command and preStop hook.""" - def _create_pod(self, podname, session, boot_request: BootRequest) -> None: - """Constructs and creates a Kubernetes Pod manifest.""" pod_image = self.configuration.data.image exec_and_args_list = boot_request.process_description.executable_and_arguments @@ -506,6 +537,12 @@ def _create_pod(self, podname, session, boot_request: BootRequest) -> None: command_parts.append(prefix + " ".join([e_and_a.exec] + list(e_and_a.args))) main_command_str = " && ".join(command_parts) + container_ports = [] + if podname == self.connection_server_name and lcs_port is not None: + container_ports.append( + client.V1ContainerPort(container_port=lcs_port, name="http-port") + ) + # Only add preStop hook for C++ applications (non-controllers) lifecycle_hook = None if "controller" not in podname and podname != self.connection_server_name: @@ -525,59 +562,59 @@ def _create_pod(self, podname, session, boot_request: BootRequest) -> None: f"'{podname}' identified as a Python app, no preStop hook needed." ) - # Create container with conditional lifecycle hook - container_kwargs = { - "name": podname, - "image": pod_image, - "command": ["/bin/sh", "-c"], - "args": [main_command_str], - "env": [ + main_container = client.V1Container( + name=podname, + image=pod_image, + command=["/bin/sh", "-c"], + args=[main_command_str], + env=[ client.V1EnvVar(name=k, value=v) for k, v in boot_request.process_description.env.items() ], - "ports": [], - "volume_mounts": [ + lifecycle=lifecycle_hook, + ports=container_ports, + volume_mounts=[ client.V1VolumeMount(name="nfs", mount_path="/nfs"), client.V1VolumeMount(name="cvmfs", mount_path="/cvmfs"), ], - "working_dir": boot_request.process_description.process_execution_directory, - "security_context": client.V1SecurityContext( + working_dir=boot_request.process_description.process_execution_directory, + security_context=client.V1SecurityContext( run_as_user=os.getuid(), run_as_group=os.getgid() ), - } - - # Only add lifecycle hook for C++ applications - if lifecycle_hook is not None: - container_kwargs["lifecycle"] = lifecycle_hook - - main_container = client.V1Container(**container_kwargs) - - all_containers = [main_container] + ) + return main_container + def _get_pod_node_selector( + self, podname: str, restriction: ProcessRestriction + ) -> dict: + """Verifies the target host and returns the Kubernetes node selector.""" node_selector = {} - if boot_request.process_restriction.allowed_hosts: - target_host = boot_request.process_restriction.allowed_hosts[0] - # Resolve localhost to actual hostname for Kubernetes node selection + if restriction.allowed_hosts: + target_host = restriction.allowed_hosts[0] + if target_host == "localhost": target_host = resolve_localhost_to_hostname(target_host) self.log.info( f"Resolved localhost to '{target_host}' for node selection" ) - # Verify the target host is available in the cluster before scheduling self._verify_host_in_cluster(target_host) node_selector = {"kubernetes.io/hostname": target_host} self.log.info( f"Pod '{podname}' will be scheduled on node '{target_host}' (from boot request)" ) + return node_selector - host_aliases = [] + def _get_pod_host_aliases( + self, podname: str, session: str + ) -> list[client.V1HostAlias] | None: + """Gets the ClusterIP of the connection server and prepares host aliases.""" + host_aliases = None if ( podname != self.connection_server_name and self.local_connection_server_is_booted ): - # Wait for service to get ClusterIP connection_server_ip = None retry_count = 0 max_retries = 10 @@ -598,8 +635,18 @@ def _create_pod(self, podname, session, boot_request: BootRequest) -> None: self.log.warning( f"Could not get connection server ClusterIP for pod '{podname}'" ) - - pod_manifest = client.V1Pod( + return host_aliases + + def _build_pod_manifest( + self, + podname: str, + session: str, + main_container: client.V1Container, + node_selector: dict, + host_aliases: list[client.V1HostAlias] | None, + ) -> client.V1Pod: + """Assembles the final V1Pod object.""" + return client.V1Pod( api_version="v1", kind="Pod", metadata=self._meta_v1_api( @@ -614,7 +661,7 @@ def _create_pod(self, podname, session, boot_request: BootRequest) -> None: node_selector=node_selector, termination_grace_period_seconds=self.kill_timeout, restart_policy="Never", - containers=all_containers, + containers=[main_container], host_aliases=host_aliases if host_aliases else None, volumes=[ client.V1Volume( @@ -628,48 +675,133 @@ def _create_pod(self, podname, session, boot_request: BootRequest) -> None: ), ) - try: - start_time = time() - pod_uid = None + def _execute_pod_creation_api( + self, session: str, podname: str, pod_manifest: client.V1Pod + ) -> str: + """Executes the API call to create the pod, handling 409 conflict during restarts.""" + start_time = time() - while True: - try: - created_pod = self._core_v1_api.create_namespaced_pod( - session, pod_manifest - ) - self.log.info(f'Creating pod "{session}.{podname}"') - pod_uid = created_pod.metadata.uid - break + while True: + try: + created_pod = self._core_v1_api.create_namespaced_pod( + session, pod_manifest + ) + self.log.info(f'Creating pod "{session}.{podname}"') + return created_pod.metadata.uid - # this covers restart where we need to wait for cleanup - except self._api_error_v1_api as e: - is_409_conflict = e.status == 409 + except self._api_error_v1_api as e: + is_409_conflict = e.status == 409 + elapsed_time = time() - start_time - if ( - is_409_conflict - and time() - start_time < self.restart_cleanup_time - ): - sleep(self.restart_cleanup_polling) - continue - raise e + if is_409_conflict and elapsed_time < self.restart_cleanup_time: + sleep(self.restart_cleanup_polling) + continue - if podname == self.connection_server_name: + if is_409_conflict: + error_message = ( + f"Timeout (>{self.restart_cleanup_time}s) waiting for old pod object " + f'"{session}/{podname}" to be fully deleted. Could not restart pod.' + ) + self.log.error(error_message) + raise DruncK8sException(error_message) from e + + raise e + + def _create_associated_service( + self, + podname: str, + session: str, + pod_uid: str, + boot_request: BootRequest, + lcs_port: int | None, + ) -> None: + """Calls the appropriate service creation method based on pod type.""" + if podname == self.connection_server_name: + if lcs_port is None: + raise DruncK8sException( + "LCS service creation failed: port was not extracted." + ) + + # If LCS, call nodeport service creation + self._create_nodeport_service(podname, session, pod_uid) + + elif "root-controller" in podname: + self.log.info( + f"'{podname}' is the root controller, checking for NodePort service." + ) + port = self._extract_port_from_cmd(boot_request) + if port: + self.log.info(f"Extracted port {port} for '{podname}' NodePort.") + self.connection_server_port = port + self.connection_server_node_port = port self._create_nodeport_service(podname, session, pod_uid) else: + self.log.warning( + f"Could not extract port for '{podname}', falling back to headless." + ) self._create_headless_service(podname, session, pod_uid) - except self._api_error_v1_api as e: - error_message = f'Couldn\'t create resources for pod "{session}.{podname}". Reason: {e.reason}. Kubernetes API Error: ({e.status})' + else: + self._create_headless_service(podname, session, pod_uid) - if e.status == 409 and time() - start_time >= self.restart_cleanup_time: - error_message = ( - f"Timeout (>{self.restart_cleanup_time}s) waiting for old pod object " - f'"{session}/{podname}" to be fully deleted. Could not restart pod.' - ) + def _create_pod(self, podname, session, boot_request: BootRequest) -> None: + """Constructs and creates a Kubernetes Pod manifest and its associated service.""" + try: + lcs_port = None + # Early Port Extraction and Class Variable Setup for LCS + if podname == self.connection_server_name: + lcs_port = self._extract_port_from_cmd(boot_request) + if lcs_port: + self.connection_server_port = lcs_port + self.connection_server_node_port = lcs_port + else: + raise DruncK8sException( + f"Could not extract port for LCS '{podname}'." + ) + + # Build the main container manifest + main_container = self._build_pod_main_container( + podname, boot_request, lcs_port + ) + + # Node_selector, host_aliases, pod_manifest + node_selector = self._get_pod_node_selector( + podname, boot_request.process_restriction + ) + host_aliases = self._get_pod_host_aliases(podname, session) + pod_manifest = self._build_pod_manifest( + podname, + session, + main_container, + node_selector, + host_aliases, + ) + + # Execute the pod creation API call + pod_uid = self._execute_pod_creation_api(session, podname, pod_manifest) + + # Create associated service + self._create_associated_service( + podname, session, pod_uid, boot_request, lcs_port + ) + + except self._api_error_v1_api as e: + # *other* K8s errors (e.g., 400, 403, 500) + error_message = f'Couldn\'t create resources for pod "{session}.{podname}". Reason: {e.reason}. Kubernetes API Error: ({e.status})' self.log.error(error_message) raise DruncK8sException(error_message) from e + except DruncK8sException: + # any other DruncK8sException + raise + + except Exception as e: + # generic catch-all + raise DruncK8sException( + f"Failed to create pod '{session}.{podname}': {e}" + ) from e + def _get_connection_server_cluster_ip(self, session) -> str: """Gets the ClusterIP of the connection server service.""" try: @@ -682,19 +814,78 @@ def _get_connection_server_cluster_ip(self, session) -> str: return None def _extract_port_from_cmd(self, boot_request) -> int | None: - # Find the gunicorn port argument from exec_and_args_list + """ + Parses the boot request's command arguments to find a port. + It must cover Gunicorn (hardcoded and env var) and drunc-controller. + """ + # Check all command parts for a port argument for e_and_a in boot_request.process_description.executable_and_arguments: - if "gunicorn" in e_and_a.exec or ( - "gunicorn" in " ".join(list(e_and_a.args)) - ): - all_args = [e_and_a.exec] + list(e_and_a.args) - arg_str = " ".join(all_args) - match = re.search(r"-b\s+[\w\.]+:(\d+)", arg_str) - if not match: - # Try to match '--bind' - match = re.search(r"--bind[\s=]+[\w\.]+:(\d+)", arg_str) + all_args = [e_and_a.exec] + list(e_and_a.args) + arg_str = " ".join(all_args) + + # Check for gunicorn bind syntax (for local-connection-server) + if "gunicorn" in arg_str: + match_hardcoded = re.search(r"(-b|--bind)[\s=]+[\w\.]+:(\d+)", arg_str) + + if match_hardcoded: + port = int(match_hardcoded.group(2)) + if port != 0: + self.log.info( + f"Extracted hardcoded gunicorn port {port} from command." + ) + return port + + # Match environment variable port: e.g., --bind=0.0.0.0:${CONNECTION_PORT} + match_var = re.search(r"(-b|--bind)[\s=]+[\w\.]+:\$\{(\w+)\}", arg_str) + + if match_var: + var_name = match_var.group(2) + # Look up the value in the environment variables + port_val = boot_request.process_description.env.get(var_name) + + if port_val is not None: + try: + port = int(port_val) + if port != 0: + self.log.info( + f"Extracted gunicorn port {port} from environment variable '{var_name}'." + ) + return port + except ValueError: + self.log.error( + f"Environment variable '{var_name}' ('{port_val}') is not an integer port." + ) + else: + self.log.warning( + f"Extracted port variable '{var_name}' but it was not found in environment map." + ) + + # Check for drunc-controller --port syntax (unchanged) + if "controller" in arg_str: + match = re.search(r"--port[\s=]+(\d+)", arg_str) + if match: + port = int(match.group(1)) + if port != 0: + self.log.info( + f"Extracted drunc-controller port {port} from command." + ) + return port + + # Check for drunc-controller -c grpc://... syntax (unchanged) + if "controller" in arg_str: + match = re.search(r"-c\s+[\"\']?grpc:\/\/[^:]+:(\d+)[\"\']?", arg_str) if match: - return int(match.group(1)) + port = int(match.group(1)) + if port != 0: + self.log.info( + f"Extracted drunc-controller gRPC port {port} from command." + ) + return port + else: + self.log.warning( + "Controller gRPC port is 0, cannot create NodePort." + ) + return None def _get_process_uid(self, query: ProcessQuery, order_by: str = None) -> list[str]: @@ -759,80 +950,198 @@ def _boot_impl(self, boot_request: BootRequest) -> ProcessInstanceList: process = self.__boot(boot_request, this_uuid) return ProcessInstanceList(values=[process]) - def __boot(self, boot_request: BootRequest, uuid: str) -> ProcessInstance: - """ - Internal boot method. Handles pod creation and special logic for the connection server. - - For the connection server: Wait for it to be ready and check the NodePort service - - For all other pods: Boot is NON-BLOCKING. - """ - session = boot_request.process_description.metadata.session - podname = boot_request.process_description.metadata.name - + def _run_pre_boot_checks( + self, session: str, podname: str, boot_request: BootRequest + ) -> None: + """Performs initial validation.""" if not validate_k8s_session_name(session): raise DruncK8sNamespaceException( f'Invalid session/namespace name "{session}". Must match RFC1123 label: ' "lowercase alphanumeric or '-', start/end with alphanumeric, max 63 chars." ) - if boot_request.process_restriction.allowed_hosts: - hostname = boot_request.process_restriction.allowed_hosts[0] - boot_request.process_description.metadata.hostname = hostname + def _wait_for_pod_api_ready( + self, podname: str, session: str, timeout: float + ) -> str: + """ + [HELPER] Blocking wait for a pod to be 'Running' and 'Ready' + in the K8s API. + Returns the node_name on success. + Raises DruncK8sException on timeout. + """ + self.log.info( + f"Stage 1: Waiting for '{podname}' pod to be Running and Ready..." + ) + start_time = time() - if uuid in self.boot_request: - raise DruncK8sPodException(f'"{session}.{podname}":{uuid} already exists!') + while time() - start_time < timeout: + try: + pod_status = self._core_v1_api.read_namespaced_pod_status( + podname, session + ) + if pod_status.status.phase == "Running": + is_ready = False + if pod_status.status.conditions: + for condition in pod_status.status.conditions: + if condition.type == "Ready" and condition.status == "True": + is_ready = True + break + + if is_ready: + node_name = pod_status.spec.node_name + self.log.info( + f"Stage 1: Pod '{podname}' is API Ready on node {node_name}." + ) + return node_name # Success! - # Extract ports for LCS - if podname == self.connection_server_name: - self.log.info(f"Waiting for '{podname}' to become ready...") + except self._api_error_v1_api as e: + if e.status == 404: + # Pod not created yet, this is expected, continue loop + pass + else: + # Re-raise other K8s API errors + raise e - port = None - env_vars = boot_request.process_description.env + sleep(self.pod_status_check_sleep) - if "CONNECTION_PORT" in env_vars: - port_str = env_vars["CONNECTION_PORT"] - try: - port = int(port_str) - self.log.info( - f"Using port {port} from 'CONNECTION_PORT' environment variable." - ) - except (ValueError, TypeError): - raise DruncK8sException( - f"The provided CONNECTION_PORT '{port_str}' is not a valid integer." - ) + # If we exit the loop, it's a timeout + raise DruncK8sException( + f"'{podname}' pod did not become API Ready in {timeout} seconds." + ) - if port is None: - self.log.info( - "CONNECTION_PORT not found in env, falling back to parsing gunicorn command." - ) - port = self._extract_port_from_cmd(boot_request) + def _wait_for_nodeport_http_ready(self, url: str, timeout: float) -> None: + """ + [HELPER] Blocking wait for a NodePort URL to be reachable via HTTP. + Raises DruncK8sException on timeout. + """ + self.log.info(f"Stage 2: Waiting for NodePort {url} to be reachable...") + start_time = time() - if port: - self.connection_server_port = port - self.connection_server_node_port = port - else: - raise DruncK8sException( - "Could not determine connection server port from 'CONNECTION_PORT' env var or gunicorn command." - ) + while time() - start_time < timeout: + try: + urllib.request.urlopen(url, timeout=1) + self.log.info(f"Stage 2: NodePort {url} is now active.") + return # Success! + except ( + urllib.error.URLError, + ConnectionRefusedError, + TimeoutError, + OSError, + ) as e: + self.log.debug(f"NodePort not ready yet ({e}), retrying...") + sleep(self.pod_status_check_sleep) - # Check for NodePort collision - api = self._core_v1_api - all_services = api.list_service_for_all_namespaces() - for svc in all_services.items: - if not svc.spec.type == "NodePort": - continue - for p in svc.spec.ports: - if p.node_port == self.connection_server_node_port and ( - svc.metadata.namespace != session - or svc.metadata.name != podname - ): - raise DruncK8sException( - f"NodePort {self.connection_server_node_port} is already in use by service " - f"{svc.metadata.name} in namespace {svc.metadata.namespace}. " - "Cannot start another local connection server with the same port." + raise DruncK8sException( + f"NodePort {url} did not become reachable in {timeout} seconds." + ) + + def _wait_for_nodeport_tcp_ready( + self, node_name: str, port: int, timeout: float + ) -> None: + """ + [HELPER] Blocking wait for a NodePort to be reachable via TCP socket. + Raises DruncK8sException on timeout. + """ + self.log.info( + f"Stage 2: Waiting for NodePort {node_name}:{port} to be reachable..." + ) + start_time = time() + + while time() - start_time < timeout: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(self.socket_retry_timeout) + result = sock.connect_ex((node_name, port)) + + if result == 0: + self.log.info( + f"Stage 2: NodePort {node_name}:{port} is active (TCP connect success)." + ) + return + else: + self.log.debug( + f"NodePort {node_name}:{port} not ready yet (socket error {result}), retrying..." ) - self._create_namespace(session) + except socket.gaierror as e: + self.log.warning( + f"Failed to resolve hostname '{node_name}': {e}. Retrying..." + ) + except Exception as e: + self.log.debug( + f"NodePort not ready yet (Socket error: {e}), retrying..." + ) + + sleep(self.pod_status_check_sleep) + + raise DruncK8sException( + f"NodePort {node_name}:{port} did not become reachable in {timeout} seconds." + ) + + def _wait_for_lcs_readiness(self, podname: str, session: str) -> None: + """Blocking two-stage wait for the Local Connection Server (NodePort) to be fully ready.""" + self.log.info(f"Waiting for LCS '{podname}' to be fully ready...") + start_time = time() + total_timeout = self.pod_ready_timeout + + # --- STAGE 1: Wait for Pod to be Running/Ready in K8s API --- + node_name = self._wait_for_pod_api_ready(podname, session, total_timeout) + + # --- STAGE 2: Wait for NodePort to be externally reachable (using HTTP urllib) --- + url = f"http://{node_name}:{self.connection_server_node_port}" + + # Calculate remaining time for stage 2, preserving original logic + elapsed_stage1 = time() - start_time + remaining_time = total_timeout - elapsed_stage1 + + if remaining_time <= 0: + raise DruncK8sException( + f"NodePort {url} check failed: No time left after API readiness." + ) + + self._wait_for_nodeport_http_ready(url, remaining_time) + + self.local_connection_server_is_booted = True + self.log.info(f"Connection server '{podname}' is fully ready.") + + def _wait_for_controller_readiness( + self, podname: str, session: str, boot_request: BootRequest + ) -> None: + """Blocking two-stage wait for Drunc Controller (NodePort) to be fully ready.""" + self.log.info( + f"Waiting for controller '{podname}' (NodePort) to become ready..." + ) + + controller_port = self._extract_port_from_cmd(boot_request) + if not controller_port or controller_port == 0: + raise DruncK8sException( + f"Cannot wait for '{podname}', port is 0 or missing." + ) + + # --- STAGE 1: Wait for Pod to be Running/Ready in K8s API --- + node_name = self._wait_for_pod_api_ready( + podname, session, self.pod_ready_timeout + ) + + # --- STAGE 2: Wait for NodePort to be externally reachable (using TCP socket) --- + self._wait_for_nodeport_tcp_ready( + node_name, controller_port, self.grpc_startup_timeout + ) + + self.log.info(f"Drunc controller '{podname}' is fully ready.") + + def __boot(self, boot_request: BootRequest, uuid: str) -> ProcessInstance: + """ + Internal boot method. Handles pre-checks, pod creation, and blocking wait for critical services. + """ + session = boot_request.process_description.metadata.session + podname = boot_request.process_description.metadata.name + + # Pre-checks (Session validation, NodePort collision) + self._run_pre_boot_checks(session, podname, boot_request) + # Resource Creation (Namespace, Pod, Labels) + self._create_namespace(session) self.boot_request[uuid] = BootRequest() self.boot_request[uuid].CopyFrom(boot_request) @@ -840,50 +1149,20 @@ def __boot(self, boot_request: BootRequest, uuid: str) -> ProcessInstance: self._add_label(podname, "pod", "uuid", uuid, session=session) self.log.info(f'"{session}.{podname}":{uuid} boot request sent.') - # Special handling only for the connection server + # Special handling and blocking wait for critical processes if podname == self.connection_server_name: - self.log.info(f"Waiting for '{podname}' to become ready...") - - start_time = time() - while time() - start_time < self.pod_ready_timeout: - try: - pod_status = self._core_v1_api.read_namespaced_pod_status( - podname, session - ) - if ( - pod_status.status.phase == "Running" - and pod_status.status.pod_ip - ): - self.log.info( - f"'{podname}' is ready with IP {pod_status.status.pod_ip}." - ) - self.local_connection_server_is_booted = True - - # Log connection information using the NodePort service - node_name = pod_status.spec.node_name - self.log.info(f"Connection server '{podname}' is ready.") - self.log.info( - f" -> For internal cluster access: 'http://localhost:{self.connection_server_port}'" - ) - self.log.info( - f" -> For external access, use NodePort {self.connection_server_node_port} on any cluster node IP (e.g., http://{node_name}:{self.connection_server_node_port})" - ) - - break - except self._api_error_v1_api as e: - if e.status == 404: - pass - else: - raise e - sleep(self.pod_status_check_sleep) - else: - raise DruncK8sException( - f"'{podname}' did not become ready in {self.pod_ready_timeout} seconds." - ) - - pd, pr, pu = ProcessDescription(), ProcessRestriction(), ProcessUUID(uuid=uuid) - pd.CopyFrom(self.boot_request[uuid].process_description) - pr.CopyFrom(self.boot_request[uuid].process_restriction) + self._wait_for_lcs_readiness(podname, session) + elif "root-controller" in podname: + self._wait_for_controller_readiness(podname, session, boot_request) + + # Post-Process + pd, pr, pu = ( + ProcessDescription(), + ProcessRestriction(), + ProcessUUID(uuid=uuid), + ) + pd.CopyFrom(boot_request.process_description) + pr.CopyFrom(boot_request.process_restriction) return ProcessInstance( process_description=pd, diff --git a/src/drunc/process_manager/process_manager_driver.py b/src/drunc/process_manager/process_manager_driver.py index 489f2b191..4feadbc30 100644 --- a/src/drunc/process_manager/process_manager_driver.py +++ b/src/drunc/process_manager/process_manager_driver.py @@ -28,7 +28,10 @@ from drunc.connectivity_service.exceptions import ApplicationLookupUnsuccessful from drunc.controller.utils import get_segment_lookup_timeout from drunc.exceptions import DruncSetupException, DruncShellException -from drunc.process_manager.utils import get_log_path, get_rte_script +from drunc.process_manager.utils import ( + get_log_path, + get_rte_script, +) from drunc.utils.grpc_utils import ( copy_token, extract_grpc_rich_error, @@ -350,6 +353,14 @@ def _connect_to_service( if session_dal.connectivity_service: connection_server = session_dal.connectivity_service.host connection_port = session_dal.connectivity_service.service.port + + if connection_server == "localhost": + resolved_server = resolve_localhost_to_hostname(connection_server) + self.log.debug( + f"Resolved connection server 'localhost' to '{resolved_server}' to avoid K8s hairpinning." + ) + connection_server = resolved_server + client = ConnectivityServiceClient( session_name, f"{connection_server}:{connection_port}" ) @@ -366,19 +377,30 @@ def _discover_controller( ): """ Attempts to discover the controller address after booting applications. + Tries dynamic lookup via connectivity service first, then falls back + to static OKS configuration. """ - top_controller_name = session_dal.segment.controller.id + try: + top_controller_name = session_dal.segment.controller.id + except AttributeError as e: + self.log.error(f"Could not determine controller name from OKS: {e}") + top_controller_name = "Unknown-Controller" # Set a default def get_controller_address(session_dal, session_name): from drunc.process_manager.oks_parser import collect_variables env = {} collect_variables(session_dal.environment, env) + + # 1: Try dynamic lookup via Connectivity Service if csc: + self.log.debug( + f"Attempting to discover controller '{top_controller_name}' via connectivity service at {connection_server}:{connection_port}" + ) try: timeout = ( get_segment_lookup_timeout(session_dal.segment, 60) + 60 - ) # root-controller timout to find all its children + 60s for the root controller to start itself + ) # root-controller timeout to find all its children + 60s for the root controller to start itself self.log.debug( f"Using a timeout of {timeout}s to find the [green]{top_controller_name}[/] on the connectivity service" ) @@ -390,33 +412,149 @@ def get_controller_address(session_dal, session_name): progress_bar=True, title=f"Looking for [green]{top_controller_name}[/] on the connectivity service...", ) + + address = uri.replace("grpc://", "") + self.log.debug( + f"Successfully discovered controller '{top_controller_name}' via connectivity service: {address}" + ) + return address + except ApplicationLookupUnsuccessful: + self.log.warning( + f"Connectivity service lookup failed: Application '{top_controller_name}' not found." + ) + # Log the original failure details self._log_controller_lookup_failure( session_name, top_controller_name, connection_server, connection_port, ) - return + self.log.warning( + "Falling back to static OKS configuration for address resolution." + ) - return uri.replace("grpc://", "") + except Exception as e: + self.log.error( + f"An unexpected error occurred during connectivity service lookup: {e}. " + "Falling back to static OKS configuration." + ) + + else: + self.log.warning( + "Connectivity service client (csc) is not available. Using static OKS configuration only." + ) + + # 2: Fallback to static OKS configuration + self.log.debug( + "Attempting to resolve controller address from static OKS configuration." + ) - service_id = top_controller_name + "_control" port_number = None protocol = None + service_found = None + + try: + self.log.debug( + f"Top controller name from OKS config: '{top_controller_name}'" + ) - for service in session_dal.segment.controller.exposes_service: - if service.id == service_id: - port_number = service.port - protocol = service.protocol - break + if ( + not hasattr(session_dal.segment.controller, "exposes_service") + or not session_dal.segment.controller.exposes_service + ): + self.log.error( + f"Controller '{top_controller_name}' in OKS config has no 'exposes_service' relationship defined or it's empty." + ) + return None + + self.log.debug( + f"Controller '{top_controller_name}' exposes services: {[s.id for s in session_dal.segment.controller.exposes_service]}" + ) + + # Get the first (and presumably only) control service linked + service_found = next( + iter(session_dal.segment.controller.exposes_service), None + ) + + if service_found: + self.log.debug( + f"Found linked control service object with ID: '{service_found.id}'" + ) + if ( + hasattr(service_found, "port") + and service_found.port is not None + ): + port_number = service_found.port + self.log.debug( + f"Extracted port from service '{service_found.id}': {port_number}" + ) + else: + self.log.error( + f"Service object '{service_found.id}' is missing the 'port' attribute or it's null." + ) + + if hasattr(service_found, "protocol") and service_found.protocol: + protocol = service_found.protocol + self.log.debug( + f"Extracted protocol from service '{service_found.id}': {protocol}" + ) + else: + self.log.error( + f"Service object '{service_found.id}' is missing the 'protocol' attribute or it's empty." + ) + + else: + self.log.error( + f"Could not retrieve the first service object from 'exposes_service' for controller '{top_controller_name}'." + ) + return None + + except AttributeError as e: + self.log.error( + f"Error accessing OKS configuration attributes: {e}. Check structure around session_dal.segment.controller." + ) + return None + except Exception as e: + self.log.error( + f"Unexpected error during service discovery from OKS: {e}" + ) + return None + + # Check if we successfully got a port and protocol if port_number is None or protocol is None: + self.log.error( + f"Failed to extract valid port ({port_number}) or protocol ({protocol}) for service '{service_found.id if service_found else 'N/A'}'. Cannot determine controller address." + ) + return None + + # Resolve the IP address of the host where the controller runs + try: + host_id = session_dal.segment.controller.runs_on.runs_on.id + self.log.debug(f"Controller runs on host ID: '{host_id}'") + ip = resolve_localhost_and_127_ip_to_network_ip(host_id) + self.log.debug(f"Resolved host ID '{host_id}' to IP: {ip}") + except AttributeError as e: + self.log.error( + f"Error accessing OKS configuration attributes for host resolution: {e}. Check structure around session_dal.segment.controller.runs_on." + ) + return None + except Exception as e: + self.log.error(f"Unexpected error during host IP resolution: {e}") + return None + + if not ip: + self.log.error( + f"Host ID '{host_id}' resolved to an empty or invalid IP address." + ) return None - ip = resolve_localhost_and_127_ip_to_network_ip( - session_dal.segment.controller.runs_on.runs_on.id + # If all checks passed, return the address + final_address = f"{ip}:{port_number}" + self.log.debug( + f"Successfully resolved controller address from OKS config: {final_address}" ) - return f"{ip}:{port_number}" + return final_address def keyboard_interrupt_on_sigint(signal, frame): self.log.warning("Interrupted") @@ -431,7 +569,7 @@ def keyboard_interrupt_on_sigint(signal, frame): connection_server = session_dal.connectivity_service.host connection_port = session_dal.connectivity_service.service.port self._log_controller_interrupt( - self, top_controller_name, connection_server, connection_port + top_controller_name, connection_server, connection_port ) else: self.log.warning( diff --git a/src/drunc/utils/grpc_utils.py b/src/drunc/utils/grpc_utils.py index 972a91e2a..67a447280 100644 --- a/src/drunc/utils/grpc_utils.py +++ b/src/drunc/utils/grpc_utils.py @@ -5,7 +5,7 @@ from druncschema.generic_pb2 import PlainText from druncschema.request_response_pb2 import Response, ResponseFlag from druncschema.token_pb2 import Token -from google.protobuf import any_pb2 +from google.protobuf import any_pb2, json_format from google.protobuf.descriptor import FieldDescriptor from google.protobuf.message import Message from google.rpc import code_pb2, error_details_pb2 @@ -311,3 +311,22 @@ def extract_grpc_rich_error(grpc_error: grpc.RpcError) -> GrpcErrorDetails: return GrpcErrorDetails( code=code, message=status.message or "No message", details=error_details ) + + +def grpc_proto_to_dict(proto_message: Message) -> dict: + """ + Converts a gRPC Protobuf message object to a Python dictionary. + """ + return json_format.MessageToDict( + proto_message, + preserving_proto_field_name=True, + # Removed: including_default_value_fields=True + ) + + +def dict_to_grpc_proto(data: dict, proto_class_instance: Message) -> Message: + """ + Converts a Python dictionary into an instance of a gRPC Protobuf message. + 'proto_class_instance' should be an empty instance, e.g., Token() + """ + return json_format.ParseDict(data, proto_class_instance, ignore_unknown_fields=True) diff --git a/tests/controller/test_controller_driver.py b/tests/controller/test_controller_driver.py index 571b49b55..9be7b2abf 100644 --- a/tests/controller/test_controller_driver.py +++ b/tests/controller/test_controller_driver.py @@ -6,10 +6,18 @@ from drunc.controller.controller_driver import ControllerDriver from drunc.exceptions import DruncException from drunc.utils.shell_utils import create_dummy_token_from_uname +import json +import pprint def setup_controller_driver(processes_and_logs, dal, session_name) -> ControllerDriver: connectivity_service_port = dal.connectivity_service.service.port + try: + print(json.dumps(processes_and_logs, indent=2, sort_keys=True, default=str)) + except Exception: + pprint.pprint(processes_and_logs, width=120, sort_dicts=True) + print(f"{dal=}") + print(f"{session_name=}") csc = ConnectivityServiceClient( session_name, diff --git a/tests/process_manager/test_process_manager_driver.py b/tests/process_manager/test_process_manager_driver.py index 6357efbbc..77c712dd5 100644 --- a/tests/process_manager/test_process_manager_driver.py +++ b/tests/process_manager/test_process_manager_driver.py @@ -9,6 +9,7 @@ or if a bug was introduced. """ +import socket from unittest.mock import MagicMock, patch import grpc @@ -485,15 +486,22 @@ def test_connect_to_service_success(mock_client_class, mock_driver): mock_client_instance = MagicMock() mock_client_class.return_value = mock_client_instance + pytest_hostname = socket.gethostname() + mock_session_dal = MagicMock() mock_session_dal.connectivity_service.host = "localhost" mock_session_dal.connectivity_service.service.port = 1234 - result = mock_driver._connect_to_service(mock_session_dal, "session1") + result_localhost = mock_driver._connect_to_service(mock_session_dal, "session1") + mock_client_class.assert_called_once_with("session1", f"{pytest_hostname}:1234") + assert result_localhost == (mock_client_instance, pytest_hostname, 1234) - mock_client_class.assert_called_once_with("session1", "localhost:1234") + mock_session_dal.connectivity_service.host = pytest_hostname + result_pytest_hostname = mock_driver._connect_to_service(mock_session_dal, "session2") + mock_client_class.assert_called_with("session2", f"{pytest_hostname}:1234") + assert result_pytest_hostname == (mock_client_instance, pytest_hostname, 1234) - assert result == (mock_client_instance, "localhost", 1234) + mock_client_class.assert_called_with("session2", f"{pytest_hostname}:1234") def test_connect_to_service_none(mock_driver):