diff --git a/pyproject.toml b/pyproject.toml index 8d8313ab..68038ae4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ ] license = { text = "BSD-3-Clause" } requires-python = ">=3.10" -dependencies = ["bidict", "pika", "setuptools", "stomp-py>=7"] +dependencies = ["bidict", "pika", "setuptools", "stomp-py>=7", "opentelemetry-api~=1.20.0", "opentelemetry-sdk~=1.20.0", "opentelemetry-exporter-otlp-proto-http~=1.20.0" ] [project.urls] Download = "https://github.com/DiamondLightSource/python-workflows/releases" @@ -53,6 +53,7 @@ OfflineTransport = "workflows.transport.offline_transport:OfflineTransport" pika = "workflows.util.zocalo.configuration:Pika" stomp = "workflows.util.zocalo.configuration:Stomp" transport = "workflows.util.zocalo.configuration:DefaultTransport" +opentelemetry = "workflows.util.zocalo.configuration:OTEL" [project.scripts] "workflows.validate_recipe" = "workflows.recipe.validate:main" diff --git a/requirements_dev.txt b/requirements_dev.txt index 8207c45b..89483de8 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -6,4 +6,4 @@ pytest-cov==6.0.0 pytest-mock==3.14.0 pytest-timeout==2.3.1 stomp-py==8.1.2 -websocket-client==1.8.0 +websocket-client==1.8.0 \ No newline at end of file diff --git a/src/workflows/recipe/__init__.py b/src/workflows/recipe/__init__.py index 0f1973f4..7bbe3ed1 100644 --- a/src/workflows/recipe/__init__.py +++ b/src/workflows/recipe/__init__.py @@ -3,8 +3,11 @@ import functools import logging from collections.abc import Callable +from contextlib import ExitStack from typing import Any +from opentelemetry import trace + from workflows.recipe.recipe import Recipe from workflows.recipe.validate import validate_recipe from workflows.recipe.wrapper import RecipeWrapper @@ -69,10 +72,35 @@ def unwrap_recipe(header, message): message = mangle_for_receiving(message) if header.get("workflows-recipe") in {True, "True", "true", 1}: rw = RecipeWrapper(message=message, transport=transport_layer) - if log_extender and rw.environment and rw.environment.get("ID"): - with log_extender("recipe_ID", rw.environment["ID"]): + + if log_extender and rw.environment["ID"]: + # Extract recipe ID from environment and add to current span + span = trace.get_current_span() + recipe_id = rw.environment["ID"] + span.set_attribute("recipe_id", recipe_id) + + # Extract span_id and trace_id for logging + span_context = span.get_span_context() + otel_logs = None + if span_context.is_valid: + span_id = span_context.span_id + trace_id = span_context.trace_id + + otel_logs = { + "span_id": span_id, + "trace_id": trace_id, + "recipe_id": recipe_id, + } + + with ExitStack() as stack: + # Configure the context depending on if service is emitting spans + stack.enter_context(log_extender("recipe_ID", recipe_id)) + if otel_logs: + stack.enter_context(log_extender("otel_logs", otel_logs)) return callback(rw, header, message.get("payload")) + return callback(rw, header, message.get("payload")) + if allow_non_recipe_messages: return callback(None, header, message) # self.log.warning('Discarding non-recipe message:\n' + \ diff --git a/src/workflows/services/common_service.py b/src/workflows/services/common_service.py index de2ef704..9223fc3a 100644 --- a/src/workflows/services/common_service.py +++ b/src/workflows/services/common_service.py @@ -9,8 +9,15 @@ import time from typing import Any +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor + import workflows import workflows.logging +from workflows.transport.middleware.otel_tracing import OTELTracingMiddleware class Status(enum.Enum): @@ -185,6 +192,40 @@ def start_transport(self): self.transport.subscription_callback_set_intercept( self._transport_interceptor ) + + # Configure OTELTracing if configuration is available + otel_config = ( + self.config._opentelemetry + if self.config and hasattr(self.config, "_opentelemetry") + else None + ) + if otel_config: + # Configure OTELTracing + resource = Resource.create( + { + SERVICE_NAME: self._service_name, + } + ) + + self.log.debug("Configuring OTELTracing") + provider = TracerProvider(resource=resource) + trace.set_tracer_provider(provider) + + # Configure BatchProcessor and OTLPSpanExporter using config values + otlp_exporter = OTLPSpanExporter( + endpoint=otel_config["endpoint"], + timeout=otel_config.get("timeout", 10), + ) + span_processor = BatchSpanProcessor(otlp_exporter) + provider.add_span_processor(span_processor) + + # Add OTELTracingMiddleware to the transport layer + tracer = trace.get_tracer(__name__) + otel_middleware = OTELTracingMiddleware( + tracer, service_name=self._service_name + ) + self._transport.add_middleware(otel_middleware) + metrics = self._environment.get("metrics") if metrics: import prometheus_client diff --git a/src/workflows/transport/middleware/otel_tracing.py b/src/workflows/transport/middleware/otel_tracing.py new file mode 100644 index 00000000..f27a5839 --- /dev/null +++ b/src/workflows/transport/middleware/otel_tracing.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +import functools +from collections.abc import Callable + +from opentelemetry import trace +from opentelemetry.context import Context +from opentelemetry.propagate import extract, inject + +from workflows.transport.common_transport import MessageCallback, TemporarySubscription + + +class OTELTracingMiddleware: + def __init__(self, tracer: trace.Tracer, service_name: str): + self.tracer = tracer + self.service_name = service_name + + def _set_span_attributes(self, span, **attributes): + """Helper method to set common span attributes""" + span.set_attribute("service_name", self.service_name) + for key, value in attributes.items(): + if value is not None: + span.set_attribute(key, value) + + def send(self, call_next: Callable, destination: str, message, **kwargs): + # Get current span context (may be None if this is the root span) + current_span = trace.get_current_span() + parent_context = ( + trace.set_span_in_context(current_span) if current_span else None + ) + + with self.tracer.start_as_current_span( + "transport.send", + context=parent_context, + ) as span: + self._set_span_attributes(span, destination=destination) + + # Inject the current trace context into the message headers + headers = kwargs.get("headers", {}) + if headers is None: + headers = {} + inject(headers) # This modifies headers in-place + kwargs["headers"] = headers + + return call_next(destination, message, **kwargs) + + def subscribe( + self, call_next: Callable, channel: str, callback: Callable, **kwargs + ) -> int: + @functools.wraps(callback) + def wrapped_callback(header, message): + # Extract trace context from message headers + ctx = extract(header) if header else Context() + + # Start a new span with the extracted context + with self.tracer.start_as_current_span( + "transport.subscribe", + context=ctx, + ) as span: + self._set_span_attributes(span, channel=channel) + + # Call the original callback - this will process the message + # and potentially call send() which will pick up this context + return callback(header, message) + + return call_next(channel, wrapped_callback, **kwargs) + + def subscribe_broadcast( + self, call_next: Callable, channel: str, callback: Callable, **kwargs + ) -> int: + @functools.wraps(callback) + def wrapped_callback(header, message): + # Extract trace context from message headers + ctx = extract(header) if header else Context() + + # Start a new span with the extracted context + with self.tracer.start_as_current_span( + "transport.subscribe_broadcast", + context=ctx, + ) as span: + self._set_span_attributes(span, channel=channel) + + return callback(header, message) + + return call_next(channel, wrapped_callback, **kwargs) + + def subscribe_temporary( + self, + call_next: Callable, + channel_hint: str | None, + callback: MessageCallback, + **kwargs, + ) -> TemporarySubscription: + @functools.wraps(callback) + def wrapped_callback(header, message): + # Extract trace context from message headers + ctx = extract(header) if header else Context() + + # Start a new span with the extracted context + with self.tracer.start_as_current_span( + "transport.subscribe_temporary", + context=ctx, + ) as span: + self._set_span_attributes(span, channel_hint=channel_hint) + + return callback(header, message) + + return call_next(channel_hint, wrapped_callback, **kwargs) + + def unsubscribe( + self, + call_next: Callable, + subscription: int, + drop_callback_reference=False, + **kwargs, + ): + # Get current span context + current_span = trace.get_current_span() + current_context = ( + trace.set_span_in_context(current_span) if current_span else Context() + ) + + with self.tracer.start_as_current_span( + "transport.unsubscribe", + context=current_context, + ) as span: + self._set_span_attributes(span, subscription_id=subscription) + + call_next( + subscription, drop_callback_reference=drop_callback_reference, **kwargs + ) + + def ack( + self, + call_next: Callable, + message, + subscription_id: int | None = None, + **kwargs, + ): + # Get current span context + current_span = trace.get_current_span() + current_context = ( + trace.set_span_in_context(current_span) if current_span else Context() + ) + + with self.tracer.start_as_current_span( + "transport.ack", + context=current_context, + ) as span: + self._set_span_attributes(span, subscription_id=subscription_id) + + call_next(message, subscription_id=subscription_id, **kwargs) + + def nack( + self, + call_next: Callable, + message, + subscription_id: int | None = None, + **kwargs, + ): + # Get current span context + current_span = trace.get_current_span() + current_context = ( + trace.set_span_in_context(current_span) if current_span else Context() + ) + + with self.tracer.start_as_current_span( + "transport.nack", + context=current_context, + ) as span: + self._set_span_attributes(span, subscription_id=subscription_id) + + call_next(message, subscription_id=subscription_id, **kwargs) + + def transaction_begin( + self, call_next: Callable, subscription_id: int | None = None, **kwargs + ) -> int: + """Start a new transaction span""" + # Get current span context (may be None if this is the root span) + current_span = trace.get_current_span() + current_context = ( + trace.set_span_in_context(current_span) if current_span else Context() + ) + + with self.tracer.start_as_current_span( + "transaction.begin", + context=current_context, + ) as span: + self._set_span_attributes(span, subscription_id=subscription_id) + + return call_next(subscription_id=subscription_id, **kwargs) + + def transaction_abort( + self, call_next: Callable, transaction_id: int | None = None, **kwargs + ): + """Abort a transaction span""" + # Get current span context + current_span = trace.get_current_span() + current_context = ( + trace.set_span_in_context(current_span) if current_span else Context() + ) + + with self.tracer.start_as_current_span( + "transaction.abort", + context=current_context, + ) as span: + self._set_span_attributes(span, transaction_id=transaction_id) + + call_next(transaction_id=transaction_id, **kwargs) + + def transaction_commit( + self, call_next: Callable, transaction_id: int | None = None, **kwargs + ): + """Commit a transaction span""" + # Get current span context + current_span = trace.get_current_span() + current_context = ( + trace.set_span_in_context(current_span) if current_span else Context() + ) + + with self.tracer.start_as_current_span( + "transaction.commit", + context=current_context, + ) as span: + self._set_span_attributes(span, transaction_id=transaction_id) + + call_next(transaction_id=transaction_id, **kwargs) diff --git a/src/workflows/util/zocalo/configuration.py b/src/workflows/util/zocalo/configuration.py index 08a600aa..359fe37c 100644 --- a/src/workflows/util/zocalo/configuration.py +++ b/src/workflows/util/zocalo/configuration.py @@ -8,6 +8,26 @@ from workflows.transport.stomp_transport import StompTransport +class OTEL: + """A Zocalo configuration plugin to pre-populate OTELTracing config defaults""" + + class Schema(PluginSchema): + host = fields.Str(required=True) + port = fields.Int(required=True) + timeout = fields.Int(required=False, load_default=10) + + # Store configuration for access by services + config = {} + + @staticmethod + def activate(configuration): + # Build the full endpoint URL + endpoint = f"https://{configuration['host']}:{configuration['port']}/v1/traces" + OTEL.config["endpoint"] = endpoint + OTEL.config["timeout"] = configuration.get("timeout", 10) + return OTEL.config + + class Stomp: """A Zocalo configuration plugin to pre-populate StompTransport config defaults"""