diff --git a/Cargo.lock b/Cargo.lock index fb996fa5..e10e688c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4936,6 +4936,7 @@ dependencies = [ "tokio-test", "tokio-tungstenite", "tokio-util", + "trzcina", "url", ] @@ -4945,6 +4946,7 @@ version = "4.0.0" dependencies = [ "actix-web", "anyhow", + "async-trait", "log", "nanoid", "paddler", @@ -4952,6 +4954,7 @@ dependencies = [ "tempfile", "tokio", "tokio-util", + "trzcina", ] [[package]] @@ -4984,6 +4987,7 @@ dependencies = [ "paddler_types", "tokio", "tokio-util", + "trzcina", ] [[package]] @@ -7367,6 +7371,20 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "trzcina" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cd35208b88a2f0f7eb890af3034ba5676ec11cd9a61ef73248b400ae452fea" +dependencies = [ + "anyhow", + "async-trait", + "futures-util", + "log", + "tokio", + "tokio-util", +] + [[package]] name = "ttf-parser" version = "0.25.1" diff --git a/Cargo.toml b/Cargo.toml index ba364156..348c77c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,7 @@ tokio-test = "0.4.4" tokio-tungstenite = "0.28" tokio-util = "0.7" thiserror = "2" +trzcina = "=0.2.1" url = { version = "2.5", features = ["serde"] } paddler = { version = "4.0.0", path = "paddler" } paddler_bootstrap = { version = "4.0.0", path = "paddler_bootstrap" } diff --git a/paddler/Cargo.toml b/paddler/Cargo.toml index 117950c3..a6277801 100644 --- a/paddler/Cargo.toml +++ b/paddler/Cargo.toml @@ -51,6 +51,7 @@ tokio = { workspace = true } tokio-stream = { workspace = true } tokio-tungstenite = { workspace = true } tokio-util = { workspace = true } +trzcina = { workspace = true } url = { workspace = true } # web dashboard deps diff --git a/paddler/src/agent/llamacpp_arbiter_service.rs b/paddler/src/agent/llamacpp_arbiter_service.rs index 8d60fb56..16096d78 100644 --- a/paddler/src/agent/llamacpp_arbiter_service.rs +++ b/paddler/src/agent/llamacpp_arbiter_service.rs @@ -12,6 +12,7 @@ use tokio::time::Duration; use tokio::time::MissedTickBehavior; use tokio::time::interval; use tokio_util::sync::CancellationToken; +use trzcina::Service; use crate::agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; use crate::agent::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; @@ -24,7 +25,6 @@ use crate::agent::generate_embedding_batch_request::GenerateEmbeddingBatchReques use crate::agent::model_metadata_holder::ModelMetadataHolder; use crate::agent_applicable_state::AgentApplicableState; use crate::agent_applicable_state_holder::AgentApplicableStateHolder; -use crate::service::Service; use crate::slot_aggregated_status_manager::SlotAggregatedStatusManager; pub struct LlamaCppArbiterService { diff --git a/paddler/src/agent/management_socket_client_service.rs b/paddler/src/agent/management_socket_client_service.rs index 155f2f04..ebe16a5b 100644 --- a/paddler/src/agent/management_socket_client_service.rs +++ b/paddler/src/agent/management_socket_client_service.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use actix_web::rt; use actix_web::web::Bytes; use anyhow::Context; use anyhow::Result; @@ -18,6 +17,7 @@ use tokio::time::interval; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::protocol::Message; use tokio_util::sync::CancellationToken; +use trzcina::Service; use paddler_types::agent_desired_state::AgentDesiredState; use paddler_types::jsonrpc::Error as JsonRpcError; @@ -42,7 +42,6 @@ use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonr use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonrpc::notification_params::RegisterAgentParams; use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonrpc::notification_params::UpdateAgentStatusParams; use crate::produces_snapshot::ProducesSnapshot; -use crate::service::Service; use crate::slot_aggregated_status::SlotAggregatedStatus; use crate::subscribes_to_updates::SubscribesToUpdates as _; @@ -266,7 +265,7 @@ impl ManagementSocketClientService { Message::Text(text) => { let connection_close = incoming_message_context.connection_close.clone(); - rt::spawn(async move { + tokio::spawn(async move { tokio::select! { () = connection_close.cancelled() => { info!("Connection close signal received, shutting down"); @@ -327,7 +326,7 @@ impl ManagementSocketClientService { let forward_connection_close = connection_close.clone(); let forward_shutdown = shutdown.clone(); - let message_forward_handle = rt::spawn(async move { + let message_forward_handle = tokio::spawn(async move { loop { tokio::select! { () = forward_connection_close.cancelled() => { diff --git a/paddler/src/agent/reconciliation_service.rs b/paddler/src/agent/reconciliation_service.rs index e4a14574..0e134900 100644 --- a/paddler/src/agent/reconciliation_service.rs +++ b/paddler/src/agent/reconciliation_service.rs @@ -9,11 +9,11 @@ use tokio::time::Duration; use tokio::time::MissedTickBehavior; use tokio::time::interval; use tokio_util::sync::CancellationToken; +use trzcina::Service; use crate::agent_applicable_state_holder::AgentApplicableStateHolder; use crate::agent_issue_fix::AgentIssueFix; use crate::converts_to_applicable_state::ConvertsToApplicableState as _; -use crate::service::Service; use crate::slot_aggregated_status::SlotAggregatedStatus; pub struct ReconciliationService { diff --git a/paddler/src/balancer/compatibility/openai_service/mod.rs b/paddler/src/balancer/compatibility/openai_service/mod.rs index 03ea1cf7..b39d0041 100644 --- a/paddler/src/balancer/compatibility/openai_service/mod.rs +++ b/paddler/src/balancer/compatibility/openai_service/mod.rs @@ -10,6 +10,8 @@ use actix_web::web::Data; use anyhow::Result; use async_trait::async_trait; use tokio_util::sync::CancellationToken; +use trzcina::Service; +use trzcina::ServiceShutdownOptions; use crate::balancer::buffered_request_manager::BufferedRequestManager; use crate::balancer::compatibility::openai_service::app_data::AppData; @@ -17,12 +19,12 @@ use crate::balancer::compatibility::openai_service::configuration::Configuration use crate::balancer::http_route as common_http_route; use crate::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; use crate::create_cors_middleware::create_cors_middleware; -use crate::service::Service; pub struct OpenAIService { pub buffered_request_manager: Arc, pub inference_service_configuration: InferenceServiceConfiguration, pub openai_service_configuration: OpenAIServiceConfiguration, + pub shutdown_options: ServiceShutdownOptions, } #[async_trait] @@ -54,6 +56,7 @@ impl Service for OpenAIService { .shutdown_signal(async move { shutdown.cancelled().await; }) + .shutdown_timeout(self.shutdown_options.cooperative_deadline.as_secs()) .disable_signals() .bind(self.openai_service_configuration.addr) .expect("Unable to bind server to address") diff --git a/paddler/src/balancer/inference_service/mod.rs b/paddler/src/balancer/inference_service/mod.rs index 844a004a..e22b535a 100644 --- a/paddler/src/balancer/inference_service/mod.rs +++ b/paddler/src/balancer/inference_service/mod.rs @@ -10,6 +10,8 @@ use actix_web::web::Data; use anyhow::Result; use async_trait::async_trait; use tokio_util::sync::CancellationToken; +use trzcina::Service; +use trzcina::ServiceShutdownOptions; use crate::balancer::agent_controller_pool::AgentControllerPool; use crate::balancer::buffered_request_manager::BufferedRequestManager; @@ -20,13 +22,13 @@ use crate::balancer::inference_service::configuration::Configuration as Inferenc use crate::balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; use crate::create_cors_middleware::create_cors_middleware; -use crate::service::Service; pub struct InferenceService { pub agent_controller_pool: Arc, pub balancer_applicable_state_holder: Arc, pub buffered_request_manager: Arc, pub configuration: InferenceServiceConfiguration, + pub shutdown_options: ServiceShutdownOptions, #[cfg(feature = "web_admin_panel")] pub web_admin_panel_service_configuration: Option, } @@ -70,6 +72,7 @@ impl Service for InferenceService { .shutdown_signal(async move { shutdown.cancelled().await; }) + .shutdown_timeout(self.shutdown_options.cooperative_deadline.as_secs()) .disable_signals() .bind(self.configuration.addr) .expect("Unable to bind server to address") diff --git a/paddler/src/balancer/management_service/mod.rs b/paddler/src/balancer/management_service/mod.rs index 6c407c2d..72131e0f 100644 --- a/paddler/src/balancer/management_service/mod.rs +++ b/paddler/src/balancer/management_service/mod.rs @@ -10,6 +10,8 @@ use actix_web::web::Data; use anyhow::Result; use async_trait::async_trait; use tokio_util::sync::CancellationToken; +use trzcina::Service; +use trzcina::ServiceShutdownOptions; use crate::balancer::agent_controller_pool::AgentControllerPool; use crate::balancer::buffered_request_manager::BufferedRequestManager; @@ -25,7 +27,6 @@ use crate::balancer::state_database::StateDatabase; use crate::balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; use crate::create_cors_middleware::create_cors_middleware; -use crate::service::Service; pub struct ManagementService { pub agent_controller_pool: Arc, @@ -36,6 +37,7 @@ pub struct ManagementService { pub embedding_sender_collection: Arc, pub generate_tokens_sender_collection: Arc, pub model_metadata_sender_collection: Arc, + pub shutdown_options: ServiceShutdownOptions, pub state_database: Arc, pub statsd_prefix: String, #[cfg(feature = "web_admin_panel")] @@ -95,6 +97,7 @@ impl Service for ManagementService { .shutdown_signal(async move { shutdown.cancelled().await; }) + .shutdown_timeout(self.shutdown_options.cooperative_deadline.as_secs()) .disable_signals() .bind(self.configuration.addr) .expect("Unable to bind server to address") diff --git a/paddler/src/balancer/reconciliation_service.rs b/paddler/src/balancer/reconciliation_service.rs index 11f95b03..928ab0f7 100644 --- a/paddler/src/balancer/reconciliation_service.rs +++ b/paddler/src/balancer/reconciliation_service.rs @@ -9,11 +9,11 @@ use tokio::time::Duration; use tokio::time::MissedTickBehavior; use tokio::time::interval; use tokio_util::sync::CancellationToken; +use trzcina::Service; use crate::balancer::agent_controller_pool::AgentControllerPool; use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; use crate::converts_to_applicable_state::ConvertsToApplicableState as _; -use crate::service::Service; use crate::sets_desired_state::SetsDesiredState as _; pub struct ReconciliationService { diff --git a/paddler/src/balancer/statsd_service/mod.rs b/paddler/src/balancer/statsd_service/mod.rs index d1810d47..509d734f 100644 --- a/paddler/src/balancer/statsd_service/mod.rs +++ b/paddler/src/balancer/statsd_service/mod.rs @@ -12,12 +12,12 @@ use log::error; use tokio::time::MissedTickBehavior; use tokio::time::interval; use tokio_util::sync::CancellationToken; +use trzcina::Service; use crate::balancer::agent_controller_pool::AgentControllerPool; use crate::balancer::agent_controller_pool_total_slots::AgentControllerPoolTotalSlots; use crate::balancer::buffered_request_manager::BufferedRequestManager; use crate::balancer::statsd_service::configuration::Configuration as StatsdServiceConfiguration; -use crate::service::Service; pub struct StatsdService { pub agent_controller_pool: Arc, diff --git a/paddler/src/balancer/web_admin_panel_service/mod.rs b/paddler/src/balancer/web_admin_panel_service/mod.rs index 715464a9..fff4829d 100644 --- a/paddler/src/balancer/web_admin_panel_service/mod.rs +++ b/paddler/src/balancer/web_admin_panel_service/mod.rs @@ -9,13 +9,15 @@ use actix_web::web::Data; use anyhow::Result; use async_trait::async_trait; use tokio_util::sync::CancellationToken; +use trzcina::Service; +use trzcina::ServiceShutdownOptions; use crate::balancer::web_admin_panel_service::app_data::AppData; use crate::balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; -use crate::service::Service; pub struct WebAdminPanelService { pub configuration: WebAdminPanelServiceConfiguration, + pub shutdown_options: ServiceShutdownOptions, } #[async_trait] @@ -40,6 +42,7 @@ impl Service for WebAdminPanelService { .shutdown_signal(async move { shutdown.cancelled().await; }) + .shutdown_timeout(self.shutdown_options.cooperative_deadline.as_secs()) .disable_signals() .bind(self.configuration.addr) .expect("Unable to bind server to address") diff --git a/paddler/src/lib.rs b/paddler/src/lib.rs index b6516edb..5f97aef1 100644 --- a/paddler/src/lib.rs +++ b/paddler/src/lib.rs @@ -29,8 +29,6 @@ pub mod resolve_desired_model; pub mod resolved_socket_addr; pub mod resolves_model_source; pub mod sends_rpc_message; -pub mod service; -pub mod service_manager; pub mod sets_desired_state; pub mod slot_aggregated_status; pub mod slot_aggregated_status_download_progress; diff --git a/paddler/src/service.rs b/paddler/src/service.rs deleted file mode 100644 index 75ca6a35..00000000 --- a/paddler/src/service.rs +++ /dev/null @@ -1,10 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use tokio_util::sync::CancellationToken; - -#[async_trait] -pub trait Service: Send + 'static { - fn name(&self) -> &'static str; - - async fn run(&mut self, shutdown: CancellationToken) -> Result<()>; -} diff --git a/paddler/src/service_manager.rs b/paddler/src/service_manager.rs deleted file mode 100644 index b07eec1a..00000000 --- a/paddler/src/service_manager.rs +++ /dev/null @@ -1,337 +0,0 @@ -use std::collections::BTreeSet; - -use actix_web::rt; -use actix_web::rt::task::JoinError; -use anyhow::Result; -use anyhow::anyhow; -use futures::stream::FuturesUnordered; -use futures::stream::StreamExt; -use log::error; -use log::info; -use tokio_util::sync::CancellationToken; - -use crate::service::Service; - -struct ServiceDrainEvent { - join_result: Result, JoinError>, - name: String, -} - -impl ServiceDrainEvent { - fn into_service_error(self) -> Option { - match self.join_result { - Ok(Ok(())) => None, - Ok(Err(service_error)) => Some(service_error), - Err(join_error) => Some(anyhow!("service task panicked: {join_error}")), - } - } -} - -#[derive(Default)] -pub struct ServiceManager { - services: Vec>, -} - -impl ServiceManager { - pub fn add_service(&mut self, service: TService) { - self.services.push(Box::new(service)); - } - - pub async fn run_forever(self, cancellation_token: CancellationToken) -> Result<()> { - let mut service_handles = FuturesUnordered::new(); - let mut pending_service_names: BTreeSet = BTreeSet::new(); - - for mut service in self.services { - let service_name = service.name().to_owned(); - - pending_service_names.insert(service_name.clone()); - - let task_token = cancellation_token.clone(); - let event_name = service_name.clone(); - - service_handles.push(async move { - let join_result = rt::spawn(async move { - info!("{service_name}: Starting"); - - let result = service.run(task_token).await; - - match &result { - Ok(()) => info!("{service_name}: Stopped"), - Err(service_error) => error!("{service_name}: {service_error}"), - } - - result - }) - .await; - - ServiceDrainEvent { - join_result, - name: event_name, - } - }); - } - - let mut first_error: Option = None; - - tokio::select! { - () = cancellation_token.cancelled() => {} - Some(event) = service_handles.next() => { - pending_service_names.remove(&event.name); - first_error = event.into_service_error(); - } - } - - info!( - "run_forever: shutdown triggered; draining {} service(s): {:?}", - pending_service_names.len(), - pending_service_names - ); - - cancellation_token.cancel(); - - while let Some(event) = service_handles.next().await { - pending_service_names.remove(&event.name); - - info!( - "run_forever: {name} drained; remaining: {pending_service_names:?}", - name = event.name - ); - - if let Some(service_error) = event.into_service_error() - && first_error.is_none() - { - first_error = Some(service_error); - } - } - - info!("run_forever: all services drained"); - - first_error.map_or_else(|| Ok(()), Err) - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use async_trait::async_trait; - use thiserror::Error; - use tokio::sync::Notify; - - use super::*; - - #[derive(Debug, Error)] - #[error("intentional test failure")] - struct TestFailureMarker; - - struct NeverExitingService { - ready: Arc, - } - - #[async_trait] - impl Service for NeverExitingService { - fn name(&self) -> &'static str { - "test::never_exiting_service" - } - - async fn run(&mut self, shutdown: CancellationToken) -> Result<()> { - self.ready.notify_one(); - - shutdown.cancelled().await; - - Ok(()) - } - } - - struct FailingOnDemandService { - fail: Arc, - } - - #[async_trait] - impl Service for FailingOnDemandService { - fn name(&self) -> &'static str { - "test::failing_on_demand_service" - } - - async fn run(&mut self, _shutdown: CancellationToken) -> Result<()> { - self.fail.notified().await; - - Err(TestFailureMarker.into()) - } - } - - struct ImmediatelyFailingService; - - #[async_trait] - impl Service for ImmediatelyFailingService { - fn name(&self) -> &'static str { - "test::immediately_failing_service" - } - - async fn run(&mut self, _shutdown: CancellationToken) -> Result<()> { - Err(TestFailureMarker.into()) - } - } - - struct ImmediatelySuccessService; - - #[async_trait] - impl Service for ImmediatelySuccessService { - fn name(&self) -> &'static str { - "test::immediately_success_service" - } - - async fn run(&mut self, _shutdown: CancellationToken) -> Result<()> { - Ok(()) - } - } - - #[actix_web::test] - async fn err_exit_cascades_to_peers() -> Result<()> { - let ready = Arc::new(Notify::new()); - let fail = Arc::new(Notify::new()); - let shutdown = CancellationToken::new(); - - let mut manager = ServiceManager::default(); - manager.add_service(NeverExitingService { - ready: ready.clone(), - }); - manager.add_service(FailingOnDemandService { fail: fail.clone() }); - - let manager_handle = actix_web::rt::spawn(manager.run_forever(shutdown)); - - ready.notified().await; - fail.notify_one(); - - let error = match manager_handle.await? { - Ok(()) => { - return Err(anyhow!( - "run_forever should surface the failing service's error" - )); - } - Err(service_error) => service_error, - }; - - error - .downcast_ref::() - .ok_or_else(|| anyhow!("expected TestFailureMarker, got: {error:?}"))?; - - Ok(()) - } - - #[actix_web::test] - async fn ok_exit_cascades_to_peers() -> Result<()> { - let ready = Arc::new(Notify::new()); - let shutdown = CancellationToken::new(); - - let mut manager = ServiceManager::default(); - manager.add_service(NeverExitingService { - ready: ready.clone(), - }); - manager.add_service(ImmediatelySuccessService); - - let manager_handle = actix_web::rt::spawn(manager.run_forever(shutdown)); - - ready.notified().await; - - manager_handle.await??; - - Ok(()) - } - - #[actix_web::test] - async fn fast_failure_cascades_to_late_subscriber() -> Result<()> { - let ready = Arc::new(Notify::new()); - let shutdown = CancellationToken::new(); - - let mut manager = ServiceManager::default(); - manager.add_service(ImmediatelyFailingService); - manager.add_service(NeverExitingService { - ready: ready.clone(), - }); - - let manager_handle = actix_web::rt::spawn(manager.run_forever(shutdown)); - - ready.notified().await; - - let error = match manager_handle.await? { - Ok(()) => { - return Err(anyhow!( - "run_forever should surface the failing service's error" - )); - } - Err(service_error) => service_error, - }; - - error - .downcast_ref::() - .ok_or_else(|| anyhow!("expected TestFailureMarker, got: {error:?}"))?; - - Ok(()) - } - - #[actix_web::test] - async fn drains_all_services_on_external_cancel() -> Result<()> { - let ready = Arc::new(Notify::new()); - let shutdown = CancellationToken::new(); - - let mut manager = ServiceManager::default(); - manager.add_service(NeverExitingService { - ready: ready.clone(), - }); - manager.add_service(ImmediatelySuccessService); - manager.add_service(ImmediatelySuccessService); - manager.add_service(ImmediatelySuccessService); - - let manager_handle = actix_web::rt::spawn(manager.run_forever(shutdown.clone())); - - ready.notified().await; - shutdown.cancel(); - - manager_handle.await??; - - Ok(()) - } - - #[actix_web::test] - async fn all_services_exit_before_cancel_is_idempotent() -> Result<()> { - let shutdown = CancellationToken::new(); - - let mut manager = ServiceManager::default(); - manager.add_service(ImmediatelySuccessService); - manager.add_service(ImmediatelySuccessService); - manager.add_service(ImmediatelySuccessService); - - let manager_handle = actix_web::rt::spawn(manager.run_forever(shutdown)); - - manager_handle.await??; - - Ok(()) - } - - #[actix_web::test] - async fn external_shutdown_still_works() -> Result<()> { - let ready_first = Arc::new(Notify::new()); - let ready_second = Arc::new(Notify::new()); - let shutdown = CancellationToken::new(); - - let mut manager = ServiceManager::default(); - manager.add_service(NeverExitingService { - ready: ready_first.clone(), - }); - manager.add_service(NeverExitingService { - ready: ready_second.clone(), - }); - - let manager_handle = actix_web::rt::spawn(manager.run_forever(shutdown.clone())); - - ready_first.notified().await; - ready_second.notified().await; - - shutdown.cancel(); - - manager_handle.await??; - - Ok(()) - } -} diff --git a/paddler_bootstrap/Cargo.toml b/paddler_bootstrap/Cargo.toml index 4e495b64..32d86412 100644 --- a/paddler_bootstrap/Cargo.toml +++ b/paddler_bootstrap/Cargo.toml @@ -9,12 +9,14 @@ license.workspace = true [dependencies] actix-web = { workspace = true } anyhow = { workspace = true } +async-trait = { workspace = true } log = { workspace = true } nanoid = { workspace = true } paddler = { workspace = true } paddler_types = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true } +trzcina = { workspace = true } [dev-dependencies] tempfile = { workspace = true } diff --git a/paddler_bootstrap/src/agent_runner.rs b/paddler_bootstrap/src/agent_runner.rs index e168192f..758e4cde 100644 --- a/paddler_bootstrap/src/agent_runner.rs +++ b/paddler_bootstrap/src/agent_runner.rs @@ -4,9 +4,10 @@ use std::sync::Arc; use anyhow::Result; use paddler::slot_aggregated_status::SlotAggregatedStatus; use tokio_util::sync::CancellationToken; +use trzcina::ServiceManager; +use trzcina::ServiceShutdownOptions; -use crate::bootstrapped_agent_handle::BootstrappedAgentHandle; -use crate::bootstrapped_agent_handle::bootstrap_agent; +use crate::agent_service_bundle::AgentServiceBundle; use crate::service_thread::ServiceThread; pub struct AgentRunnerParams { @@ -31,13 +32,18 @@ impl AgentRunner { slots, }: AgentRunnerParams, ) -> Self { - let BootstrappedAgentHandle { - service_manager, - slot_aggregated_status, - } = bootstrap_agent(agent_name, &management_address, slots); + let bundle = AgentServiceBundle::new(agent_name, &management_address, slots); + let slot_aggregated_status = bundle.slot_aggregated_status.clone(); let thread = ServiceThread::spawn(cancellation_token, move |task_shutdown| async move { - service_manager.run_forever(task_shutdown).await + let mut service_manager = ServiceManager::default(); + service_manager.register_bundle(bundle).await?; + service_manager + .start(task_shutdown) + .run_to_completion(ServiceShutdownOptions::default()) + .await + .into_result() + .map_err(anyhow::Error::from) }); Self { diff --git a/paddler_bootstrap/src/agent_service_bundle.rs b/paddler_bootstrap/src/agent_service_bundle.rs new file mode 100644 index 00000000..ab1573bc --- /dev/null +++ b/paddler_bootstrap/src/agent_service_bundle.rs @@ -0,0 +1,103 @@ +use std::sync::Arc; + +use anyhow::Result; +use async_trait::async_trait; +use nanoid::nanoid; +use paddler::agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; +use paddler::agent::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; +use paddler::agent::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; +use paddler::agent::llamacpp_arbiter_service::LlamaCppArbiterService; +use paddler::agent::management_socket_client_service::ManagementSocketClientService; +use paddler::agent::model_metadata_holder::ModelMetadataHolder; +use paddler::agent::reconciliation_service::ReconciliationService; +use paddler::agent_applicable_state_holder::AgentApplicableStateHolder; +use paddler::slot_aggregated_status::SlotAggregatedStatus; +use paddler::slot_aggregated_status_manager::SlotAggregatedStatusManager; +use paddler_types::agent_desired_state::AgentDesiredState; +use tokio::sync::mpsc; +use trzcina::Service; +use trzcina::ServiceBundle; + +pub struct AgentServiceBundle { + pub slot_aggregated_status: Arc, + llamacpp_arbiter_service: LlamaCppArbiterService, + management_socket_client_service: ManagementSocketClientService, + reconciliation_service: ReconciliationService, +} + +impl AgentServiceBundle { + #[must_use] + pub fn new(agent_name: Option, management_address: &str, slots: i32) -> Self { + let (agent_desired_state_tx, agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let ( + continue_from_conversation_history_request_tx, + continue_from_conversation_history_request_rx, + ) = mpsc::unbounded_channel::(); + let (continue_from_raw_prompt_request_tx, continue_from_raw_prompt_request_rx) = + mpsc::unbounded_channel::(); + let (generate_embedding_batch_request_tx, generate_embedding_batch_request_rx) = + mpsc::unbounded_channel::(); + + let agent_applicable_state_holder = Arc::new(AgentApplicableStateHolder::default()); + let model_metadata_holder = Arc::new(ModelMetadataHolder::default()); + let slot_aggregated_status_manager = Arc::new(SlotAggregatedStatusManager::new(slots)); + let slot_aggregated_status = slot_aggregated_status_manager.slot_aggregated_status.clone(); + + let llamacpp_arbiter_service = LlamaCppArbiterService { + agent_applicable_state: None, + agent_applicable_state_holder: agent_applicable_state_holder.clone(), + agent_name: agent_name.clone(), + continue_from_conversation_history_request_rx, + continue_from_raw_prompt_request_rx, + desired_slots_total: slots, + generate_embedding_batch_request_rx, + continuous_batch_arbiter_handle: None, + model_metadata_holder: model_metadata_holder.clone(), + slot_aggregated_status_manager, + }; + + let management_socket_client_service = ManagementSocketClientService { + agent_applicable_state_holder: agent_applicable_state_holder.clone(), + agent_desired_state_tx, + continue_from_conversation_history_request_tx, + continue_from_raw_prompt_request_tx, + generate_embedding_batch_request_tx, + model_metadata_holder, + name: agent_name, + receive_stream_stopper_collection: Arc::default(), + slot_aggregated_status: slot_aggregated_status.clone(), + socket_url: format!( + "ws://{}/api/v1/agent_socket/{}", + management_address, + nanoid!() + ), + }; + + let reconciliation_service = ReconciliationService { + agent_applicable_state_holder, + agent_desired_state: None, + agent_desired_state_rx, + is_converted_to_applicable_state: false, + slot_aggregated_status: slot_aggregated_status.clone(), + }; + + Self { + slot_aggregated_status, + llamacpp_arbiter_service, + management_socket_client_service, + reconciliation_service, + } + } +} + +#[async_trait] +impl ServiceBundle for AgentServiceBundle { + async fn services(self) -> Result>> { + Ok(vec![ + Box::new(self.llamacpp_arbiter_service), + Box::new(self.management_socket_client_service), + Box::new(self.reconciliation_service), + ]) + } +} diff --git a/paddler_bootstrap/src/balancer_runner.rs b/paddler_bootstrap/src/balancer_runner.rs index 6e126403..76bdbb34 100644 --- a/paddler_bootstrap/src/balancer_runner.rs +++ b/paddler_bootstrap/src/balancer_runner.rs @@ -15,10 +15,11 @@ use paddler::balancer_applicable_state_holder::BalancerApplicableStateHolder; use paddler_types::balancer_desired_state::BalancerDesiredState; use tokio::sync::broadcast; use tokio_util::sync::CancellationToken; +use trzcina::ServiceManager; +use trzcina::ServiceShutdownOptions; -use crate::bootstrapped_balancer_handle::BalancerBootstrapConfig; -use crate::bootstrapped_balancer_handle::BootstrappedBalancerHandle; -use crate::bootstrapped_balancer_handle::bootstrap_balancer; +use crate::balancer_service_bundle::BalancerBootstrapConfig; +use crate::balancer_service_bundle::BalancerServiceBundle; use crate::service_thread::ServiceThread; pub struct BalancerRunnerParams { @@ -59,18 +60,15 @@ impl BalancerRunner { web_admin_panel_service_configuration, }: BalancerRunnerParams, ) -> Result { - let BootstrappedBalancerHandle { - agent_controller_pool, - balancer_applicable_state_holder, - balancer_desired_state_tx, - service_manager, - state_database, - } = bootstrap_balancer(BalancerBootstrapConfig { + let shutdown_options = ServiceShutdownOptions::default(); + + let bundle = BalancerServiceBundle::new(BalancerBootstrapConfig { buffered_request_timeout, inference_service_configuration, management_service_configuration, max_buffered_requests, openai_service_configuration, + shutdown_options: shutdown_options.clone(), state_database_type, statsd_prefix, statsd_service_configuration, @@ -79,10 +77,20 @@ impl BalancerRunner { }) .await?; - let initial_desired_state = state_database.read_balancer_desired_state().await?; + let agent_controller_pool = bundle.agent_controller_pool.clone(); + let balancer_applicable_state_holder = bundle.balancer_applicable_state_holder.clone(); + let balancer_desired_state_tx = bundle.balancer_desired_state_tx.clone(); + let initial_desired_state = bundle.initial_desired_state.clone(); let thread = ServiceThread::spawn(cancellation_token, move |task_shutdown| async move { - service_manager.run_forever(task_shutdown).await + let mut service_manager = ServiceManager::default(); + service_manager.register_bundle(bundle).await?; + service_manager + .start(task_shutdown) + .run_to_completion(shutdown_options) + .await + .into_result() + .map_err(anyhow::Error::from) }); Ok(Self { diff --git a/paddler_bootstrap/src/balancer_service_bundle.rs b/paddler_bootstrap/src/balancer_service_bundle.rs new file mode 100644 index 00000000..4b852e10 --- /dev/null +++ b/paddler_bootstrap/src/balancer_service_bundle.rs @@ -0,0 +1,204 @@ +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use paddler::balancer::agent_controller_pool::AgentControllerPool; +use paddler::balancer::buffered_request_manager::BufferedRequestManager; +use paddler::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; +use paddler::balancer::compatibility::openai_service::OpenAIService; +use paddler::balancer::compatibility::openai_service::configuration::Configuration as OpenAIServiceConfiguration; +use paddler::balancer::embedding_sender_collection::EmbeddingSenderCollection; +use paddler::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; +use paddler::balancer::inference_service::InferenceService; +use paddler::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use paddler::balancer::management_service::ManagementService; +use paddler::balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; +use paddler::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; +use paddler::balancer::reconciliation_service::ReconciliationService; +use paddler::balancer::state_database::File; +use paddler::balancer::state_database::Memory; +use paddler::balancer::state_database::StateDatabase; +use paddler::balancer::state_database_type::StateDatabaseType; +use paddler::balancer::statsd_service::StatsdService; +use paddler::balancer::statsd_service::configuration::Configuration as StatsdServiceConfiguration; +#[cfg(feature = "web_admin_panel")] +use paddler::balancer::web_admin_panel_service::WebAdminPanelService; +#[cfg(feature = "web_admin_panel")] +use paddler::balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; +use paddler::balancer_applicable_state_holder::BalancerApplicableStateHolder; +use paddler_types::balancer_desired_state::BalancerDesiredState; +use tokio::sync::broadcast; +use trzcina::Service; +use trzcina::ServiceBundle; +use trzcina::ServiceShutdownOptions; + +pub struct BalancerBootstrapConfig { + pub buffered_request_timeout: Duration, + pub inference_service_configuration: InferenceServiceConfiguration, + pub management_service_configuration: ManagementServiceConfiguration, + pub max_buffered_requests: i32, + pub openai_service_configuration: Option, + pub shutdown_options: ServiceShutdownOptions, + pub state_database_type: StateDatabaseType, + pub statsd_prefix: String, + pub statsd_service_configuration: Option, + #[cfg(feature = "web_admin_panel")] + pub web_admin_panel_service_configuration: Option, +} + +pub struct BalancerServiceBundle { + pub agent_controller_pool: Arc, + pub balancer_applicable_state_holder: Arc, + pub balancer_desired_state_tx: broadcast::Sender, + pub initial_desired_state: BalancerDesiredState, + pub state_database: Arc, + inference_service: InferenceService, + management_service: ManagementService, + reconciliation_service: ReconciliationService, + openai_service: Option, + statsd_service: Option, + #[cfg(feature = "web_admin_panel")] + web_admin_panel_service: Option, +} + +impl BalancerServiceBundle { + pub async fn new( + BalancerBootstrapConfig { + buffered_request_timeout, + inference_service_configuration, + management_service_configuration, + max_buffered_requests, + openai_service_configuration, + shutdown_options, + state_database_type, + statsd_prefix, + statsd_service_configuration, + #[cfg(feature = "web_admin_panel")] + web_admin_panel_service_configuration, + }: BalancerBootstrapConfig, + ) -> Result { + let (balancer_desired_state_tx, balancer_desired_state_rx) = broadcast::channel(100); + + let agent_controller_pool = Arc::new(AgentControllerPool::default()); + let balancer_applicable_state_holder = Arc::new(BalancerApplicableStateHolder::default()); + let buffered_request_manager = Arc::new(BufferedRequestManager::new( + agent_controller_pool.clone(), + buffered_request_timeout, + max_buffered_requests, + )); + let chat_template_override_sender_collection = + Arc::new(ChatTemplateOverrideSenderCollection::default()); + let embedding_sender_collection = Arc::new(EmbeddingSenderCollection::default()); + let generate_tokens_sender_collection = Arc::new(GenerateTokensSenderCollection::default()); + let model_metadata_sender_collection = Arc::new(ModelMetadataSenderCollection::default()); + let state_database: Arc = match state_database_type { + StateDatabaseType::File(path) => { + Arc::new(File::new(balancer_desired_state_tx.clone(), path)) + } + StateDatabaseType::Memory(initial_desired_state) => Arc::new(Memory::new( + balancer_desired_state_tx.clone(), + *initial_desired_state, + )), + }; + + let initial_desired_state = state_database.read_balancer_desired_state().await?; + + let inference_service = InferenceService { + agent_controller_pool: agent_controller_pool.clone(), + balancer_applicable_state_holder: balancer_applicable_state_holder.clone(), + buffered_request_manager: buffered_request_manager.clone(), + configuration: inference_service_configuration.clone(), + shutdown_options: shutdown_options.clone(), + #[cfg(feature = "web_admin_panel")] + web_admin_panel_service_configuration: web_admin_panel_service_configuration.clone(), + }; + + let management_service = ManagementService { + agent_controller_pool: agent_controller_pool.clone(), + balancer_applicable_state_holder: balancer_applicable_state_holder.clone(), + buffered_request_manager: buffered_request_manager.clone(), + chat_template_override_sender_collection, + configuration: management_service_configuration, + embedding_sender_collection, + generate_tokens_sender_collection, + model_metadata_sender_collection, + shutdown_options: shutdown_options.clone(), + state_database: state_database.clone(), + statsd_prefix, + #[cfg(feature = "web_admin_panel")] + web_admin_panel_service_configuration: web_admin_panel_service_configuration.clone(), + }; + + let reconciliation_service = ReconciliationService { + agent_controller_pool: agent_controller_pool.clone(), + balancer_applicable_state_holder: balancer_applicable_state_holder.clone(), + balancer_desired_state: initial_desired_state.clone(), + balancer_desired_state_rx, + is_converted_to_applicable_state: false, + }; + + let openai_service = openai_service_configuration.map(|openai_service_configuration| { + OpenAIService { + buffered_request_manager: buffered_request_manager.clone(), + inference_service_configuration, + openai_service_configuration, + shutdown_options: shutdown_options.clone(), + } + }); + + let statsd_service = statsd_service_configuration.map(|configuration| StatsdService { + agent_controller_pool: agent_controller_pool.clone(), + buffered_request_manager, + configuration, + }); + + #[cfg(feature = "web_admin_panel")] + let web_admin_panel_service = + web_admin_panel_service_configuration.map(|configuration| WebAdminPanelService { + configuration, + shutdown_options: shutdown_options.clone(), + }); + + Ok(Self { + agent_controller_pool, + balancer_applicable_state_holder, + balancer_desired_state_tx, + initial_desired_state, + state_database, + inference_service, + management_service, + reconciliation_service, + openai_service, + statsd_service, + #[cfg(feature = "web_admin_panel")] + web_admin_panel_service, + }) + } +} + +#[async_trait] +impl ServiceBundle for BalancerServiceBundle { + async fn services(self) -> Result>> { + let mut services: Vec> = vec![ + Box::new(self.inference_service), + Box::new(self.management_service), + Box::new(self.reconciliation_service), + ]; + + if let Some(service) = self.openai_service { + services.push(Box::new(service)); + } + + if let Some(service) = self.statsd_service { + services.push(Box::new(service)); + } + + #[cfg(feature = "web_admin_panel")] + if let Some(service) = self.web_admin_panel_service { + services.push(Box::new(service)); + } + + Ok(services) + } +} diff --git a/paddler_bootstrap/src/bootstrapped_agent_handle.rs b/paddler_bootstrap/src/bootstrapped_agent_handle.rs deleted file mode 100644 index 9d87fdb5..00000000 --- a/paddler_bootstrap/src/bootstrapped_agent_handle.rs +++ /dev/null @@ -1,92 +0,0 @@ -use std::sync::Arc; - -use nanoid::nanoid; -use paddler::agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; -use paddler::agent::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; -use paddler::agent::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; -use paddler::agent::llamacpp_arbiter_service::LlamaCppArbiterService; -use paddler::agent::management_socket_client_service::ManagementSocketClientService; -use paddler::agent::model_metadata_holder::ModelMetadataHolder; -use paddler::agent::reconciliation_service::ReconciliationService; -use paddler::agent_applicable_state_holder::AgentApplicableStateHolder; -use paddler::service_manager::ServiceManager; -use paddler::slot_aggregated_status::SlotAggregatedStatus; -use paddler::slot_aggregated_status_manager::SlotAggregatedStatusManager; -use paddler_types::agent_desired_state::AgentDesiredState; -use tokio::sync::mpsc; - -pub struct BootstrappedAgentHandle { - pub service_manager: ServiceManager, - pub slot_aggregated_status: Arc, -} - -pub fn bootstrap_agent( - agent_name: Option, - management_address: &str, - slots: i32, -) -> BootstrappedAgentHandle { - let (agent_desired_state_tx, agent_desired_state_rx) = - mpsc::unbounded_channel::(); - let ( - continue_from_conversation_history_request_tx, - continue_from_conversation_history_request_rx, - ) = mpsc::unbounded_channel::(); - let (continue_from_raw_prompt_request_tx, continue_from_raw_prompt_request_rx) = - mpsc::unbounded_channel::(); - let (generate_embedding_batch_request_tx, generate_embedding_batch_request_rx) = - mpsc::unbounded_channel::(); - - let agent_applicable_state_holder = Arc::new(AgentApplicableStateHolder::default()); - let model_metadata_holder = Arc::new(ModelMetadataHolder::default()); - let mut service_manager = ServiceManager::default(); - let slot_aggregated_status_manager = Arc::new(SlotAggregatedStatusManager::new(slots)); - - service_manager.add_service(LlamaCppArbiterService { - agent_applicable_state: None, - agent_applicable_state_holder: agent_applicable_state_holder.clone(), - agent_name: agent_name.clone(), - continue_from_conversation_history_request_rx, - continue_from_raw_prompt_request_rx, - desired_slots_total: slots, - generate_embedding_batch_request_rx, - continuous_batch_arbiter_handle: None, - model_metadata_holder: model_metadata_holder.clone(), - slot_aggregated_status_manager: slot_aggregated_status_manager.clone(), - }); - - service_manager.add_service(ManagementSocketClientService { - agent_applicable_state_holder: agent_applicable_state_holder.clone(), - agent_desired_state_tx, - continue_from_conversation_history_request_tx, - continue_from_raw_prompt_request_tx, - generate_embedding_batch_request_tx, - model_metadata_holder, - name: agent_name, - receive_stream_stopper_collection: Arc::default(), - slot_aggregated_status: slot_aggregated_status_manager - .slot_aggregated_status - .clone(), - socket_url: format!( - "ws://{}/api/v1/agent_socket/{}", - management_address, - nanoid!() - ), - }); - - service_manager.add_service(ReconciliationService { - agent_applicable_state_holder, - agent_desired_state: None, - agent_desired_state_rx, - is_converted_to_applicable_state: false, - slot_aggregated_status: slot_aggregated_status_manager - .slot_aggregated_status - .clone(), - }); - - BootstrappedAgentHandle { - service_manager, - slot_aggregated_status: slot_aggregated_status_manager - .slot_aggregated_status - .clone(), - } -} diff --git a/paddler_bootstrap/src/bootstrapped_balancer_handle.rs b/paddler_bootstrap/src/bootstrapped_balancer_handle.rs deleted file mode 100644 index 1e208648..00000000 --- a/paddler_bootstrap/src/bootstrapped_balancer_handle.rs +++ /dev/null @@ -1,154 +0,0 @@ -use std::sync::Arc; -use std::time::Duration; - -use paddler::balancer::agent_controller_pool::AgentControllerPool; -use paddler::balancer::buffered_request_manager::BufferedRequestManager; -use paddler::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; -use paddler::balancer::compatibility::openai_service::OpenAIService; -use paddler::balancer::compatibility::openai_service::configuration::Configuration as OpenAIServiceConfiguration; -use paddler::balancer::embedding_sender_collection::EmbeddingSenderCollection; -use paddler::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; -use paddler::balancer::inference_service::InferenceService; -use paddler::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use paddler::balancer::management_service::ManagementService; -use paddler::balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; -use paddler::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; -use paddler::balancer::reconciliation_service::ReconciliationService; -use paddler::balancer::state_database::File; -use paddler::balancer::state_database::Memory; -use paddler::balancer::state_database::StateDatabase; -use paddler::balancer::state_database_type::StateDatabaseType; -use paddler::balancer::statsd_service::StatsdService; -use paddler::balancer::statsd_service::configuration::Configuration as StatsdServiceConfiguration; -#[cfg(feature = "web_admin_panel")] -use paddler::balancer::web_admin_panel_service::WebAdminPanelService; -#[cfg(feature = "web_admin_panel")] -use paddler::balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; -use paddler::balancer_applicable_state_holder::BalancerApplicableStateHolder; -use paddler::service_manager::ServiceManager; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use tokio::sync::broadcast; - -pub struct BalancerBootstrapConfig { - pub buffered_request_timeout: Duration, - pub inference_service_configuration: InferenceServiceConfiguration, - pub management_service_configuration: ManagementServiceConfiguration, - pub max_buffered_requests: i32, - pub openai_service_configuration: Option, - pub state_database_type: StateDatabaseType, - pub statsd_prefix: String, - pub statsd_service_configuration: Option, - #[cfg(feature = "web_admin_panel")] - pub web_admin_panel_service_configuration: Option, -} - -pub struct BootstrappedBalancerHandle { - pub agent_controller_pool: Arc, - pub balancer_applicable_state_holder: Arc, - pub balancer_desired_state_tx: broadcast::Sender, - pub service_manager: ServiceManager, - pub state_database: Arc, -} - -pub async fn bootstrap_balancer( - BalancerBootstrapConfig { - buffered_request_timeout, - inference_service_configuration, - management_service_configuration, - max_buffered_requests, - openai_service_configuration, - state_database_type, - statsd_prefix, - statsd_service_configuration, - #[cfg(feature = "web_admin_panel")] - web_admin_panel_service_configuration, - }: BalancerBootstrapConfig, -) -> anyhow::Result { - let (balancer_desired_state_tx, balancer_desired_state_rx) = broadcast::channel(100); - - let agent_controller_pool = Arc::new(AgentControllerPool::default()); - let balancer_applicable_state_holder = Arc::new(BalancerApplicableStateHolder::default()); - let buffered_request_manager = Arc::new(BufferedRequestManager::new( - agent_controller_pool.clone(), - buffered_request_timeout, - max_buffered_requests, - )); - let chat_template_override_sender_collection = - Arc::new(ChatTemplateOverrideSenderCollection::default()); - let embedding_sender_collection = Arc::new(EmbeddingSenderCollection::default()); - let generate_tokens_sender_collection = Arc::new(GenerateTokensSenderCollection::default()); - let model_metadata_sender_collection = Arc::new(ModelMetadataSenderCollection::default()); - let mut service_manager = ServiceManager::default(); - let state_database: Arc = match state_database_type { - StateDatabaseType::File(path) => { - Arc::new(File::new(balancer_desired_state_tx.clone(), path)) - } - StateDatabaseType::Memory(initial_desired_state) => Arc::new(Memory::new( - balancer_desired_state_tx.clone(), - *initial_desired_state, - )), - }; - - service_manager.add_service(InferenceService { - agent_controller_pool: agent_controller_pool.clone(), - balancer_applicable_state_holder: balancer_applicable_state_holder.clone(), - buffered_request_manager: buffered_request_manager.clone(), - configuration: inference_service_configuration.clone(), - #[cfg(feature = "web_admin_panel")] - web_admin_panel_service_configuration: web_admin_panel_service_configuration.clone(), - }); - - service_manager.add_service(ManagementService { - agent_controller_pool: agent_controller_pool.clone(), - balancer_applicable_state_holder: balancer_applicable_state_holder.clone(), - buffered_request_manager: buffered_request_manager.clone(), - chat_template_override_sender_collection, - configuration: management_service_configuration, - embedding_sender_collection, - generate_tokens_sender_collection, - model_metadata_sender_collection, - state_database: state_database.clone(), - statsd_prefix, - #[cfg(feature = "web_admin_panel")] - web_admin_panel_service_configuration: web_admin_panel_service_configuration.clone(), - }); - - service_manager.add_service(ReconciliationService { - agent_controller_pool: agent_controller_pool.clone(), - balancer_applicable_state_holder: balancer_applicable_state_holder.clone(), - balancer_desired_state: state_database.read_balancer_desired_state().await?, - balancer_desired_state_rx, - is_converted_to_applicable_state: false, - }); - - if let Some(openai_configuration) = openai_service_configuration { - service_manager.add_service(OpenAIService { - buffered_request_manager: buffered_request_manager.clone(), - inference_service_configuration, - openai_service_configuration: openai_configuration, - }); - } - - if let Some(statsd_configuration) = statsd_service_configuration { - service_manager.add_service(StatsdService { - agent_controller_pool: agent_controller_pool.clone(), - buffered_request_manager: buffered_request_manager.clone(), - configuration: statsd_configuration, - }); - } - - #[cfg(feature = "web_admin_panel")] - if let Some(web_admin_panel_configuration) = web_admin_panel_service_configuration { - service_manager.add_service(WebAdminPanelService { - configuration: web_admin_panel_configuration, - }); - } - - Ok(BootstrappedBalancerHandle { - agent_controller_pool, - balancer_applicable_state_holder, - balancer_desired_state_tx, - service_manager, - state_database, - }) -} diff --git a/paddler_bootstrap/src/lib.rs b/paddler_bootstrap/src/lib.rs index 82fc6c23..f98af427 100644 --- a/paddler_bootstrap/src/lib.rs +++ b/paddler_bootstrap/src/lib.rs @@ -1,6 +1,6 @@ pub mod agent_runner; +pub mod agent_service_bundle; pub mod balancer_runner; -mod bootstrapped_agent_handle; -mod bootstrapped_balancer_handle; +pub mod balancer_service_bundle; pub mod service_thread; pub mod shutdown_signal; diff --git a/paddler_cli/Cargo.toml b/paddler_cli/Cargo.toml index 6c93cba8..8bec93a2 100644 --- a/paddler_cli/Cargo.toml +++ b/paddler_cli/Cargo.toml @@ -24,6 +24,7 @@ paddler_bootstrap = { workspace = true } paddler_types = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true } +trzcina = { workspace = true } # web dashboard deps esbuild-metafile = { workspace = true, optional = true } diff --git a/paddler_cli/src/cmd/agent.rs b/paddler_cli/src/cmd/agent.rs index 43a67945..198268ab 100644 --- a/paddler_cli/src/cmd/agent.rs +++ b/paddler_cli/src/cmd/agent.rs @@ -2,9 +2,10 @@ use anyhow::Result; use async_trait::async_trait; use clap::Parser; use paddler::resolved_socket_addr::ResolvedSocketAddr; -use paddler_bootstrap::agent_runner::AgentRunner; -use paddler_bootstrap::agent_runner::AgentRunnerParams; +use paddler_bootstrap::agent_service_bundle::AgentServiceBundle; use tokio_util::sync::CancellationToken; +use trzcina::ServiceManager; +use trzcina::ServiceShutdownOptions; use super::handler::Handler; use super::value_parser::parse_socket_addr; @@ -27,13 +28,21 @@ pub struct Agent { #[async_trait] impl Handler for Agent { async fn handle(&self, shutdown: CancellationToken) -> Result<()> { - let mut runner = AgentRunner::start(AgentRunnerParams { - agent_name: self.name.clone(), - management_address: self.management_addr.socket_addr.to_string(), - cancellation_token: shutdown, - slots: self.slots, - }); - - runner.wait_for_completion().await + let bundle = AgentServiceBundle::new( + self.name.clone(), + &self.management_addr.socket_addr.to_string(), + self.slots, + ); + + let mut service_manager = ServiceManager::default(); + + service_manager.register_bundle(bundle).await?; + + service_manager + .start(shutdown) + .run_to_completion(ServiceShutdownOptions::default()) + .await + .into_result() + .map_err(anyhow::Error::from) } } diff --git a/paddler_cli/src/cmd/balancer.rs b/paddler_cli/src/cmd/balancer.rs index c2e25aab..94d2248b 100644 --- a/paddler_cli/src/cmd/balancer.rs +++ b/paddler_cli/src/cmd/balancer.rs @@ -13,9 +13,11 @@ use paddler::balancer::web_admin_panel_service::configuration::Configuration as #[cfg(feature = "web_admin_panel")] use paddler::balancer::web_admin_panel_service::template_data::TemplateData; use paddler::resolved_socket_addr::ResolvedSocketAddr; -use paddler_bootstrap::balancer_runner::BalancerRunner; -use paddler_bootstrap::balancer_runner::BalancerRunnerParams; +use paddler_bootstrap::balancer_service_bundle::BalancerBootstrapConfig; +use paddler_bootstrap::balancer_service_bundle::BalancerServiceBundle; use tokio_util::sync::CancellationToken; +use trzcina::ServiceManager; +use trzcina::ServiceShutdownOptions; use super::handler::Handler; use super::value_parser::parse_duration; @@ -111,7 +113,9 @@ impl Balancer { #[async_trait] impl Handler for Balancer { async fn handle(&self, shutdown: CancellationToken) -> Result<()> { - let mut runner = BalancerRunner::start(BalancerRunnerParams { + let shutdown_options = ServiceShutdownOptions::default(); + + let bundle = BalancerServiceBundle::new(BalancerBootstrapConfig { buffered_request_timeout: self.buffered_request_timeout, inference_service_configuration: InferenceServiceConfiguration { addr: self.inference_addr.socket_addr, @@ -128,7 +132,7 @@ impl Handler for Balancer { addr: compat_openai_addr.socket_addr, }, ), - cancellation_token: shutdown, + shutdown_options: shutdown_options.clone(), state_database_type: self.state_database.clone(), statsd_prefix: self.statsd_prefix.clone(), statsd_service_configuration: self.statsd_addr.clone().map(|statsd_addr| { @@ -143,6 +147,14 @@ impl Handler for Balancer { }) .await?; - runner.wait_for_completion().await + let mut service_manager = ServiceManager::default(); + service_manager.register_bundle(bundle).await?; + + service_manager + .start(shutdown) + .run_to_completion(shutdown_options) + .await + .into_result() + .map_err(anyhow::Error::from) } } diff --git a/paddler_cli/src/main.rs b/paddler_cli/src/main.rs index 0e06b294..56fb8e46 100644 --- a/paddler_cli/src/main.rs +++ b/paddler_cli/src/main.rs @@ -38,7 +38,7 @@ enum Commands { Balancer(Balancer), } -#[tokio::main] +#[actix_web::main] async fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); @@ -55,12 +55,12 @@ async fn main() -> Result<()> { }); match Cli::parse().command { - Some(Commands::Agent(handler)) => Ok(handler.handle(shutdown).await?), + Some(Commands::Agent(handler)) => handler.handle(shutdown).await, Some(Commands::Balancer(handler)) => { #[cfg(feature = "web_admin_panel")] initialize_instance(ESBUILD_META_CONTENTS); - Ok(handler.handle(shutdown).await?) + handler.handle(shutdown).await } None => Ok(()), }