diff --git a/Cargo.toml b/Cargo.toml index 9361c4c..108d893 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,6 @@ edition = "2024" server = [ "axum", "base32", - "futures", "hex", "libc", "log", @@ -31,7 +30,6 @@ daemon = [ "sdnotify", "tempfile", "tokio-tasks", - "futures", ] nix = ["dep:nix"] @@ -52,7 +50,7 @@ cgroups-rs = {version = "0.2", optional=true} chrono = {version = "0.4", default-features = false, features = ["std", "clock"], optional=true} clap = {version = "4", default-features = false, features=['std', 'derive', 'help', 'suggestions', 'usage', 'color']} dirs = "6" -futures = {version = "0.3", optional = true} +futures = {version = "0.3" } futures-util = "0.3" hex = {version = "0.4", optional = true} http-body-util = "0.1" diff --git a/src/action_types.rs b/src/action_types.rs index 066903d..06f213e 100644 --- a/src/action_types.rs +++ b/src/action_types.rs @@ -1094,6 +1094,94 @@ pub struct IRunCommandFinished { pub status: i32, } +#[derive(Serialize, Deserialize, Clone, Debug, TS)] +pub struct ISocketConnect { + pub msg_id: u64, + pub socket_id: u64, + pub host: String, + pub dst: String, +} + +#[derive(Serialize, Deserialize, Clone, Debug, TS)] +pub struct IResponse { + pub msg_id: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug, TS)] +pub struct ISocketClose { + pub msg_id: u64, + pub socket_id: u64, +} + +#[derive(Serialize, Deserialize, Clone, Debug, TS)] +pub struct ISocketSend { + pub msg_id: u64, + pub socket_id: u64, + // Base 64 encoded binary data to send on port, if non close the write part of the socket + pub data: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug, TS)] +pub struct ISocketRecv { + pub socket_id: u64, + // Base 64 encoded binary data read, if none no more data will be sent + pub data: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug, TS)] +pub struct ICommandSpawn { + pub msg_id: u64, + pub command_id: u64, + pub host: String, + pub program: String, + pub args: Vec, + pub forward_stdin: bool, + pub forward_stdout: bool, + pub forward_stderr: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub env: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug, TS)] +pub struct ICommandStdin { + pub msg_id: u64, + pub command_id: u64, + // Base64 encode binary data for fd, if none close fd + pub data: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug, TS)] +pub struct ICommandSignal { + pub msg_id: u64, + pub command_id: u64, + pub signal: i32, +} + +#[derive(Serialize, Deserialize, Clone, Debug, TS)] +pub struct ICommandStdout { + pub command_id: u64, + // Base64 encode binary data from fd, if none fd is closed + pub data: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug, TS)] +pub struct ICommandStderr { + pub command_id: u64, + // Base64 encode binary data from fd, if none fd is closed + pub data: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug, TS)] +pub struct ICommandFinished { + pub command_id: u64, + pub code: i32, + pub signal: Option, +} + #[derive(Serialize, Deserialize, Clone, Debug, TS)] #[serde(tag = "type", rename_all = "PascalCase")] pub enum IServerAction { @@ -1131,6 +1219,11 @@ pub enum IServerAction { SetPage(ISetPageAction), ToggleDeploymentObject(IToggleDeploymentObject), GetSecretRes(IGetSecretRes), + Response(IResponse), + SocketRecv(ISocketRecv), + CommandStdout(ICommandStdout), + CommandStderr(ICommandStderr), + CommandFinished(ICommandFinished), } impl IServerAction { @@ -1170,6 +1263,11 @@ impl IServerAction { IServerAction::SetMessagesDismissed(_) => "SetMessagesDismissed", IServerAction::SetPage(_) => "SetPage", IServerAction::ToggleDeploymentObject(_) => "ToggleDeploymentObject", + IServerAction::Response(_) => "Response", + IServerAction::SocketRecv(_) => "SocketRecv", + IServerAction::CommandStdout(_) => "CommandStdout", + IServerAction::CommandStderr(_) => "CommandStderr", + IServerAction::CommandFinished(_) => "CommandFinished", } } } @@ -1214,6 +1312,12 @@ pub enum IClientAction { StopDeployment(IStopDeployment), ToggleDeploymentObject(IToggleDeploymentObject), GetSecret(IGetSecret), + SocketConnect(ISocketConnect), + SocketClose(ISocketClose), + SocketSend(ISocketSend), + CommandSpawn(ICommandSpawn), + CommandStdin(ICommandStdin), + CommandSignal(ICommandSignal), } impl IClientAction { @@ -1256,6 +1360,60 @@ impl IClientAction { IClientAction::ToggleDeploymentObject(_) => "ToggleDeploymentObject", IClientAction::MarkDeployed(_) => "MarkDeployed", IClientAction::GetSecret(_) => "GetSecret", + IClientAction::SocketConnect(_) => "SocketConnect", + IClientAction::SocketClose(_) => "SocketClose", + IClientAction::SocketSend(_) => "SocketSend", + IClientAction::CommandSpawn(_) => "CommandSpawn", + IClientAction::CommandStdin(_) => "CommandStdin", + IClientAction::CommandSignal(_) => "CommandSignal", + } + } + + pub fn msg_id(&self) -> Option { + match self { + IClientAction::CancelDeployment(_) => None, + IClientAction::Debug(_) => None, + IClientAction::DeleteObject(_) => None, + IClientAction::DeployObject(_) => None, + IClientAction::MarkDeployed(_) => None, + IClientAction::DockerContainerForget(_) => None, + IClientAction::DockerImageSetPin(_) => None, + IClientAction::DockerImageTagSetPin(_) => None, + IClientAction::DockerListDeploymentHistory(_) => None, + IClientAction::DockerListDeployments(_) => None, + IClientAction::DockerListImageByHash(_) => None, + IClientAction::DockerListImageTagHistory(_) => None, + IClientAction::DockerListImageTags(_) => None, + IClientAction::FetchObject(_) => None, + IClientAction::GenerateKey(_) => None, + IClientAction::GetObjectHistory(_) => None, + IClientAction::GetObjectId(_) => None, + IClientAction::Login(_) => None, + IClientAction::Logout(_) => None, + IClientAction::MessageTextReq(_) => None, + IClientAction::ModifiedFilesList(_) => None, + IClientAction::ModifiedFilesResolve(_) => None, + IClientAction::ModifiedFilesScan(_) => None, + IClientAction::RequestAuthStatus(_) => None, + IClientAction::RequestInitialState(_) => None, + IClientAction::ResetServerState(_) => None, + IClientAction::RunCommand(_) => None, + IClientAction::RunCommandTerminate(_) => None, + IClientAction::SaveObject(_) => None, + IClientAction::Search(_) => None, + IClientAction::ServiceDeployStart(_) => None, + IClientAction::ServiceRedeployStart(_) => None, + IClientAction::SetMessageDismissed(_) => None, + IClientAction::StartDeployment(_) => None, + IClientAction::StopDeployment(_) => None, + IClientAction::ToggleDeploymentObject(_) => None, + IClientAction::SocketConnect(act) => Some(act.msg_id), + IClientAction::SocketClose(act) => Some(act.msg_id), + IClientAction::SocketSend(act) => Some(act.msg_id), + IClientAction::CommandSpawn(act) => Some(act.msg_id), + IClientAction::CommandStdin(act) => Some(act.msg_id), + IClientAction::CommandSignal(act) => Some(act.msg_id), + IClientAction::GetSecret(_) => None, } } } @@ -1361,6 +1519,17 @@ pub fn export_ts() -> Vec { IGetObjectIdRes::export_to_string().unwrap(), IServerAction::export_to_string().unwrap(), IClientAction::export_to_string().unwrap(), + IResponse::export_to_string().unwrap(), + ISocketConnect::export_to_string().unwrap(), + ISocketClose::export_to_string().unwrap(), + ISocketSend::export_to_string().unwrap(), + ICommandSpawn::export_to_string().unwrap(), + ICommandStdin::export_to_string().unwrap(), + ICommandSignal::export_to_string().unwrap(), + ISocketRecv::export_to_string().unwrap(), + ICommandStdout::export_to_string().unwrap(), + ICommandStderr::export_to_string().unwrap(), + ICommandFinished::export_to_string().unwrap(), ] } diff --git a/src/bin/sadmin/client_daemon.rs b/src/bin/sadmin/client_daemon.rs index c506e5c..9e7b59e 100644 --- a/src/bin/sadmin/client_daemon.rs +++ b/src/bin/sadmin/client_daemon.rs @@ -20,6 +20,7 @@ use base64::{Engine, prelude::BASE64_STANDARD}; use bytes::BytesMut; use futures::{future, pin_mut}; use log::{debug, error, info, warn}; +use nix::{sys::signal::Signal, unistd::Pid}; use reqwest::Url; use serde::Deserialize; use tokio::{ @@ -28,7 +29,8 @@ use tokio::{ TcpStream, UnixStream, unix::{OwnedReadHalf, OwnedWriteHalf}, }, - process::ChildStdin, + pin, + process::{Child, ChildStdin}, select, sync::{ Notify, @@ -37,12 +39,12 @@ use tokio::{ time::timeout, }; use tokio_rustls::{TlsConnector, client::TlsStream, rustls}; -use tokio_tasks::{CancelledError, RunToken, TaskBase, TaskBuilder, cancelable}; +use tokio_tasks::{CancelledError, RunToken, Task, TaskBase, TaskBuilder, cancelable}; use sadmin2::client_message::{ - ClientHostMessage, DataMessage, DataSource, DeployServiceMessage, FailureMessage, FailureType, - HostClientMessage, RunInstantMessage, RunInstantStdinOutputType, RunScriptMessage, - RunScriptOutType, RunScriptStdinType, SuccessMessage, + ClientHostMessage, CommandSpawnMessage, DataMessage, DataSource, DeployServiceMessage, + FailureMessage, FailureType, HostClientMessage, RunInstantMessage, RunInstantStdinOutputType, + RunScriptMessage, RunScriptOutType, RunScriptStdinType, SuccessMessage, }; use sadmin2::service_description::ServiceDescription; @@ -103,6 +105,21 @@ pub struct ClientDaemon { pub type PersistMessageSender = tokio::sync::oneshot::Sender<(persist_daemon::Message, Option)>; +enum SocketWrite { + Unix(tokio::net::unix::OwnedWriteHalf), + Tcp(tokio::net::tcp::OwnedWriteHalf), +} + +enum SocketRead { + Unix(tokio::net::unix::OwnedReadHalf), + Tcp(tokio::net::tcp::OwnedReadHalf), +} + +pub struct Socket { + task: Arc>, + write: tokio::sync::Mutex>, +} + pub struct Client { connector: TlsConnector, pub config: Config, @@ -119,6 +136,11 @@ pub struct Client { persist_sender: tokio::sync::Mutex, password: String, metrics_token: Option, + sockets: Mutex>>, + + command_pids: Mutex>, + command_stdins: Mutex>>>, + pub db: Mutex, pub dead_process_handlers: Mutex>>, @@ -660,6 +682,392 @@ impl Client { } } + async fn send_result(self: Arc, id: u64, r: Result<()>) { + match r { + Ok(_) => { + self.send_message(ClientHostMessage::Success(SuccessMessage { + id, + ..Default::default() + })) + .await; + } + Err(e) => { + self.send_message(ClientHostMessage::Failure(FailureMessage { + id, + message: Some(format!("{:?}", e)), + ..Default::default() + })) + .await; + } + } + } + + async fn handle_socket( + self: Arc, + socket_id: u64, + rt: RunToken, + mut r: SocketRead, + ) -> Result<()> { + let mut buf = BytesMut::with_capacity(1024 * 64); + loop { + let r = match &mut r { + SocketRead::Unix(r) => cancelable(&rt, r.read_buf(&mut buf)).await, + SocketRead::Tcp(r) => cancelable(&rt, r.read_buf(&mut buf)).await, + }; + match r { + Ok(Ok(0)) => break, + Ok(Ok(_)) => (), + Ok(Err(_)) => break, + Err(_) => { + self.sockets.lock().unwrap().remove(&socket_id); + return Ok(()); + } + } + let data = BASE64_STANDARD.encode(&buf); + self.send_message(ClientHostMessage::SocketRecv { + socket_id, + data: Some(data), + }) + .await; + buf.clear(); + } + self.send_message(ClientHostMessage::SocketRecv { + socket_id, + data: None, + }) + .await; + self.sockets.lock().unwrap().remove(&socket_id); + Ok(()) + } + + async fn handle_socket_connect_inner( + self: &Arc, + socket_id: u64, + dst: String, + ) -> Result<()> { + if self.sockets.lock().unwrap().contains_key(&socket_id) { + bail!("socket_id already in use"); + } + let (r, w) = if dst.contains(':') && !dst.contains("/") { + let s = tokio::net::TcpStream::connect(&dst) + .await + .with_context(|| format!("Unable to connect to {}", dst))?; + let (r, w) = s.into_split(); + (SocketRead::Tcp(r), SocketWrite::Tcp(w)) + } else { + let s = tokio::net::UnixStream::connect(&dst) + .await + .with_context(|| format!("Unable to connect to {}", dst))?; + let (r, w) = s.into_split(); + (SocketRead::Unix(r), SocketWrite::Unix(w)) + }; + let s2 = self.clone(); + let task = TaskBuilder::new(format!("handle_tcp_socket_{}", socket_id)) + .shutdown_order(-99) + .create(|rt| async move { s2.handle_socket(socket_id, rt, r).await }); + + self.sockets.lock().unwrap().insert( + socket_id, + Arc::new(Socket { + task, + write: tokio::sync::Mutex::new(Some(w)), + }), + ); + Ok(()) + } + + async fn handle_socket_connect(self: Arc, id: u64, socket_id: u64, dst: String) { + let r = self.handle_socket_connect_inner(socket_id, dst).await; + self.send_result(id, r).await; + } + + async fn handle_socket_close_inner(self: &Arc, socket_id: u64) -> Result<()> { + let conn = self + .sockets + .lock() + .unwrap() + .remove(&socket_id) + .with_context(|| format!("Unknown socket {}", socket_id))?; + conn.task.run_token().cancel(); + if let Err(e) = conn.task.clone().wait().await { + match e { + tokio_tasks::WaitError::HandleUnset(e) => bail!("Handle unset {}", e), + tokio_tasks::WaitError::JoinError(e) => bail!("Join error {:?}", e), + tokio_tasks::WaitError::TaskFailure(_) => (), + } + } + Ok(()) + } + + async fn handle_socket_close(self: Arc, id: u64, socket_id: u64) { + let r = self.handle_socket_close_inner(socket_id).await; + self.send_result(id, r).await; + } + + async fn handle_socket_send_inner( + self: &Arc, + socket_id: u64, + data: Option, + ) -> Result<()> { + let conn = self + .sockets + .lock() + .unwrap() + .get(&socket_id) + .with_context(|| format!("Unknown socket {}", socket_id))? + .clone(); + if let Some(data) = data { + let mut conn = conn.write.lock().await; + let Some(w) = &mut *conn else { + bail!("Write half closed"); + }; + match w { + SocketWrite::Unix(w) => { + w.write_all(&BASE64_STANDARD.decode(&data)?).await?; + w.flush().await?; + } + SocketWrite::Tcp(w) => { + w.write_all(&BASE64_STANDARD.decode(&data)?).await?; + w.flush().await?; + } + } + } else { + let w = conn.write.lock().await.take(); + if let Some(mut w) = w { + match &mut w { + SocketWrite::Unix(w) => w.shutdown().await?, + SocketWrite::Tcp(w) => w.shutdown().await?, + } + } + } + Ok(()) + } + + async fn handle_socket_send(self: Arc, id: u64, socket_id: u64, data: Option) { + let r = self.handle_socket_send_inner(socket_id, data).await; + self.send_result(id, r).await; + } + + pub async fn handle_command( + self: Arc, + rt: RunToken, + mut child: Child, + command_id: u64, + ) -> Result<()> { + let stdout = child.stdout.take(); + let stderr = child.stderr.take(); + + let mut do_handle_stdout = true; + let s2: Arc = self.clone(); + let handle_stdout = async move { + let Some(mut fd) = stdout else { + return Ok::<_, anyhow::Error>(()); + }; + let mut buf = BytesMut::with_capacity(64 * 1024); + loop { + buf.clear(); + match fd.read_buf(&mut buf).await { + Ok(0) => break, + Ok(_) => { + s2.send_message(ClientHostMessage::CommandStdout { + command_id, + data: Some(BASE64_STANDARD.encode(&buf)), + }) + .await; + } + Err(e) => bail!("Failed to read from child {:?}", e), + } + } + s2.send_message(ClientHostMessage::CommandStdout { + command_id, + data: None, + }) + .await; + Ok(()) + }; + + let mut do_handle_stderr = true; + let s2: Arc = self.clone(); + let handle_stderr = async move { + let Some(mut fd) = stderr else { + return Ok::<_, anyhow::Error>(()); + }; + let mut buf = BytesMut::with_capacity(64 * 1024); + loop { + buf.clear(); + match fd.read_buf(&mut buf).await { + Ok(0) => break, + Ok(_) => { + s2.send_message(ClientHostMessage::CommandStderr { + command_id, + data: Some(BASE64_STANDARD.encode(&buf)), + }) + .await; + } + Err(e) => bail!("Failed to read from child {:?}", e), + } + } + s2.send_message(ClientHostMessage::CommandStderr { + command_id, + data: None, + }) + .await; + Ok(()) + }; + + pin!(handle_stdout, handle_stderr); + + while do_handle_stdout || do_handle_stderr { + select! { + _ = rt.cancelled() => { + return Ok(()) + }, + r = &mut handle_stdout, if do_handle_stdout => { + r?; + do_handle_stdout = false; + }, + r = &mut handle_stderr, if do_handle_stderr => { + r?; + do_handle_stderr = false; + } + } + } + + let r = cancelable(&rt, child.wait()).await; + self.command_pids.lock().unwrap().remove(&command_id); + self.command_stdins.lock().unwrap().remove(&command_id); + let w = r??; + let code = w.code().unwrap_or_default(); + let signal = w.signal(); + + self.send_message(ClientHostMessage::CommandFinished { + command_id, + code, + signal, + }) + .await; + Ok(()) + } + + pub async fn handle_command_spawn_inner( + self: &Arc, + msg: CommandSpawnMessage, + ) -> Result<()> { + if self + .command_pids + .lock() + .unwrap() + .contains_key(&msg.command_id) + { + bail!("command_id is is use"); + } + let mut cmd = tokio::process::Command::new(msg.program); + cmd.args(msg.args); + if let Some(env) = msg.env { + cmd.envs(env); + } + if let Some(cwd) = msg.cwd { + cmd.current_dir(cwd); + } + + if msg.forward_stdin { + cmd.stdin(Stdio::piped()); + } else { + cmd.stdin(Stdio::null()); + } + if msg.forward_stdout { + cmd.stdout(Stdio::piped()); + } else { + cmd.stdout(Stdio::null()); + } + if msg.forward_stderr { + cmd.stderr(Stdio::piped()); + } else { + cmd.stderr(Stdio::null()); + } + cmd.kill_on_drop(true); + + let mut child = cmd.spawn().context("Failed to spawn command")?; + + self.command_pids + .lock() + .unwrap() + .insert(msg.command_id, child.id().context("missing pid")?); + if let Some(stdin) = child.stdin.take() { + self.command_stdins + .lock() + .unwrap() + .insert(msg.command_id, Arc::new(tokio::sync::Mutex::new(stdin))); + } + + let s2 = self.clone(); + TaskBuilder::new(format!("handle_command_{}", msg.command_id)) + .shutdown_order(0) + .create(|rt| async move { s2.handle_command(rt, child, msg.command_id).await }); + + Ok(()) + } + + pub async fn handle_command_spawn(self: Arc, msg: CommandSpawnMessage) { + let id = msg.id; + let r = self.handle_command_spawn_inner(msg).await; + self.send_result(id, r).await; + } + + pub async fn handle_command_stdin_inner( + self: &Arc, + command_id: u64, + data: Option, + ) -> Result<()> { + if let Some(data) = data { + let data = BASE64_STANDARD.decode(&data)?; + let Some(stdin) = self + .command_stdins + .lock() + .unwrap() + .get(&command_id) + .cloned() + else { + bail!("Stdin is closed"); + }; + let mut stdin = stdin.lock().await; + stdin.write_all(&data).await?; + stdin.flush().await?; + } else { + self.command_stdins.lock().unwrap().remove(&command_id); + } + Ok(()) + } + + pub async fn handle_command_stdin( + self: Arc, + id: u64, + command_id: u64, + data: Option, + ) { + let r = self.handle_command_stdin_inner(command_id, data).await; + self.send_result(id, r).await; + } + + pub async fn handle_command_signal_inner( + self: &Arc, + command_id: u64, + signal: i32, + ) -> Result<()> { + let Some(pid) = self.command_pids.lock().unwrap().get(&command_id).copied() else { + bail!("Command not found"); + }; + let signal = Signal::try_from(signal)?; + let pid = Pid::from_raw(pid as libc::pid_t); + nix::sys::signal::kill(pid, signal).context("Kill failed")?; + Ok(()) + } + + pub async fn handle_command_signal(self: Arc, id: u64, command_id: u64, signal: i32) { + let r = self.handle_command_signal_inner(command_id, signal).await; + self.send_result(id, r).await; + } + fn handle_message(self: &Arc, message: HostClientMessage) { match message { HostClientMessage::Data(d) => { @@ -722,6 +1130,36 @@ impl Client { } => { tokio::spawn(self.clone().handle_write_file(id, path, content, mode)); } + HostClientMessage::SocketConnect { id, socket_id, dst } => { + tokio::spawn(self.clone().handle_socket_connect(id, socket_id, dst)); + } + HostClientMessage::SocketClose { id, socket_id } => { + tokio::spawn(self.clone().handle_socket_close(id, socket_id)); + } + HostClientMessage::SocketSend { + id, + socket_id, + data, + } => { + tokio::spawn(self.clone().handle_socket_send(id, socket_id, data)); + } + HostClientMessage::CommandSpawn(msg) => { + tokio::spawn(self.clone().handle_command_spawn(msg)); + } + HostClientMessage::CommandStdin { + id, + command_id, + data, + } => { + tokio::spawn(self.clone().handle_command_stdin(id, command_id, data)); + } + HostClientMessage::CommandSignal { + id, + command_id, + signal, + } => { + tokio::spawn(self.clone().handle_command_signal(id, command_id, signal)); + } } } @@ -1645,6 +2083,9 @@ pub async fn client_daemon(config: Config, args: ClientDaemon) -> Result<()> { journal_socket, password, metrics_token, + sockets: Default::default(), + command_pids: Default::default(), + command_stdins: Default::default(), }); TaskBuilder::new("run_control") diff --git a/src/bin/sadmin/connection.rs b/src/bin/sadmin/connection.rs index afd49fc..5456d5c 100644 --- a/src/bin/sadmin/connection.rs +++ b/src/bin/sadmin/connection.rs @@ -1,5 +1,6 @@ use anyhow::{Context, Result, bail}; use base64::{Engine, prelude::BASE64_STANDARD}; +use bytes::Bytes; use futures_util::{SinkExt, StreamExt}; use sadmin2::action_types::{ IClientAction, IGenerateKey, ILogin, IRequestAuthStatus, IServerAction, Ref, @@ -11,6 +12,7 @@ use std::{ fs::OpenOptions, io::{BufRead, Write}, path::PathBuf, + sync::atomic::AtomicU64, }; use tokio_tungstenite::tungstenite::Message as WSMessage; @@ -37,6 +39,126 @@ pub struct Config { type Wss = tokio_tungstenite::WebSocketStream>; +pub struct ConnectionSend { + send: futures::stream::SplitSink, +} + +impl ConnectionSend { + pub async fn send_message_str(&mut self, msg: String) -> Result<()> { + self.send.send(WSMessage::text(msg)).await?; + Ok(()) + } + + pub async fn ping(&mut self) -> Result<()> { + self.send + .send(WSMessage::Ping(([42, 41]).as_slice().into())) + .await?; + Ok(()) + } + + pub async fn pong(&mut self, v: Bytes) -> Result<()> { + self.send.send(WSMessage::Pong(v)).await?; + Ok(()) + } + + pub async fn close(&mut self) -> Result<()> { + self.send.close().await?; + Ok(()) + } + + pub fn into2(self) -> std::sync::Arc { + std::sync::Arc::new(ConnectionSend2 { + idc: AtomicU64::new(2), + response_handlers: Default::default(), + send: tokio::sync::Mutex::new(self), + }) + } +} + +pub struct ConnectionSend2 { + idc: AtomicU64, + response_handlers: std::sync::Mutex< + std::collections::HashMap>, + >, + send: tokio::sync::Mutex, +} + +impl ConnectionSend2 { + pub fn next_id(&self) -> u64 { + self.idc.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + } + + pub async fn send_message_with_response(&self, msg: &IClientAction) -> Result { + let msg_id = msg.msg_id().context("No message id")?; + let m = serde_json::to_string(msg)?; + let (s, r) = tokio::sync::oneshot::channel(); + self.response_handlers.lock().unwrap().insert(msg_id, s); + self.send.lock().await.send_message_str(m).await?; + let r = r.await.context("r failed")?; + if let IServerAction::Response(r) = &r { + if let Some(e) = &r.error { + bail!("Remote error: {}", e); + } + } + Ok(r) + } + + pub fn handle_response(&self, msg_id: u64, act: IServerAction) { + if let Some(v) = self.response_handlers.lock().unwrap().remove(&msg_id) { + let _ = v.send(act); + } + } + + pub async fn ping(&self) -> Result<()> { + self.send.lock().await.ping().await + } + + pub async fn pong(&self, v: Bytes) -> Result<()> { + self.send.lock().await.pong(v).await + } + + pub async fn close(&self) -> Result<()> { + self.send.lock().await.close().await + } +} + +#[allow(clippy::large_enum_variant)] +pub enum ConnectionRecvRes { + Message(IServerAction), + SendPong(Bytes), +} + +pub struct ConnectionRecv { + recv: futures::stream::SplitStream, +} + +impl ConnectionRecv { + pub async fn recv(&mut self) -> Result { + loop { + let msg = match self.recv.next().await.context("Expected package")?? { + WSMessage::Text(msg) => msg.into(), + WSMessage::Binary(msg) => msg, + WSMessage::Ping(v) => { + return Ok(ConnectionRecvRes::SendPong(v)); + } + WSMessage::Pong(_) => continue, + WSMessage::Close(_) => continue, + WSMessage::Frame(_) => continue, + }; + match serde_json::from_slice(&msg) { + Ok(v) => break Ok(ConnectionRecvRes::Message(v)), + Err(e) => eprintln!( + "Invalid message: {:?} at {}:{}\n{}", + e, + e.line(), + e.column(), + String::from_utf8_lossy(&msg) + ), + } + } + } +} + pub struct Connection { pub cookie_file: PathBuf, ca_file: PathBuf, @@ -363,4 +485,9 @@ impl Connection { Ok(()) } + + pub fn split(self) -> (ConnectionSend, ConnectionRecv) { + let (send, recv) = self.stream.split(); + (ConnectionSend { send }, ConnectionRecv { recv }) + } } diff --git a/src/bin/sadmin/main.rs b/src/bin/sadmin/main.rs index eaa3ffb..eccc8f2 100644 --- a/src/bin/sadmin/main.rs +++ b/src/bin/sadmin/main.rs @@ -27,6 +27,7 @@ mod list_deployments; mod list_images; #[cfg(feature = "daemon")] mod persist_daemon; +mod port; mod run; #[cfg(feature = "daemon")] mod service_control; @@ -37,6 +38,8 @@ mod upgrade; use run::{Run, Shell}; +use crate::port::ProxySocket; + #[derive(clap::Parser)] #[command(name = "sadmin")] #[command(version = include_str!("../../version.txt"))] @@ -105,6 +108,7 @@ enum Action { Run(Run), DebugServer, GetSecret(GetSecret), + ProxySocket(ProxySocket), } async fn auth(config: Config) -> Result<()> { @@ -235,5 +239,6 @@ async fn main() -> Result<()> { Action::Run(args) => run::run(config, args).await, Action::DebugServer => debug_server(config).await, Action::GetSecret(args) => get_secret(config, args).await, + Action::ProxySocket(act) => port::proxy(config, act).await, } } diff --git a/src/bin/sadmin/port.rs b/src/bin/sadmin/port.rs new file mode 100644 index 0000000..5e9d812 --- /dev/null +++ b/src/bin/sadmin/port.rs @@ -0,0 +1,164 @@ +use std::{ + collections::HashMap, + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, + time::Duration, +}; + +use anyhow::{Context, Result}; + +use crate::connection::{Config, Connection, ConnectionRecvRes, ConnectionSend2}; +use base64::{Engine, prelude::BASE64_STANDARD}; +use bytes::BytesMut; +use sadmin2::action_types::{IClientAction, IServerAction, ISocketConnect, ISocketSend}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::tcp::OwnedReadHalf, + pin, select, + sync::mpsc::UnboundedSender, +}; + +/// Proxy a port tcp on a remote machine +#[derive(clap::Parser)] +pub struct ProxySocket { + host: String, + local_port: u16, + /// Either host:port or unix socket path + destination: String, +} + +pub async fn handle_socket_inner( + mut socket_read: OwnedReadHalf, + send: Arc, + socket_id: u64, + cmd: &'static ProxySocket, + writer_shutdown_s: UnboundedSender<(u64, OwnedReadHalf)>, +) -> Result<()> { + if let Err(e) = send + .send_message_with_response(&IClientAction::SocketConnect(ISocketConnect { + msg_id: send.next_id(), + socket_id, + host: cmd.host.clone(), + dst: cmd.destination.clone(), + })) + .await + { + let _ = writer_shutdown_s.send((socket_id, socket_read)); + return Err(e).context("Unable to open port"); + } + + let mut buf = BytesMut::with_capacity(1024 * 64); + loop { + buf.clear(); + let v = match socket_read.read_buf(&mut buf).await { + Ok(0) => None, + Ok(_) => Some(BASE64_STANDARD.encode(&buf)), + Err(_) => None, + }; + let eof = v.is_none(); + if let Err(e) = send + .send_message_with_response(&IClientAction::SocketSend(ISocketSend { + msg_id: send.next_id(), + socket_id, + data: v, + })) + .await + { + let _ = writer_shutdown_s.send((socket_id, socket_read)); + return Err(e).context("Unable send bytes"); + } + if eof { + return Ok(()); + } + } +} + +pub async fn handle_socket( + socket_read: OwnedReadHalf, + send: Arc, + socket_id: u64, + cmd: &'static ProxySocket, + writer_shutdown_s: UnboundedSender<(u64, OwnedReadHalf)>, +) { + if let Err(e) = handle_socket_inner(socket_read, send, socket_id, cmd, writer_shutdown_s).await + { + eprintln!("Error in handle socket inner {:?}", e); + } +} + +pub async fn proxy(config: Config, cmd: ProxySocket) -> Result<()> { + let cmd = Box::leak(Box::new(cmd)); + let mut c = Connection::open(config, false).await?; + c.prompt_auth().await?; + let (send, mut recv) = c.split(); + + let mut socket_id = 1; + + let sock = tokio::net::TcpSocket::new_v4()?; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), cmd.local_port); + sock.set_reuseaddr(true)?; + sock.bind(addr)?; + + let sig = tokio::signal::ctrl_c(); + pin!(sig); + let mut interval = tokio::time::interval(Duration::from_secs(66)); + let listener = sock.listen(1024)?; + + let (writer_shutdown_s, mut writer_shutdown_r) = tokio::sync::mpsc::unbounded_channel(); + + let mut socket_write_halfs = HashMap::new(); + let send = send.into2(); + loop { + select! { + _ = &mut sig => {break}, + r = listener.accept() => { + let (socket, addr) = r?; + socket_id += 1; + let (socket_read, socket_write) = socket.into_split(); + socket_write_halfs.insert(socket_id, socket_write); + tokio::task::spawn(handle_socket(socket_read, send.clone(), socket_id, cmd, writer_shutdown_s.clone())); + println!("Accepting proxy connection from {:?}", addr); + } + r = recv.recv() => { + match r? { + ConnectionRecvRes::Message(act) => { + match act { + IServerAction::Response(act) => { + send.handle_response(act.msg_id, IServerAction::Response(act)) + } + IServerAction::SocketRecv(act) => { + match act.data { + None => {socket_write_halfs.remove(&act.socket_id);} + Some(data) => { + if let Some(w) = socket_write_halfs.get_mut(&act.socket_id) { + w.write_all(&BASE64_STANDARD.decode(&data)?).await?; + w.flush().await?; + } + } + } + } + _ => () + } + } + ConnectionRecvRes::SendPong(bytes) => { + send.pong(bytes).await?; + } + } + } + r = writer_shutdown_r.recv() => { + if let Some((socket_id, read_half)) = r { + if let Some(write_half) = socket_write_halfs.remove(&socket_id) { + read_half.reunite(write_half)?.shutdown().await?; + } + } + } + _ = interval.tick() => { + send.ping().await?; + } + } + } + println!("Stopping"); + send.close().await?; + + Ok(()) +} diff --git a/src/bin/sadmin/run.rs b/src/bin/sadmin/run.rs index 02ca055..05bb99e 100644 --- a/src/bin/sadmin/run.rs +++ b/src/bin/sadmin/run.rs @@ -1,13 +1,21 @@ -use anyhow::{Result, bail}; +use anyhow::{Context, Result, bail}; +use base64::Engine; +use base64::prelude::BASE64_STANDARD; +use bytes::BytesMut; use futures_util::{SinkExt, StreamExt}; -use sadmin2::action_types::{IClientAction, IRequestInitialState, IServerAction, ObjectType}; -use std::time::Duration; +use sadmin2::action_types::{ + IClientAction, ICommandSignal, ICommandSpawn, ICommandStdin, IRequestInitialState, + IServerAction, ObjectType, +}; +use std::collections::HashMap; +use std::sync::atomic::AtomicU64; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::signal; +use tokio::{pin, select}; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite}; -use crate::connection::{Config, Connection}; +use crate::connection::{Config, Connection, ConnectionRecvRes}; /// Deauthenticate your user #[derive(clap::Parser)] @@ -19,6 +27,23 @@ pub struct Shell { #[derive(clap::Parser)] pub struct Run { host: String, + /// Run Command with this env format is k=v can be specified multiple times + /// if the value is omitted copy it from the env of this process + #[clap(long)] + env: Vec, + /// Run command with this work dir + #[clap(long)] + cwd: Option, + /// Do not forward stdout + #[clap(long)] + no_stdout: bool, + /// Do not forward stderr + #[clap(long)] + no_stderr: bool, + // Do not forward stdin + #[clap(long)] + no_stdin: bool, + command: String, #[arg(trailing_var_arg = true, allow_hyphen_values = true)] args: Vec, @@ -175,51 +200,157 @@ pub async fn shell(config: Config, args: Shell) -> Result<()> { } pub async fn run(config: Config, args: Run) -> Result<()> { - let (mut send, mut recv) = connect(config, args.host, 60, 100).await?.split(); - - let mut cmd = String::new(); - for arg in [args.command].iter().chain(&args.args) { - if cmd.is_empty() { - cmd.push('d'); - } else { - cmd.push(' '); - } - for c in arg.chars() { - if matches!(c, '$' | '`' | '\\' | '\"') { - cmd.push('\\'); + let mut c = Connection::open(config, false).await?; + c.prompt_auth().await?; + let (send, mut recv) = c.split(); + + let send = send.into2(); + const SPAWN_MSG_ID: u64 = 1; + let next_msg_id = AtomicU64::new(SPAWN_MSG_ID + 1); + let next_msg_id = &next_msg_id; + let command_id = 0; + + let env = if args.env.is_empty() { + None + } else { + let mut env = HashMap::new(); + for v in args.env { + match v.split_once("=") { + Some((k, v)) => { + env.insert(k.to_string(), v.to_string()); + } + None => { + if let Ok(w) = std::env::var(&v) { + env.insert(v, w); + } + } } - cmd.push(c); } - } - cmd.push('\n'); - cmd.push('\0'); - - let send_command = async { - send.send(tungstenite::Message::text(cmd)).await?; - send.send(tungstenite::Message::text("d\x04\0")).await?; - tokio::time::sleep(Duration::from_secs(100000000)).await; - Ok::<_, anyhow::Error>(()) + Some(env) }; - let handle_stdout = async { - while let Some(data) = recv.next().await { - match data? { - tungstenite::Message::Text(t) => { - let mut stdout = tokio::io::stdout(); - stdout.write_all(t.as_bytes()).await?; - stdout.flush().await?; - } - tungstenite::Message::Close(_) => break, - _ => (), + // Send spawn message + let msg = IClientAction::CommandSpawn(ICommandSpawn { + msg_id: SPAWN_MSG_ID, + command_id, + host: args.host, + program: args.command, + args: args.args, + forward_stdin: !args.no_stdin, + forward_stdout: !args.no_stdout, + forward_stderr: !args.no_stdout, + env, + cwd: args.cwd, + }); + let mut do_send_spawn = true; + let send_spawn = send.send_message_with_response(&msg); + pin!(send_spawn); + + let mut stdout = tokio::io::stdout(); + let mut stderr = tokio::io::stdout(); + + let send2 = send.clone(); + let mut do_process_stdin = false; + let process_stdin = async move { + let mut stdin_buf = BytesMut::with_capacity(64 * 1024); + let mut stdin = tokio::io::stdin(); + loop { + let r = stdin.read_buf(&mut stdin_buf).await?; + let data = if r == 0 { + None + } else { + let data = BASE64_STANDARD.encode(&stdin_buf); + stdin_buf.clear(); + Some(data) + }; + let msg_id = next_msg_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let stop = data.is_none(); + send2 + .send_message_with_response(&IClientAction::CommandStdin(ICommandStdin { + msg_id, + command_id, + data, + })) + .await + .context("Failed sending stdin to remote command")?; + if stop { + return Ok::<_, anyhow::Error>(()); } } - Ok::<(), anyhow::Error>(()) }; + pin!(process_stdin); - tokio::select! { - _ = send_command => {}, - _ = handle_stdout => {} + let send2 = send.clone(); + let mut do_process_ctrl_c: bool = true; + let process_ctrl_c = async move { + signal::ctrl_c().await.context("Reading ctrl+c")?; + let msg_id = next_msg_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + send2 + .send_message_with_response(&IClientAction::CommandSignal(ICommandSignal { + msg_id, + command_id, + signal: 2, + })) + .await + .context("Failed sending sigint to remote command")?; + Ok::<_, anyhow::Error>(()) }; + pin!(process_ctrl_c); - Ok(()) + loop { + select! { + r = &mut send_spawn, if do_send_spawn => { + r.context("Unable to spawn remote process")?; + do_send_spawn = false; + } + r = &mut process_stdin, if do_process_stdin => { + r.context("Process stdin failed")?; + do_process_stdin = false; + } + r = recv.recv() => { + match r.context("Failed to recv message from backend")? { + ConnectionRecvRes::Message(IServerAction::Response(r)) => { + if r.msg_id == SPAWN_MSG_ID { + if let Some(e) = r.error { + bail!("Failed to spawn process: {}", e); + } + do_process_stdin = !args.no_stdin; + } else { + send.handle_response(r.msg_id, IServerAction::Response(r)); + } + }, + ConnectionRecvRes::Message(IServerAction::CommandFinished(a)) => { + if let Some(signal) = a.signal { + if signal != 2 || do_process_ctrl_c { + eprintln!("Command finished with signal: {}", signal); + } + } + std::process::exit(a.code); + } + ConnectionRecvRes::Message(IServerAction::CommandStdout(a)) => { + if let Some(d) = a.data { + stdout.write_all(&BASE64_STANDARD.decode(&d)?).await?; + stdout.flush().await?; + } else { + //let _ = close(1).context("Closing stdout"); + } + } + ConnectionRecvRes::Message(IServerAction::CommandStderr(a)) => { + if let Some(d) = a.data { + stderr.write_all(&BASE64_STANDARD.decode(&d)?).await?; + stderr.flush().await?; + } else { + //let _ = close(2).context("Closing stderr"); + } + } + ConnectionRecvRes::Message(_) => (), + ConnectionRecvRes::SendPong(v) => send.pong(v).await?, + } + } + r = &mut process_ctrl_c, if do_process_ctrl_c => { + r?; + do_process_ctrl_c = false; + } + } + } } diff --git a/src/bin/server/hostclient.rs b/src/bin/server/hostclient.rs index 06b688c..0b026d2 100644 --- a/src/bin/server/hostclient.rs +++ b/src/bin/server/hostclient.rs @@ -33,7 +33,7 @@ use crate::{ action_types::{IHostDown, IHostUp, IServerAction}, crt, crypt, db, state::State, - webclient, + webclient::{self}, }; use sadmin2::type_types::HOST_ID; @@ -70,9 +70,16 @@ pub struct HostClient { hostname: String, writer: TMutex>>, job_sinks: Mutex>>, + message_handlers: Mutex>>, killed_jobs: Mutex>, next_job_id: AtomicU64, run_token: RunToken, + next_command_id: AtomicU64, + next_socket_id: AtomicU64, + pub command_message_handlers: + Mutex>>, + pub socket_message_handlers: + Mutex>>, } async fn write_all_and_flush(v: &mut WriteHalf>, data: &[u8]) -> Result<()> { @@ -105,6 +112,16 @@ impl HostClient { .fetch_add(1, std::sync::atomic::Ordering::SeqCst) } + pub fn next_command_id(&self) -> u64 { + self.next_command_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst) + } + + pub fn next_socket_id(&self) -> u64 { + self.next_socket_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst) + } + pub async fn send_message(&self, msg: &HostClientMessage) -> Result<()> { let mut msg = serde_json::to_vec(msg)?; msg.push(0x1e); @@ -135,6 +152,76 @@ impl HostClient { Ok(()) } + pub async fn send_message_with_response( + &self, + msg: &HostClientMessage, + ) -> Result { + let id = msg.job_id().context("Missing job id")?; + let mut msg = serde_json::to_vec(msg)?; + msg.push(0x1e); + let mut writer = cancelable(&self.run_token, self.writer.lock()).await?; + + let (send, recv) = tokio::sync::oneshot::channel(); + self.message_handlers.lock().unwrap().insert(id, send); + + match cancelable( + &self.run_token, + tokio::time::timeout( + Duration::from_secs(60), + write_all_and_flush(&mut writer, &msg), + ), + ) + .await + { + Ok(Ok(Ok(()))) => (), + Ok(Ok(Err(e))) => { + self.run_token.cancel(); + self.message_handlers.lock().unwrap().remove(&id); + bail!("Failure sending message to {}: {:?}", self.hostname, e); + } + Ok(Err(_)) => { + self.run_token.cancel(); + self.message_handlers.lock().unwrap().remove(&id); + bail!("Timeout sending message to {}", self.hostname) + } + Err(_) => { + self.message_handlers.lock().unwrap().remove(&id); + bail!("Host client aborted"); + } + } + std::mem::drop(writer); + + match cancelable( + &self.run_token, + tokio::time::timeout(Duration::from_secs(60), recv), + ) + .await + { + Ok(Ok(Ok(ClientHostMessage::Failure(f)))) => { + bail!( + "failure on host {}: {}", + self.hostname, + f.message.as_deref().unwrap_or_default() + ); + } + Ok(Ok(Ok(r))) => Ok(r), + Ok(Ok(Err(e))) => { + self.run_token.cancel(); + self.message_handlers.lock().unwrap().remove(&id); + bail!("Failure receiving message from {}: {:?}", self.hostname, e); + } + Ok(Err(_)) => { + self.run_token.cancel(); + self.message_handlers.lock().unwrap().remove(&id); + bail!("Timeout receiving message from {}", self.hostname) + } + Err(_) => { + self.message_handlers.lock().unwrap().remove(&id); + bail!("Host client aborted"); + } + } + } + pub async fn start_job(self: &Arc, msg: &HostClientMessage) -> Result { let Some(id) = msg.job_id() else { bail!("Not a job message") @@ -219,12 +306,56 @@ impl HostClient { pong_time = ping_time + FOREVER; } } + ClientHostMessage::SocketRecv{ socket_id, data } => { + match self.socket_message_handlers.lock().unwrap().get(&socket_id) { + None => warn!("Get recv for unknows socket {}", socket_id), + Some(v) => { + if let Err(e) = v.send(ClientHostMessage::SocketRecv{ socket_id, data }) { + warn!("Failed forwarding recv message for socket {}: {:?}", socket_id, e); + } + } + } + } + ClientHostMessage::CommandStdout{ command_id, data } => { + match self.command_message_handlers.lock().unwrap().get(&command_id) { + None => warn!("Get stdout for unknown command {}", command_id), + Some(v) => { + if let Err(e) = v.send(ClientHostMessage::CommandStdout{ command_id, data }) { + warn!("Failed forwarding stdout message to command {}: {:?}", command_id, e); + } + } + } + } + ClientHostMessage::CommandStderr{ command_id, data } => { + match self.command_message_handlers.lock().unwrap().get(&command_id) { + None => warn!("Get stderr for unknown command {}", command_id), + Some(v) => { + if let Err(e) = v.send(ClientHostMessage::CommandStderr{ command_id, data }) { + warn!("Failed forwarding stderr message to command {}: {:?}", command_id, e); + } + } + } + } + ClientHostMessage::CommandFinished{ command_id, code, signal } => { + match self.command_message_handlers.lock().unwrap().get(&command_id) { + None => warn!("Get finished for unknown command {}", command_id), + Some(v) => { + if let Err(e) = v.send(ClientHostMessage::CommandFinished{ command_id, code, signal }) { + warn!("Failed forwarding finished message to command {}: {:?}", command_id, e); + } + } + } + } msg => { if let Some(id) = msg.job_id() { if let Some(job) = self.job_sinks.lock().unwrap().get(&id) { if let Err(e) = job.send(msg) { error!("Unable to handle job message: {:?}", e); } + } else if let Some(s) = self.message_handlers.lock().unwrap().remove(&id) { + if let Err(e) = s.send(msg) { + error!("Unable to handle job message: {:?}", e); + } } else if self.clone().spawn_kill_job(id) { error!("Got message from unknown job {} on host {}", id, self.hostname); } @@ -478,9 +609,14 @@ async fn handle_host_client( hostname, writer: TMutex::new(writer), job_sinks: Default::default(), + message_handlers: Default::default(), next_job_id: AtomicU64::new(j as u64), run_token: run_token.clone(), killed_jobs: Default::default(), + next_command_id: AtomicU64::new(1), + next_socket_id: AtomicU64::new(1), + command_message_handlers: Default::default(), + socket_message_handlers: Default::default(), }); if let Some(c) = state.host_clients.lock().unwrap().insert(id, hc.clone()) { info!( diff --git a/src/bin/server/webclient.rs b/src/bin/server/webclient.rs index 69ea0b7..7c94679 100644 --- a/src/bin/server/webclient.rs +++ b/src/bin/server/webclient.rs @@ -9,9 +9,9 @@ use log::{error, info, warn}; use serde::{Deserialize, Serialize}; use sqlx_type::{query, query_as}; use std::{ - collections::HashMap, + collections::{HashMap, hash_map::Entry}, net::SocketAddr, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, Weak}, time::Duration, }; use tokio::net::TcpListener; @@ -36,7 +36,7 @@ use crate::{ docker::{deploy_service, list_deployment_history, list_deployments, redploy_service}, docker_web, get_auth::get_auth, - hostclient::JobHandle, + hostclient::{HostClient, JobHandle}, modified_files, msg, setup, state::State, terminal, @@ -56,12 +56,13 @@ use axum::{ }; use sadmin2::{ action_types::{ - IClientAction, IGetSecretRes, IRunCommand, IRunCommandFinished, IRunCommandOutput, - IServerAction, + IClientAction, ICommandFinished, ICommandSignal, ICommandSpawn, ICommandStderr, + ICommandStdin, ICommandStdout, IGetSecretRes, IResponse, IRunCommand, IRunCommandFinished, + IRunCommandOutput, IServerAction, ISocketClose, ISocketConnect, ISocketRecv, ISocketSend, }, client_message::{ - ClientHostMessage, DataSource, HostClientMessage, RunScriptMessage, RunScriptOutType, - RunScriptStdinType, + ClientHostMessage, CommandSpawnMessage, DataSource, HostClientMessage, RunScriptMessage, + RunScriptOutType, RunScriptStdinType, }, finite_float::ToFinite, page_types::{IObjectPage, IPage}, @@ -98,6 +99,8 @@ pub struct WebClient { auth: Mutex, run_token: RunToken, command_tokens: Mutex>, + pub sockets: Mutex)>>, + pub commands: Mutex)>>, } impl WebClient { @@ -460,6 +463,353 @@ impl WebClient { } } + async fn handle_socket_messages_inner( + &self, + rt: RunToken, + socket_id: u64, + mut r: tokio::sync::mpsc::UnboundedReceiver, + ) -> Result<()> { + while let Ok(Ok(Some(m))) = cancelable(&self.run_token, cancelable(&rt, r.recv())).await { + if let ClientHostMessage::SocketRecv { data, .. } = m { + let end = data.is_none(); + self.send_message( + &rt, + IServerAction::SocketRecv(ISocketRecv { socket_id, data }), + ) + .await?; + if end { + break; + } + } + } + Ok(()) + } + + async fn handle_socket_messages( + self: Arc, + rt: RunToken, + r: tokio::sync::mpsc::UnboundedReceiver, + socket_id: u64, + host_socket_id: u64, + host: Weak, + ) -> Result<()> { + let r = self.handle_socket_messages_inner(rt, socket_id, r).await; + if let Some(host) = host.upgrade() { + host.socket_message_handlers + .lock() + .unwrap() + .remove(&host_socket_id); + } + r + } + + pub async fn handle_socket_connect( + self: &Arc, + rt: &RunToken, + state: &State, + act: ISocketConnect, + ) -> Result<()> { + let mut host = None; + for hc in state.host_clients.lock().unwrap().values() { + if hc.hostname() == act.host { + host = Some(hc.clone()); + } + } + let Some(host) = host else { + bail!("Unable to find host"); + }; + let socket_id = host.next_socket_id(); + match self.sockets.lock().unwrap().entry(act.socket_id) { + Entry::Occupied(_) => bail!("Socket id in use"), + Entry::Vacant(e) => { + e.insert((socket_id, Arc::downgrade(&host))); + } + } + + let (s, r) = tokio::sync::mpsc::unbounded_channel(); + TaskBuilder::new("socket_message_forwarder") + .shutdown_order(-1) + .create(|rt| { + self.clone().handle_socket_messages( + rt, + r, + act.socket_id, + socket_id, + Arc::downgrade(&host), + ) + }); + + host.socket_message_handlers + .lock() + .unwrap() + .insert(socket_id, s); + + let r = cancelable( + rt, + host.send_message_with_response(&HostClientMessage::SocketConnect { + id: host.next_job_id(), + socket_id, + dst: act.dst, + }), + ) + .await; + if !matches!(r, Ok(Ok(_))) { + self.sockets.lock().unwrap().remove(&act.socket_id); + host.socket_message_handlers + .lock() + .unwrap() + .remove(&socket_id); + } + if let Ok(r) = r { + r?; + } + Ok(()) + } + + pub async fn handle_socket_close( + self: &Arc, + rt: &RunToken, + act: ISocketClose, + ) -> Result<()> { + let (socket_id, host) = self + .sockets + .lock() + .unwrap() + .remove(&act.socket_id) + .with_context(|| format!("Unknown socket_id {}", act.socket_id))?; + let host = host.upgrade().context("Dead host")?; + host.socket_message_handlers + .lock() + .unwrap() + .remove(&socket_id); + cancelable( + rt, + host.send_message_with_response(&HostClientMessage::SocketClose { + id: host.next_job_id(), + socket_id, + }), + ) + .await??; + Ok(()) + } + + pub async fn handle_socket_send( + self: &Arc, + rt: &RunToken, + act: ISocketSend, + ) -> Result<()> { + let (socket_id, host) = match self.sockets.lock().unwrap().entry(act.socket_id) { + Entry::Occupied(mut e) => match e.get_mut().1.upgrade() { + Some(v) => (e.get().0, v), + None => { + e.remove(); + bail!("Dead host") + } + }, + Entry::Vacant(_) => bail!("Unknown socket_id {}", act.socket_id), + }; + cancelable( + rt, + host.send_message_with_response(&HostClientMessage::SocketSend { + id: host.next_job_id(), + socket_id, + data: act.data, + }), + ) + .await??; + Ok(()) + } + + async fn handle_command_messages_inner( + &self, + rt: RunToken, + mut r: tokio::sync::mpsc::UnboundedReceiver, + command_id: u64, + ) -> Result<()> { + while let Ok(Ok(Some(m))) = cancelable(&self.run_token, cancelable(&rt, r.recv())).await { + match m { + ClientHostMessage::CommandStdout { data, .. } => { + self.send_message( + &rt, + IServerAction::CommandStdout(ICommandStdout { command_id, data }), + ) + .await?; + } + ClientHostMessage::CommandStderr { data, .. } => { + self.send_message( + &rt, + IServerAction::CommandStderr(ICommandStderr { command_id, data }), + ) + .await?; + } + ClientHostMessage::CommandFinished { code, signal, .. } => { + self.send_message( + &rt, + IServerAction::CommandFinished(ICommandFinished { + command_id, + code, + signal, + }), + ) + .await?; + } + _ => (), + } + } + Ok(()) + } + + pub async fn handle_command_messages( + self: Arc, + rt: RunToken, + r: tokio::sync::mpsc::UnboundedReceiver, + command_id: u64, + host_command_id: u64, + h: Weak, + ) -> Result<()> { + let r = self.handle_command_messages_inner(rt, r, command_id).await; + if let Some(h) = h.upgrade() { + h.command_message_handlers + .lock() + .unwrap() + .remove(&host_command_id); + } + r + } + + pub async fn handle_command_spawn( + self: &Arc, + rt: &RunToken, + state: &State, + act: ICommandSpawn, + ) -> Result<()> { + let mut host = None; + for hc in state.host_clients.lock().unwrap().values() { + if hc.hostname() == act.host { + host = Some(hc.clone()); + } + } + let Some(host) = host else { + bail!("Unable to find host"); + }; + let command_id = host.next_command_id(); + match self.commands.lock().unwrap().entry(act.command_id) { + Entry::Occupied(_) => bail!("command_id in use"), + Entry::Vacant(e) => { + e.insert((command_id, Arc::downgrade(&host))); + } + } + + let (s, r) = tokio::sync::mpsc::unbounded_channel(); + + TaskBuilder::new("command_message_forwarder") + .shutdown_order(-1) + .create(|rt| { + self.clone().handle_command_messages( + rt, + r, + act.command_id, + command_id, + Arc::downgrade(&host), + ) + }); + + host.command_message_handlers + .lock() + .unwrap() + .insert(command_id, s); + + let r = cancelable( + rt, + host.send_message_with_response(&HostClientMessage::CommandSpawn( + CommandSpawnMessage { + id: host.next_job_id(), + command_id, + program: act.program, + args: act.args, + env: act.env, + cwd: act.cwd, + forward_stdin: act.forward_stdin, + forward_stdout: act.forward_stdout, + forward_stderr: act.forward_stderr, + }, + )), + ) + .await; + if !matches!(r, Ok(Ok(_))) { + self.commands.lock().unwrap().remove(&act.command_id); + host.command_message_handlers + .lock() + .unwrap() + .remove(&command_id); + } + Ok(()) + } + + pub async fn handle_command_stdin( + self: &Arc, + rt: &RunToken, + act: ICommandStdin, + ) -> Result<()> { + let (command_id, host) = match self.commands.lock().unwrap().entry(act.command_id) { + Entry::Occupied(mut e) => match e.get_mut().1.upgrade() { + Some(v) => (e.get().0, v), + None => { + e.remove(); + bail!("Dead host") + } + }, + Entry::Vacant(_) => bail!("Unknown command_id {}", act.command_id), + }; + cancelable( + rt, + host.send_message_with_response(&HostClientMessage::CommandStdin { + id: host.next_job_id(), + command_id, + data: act.data, + }), + ) + .await??; + Ok(()) + } + + pub async fn handle_command_signal( + self: &Arc, + rt: &RunToken, + act: ICommandSignal, + ) -> Result<()> { + let (command_id, host) = match self.commands.lock().unwrap().entry(act.command_id) { + Entry::Occupied(mut e) => match e.get_mut().1.upgrade() { + Some(v) => (e.get().0, v), + None => { + e.remove(); + bail!("Dead host") + } + }, + Entry::Vacant(_) => bail!("Unknown command_id {}", act.command_id), + }; + cancelable( + rt, + host.send_message_with_response(&HostClientMessage::CommandSignal { + id: host.next_job_id(), + command_id, + signal: act.signal, + }), + ) + .await??; + Ok(()) + } + + pub async fn send_response(&self, rt: &RunToken, msg_id: u64, r: Result<()>) -> Result<()> { + let error = match r { + Ok(_) => None, + Err(e) => Some(format!("{:?}", e)), + }; + self.send_message(rt, IServerAction::Response(IResponse { msg_id, error })) + .await?; + Ok(()) + } + pub async fn handle_run_command_inner( &self, state: &State, @@ -572,7 +922,7 @@ os.execv(sys.argv[1], sys.argv[1:]) } pub async fn handle_message( - &self, + self: &Arc, state: &State, rt: RunToken, act: IClientAction, @@ -1524,6 +1874,84 @@ os.execv(sys.argv[1], sys.argv[1:]) ) .await?; } + IClientAction::SocketConnect(act) => { + if !self.get_auth().admin { + self.close(403).await?; + return Ok(()); + }; + if state.read_only { + self.close(503).await?; + return Ok(()); + } + let msg_id = act.msg_id; + let r = self.handle_socket_connect(&rt, state, act).await; + self.send_response(&rt, msg_id, r).await?; + } + IClientAction::SocketClose(act) => { + if !self.get_auth().admin { + self.close(403).await?; + return Ok(()); + }; + if state.read_only { + self.close(503).await?; + return Ok(()); + } + let msg_id = act.msg_id; + let r = self.handle_socket_close(&rt, act).await; + self.send_response(&rt, msg_id, r).await?; + } + IClientAction::SocketSend(act) => { + if !self.get_auth().admin { + self.close(403).await?; + return Ok(()); + }; + if state.read_only { + self.close(503).await?; + return Ok(()); + } + let msg_id = act.msg_id; + let r = self.handle_socket_send(&rt, act).await; + self.send_response(&rt, msg_id, r).await?; + } + IClientAction::CommandSpawn(act) => { + if !self.get_auth().admin { + self.close(403).await?; + return Ok(()); + }; + if state.read_only { + self.close(503).await?; + return Ok(()); + } + let msg_id = act.msg_id; + let r = self.handle_command_spawn(&rt, state, act).await; + self.send_response(&rt, msg_id, r).await?; + } + IClientAction::CommandSignal(act) => { + if !self.get_auth().admin { + self.close(403).await?; + return Ok(()); + }; + if state.read_only { + self.close(503).await?; + return Ok(()); + } + let msg_id = act.msg_id; + let r = self.handle_command_signal(&rt, act).await; + self.send_response(&rt, msg_id, r).await?; + } + IClientAction::CommandStdin(act) => { + if !self.get_auth().admin { + self.close(403).await?; + return Ok(()); + }; + if state.read_only { + self.close(503).await?; + return Ok(()); + } + let msg_id = act.msg_id; + let r = self.handle_command_stdin(&rt, act).await; + self.send_response(&rt, msg_id, r).await?; + } } Ok(()) } @@ -1593,8 +2021,10 @@ async fn handle_webclient(websocket: WebSocket, state: Arc, remote: Strin remote, sink: TMutex::new(sink), auth: Default::default(), - run_token, + run_token: run_token.clone(), command_tokens: Default::default(), + commands: Default::default(), + sockets: Default::default(), }); state .web_clients @@ -1604,7 +2034,46 @@ async fn handle_webclient(websocket: WebSocket, state: Arc, remote: Strin let e = webclient.handle_messages(&state, source).await; info!("Web client disconnected {}", webclient.remote); + let sockets = std::mem::take(&mut *webclient.sockets.lock().unwrap()); + let commands = std::mem::take(&mut *webclient.commands.lock().unwrap()); state.web_clients.lock().unwrap().remove(&CmpRef(webclient)); + run_token.cancel(); + + for (socket_id, host) in sockets.into_values() { + if let Some(host) = host.upgrade() { + host.socket_message_handlers + .lock() + .unwrap() + .remove(&socket_id); + if let Err(e) = host + .send_message(&HostClientMessage::SocketClose { + id: host.next_job_id(), + socket_id, + }) + .await + { + warn!("Unable to sent host socket close message {e}"); + } + } + } + for (command_id, host) in commands.into_values() { + if let Some(host) = host.upgrade() { + host.command_message_handlers + .lock() + .unwrap() + .remove(&command_id); + if let Err(e) = host + .send_message(&HostClientMessage::CommandSignal { + id: host.next_job_id(), + command_id, + signal: 2, + }) + .await + { + warn!("Unable to send command kill {e}"); + } + } + } e?; Ok(()) } diff --git a/src/client_message.rs b/src/client_message.rs index 201e394..d76c4f4 100644 --- a/src/client_message.rs +++ b/src/client_message.rs @@ -102,7 +102,7 @@ pub struct FailureMessage { pub message: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Default)] pub struct SuccessMessage { pub id: u64, #[serde(default, skip_serializing_if = "Option::is_none")] @@ -125,6 +125,21 @@ pub struct DeployServiceMessage { pub user: Option, } +#[derive(Debug, Serialize, Deserialize)] +pub struct CommandSpawnMessage { + pub id: u64, + pub command_id: u64, + pub program: String, + pub args: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub env: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cwd: Option, + pub forward_stdin: bool, + pub forward_stdout: bool, + pub forward_stderr: bool, +} + #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum HostClientMessage { @@ -149,6 +164,31 @@ pub enum HostClientMessage { content: String, mode: Option, }, + SocketConnect { + id: u64, + socket_id: u64, + dst: String, + }, + SocketClose { + id: u64, + socket_id: u64, + }, + SocketSend { + id: u64, + socket_id: u64, + data: Option, + }, + CommandSpawn(CommandSpawnMessage), + CommandStdin { + id: u64, + command_id: u64, + data: Option, + }, + CommandSignal { + id: u64, + command_id: u64, + signal: i32, + }, } impl HostClientMessage { @@ -164,6 +204,12 @@ impl HostClientMessage { HostClientMessage::Ping { .. } => None, HostClientMessage::ReadFile { id, .. } => Some(*id), HostClientMessage::WriteFile { id, .. } => Some(*id), + HostClientMessage::SocketConnect { id, .. } => Some(*id), + HostClientMessage::SocketClose { id, .. } => Some(*id), + HostClientMessage::SocketSend { id, .. } => Some(*id), + HostClientMessage::CommandSpawn(msg) => Some(msg.id), + HostClientMessage::CommandStdin { id, .. } => Some(*id), + HostClientMessage::CommandSignal { id, .. } => Some(*id), } } @@ -177,6 +223,12 @@ impl HostClientMessage { HostClientMessage::DeployService(_) => "deploy_service", HostClientMessage::ReadFile { .. } => "read_file", HostClientMessage::WriteFile { .. } => "read_file", + HostClientMessage::SocketConnect { .. } => "socket_connect", + HostClientMessage::SocketClose { .. } => "socket_close", + HostClientMessage::SocketSend { .. } => "socket_send", + HostClientMessage::CommandSpawn(_) => "command_run", + HostClientMessage::CommandStdin { .. } => "command_stdin", + HostClientMessage::CommandSignal { .. } => "command_signal", } } } @@ -199,6 +251,26 @@ pub enum ClientHostMessage { // Base64 encoded content: String, }, + SocketRecv { + socket_id: u64, + // Base64 encoded + data: Option, + }, + CommandStdout { + command_id: u64, + // Base64 encoded + data: Option, + }, + CommandStderr { + command_id: u64, + // Base64 encoded + data: Option, + }, + CommandFinished { + command_id: u64, + code: i32, + signal: Option, + }, } impl ClientHostMessage { @@ -209,6 +281,10 @@ impl ClientHostMessage { ClientHostMessage::Data(data_message) => Some(data_message.id), ClientHostMessage::Auth { .. } | ClientHostMessage::Pong { .. } => None, ClientHostMessage::ReadFileResult { id, .. } => Some(*id), + ClientHostMessage::SocketRecv { .. } => None, + ClientHostMessage::CommandStdout { .. } => None, + ClientHostMessage::CommandStderr { .. } => None, + ClientHostMessage::CommandFinished { .. } => None, } } @@ -220,6 +296,10 @@ impl ClientHostMessage { ClientHostMessage::Success(_) => "success", ClientHostMessage::Data(_) => "data", ClientHostMessage::ReadFileResult { .. } => "read_file_result", + ClientHostMessage::SocketRecv { .. } => "socket_recv", + ClientHostMessage::CommandStdout { .. } => "command_stdout", + ClientHostMessage::CommandStderr { .. } => "command_stderr", + ClientHostMessage::CommandFinished { .. } => "command_finished", } } }