From 110d9ea6f551970d2e6ef375a666fee0ad36fef3 Mon Sep 17 00:00:00 2001 From: danielshih Date: Mon, 4 May 2026 08:44:25 +0000 Subject: [PATCH 1/4] Add SQL builder utilities for PostgreSQL replication management - Introduced a new module `sql_builder` containing functions for building SQL statements related to replication slots and subscriptions. - Implemented quoting functions for identifiers and string literals to prevent SQL injection. - Added SQL builders for creating, altering, dropping, and managing replication slots and subscriptions. - Included tests for all new functionalities to ensure correctness and prevent regressions. - update SQL generation for START_REPLICATION to remove unnecessary parentheses and add null byte checks in quoting functions --- src/connection/libpq.rs | 286 +---- src/connection/native/connection.rs | 244 +---- src/lib.rs | 7 + src/sql_builder.rs | 1533 +++++++++++++++++++++++++++ 4 files changed, 1595 insertions(+), 475 deletions(-) create mode 100644 src/sql_builder.rs diff --git a/src/connection/libpq.rs b/src/connection/libpq.rs index 19cc319..a6ebbb0 100644 --- a/src/connection/libpq.rs +++ b/src/connection/libpq.rs @@ -48,44 +48,6 @@ use tokio::io::unix::AsyncFd; use tokio_util::sync::CancellationToken; use tracing::{debug, info, warn}; -/// Sanitize a string value for use in PostgreSQL replication protocol commands -/// by escaping single quotes (replacing ' with '') -/// -/// This prevents SQL injection when values are used in replication commands. -/// -/// # Arguments -/// * `value` - The string value to sanitize -/// -/// # Returns -/// A sanitized string safe for use in SQL string literals -#[inline] -fn sanitize_sql_string_value(value: &str) -> String { - value.replace('\'', "''") -} - -/// Sanitize a string value and wrap it in single quotes for SQL -#[inline] -fn quote_sql_string_value(value: &str) -> String { - format!("'{}'", sanitize_sql_string_value(value)) -} - -/// Quote a SQL identifier by escaping internal double quotes and wrapping in double quotes. -/// -/// In PostgreSQL, identifiers wrapped in double quotes must have any internal -/// double quotes escaped by doubling them (e.g., `"` becomes `""`). -/// -/// # Examples -/// ```ignore -/// assert_eq!(quote_sql_identifier("my_slot"), r#""my_slot""#); -/// assert_eq!(quote_sql_identifier(r#"a"b"#), r#""a""b""#); -/// ``` -#[inline] -fn quote_sql_identifier(identifier: &str) -> String { - format!("\"{}\"", identifier.replace('"', "\"\"")) -} - -pub use crate::types::INVALID_XLOG_REC_PTR; - /// Result of attempting to read from libpq's internal buffer #[derive(Debug)] enum ReadResult { @@ -341,27 +303,7 @@ impl PgReplicationConnection { start_lsn: XLogRecPtr, options: &[(&str, &str)], ) -> String { - let quoted_slot = quote_sql_identifier(slot_name); - let mut options_str = String::new(); - for (i, (key, value)) in options.iter().enumerate() { - if i > 0 { - options_str.push_str(", "); - } - let quoted_key = quote_sql_identifier(key); - let sanitized_value = sanitize_sql_string_value(value); - options_str.push_str(&format!("{quoted_key} '{sanitized_value}'")); - } - - if start_lsn == INVALID_XLOG_REC_PTR { - format!("START_REPLICATION SLOT {quoted_slot} LOGICAL 0/0 ({options_str})") - } else { - format!( - "START_REPLICATION SLOT {} LOGICAL {} ({})", - quoted_slot, - format_lsn(start_lsn), - options_str - ) - } + crate::sql_builder::build_start_replication_sql(slot_name, start_lsn, options) } /// Start logical replication @@ -631,15 +573,6 @@ impl PgReplicationConnection { } } - /// Helper: Build SQL options string from key-value pairs - fn build_sql_options(options: &[String]) -> String { - if options.is_empty() { - String::new() - } else { - format!(" ({})", options.join(", ")) - } - } - /// Check if the connection is still alive pub fn is_alive(&self) -> bool { if self.conn.is_null() { @@ -693,64 +626,7 @@ impl PgReplicationConnection { output_plugin: Option<&str>, options: &ReplicationSlotOptions, ) -> Result { - let mut parts: Vec<&str> = Vec::new(); - - // Quoted slot name — owned, kept alive for the borrow. - let quoted_slot = quote_sql_identifier(slot_name); - - parts.push("CREATE_REPLICATION_SLOT"); - parts.push("ed_slot); - - if options.temporary { - parts.push("TEMPORARY"); - } - - parts.push(slot_type.as_str()); - - // Owned strings that may be needed below; declared here so - // borrows into `parts` remain valid until the join. - let quoted_plugin: String; - - match slot_type { - SlotType::Physical => { - if options.reserve_wal { - parts.push("RESERVE_WAL"); - } - } - SlotType::Logical => { - let plugin = output_plugin.ok_or_else(|| { - ReplicationError::protocol( - "Output plugin required for LOGICAL slots".to_string(), - ) - })?; - quoted_plugin = quote_sql_identifier(plugin); - parts.push("ed_plugin); - - // Only ONE of TWO_PHASE / snapshot keywords is allowed. - if options.two_phase { - parts.push("TWO_PHASE"); - } else if let Some(ref snapshot) = options.snapshot { - match snapshot.as_str() { - "export" => parts.push("EXPORT_SNAPSHOT"), - "nothing" => parts.push("NOEXPORT_SNAPSHOT"), - "use" => parts.push("USE_SNAPSHOT"), - other => { - return Err(ReplicationError::config(format!( - "Invalid snapshot option '{}': \ - expected 'export', 'nothing', or 'use'", - other - ))); - } - } - } - - if options.failover { - parts.push("FAILOVER"); - } - } - } - - Ok(format!("{};", parts.join(" "))) + crate::sql_builder::build_create_slot_sql(slot_name, slot_type, output_plugin, options) } /// Build the SQL string for `ALTER_REPLICATION_SLOT`. @@ -759,28 +635,7 @@ impl PgReplicationConnection { two_phase: Option, failover: Option, ) -> Result { - let mut opts = Vec::new(); - - if let Some(tp) = two_phase { - opts.push(format!("TWO_PHASE {}", tp)); - } - - if let Some(failover_value) = failover { - opts.push(format!("FAILOVER {}", failover_value)); - } - - if opts.is_empty() { - return Err(ReplicationError::protocol( - "At least one option must be specified for ALTER_REPLICATION_SLOT".to_string(), - )); - } - - let options_str = Self::build_sql_options(&opts); - let quoted_slot = quote_sql_identifier(slot_name); - Ok(format!( - "ALTER_REPLICATION_SLOT {}{};", - quoted_slot, options_str - )) + crate::sql_builder::build_alter_slot_sql(slot_name, two_phase, failover) } /// Alter a replication slot (logical slots only) @@ -800,12 +655,7 @@ impl PgReplicationConnection { /// Build the SQL string for `DROP_REPLICATION_SLOT`. fn build_drop_slot_sql(slot_name: &str, wait: bool) -> String { - let quoted_slot = quote_sql_identifier(slot_name); - if wait { - format!("DROP_REPLICATION_SLOT {} WAIT;", quoted_slot) - } else { - format!("DROP_REPLICATION_SLOT {};", quoted_slot) - } + crate::sql_builder::build_drop_slot_sql(slot_name, wait) } /// Drop a replication slot @@ -837,8 +687,7 @@ impl PgReplicationConnection { /// Build the SQL string for `READ_REPLICATION_SLOT`. fn build_read_slot_sql(slot_name: &str) -> String { - let quoted_slot = quote_sql_identifier(slot_name); - format!("READ_REPLICATION_SLOT {};", quoted_slot) + crate::sql_builder::build_read_slot_sql(slot_name) } /// Read information about a replication slot @@ -885,27 +734,7 @@ impl PgReplicationConnection { start_lsn: XLogRecPtr, timeline_id: Option, ) -> String { - let mut sql = String::from("START_REPLICATION "); - - if let Some(slot) = slot_name { - let quoted_slot = quote_sql_identifier(slot); - sql.push_str(&format!("SLOT {} ", quoted_slot)); - } - - sql.push_str("PHYSICAL "); - - let lsn_str = if start_lsn == INVALID_XLOG_REC_PTR { - "0/0".to_string() - } else { - format_lsn(start_lsn) - }; - sql.push_str(&lsn_str); - - if let Some(tli) = timeline_id { - sql.push_str(&format!(" TIMELINE {}", tli)); - } - - sql + crate::sql_builder::build_start_physical_replication_sql(slot_name, start_lsn, timeline_id) } /// Start physical replication @@ -951,85 +780,7 @@ impl PgReplicationConnection { /// Build the SQL string for `BASE_BACKUP`. fn build_base_backup_sql(options: &BaseBackupOptions) -> String { - let mut opts = Vec::new(); - - if let Some(ref label) = options.label { - opts.push(format!("LABEL {}", quote_sql_string_value(label))); - } - - if let Some(ref target) = options.target { - opts.push(format!("TARGET {}", quote_sql_string_value(target))); - } - - if let Some(ref target_detail) = options.target_detail { - opts.push(format!( - "TARGET_DETAIL {}", - quote_sql_string_value(target_detail) - )); - } - - if options.progress { - opts.push("PROGRESS true".to_string()); - } - - if let Some(ref checkpoint) = options.checkpoint { - opts.push(format!("CHECKPOINT {}", quote_sql_string_value(checkpoint))); - } - - if options.wal { - opts.push("WAL true".to_string()); - } - - if options.wait { - opts.push("WAIT true".to_string()); - } - - if let Some(ref compression) = options.compression { - opts.push(format!( - "COMPRESSION {}", - quote_sql_string_value(compression) - )); - } - - if let Some(ref compression_detail) = options.compression_detail { - opts.push(format!( - "COMPRESSION_DETAIL {}", - quote_sql_string_value(compression_detail) - )); - } - - if let Some(max_rate) = options.max_rate { - opts.push(format!("MAX_RATE {}", max_rate)); - } - - if options.tablespace_map { - opts.push("TABLESPACE_MAP true".to_string()); - } - - if options.verify_checksums { - opts.push("VERIFY_CHECKSUMS true".to_string()); - } - - if let Some(ref manifest) = options.manifest { - opts.push(format!("MANIFEST {}", quote_sql_string_value(manifest))); - } - - if let Some(ref manifest_checksums) = options.manifest_checksums { - opts.push(format!( - "MANIFEST_CHECKSUMS {}", - quote_sql_string_value(manifest_checksums) - )); - } - - if options.incremental { - opts.push("INCREMENTAL".to_string()); - } - - if opts.is_empty() { - "BASE_BACKUP".to_string() - } else { - format!("BASE_BACKUP ({})", opts.join(", ")) - } + crate::sql_builder::build_base_backup_sql(options) } /// Start a base backup with options @@ -1280,6 +1031,21 @@ fn drain_buffered_messages( #[cfg(test)] mod tests { use super::*; + use crate::sql_builder::{quote_ident, quote_literal}; + use crate::INVALID_XLOG_REC_PTR; + + fn sanitize_sql_string_value(value: &str) -> String { + let quoted = quote_literal(value); + quoted[1..quoted.len() - 1].to_owned() + } + + fn quote_sql_string_value(value: &str) -> String { + quote_literal(value) + } + + fn quote_sql_identifier(identifier: &str) -> String { + quote_ident(identifier) + } #[test] fn test_sanitize_sql_string_value_no_quotes() { @@ -1403,14 +1169,14 @@ mod tests { #[test] fn test_build_sql_options_empty() { let options: Vec = vec![]; - let result = PgReplicationConnection::build_sql_options(&options); + let result = crate::sql_builder::build_sql_options(&options); assert_eq!(result, ""); } #[test] fn test_build_sql_options_single() { let options = vec!["proto_version '2'".to_string()]; - let result = PgReplicationConnection::build_sql_options(&options); + let result = crate::sql_builder::build_sql_options(&options); assert_eq!(result, " (proto_version '2')"); } @@ -1421,7 +1187,7 @@ mod tests { "publication_names '\"my_pub\"'".to_string(), "streaming 'on'".to_string(), ]; - let result = PgReplicationConnection::build_sql_options(&options); + let result = crate::sql_builder::build_sql_options(&options); assert_eq!( result, " (proto_version '2', publication_names '\"my_pub\"', streaming 'on')" @@ -2266,7 +2032,7 @@ mod tests { INVALID_XLOG_REC_PTR, &[], ); - assert_eq!(sql, r#"START_REPLICATION SLOT "slot1" LOGICAL 0/0 ()"#); + assert_eq!(sql, r#"START_REPLICATION SLOT "slot1" LOGICAL 0/0"#); } #[test] diff --git a/src/connection/native/connection.rs b/src/connection/native/connection.rs index f6a2471..15a6ace 100644 --- a/src/connection/native/connection.rs +++ b/src/connection/native/connection.rs @@ -19,7 +19,7 @@ use crate::error::{ReplicationError, Result}; use crate::protocol::build_hot_standby_feedback_message; use crate::types::{ format_lsn, system_time_to_postgres_timestamp, BaseBackupOptions, ReplicationSlotOptions, - SlotType, XLogRecPtr, INVALID_XLOG_REC_PTR, + SlotType, XLogRecPtr, }; /// Initial capacity for the read buffer (256 KiB). @@ -46,25 +46,6 @@ fn run_sync(fut: F) -> F::Output { } } -/// Sanitize a string value for use in PostgreSQL replication protocol commands -/// by escaping single quotes (replacing ' with '') -#[inline] -fn sanitize_sql_string_value(value: &str) -> String { - value.replace('\'', "''") -} - -/// Sanitize a string value and wrap it in single quotes for SQL -#[inline] -fn quote_sql_string_value(value: &str) -> String { - format!("'{}'", sanitize_sql_string_value(value)) -} - -/// Quote a SQL identifier by escaping internal double quotes. -#[inline] -fn quote_sql_identifier(identifier: &str) -> String { - format!("\"{}\"", identifier.replace('"', "\"\"")) -} - /// Pure-Rust PostgreSQL connection for replication. /// /// Provides the same public API as the libpq `PgReplicationConnection` @@ -174,27 +155,7 @@ impl NativeConnection { start_lsn: XLogRecPtr, options: &[(&str, &str)], ) -> Result<()> { - let quoted_slot = quote_sql_identifier(slot_name); - let mut options_str = String::new(); - for (i, (key, value)) in options.iter().enumerate() { - if i > 0 { - options_str.push_str(", "); - } - let quoted_key = quote_sql_identifier(key); - let sanitized_value = sanitize_sql_string_value(value); - options_str.push_str(&format!("{quoted_key} '{sanitized_value}'")); - } - - let sql = if start_lsn == INVALID_XLOG_REC_PTR { - format!("START_REPLICATION SLOT {quoted_slot} LOGICAL 0/0 ({options_str})") - } else { - format!( - "START_REPLICATION SLOT {} LOGICAL {} ({})", - quoted_slot, - format_lsn(start_lsn), - options_str - ) - }; + let sql = crate::sql_builder::build_start_replication_sql(slot_name, start_lsn, options); debug!("Starting replication: {}", sql); @@ -331,58 +292,7 @@ impl NativeConnection { output_plugin: Option<&str>, options: &ReplicationSlotOptions, ) -> Result { - let mut parts: Vec<&str> = Vec::new(); - let quoted_slot = quote_sql_identifier(slot_name); - - parts.push("CREATE_REPLICATION_SLOT"); - parts.push("ed_slot); - - if options.temporary { - parts.push("TEMPORARY"); - } - - parts.push(slot_type.as_str()); - - let quoted_plugin: String; - - match slot_type { - SlotType::Physical => { - if options.reserve_wal { - parts.push("RESERVE_WAL"); - } - } - SlotType::Logical => { - let plugin = output_plugin.ok_or_else(|| { - ReplicationError::protocol( - "Output plugin required for LOGICAL slots".to_string(), - ) - })?; - quoted_plugin = quote_sql_identifier(plugin); - parts.push("ed_plugin); - - if options.two_phase { - parts.push("TWO_PHASE"); - } else if let Some(ref snapshot) = options.snapshot { - match snapshot.as_str() { - "export" => parts.push("EXPORT_SNAPSHOT"), - "nothing" => parts.push("NOEXPORT_SNAPSHOT"), - "use" => parts.push("USE_SNAPSHOT"), - other => { - return Err(ReplicationError::config(format!( - "Invalid snapshot option '{}': expected 'export', 'nothing', or 'use'", - other - ))); - } - } - } - - if options.failover { - parts.push("FAILOVER"); - } - } - } - - Ok(format!("{};", parts.join(" "))) + crate::sql_builder::build_create_slot_sql(slot_name, slot_type, output_plugin, options) } /// Alter a replication slot (logical slots only). @@ -392,25 +302,7 @@ impl NativeConnection { two_phase: Option, failover: Option, ) -> Result { - let mut opts = Vec::new(); - - if let Some(tp) = two_phase { - opts.push(format!("TWO_PHASE {}", tp)); - } - - if let Some(failover_value) = failover { - opts.push(format!("FAILOVER {}", failover_value)); - } - - if opts.is_empty() { - return Err(ReplicationError::protocol( - "At least one option must be specified for ALTER_REPLICATION_SLOT".to_string(), - )); - } - - let options_str = Self::build_sql_options(&opts); - let quoted_slot = quote_sql_identifier(slot_name); - let sql = format!("ALTER_REPLICATION_SLOT {}{};", quoted_slot, options_str); + let sql = crate::sql_builder::build_alter_slot_sql(slot_name, two_phase, failover)?; debug!("Altering replication slot: {}", sql); let result = self.exec(&sql)?; @@ -418,21 +310,8 @@ impl NativeConnection { Ok(result) } - fn build_sql_options(options: &[String]) -> String { - if options.is_empty() { - String::new() - } else { - format!(" ({})", options.join(", ")) - } - } - fn build_drop_slot_sql(slot_name: &str, wait: bool) -> String { - let quoted_slot = quote_sql_identifier(slot_name); - if wait { - format!("DROP_REPLICATION_SLOT {} WAIT;", quoted_slot) - } else { - format!("DROP_REPLICATION_SLOT {};", quoted_slot) - } + crate::sql_builder::build_drop_slot_sql(slot_name, wait) } /// Drop a replication slot. @@ -454,8 +333,7 @@ impl NativeConnection { } fn build_read_slot_sql(slot_name: &str) -> String { - let quoted_slot = quote_sql_identifier(slot_name); - format!("READ_REPLICATION_SLOT {};", quoted_slot) + crate::sql_builder::build_read_slot_sql(slot_name) } /// Read information about a replication slot. @@ -497,25 +375,11 @@ impl NativeConnection { start_lsn: XLogRecPtr, timeline_id: Option, ) -> Result<()> { - let mut sql = String::from("START_REPLICATION "); - - if let Some(slot) = slot_name { - let quoted_slot = quote_sql_identifier(slot); - sql.push_str(&format!("SLOT {} ", quoted_slot)); - } - - sql.push_str("PHYSICAL "); - - let lsn_str = if start_lsn == INVALID_XLOG_REC_PTR { - "0/0".to_string() - } else { - format_lsn(start_lsn) - }; - sql.push_str(&lsn_str); - - if let Some(tli) = timeline_id { - sql.push_str(&format!(" TIMELINE {}", tli)); - } + let sql = crate::sql_builder::build_start_physical_replication_sql( + slot_name, + start_lsn, + timeline_id, + ); debug!("Starting physical replication: {}", sql); @@ -542,71 +406,7 @@ impl NativeConnection { /// Start a base backup with options. pub fn base_backup(&mut self, options: &BaseBackupOptions) -> Result { - let mut opts = Vec::new(); - - if let Some(ref label) = options.label { - opts.push(format!("LABEL {}", quote_sql_string_value(label))); - } - if let Some(ref target) = options.target { - opts.push(format!("TARGET {}", quote_sql_string_value(target))); - } - if let Some(ref target_detail) = options.target_detail { - opts.push(format!( - "TARGET_DETAIL {}", - quote_sql_string_value(target_detail) - )); - } - if options.progress { - opts.push("PROGRESS true".to_string()); - } - if let Some(ref checkpoint) = options.checkpoint { - opts.push(format!("CHECKPOINT {}", quote_sql_string_value(checkpoint))); - } - if options.wal { - opts.push("WAL true".to_string()); - } - if options.wait { - opts.push("WAIT true".to_string()); - } - if let Some(ref compression) = options.compression { - opts.push(format!( - "COMPRESSION {}", - quote_sql_string_value(compression) - )); - } - if let Some(ref compression_detail) = options.compression_detail { - opts.push(format!( - "COMPRESSION_DETAIL {}", - quote_sql_string_value(compression_detail) - )); - } - if let Some(max_rate) = options.max_rate { - opts.push(format!("MAX_RATE {}", max_rate)); - } - if options.tablespace_map { - opts.push("TABLESPACE_MAP true".to_string()); - } - if options.verify_checksums { - opts.push("VERIFY_CHECKSUMS true".to_string()); - } - if let Some(ref manifest) = options.manifest { - opts.push(format!("MANIFEST {}", quote_sql_string_value(manifest))); - } - if let Some(ref manifest_checksums) = options.manifest_checksums { - opts.push(format!( - "MANIFEST_CHECKSUMS {}", - quote_sql_string_value(manifest_checksums) - )); - } - if options.incremental { - opts.push("INCREMENTAL".to_string()); - } - - let sql = if opts.is_empty() { - "BASE_BACKUP".to_string() - } else { - format!("BASE_BACKUP ({})", opts.join(", ")) - }; + let sql = crate::sql_builder::build_base_backup_sql(options); debug!("Starting base backup: {}", sql); let result = self.exec(&sql)?; @@ -706,6 +506,20 @@ mod tests { use super::*; use crate::types::{ReplicationSlotOptions, SlotType}; + fn sanitize_sql_string_value(value: &str) -> String { + let quoted = crate::sql_builder::quote_literal(value); + // Strip surrounding quotes to get just the sanitized interior + quoted[1..quoted.len() - 1].to_string() + } + + fn quote_sql_string_value(value: &str) -> String { + crate::sql_builder::quote_literal(value) + } + + fn quote_sql_identifier(identifier: &str) -> String { + crate::sql_builder::quote_ident(identifier) + } + // === sanitize_sql_string_value === #[test] @@ -844,14 +658,14 @@ mod tests { #[test] fn test_build_sql_options_empty() { let options: Vec = vec![]; - assert_eq!(NativeConnection::build_sql_options(&options), ""); + assert_eq!(crate::sql_builder::build_sql_options(&options), ""); } #[test] fn test_build_sql_options_single() { let options = vec!["proto_version '2'".to_string()]; assert_eq!( - NativeConnection::build_sql_options(&options), + crate::sql_builder::build_sql_options(&options), " (proto_version '2')" ); } @@ -864,7 +678,7 @@ mod tests { "streaming 'on'".to_string(), ]; assert_eq!( - NativeConnection::build_sql_options(&options), + crate::sql_builder::build_sql_options(&options), " (proto_version '2', publication_names '\"my_pub\"', streaming 'on')" ); } diff --git a/src/lib.rs b/src/lib.rs index 6bd563a..bf2121d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -127,6 +127,7 @@ pub mod buffer; pub mod column_value; pub mod deserializer; pub mod error; +pub mod sql_builder; pub mod types; // Protocol implementation @@ -198,3 +199,9 @@ pub use connection::{PgReplicationConnection, PgResult}; // Re-export retry types pub use retry::{ExponentialBackoff, ReplicationConnectionRetry, RetryConfig}; + +// Re-export SQL builder utilities +pub use sql_builder::{ + build_create_subscription_sql, build_detach_slot_sql, build_disable_subscription_sql, + build_drop_subscription_sql, quote_ident, quote_literal, CreateSubscriptionOptions, +}; diff --git a/src/sql_builder.rs b/src/sql_builder.rs new file mode 100644 index 0000000..c925051 --- /dev/null +++ b/src/sql_builder.rs @@ -0,0 +1,1533 @@ +//! SQL builder utilities for PostgreSQL replication management. +//! +//! Provides safe quoting primitives and SQL statement builders for replication +//! slot management, subscription management, and base backup commands. +//! All quoting functions use pre-allocated buffers and char-by-char iteration +//! to avoid intermediate allocations. + +use crate::error::{ReplicationError, Result}; +use crate::types::{format_lsn, BaseBackupOptions, ReplicationSlotOptions, SlotType, XLogRecPtr}; + +/// Invalid/zero LSN pointer (re-imported for internal use). +const INVALID_XLOG_REC_PTR: u64 = 0; + +/// Quote a PostgreSQL identifier by wrapping in double quotes and escaping +/// internal double quotes (doubling them). +/// +/// # Panics +/// +/// Panics if `name` contains a null byte (`\0`), which is invalid in +/// PostgreSQL identifiers and could cause truncation-based injection via +/// the C-string wire protocol. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::quote_ident; +/// +/// assert_eq!(quote_ident("my_slot"), r#""my_slot""#); +/// assert_eq!(quote_ident(r#"a"b"#), r#""a""b""#); +/// ``` +#[inline] +pub fn quote_ident(name: &str) -> String { + assert!( + !name.contains('\0'), + "SQL identifier must not contain null bytes" + ); + let mut out = String::with_capacity(name.len() + 2); + out.push('"'); + for ch in name.chars() { + if ch == '"' { + out.push('"'); + } + out.push(ch); + } + out.push('"'); + out +} + +/// Quote a PostgreSQL string literal by wrapping in single quotes and escaping +/// internal single quotes (doubling them). +/// +/// # Panics +/// +/// Panics if `value` contains a null byte (`\0`), which is invalid in +/// PostgreSQL string literals and could cause truncation-based injection via +/// the C-string wire protocol. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::quote_literal; +/// +/// assert_eq!(quote_literal("hello"), "'hello'"); +/// assert_eq!(quote_literal("it's"), "'it''s'"); +/// ``` +#[inline] +pub fn quote_literal(value: &str) -> String { + assert!( + !value.contains('\0'), + "SQL literal must not contain null bytes" + ); + let mut out = String::with_capacity(value.len() + 2); + out.push('\''); + for ch in value.chars() { + if ch == '\'' { + out.push('\''); + } + out.push(ch); + } + out.push('\''); + out +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Replication slot SQL builders +// ═══════════════════════════════════════════════════════════════════════════ + +/// Build the SQL for `CREATE_REPLICATION_SLOT`. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::build_create_slot_sql; +/// use pg_walstream::types::{ReplicationSlotOptions, SlotType}; +/// +/// let opts = ReplicationSlotOptions::default(); +/// let sql = build_create_slot_sql("my_slot", SlotType::Logical, Some("pgoutput"), &opts).unwrap(); +/// assert_eq!(sql, r#"CREATE_REPLICATION_SLOT "my_slot" LOGICAL "pgoutput";"#); +/// ``` +pub fn build_create_slot_sql( + slot_name: &str, + slot_type: SlotType, + output_plugin: Option<&str>, + options: &ReplicationSlotOptions, +) -> Result { + let mut parts: Vec<&str> = Vec::with_capacity(6); + + let quoted_slot = quote_ident(slot_name); + + parts.push("CREATE_REPLICATION_SLOT"); + parts.push("ed_slot); + + if options.temporary { + parts.push("TEMPORARY"); + } + + parts.push(slot_type.as_str()); + + let quoted_plugin: String; + + match slot_type { + SlotType::Physical => { + if options.reserve_wal { + parts.push("RESERVE_WAL"); + } + } + SlotType::Logical => { + let plugin = output_plugin.ok_or_else(|| { + ReplicationError::protocol("Output plugin required for LOGICAL slots".to_string()) + })?; + quoted_plugin = quote_ident(plugin); + parts.push("ed_plugin); + + if options.two_phase { + parts.push("TWO_PHASE"); + } else if let Some(ref snapshot) = options.snapshot { + match snapshot.as_str() { + "export" => parts.push("EXPORT_SNAPSHOT"), + "nothing" => parts.push("NOEXPORT_SNAPSHOT"), + "use" => parts.push("USE_SNAPSHOT"), + other => { + return Err(ReplicationError::config(format!( + "Invalid snapshot option '{}': \ + expected 'export', 'nothing', or 'use'", + other + ))); + } + } + } + + if options.failover { + parts.push("FAILOVER"); + } + } + } + + Ok(format!("{};", parts.join(" "))) +} + +/// Build the SQL for `ALTER_REPLICATION_SLOT`. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::build_alter_slot_sql; +/// +/// let sql = build_alter_slot_sql("my_slot", Some(true), None).unwrap(); +/// assert_eq!(sql, r#"ALTER_REPLICATION_SLOT "my_slot" (TWO_PHASE true);"#); +/// ``` +pub fn build_alter_slot_sql( + slot_name: &str, + two_phase: Option, + failover: Option, +) -> Result { + let mut opts = Vec::new(); + + if let Some(tp) = two_phase { + opts.push(format!("TWO_PHASE {}", tp)); + } + + if let Some(failover_value) = failover { + opts.push(format!("FAILOVER {}", failover_value)); + } + + if opts.is_empty() { + return Err(ReplicationError::protocol( + "At least one option must be specified for ALTER_REPLICATION_SLOT".to_string(), + )); + } + + let options_str = build_sql_options(&opts); + let quoted_slot = quote_ident(slot_name); + Ok(format!( + "ALTER_REPLICATION_SLOT {}{};", + quoted_slot, options_str + )) +} + +/// Build the SQL for `DROP_REPLICATION_SLOT`. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::build_drop_slot_sql; +/// +/// assert_eq!(build_drop_slot_sql("my_slot", false), r#"DROP_REPLICATION_SLOT "my_slot";"#); +/// assert_eq!(build_drop_slot_sql("my_slot", true), r#"DROP_REPLICATION_SLOT "my_slot" WAIT;"#); +/// ``` +#[inline] +pub fn build_drop_slot_sql(slot_name: &str, wait: bool) -> String { + let quoted_slot = quote_ident(slot_name); + if wait { + format!("DROP_REPLICATION_SLOT {} WAIT;", quoted_slot) + } else { + format!("DROP_REPLICATION_SLOT {};", quoted_slot) + } +} + +/// Build the SQL for `READ_REPLICATION_SLOT`. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::build_read_slot_sql; +/// +/// assert_eq!(build_read_slot_sql("my_slot"), r#"READ_REPLICATION_SLOT "my_slot";"#); +/// ``` +#[inline] +pub fn build_read_slot_sql(slot_name: &str) -> String { + let quoted_slot = quote_ident(slot_name); + format!("READ_REPLICATION_SLOT {};", quoted_slot) +} + +/// Build the SQL for `START_REPLICATION SLOT ... LOGICAL`. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::build_start_replication_sql; +/// +/// let sql = build_start_replication_sql("my_slot", 0, &[("proto_version", "1")]); +/// assert_eq!(sql, r#"START_REPLICATION SLOT "my_slot" LOGICAL 0/0 ("proto_version" '1')"#); +/// ``` +pub fn build_start_replication_sql( + slot_name: &str, + start_lsn: XLogRecPtr, + options: &[(&str, &str)], +) -> String { + let quoted_slot = quote_ident(slot_name); + let lsn_str = if start_lsn == INVALID_XLOG_REC_PTR { + "0/0".to_string() + } else { + format_lsn(start_lsn) + }; + + if options.is_empty() { + return format!("START_REPLICATION SLOT {quoted_slot} LOGICAL {lsn_str}"); + } + + let mut options_str = String::new(); + for (i, (key, value)) in options.iter().enumerate() { + if i > 0 { + options_str.push_str(", "); + } + let quoted_key = quote_ident(key); + let quoted_value = quote_literal(value); + options_str.push_str("ed_key); + options_str.push(' '); + options_str.push_str("ed_value); + } + + format!("START_REPLICATION SLOT {quoted_slot} LOGICAL {lsn_str} ({options_str})") +} + +/// Build the SQL for `START_REPLICATION ... PHYSICAL`. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::build_start_physical_replication_sql; +/// +/// let sql = build_start_physical_replication_sql(Some("my_slot"), 0, None); +/// assert_eq!(sql, r#"START_REPLICATION SLOT "my_slot" PHYSICAL 0/0"#); +/// ``` +pub fn build_start_physical_replication_sql( + slot_name: Option<&str>, + start_lsn: XLogRecPtr, + timeline_id: Option, +) -> String { + let mut sql = String::with_capacity(64); + sql.push_str("START_REPLICATION "); + + if let Some(slot) = slot_name { + let quoted_slot = quote_ident(slot); + sql.push_str("SLOT "); + sql.push_str("ed_slot); + sql.push(' '); + } + + sql.push_str("PHYSICAL "); + + if start_lsn == INVALID_XLOG_REC_PTR { + sql.push_str("0/0"); + } else { + sql.push_str(&format_lsn(start_lsn)); + } + + if let Some(tli) = timeline_id { + sql.push_str(" TIMELINE "); + sql.push_str(&tli.to_string()); + } + + sql +} + +/// Build the SQL for `BASE_BACKUP`. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::build_base_backup_sql; +/// use pg_walstream::types::BaseBackupOptions; +/// +/// let opts = BaseBackupOptions::default(); +/// assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP"); +/// ``` +pub fn build_base_backup_sql(options: &BaseBackupOptions) -> String { + let mut opts = Vec::new(); + + if let Some(ref label) = options.label { + opts.push(format!("LABEL {}", quote_literal(label))); + } + + if let Some(ref target) = options.target { + opts.push(format!("TARGET {}", quote_literal(target))); + } + + if let Some(ref target_detail) = options.target_detail { + opts.push(format!("TARGET_DETAIL {}", quote_literal(target_detail))); + } + + if options.progress { + opts.push("PROGRESS true".to_string()); + } + + if let Some(ref checkpoint) = options.checkpoint { + opts.push(format!("CHECKPOINT {}", quote_literal(checkpoint))); + } + + if options.wal { + opts.push("WAL true".to_string()); + } + + if options.wait { + opts.push("WAIT true".to_string()); + } + + if let Some(ref compression) = options.compression { + opts.push(format!("COMPRESSION {}", quote_literal(compression))); + } + + if let Some(ref compression_detail) = options.compression_detail { + opts.push(format!( + "COMPRESSION_DETAIL {}", + quote_literal(compression_detail) + )); + } + + if let Some(max_rate) = options.max_rate { + opts.push(format!("MAX_RATE {}", max_rate)); + } + + if options.tablespace_map { + opts.push("TABLESPACE_MAP true".to_string()); + } + + if options.verify_checksums { + opts.push("VERIFY_CHECKSUMS true".to_string()); + } + + if let Some(ref manifest) = options.manifest { + opts.push(format!("MANIFEST {}", quote_literal(manifest))); + } + + if let Some(ref manifest_checksums) = options.manifest_checksums { + opts.push(format!( + "MANIFEST_CHECKSUMS {}", + quote_literal(manifest_checksums) + )); + } + + if options.incremental { + opts.push("INCREMENTAL".to_string()); + } + + if opts.is_empty() { + "BASE_BACKUP".to_string() + } else { + format!("BASE_BACKUP ({})", opts.join(", ")) + } +} + +/// Options for building a `CREATE SUBSCRIPTION` SQL statement. +/// +/// All fields borrow from the caller — no allocation or cloning required. +/// Use [`Default`] for the WITH-clause flags to get the typical migration +/// defaults (`create_slot = false`, `enabled = true`, `copy_data = false`). +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::CreateSubscriptionOptions; +/// +/// let opts = CreateSubscriptionOptions { +/// subscription_name: "my_sub", +/// connection_string: "host=localhost dbname=source", +/// publication: "my_pub", +/// slot_name: "my_slot", +/// ..Default::default() +/// }; +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct CreateSubscriptionOptions<'a> { + pub subscription_name: &'a str, + pub connection_string: &'a str, + pub publication: &'a str, + pub slot_name: &'a str, + pub create_slot: bool, + pub enabled: bool, + pub copy_data: bool, +} + +impl<'a> Default for CreateSubscriptionOptions<'a> { + #[inline] + fn default() -> Self { + Self { + subscription_name: "", + connection_string: "", + publication: "", + slot_name: "", + create_slot: false, + enabled: true, + copy_data: false, + } + } +} + +/// Build a `CREATE SUBSCRIPTION` statement. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::{build_create_subscription_sql, CreateSubscriptionOptions}; +/// +/// let opts = CreateSubscriptionOptions { +/// subscription_name: "my_sub", +/// connection_string: "host=localhost dbname=source", +/// publication: "my_pub", +/// slot_name: "my_slot", +/// ..Default::default() +/// }; +/// let sql = build_create_subscription_sql(&opts); +/// assert!(sql.starts_with("CREATE SUBSCRIPTION")); +/// assert!(sql.contains("create_slot = false")); +/// ``` +pub fn build_create_subscription_sql(opts: &CreateSubscriptionOptions<'_>) -> String { + let sub = quote_ident(opts.subscription_name); + let conn = quote_literal(opts.connection_string); + let pubname = quote_ident(opts.publication); + let slot = quote_literal(opts.slot_name); + + let create_slot_str = if opts.create_slot { "true" } else { "false" }; + let enabled_str = if opts.enabled { "true" } else { "false" }; + let copy_data_str = if opts.copy_data { "true" } else { "false" }; + + let estimated = "CREATE SUBSCRIPTION ".len() + + sub.len() + + " CONNECTION ".len() + + conn.len() + + " PUBLICATION ".len() + + pubname.len() + + " WITH (create_slot = , slot_name = , enabled = , copy_data = )".len() + + slot.len() + + create_slot_str.len() + + enabled_str.len() + + copy_data_str.len(); + + let mut sql = String::with_capacity(estimated); + sql.push_str("CREATE SUBSCRIPTION "); + sql.push_str(&sub); + sql.push_str(" CONNECTION "); + sql.push_str(&conn); + sql.push_str(" PUBLICATION "); + sql.push_str(&pubname); + sql.push_str(" WITH (create_slot = "); + sql.push_str(create_slot_str); + sql.push_str(", slot_name = "); + sql.push_str(&slot); + sql.push_str(", enabled = "); + sql.push_str(enabled_str); + sql.push_str(", copy_data = "); + sql.push_str(copy_data_str); + sql.push(')'); + sql +} + +/// Build an `ALTER SUBSCRIPTION ... DISABLE` statement. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::build_disable_subscription_sql; +/// +/// let sql = build_disable_subscription_sql("my_sub"); +/// assert_eq!(sql, r#"ALTER SUBSCRIPTION "my_sub" DISABLE"#); +/// ``` +#[inline] +pub fn build_disable_subscription_sql(name: &str) -> String { + let quoted = quote_ident(name); + let mut sql = + String::with_capacity("ALTER SUBSCRIPTION ".len() + quoted.len() + " DISABLE".len()); + sql.push_str("ALTER SUBSCRIPTION "); + sql.push_str("ed); + sql.push_str(" DISABLE"); + sql +} + +/// Build an `ALTER SUBSCRIPTION ... SET (slot_name = NONE)` statement to detach a slot. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::build_detach_slot_sql; +/// +/// let sql = build_detach_slot_sql("my_sub"); +/// assert_eq!(sql, r#"ALTER SUBSCRIPTION "my_sub" SET (slot_name = NONE)"#); +/// ``` +#[inline] +pub fn build_detach_slot_sql(name: &str) -> String { + let quoted = quote_ident(name); + let mut sql = String::with_capacity( + "ALTER SUBSCRIPTION ".len() + quoted.len() + " SET (slot_name = NONE)".len(), + ); + sql.push_str("ALTER SUBSCRIPTION "); + sql.push_str("ed); + sql.push_str(" SET (slot_name = NONE)"); + sql +} + +/// Build a `DROP SUBSCRIPTION` statement. +/// +/// # Example +/// +/// ``` +/// use pg_walstream::sql_builder::build_drop_subscription_sql; +/// +/// let sql = build_drop_subscription_sql("my_sub"); +/// assert_eq!(sql, r#"DROP SUBSCRIPTION "my_sub""#); +/// ``` +#[inline] +pub fn build_drop_subscription_sql(name: &str) -> String { + let quoted = quote_ident(name); + let mut sql = String::with_capacity("DROP SUBSCRIPTION ".len() + quoted.len()); + sql.push_str("DROP SUBSCRIPTION "); + sql.push_str("ed); + sql +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Helpers +// ═══════════════════════════════════════════════════════════════════════════ + +/// Format a list of options as ` (opt1, opt2, ...)`. +/// Returns an empty string if the list is empty. +#[inline] +pub fn build_sql_options(options: &[String]) -> String { + if options.is_empty() { + String::new() + } else { + format!(" ({})", options.join(", ")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── quote_ident ────────────────────────────────────────────────────── + + #[test] + fn quote_ident_simple() { + assert_eq!(quote_ident("my_slot"), r#""my_slot""#); + } + + #[test] + fn quote_ident_with_internal_double_quote() { + assert_eq!(quote_ident(r#"a"b"#), r#""a""b""#); + } + + #[test] + fn quote_ident_multiple_quotes() { + assert_eq!(quote_ident(r#"a""b"#), r#""a""""b""#); + } + + #[test] + fn quote_ident_empty() { + assert_eq!(quote_ident(""), r#""""#); + } + + #[test] + fn quote_ident_special_chars() { + assert_eq!( + quote_ident("slot; DROP TABLE users; --"), + r#""slot; DROP TABLE users; --""# + ); + } + + #[test] + fn quote_ident_unicode() { + assert_eq!(quote_ident("テスト"), r#""テスト""#); + } + + #[test] + fn quote_ident_mixed_unicode_and_quotes() { + assert_eq!(quote_ident(r#"名前"テスト"#), r#""名前""テスト""#); + } + + #[test] + #[should_panic(expected = "null bytes")] + fn quote_ident_rejects_null_byte() { + quote_ident("evil\0injection"); + } + + // ── quote_literal ──────────────────────────────────────────────────── + + #[test] + fn quote_literal_simple() { + assert_eq!(quote_literal("hello"), "'hello'"); + } + + #[test] + fn quote_literal_with_internal_single_quote() { + assert_eq!(quote_literal("it's"), "'it''s'"); + } + + #[test] + fn quote_literal_multiple_quotes() { + assert_eq!(quote_literal("a''b"), "'a''''b'"); + } + + #[test] + fn quote_literal_empty() { + assert_eq!(quote_literal(""), "''"); + } + + #[test] + fn quote_literal_sql_injection_attempt() { + assert_eq!( + quote_literal("'; DROP TABLE users; --"), + "'''; DROP TABLE users; --'" + ); + } + + #[test] + fn quote_literal_unicode() { + assert_eq!(quote_literal("日本語"), "'日本語'"); + } + + #[test] + fn quote_literal_newlines() { + assert_eq!(quote_literal("line1\nline2"), "'line1\nline2'"); + } + + #[test] + fn quote_literal_complex_injection() { + assert_eq!(quote_literal("value' OR '1'='1"), "'value'' OR ''1''=''1'"); + } + + #[test] + fn quote_literal_backslash_and_quote() { + assert_eq!(quote_literal("test\\'value"), "'test\\''value'"); + } + + #[test] + #[should_panic(expected = "null bytes")] + fn quote_literal_rejects_null_byte() { + quote_literal("evil\0injection"); + } + + // ── build_create_slot_sql ──────────────────────────────────────────── + + #[test] + fn create_slot_logical_default() { + let opts = ReplicationSlotOptions::default(); + let sql = + build_create_slot_sql("my_slot", SlotType::Logical, Some("pgoutput"), &opts).unwrap(); + assert_eq!( + sql, + r#"CREATE_REPLICATION_SLOT "my_slot" LOGICAL "pgoutput";"# + ); + } + + #[test] + fn create_slot_logical_temporary_export_snapshot() { + let opts = ReplicationSlotOptions { + temporary: true, + snapshot: Some("export".to_string()), + ..Default::default() + }; + let sql = + build_create_slot_sql("tmp_slot", SlotType::Logical, Some("pgoutput"), &opts).unwrap(); + assert_eq!( + sql, + r#"CREATE_REPLICATION_SLOT "tmp_slot" TEMPORARY LOGICAL "pgoutput" EXPORT_SNAPSHOT;"# + ); + } + + #[test] + fn create_slot_logical_noexport_snapshot() { + let opts = ReplicationSlotOptions { + snapshot: Some("nothing".to_string()), + ..Default::default() + }; + let sql = + build_create_slot_sql("slot", SlotType::Logical, Some("pgoutput"), &opts).unwrap(); + assert_eq!( + sql, + r#"CREATE_REPLICATION_SLOT "slot" LOGICAL "pgoutput" NOEXPORT_SNAPSHOT;"# + ); + } + + #[test] + fn create_slot_logical_use_snapshot() { + let opts = ReplicationSlotOptions { + snapshot: Some("use".to_string()), + ..Default::default() + }; + let sql = + build_create_slot_sql("slot", SlotType::Logical, Some("pgoutput"), &opts).unwrap(); + assert_eq!( + sql, + r#"CREATE_REPLICATION_SLOT "slot" LOGICAL "pgoutput" USE_SNAPSHOT;"# + ); + } + + #[test] + fn create_slot_logical_two_phase() { + let opts = ReplicationSlotOptions { + two_phase: true, + ..Default::default() + }; + let sql = + build_create_slot_sql("slot", SlotType::Logical, Some("pgoutput"), &opts).unwrap(); + assert_eq!( + sql, + r#"CREATE_REPLICATION_SLOT "slot" LOGICAL "pgoutput" TWO_PHASE;"# + ); + } + + #[test] + fn create_slot_logical_two_phase_overrides_snapshot() { + let opts = ReplicationSlotOptions { + two_phase: true, + snapshot: Some("export".to_string()), + ..Default::default() + }; + let sql = + build_create_slot_sql("slot", SlotType::Logical, Some("pgoutput"), &opts).unwrap(); + assert_eq!( + sql, + r#"CREATE_REPLICATION_SLOT "slot" LOGICAL "pgoutput" TWO_PHASE;"# + ); + } + + #[test] + fn create_slot_logical_failover() { + let opts = ReplicationSlotOptions { + failover: true, + ..Default::default() + }; + let sql = + build_create_slot_sql("slot", SlotType::Logical, Some("pgoutput"), &opts).unwrap(); + assert_eq!( + sql, + r#"CREATE_REPLICATION_SLOT "slot" LOGICAL "pgoutput" FAILOVER;"# + ); + } + + #[test] + fn create_slot_logical_export_with_failover() { + let opts = ReplicationSlotOptions { + snapshot: Some("export".to_string()), + failover: true, + ..Default::default() + }; + let sql = + build_create_slot_sql("slot", SlotType::Logical, Some("pgoutput"), &opts).unwrap(); + assert_eq!( + sql, + r#"CREATE_REPLICATION_SLOT "slot" LOGICAL "pgoutput" EXPORT_SNAPSHOT FAILOVER;"# + ); + } + + #[test] + fn create_slot_physical_reserve_wal() { + let opts = ReplicationSlotOptions { + reserve_wal: true, + ..Default::default() + }; + let sql = build_create_slot_sql("phys", SlotType::Physical, None, &opts).unwrap(); + assert_eq!( + sql, + r#"CREATE_REPLICATION_SLOT "phys" PHYSICAL RESERVE_WAL;"# + ); + } + + #[test] + fn create_slot_physical_default() { + let opts = ReplicationSlotOptions::default(); + let sql = build_create_slot_sql("phys", SlotType::Physical, None, &opts).unwrap(); + assert_eq!(sql, r#"CREATE_REPLICATION_SLOT "phys" PHYSICAL;"#); + } + + #[test] + fn create_slot_physical_temporary() { + let opts = ReplicationSlotOptions { + temporary: true, + ..Default::default() + }; + let sql = build_create_slot_sql("phys", SlotType::Physical, None, &opts).unwrap(); + assert_eq!(sql, r#"CREATE_REPLICATION_SLOT "phys" TEMPORARY PHYSICAL;"#); + } + + #[test] + fn create_slot_invalid_snapshot_value() { + let opts = ReplicationSlotOptions { + snapshot: Some("invalid".to_string()), + ..Default::default() + }; + let err = + build_create_slot_sql("slot", SlotType::Logical, Some("pgoutput"), &opts).unwrap_err(); + assert!(err.to_string().contains("Invalid snapshot option")); + } + + #[test] + fn create_slot_logical_missing_plugin() { + let opts = ReplicationSlotOptions::default(); + let err = build_create_slot_sql("slot", SlotType::Logical, None, &opts).unwrap_err(); + assert!(err.to_string().contains("Output plugin required")); + } + + #[test] + fn create_slot_name_injection() { + let opts = ReplicationSlotOptions::default(); + let sql = build_create_slot_sql( + r#"evil"PHYSICAL"#, + SlotType::Logical, + Some("test_decoding"), + &opts, + ) + .unwrap(); + assert_eq!( + sql, + r#"CREATE_REPLICATION_SLOT "evil""PHYSICAL" LOGICAL "test_decoding";"# + ); + } + + #[test] + fn create_slot_plugin_name_injection() { + let opts = ReplicationSlotOptions::default(); + let sql = + build_create_slot_sql("safe_slot", SlotType::Logical, Some(r#"bad"plugin"#), &opts) + .unwrap(); + assert_eq!( + sql, + r#"CREATE_REPLICATION_SLOT "safe_slot" LOGICAL "bad""plugin";"# + ); + } + + // ── build_alter_slot_sql ───────────────────────────────────────────── + + #[test] + fn alter_slot_two_phase() { + let sql = build_alter_slot_sql("my_slot", Some(true), None).unwrap(); + assert_eq!(sql, r#"ALTER_REPLICATION_SLOT "my_slot" (TWO_PHASE true);"#); + } + + #[test] + fn alter_slot_failover() { + let sql = build_alter_slot_sql("my_slot", None, Some(true)).unwrap(); + assert_eq!(sql, r#"ALTER_REPLICATION_SLOT "my_slot" (FAILOVER true);"#); + } + + #[test] + fn alter_slot_both() { + let sql = build_alter_slot_sql("my_slot", Some(false), Some(true)).unwrap(); + assert_eq!( + sql, + r#"ALTER_REPLICATION_SLOT "my_slot" (TWO_PHASE false, FAILOVER true);"# + ); + } + + #[test] + fn alter_slot_no_options_error() { + let err = build_alter_slot_sql("my_slot", None, None).unwrap_err(); + assert!(err.to_string().contains("At least one option")); + } + + #[test] + fn alter_slot_injection() { + let sql = build_alter_slot_sql(r#"evil"slot"#, Some(true), None).unwrap(); + assert!(sql.contains(r#""evil""slot""#)); + } + + // ── build_drop_slot_sql ────────────────────────────────────────────── + + #[test] + fn drop_slot_without_wait() { + assert_eq!( + build_drop_slot_sql("my_slot", false), + r#"DROP_REPLICATION_SLOT "my_slot";"# + ); + } + + #[test] + fn drop_slot_with_wait() { + assert_eq!( + build_drop_slot_sql("my_slot", true), + r#"DROP_REPLICATION_SLOT "my_slot" WAIT;"# + ); + } + + #[test] + fn drop_slot_injection() { + let sql = build_drop_slot_sql(r#"evil"slot"#, false); + assert_eq!(sql, r#"DROP_REPLICATION_SLOT "evil""slot";"#); + } + + // ── build_read_slot_sql ────────────────────────────────────────────── + + #[test] + fn read_slot_basic() { + assert_eq!( + build_read_slot_sql("my_slot"), + r#"READ_REPLICATION_SLOT "my_slot";"# + ); + } + + #[test] + fn read_slot_injection() { + assert_eq!( + build_read_slot_sql(r#"evil"slot"#), + r#"READ_REPLICATION_SLOT "evil""slot";"# + ); + } + + // ── build_start_replication_sql ────────────────────────────────────── + + #[test] + fn start_replication_zero_lsn() { + let sql = build_start_replication_sql( + "my_slot", + 0, + &[("proto_version", "1"), ("publication_names", "my_pub")], + ); + assert_eq!( + sql, + r#"START_REPLICATION SLOT "my_slot" LOGICAL 0/0 ("proto_version" '1', "publication_names" 'my_pub')"# + ); + } + + #[test] + fn start_replication_valid_lsn() { + let lsn: XLogRecPtr = 0x0000_0001_0000_0000; + let sql = build_start_replication_sql("test_slot", lsn, &[("proto_version", "2")]); + assert!(sql.contains("START_REPLICATION SLOT \"test_slot\" LOGICAL")); + assert!(sql.contains("(\"proto_version\" '2')")); + assert!(!sql.contains("0/0")); + } + + #[test] + fn start_replication_multiple_options() { + let sql = build_start_replication_sql( + "slot1", + 0, + &[ + ("proto_version", "1"), + ("publication_names", "pub1"), + ("messages", "true"), + ], + ); + assert!( + sql.contains(r#""proto_version" '1', "publication_names" 'pub1', "messages" 'true'"#) + ); + } + + #[test] + fn start_replication_empty_options() { + let sql = build_start_replication_sql("slot1", 0, &[]); + assert_eq!(sql, r#"START_REPLICATION SLOT "slot1" LOGICAL 0/0"#); + } + + #[test] + fn start_replication_option_injection() { + let sql = build_start_replication_sql(r#"evil"slot"#, 0, &[("key", "it's")]); + assert!(sql.contains(r#""evil""slot""#)); + assert!(sql.contains("'it''s'")); + } + + #[test] + fn start_replication_single_option() { + let sql = build_start_replication_sql("my_slot", 0, &[("proto_version", "1")]); + assert_eq!( + sql, + r#"START_REPLICATION SLOT "my_slot" LOGICAL 0/0 ("proto_version" '1')"# + ); + } + + // ── build_start_physical_replication_sql ────────────────────────────── + + #[test] + fn start_physical_with_slot_zero_lsn() { + let sql = build_start_physical_replication_sql(Some("my_slot"), 0, None); + assert_eq!(sql, r#"START_REPLICATION SLOT "my_slot" PHYSICAL 0/0"#); + } + + #[test] + fn start_physical_no_slot() { + let sql = build_start_physical_replication_sql(None, 0, None); + assert_eq!(sql, "START_REPLICATION PHYSICAL 0/0"); + } + + #[test] + fn start_physical_with_lsn() { + let lsn: XLogRecPtr = 0x0000_0001_0000_0000; + let sql = build_start_physical_replication_sql(Some("slot"), lsn, None); + assert!(sql.contains("PHYSICAL 1/0")); + assert!(!sql.contains("0/0")); + } + + #[test] + fn start_physical_with_timeline() { + let sql = build_start_physical_replication_sql(Some("slot"), 0, Some(3)); + assert_eq!( + sql, + r#"START_REPLICATION SLOT "slot" PHYSICAL 0/0 TIMELINE 3"# + ); + } + + #[test] + fn start_physical_no_slot_with_timeline() { + let sql = build_start_physical_replication_sql(None, 0, Some(5)); + assert_eq!(sql, "START_REPLICATION PHYSICAL 0/0 TIMELINE 5"); + } + + #[test] + fn start_physical_slot_injection() { + let sql = build_start_physical_replication_sql(Some(r#"evil"slot"#), 0, None); + assert!(sql.contains(r#"SLOT "evil""slot""#)); + } + + // ── build_base_backup_sql ──────────────────────────────────────────── + + #[test] + fn base_backup_default() { + let opts = BaseBackupOptions::default(); + assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP"); + } + + #[test] + fn base_backup_with_label() { + let opts = BaseBackupOptions { + label: Some("my_backup".to_string()), + ..Default::default() + }; + assert_eq!( + build_base_backup_sql(&opts), + "BASE_BACKUP (LABEL 'my_backup')" + ); + } + + #[test] + fn base_backup_with_target() { + let opts = BaseBackupOptions { + target: Some("client".to_string()), + ..Default::default() + }; + assert_eq!( + build_base_backup_sql(&opts), + "BASE_BACKUP (TARGET 'client')" + ); + } + + #[test] + fn base_backup_with_target_detail() { + let opts = BaseBackupOptions { + target: Some("server".to_string()), + target_detail: Some("/var/backups".to_string()), + ..Default::default() + }; + assert_eq!( + build_base_backup_sql(&opts), + "BASE_BACKUP (TARGET 'server', TARGET_DETAIL '/var/backups')" + ); + } + + #[test] + fn base_backup_with_progress() { + let opts = BaseBackupOptions { + progress: true, + ..Default::default() + }; + assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP (PROGRESS true)"); + } + + #[test] + fn base_backup_with_checkpoint() { + let opts = BaseBackupOptions { + checkpoint: Some("fast".to_string()), + ..Default::default() + }; + assert_eq!( + build_base_backup_sql(&opts), + "BASE_BACKUP (CHECKPOINT 'fast')" + ); + } + + #[test] + fn base_backup_with_wal() { + let opts = BaseBackupOptions { + wal: true, + ..Default::default() + }; + assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP (WAL true)"); + } + + #[test] + fn base_backup_with_wait() { + let opts = BaseBackupOptions { + wait: true, + ..Default::default() + }; + assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP (WAIT true)"); + } + + #[test] + fn base_backup_with_compression() { + let opts = BaseBackupOptions { + compression: Some("gzip".to_string()), + ..Default::default() + }; + assert_eq!( + build_base_backup_sql(&opts), + "BASE_BACKUP (COMPRESSION 'gzip')" + ); + } + + #[test] + fn base_backup_with_compression_detail() { + let opts = BaseBackupOptions { + compression: Some("zstd".to_string()), + compression_detail: Some("level=3".to_string()), + ..Default::default() + }; + assert_eq!( + build_base_backup_sql(&opts), + "BASE_BACKUP (COMPRESSION 'zstd', COMPRESSION_DETAIL 'level=3')" + ); + } + + #[test] + fn base_backup_with_max_rate() { + let opts = BaseBackupOptions { + max_rate: Some(1024), + ..Default::default() + }; + assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP (MAX_RATE 1024)"); + } + + #[test] + fn base_backup_with_tablespace_map() { + let opts = BaseBackupOptions { + tablespace_map: true, + ..Default::default() + }; + assert_eq!( + build_base_backup_sql(&opts), + "BASE_BACKUP (TABLESPACE_MAP true)" + ); + } + + #[test] + fn base_backup_with_verify_checksums() { + let opts = BaseBackupOptions { + verify_checksums: true, + ..Default::default() + }; + assert_eq!( + build_base_backup_sql(&opts), + "BASE_BACKUP (VERIFY_CHECKSUMS true)" + ); + } + + #[test] + fn base_backup_with_manifest() { + let opts = BaseBackupOptions { + manifest: Some("yes".to_string()), + ..Default::default() + }; + assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP (MANIFEST 'yes')"); + } + + #[test] + fn base_backup_with_manifest_checksums() { + let opts = BaseBackupOptions { + manifest: Some("yes".to_string()), + manifest_checksums: Some("SHA256".to_string()), + ..Default::default() + }; + assert_eq!( + build_base_backup_sql(&opts), + "BASE_BACKUP (MANIFEST 'yes', MANIFEST_CHECKSUMS 'SHA256')" + ); + } + + #[test] + fn base_backup_incremental() { + let opts = BaseBackupOptions { + incremental: true, + ..Default::default() + }; + assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP (INCREMENTAL)"); + } + + #[test] + fn base_backup_multiple_options() { + let opts = BaseBackupOptions { + label: Some("backup".to_string()), + progress: true, + wal: true, + verify_checksums: true, + ..Default::default() + }; + assert_eq!( + build_base_backup_sql(&opts), + "BASE_BACKUP (LABEL 'backup', PROGRESS true, WAL true, VERIFY_CHECKSUMS true)" + ); + } + + #[test] + fn base_backup_label_injection() { + let opts = BaseBackupOptions { + label: Some("evil'; DROP TABLE users; --".to_string()), + ..Default::default() + }; + assert_eq!( + build_base_backup_sql(&opts), + "BASE_BACKUP (LABEL 'evil''; DROP TABLE users; --')" + ); + } + + // ── build_sql_options ──────────────────────────────────────────────── + + #[test] + fn sql_options_empty() { + assert_eq!(build_sql_options(&[]), ""); + } + + #[test] + fn sql_options_single() { + let opts = vec!["proto_version '2'".to_string()]; + assert_eq!(build_sql_options(&opts), " (proto_version '2')"); + } + + #[test] + fn sql_options_multiple() { + let opts = vec![ + "proto_version '2'".to_string(), + "publication_names '\"my_pub\"'".to_string(), + "streaming 'on'".to_string(), + ]; + assert_eq!( + build_sql_options(&opts), + " (proto_version '2', publication_names '\"my_pub\"', streaming 'on')" + ); + } + + // ── build_create_subscription_sql ──────────────────────────────────── + + #[test] + fn create_subscription_basic() { + let opts = CreateSubscriptionOptions { + subscription_name: "my_sub", + connection_string: "host=localhost dbname=source", + publication: "my_pub", + slot_name: "my_slot", + ..Default::default() + }; + let sql = build_create_subscription_sql(&opts); + assert_eq!( + sql, + "CREATE SUBSCRIPTION \"my_sub\" \ + CONNECTION 'host=localhost dbname=source' \ + PUBLICATION \"my_pub\" \ + WITH (create_slot = false, slot_name = 'my_slot', enabled = true, copy_data = false)" + ); + } + + #[test] + fn create_subscription_with_special_chars() { + let opts = CreateSubscriptionOptions { + subscription_name: r#"sub"name"#, + connection_string: "host=db password='secret'", + publication: "pub", + slot_name: "slot'name", + ..Default::default() + }; + let sql = build_create_subscription_sql(&opts); + assert_eq!( + sql, + "CREATE SUBSCRIPTION \"sub\"\"name\" \ + CONNECTION 'host=db password=''secret''' \ + PUBLICATION \"pub\" \ + WITH (create_slot = false, slot_name = 'slot''name', enabled = true, copy_data = false)" + ); + } + + #[test] + fn create_subscription_empty_names() { + let opts = CreateSubscriptionOptions { + subscription_name: "", + connection_string: "", + publication: "", + slot_name: "", + ..Default::default() + }; + let sql = build_create_subscription_sql(&opts); + assert_eq!( + sql, + "CREATE SUBSCRIPTION \"\" \ + CONNECTION '' \ + PUBLICATION \"\" \ + WITH (create_slot = false, slot_name = '', enabled = true, copy_data = false)" + ); + } + + #[test] + fn create_subscription_create_slot_true() { + let opts = CreateSubscriptionOptions { + subscription_name: "sub", + connection_string: "host=localhost", + publication: "pub", + slot_name: "slot", + create_slot: true, + ..Default::default() + }; + let sql = build_create_subscription_sql(&opts); + assert!(sql.contains("create_slot = true")); + assert!(sql.contains("enabled = true")); + assert!(sql.contains("copy_data = false")); + } + + #[test] + fn create_subscription_copy_data_true() { + let opts = CreateSubscriptionOptions { + subscription_name: "sub", + connection_string: "host=localhost", + publication: "pub", + slot_name: "slot", + copy_data: true, + ..Default::default() + }; + let sql = build_create_subscription_sql(&opts); + assert!(sql.contains("copy_data = true")); + } + + #[test] + fn create_subscription_disabled() { + let opts = CreateSubscriptionOptions { + subscription_name: "sub", + connection_string: "host=localhost", + publication: "pub", + slot_name: "slot", + enabled: false, + ..Default::default() + }; + let sql = build_create_subscription_sql(&opts); + assert!(sql.contains("enabled = false")); + } + + #[test] + fn create_subscription_all_true() { + let opts = CreateSubscriptionOptions { + subscription_name: "sub", + connection_string: "host=localhost", + publication: "pub", + slot_name: "slot", + create_slot: true, + enabled: true, + copy_data: true, + }; + let sql = build_create_subscription_sql(&opts); + assert_eq!( + sql, + "CREATE SUBSCRIPTION \"sub\" \ + CONNECTION 'host=localhost' \ + PUBLICATION \"pub\" \ + WITH (create_slot = true, slot_name = 'slot', enabled = true, copy_data = true)" + ); + } + + #[test] + fn create_subscription_all_false() { + let opts = CreateSubscriptionOptions { + subscription_name: "sub", + connection_string: "host=localhost", + publication: "pub", + slot_name: "slot", + create_slot: false, + enabled: false, + copy_data: false, + }; + let sql = build_create_subscription_sql(&opts); + assert_eq!( + sql, + "CREATE SUBSCRIPTION \"sub\" \ + CONNECTION 'host=localhost' \ + PUBLICATION \"pub\" \ + WITH (create_slot = false, slot_name = 'slot', enabled = false, copy_data = false)" + ); + } + + // ── build_disable_subscription_sql ─────────────────────────────────── + + #[test] + fn disable_subscription_basic() { + assert_eq!( + build_disable_subscription_sql("my_sub"), + r#"ALTER SUBSCRIPTION "my_sub" DISABLE"# + ); + } + + #[test] + fn disable_subscription_with_quotes() { + assert_eq!( + build_disable_subscription_sql(r#"sub"name"#), + r#"ALTER SUBSCRIPTION "sub""name" DISABLE"# + ); + } + + // ── build_detach_slot_sql ──────────────────────────────────────────── + + #[test] + fn detach_slot_basic() { + assert_eq!( + build_detach_slot_sql("my_sub"), + r#"ALTER SUBSCRIPTION "my_sub" SET (slot_name = NONE)"# + ); + } + + #[test] + fn detach_slot_with_quotes() { + assert_eq!( + build_detach_slot_sql(r#"sub"x"#), + r#"ALTER SUBSCRIPTION "sub""x" SET (slot_name = NONE)"# + ); + } + + // ── build_drop_subscription_sql ────────────────────────────────────── + + #[test] + fn drop_subscription_basic() { + assert_eq!( + build_drop_subscription_sql("my_sub"), + r#"DROP SUBSCRIPTION "my_sub""# + ); + } + + #[test] + fn drop_subscription_with_quotes() { + assert_eq!( + build_drop_subscription_sql(r#"sub"name"#), + r#"DROP SUBSCRIPTION "sub""name""# + ); + } + + #[test] + fn drop_subscription_injection_attempt() { + assert_eq!( + build_drop_subscription_sql("evil\"; DROP TABLE users; --"), + "DROP SUBSCRIPTION \"evil\"\"; DROP TABLE users; --\"" + ); + } + + // ── Compatibility with existing behavior ───────────────────────────── + + #[test] + fn quote_ident_matches_legacy_behavior() { + let cases = [ + "my_slot", + r#"a"b"#, + r#"a""b"#, + "", + "slot; DROP TABLE users; --", + r#"evil"PHYSICAL"#, + ]; + for input in cases { + let legacy = format!("\"{}\"", input.replace('"', "\"\"")); + assert_eq!(quote_ident(input), legacy, "mismatch for input: {input:?}"); + } + } + + #[test] + fn quote_literal_matches_legacy_behavior() { + let cases = [ + "test_value", + "test'value", + "test'value'with'quotes", + "'; DROP TABLE users; --", + "", + "'", + "''", + ]; + for input in cases { + let legacy = format!("'{}'", input.replace('\'', "''")); + assert_eq!( + quote_literal(input), + legacy, + "mismatch for input: {input:?}" + ); + } + } +} From 196c4ae47d0d32725942077ba4b7fd4e8bfd2107 Mon Sep 17 00:00:00 2001 From: danielshih Date: Mon, 4 May 2026 09:00:02 +0000 Subject: [PATCH 2/4] refactor: simplify SQL command construction in sql_builder.rs style: format SQL statement in build_detach_slot_sql for improved readability --- src/sql_builder.rs | 76 ++++++++++------------------------------------ 1 file changed, 16 insertions(+), 60 deletions(-) diff --git a/src/sql_builder.rs b/src/sql_builder.rs index c925051..a9d18a0 100644 --- a/src/sql_builder.rs +++ b/src/sql_builder.rs @@ -257,17 +257,11 @@ pub fn build_start_replication_sql( return format!("START_REPLICATION SLOT {quoted_slot} LOGICAL {lsn_str}"); } - let mut options_str = String::new(); - for (i, (key, value)) in options.iter().enumerate() { - if i > 0 { - options_str.push_str(", "); - } - let quoted_key = quote_ident(key); - let quoted_value = quote_literal(value); - options_str.push_str("ed_key); - options_str.push(' '); - options_str.push_str("ed_value); - } + let options_str = options + .iter() + .map(|(k, v)| format!("{} {}", quote_ident(k), quote_literal(v))) + .collect::>() + .join(", "); format!("START_REPLICATION SLOT {quoted_slot} LOGICAL {lsn_str} ({options_str})") } @@ -473,35 +467,11 @@ pub fn build_create_subscription_sql(opts: &CreateSubscriptionOptions<'_>) -> St let enabled_str = if opts.enabled { "true" } else { "false" }; let copy_data_str = if opts.copy_data { "true" } else { "false" }; - let estimated = "CREATE SUBSCRIPTION ".len() - + sub.len() - + " CONNECTION ".len() - + conn.len() - + " PUBLICATION ".len() - + pubname.len() - + " WITH (create_slot = , slot_name = , enabled = , copy_data = )".len() - + slot.len() - + create_slot_str.len() - + enabled_str.len() - + copy_data_str.len(); - - let mut sql = String::with_capacity(estimated); - sql.push_str("CREATE SUBSCRIPTION "); - sql.push_str(&sub); - sql.push_str(" CONNECTION "); - sql.push_str(&conn); - sql.push_str(" PUBLICATION "); - sql.push_str(&pubname); - sql.push_str(" WITH (create_slot = "); - sql.push_str(create_slot_str); - sql.push_str(", slot_name = "); - sql.push_str(&slot); - sql.push_str(", enabled = "); - sql.push_str(enabled_str); - sql.push_str(", copy_data = "); - sql.push_str(copy_data_str); - sql.push(')'); - sql + format!( + "CREATE SUBSCRIPTION {sub} CONNECTION {conn} PUBLICATION {pubname} \ + WITH (create_slot = {create_slot_str}, slot_name = {slot}, \ + enabled = {enabled_str}, copy_data = {copy_data_str})" + ) } /// Build an `ALTER SUBSCRIPTION ... DISABLE` statement. @@ -516,13 +486,7 @@ pub fn build_create_subscription_sql(opts: &CreateSubscriptionOptions<'_>) -> St /// ``` #[inline] pub fn build_disable_subscription_sql(name: &str) -> String { - let quoted = quote_ident(name); - let mut sql = - String::with_capacity("ALTER SUBSCRIPTION ".len() + quoted.len() + " DISABLE".len()); - sql.push_str("ALTER SUBSCRIPTION "); - sql.push_str("ed); - sql.push_str(" DISABLE"); - sql + format!("ALTER SUBSCRIPTION {} DISABLE", quote_ident(name)) } /// Build an `ALTER SUBSCRIPTION ... SET (slot_name = NONE)` statement to detach a slot. @@ -537,14 +501,10 @@ pub fn build_disable_subscription_sql(name: &str) -> String { /// ``` #[inline] pub fn build_detach_slot_sql(name: &str) -> String { - let quoted = quote_ident(name); - let mut sql = String::with_capacity( - "ALTER SUBSCRIPTION ".len() + quoted.len() + " SET (slot_name = NONE)".len(), - ); - sql.push_str("ALTER SUBSCRIPTION "); - sql.push_str("ed); - sql.push_str(" SET (slot_name = NONE)"); - sql + format!( + "ALTER SUBSCRIPTION {} SET (slot_name = NONE)", + quote_ident(name) + ) } /// Build a `DROP SUBSCRIPTION` statement. @@ -559,11 +519,7 @@ pub fn build_detach_slot_sql(name: &str) -> String { /// ``` #[inline] pub fn build_drop_subscription_sql(name: &str) -> String { - let quoted = quote_ident(name); - let mut sql = String::with_capacity("DROP SUBSCRIPTION ".len() + quoted.len()); - sql.push_str("DROP SUBSCRIPTION "); - sql.push_str("ed); - sql + format!("DROP SUBSCRIPTION {}", quote_ident(name)) } // ═══════════════════════════════════════════════════════════════════════════ From f6ae52fcb412cc51caff7ab652ca3066ec1827ec Mon Sep 17 00:00:00 2001 From: danielshih Date: Mon, 4 May 2026 09:48:27 +0000 Subject: [PATCH 3/4] Refactor SQL builder functions to return Result types for error handling - Updated SQL builder functions to return Result instead of String, allowing for better error propagation. - Modified the NativeConnection methods to handle the Result type from SQL builder functions. - Adjusted tests to unwrap results from SQL builder functions, ensuring they handle potential errors. - Added tests to verify that null byte inputs are correctly rejected across various SQL builder functions. --- src/connection/libpq.rs | 110 +++--- src/connection/native/connection.rs | 32 +- src/sql_builder.rs | 499 +++++++++++++++++++--------- 3 files changed, 425 insertions(+), 216 deletions(-) diff --git a/src/connection/libpq.rs b/src/connection/libpq.rs index a6ebbb0..c53a6f0 100644 --- a/src/connection/libpq.rs +++ b/src/connection/libpq.rs @@ -302,7 +302,7 @@ impl PgReplicationConnection { slot_name: &str, start_lsn: XLogRecPtr, options: &[(&str, &str)], - ) -> String { + ) -> Result { crate::sql_builder::build_start_replication_sql(slot_name, start_lsn, options) } @@ -313,7 +313,7 @@ impl PgReplicationConnection { start_lsn: XLogRecPtr, options: &[(&str, &str)], ) -> Result<()> { - let sql = Self::build_start_replication_sql(slot_name, start_lsn, options); + let sql = Self::build_start_replication_sql(slot_name, start_lsn, options)?; debug!("Starting replication: {}", sql); let _result = self.exec(&sql)?; @@ -654,7 +654,7 @@ impl PgReplicationConnection { } /// Build the SQL string for `DROP_REPLICATION_SLOT`. - fn build_drop_slot_sql(slot_name: &str, wait: bool) -> String { + fn build_drop_slot_sql(slot_name: &str, wait: bool) -> Result { crate::sql_builder::build_drop_slot_sql(slot_name, wait) } @@ -668,7 +668,7 @@ impl PgReplicationConnection { /// * `wait` - If true, the command waits until the slot becomes inactive /// instead of returning an error when the slot is in use pub fn drop_replication_slot(&mut self, slot_name: &str, wait: bool) -> Result<()> { - let sql = Self::build_drop_slot_sql(slot_name, wait); + let sql = Self::build_drop_slot_sql(slot_name, wait)?; debug!("Dropping replication slot: {}", sql); let result = self.exec(&sql)?; @@ -686,7 +686,7 @@ impl PgReplicationConnection { } /// Build the SQL string for `READ_REPLICATION_SLOT`. - fn build_read_slot_sql(slot_name: &str) -> String { + fn build_read_slot_sql(slot_name: &str) -> Result { crate::sql_builder::build_read_slot_sql(slot_name) } @@ -700,7 +700,7 @@ impl PgReplicationConnection { &mut self, slot_name: &str, ) -> Result { - let sql = Self::build_read_slot_sql(slot_name); + let sql = Self::build_read_slot_sql(slot_name)?; debug!("Reading replication slot: {}", sql); let result = self.exec(&sql)?; @@ -733,7 +733,7 @@ impl PgReplicationConnection { slot_name: Option<&str>, start_lsn: XLogRecPtr, timeline_id: Option, - ) -> String { + ) -> Result { crate::sql_builder::build_start_physical_replication_sql(slot_name, start_lsn, timeline_id) } @@ -744,7 +744,7 @@ impl PgReplicationConnection { start_lsn: XLogRecPtr, timeline_id: Option, ) -> Result<()> { - let sql = Self::build_start_physical_replication_sql(slot_name, start_lsn, timeline_id); + let sql = Self::build_start_physical_replication_sql(slot_name, start_lsn, timeline_id)?; debug!("Starting physical replication: {}", sql); let _result = self.exec(&sql)?; @@ -779,13 +779,13 @@ impl PgReplicationConnection { } /// Build the SQL string for `BASE_BACKUP`. - fn build_base_backup_sql(options: &BaseBackupOptions) -> String { + fn build_base_backup_sql(options: &BaseBackupOptions) -> Result { crate::sql_builder::build_base_backup_sql(options) } /// Start a base backup with options pub fn base_backup(&mut self, options: &BaseBackupOptions) -> Result { - let base_backup_sql = Self::build_base_backup_sql(options); + let base_backup_sql = Self::build_base_backup_sql(options)?; debug!("Starting base backup: {}", base_backup_sql); let result = self.exec(&base_backup_sql)?; @@ -1035,16 +1035,16 @@ mod tests { use crate::INVALID_XLOG_REC_PTR; fn sanitize_sql_string_value(value: &str) -> String { - let quoted = quote_literal(value); + let quoted = quote_literal(value).unwrap(); quoted[1..quoted.len() - 1].to_owned() } fn quote_sql_string_value(value: &str) -> String { - quote_literal(value) + quote_literal(value).unwrap() } fn quote_sql_identifier(identifier: &str) -> String { - quote_ident(identifier) + quote_ident(identifier).unwrap() } #[test] @@ -1651,25 +1651,25 @@ mod tests { #[test] fn test_build_drop_slot_sql_without_wait() { - let sql = PgReplicationConnection::build_drop_slot_sql("my_slot", false); + let sql = PgReplicationConnection::build_drop_slot_sql("my_slot", false).unwrap(); assert_eq!(sql, r#"DROP_REPLICATION_SLOT "my_slot";"#); } #[test] fn test_build_drop_slot_sql_with_wait() { - let sql = PgReplicationConnection::build_drop_slot_sql("my_slot", true); + let sql = PgReplicationConnection::build_drop_slot_sql("my_slot", true).unwrap(); assert_eq!(sql, r#"DROP_REPLICATION_SLOT "my_slot" WAIT;"#); } #[test] fn test_build_drop_slot_sql_injection() { - let sql = PgReplicationConnection::build_drop_slot_sql(r#"evil"slot"#, false); + let sql = PgReplicationConnection::build_drop_slot_sql(r#"evil"slot"#, false).unwrap(); assert_eq!(sql, r#"DROP_REPLICATION_SLOT "evil""slot";"#); } #[test] fn test_build_drop_slot_sql_injection_with_wait() { - let sql = PgReplicationConnection::build_drop_slot_sql(r#"evil"slot"#, true); + let sql = PgReplicationConnection::build_drop_slot_sql(r#"evil"slot"#, true).unwrap(); assert_eq!(sql, r#"DROP_REPLICATION_SLOT "evil""slot" WAIT;"#); } @@ -1679,13 +1679,13 @@ mod tests { #[test] fn test_build_read_slot_sql_basic() { - let sql = PgReplicationConnection::build_read_slot_sql("my_slot"); + let sql = PgReplicationConnection::build_read_slot_sql("my_slot").unwrap(); assert_eq!(sql, r#"READ_REPLICATION_SLOT "my_slot";"#); } #[test] fn test_build_read_slot_sql_injection() { - let sql = PgReplicationConnection::build_read_slot_sql(r#"evil"slot"#); + let sql = PgReplicationConnection::build_read_slot_sql(r#"evil"slot"#).unwrap(); assert_eq!(sql, r#"READ_REPLICATION_SLOT "evil""slot";"#); } @@ -1785,7 +1785,7 @@ mod tests { #[test] fn test_base_backup_sql_default_options() { let opts = BaseBackupOptions::default(); - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP"); } @@ -1795,7 +1795,7 @@ mod tests { label: Some("my_backup".to_string()), ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP (LABEL 'my_backup')"); } @@ -1805,7 +1805,7 @@ mod tests { target: Some("client".to_string()), ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP (TARGET 'client')"); } @@ -1816,7 +1816,7 @@ mod tests { target_detail: Some("/var/backups".to_string()), ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!( sql, "BASE_BACKUP (TARGET 'server', TARGET_DETAIL '/var/backups')" @@ -1829,7 +1829,7 @@ mod tests { progress: true, ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP (PROGRESS true)"); } @@ -1839,7 +1839,7 @@ mod tests { checkpoint: Some("fast".to_string()), ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP (CHECKPOINT 'fast')"); } @@ -1849,7 +1849,7 @@ mod tests { wal: true, ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP (WAL true)"); } @@ -1859,7 +1859,7 @@ mod tests { wait: true, ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP (WAIT true)"); } @@ -1869,7 +1869,7 @@ mod tests { compression: Some("gzip".to_string()), ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP (COMPRESSION 'gzip')"); } @@ -1880,7 +1880,7 @@ mod tests { compression_detail: Some("level=3".to_string()), ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!( sql, "BASE_BACKUP (COMPRESSION 'zstd', COMPRESSION_DETAIL 'level=3')" @@ -1893,7 +1893,7 @@ mod tests { max_rate: Some(32768), ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP (MAX_RATE 32768)"); } @@ -1903,7 +1903,7 @@ mod tests { tablespace_map: true, ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP (TABLESPACE_MAP true)"); } @@ -1913,7 +1913,7 @@ mod tests { verify_checksums: true, ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP (VERIFY_CHECKSUMS true)"); } @@ -1923,7 +1923,7 @@ mod tests { manifest: Some("yes".to_string()), ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP (MANIFEST 'yes')"); } @@ -1934,7 +1934,7 @@ mod tests { manifest_checksums: Some("SHA256".to_string()), ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!( sql, "BASE_BACKUP (MANIFEST 'yes', MANIFEST_CHECKSUMS 'SHA256')" @@ -1947,7 +1947,7 @@ mod tests { incremental: true, ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP (INCREMENTAL)"); } @@ -1961,7 +1961,7 @@ mod tests { verify_checksums: true, ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!( sql, "BASE_BACKUP (LABEL 'full_backup', PROGRESS true, CHECKPOINT 'fast', WAL true, VERIFY_CHECKSUMS true)" @@ -1974,7 +1974,7 @@ mod tests { label: Some("evil'label".to_string()), ..Default::default() }; - let sql = PgReplicationConnection::build_base_backup_sql(&opts); + let sql = PgReplicationConnection::build_base_backup_sql(&opts).unwrap(); assert_eq!(sql, "BASE_BACKUP (LABEL 'evil''label')"); } @@ -1988,7 +1988,8 @@ mod tests { "my_slot", INVALID_XLOG_REC_PTR, &[("proto_version", "1"), ("publication_names", "my_pub")], - ); + ) + .unwrap(); assert_eq!( sql, r#"START_REPLICATION SLOT "my_slot" LOGICAL 0/0 ("proto_version" '1', "publication_names" 'my_pub')"# @@ -2002,7 +2003,8 @@ mod tests { "test_slot", lsn, &[("proto_version", "2")], - ); + ) + .unwrap(); assert!(sql.contains("START_REPLICATION SLOT \"test_slot\" LOGICAL")); assert!(sql.contains("(\"proto_version\" '2')")); // Should NOT contain "0/0" since we provided a valid LSN @@ -2019,7 +2021,8 @@ mod tests { ("publication_names", "pub1"), ("messages", "true"), ], - ); + ) + .unwrap(); assert!( sql.contains(r#""proto_version" '1', "publication_names" 'pub1', "messages" 'true'"#) ); @@ -2031,7 +2034,8 @@ mod tests { "slot1", INVALID_XLOG_REC_PTR, &[], - ); + ) + .unwrap(); assert_eq!(sql, r#"START_REPLICATION SLOT "slot1" LOGICAL 0/0"#); } @@ -2041,7 +2045,8 @@ mod tests { r#"evil"slot"#, INVALID_XLOG_REC_PTR, &[("key", "it's")], - ); + ) + .unwrap(); // Slot name should be quoted, value should be sanitized assert!(sql.contains(r#""evil""slot""#)); assert!(sql.contains("'it''s'")); @@ -2053,7 +2058,8 @@ mod tests { "my_slot", INVALID_XLOG_REC_PTR, &[("proto_version", "1")], - ); + ) + .unwrap(); assert_eq!( sql, r#"START_REPLICATION SLOT "my_slot" LOGICAL 0/0 ("proto_version" '1')"# @@ -2127,7 +2133,8 @@ mod tests { None, INVALID_XLOG_REC_PTR, None, - ); + ) + .unwrap(); assert_eq!(sql, "START_REPLICATION PHYSICAL 0/0"); } @@ -2137,7 +2144,8 @@ mod tests { Some("phys_slot"), INVALID_XLOG_REC_PTR, None, - ); + ) + .unwrap(); assert_eq!(sql, r#"START_REPLICATION SLOT "phys_slot" PHYSICAL 0/0"#); } @@ -2147,14 +2155,16 @@ mod tests { None, INVALID_XLOG_REC_PTR, Some(3), - ); + ) + .unwrap(); assert_eq!(sql, "START_REPLICATION PHYSICAL 0/0 TIMELINE 3"); } #[test] fn test_physical_replication_sql_with_valid_lsn() { let lsn: XLogRecPtr = 0x0000_0001_0000_0000; // 1/0 - let sql = PgReplicationConnection::build_start_physical_replication_sql(None, lsn, None); + let sql = + PgReplicationConnection::build_start_physical_replication_sql(None, lsn, None).unwrap(); assert!(sql.starts_with("START_REPLICATION PHYSICAL ")); assert!(!sql.contains("0/0")); } @@ -2165,7 +2175,8 @@ mod tests { Some("my_slot"), INVALID_XLOG_REC_PTR, Some(2), - ); + ) + .unwrap(); assert_eq!( sql, r#"START_REPLICATION SLOT "my_slot" PHYSICAL 0/0 TIMELINE 2"# @@ -2178,7 +2189,8 @@ mod tests { Some(r#"evil"slot"#), INVALID_XLOG_REC_PTR, None, - ); + ) + .unwrap(); assert!(sql.contains(r#""evil""slot""#)); } } diff --git a/src/connection/native/connection.rs b/src/connection/native/connection.rs index 15a6ace..fa1905d 100644 --- a/src/connection/native/connection.rs +++ b/src/connection/native/connection.rs @@ -155,7 +155,7 @@ impl NativeConnection { start_lsn: XLogRecPtr, options: &[(&str, &str)], ) -> Result<()> { - let sql = crate::sql_builder::build_start_replication_sql(slot_name, start_lsn, options); + let sql = crate::sql_builder::build_start_replication_sql(slot_name, start_lsn, options)?; debug!("Starting replication: {}", sql); @@ -310,13 +310,13 @@ impl NativeConnection { Ok(result) } - fn build_drop_slot_sql(slot_name: &str, wait: bool) -> String { + fn build_drop_slot_sql(slot_name: &str, wait: bool) -> Result { crate::sql_builder::build_drop_slot_sql(slot_name, wait) } /// Drop a replication slot. pub fn drop_replication_slot(&mut self, slot_name: &str, wait: bool) -> Result<()> { - let sql = Self::build_drop_slot_sql(slot_name, wait); + let sql = Self::build_drop_slot_sql(slot_name, wait)?; debug!("Dropping replication slot: {}", sql); let result = self.exec(&sql)?; if !result.is_ok() { @@ -332,7 +332,7 @@ impl NativeConnection { Ok(()) } - fn build_read_slot_sql(slot_name: &str) -> String { + fn build_read_slot_sql(slot_name: &str) -> Result { crate::sql_builder::build_read_slot_sql(slot_name) } @@ -341,7 +341,7 @@ impl NativeConnection { &mut self, slot_name: &str, ) -> Result { - let sql = Self::build_read_slot_sql(slot_name); + let sql = Self::build_read_slot_sql(slot_name)?; debug!("Reading replication slot: {}", sql); let result = self.exec(&sql)?; if !result.is_ok() { @@ -379,7 +379,7 @@ impl NativeConnection { slot_name, start_lsn, timeline_id, - ); + )?; debug!("Starting physical replication: {}", sql); @@ -406,7 +406,7 @@ impl NativeConnection { /// Start a base backup with options. pub fn base_backup(&mut self, options: &BaseBackupOptions) -> Result { - let sql = crate::sql_builder::build_base_backup_sql(options); + let sql = crate::sql_builder::build_base_backup_sql(options)?; debug!("Starting base backup: {}", sql); let result = self.exec(&sql)?; @@ -507,17 +507,17 @@ mod tests { use crate::types::{ReplicationSlotOptions, SlotType}; fn sanitize_sql_string_value(value: &str) -> String { - let quoted = crate::sql_builder::quote_literal(value); + let quoted = crate::sql_builder::quote_literal(value).unwrap(); // Strip surrounding quotes to get just the sanitized interior quoted[1..quoted.len() - 1].to_string() } fn quote_sql_string_value(value: &str) -> String { - crate::sql_builder::quote_literal(value) + crate::sql_builder::quote_literal(value).unwrap() } fn quote_sql_identifier(identifier: &str) -> String { - crate::sql_builder::quote_ident(identifier) + crate::sql_builder::quote_ident(identifier).unwrap() } // === sanitize_sql_string_value === @@ -937,7 +937,7 @@ mod tests { #[test] fn test_build_drop_slot_sql_without_wait() { assert_eq!( - NativeConnection::build_drop_slot_sql("my_slot", false), + NativeConnection::build_drop_slot_sql("my_slot", false).unwrap(), r#"DROP_REPLICATION_SLOT "my_slot";"# ); } @@ -945,7 +945,7 @@ mod tests { #[test] fn test_build_drop_slot_sql_with_wait() { assert_eq!( - NativeConnection::build_drop_slot_sql("my_slot", true), + NativeConnection::build_drop_slot_sql("my_slot", true).unwrap(), r#"DROP_REPLICATION_SLOT "my_slot" WAIT;"# ); } @@ -953,7 +953,7 @@ mod tests { #[test] fn test_build_drop_slot_sql_injection() { assert_eq!( - NativeConnection::build_drop_slot_sql(r#"evil"slot"#, false), + NativeConnection::build_drop_slot_sql(r#"evil"slot"#, false).unwrap(), r#"DROP_REPLICATION_SLOT "evil""slot";"# ); } @@ -961,7 +961,7 @@ mod tests { #[test] fn test_build_drop_slot_sql_injection_with_wait() { assert_eq!( - NativeConnection::build_drop_slot_sql(r#"evil"slot"#, true), + NativeConnection::build_drop_slot_sql(r#"evil"slot"#, true).unwrap(), r#"DROP_REPLICATION_SLOT "evil""slot" WAIT;"# ); } @@ -971,7 +971,7 @@ mod tests { #[test] fn test_build_read_slot_sql_basic() { assert_eq!( - NativeConnection::build_read_slot_sql("my_slot"), + NativeConnection::build_read_slot_sql("my_slot").unwrap(), r#"READ_REPLICATION_SLOT "my_slot";"# ); } @@ -979,7 +979,7 @@ mod tests { #[test] fn test_build_read_slot_sql_injection() { assert_eq!( - NativeConnection::build_read_slot_sql(r#"evil"slot"#), + NativeConnection::build_read_slot_sql(r#"evil"slot"#).unwrap(), r#"READ_REPLICATION_SLOT "evil""slot";"# ); } diff --git a/src/sql_builder.rs b/src/sql_builder.rs index a9d18a0..9f98fb8 100644 --- a/src/sql_builder.rs +++ b/src/sql_builder.rs @@ -14,9 +14,9 @@ const INVALID_XLOG_REC_PTR: u64 = 0; /// Quote a PostgreSQL identifier by wrapping in double quotes and escaping /// internal double quotes (doubling them). /// -/// # Panics +/// # Errors /// -/// Panics if `name` contains a null byte (`\0`), which is invalid in +/// Returns an error if `name` contains a null byte (`\0`), which is invalid in /// PostgreSQL identifiers and could cause truncation-based injection via /// the C-string wire protocol. /// @@ -25,15 +25,17 @@ const INVALID_XLOG_REC_PTR: u64 = 0; /// ``` /// use pg_walstream::sql_builder::quote_ident; /// -/// assert_eq!(quote_ident("my_slot"), r#""my_slot""#); -/// assert_eq!(quote_ident(r#"a"b"#), r#""a""b""#); +/// assert_eq!(quote_ident("my_slot").unwrap(), r#""my_slot""#); +/// assert_eq!(quote_ident(r#"a"b"#).unwrap(), r#""a""b""#); +/// assert!(quote_ident("bad\0name").is_err()); /// ``` #[inline] -pub fn quote_ident(name: &str) -> String { - assert!( - !name.contains('\0'), - "SQL identifier must not contain null bytes" - ); +pub fn quote_ident(name: &str) -> Result { + if name.contains('\0') { + return Err(ReplicationError::config( + "SQL identifier must not contain null bytes".to_string(), + )); + } let mut out = String::with_capacity(name.len() + 2); out.push('"'); for ch in name.chars() { @@ -43,15 +45,15 @@ pub fn quote_ident(name: &str) -> String { out.push(ch); } out.push('"'); - out + Ok(out) } /// Quote a PostgreSQL string literal by wrapping in single quotes and escaping /// internal single quotes (doubling them). /// -/// # Panics +/// # Errors /// -/// Panics if `value` contains a null byte (`\0`), which is invalid in +/// Returns an error if `value` contains a null byte (`\0`), which is invalid in /// PostgreSQL string literals and could cause truncation-based injection via /// the C-string wire protocol. /// @@ -60,15 +62,17 @@ pub fn quote_ident(name: &str) -> String { /// ``` /// use pg_walstream::sql_builder::quote_literal; /// -/// assert_eq!(quote_literal("hello"), "'hello'"); -/// assert_eq!(quote_literal("it's"), "'it''s'"); +/// assert_eq!(quote_literal("hello").unwrap(), "'hello'"); +/// assert_eq!(quote_literal("it's").unwrap(), "'it''s'"); +/// assert!(quote_literal("bad\0value").is_err()); /// ``` #[inline] -pub fn quote_literal(value: &str) -> String { - assert!( - !value.contains('\0'), - "SQL literal must not contain null bytes" - ); +pub fn quote_literal(value: &str) -> Result { + if value.contains('\0') { + return Err(ReplicationError::config( + "SQL literal must not contain null bytes".to_string(), + )); + } let mut out = String::with_capacity(value.len() + 2); out.push('\''); for ch in value.chars() { @@ -78,7 +82,7 @@ pub fn quote_literal(value: &str) -> String { out.push(ch); } out.push('\''); - out + Ok(out) } // ═══════════════════════════════════════════════════════════════════════════ @@ -105,7 +109,7 @@ pub fn build_create_slot_sql( ) -> Result { let mut parts: Vec<&str> = Vec::with_capacity(6); - let quoted_slot = quote_ident(slot_name); + let quoted_slot = quote_ident(slot_name)?; parts.push("CREATE_REPLICATION_SLOT"); parts.push("ed_slot); @@ -128,7 +132,7 @@ pub fn build_create_slot_sql( let plugin = output_plugin.ok_or_else(|| { ReplicationError::protocol("Output plugin required for LOGICAL slots".to_string()) })?; - quoted_plugin = quote_ident(plugin); + quoted_plugin = quote_ident(plugin)?; parts.push("ed_plugin); if options.two_phase { @@ -189,7 +193,7 @@ pub fn build_alter_slot_sql( } let options_str = build_sql_options(&opts); - let quoted_slot = quote_ident(slot_name); + let quoted_slot = quote_ident(slot_name)?; Ok(format!( "ALTER_REPLICATION_SLOT {}{};", quoted_slot, options_str @@ -203,16 +207,16 @@ pub fn build_alter_slot_sql( /// ``` /// use pg_walstream::sql_builder::build_drop_slot_sql; /// -/// assert_eq!(build_drop_slot_sql("my_slot", false), r#"DROP_REPLICATION_SLOT "my_slot";"#); -/// assert_eq!(build_drop_slot_sql("my_slot", true), r#"DROP_REPLICATION_SLOT "my_slot" WAIT;"#); +/// assert_eq!(build_drop_slot_sql("my_slot", false).unwrap(), r#"DROP_REPLICATION_SLOT "my_slot";"#); +/// assert_eq!(build_drop_slot_sql("my_slot", true).unwrap(), r#"DROP_REPLICATION_SLOT "my_slot" WAIT;"#); /// ``` #[inline] -pub fn build_drop_slot_sql(slot_name: &str, wait: bool) -> String { - let quoted_slot = quote_ident(slot_name); +pub fn build_drop_slot_sql(slot_name: &str, wait: bool) -> Result { + let quoted_slot = quote_ident(slot_name)?; if wait { - format!("DROP_REPLICATION_SLOT {} WAIT;", quoted_slot) + Ok(format!("DROP_REPLICATION_SLOT {} WAIT;", quoted_slot)) } else { - format!("DROP_REPLICATION_SLOT {};", quoted_slot) + Ok(format!("DROP_REPLICATION_SLOT {};", quoted_slot)) } } @@ -223,12 +227,12 @@ pub fn build_drop_slot_sql(slot_name: &str, wait: bool) -> String { /// ``` /// use pg_walstream::sql_builder::build_read_slot_sql; /// -/// assert_eq!(build_read_slot_sql("my_slot"), r#"READ_REPLICATION_SLOT "my_slot";"#); +/// assert_eq!(build_read_slot_sql("my_slot").unwrap(), r#"READ_REPLICATION_SLOT "my_slot";"#); /// ``` #[inline] -pub fn build_read_slot_sql(slot_name: &str) -> String { - let quoted_slot = quote_ident(slot_name); - format!("READ_REPLICATION_SLOT {};", quoted_slot) +pub fn build_read_slot_sql(slot_name: &str) -> Result { + let quoted_slot = quote_ident(slot_name)?; + Ok(format!("READ_REPLICATION_SLOT {};", quoted_slot)) } /// Build the SQL for `START_REPLICATION SLOT ... LOGICAL`. @@ -238,15 +242,15 @@ pub fn build_read_slot_sql(slot_name: &str) -> String { /// ``` /// use pg_walstream::sql_builder::build_start_replication_sql; /// -/// let sql = build_start_replication_sql("my_slot", 0, &[("proto_version", "1")]); +/// let sql = build_start_replication_sql("my_slot", 0, &[("proto_version", "1")]).unwrap(); /// assert_eq!(sql, r#"START_REPLICATION SLOT "my_slot" LOGICAL 0/0 ("proto_version" '1')"#); /// ``` pub fn build_start_replication_sql( slot_name: &str, start_lsn: XLogRecPtr, options: &[(&str, &str)], -) -> String { - let quoted_slot = quote_ident(slot_name); +) -> Result { + let quoted_slot = quote_ident(slot_name)?; let lsn_str = if start_lsn == INVALID_XLOG_REC_PTR { "0/0".to_string() } else { @@ -254,16 +258,20 @@ pub fn build_start_replication_sql( }; if options.is_empty() { - return format!("START_REPLICATION SLOT {quoted_slot} LOGICAL {lsn_str}"); + return Ok(format!( + "START_REPLICATION SLOT {quoted_slot} LOGICAL {lsn_str}" + )); } - let options_str = options - .iter() - .map(|(k, v)| format!("{} {}", quote_ident(k), quote_literal(v))) - .collect::>() - .join(", "); + let mut options_parts = Vec::with_capacity(options.len()); + for (k, v) in options { + options_parts.push(format!("{} {}", quote_ident(k)?, quote_literal(v)?)); + } + let options_str = options_parts.join(", "); - format!("START_REPLICATION SLOT {quoted_slot} LOGICAL {lsn_str} ({options_str})") + Ok(format!( + "START_REPLICATION SLOT {quoted_slot} LOGICAL {lsn_str} ({options_str})" + )) } /// Build the SQL for `START_REPLICATION ... PHYSICAL`. @@ -273,19 +281,19 @@ pub fn build_start_replication_sql( /// ``` /// use pg_walstream::sql_builder::build_start_physical_replication_sql; /// -/// let sql = build_start_physical_replication_sql(Some("my_slot"), 0, None); +/// let sql = build_start_physical_replication_sql(Some("my_slot"), 0, None).unwrap(); /// assert_eq!(sql, r#"START_REPLICATION SLOT "my_slot" PHYSICAL 0/0"#); /// ``` pub fn build_start_physical_replication_sql( slot_name: Option<&str>, start_lsn: XLogRecPtr, timeline_id: Option, -) -> String { +) -> Result { let mut sql = String::with_capacity(64); sql.push_str("START_REPLICATION "); if let Some(slot) = slot_name { - let quoted_slot = quote_ident(slot); + let quoted_slot = quote_ident(slot)?; sql.push_str("SLOT "); sql.push_str("ed_slot); sql.push(' '); @@ -304,7 +312,7 @@ pub fn build_start_physical_replication_sql( sql.push_str(&tli.to_string()); } - sql + Ok(sql) } /// Build the SQL for `BASE_BACKUP`. @@ -316,21 +324,21 @@ pub fn build_start_physical_replication_sql( /// use pg_walstream::types::BaseBackupOptions; /// /// let opts = BaseBackupOptions::default(); -/// assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP"); +/// assert_eq!(build_base_backup_sql(&opts).unwrap(), "BASE_BACKUP"); /// ``` -pub fn build_base_backup_sql(options: &BaseBackupOptions) -> String { +pub fn build_base_backup_sql(options: &BaseBackupOptions) -> Result { let mut opts = Vec::new(); if let Some(ref label) = options.label { - opts.push(format!("LABEL {}", quote_literal(label))); + opts.push(format!("LABEL {}", quote_literal(label)?)); } if let Some(ref target) = options.target { - opts.push(format!("TARGET {}", quote_literal(target))); + opts.push(format!("TARGET {}", quote_literal(target)?)); } if let Some(ref target_detail) = options.target_detail { - opts.push(format!("TARGET_DETAIL {}", quote_literal(target_detail))); + opts.push(format!("TARGET_DETAIL {}", quote_literal(target_detail)?)); } if options.progress { @@ -338,7 +346,7 @@ pub fn build_base_backup_sql(options: &BaseBackupOptions) -> String { } if let Some(ref checkpoint) = options.checkpoint { - opts.push(format!("CHECKPOINT {}", quote_literal(checkpoint))); + opts.push(format!("CHECKPOINT {}", quote_literal(checkpoint)?)); } if options.wal { @@ -350,13 +358,13 @@ pub fn build_base_backup_sql(options: &BaseBackupOptions) -> String { } if let Some(ref compression) = options.compression { - opts.push(format!("COMPRESSION {}", quote_literal(compression))); + opts.push(format!("COMPRESSION {}", quote_literal(compression)?)); } if let Some(ref compression_detail) = options.compression_detail { opts.push(format!( "COMPRESSION_DETAIL {}", - quote_literal(compression_detail) + quote_literal(compression_detail)? )); } @@ -373,13 +381,13 @@ pub fn build_base_backup_sql(options: &BaseBackupOptions) -> String { } if let Some(ref manifest) = options.manifest { - opts.push(format!("MANIFEST {}", quote_literal(manifest))); + opts.push(format!("MANIFEST {}", quote_literal(manifest)?)); } if let Some(ref manifest_checksums) = options.manifest_checksums { opts.push(format!( "MANIFEST_CHECKSUMS {}", - quote_literal(manifest_checksums) + quote_literal(manifest_checksums)? )); } @@ -388,9 +396,9 @@ pub fn build_base_backup_sql(options: &BaseBackupOptions) -> String { } if opts.is_empty() { - "BASE_BACKUP".to_string() + Ok("BASE_BACKUP".to_string()) } else { - format!("BASE_BACKUP ({})", opts.join(", ")) + Ok(format!("BASE_BACKUP ({})", opts.join(", "))) } } @@ -453,25 +461,24 @@ impl<'a> Default for CreateSubscriptionOptions<'a> { /// slot_name: "my_slot", /// ..Default::default() /// }; -/// let sql = build_create_subscription_sql(&opts); +/// let sql = build_create_subscription_sql(&opts).unwrap(); /// assert!(sql.starts_with("CREATE SUBSCRIPTION")); -/// assert!(sql.contains("create_slot = false")); /// ``` -pub fn build_create_subscription_sql(opts: &CreateSubscriptionOptions<'_>) -> String { - let sub = quote_ident(opts.subscription_name); - let conn = quote_literal(opts.connection_string); - let pubname = quote_ident(opts.publication); - let slot = quote_literal(opts.slot_name); +pub fn build_create_subscription_sql(opts: &CreateSubscriptionOptions<'_>) -> Result { + let sub = quote_ident(opts.subscription_name)?; + let conn = quote_literal(opts.connection_string)?; + let pubname = quote_ident(opts.publication)?; + let slot = quote_literal(opts.slot_name)?; let create_slot_str = if opts.create_slot { "true" } else { "false" }; let enabled_str = if opts.enabled { "true" } else { "false" }; let copy_data_str = if opts.copy_data { "true" } else { "false" }; - format!( + Ok(format!( "CREATE SUBSCRIPTION {sub} CONNECTION {conn} PUBLICATION {pubname} \ WITH (create_slot = {create_slot_str}, slot_name = {slot}, \ enabled = {enabled_str}, copy_data = {copy_data_str})" - ) + )) } /// Build an `ALTER SUBSCRIPTION ... DISABLE` statement. @@ -481,12 +488,12 @@ pub fn build_create_subscription_sql(opts: &CreateSubscriptionOptions<'_>) -> St /// ``` /// use pg_walstream::sql_builder::build_disable_subscription_sql; /// -/// let sql = build_disable_subscription_sql("my_sub"); +/// let sql = build_disable_subscription_sql("my_sub").unwrap(); /// assert_eq!(sql, r#"ALTER SUBSCRIPTION "my_sub" DISABLE"#); /// ``` #[inline] -pub fn build_disable_subscription_sql(name: &str) -> String { - format!("ALTER SUBSCRIPTION {} DISABLE", quote_ident(name)) +pub fn build_disable_subscription_sql(name: &str) -> Result { + Ok(format!("ALTER SUBSCRIPTION {} DISABLE", quote_ident(name)?)) } /// Build an `ALTER SUBSCRIPTION ... SET (slot_name = NONE)` statement to detach a slot. @@ -496,15 +503,15 @@ pub fn build_disable_subscription_sql(name: &str) -> String { /// ``` /// use pg_walstream::sql_builder::build_detach_slot_sql; /// -/// let sql = build_detach_slot_sql("my_sub"); +/// let sql = build_detach_slot_sql("my_sub").unwrap(); /// assert_eq!(sql, r#"ALTER SUBSCRIPTION "my_sub" SET (slot_name = NONE)"#); /// ``` #[inline] -pub fn build_detach_slot_sql(name: &str) -> String { - format!( +pub fn build_detach_slot_sql(name: &str) -> Result { + Ok(format!( "ALTER SUBSCRIPTION {} SET (slot_name = NONE)", - quote_ident(name) - ) + quote_ident(name)? + )) } /// Build a `DROP SUBSCRIPTION` statement. @@ -514,12 +521,12 @@ pub fn build_detach_slot_sql(name: &str) -> String { /// ``` /// use pg_walstream::sql_builder::build_drop_subscription_sql; /// -/// let sql = build_drop_subscription_sql("my_sub"); +/// let sql = build_drop_subscription_sql("my_sub").unwrap(); /// assert_eq!(sql, r#"DROP SUBSCRIPTION "my_sub""#); /// ``` #[inline] -pub fn build_drop_subscription_sql(name: &str) -> String { - format!("DROP SUBSCRIPTION {}", quote_ident(name)) +pub fn build_drop_subscription_sql(name: &str) -> Result { + Ok(format!("DROP SUBSCRIPTION {}", quote_ident(name)?)) } // ═══════════════════════════════════════════════════════════════════════════ @@ -545,102 +552,103 @@ mod tests { #[test] fn quote_ident_simple() { - assert_eq!(quote_ident("my_slot"), r#""my_slot""#); + assert_eq!(quote_ident("my_slot").unwrap(), r#""my_slot""#); } #[test] fn quote_ident_with_internal_double_quote() { - assert_eq!(quote_ident(r#"a"b"#), r#""a""b""#); + assert_eq!(quote_ident(r#"a"b"#).unwrap(), r#""a""b""#); } #[test] fn quote_ident_multiple_quotes() { - assert_eq!(quote_ident(r#"a""b"#), r#""a""""b""#); + assert_eq!(quote_ident(r#"a""b"#).unwrap(), r#""a""""b""#); } #[test] fn quote_ident_empty() { - assert_eq!(quote_ident(""), r#""""#); + assert_eq!(quote_ident("").unwrap(), r#""""#); } #[test] fn quote_ident_special_chars() { assert_eq!( - quote_ident("slot; DROP TABLE users; --"), + quote_ident("slot; DROP TABLE users; --").unwrap(), r#""slot; DROP TABLE users; --""# ); } #[test] fn quote_ident_unicode() { - assert_eq!(quote_ident("テスト"), r#""テスト""#); + assert_eq!(quote_ident("テスト").unwrap(), r#""テスト""#); } #[test] fn quote_ident_mixed_unicode_and_quotes() { - assert_eq!(quote_ident(r#"名前"テスト"#), r#""名前""テスト""#); + assert_eq!(quote_ident(r#"名前"テスト"#).unwrap(), r#""名前""テスト""#); } #[test] - #[should_panic(expected = "null bytes")] fn quote_ident_rejects_null_byte() { - quote_ident("evil\0injection"); + assert!(quote_ident("evil\0injection").is_err()); } // ── quote_literal ──────────────────────────────────────────────────── #[test] fn quote_literal_simple() { - assert_eq!(quote_literal("hello"), "'hello'"); + assert_eq!(quote_literal("hello").unwrap(), "'hello'"); } #[test] fn quote_literal_with_internal_single_quote() { - assert_eq!(quote_literal("it's"), "'it''s'"); + assert_eq!(quote_literal("it's").unwrap(), "'it''s'"); } #[test] fn quote_literal_multiple_quotes() { - assert_eq!(quote_literal("a''b"), "'a''''b'"); + assert_eq!(quote_literal("a''b").unwrap(), "'a''''b'"); } #[test] fn quote_literal_empty() { - assert_eq!(quote_literal(""), "''"); + assert_eq!(quote_literal("").unwrap(), "''"); } #[test] fn quote_literal_sql_injection_attempt() { assert_eq!( - quote_literal("'; DROP TABLE users; --"), + quote_literal("'; DROP TABLE users; --").unwrap(), "'''; DROP TABLE users; --'" ); } #[test] fn quote_literal_unicode() { - assert_eq!(quote_literal("日本語"), "'日本語'"); + assert_eq!(quote_literal("日本語").unwrap(), "'日本語'"); } #[test] fn quote_literal_newlines() { - assert_eq!(quote_literal("line1\nline2"), "'line1\nline2'"); + assert_eq!(quote_literal("line1\nline2").unwrap(), "'line1\nline2'"); } #[test] fn quote_literal_complex_injection() { - assert_eq!(quote_literal("value' OR '1'='1"), "'value'' OR ''1''=''1'"); + assert_eq!( + quote_literal("value' OR '1'='1").unwrap(), + "'value'' OR ''1''=''1'" + ); } #[test] fn quote_literal_backslash_and_quote() { - assert_eq!(quote_literal("test\\'value"), "'test\\''value'"); + assert_eq!(quote_literal("test\\'value").unwrap(), "'test\\''value'"); } #[test] - #[should_panic(expected = "null bytes")] fn quote_literal_rejects_null_byte() { - quote_literal("evil\0injection"); + assert!(quote_literal("evil\0injection").is_err()); } // ── build_create_slot_sql ──────────────────────────────────────────── @@ -873,7 +881,7 @@ mod tests { #[test] fn drop_slot_without_wait() { assert_eq!( - build_drop_slot_sql("my_slot", false), + build_drop_slot_sql("my_slot", false).unwrap(), r#"DROP_REPLICATION_SLOT "my_slot";"# ); } @@ -881,14 +889,14 @@ mod tests { #[test] fn drop_slot_with_wait() { assert_eq!( - build_drop_slot_sql("my_slot", true), + build_drop_slot_sql("my_slot", true).unwrap(), r#"DROP_REPLICATION_SLOT "my_slot" WAIT;"# ); } #[test] fn drop_slot_injection() { - let sql = build_drop_slot_sql(r#"evil"slot"#, false); + let sql = build_drop_slot_sql(r#"evil"slot"#, false).unwrap(); assert_eq!(sql, r#"DROP_REPLICATION_SLOT "evil""slot";"#); } @@ -897,7 +905,7 @@ mod tests { #[test] fn read_slot_basic() { assert_eq!( - build_read_slot_sql("my_slot"), + build_read_slot_sql("my_slot").unwrap(), r#"READ_REPLICATION_SLOT "my_slot";"# ); } @@ -905,7 +913,7 @@ mod tests { #[test] fn read_slot_injection() { assert_eq!( - build_read_slot_sql(r#"evil"slot"#), + build_read_slot_sql(r#"evil"slot"#).unwrap(), r#"READ_REPLICATION_SLOT "evil""slot";"# ); } @@ -918,7 +926,8 @@ mod tests { "my_slot", 0, &[("proto_version", "1"), ("publication_names", "my_pub")], - ); + ) + .unwrap(); assert_eq!( sql, r#"START_REPLICATION SLOT "my_slot" LOGICAL 0/0 ("proto_version" '1', "publication_names" 'my_pub')"# @@ -928,7 +937,7 @@ mod tests { #[test] fn start_replication_valid_lsn() { let lsn: XLogRecPtr = 0x0000_0001_0000_0000; - let sql = build_start_replication_sql("test_slot", lsn, &[("proto_version", "2")]); + let sql = build_start_replication_sql("test_slot", lsn, &[("proto_version", "2")]).unwrap(); assert!(sql.contains("START_REPLICATION SLOT \"test_slot\" LOGICAL")); assert!(sql.contains("(\"proto_version\" '2')")); assert!(!sql.contains("0/0")); @@ -944,7 +953,8 @@ mod tests { ("publication_names", "pub1"), ("messages", "true"), ], - ); + ) + .unwrap(); assert!( sql.contains(r#""proto_version" '1', "publication_names" 'pub1', "messages" 'true'"#) ); @@ -952,20 +962,20 @@ mod tests { #[test] fn start_replication_empty_options() { - let sql = build_start_replication_sql("slot1", 0, &[]); + let sql = build_start_replication_sql("slot1", 0, &[]).unwrap(); assert_eq!(sql, r#"START_REPLICATION SLOT "slot1" LOGICAL 0/0"#); } #[test] fn start_replication_option_injection() { - let sql = build_start_replication_sql(r#"evil"slot"#, 0, &[("key", "it's")]); + let sql = build_start_replication_sql(r#"evil"slot"#, 0, &[("key", "it's")]).unwrap(); assert!(sql.contains(r#""evil""slot""#)); assert!(sql.contains("'it''s'")); } #[test] fn start_replication_single_option() { - let sql = build_start_replication_sql("my_slot", 0, &[("proto_version", "1")]); + let sql = build_start_replication_sql("my_slot", 0, &[("proto_version", "1")]).unwrap(); assert_eq!( sql, r#"START_REPLICATION SLOT "my_slot" LOGICAL 0/0 ("proto_version" '1')"# @@ -976,27 +986,27 @@ mod tests { #[test] fn start_physical_with_slot_zero_lsn() { - let sql = build_start_physical_replication_sql(Some("my_slot"), 0, None); + let sql = build_start_physical_replication_sql(Some("my_slot"), 0, None).unwrap(); assert_eq!(sql, r#"START_REPLICATION SLOT "my_slot" PHYSICAL 0/0"#); } #[test] fn start_physical_no_slot() { - let sql = build_start_physical_replication_sql(None, 0, None); + let sql = build_start_physical_replication_sql(None, 0, None).unwrap(); assert_eq!(sql, "START_REPLICATION PHYSICAL 0/0"); } #[test] fn start_physical_with_lsn() { let lsn: XLogRecPtr = 0x0000_0001_0000_0000; - let sql = build_start_physical_replication_sql(Some("slot"), lsn, None); + let sql = build_start_physical_replication_sql(Some("slot"), lsn, None).unwrap(); assert!(sql.contains("PHYSICAL 1/0")); assert!(!sql.contains("0/0")); } #[test] fn start_physical_with_timeline() { - let sql = build_start_physical_replication_sql(Some("slot"), 0, Some(3)); + let sql = build_start_physical_replication_sql(Some("slot"), 0, Some(3)).unwrap(); assert_eq!( sql, r#"START_REPLICATION SLOT "slot" PHYSICAL 0/0 TIMELINE 3"# @@ -1005,13 +1015,13 @@ mod tests { #[test] fn start_physical_no_slot_with_timeline() { - let sql = build_start_physical_replication_sql(None, 0, Some(5)); + let sql = build_start_physical_replication_sql(None, 0, Some(5)).unwrap(); assert_eq!(sql, "START_REPLICATION PHYSICAL 0/0 TIMELINE 5"); } #[test] fn start_physical_slot_injection() { - let sql = build_start_physical_replication_sql(Some(r#"evil"slot"#), 0, None); + let sql = build_start_physical_replication_sql(Some(r#"evil"slot"#), 0, None).unwrap(); assert!(sql.contains(r#"SLOT "evil""slot""#)); } @@ -1020,7 +1030,7 @@ mod tests { #[test] fn base_backup_default() { let opts = BaseBackupOptions::default(); - assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP"); + assert_eq!(build_base_backup_sql(&opts).unwrap(), "BASE_BACKUP"); } #[test] @@ -1030,7 +1040,7 @@ mod tests { ..Default::default() }; assert_eq!( - build_base_backup_sql(&opts), + build_base_backup_sql(&opts).unwrap(), "BASE_BACKUP (LABEL 'my_backup')" ); } @@ -1042,7 +1052,7 @@ mod tests { ..Default::default() }; assert_eq!( - build_base_backup_sql(&opts), + build_base_backup_sql(&opts).unwrap(), "BASE_BACKUP (TARGET 'client')" ); } @@ -1055,7 +1065,7 @@ mod tests { ..Default::default() }; assert_eq!( - build_base_backup_sql(&opts), + build_base_backup_sql(&opts).unwrap(), "BASE_BACKUP (TARGET 'server', TARGET_DETAIL '/var/backups')" ); } @@ -1066,7 +1076,10 @@ mod tests { progress: true, ..Default::default() }; - assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP (PROGRESS true)"); + assert_eq!( + build_base_backup_sql(&opts).unwrap(), + "BASE_BACKUP (PROGRESS true)" + ); } #[test] @@ -1076,7 +1089,7 @@ mod tests { ..Default::default() }; assert_eq!( - build_base_backup_sql(&opts), + build_base_backup_sql(&opts).unwrap(), "BASE_BACKUP (CHECKPOINT 'fast')" ); } @@ -1087,7 +1100,10 @@ mod tests { wal: true, ..Default::default() }; - assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP (WAL true)"); + assert_eq!( + build_base_backup_sql(&opts).unwrap(), + "BASE_BACKUP (WAL true)" + ); } #[test] @@ -1096,7 +1112,10 @@ mod tests { wait: true, ..Default::default() }; - assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP (WAIT true)"); + assert_eq!( + build_base_backup_sql(&opts).unwrap(), + "BASE_BACKUP (WAIT true)" + ); } #[test] @@ -1106,7 +1125,7 @@ mod tests { ..Default::default() }; assert_eq!( - build_base_backup_sql(&opts), + build_base_backup_sql(&opts).unwrap(), "BASE_BACKUP (COMPRESSION 'gzip')" ); } @@ -1119,7 +1138,7 @@ mod tests { ..Default::default() }; assert_eq!( - build_base_backup_sql(&opts), + build_base_backup_sql(&opts).unwrap(), "BASE_BACKUP (COMPRESSION 'zstd', COMPRESSION_DETAIL 'level=3')" ); } @@ -1130,7 +1149,10 @@ mod tests { max_rate: Some(1024), ..Default::default() }; - assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP (MAX_RATE 1024)"); + assert_eq!( + build_base_backup_sql(&opts).unwrap(), + "BASE_BACKUP (MAX_RATE 1024)" + ); } #[test] @@ -1140,7 +1162,7 @@ mod tests { ..Default::default() }; assert_eq!( - build_base_backup_sql(&opts), + build_base_backup_sql(&opts).unwrap(), "BASE_BACKUP (TABLESPACE_MAP true)" ); } @@ -1152,7 +1174,7 @@ mod tests { ..Default::default() }; assert_eq!( - build_base_backup_sql(&opts), + build_base_backup_sql(&opts).unwrap(), "BASE_BACKUP (VERIFY_CHECKSUMS true)" ); } @@ -1163,7 +1185,10 @@ mod tests { manifest: Some("yes".to_string()), ..Default::default() }; - assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP (MANIFEST 'yes')"); + assert_eq!( + build_base_backup_sql(&opts).unwrap(), + "BASE_BACKUP (MANIFEST 'yes')" + ); } #[test] @@ -1174,7 +1199,7 @@ mod tests { ..Default::default() }; assert_eq!( - build_base_backup_sql(&opts), + build_base_backup_sql(&opts).unwrap(), "BASE_BACKUP (MANIFEST 'yes', MANIFEST_CHECKSUMS 'SHA256')" ); } @@ -1185,7 +1210,10 @@ mod tests { incremental: true, ..Default::default() }; - assert_eq!(build_base_backup_sql(&opts), "BASE_BACKUP (INCREMENTAL)"); + assert_eq!( + build_base_backup_sql(&opts).unwrap(), + "BASE_BACKUP (INCREMENTAL)" + ); } #[test] @@ -1198,7 +1226,7 @@ mod tests { ..Default::default() }; assert_eq!( - build_base_backup_sql(&opts), + build_base_backup_sql(&opts).unwrap(), "BASE_BACKUP (LABEL 'backup', PROGRESS true, WAL true, VERIFY_CHECKSUMS true)" ); } @@ -1210,7 +1238,7 @@ mod tests { ..Default::default() }; assert_eq!( - build_base_backup_sql(&opts), + build_base_backup_sql(&opts).unwrap(), "BASE_BACKUP (LABEL 'evil''; DROP TABLE users; --')" ); } @@ -1252,7 +1280,7 @@ mod tests { slot_name: "my_slot", ..Default::default() }; - let sql = build_create_subscription_sql(&opts); + let sql = build_create_subscription_sql(&opts).unwrap(); assert_eq!( sql, "CREATE SUBSCRIPTION \"my_sub\" \ @@ -1271,7 +1299,7 @@ mod tests { slot_name: "slot'name", ..Default::default() }; - let sql = build_create_subscription_sql(&opts); + let sql = build_create_subscription_sql(&opts).unwrap(); assert_eq!( sql, "CREATE SUBSCRIPTION \"sub\"\"name\" \ @@ -1290,7 +1318,7 @@ mod tests { slot_name: "", ..Default::default() }; - let sql = build_create_subscription_sql(&opts); + let sql = build_create_subscription_sql(&opts).unwrap(); assert_eq!( sql, "CREATE SUBSCRIPTION \"\" \ @@ -1310,7 +1338,7 @@ mod tests { create_slot: true, ..Default::default() }; - let sql = build_create_subscription_sql(&opts); + let sql = build_create_subscription_sql(&opts).unwrap(); assert!(sql.contains("create_slot = true")); assert!(sql.contains("enabled = true")); assert!(sql.contains("copy_data = false")); @@ -1326,7 +1354,7 @@ mod tests { copy_data: true, ..Default::default() }; - let sql = build_create_subscription_sql(&opts); + let sql = build_create_subscription_sql(&opts).unwrap(); assert!(sql.contains("copy_data = true")); } @@ -1340,7 +1368,7 @@ mod tests { enabled: false, ..Default::default() }; - let sql = build_create_subscription_sql(&opts); + let sql = build_create_subscription_sql(&opts).unwrap(); assert!(sql.contains("enabled = false")); } @@ -1355,7 +1383,7 @@ mod tests { enabled: true, copy_data: true, }; - let sql = build_create_subscription_sql(&opts); + let sql = build_create_subscription_sql(&opts).unwrap(); assert_eq!( sql, "CREATE SUBSCRIPTION \"sub\" \ @@ -1376,7 +1404,7 @@ mod tests { enabled: false, copy_data: false, }; - let sql = build_create_subscription_sql(&opts); + let sql = build_create_subscription_sql(&opts).unwrap(); assert_eq!( sql, "CREATE SUBSCRIPTION \"sub\" \ @@ -1391,7 +1419,7 @@ mod tests { #[test] fn disable_subscription_basic() { assert_eq!( - build_disable_subscription_sql("my_sub"), + build_disable_subscription_sql("my_sub").unwrap(), r#"ALTER SUBSCRIPTION "my_sub" DISABLE"# ); } @@ -1399,7 +1427,7 @@ mod tests { #[test] fn disable_subscription_with_quotes() { assert_eq!( - build_disable_subscription_sql(r#"sub"name"#), + build_disable_subscription_sql(r#"sub"name"#).unwrap(), r#"ALTER SUBSCRIPTION "sub""name" DISABLE"# ); } @@ -1409,7 +1437,7 @@ mod tests { #[test] fn detach_slot_basic() { assert_eq!( - build_detach_slot_sql("my_sub"), + build_detach_slot_sql("my_sub").unwrap(), r#"ALTER SUBSCRIPTION "my_sub" SET (slot_name = NONE)"# ); } @@ -1417,7 +1445,7 @@ mod tests { #[test] fn detach_slot_with_quotes() { assert_eq!( - build_detach_slot_sql(r#"sub"x"#), + build_detach_slot_sql(r#"sub"x"#).unwrap(), r#"ALTER SUBSCRIPTION "sub""x" SET (slot_name = NONE)"# ); } @@ -1427,7 +1455,7 @@ mod tests { #[test] fn drop_subscription_basic() { assert_eq!( - build_drop_subscription_sql("my_sub"), + build_drop_subscription_sql("my_sub").unwrap(), r#"DROP SUBSCRIPTION "my_sub""# ); } @@ -1435,7 +1463,7 @@ mod tests { #[test] fn drop_subscription_with_quotes() { assert_eq!( - build_drop_subscription_sql(r#"sub"name"#), + build_drop_subscription_sql(r#"sub"name"#).unwrap(), r#"DROP SUBSCRIPTION "sub""name""# ); } @@ -1443,11 +1471,176 @@ mod tests { #[test] fn drop_subscription_injection_attempt() { assert_eq!( - build_drop_subscription_sql("evil\"; DROP TABLE users; --"), + build_drop_subscription_sql("evil\"; DROP TABLE users; --").unwrap(), "DROP SUBSCRIPTION \"evil\"\"; DROP TABLE users; --\"" ); } + // ── Null byte error propagation through builders ─────────────────── + + #[test] + fn create_slot_rejects_null_in_slot_name() { + let opts = ReplicationSlotOptions::default(); + let err = + build_create_slot_sql("slot\0name", SlotType::Logical, Some("pgoutput"), &opts) + .unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn create_slot_rejects_null_in_plugin() { + let opts = ReplicationSlotOptions::default(); + let err = + build_create_slot_sql("slot", SlotType::Logical, Some("pg\0output"), &opts) + .unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn alter_slot_rejects_null_in_slot_name() { + let err = build_alter_slot_sql("slot\0x", Some(true), None).unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn drop_slot_rejects_null_in_slot_name() { + let err = build_drop_slot_sql("slot\0x", false).unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn read_slot_rejects_null_in_slot_name() { + let err = build_read_slot_sql("slot\0x").unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn start_replication_rejects_null_in_slot_name() { + let err = + build_start_replication_sql("slot\0x", 0, &[("proto_version", "1")]).unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn start_replication_rejects_null_in_option_key() { + let err = + build_start_replication_sql("slot", 0, &[("key\0x", "value")]).unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn start_replication_rejects_null_in_option_value() { + let err = + build_start_replication_sql("slot", 0, &[("key", "val\0ue")]).unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn start_physical_rejects_null_in_slot_name() { + let err = build_start_physical_replication_sql(Some("slot\0x"), 0, None).unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn base_backup_rejects_null_in_label() { + let opts = BaseBackupOptions { + label: Some("label\0x".to_string()), + ..Default::default() + }; + let err = build_base_backup_sql(&opts).unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn base_backup_rejects_null_in_target() { + let opts = BaseBackupOptions { + target: Some("target\0x".to_string()), + ..Default::default() + }; + let err = build_base_backup_sql(&opts).unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn base_backup_rejects_null_in_compression() { + let opts = BaseBackupOptions { + compression: Some("gzip\0x".to_string()), + ..Default::default() + }; + let err = build_base_backup_sql(&opts).unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn create_subscription_rejects_null_in_name() { + let opts = CreateSubscriptionOptions { + subscription_name: "sub\0x", + connection_string: "host=localhost", + publication: "pub", + slot_name: "slot", + ..Default::default() + }; + let err = build_create_subscription_sql(&opts).unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn create_subscription_rejects_null_in_connection() { + let opts = CreateSubscriptionOptions { + subscription_name: "sub", + connection_string: "host=\0localhost", + publication: "pub", + slot_name: "slot", + ..Default::default() + }; + let err = build_create_subscription_sql(&opts).unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn create_subscription_rejects_null_in_publication() { + let opts = CreateSubscriptionOptions { + subscription_name: "sub", + connection_string: "host=localhost", + publication: "pub\0x", + slot_name: "slot", + ..Default::default() + }; + let err = build_create_subscription_sql(&opts).unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn create_subscription_rejects_null_in_slot_name() { + let opts = CreateSubscriptionOptions { + subscription_name: "sub", + connection_string: "host=localhost", + publication: "pub", + slot_name: "slot\0x", + ..Default::default() + }; + let err = build_create_subscription_sql(&opts).unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn disable_subscription_rejects_null() { + let err = build_disable_subscription_sql("sub\0x").unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn detach_slot_rejects_null() { + let err = build_detach_slot_sql("sub\0x").unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + + #[test] + fn drop_subscription_rejects_null() { + let err = build_drop_subscription_sql("sub\0x").unwrap_err(); + assert!(err.to_string().contains("null bytes")); + } + // ── Compatibility with existing behavior ───────────────────────────── #[test] @@ -1462,7 +1655,11 @@ mod tests { ]; for input in cases { let legacy = format!("\"{}\"", input.replace('"', "\"\"")); - assert_eq!(quote_ident(input), legacy, "mismatch for input: {input:?}"); + assert_eq!( + quote_ident(input).unwrap(), + legacy, + "mismatch for input: {input:?}" + ); } } @@ -1480,7 +1677,7 @@ mod tests { for input in cases { let legacy = format!("'{}'", input.replace('\'', "''")); assert_eq!( - quote_literal(input), + quote_literal(input).unwrap(), legacy, "mismatch for input: {input:?}" ); From 9fbf3b222ea3880550c72aed878c61a488249b23 Mon Sep 17 00:00:00 2001 From: danielshih Date: Mon, 4 May 2026 09:48:49 +0000 Subject: [PATCH 4/4] refactor(tests): streamline error handling for null byte checks in replication SQL functions --- src/sql_builder.rs | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/sql_builder.rs b/src/sql_builder.rs index 9f98fb8..0e461dc 100644 --- a/src/sql_builder.rs +++ b/src/sql_builder.rs @@ -1481,18 +1481,16 @@ mod tests { #[test] fn create_slot_rejects_null_in_slot_name() { let opts = ReplicationSlotOptions::default(); - let err = - build_create_slot_sql("slot\0name", SlotType::Logical, Some("pgoutput"), &opts) - .unwrap_err(); + let err = build_create_slot_sql("slot\0name", SlotType::Logical, Some("pgoutput"), &opts) + .unwrap_err(); assert!(err.to_string().contains("null bytes")); } #[test] fn create_slot_rejects_null_in_plugin() { let opts = ReplicationSlotOptions::default(); - let err = - build_create_slot_sql("slot", SlotType::Logical, Some("pg\0output"), &opts) - .unwrap_err(); + let err = build_create_slot_sql("slot", SlotType::Logical, Some("pg\0output"), &opts) + .unwrap_err(); assert!(err.to_string().contains("null bytes")); } @@ -1516,22 +1514,19 @@ mod tests { #[test] fn start_replication_rejects_null_in_slot_name() { - let err = - build_start_replication_sql("slot\0x", 0, &[("proto_version", "1")]).unwrap_err(); + let err = build_start_replication_sql("slot\0x", 0, &[("proto_version", "1")]).unwrap_err(); assert!(err.to_string().contains("null bytes")); } #[test] fn start_replication_rejects_null_in_option_key() { - let err = - build_start_replication_sql("slot", 0, &[("key\0x", "value")]).unwrap_err(); + let err = build_start_replication_sql("slot", 0, &[("key\0x", "value")]).unwrap_err(); assert!(err.to_string().contains("null bytes")); } #[test] fn start_replication_rejects_null_in_option_value() { - let err = - build_start_replication_sql("slot", 0, &[("key", "val\0ue")]).unwrap_err(); + let err = build_start_replication_sql("slot", 0, &[("key", "val\0ue")]).unwrap_err(); assert!(err.to_string().contains("null bytes")); }