diff --git a/crates/dkdc-db-core/src/db.rs b/crates/dkdc-db-core/src/db.rs index a7502f1..a4a9883 100644 --- a/crates/dkdc-db-core/src/db.rs +++ b/crates/dkdc-db-core/src/db.rs @@ -12,17 +12,42 @@ const WAL_MODE_PRAGMA: &str = "journal_mode"; const WAL_MODE_VALUE: &str = "'wal'"; /// Try to extract a table name from a simple SELECT query for PRAGMA fallback. -/// Handles `SELECT ... FROM table_name` patterns. +/// Handles `SELECT ... FROM table_name` patterns, including schema-qualified +/// names (`schema.table`) and quoted identifiers (`"My Table"`). +/// Returns `None` for subqueries and other non-simple patterns. fn extract_table_name(sql: &str) -> Option { let upper = sql.to_uppercase(); let from_idx = upper.find(" FROM ")?; let after_from = sql[from_idx + 6..].trim_start(); - // Take the first word (table name), stop at whitespace, comma, semicolon, or parenthesis - let name: String = after_from - .chars() - .take_while(|c| c.is_alphanumeric() || *c == '_') - .collect(); - if name.is_empty() { None } else { Some(name) } + + // Subqueries start with '(' — bail out + if after_from.starts_with('(') { + return None; + } + + // Handle quoted identifier: "table name" + let raw_name = if let Some(rest) = after_from.strip_prefix('"') { + let end = rest.find('"')?; + &rest[..end] + } else { + // Unquoted: take chars valid in identifiers (alphanumeric, underscore, dot for schema) + let end = after_from + .find(|c: char| !(c.is_alphanumeric() || c == '_' || c == '.')) + .unwrap_or(after_from.len()); + &after_from[..end] + }; + + if raw_name.is_empty() { + return None; + } + + // For schema-qualified names (schema.table), take the last part + let table = raw_name.rsplit('.').next().unwrap_or(raw_name); + if table.is_empty() { + None + } else { + Some(table.to_string()) + } } /// Enable WAL mode on a connection for concurrent read+write. @@ -177,3 +202,78 @@ impl DkdcDb { Ok(self.db.connect()?) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_simple_table() { + assert_eq!( + extract_table_name("SELECT * FROM users"), + Some("users".to_string()) + ); + } + + #[test] + fn test_extract_table_with_where() { + assert_eq!( + extract_table_name("SELECT id FROM orders WHERE id = 1"), + Some("orders".to_string()) + ); + } + + #[test] + fn test_extract_schema_qualified() { + assert_eq!( + extract_table_name("SELECT * FROM myschema.mytable"), + Some("mytable".to_string()) + ); + } + + #[test] + fn test_extract_quoted_identifier() { + assert_eq!( + extract_table_name("SELECT * FROM \"My Table\""), + Some("My Table".to_string()) + ); + } + + #[test] + fn test_extract_quoted_with_schema() { + // Quoted name with dot inside quotes is treated as the full name + assert_eq!( + extract_table_name("SELECT * FROM \"my.table\""), + Some("table".to_string()) + ); + } + + #[test] + fn test_extract_subquery_returns_none() { + assert_eq!( + extract_table_name("SELECT * FROM (SELECT id FROM users)"), + None + ); + } + + #[test] + fn test_extract_no_from_returns_none() { + assert_eq!(extract_table_name("SELECT 1 + 1"), None); + } + + #[test] + fn test_extract_case_insensitive_from() { + assert_eq!( + extract_table_name("select * from events"), + Some("events".to_string()) + ); + } + + #[test] + fn test_extract_table_with_semicolon() { + assert_eq!( + extract_table_name("SELECT * FROM users;"), + Some("users".to_string()) + ); + } +}