diff --git a/src/drunc/utils/__init__.py b/src/drunc/utils/__init__.py
index 786dda6c3..4ebf79608 100644
--- a/src/drunc/utils/__init__.py
+++ b/src/drunc/utils/__init__.py
@@ -1,3 +1,5 @@
+"""drunc utilities module."""
+
from drunc.utils.utils import get_logger
# Initialise utils logger with Rich handler
diff --git a/src/drunc/utils/configuration.py b/src/drunc/utils/configuration.py
index 99d08db7b..d5747a7e6 100644
--- a/src/drunc/utils/configuration.py
+++ b/src/drunc/utils/configuration.py
@@ -1,6 +1,9 @@
+"""Configuration utilities for DRUNC."""
+
import json
import os
from enum import Enum
+from typing import Protocol, cast
import conffwk
@@ -9,6 +12,8 @@
class ConfTypes(Enum):
+ """Enumeration of supported configuration types."""
+
Unknown = 0
# End product
@@ -21,6 +26,17 @@ class ConfTypes(Enum):
def CLI_to_ConfTypes(scheme: str) -> ConfTypes:
+ """Convert a CLI scheme string to a ConfTypes enum.
+
+ Args:
+ scheme: The scheme string ("file", "oksconflibs", or "").
+
+ Returns:
+ ConfTypes: The corresponding configuration type.
+
+ Raises:
+ DruncSetupException: If the scheme is not recognized.
+ """
match scheme:
case "file":
return ConfTypes.JsonFileName
@@ -31,20 +47,43 @@ def CLI_to_ConfTypes(scheme: str) -> ConfTypes:
def parse_conf_url(url: str) -> tuple[str, ConfTypes]:
+ """Parse a configuration URL into scheme and type.
+
+ Args:
+ url: The configuration URL (format: "scheme:filename").
+
+ Returns:
+ tuple[str, ConfTypes]: A tuple of (url, conf_type).
+ """
scheme, filename = url.split(":")
t = CLI_to_ConfTypes(scheme)
return url, t
class ConfigurationNotFound(DruncSetupException):
- def __init__(self, requested_path):
+ """Exception raised when configuration is not found."""
+
+ def __init__(self, requested_path: str) -> None:
+ """Initialize the ConfigurationNotFound exception.
+
+ Args:
+ requested_path: The path to the configuration that was not found.
+ """
super().__init__(
f"The configuration '{requested_path}' is not in $DUNEDAQ_DB_PATH, perhaps you forgot to 'dbt-workarea-env && dbt-build'?"
)
class ConfTypeNotSupported(DruncSetupException):
- def __init__(self, conf_type: ConfTypes, class_name: str):
+ """Exception raised when a configuration type is not supported."""
+
+ def __init__(self, conf_type: ConfTypes, class_name: str) -> None:
+ """Initialize the ConfTypeNotSupported exception.
+
+ Args:
+ conf_type: The configuration type that is not supported.
+ class_name: The name of the class where this type is not supported.
+ """
if not isinstance(class_name, str):
class_name = class_name.__class__.__name__
message = f"'{conf_type}' is not supported by '{class_name}'"
@@ -52,22 +91,62 @@ def __init__(self, conf_type: ConfTypes, class_name: str):
class OKSKey:
- def __init__(self, schema_file: str, class_name: str, obj_uid: str, session: str):
+ """Key information for accessing OKS configuration objects."""
+
+ def __init__(
+ self, schema_file: str, class_name: str, obj_uid: str, session: str
+ ) -> None:
+ """Initialize an OKSKey.
+
+ Args:
+ schema_file: The OKS schema file path.
+ class_name: The class name in the OKS schema.
+ obj_uid: The unique identifier for the object.
+ session: The session name.
+ """
self.schema_file = schema_file
self.class_name = class_name
self.obj_uid = obj_uid
self.session = session
+class _DataTypeName(Protocol):
+ _name_: str
+
+
+class _ConfigurationData(Protocol):
+ type: _DataTypeName
+ broadcaster: object
+ authoriser: object
+
+
class ConfHandler:
+ """Handler for loading and parsing DRUNC configurations.
+
+ Supports multiple configuration types including JSON files, Protobuf messages,
+ and OKS.
+ """
+
def __init__(
self,
- data=None,
- type=ConfTypes.PyObject,
- oks_key: OKSKey = None,
- *args,
- **kwargs,
- ):
+ data: object = None,
+ type: ConfTypes = ConfTypes.PyObject,
+ oks_key: OKSKey | None = None,
+ *args: object,
+ **kwargs: object,
+ ) -> None:
+ """Initialize a ConfHandler.
+
+ Args:
+ data: The configuration data. Defaults to None.
+ type: The configuration type. Defaults to PyObject.
+ oks_key: OKS key if using OKS configuration. Defaults to None.
+ *args: Additional positional arguments.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ DruncSetupException: If OKS type is used without an OKS key.
+ """
self.class_name = self.__class__.__name__
self.log = get_logger("utils." + self.class_name)
self.initial_type = type
@@ -84,26 +163,52 @@ def __init__(
self.oks_key = oks_key
self.validate_and_parse_configuration_location(*args, **kwargs)
- def get_data(self):
+ def get_data(self) -> object:
+ """Get the configuration data.
+
+ Returns:
+ Any: The stored configuration data.
+ """
return self.data
- def get_data_type_name(self):
- return self.get_data().type._name_
+ def get_data_type_name(self) -> str:
+ """Get the type name of the configuration data.
- def get_data_broadcaster(self):
- return self.get_data().broadcaster
+ Returns:
+ str: The name of the data type.
+ """
+ return str(cast(_ConfigurationData, self.get_data()).type._name_)
- def get_data_authoriser(self):
- return self.get_data().authoriser
+ def get_data_broadcaster(self) -> object:
+ """Get the broadcaster from the configuration data.
- def copy_oks_key(self):
+ Returns:
+ Any: The broadcaster object.
+ """
+ return cast(_ConfigurationData, self.get_data()).broadcaster
+
+ def get_data_authoriser(self) -> object:
+ """Get the authoriser from the configuration data.
+
+ Returns:
+ Any: The authoriser object.
+ """
+ return cast(_ConfigurationData, self.get_data()).authoriser
+
+ def copy_oks_key(self) -> OKSKey | None:
+ """Get a copy of the OKS key if one exists.
+
+ Returns:
+ OKSKey | None: The OKS key, or None if not using OKS configuration.
+ """
return self.oks_key
- def _parse_oks_file(self, oks_path):
+ def _parse_oks_file(self, oks_path: str) -> object:
try:
self.oks_path = oks_path
self.log.debug(f"Using {self.oks_path} to configure")
self.db = conffwk.Configuration(self.oks_path)
+ assert self.oks_key is not None, "OKS key is required for OKS configuration"
return self.db.get_dal(
class_name=self.oks_key.class_name, uid=self.oks_key.obj_uid
)
@@ -118,16 +223,26 @@ def _parse_oks_file(self, oks_path):
"OKS params where not passed to this ConfigurationHandler, cannot parse OKS configurations"
) from e
- def _post_process_oks(self):
+ def _post_process_oks(self, *args: object, **kwargs: object) -> None:
pass
- def _parse_pbany(self, pbany_data):
- raise ConfTypeNotSupported(ConfTypes.ProtobufAny, self)
+ def _parse_pbany(self, pbany_data: object) -> object:
+ raise ConfTypeNotSupported(ConfTypes.ProtobufAny, self.class_name)
+
+ def _parse_dict(self, data: dict[str, object]) -> object:
+ raise ConfTypeNotSupported(ConfTypes.JsonFileName, self.class_name)
+
+ def validate_and_parse_configuration_location(
+ self, *args: object, **kwargs: object
+ ) -> None:
+ """Validate and parse the configuration from the provided location.
- def _parse_dict(self, data):
- raise ConfTypeNotSupported(ConfTypes.JsonFileName, self)
+ Supports JsonFileName, OKSFileName, and PyObject types.
- def validate_and_parse_configuration_location(self, *args, **kwargs):
+ Args:
+ *args: Additional positional arguments.
+ **kwargs: Additional keyword arguments.
+ """
match self.initial_type:
case ConfTypes.PyObject:
self.data = self.initial_data
@@ -135,8 +250,8 @@ def validate_and_parse_configuration_location(self, *args, **kwargs):
self._post_process_oks(*args, **kwargs)
case ConfTypes.JsonFileName:
- resolved = expand_path(self.initial_data, True)
- if not os.path.exists(expand_path(self.initial_data)):
+ resolved = expand_path(cast(str, self.initial_data), True)
+ if not os.path.exists(expand_path(cast(str, self.initial_data))):
raise DruncSetupException(
f"Location {resolved} ({self.initial_data}) is empty!"
)
@@ -148,7 +263,7 @@ def validate_and_parse_configuration_location(self, *args, **kwargs):
self._post_process_oks(*args, **kwargs)
case ConfTypes.OKSFileName:
- self.data = self._parse_oks_file(self.initial_data)
+ self.data = self._parse_oks_file(cast(str, self.initial_data))
self.type = ConfTypes.PyObject
self._post_process_oks(*args, **kwargs)
diff --git a/src/drunc/utils/flask_manager.py b/src/drunc/utils/flask_manager.py
index 80c504e06..cc0450f68 100644
--- a/src/drunc/utils/flask_manager.py
+++ b/src/drunc/utils/flask_manager.py
@@ -1,27 +1,74 @@
+"""Flask application manager utilities for DRUNC."""
+
import os
import signal
import threading
import time
from multiprocessing import Process
-from typing import NoReturn
+from typing import TYPE_CHECKING, Protocol
-import gunicorn.app.base
import psutil
import requests
from flask import Flask, jsonify, make_response, request
-from flask_restful import Api, Resource
+
+if TYPE_CHECKING:
+
+ class _GunicornConfig(Protocol):
+ settings: dict[str, object]
+
+ def set(self, key: str, value: object) -> None: ...
+
+ class _BaseApplication:
+ cfg: _GunicornConfig
+
+ def __init__(self, *args: object, **kwargs: object) -> None: ...
+ def run(self) -> None: ...
+
+ class _Resource:
+ pass
+
+ class Api:
+ """Typing stub for flask_restful.Api."""
+
+ def __init__(self, app: Flask) -> None:
+ """Initialize the API with a Flask application."""
+ ...
+
+ def add_resource(
+ self, resource: type[_Resource], *urls: str, **kwargs: object
+ ) -> None:
+ """Register a resource class on one or more URL routes."""
+ ...
+
+else:
+ from flask_restful import Api
+ from flask_restful import Resource as _Resource
+ from gunicorn.app.base import BaseApplication as _BaseApplication
from drunc.exceptions import DruncCommandException
from drunc.utils.utils import get_logger, get_new_port
-class GunicornStandaloneApplication(gunicorn.app.base.BaseApplication):
- def __init__(self, app, options=None):
+class GunicornStandaloneApplication(_BaseApplication):
+ """Standalone Gunicorn application wrapper."""
+
+ def __init__(
+ self,
+ app: Flask,
+ options: dict[str, object] | None = None,
+ ) -> None:
+ """Initialize a GunicornStandaloneApplication.
+
+ Args:
+ app: The Flask application to run.
+ options: Configuration options for Gunicorn. Defaults to None.
+ """
self.options = options or {}
self.application = app
super().__init__()
- def load_config(self):
+ def load_config(self) -> None:
+ """Load Gunicorn configuration from options."""
config = {
key: value
for key, value in self.options.items()
@@ -30,23 +77,32 @@ def load_config(self):
for key, value in config.items():
self.cfg.set(key.lower(), value)
- def load(self):
+ def load(self) -> Flask:
+ """Load the Flask application.
+
+ Returns:
+ Flask: The Flask application.
+ """
return self.application
class CannotStartFlaskManager(DruncCommandException):
+ """Exception raised when the Flask manager cannot start."""
+
pass
class FlaskManager(threading.Thread):
- """This class is a manager for flask.
- It allows to have a Flask server under a thread, start and stop it.
- Note that it creates another -trivial- endpoint accessible at the route /readystatus.
- This is used to poll if the service is up, however the user can provide it, and
+ """Manager for Flask applications running in a separate thread.
+
+ It allows to have a Flask server under a thread,
+ start and stop it. Note that it creates another endpoint accessible at the route
+ /readystatus. This is used to poll if the service is up, however the user can
+ provide it.
To use this code, one can use the following example:
-
+ ```python
from flask import Flask
from flask_restful import Api
app = Flask('some-name')
@@ -66,27 +122,44 @@ class FlaskManager(threading.Thread):
while not manager.is_ready():
from time import sleep
sleep(0.1)
-
+ ```
Then, later on, to stop it:
-
+
+ ```python
manager.stop()
-
+ ```
"""
- def __init__(self, name, app, port, workers=1, host="0.0.0.0"):
+ def __init__(
+ self,
+ name: str,
+ app: Flask,
+ port: int,
+ workers: int = 1,
+ host: str = "0.0.0.0",
+ ) -> None:
+ """Initialize a FlaskManager.
+
+ Args:
+ name: The name of the Flask manager.
+ app: The Flask application to manage.
+ port: The port to run the Flask server on.
+ workers: The number of Gunicorn workers. Defaults to 1.
+ host: The host address to bind to. Defaults to "0.0.0.0".
+ """
super(FlaskManager, self).__init__(daemon=True)
self.log = get_logger(f"{name}-flaskmanager", stream_handlers=True)
self.name = name
self.app = app
- self.prod_app = None
- self.flask = None
+ self.prod_app: GunicornStandaloneApplication | None = None
+ self.flask: Process | None = None
self.host = host
self.port = port
self.workers = workers
- self.gunicorn_pid = None
+ self.gunicorn_pid: int | None = None
self.ready = False
self.joined = False
self.ready_lock = threading.Lock()
@@ -97,7 +170,7 @@ def _create_flask(self) -> Process:
if "get_ready_status" in rule.endpoint:
need_ready = False
- def get_ready_status():
+ def get_ready_status() -> str:
return "ready"
if need_ready:
@@ -112,8 +185,10 @@ def get_ready_status():
"workers": self.workers,
},
)
+ prod_app = self.prod_app
+ assert prod_app is not None, "GunicornStandaloneApplication creation failed"
- def run_gunicorn_with_signal_handling():
+ def run_gunicorn_with_signal_handling() -> None:
"""Run gunicorn with SIGHUP ignored to prevent reload on shutdown.
This prevents gunicorn from reloading when the parent process receives SIGHUP.
@@ -127,7 +202,7 @@ def run_gunicorn_with_signal_handling():
# May fail if already in a process group or on some systems, ignore
pass
- self.prod_app.run()
+ prod_app.run()
thread_name = f"{self.name}_thread"
flask_srv = Process( # Indeed, we've just forked this sucker
@@ -185,21 +260,33 @@ def run_gunicorn_with_signal_handling():
return flask_srv
- def __del__(self):
+ def __del__(self) -> None:
+ """Cleanup when the FlaskManager is destroyed."""
self.stop()
- def stop(self) -> NoReturn:
- # gunicorn is forked, so we need to now need send signal ourselves
+ def stop(self) -> None:
+ """Stop the Flask manager and terminate the Gunicorn process.
+
+ Sends SIGTERM to the Gunicorn process and joins the Flask process thread.
+ """
if self.gunicorn_pid:
gunicorn_proc = psutil.Process(self.gunicorn_pid)
# https://github.com/benoitc/gunicorn/blob/ab9c8301cb9ae573ba597154ddeea16f0326fc15/docs/source/signals.rst#master-process
# TOTAL DESTRUCTION
gunicorn_proc.send_signal(signal.SIGTERM)
- self.flask.terminate()
+ if self.flask is not None:
+ self.flask.terminate()
self.join()
- def restart_renew(self):
+ def restart_renew(self) -> "FlaskManager":
+ """Restart and renew the Flask manager.
+
+ Stops the current instance and creates a new one with the same configuration.
+
+ Returns:
+ FlaskManager: A new FlaskManager instance with the same settings.
+ """
# well, we cannot really do that.
# we have to hack it a bit:
# unfortunately, this means you need to do:
@@ -217,15 +304,25 @@ def restart_renew(self):
time.sleep(0.1)
return fm
- def is_ready(self):
+ def is_ready(self) -> bool:
+ """Check if the Flask manager is ready to serve requests.
+
+ Returns:
+ bool: True if ready, False otherwise.
+ """
with self.ready_lock:
return self.ready
- def is_terminated(self):
+ def is_terminated(self) -> bool:
+ """Check if the Flask manager has been terminated.
+
+ Returns:
+ bool: True if terminated, False otherwise.
+ """
with self.ready_lock:
return self.joined
- def _create_and_join_flask(self):
+ def _create_and_join_flask(self) -> None:
with self.ready_lock:
self.ready = False
self.joined = False
@@ -238,16 +335,25 @@ def _create_and_join_flask(self):
self.log.info(f"{self.name}-flaskmanager terminated")
- def run(self) -> NoReturn:
+ def run(self) -> None:
+ """Run the Flask server in the thread.
+
+ This method is called when the thread is started.
+ """
self._create_and_join_flask()
-def main():
- class DummyEndpoint(Resource):
- def post(self):
+def main() -> None:
+ """Main entry point for demonstrating the FlaskManager.
+
+ Creates a simple Flask application with a dummy endpoint and starts it.
+ """
+
+ class DummyEndpoint(_Resource):
+ def post(self) -> None:
print(request)
- def get(self):
+ def get(self) -> object:
return make_response(jsonify({"weeeee": "wooo"}))
app = Flask("test-app")
diff --git a/src/drunc/utils/grpc_utils.py b/src/drunc/utils/grpc_utils.py
index 6ecedc48c..14f4c5bd7 100644
--- a/src/drunc/utils/grpc_utils.py
+++ b/src/drunc/utils/grpc_utils.py
@@ -1,5 +1,9 @@
+"""gRPC utilities for DRUNC."""
+
+from __future__ import annotations
+
from dataclasses import dataclass
-from typing import List, NoReturn, Optional
+from typing import Callable, NoReturn, cast
import grpc
from druncschema.generic_pb2 import PlainText
@@ -20,7 +24,15 @@
class UnpackingError(DruncCommandException):
- def __init__(self, data, format):
+ """Exception raised when unpacking gRPC messages fails."""
+
+ def __init__(self, data: object, format: type[Message]) -> None:
+ """Initialize the UnpackingError.
+
+ Args:
+ data: The data that failed to unpack.
+ format: The expected format.
+ """
self.data = data
self.format = format
@@ -50,13 +62,33 @@ def unpack_error_response(name: str, text: str, token: Token) -> Response:
)
-def pack_to_any(data) -> any_pb2.Any:
+def pack_to_any(data: Message) -> any_pb2.Any:
+ """Pack a protobuf message into an Any message.
+
+ Args:
+ data: The protobuf message to pack.
+
+ Returns:
+ any_pb2.Any: The packed message.
+ """
any = any_pb2.Any()
any.Pack(data)
return any
-def unpack_any(data, format):
+def unpack_any(data: any_pb2.Any, format: type[Message]) -> Message:
+ """Unpack an Any message into a specific protobuf format.
+
+ Args:
+ data: The Any message to unpack.
+ format: The protobuf message type to unpack into.
+
+ Returns:
+ Message: The unpacked message.
+
+ Raises:
+ UnpackingError: If the message cannot be unpacked into the specified format.
+ """
if not data.Is(format.DESCRIPTOR):
raise UnpackingError(data, format)
req = format()
@@ -65,13 +97,27 @@ def unpack_any(data, format):
class ServerUnreachable(DruncException):
- def __init__(self, message):
+ """Exception raised when the gRPC server is unreachable."""
+
+ def __init__(self, message: str) -> None:
+ """Initialize the ServerUnreachable exception.
+
+ Args:
+ message: The error message.
+ """
self.message = message
super(ServerUnreachable, self).__init__(message)
class ServerTimeout(DruncException):
- def __init__(self, message):
+ """Exception raised when the gRPC server times out."""
+
+ def __init__(self, message: str) -> None:
+ """Initialize the ServerTimeout exception.
+
+ Args:
+ message: The error message.
+ """
self.message = message
super(ServerTimeout, self).__init__(message)
@@ -97,7 +143,7 @@ def server_is_reachable(grpc_error: grpc.RpcError) -> bool:
return True
-def rethrow_if_unreachable_server(grpc_error: grpc.RpcError) -> NoReturn:
+def rethrow_if_unreachable_server(grpc_error: grpc.RpcError) -> None:
"""
Raise a ServerUnreachable exception if the gRPC error indicates the server is unreachable.
@@ -114,7 +160,7 @@ def rethrow_if_unreachable_server(grpc_error: grpc.RpcError) -> NoReturn:
raise ServerUnreachable(grpc_error._details) from grpc_error
-def rethrow_if_timeout(grpc_error: grpc.RpcError) -> NoReturn:
+def rethrow_if_timeout(grpc_error: grpc.RpcError) -> None:
"""
Raise a ServerTimeout if timeout.
@@ -135,6 +181,7 @@ def handle_grpc_error(error: grpc.RpcError) -> NoReturn:
Args:
error: The gRPC error to handle.
+
Raises:
A custom exception if the error matches a known category, or the original
gRPC error if no classification applies.
@@ -144,12 +191,11 @@ def handle_grpc_error(error: grpc.RpcError) -> NoReturn:
raise error
-def interrupt_if_unreachable_server(grpc_error: grpc.RpcError) -> Optional[str]:
- """
- Interrupt if server is not reachable and return the error details.
+def interrupt_if_unreachable_server(grpc_error: grpc.RpcError) -> str | None:
+ """Interrupt if server is not reachable and return the error details.
Args:
- grpc_error (grpc.RpcError): The gRPC error
+ grpc_error: The gRPC error
Returns:
str | None: The internal error details if the server is unreachable and details are available;
@@ -157,9 +203,10 @@ def interrupt_if_unreachable_server(grpc_error: grpc.RpcError) -> Optional[str]:
"""
if not server_is_reachable(grpc_error):
if hasattr(grpc_error, "_state"):
- return grpc_error._state.details
+ return str(grpc_error._state.details)
elif hasattr(grpc_error, "_details"):
- return grpc_error._details
+ return str(grpc_error._details)
+ return None
def copy_token(token: Token) -> Token:
@@ -176,10 +223,19 @@ def copy_token(token: Token) -> Token:
return token_copy
-def dict_to_grpc_proto(data: dict, proto_class_instance: Message) -> Message:
- """
- Converts a Python dictionary into an instance of a gRPC Protobuf message.
+def dict_to_grpc_proto(
+ data: dict[str, object], proto_class_instance: Message
+) -> Message:
+ """Converts a Python dictionary into an instance of a gRPC Protobuf message.
+
'proto_class_instance' should be an empty instance, e.g., Token()
+
+ Args:
+ data: The dictionary to convert.
+ proto_class_instance: An empty instance of the target protobuf message type.
+
+ Returns:
+ Message: The converted protobuf message.
"""
return json_format.ParseDict(data, proto_class_instance, ignore_unknown_fields=True)
@@ -199,21 +255,19 @@ class GrpcErrorDetails:
Attributes:
code (str): The gRPC status code name (e.g., "NOT_FOUND")
message (str): The error message from the gRPC status
- details (List[str]): A list of formatted error detail strings
+ details: A list of formatted error detail strings or protobuf Messages.
"""
code: str
message: str
- details: List[str]
+ details: list[str | Message]
- def __str__(self):
- """
- Return a human-readable string representation of the error.
- """
+ def __str__(self) -> str:
+ """Return a human-readable string representation of the error."""
lines = [f"[{self.code}] {self.message}"]
for detail in self.details:
# If it's a Proto message format the error detail
- if hasattr(detail, "DESCRIPTOR"):
+ if isinstance(detail, Message):
lines.extend(format_error_details(detail))
else:
lines.append(str(detail))
@@ -312,7 +366,7 @@ def extract_grpc_rich_error(grpc_error: grpc.RpcError) -> GrpcErrorDetails:
"""
code = grpc_error.code().name if grpc_error.code() else "UNKNOWN"
try:
- status = rpc_status.from_call(grpc_error)
+ status = rpc_status.from_call(cast(grpc.Call, grpc_error))
except NotImplementedError:
return GrpcErrorDetails(code=code, message="No message", details=[])
@@ -342,9 +396,9 @@ def extract_grpc_rich_error(grpc_error: grpc.RpcError) -> GrpcErrorDetails:
def abort_with_rich_error_status(
context: grpc.ServicerContext,
- grpc_error_code: code_pb2.Code,
+ grpc_error_code: int,
message: str,
- error_obj: Message,
+ error_obj: object,
) -> NoReturn:
"""
Aborts the current gRPC call with a rich error status containing
@@ -375,24 +429,37 @@ def abort_with_rich_error_status(
raise Exception(f"Aborting with status: {message}")
-class RichErrorServerInterceptor(grpc.ServerInterceptor):
+class RichErrorServerInterceptor:
"""
A gRPC server interceptor that catches exceptions and converts them into
- rich error statuses with structured error details."""
+ rich error statuses with structured error details.
+ """
- def intercept_service(self, continuation, handler_call_details):
+ def intercept_service(
+ self,
+ continuation: Callable[
+ [grpc.HandlerCallDetails],
+ grpc.RpcMethodHandler[object, object] | None,
+ ],
+ handler_call_details: grpc.HandlerCallDetails,
+ ) -> grpc.RpcMethodHandler[object, object] | None:
"""
Intercept gRPC service calls to handle exceptions and convert them
into rich error statuses.
"""
handler = continuation(handler_call_details)
+ if handler is None:
+ return None
- def error_wrapper(request, context):
+ def error_wrapper(request: object, context: grpc.ServicerContext) -> object:
try:
- return handler.unary_unary(request, context)
+ unary_unary = handler.unary_unary
+ if unary_unary is None:
+ return handler
+ return unary_unary(request, context)
except DruncSetupException as e:
- detail_obj = error_details_pb2.PreconditionFailure(
+ detail_obj_precondition = error_details_pb2.PreconditionFailure(
violations=[
error_details_pb2.PreconditionFailure.Violation(
type="MISSING OR INVALID",
@@ -402,39 +469,48 @@ def error_wrapper(request, context):
]
)
abort_with_rich_error_status(
- context, e.grpc_error_code, str(e), detail_obj
+ context,
+ int(e.grpc_error_code),
+ str(e),
+ detail_obj_precondition,
)
except DruncNotImplementedException as e:
- detail_obj = error_details_pb2.ErrorInfo(
+ detail_obj_not_implemented = error_details_pb2.ErrorInfo(
reason="NOT_IMPLEMENTED",
domain="server",
metadata={},
)
abort_with_rich_error_status(
- context, e.grpc_error_code, str(e), detail_obj
+ context,
+ int(e.grpc_error_code),
+ str(e),
+ detail_obj_not_implemented,
)
except DruncCommandException as e:
exception_data = e.detail_kwargs
- detail_obj = error_details_pb2.ErrorInfo(
+ detail_obj_command = error_details_pb2.ErrorInfo(
reason=str(e.message),
domain=str(
exception_data.get("domain", ""),
),
)
abort_with_rich_error_status(
- context, e.grpc_error_code, str(e), detail_obj
+ context,
+ int(e.grpc_error_code),
+ str(e),
+ detail_obj_command,
)
except Exception as e:
# Fallback
- detail_obj = error_details_pb2.ErrorInfo(
+ detail_obj_fallback = error_details_pb2.ErrorInfo(
reason="Unexpected error",
domain="server",
metadata={"original_error": str(type(e))},
)
abort_with_rich_error_status(
- context, code_pb2.INTERNAL, str(e), detail_obj
+ context, int(code_pb2.INTERNAL), str(e), detail_obj_fallback
)
if handler.unary_unary:
diff --git a/src/drunc/utils/shell_utils.py b/src/drunc/utils/shell_utils.py
index d1443968a..02c9ab46f 100644
--- a/src/drunc/utils/shell_utils.py
+++ b/src/drunc/utils/shell_utils.py
@@ -1,6 +1,9 @@
+"""Shell utilities for DRUNC."""
+
import abc
import getpass
-from collections.abc import Mapping
+from collections.abc import MutableMapping
+from typing import Callable, ParamSpec, Protocol, TypeVar, cast
import click
from druncschema.token_pb2 import Token
@@ -10,13 +13,80 @@
from drunc.utils.utils import get_logger
+class CommandLike(Protocol):
+ """Protocol for command-like objects."""
+
+ name: str
+
+
+class SequenceLike(Protocol):
+ """Protocol for sequence-like objects."""
+
+ id: str
+
+
+class FSMDescriptionLike(Protocol):
+ """Protocol for FSM description-like objects."""
+
+ commands: list[CommandLike]
+ sequences: list[SequenceLike]
+
+
+class DescribeFSMReplyLike(Protocol):
+ """Protocol for describe FSM reply-like objects."""
+
+ description: FSMDescriptionLike
+
+
+class StatusLike(Protocol):
+ """Protocol for status-like objects."""
+
+ state: str
+ in_error: bool
+
+
+class StatusReplyLike(Protocol):
+ """Protocol for status reply-like objects."""
+
+ status: StatusLike
+
+
+class ControllerDriverProtocol(Protocol):
+ """Protocol for controller driver objects."""
+
+ def status(self) -> StatusReplyLike:
+ """Get the current status.
+
+ Returns:
+ StatusReplyLike: The current status.
+ """
+ ...
+
+ def describe_fsm(self) -> DescribeFSMReplyLike:
+ """Describe the FSM.
+
+ Returns:
+ DescribeFSMReplyLike: The FSM description.
+ """
+ ...
+
+
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
class InterruptedCommand(DruncShellException):
- """This exception gets thrown if we don't want to have a full stack, but still want to interrupt a **shell** command"""
+ """Exception thrown to interrupt a shell command without a full stack trace."""
pass
def create_dummy_token_from_uname() -> Token:
+ """Create a dummy token from the current username.
+
+ Returns:
+ Token: A dummy token with the current username.
+ """
user = getpass.getuser()
return (
Token( # fake token, but should be figured out from the environment/authoriser
@@ -25,8 +95,14 @@ def create_dummy_token_from_uname() -> Token:
)
-def add_traceback_flag():
- def wrapper(f0):
+def add_traceback_flag() -> Callable[[Callable[P, R]], Callable[P, R]]:
+ """Add a traceback flag to a command.
+
+ Returns:
+ Callable: A decorator that adds the traceback flag.
+ """
+
+ def wrapper(f0: Callable[P, R]) -> Callable[P, R]:
f1 = click.option(
"-t/-nt",
"--traceback/--no-traceback",
@@ -39,14 +115,35 @@ def wrapper(f0):
class DecodedResponse:
- ## Warning! This should be kept in sync with druncschema/request_response.proto/Response class
+ """Decoded response object.
+
+ Warning: This should be kept in sync with
+ druncschema/request_response.proto/Response class
+ """
+
name = None
token = None
data = None
flag = None
- children = []
+ children: list["DecodedResponse"] = []
+
+ def __init__(
+ self,
+ name: str,
+ token: Token,
+ flag: object,
+ data: object | None = None,
+ children: list["DecodedResponse"] | None = None,
+ ) -> None:
+ """Initialize a DecodedResponse.
- def __init__(self, name, token, flag, data=None, children=None):
+ Args:
+ name: The name of the response.
+ token: The token associated with the response.
+ flag: The response flag.
+ data: The response data. Defaults to None.
+ children: Child responses. Defaults to None.
+ """
self.name = name
self.token = token
self.flag = flag
@@ -57,29 +154,56 @@ def __init__(self, name, token, flag, data=None, children=None):
self.children = children
@staticmethod
- def str(obj, prefix=""):
+ def to_string(obj: "DecodedResponse", prefix: str = "") -> str:
+ """Convert a DecodedResponse to a string representation.
+
+ Args:
+ obj: The DecodedResponse to convert.
+ prefix: A prefix to add to the string. Defaults to empty string.
+
+ Returns:
+ str: The string representation of the response.
+ """
text = (
f"{prefix} {obj.name} -> response flag={obj.flag} type={type(obj.data)}\n"
)
for v in obj.children:
if v is None:
continue
- text += DecodedResponse.str(v, prefix + " ")
+ text += DecodedResponse.to_string(v, prefix + " ")
return text
- def __str__(self):
- return DecodedResponse.str(self)
+ def __str__(self) -> str:
+ """Return string representation of the DecodedResponse.
+
+ Returns:
+ str: The string representation.
+ """
+ return DecodedResponse.to_string(self)
class ShellContext:
- def _reset(self, name: str, token_args: dict = {}, driver_args: dict = {}):
+ """Base class for shell contexts."""
+
+ def _reset(
+ self,
+ name: str,
+ token_args: dict[str, object] = {},
+ driver_args: dict[str, object] = {},
+ ) -> None:
self._console = Console()
self._token = self.create_token(**token_args)
- self._drivers: Mapping[str, object] = self.create_drivers(**driver_args)
+ self._drivers: MutableMapping[str, object] = self.create_drivers(**driver_args)
+
+ def __init__(self, *args: object, **kwargs: object) -> None:
+ """Initialize the shell context.
- def __init__(self, *args, **kwargs):
+ Args:
+ *args: Additional positional arguments.
+ **kwargs: Additional keyword arguments.
+ """
log = get_logger("utils.ShellContext")
- self.dynamic_commands = set()
+ self.dynamic_commands: set[str] = set()
try:
self.reset(*args, **kwargs)
except Exception as e:
@@ -87,27 +211,71 @@ def __init__(self, *args, **kwargs):
exit(1)
@abc.abstractmethod
- def reset(self, **kwargs):
+ def reset(self, **kwargs: object) -> None:
+ """Reset the shell context.
+
+ Args:
+ **kwargs: Additional keyword arguments.
+ """
pass
@abc.abstractmethod
- def create_drivers(self, **kwargs) -> Mapping[str, object]:
+ def create_drivers(self, **kwargs: object) -> MutableMapping[str, object]:
+ """Create drivers for the context.
+
+ Args:
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ MutableMapping[str, object]: A mapping of driver names to driver objects.
+ """
pass
@abc.abstractmethod
- def create_token(self, **kwargs) -> Token:
+ def create_token(self, **kwargs: object) -> Token:
+ """Create a token for the context.
+
+ Args:
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ Token: A token object.
+ """
pass
@abc.abstractmethod
def terminate(self) -> None:
+ """Terminate the shell context."""
pass
def set_driver(self, name: str, driver: object) -> None:
+ """Set a driver in the context.
+
+ Args:
+ name: The name of the driver.
+ driver: The driver object.
+
+ Raises:
+ DruncShellException: If a driver with the same name already exists.
+ """
if name in self._drivers:
raise DruncShellException(f"Driver {name} already present in this context")
self._drivers[name] = driver
- def get_driver(self, name: str = None, quiet_fail: bool = False) -> object:
+ def get_driver(self, name: str | None = None, quiet_fail: bool = False) -> object:
+ """Get a driver from the context.
+
+ Args:
+ name: The name of the driver. If None, returns the only driver if there is exactly one.
+ quiet_fail: If True, return None on failure instead of raising an exception.
+
+ Returns:
+ object: The driver object, or None if quiet_fail is True and the driver is not found.
+
+ Raises:
+ DruncShellException: If there are multiple drivers and no name is specified.
+ SystemExit: If the driver is not found and quiet_fail is False.
+ """
try:
if name:
return self._drivers[name]
@@ -127,9 +295,22 @@ def get_driver(self, name: str = None, quiet_fail: bool = False) -> object:
) # used to avoid having to catch multiple Attribute errors when this function gets called
def has_driver(self, name: str) -> bool:
+ """Check if a driver exists in the context.
+
+ Args:
+ name: The name of the driver.
+
+ Returns:
+ bool: True if the driver exists, False otherwise.
+ """
return name in self._drivers
def delete_driver(self, name: str) -> None:
+ """Delete a driver from the context.
+
+ Args:
+ name: The name of the driver to delete.
+ """
log = get_logger("utils.ShellContext")
if name in self._drivers:
log.info(f"You will not be able to issue commands to the {name} anymore.")
@@ -137,18 +318,37 @@ def delete_driver(self, name: str) -> None:
log.info(f"{name.capitalize()} driver has been deleted.")
def get_token(self) -> Token:
+ """Get the token from the context.
+
+ Returns:
+ Token: The token object.
+ """
return self._token
- def print(self, *args, **kwargs) -> None:
- self._console.print(*args, **kwargs) # rich tables require console printing
+ def print(self, *args: object, **kwargs: object) -> None:
+ """Print to the console.
+
+ Args:
+ *args: Positional arguments to pass to the console.
+ **kwargs: Keyword arguments to pass to the console.
+ """
+ self._console.print(*args, **kwargs) # type: ignore[arg-type]
+
+ def rule(self, *args: object, **kwargs: object) -> None:
+ """Print a rule to the console.
- def rule(self, *args, **kwargs) -> None:
- self._console.rule(*args, **kwargs)
+ Args:
+ *args: Positional arguments to pass to the console.
+ **kwargs: Keyword arguments to pass to the console.
+ """
+ self._console.rule(*args, **kwargs) # type: ignore[arg-type]
def print_status_summary(self) -> None:
+ """Print a summary of the FSM status and available transitions."""
log = get_logger("utils.ShellContext")
- status = self.get_driver("controller").status().status
- describe_fsm = self.get_driver("controller").describe_fsm().description
+ controller = cast(ControllerDriverProtocol, self.get_driver("controller"))
+ status = controller.status().status
+ describe_fsm = controller.describe_fsm().description
current_state = status.state
if status.in_error:
log.error(
diff --git a/src/drunc/utils/utils.py b/src/drunc/utils/utils.py
index 08d085670..5688e8165 100644
--- a/src/drunc/utils/utils.py
+++ b/src/drunc/utils/utils.py
@@ -1,3 +1,5 @@
+"""A set of utility functions for drunc."""
+
import ctypes
import logging
import os
@@ -11,11 +13,12 @@
from contextlib import closing
from datetime import datetime
from enum import Enum
-from urllib.parse import urlparse
+from typing import Protocol, cast
+from urllib.parse import ParseResult, urlparse
-from click import BadParameter
+from click import BadParameter, Context, Parameter
from daqpytools.logging import get_daq_logger, setup_root_logger
-from requests import delete, get, patch, post
+from requests import Response, delete, get, patch, post
from rich.progress import (
BarColumn,
Progress,
@@ -34,8 +37,8 @@
def get_root_logger(log_level: str) -> logging.Logger:
- """
- Set up the base logger which all other loggers will inherit.
+ """Set up the base logger which all other loggers will inherit.
+
This base logger is named the 'drunc' logger, and functions similarly to the root
logger. It should have no handlers attached to it.
@@ -49,21 +52,33 @@ def get_root_logger(log_level: str) -> logging.Logger:
return setup_root_logger("drunc", log_level)
-def get_logger(logger_name: str, *args, **kwargs) -> logging.Logger:
- """Returns / constructs default logging instances. Prepends all loggers with 'drunc'
- to inherit from the root 'drunc' logger.
- Wraps to the daqpytools implementation, see for more details
-
- Args:
- logger_name (str): Name of the logger
- args, kwargs: Passed without modification to the daqpytools implementation
- """
- return get_daq_logger(f"drunc.{logger_name}", *args, **kwargs)
+def get_logger(
+ logger_name: str,
+ log_level: int | str = logging.NOTSET,
+ use_parent_handlers: bool = True,
+ rich_handler: bool = False,
+ file_handler_path: str | None = None,
+ stream_handlers: bool = False,
+ ers_kafka_session: str | None = None,
+ throttle: bool = False,
+ **extras: object,
+) -> logging.Logger:
+ """Get a logger instance for the given logger name."""
+ return get_daq_logger(
+ f"drunc.{logger_name}",
+ log_level,
+ use_parent_handlers,
+ rich_handler,
+ file_handler_path,
+ stream_handlers,
+ ers_kafka_session,
+ throttle,
+ **extras,
+ )
def strip_non_drunc_loggers() -> None:
- """
- Strip out all the basicConfig handlers from other repositories, which define
+ """Strip out all the basicConfig handlers from other repositories, which define
handlers with the root logger.
"""
root = logging.getLogger()
@@ -71,37 +86,92 @@ def strip_non_drunc_loggers() -> None:
root.handlers.clear()
-def get_random_string(length):
+def get_random_string(length: int) -> str:
+ """Generate a random string of lowercase ASCII letters.
+
+ Args:
+ length (int): The desired length of the random string.
+
+ Returns:
+ str: A random string of the specified length.
+ """
letters = string.ascii_lowercase
return "".join(random.choice(letters) for i in range(length))
-def regex_match(regex, string):
+def regex_match(regex: str, string: str) -> bool:
+ """Check if a regex pattern matches a string.
+
+ Args:
+ regex (str): The regular expression pattern.
+ string (str): The string to match against.
+
+ Returns:
+ bool: True if the pattern matches, False otherwise.
+ """
return re.match(regex, string) is not None
-def get_new_port():
+def get_new_port() -> int:
+ """Get an available port number.
+
+ Returns:
+ int: An available port number.
+ """
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- return s.getsockname()[1]
+ return int(s.getsockname()[1])
+
+
+def now_str(posix_friendly: bool = False) -> str:
+ """Get the current time as a formatted string.
+ Args:
+ posix_friendly (bool): If True, use POSIX-friendly format. Defaults to False.
-def now_str(posix_friendly=False):
+ Returns:
+ str: The current time as a formatted string.
+ """
if not posix_friendly:
return datetime.now().strftime("%m/%d/%Y,%H:%M:%S")
else:
return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
-def expand_path(path, turn_to_abs_path=False):
+def expand_path(path: str, turn_to_abs_path: bool = False) -> str:
+ """Expand a path with user and environment variables.
+
+ Args:
+ path (str): The path to expand.
+ turn_to_abs_path (bool): If True, also convert to absolute path.
+ Defaults to False.
+
+ Returns:
+ str: The expanded path.
+ """
if turn_to_abs_path:
return os.path.abspath(os.path.expanduser(os.path.expandvars(path)))
return os.path.expanduser(os.path.expandvars(path))
-def validate_command_facility(ctx, param, value):
- parsed = ""
+def validate_command_facility(
+ ctx: Context | None, param: Parameter | None, value: str
+) -> str:
+ """Validate a command facility parameter.
+
+ Args:
+ ctx (Any): Click context.
+ param (Any): Click parameter.
+ value (str): The value to validate.
+
+ Returns:
+ str: The validated netloc.
+
+ Raises:
+ BadParameter: If the value is invalid.
+ """
+ parsed: ParseResult
try:
parsed = urlparse(value)
except Exception as e:
@@ -126,8 +196,7 @@ def validate_command_facility(ctx, param, value):
def address_regex(address: str, hostname_or_ip: str) -> str:
- """
- Replace 127.x.x.x and 0.x.x.x IPs with the provided hostname
+ """Replace 127.x.x.x and 0.x.x.x IPs with the provided hostname.
This is useful when a service binds to localhost or 127.x.x.x, but we
want to access it using the hostname or network IP.
@@ -140,7 +209,7 @@ def address_regex(address: str, hostname_or_ip: str) -> str:
str: The address with 127.x.x.x and 0.x.x.x replaced by the hostname or IP.
"""
- ip_match: re.Match = re.search(
+ ip_match: re.Match[str] | None = re.search(
r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)",
address,
)
@@ -201,7 +270,15 @@ def resolve_localhost_and_127_ip_to_network_ip(address: str) -> str:
return address_regex(address, this_ip)
-def host_is_local(host):
+def host_is_local(host: str) -> bool:
+ """Check if a host is local.
+
+ Args:
+ host (str): The hostname or IP to check.
+
+ Returns:
+ bool: True if the host is local, False otherwise.
+ """
if host in [
"localhost",
socket.gethostname(),
@@ -215,15 +292,21 @@ def host_is_local(host):
return False
-def pid_info_str():
+def pid_info_str() -> str:
+ """Get a string with process ID information.
+
+ Returns:
+ str: A string containing the parent and current process IDs.
+ """
return f"Parent's PID: {os.getppid()} | This PID: {os.getpid()}"
-def ignore_sigint_sighandler():
+def ignore_sigint_sighandler() -> None:
+ """Ignore SIGINT (Ctrl+C) signals."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
-def parent_death_pact(signal=signal.SIGHUP):
+def parent_death_pact(signal: int = signal.SIGHUP) -> None:
"""Commit to kill current process when parent process dies.
Each time you spawn a new process, run this to set signal
handler appropriately (e.g put it at the beginning of each
@@ -240,36 +323,82 @@ def parent_death_pact(signal=signal.SIGHUP):
class IncorrectAddress(DruncException):
+ """Exception raised when an address is invalid."""
+
pass
-def https_or_http_present(address: str):
+def https_or_http_present(address: str) -> None:
+ """Validate that an address starts with http:// or https://.
+
+ Args:
+ address (str): The address to validate.
+
+ Raises:
+ IncorrectAddress: If the address does not start with http:// or https://.
+ """
if not address.startswith("https://") and not address.startswith("http://"):
raise IncorrectAddress("Endpoint should start with http:// or https://")
-def http_post(address, data, as_json=True, ignore_errors=False, **post_kwargs):
+def http_post(
+ address: str,
+ data: object,
+ as_json: bool = True,
+ ignore_errors: bool = False,
+ **post_kwargs: object,
+) -> Response:
+ """Send an HTTP POST request.
+
+ Args:
+ address (str): The URL to send the request to.
+ data (Any): The data to send in the request body.
+ as_json (bool): If True, send data as JSON. Defaults to True.
+ ignore_errors (bool): If True, do not raise exceptions for HTTP errors. Defaults to False.
+ **post_kwargs: Additional keyword arguments to pass to requests.post.
+
+ Returns:
+ Response: The response from the server.
+ """
https_or_http_present(address)
if as_json:
- r = post(address, json=data, **post_kwargs)
+ r = post(address, json=data, **post_kwargs) # type: ignore[arg-type]
else:
- r = post(address, data=data, **post_kwargs)
+ r = post(address, data=data, **post_kwargs) # type: ignore[arg-type]
if not ignore_errors:
r.raise_for_status()
return r
-def http_get(address, data, as_json=True, ignore_errors=False, **post_kwargs):
+def http_get(
+ address: str,
+ data: object,
+ as_json: bool = True,
+ ignore_errors: bool = False,
+ **post_kwargs: object,
+) -> Response:
+ """Send an HTTP GET request.
+
+ Args:
+ address (str): The URL to send the request to.
+ data (Any): The data to send in the request body.
+ as_json (bool): If True, send data as JSON. Defaults to True.
+ ignore_errors (bool): If True, do not raise exceptions for HTTP errors. Defaults to False.
+ **post_kwargs: Additional keyword arguments to pass to requests.get.
+
+ Returns:
+ Response: The response from the server.
+ """
https_or_http_present(address)
log = get_logger("utils.http_get")
log.debug(f"GETTING {address} {data}")
if as_json:
- r = get(address, json=data, **post_kwargs)
+ r = get(address, json=data, **post_kwargs) # type: ignore[arg-type]
else:
- r = get(address, data=data, **post_kwargs)
+ r = get(address, data=data, **post_kwargs) # type: ignore[arg-type]
log.debug(r.text)
log.debug(r.status_code)
@@ -280,32 +409,71 @@ def http_get(address, data, as_json=True, ignore_errors=False, **post_kwargs):
return r
-def http_patch(address, data, as_json=True, ignore_errors=False, **post_kwargs):
+def http_patch(
+ address: str,
+ data: object,
+ as_json: bool = True,
+ ignore_errors: bool = False,
+ **post_kwargs: object,
+) -> Response:
+ """Send an HTTP PATCH request.
+
+ Args:
+ address (str): The URL to send the request to.
+ data (Any): The data to send in the request body.
+ as_json (bool): If True, send data as JSON. Defaults to True.
+ ignore_errors (bool): If True, do not raise exceptions for HTTP errors. Defaults to False.
+ **post_kwargs: Additional keyword arguments to pass to requests.patch.
+
+ Returns:
+ Response: The response from the server.
+ """
https_or_http_present(address)
if as_json:
- r = patch(address, json=data, **post_kwargs)
+ r = patch(address, json=data, **post_kwargs) # type: ignore[arg-type]
else:
- r = patch(address, data=data, **post_kwargs)
+ r = patch(address, data=data, **post_kwargs) # type: ignore[arg-type]
if not ignore_errors:
r.raise_for_status()
return r
-def http_delete(address, data, as_json=True, ignore_errors=False, **post_kwargs):
+def http_delete(
+ address: str,
+ data: object,
+ as_json: bool = True,
+ ignore_errors: bool = False,
+ **post_kwargs: object,
+) -> None:
+ """Send an HTTP DELETE request.
+
+ Args:
+ address (str): The URL to send the request to.
+ data (Any): The data to send in the request body.
+ as_json (bool): If True, send data as JSON. Defaults to True.
+ ignore_errors (bool): If True, do not raise exceptions for HTTP errors. Defaults to False.
+ **post_kwargs: Additional keyword arguments to pass to requests.delete.
+ """
https_or_http_present(address)
if as_json:
- r = delete(address, json=data, **post_kwargs)
+ r = delete(address, json=data, **post_kwargs) # type: ignore[arg-type]
else:
- r = delete(address, data=data, **post_kwargs)
+ r = delete(address, data=data, **post_kwargs) # type: ignore[arg-type]
if not ignore_errors:
r.raise_for_status()
+class _ConnectivityService(Protocol):
+ def resolve(self, name: str, message_type: str) -> list[dict[str, object]]: ...
+
+
class ControlType(Enum):
+ """Enumeration of control types for DUNE DAQ services."""
+
Unknown = 0
gRPC = 1
REST_API = 2
@@ -313,6 +481,17 @@ class ControlType(Enum):
def get_control_type_and_uri_from_cli(cli_args: list[str]) -> tuple[ControlType, str]:
+ """Extract control type and URI from CLI arguments.
+
+ Args:
+ cli_args (list[str]): The CLI arguments to parse.
+
+ Returns:
+ tuple[ControlType, str]: A tuple of (control_type, uri).
+
+ Raises:
+ DruncSetupException: If protocol is not 'grpc://' or 'rest://'.
+ """
for arg in cli_args:
if arg.startswith("rest://"):
uri = arg.replace("rest://", "")
@@ -326,18 +505,34 @@ def get_control_type_and_uri_from_cli(cli_args: list[str]) -> tuple[ControlType,
def get_control_type_and_uri_from_connectivity_service(
- connectivity_service,
+ connectivity_service: _ConnectivityService,
name: str,
timeout: int = 10, # seconds
retry_wait: float = 0.1, # seconds
progress_bar: bool = False,
- title: str = None,
+ title: str | None = None,
) -> tuple[ControlType, str]:
- uris = []
+ """Get control type and URI from connectivity service.
+
+ Args:
+ connectivity_service (object): The connectivity service instance.
+ name (str): The name of the service to resolve.
+ timeout (int): Maximum time to wait for resolution in seconds. Defaults to 10.
+ retry_wait (float): Time to wait between retries in seconds. Defaults to 0.1.
+ progress_bar (bool): Whether to display a progress bar. Defaults to False.
+ title (str | None): Title for the progress bar. Defaults to None.
+
+ Returns:
+ tuple[ControlType, str]: A tuple of (control_type, uri).
+
+ Raises:
+ ApplicationLookupUnsuccessful: If the URI cannot be resolved.
+ """
+ uris: list[dict[str, object]] = []
logger = get_logger("utils.get_control_type_and_uri_from_connectivity_service")
start = time.time()
- elapsed = 0
+ elapsed = 0.0
if progress_bar:
with Progress(
@@ -395,21 +590,25 @@ def get_control_type_and_uri_from_connectivity_service(
f"Could not resolve the URI for '{name}_control' in the connectivity service, got response {uris}"
)
- uri = uris[0]["uri"]
+ uri = cast(str, uris[0]["uri"])
return get_control_type_and_uri_from_cli([uri])
-def print_with_timestamp(message):
+def print_with_timestamp(message: str) -> None:
+ """Print a message with a timestamp.
+
+ Args:
+ message (str): The message to print.
+ """
now = datetime.now()
now_str = now.isoformat()
print(f"{now_str}: {message}")
def format_name_for_cli(name: str) -> str:
- """
- Format a command name or argument name to be CLI-friendly by replacing underscores
- with hyphens and converting to lowercase.
+ """Format a command name or argument name to be CLI-friendly by replacing
+ underscores with hyphens and converting to lowercase.
Args:
name (str): The original command name.
@@ -421,16 +620,16 @@ def format_name_for_cli(name: str) -> str:
def resolve_target_ip(host: str) -> str | None:
- """
- Intelligently resolves the host.
+ """Intelligently resolve a host to its IP address.
+
If host is 'localhost' or '127.0.0.1', it finds the actual LAN IP.
Args:
- host - the name of the host to reolve to LAN IP
+ host (str): The name of the host to resolve to LAN IP.
Returns:
- str - LAN IP of the host
- None - if the host could not be resolved, None is returned
+ str: LAN IP of the host.
+ None: If the host could not be resolved, None is returned.
"""
log = get_logger("utils.resolve_target_ip")
@@ -448,7 +647,7 @@ def resolve_target_ip(host: str) -> str | None:
# blocked from sending data outside of the LAN. Use connect - this does
# send any data, just establishes the connection.
s.connect(("10.255.255.255", 1))
- return s.getsockname()[0]
+ return str(s.getsockname()[0])
except Exception:
# Return the loopback address.
log.warning(f"Failed to resolve the IP address of {host}")
@@ -464,17 +663,15 @@ def resolve_target_ip(host: str) -> str | None:
def is_port_available(host: str, port: int, timeout: int = 2) -> bool:
- """
- Check if the given port number on a specified host is available.
+ """Check if the given port number on a specified host is available.
Args:
- host - the host name to check
- port - the port number to check
- timeout - timeout of attempting to establish the connection
+ host (str): The host name to check.
+ port (int): The port number to check.
+ timeout (int): Timeout of attempting to establish the connection. Defaults to 2.
Returns:
- true - the port is available
- false - the port is not available
+ bool: True if the port is available, False otherwise.
"""
log = get_logger("utils.is_port_available")
@@ -511,13 +708,12 @@ def is_port_available(host: str, port: int, timeout: int = 2) -> bool:
def file_is_read_only(file_path: str) -> bool:
- """
- Runs checks to see if the file path is read only.
+ """Check if a file is read-only.
Args:
- file_path - path of file to read
+ file_path (str): Path of file to check.
Returns:
- bool - true is file is read only, false otherwise
+ bool: True if the file is read-only, False otherwise.
"""
return not os.access(file_path, os.W_OK)