Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ tokio-test = "0.4.4"
tokio-tungstenite = "0.28"
tokio-util = "0.7"
thiserror = "2"
trzcina = "=0.2.1"
trzcina = "=0.3.0"
url = { version = "2.5", features = ["serde"] }
paddler = { version = "4.0.0", path = "paddler" }
paddler_bootstrap = { version = "4.0.0", path = "paddler_bootstrap" }
Expand Down
254 changes: 156 additions & 98 deletions paddler/src/agent/llamacpp_arbiter_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,101 +27,126 @@ use crate::agent_applicable_state::AgentApplicableState;
use crate::agent_applicable_state_holder::AgentApplicableStateHolder;
use crate::slot_aggregated_status_manager::SlotAggregatedStatusManager;

pub struct LlamaCppArbiterService {
pub agent_applicable_state: Option<AgentApplicableState>,
pub agent_applicable_state_holder: Arc<AgentApplicableStateHolder>,
pub agent_name: Option<String>,
pub continue_from_conversation_history_request_rx:
mpsc::UnboundedReceiver<ContinueFromConversationHistoryRequest>,
pub continue_from_raw_prompt_request_rx: mpsc::UnboundedReceiver<ContinueFromRawPromptRequest>,
pub desired_slots_total: i32,
pub generate_embedding_batch_request_rx: mpsc::UnboundedReceiver<GenerateEmbeddingBatchRequest>,
pub continuous_batch_arbiter_handle: Option<ContinuousBatchArbiterHandle>,
pub model_metadata_holder: Arc<ModelMetadataHolder>,
pub slot_aggregated_status_manager: Arc<SlotAggregatedStatusManager>,
}

impl LlamaCppArbiterService {
async fn apply_state(&mut self, shutdown: &CancellationToken) -> Result<()> {
self.wait_for_in_flight_requests_to_finish(shutdown).await?;
self.tear_down_arbiter()?;

if let Some(applicable_state) = self.agent_applicable_state.clone() {
self.slot_aggregated_status_manager.reset();

match ContinuousBatchArbiter::build_from_applicable_state(
applicable_state,
self.agent_name.clone(),
self.desired_slots_total,
self.model_metadata_holder.clone(),
self.slot_aggregated_status_manager.clone(),
) {
ContinuousBatchArbiterBuildOutcome::ReadyToSpawn(arbiter) => {
self.continuous_batch_arbiter_handle = Some(arbiter.spawn().await?);
info!("Reconciled state change applied successfully");
}
ContinuousBatchArbiterBuildOutcome::NoModelConfigured => {
warn!(
"No model configured in applicable state; skipping llama.cpp initialization"
);
}
async fn apply_state(
shutdown: &CancellationToken,
agent_applicable_state: Option<&AgentApplicableState>,
agent_name: Option<&str>,
desired_slots_total: i32,
model_metadata_holder: &Arc<ModelMetadataHolder>,
slot_aggregated_status_manager: &Arc<SlotAggregatedStatusManager>,
continuous_batch_arbiter_handle: &mut Option<ContinuousBatchArbiterHandle>,
) -> Result<()> {
wait_for_in_flight_requests_to_finish(
shutdown,
continuous_batch_arbiter_handle.as_ref(),
slot_aggregated_status_manager,
)
.await?;
shutdown_arbiter_handle(continuous_batch_arbiter_handle).await?;

if let Some(applicable_state) = agent_applicable_state.cloned() {
slot_aggregated_status_manager.reset();

match ContinuousBatchArbiter::build_from_applicable_state(
applicable_state,
agent_name.map(str::to_owned),
desired_slots_total,
model_metadata_holder.clone(),
slot_aggregated_status_manager.clone(),
) {
ContinuousBatchArbiterBuildOutcome::ReadyToSpawn(arbiter) => {
*continuous_batch_arbiter_handle = Some(arbiter.spawn().await?);
info!("Reconciled state change applied successfully");
}
ContinuousBatchArbiterBuildOutcome::NoModelConfigured => {
warn!(
"No model configured in applicable state; skipping llama.cpp initialization"
);
}
}

self.slot_aggregated_status_manager
.slot_aggregated_status
.set_state_application_status(AgentStateApplicationStatus::Applied);

Ok(())
}

async fn wait_for_in_flight_requests_to_finish(
&self,
shutdown: &CancellationToken,
) -> Result<()> {
if self.continuous_batch_arbiter_handle.is_some() {
drain_in_flight_requests(&self.slot_aggregated_status_manager, shutdown).await?;
}
slot_aggregated_status_manager
.slot_aggregated_status
.set_state_application_status(AgentStateApplicationStatus::Applied);

Ok(())
}
Ok(())
}

fn tear_down_arbiter(&mut self) -> Result<()> {
if let Some(arbiter_handle) = self.continuous_batch_arbiter_handle.take() {
arbiter_handle
.shutdown()
.context("Unable to stop arbiter controller")?;
fn forward_command(
continuous_batch_arbiter_handle: Option<&ContinuousBatchArbiterHandle>,
command: ContinuousBatchSchedulerCommand,
) {
if let Some(arbiter_handle) = continuous_batch_arbiter_handle {
if let Err(err) = arbiter_handle.command_tx.send(command) {
error!("Failed to forward command to scheduler: {err}");
}

Ok(())
} else {
error!("ContinuousBatchArbiterHandle is not initialized");
}
}

fn forward_command(&self, command: ContinuousBatchSchedulerCommand) {
if let Some(arbiter_handle) = &self.continuous_batch_arbiter_handle {
if let Err(err) = arbiter_handle.command_tx.send(command) {
error!("Failed to forward command to scheduler: {err}");
}
} else {
error!("ContinuousBatchArbiterHandle is not initialized");
}
async fn shutdown_arbiter_handle(
continuous_batch_arbiter_handle: &mut Option<ContinuousBatchArbiterHandle>,
) -> Result<()> {
let Some(handle) = continuous_batch_arbiter_handle.take() else {
return Ok(());
};

tokio::task::spawn_blocking(move || handle.shutdown())
.await
.context("Arbiter shutdown task panicked")?
.context("Arbiter shutdown returned an error")
}

async fn try_to_apply_state(
shutdown: &CancellationToken,
agent_applicable_state: Option<&AgentApplicableState>,
agent_name: Option<&str>,
desired_slots_total: i32,
model_metadata_holder: &Arc<ModelMetadataHolder>,
slot_aggregated_status_manager: &Arc<SlotAggregatedStatusManager>,
continuous_batch_arbiter_handle: &mut Option<ContinuousBatchArbiterHandle>,
) {
if let Err(err) = apply_state(
shutdown,
agent_applicable_state,
agent_name,
desired_slots_total,
model_metadata_holder,
slot_aggregated_status_manager,
continuous_batch_arbiter_handle,
)
.await
{
error!("Failed to apply reconciled state change: {err}");
}
}

async fn try_to_apply_state(&mut self, shutdown: &CancellationToken) {
if let Err(err) = self.apply_state(shutdown).await {
error!("Failed to apply reconciled state change: {err}");
}
async fn wait_for_in_flight_requests_to_finish(
shutdown: &CancellationToken,
continuous_batch_arbiter_handle: Option<&ContinuousBatchArbiterHandle>,
slot_aggregated_status_manager: &Arc<SlotAggregatedStatusManager>,
) -> Result<()> {
if continuous_batch_arbiter_handle.is_some() {
drain_in_flight_requests(slot_aggregated_status_manager, shutdown).await?;
}

async fn shutdown_arbiter_handle(&mut self) -> Result<()> {
let Some(handle) = self.continuous_batch_arbiter_handle.take() else {
return Ok(());
};
Ok(())
}

tokio::task::spawn_blocking(move || handle.shutdown())
.await
.context("Arbiter shutdown task panicked")?
.context("Arbiter shutdown returned an error")
}
pub struct LlamaCppArbiterService {
pub agent_applicable_state: Option<AgentApplicableState>,
pub agent_applicable_state_holder: Arc<AgentApplicableStateHolder>,
pub agent_name: Option<String>,
pub continue_from_conversation_history_request_rx:
mpsc::UnboundedReceiver<ContinueFromConversationHistoryRequest>,
pub continue_from_raw_prompt_request_rx: mpsc::UnboundedReceiver<ContinueFromRawPromptRequest>,
pub desired_slots_total: i32,
pub generate_embedding_batch_request_rx: mpsc::UnboundedReceiver<GenerateEmbeddingBatchRequest>,
pub continuous_batch_arbiter_handle: Option<ContinuousBatchArbiterHandle>,
pub model_metadata_holder: Arc<ModelMetadataHolder>,
pub slot_aggregated_status_manager: Arc<SlotAggregatedStatusManager>,
}

#[async_trait]
Expand All @@ -130,8 +155,21 @@ impl Service for LlamaCppArbiterService {
"agent::llamacpp_arbiter_service"
}

async fn run(&mut self, shutdown: CancellationToken) -> Result<()> {
let mut reconciled_state = self.agent_applicable_state_holder.subscribe();
async fn run(self: Box<Self>, shutdown: CancellationToken) -> Result<()> {
let Self {
mut agent_applicable_state,
agent_applicable_state_holder,
agent_name,
mut continue_from_conversation_history_request_rx,
mut continue_from_raw_prompt_request_rx,
desired_slots_total,
mut generate_embedding_batch_request_rx,
mut continuous_batch_arbiter_handle,
model_metadata_holder,
slot_aggregated_status_manager,
} = *self;

let mut reconciled_state = agent_applicable_state_holder.subscribe();
let mut ticker = interval(Duration::from_secs(1));

ticker.set_missed_tick_behavior(MissedTickBehavior::Delay);
Expand All @@ -141,10 +179,10 @@ impl Service for LlamaCppArbiterService {
biased;
() = shutdown.cancelled() => break Ok(()),
_ = ticker.tick() => {
let current_status = self.slot_aggregated_status_manager.slot_aggregated_status.get_state_application_status()?;
let current_status = slot_aggregated_status_manager.slot_aggregated_status.get_state_application_status()?;

if current_status.should_try_to_apply() {
self.slot_aggregated_status_manager
slot_aggregated_status_manager
.slot_aggregated_status
.set_state_application_status(
if matches!(current_status, AgentStateApplicationStatus::AttemptedAndRetrying) {
Expand All @@ -154,36 +192,55 @@ impl Service for LlamaCppArbiterService {
}
);

self.try_to_apply_state(&shutdown).await;
try_to_apply_state(
&shutdown,
agent_applicable_state.as_ref(),
agent_name.as_deref(),
desired_slots_total,
&model_metadata_holder,
&slot_aggregated_status_manager,
&mut continuous_batch_arbiter_handle,
).await;
}
}
_ = reconciled_state.changed() => {
self.agent_applicable_state.clone_from(&reconciled_state.borrow_and_update());
self.slot_aggregated_status_manager
agent_applicable_state.clone_from(&reconciled_state.borrow_and_update());
slot_aggregated_status_manager
.slot_aggregated_status
.set_state_application_status(AgentStateApplicationStatus::Fresh);

self.try_to_apply_state(&shutdown).await;
try_to_apply_state(
&shutdown,
agent_applicable_state.as_ref(),
agent_name.as_deref(),
desired_slots_total,
&model_metadata_holder,
&slot_aggregated_status_manager,
&mut continuous_batch_arbiter_handle,
).await;
}
Some(request) = self.continue_from_conversation_history_request_rx.recv() => {
self.forward_command(
Some(request) = continue_from_conversation_history_request_rx.recv() => {
forward_command(
continuous_batch_arbiter_handle.as_ref(),
ContinuousBatchSchedulerCommand::ContinueFromConversationHistory(request),
);
}
Some(request) = self.continue_from_raw_prompt_request_rx.recv() => {
self.forward_command(
Some(request) = continue_from_raw_prompt_request_rx.recv() => {
forward_command(
continuous_batch_arbiter_handle.as_ref(),
ContinuousBatchSchedulerCommand::ContinueFromRawPrompt(request),
);
}
Some(request) = self.generate_embedding_batch_request_rx.recv() => {
self.forward_command(
Some(request) = generate_embedding_batch_request_rx.recv() => {
forward_command(
continuous_batch_arbiter_handle.as_ref(),
ContinuousBatchSchedulerCommand::GenerateEmbeddingBatch(request),
);
}
}
};

if let Err(err) = self.shutdown_arbiter_handle().await {
if let Err(err) = shutdown_arbiter_handle(&mut continuous_batch_arbiter_handle).await {
error!("Failed to shut down arbiter cleanly: {err:#}");
}

Expand Down Expand Up @@ -211,7 +268,7 @@ mod tests {
let (generate_embedding_batch_request_tx, generate_embedding_batch_request_rx) =
mpsc::unbounded_channel();

let mut service = LlamaCppArbiterService {
let service = LlamaCppArbiterService {
agent_applicable_state: None,
agent_applicable_state_holder: Arc::new(AgentApplicableStateHolder::default()),
agent_name: None,
Expand All @@ -227,7 +284,8 @@ mod tests {
let shutdown = CancellationToken::new();
let task_token = shutdown.clone();

let mut join_handle = tokio::spawn(async move { service.run(task_token).await });
let mut join_handle =
tokio::spawn(async move { Box::new(service).run(task_token).await });

drop(continue_from_conversation_history_request_tx);
drop(continue_from_raw_prompt_request_tx);
Expand Down
2 changes: 1 addition & 1 deletion paddler/src/agent/management_socket_client_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ impl Service for ManagementSocketClientService {
"agent::management_socket_client_service"
}

async fn run(&mut self, shutdown: CancellationToken) -> Result<()> {
async fn run(self: Box<Self>, shutdown: CancellationToken) -> Result<()> {
let mut ticker = interval(Duration::from_secs(1));

ticker.set_missed_tick_behavior(MissedTickBehavior::Delay);
Expand Down
Loading
Loading