diff --git a/tap_postgres/_wal_helpers.py b/tap_postgres/_wal_helpers.py new file mode 100644 index 00000000..7c45480f --- /dev/null +++ b/tap_postgres/_wal_helpers.py @@ -0,0 +1,111 @@ +"""Helper functions for LOG_BASED replication.""" + +from __future__ import annotations + +import json +import re +import typing as t + +import psycopg2 + +if t.TYPE_CHECKING: + from psycopg2 import extras + + +# wal2json emits enum type names with unescaped double quotes, producing invalid JSON +# like "type":""EnumName"" ... strip the extra quotes via regex and re-parse +_WAL2JSON_ENUM_QUOTE_RE = re.compile(r'"type":""([^"]+)""') + + +def normalize_fqn(schema: str, table: str) -> str: + """Generate canonical, fully-qualified name for dispatch. + + This is the source of truth for matching a WAL message to a registered stream. + Both sides -- the tap (when registering streams with the reader) and the WAL reader + (when dispatching a parsed payload) -- *must* call this function with the raw schema + and table name strings. + + wal2json's format-version=2 output includes ``"schema"`` and ``"table"`` fields + as the raw, unquoted identifier names (wal2json reports whatever Postgres has stored). + Therefore, use the raw names joined by a single dot, with no quoting and no case folding. + + Do *not* use ``SQLStream.fully_qualified_name`` for dispatch. + """ + return f"{schema}.{table}" + + +def escape_for_add_tables(identifier: str) -> str: + """Escape a schema or table name for use in wal2json's ``add-tables``. + + wal2json's ``add-tables`` option takes a comma-separated list of "schema.table" entries. + Backslash is the escape character; comma and dot within an identifier must be escaped, + and backslash itself must be doubled. + + References: + - https://github.com/eulerto/wal2json#parameters + """ + return identifier.replace("\\", "\\\\").replace(",", "\\,").replace(".", "\\.") + + +def build_add_tables_option(fqn_pairs: list[tuple[str, str]]) -> str: + """Build the wal2json ``add-tables`` option from a list of (schema, table). + + Each identifier is escaped with ``escape_for_add_tables`` and joined with + the appropriate separators. For example:: + + >>> build_add_tables_option([("public", "users"), ("public", "orders")]) + 'public.users,public.orders' + """ + return ",".join( + f"{escape_for_add_tables(schema)}.{escape_for_add_tables(table)}" + for schema, table in fqn_pairs + ) + + +def parse_wal_message(raw_payload: str, cursor: extras.ReplicationCursor | None) -> dict | None: + """Parse a raw wal2json JSON payload into a Python dict. + + Handles the known wal2json enum-quoting bug via one retry after regex repair. + Returns None if the payload can't be decoded even after repair, in which case + the caller should log and skip. + + When ``cursor`` is provided, pre-parses ``text[]`` column values into Python lists + using psycopg2's ``STRINGARRAY`` type caster. This must be done while the cursor is alive, + since ``STRINGARRAY`` reads connection-level encoding info from it. + """ + try: + payload = json.loads(raw_payload) + except json.JSONDecodeError: + try: + payload = json.loads(fix_wal2json_enum_quotes(raw_payload)) + except json.JSONDecodeError: + return None + + if cursor is not None: + pre_parse_text_arrays(payload, cursor) + + return payload + + +def fix_wal2json_enum_quotes(payload: str) -> str: + """Repair the wal2json enum-quoting bug in a raw JSON payload. + + wal2json outputs enum type names with unescaped double quotes (e.g. "type":""EnumName""), + which is invalid JSON. Normalize this to "type":"EnumName" so a second ``json.loads`` + attempt succeeds. + """ + return _WAL2JSON_ENUM_QUOTE_RE.sub(r'"type":"\1"', payload) + + +def pre_parse_text_arrays(payload: dict, cursor: extras.ReplicationCursor) -> None: + """Pre-parse ``text[]`` column values in a wal2json payload, in place. + + wal2json returns ``text[]`` values as Postgres's array literal string (e.g. '{a,b,c}'). + Converting to a Python list requires ``psycopg2.extensions.STRINGARRAY``, which needs + a live cursor for encoding context. Calling this during the WAL read means downstream code + -- ``consume()``, etc. -- can operate on plain Python lists with no cursor dependency. + """ + for key in ("columns", "identity"): + for column in payload.get(key, ()) or (): + if column.get("type") == "text[]" and column.get("value") is not None: + column["value"] = psycopg2.extensions.STRINGARRAY(column["value"], cursor) diff --git a/tap_postgres/client.py b/tap_postgres/client.py index 9d1b85cd..98c3b625 100644 --- a/tap_postgres/client.py +++ b/tap_postgres/client.py @@ -7,8 +7,6 @@ import datetime import functools -import json -import re import select import sys import typing as t @@ -23,6 +21,8 @@ from singer_sdk.sql.connector import SQLToJSONSchema from sqlalchemy.dialects import postgresql +from tap_postgres._wal_helpers import parse_wal_message + if sys.version_info >= (3, 12): from typing import override else: @@ -41,6 +41,13 @@ from tap_postgres.connection_parameters import ConnectionParameters +_UPSERT_ACTIONS = {"I", "U"} +_DELETE_ACTIONS = {"D"} +_TRUNCATE_ACTIONS = {"T"} +_TRANSACTION_ACTIONS = {"B", "C"} +_NUMERIC_TYPES = ("int", "numeric", "decimal", "real", "double", "float", "bigint", "smallint") + + def _now_utc() -> str: """Return the current UTC time as a string.""" return datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") @@ -234,8 +241,6 @@ class PostgresLogBasedStream(SQLStream): replication_key = "_sdc_lsn" is_sorted = True - _WAL2JSON_ENUM_QUOTE_RE = re.compile(r'"type":""([^"]+)""') - connection_parameters: ConnectionParameters def __init__( @@ -247,6 +252,9 @@ def __init__( ) -> None: """Initialize Postgres log-based stream.""" self.connection_parameters = connection_parameters + # track whether this stream's SCHEMA message has already been emitted + # to avoid emitting duplicate SCHEMA messages when running via the shared WAL reader + self._schema_message_written = False super().__init__(tap, catalog_entry, connector) @@ -307,17 +315,37 @@ def _increment_stream_state( @override def get_records(self, context: Context | None) -> Iterable[dict[str, t.Any]]: + """Iterate records for this LOG_BASED stream. + + When using the single-connection WAL reader, the first call into get_records() + across any LOG_BASED stream triggers which emits records for *all* selected + LOG_BASED streams via their ``emit_record()`` adapter — bypassing this generator. + Subsequent sibling calls are no-ops. + + Yielding nothing is fine, since an empty iterable produces zero additional record + messages from the SDK's per-stream loop, while leaving SCHEMA emission, metrics, + and state finalization intact. + """ + if not self._tap.config.get("log_based_single_connection", True): + yield from self._get_records_per_stream(context) + return + + if not self._tap._shared_wal_run_completed: + self._tap._sync_log_based_streams_shared() + self._tap._shared_wal_run_completed = True + + return + + def _get_records_per_stream(self, context: Context | None) -> Iterable[dict[str, t.Any]]: """Return a generator of row-type dictionary objects. - Runs a long-lived replication session (up to - ``replication_max_run_seconds``, default 600 s) so the tap can drain - large WAL backlogs in a single sync. Sends periodic flush feedback + Runs a long-lived replication session (up to ``replication_max_run_seconds``) + so the tap can drain large WAL backlogs in a single sync. Sends periodic flush feedback while yielding records so the slot releases retained WAL incrementally. After the loop ends -- either because no data messages arrived for - ``replication_idle_exit_seconds`` (default 60 s) or the time budget is - exhausted -- the slot is advanced to the current WAL tip to prevent - unbounded WAL retention. + ``replication_idle_exit_seconds`` or total time budget is exhausted -- the slot + is advanced to the current WAL tip to prevent unbounded WAL retention. """ status_interval = 10 max_run_seconds = self.config["replication_max_run_seconds"] @@ -366,10 +394,12 @@ def get_records(self, context: Context | None) -> Iterable[dict[str, t.Any]]: message = logical_replication_cursor.read_message() if message: last_data_message = datetime.datetime.now() - row = self.consume(message, logical_replication_cursor) - if row: - records_yielded += 1 - yield row + payload = parse_wal_message(message.payload, logical_replication_cursor) + if payload is not None: + row = self.consume(payload, message.data_start) + if row: + records_yielded += 1 + yield row if ( datetime.datetime.now() - last_feedback_time ).total_seconds() >= feedback_interval: @@ -380,6 +410,11 @@ def get_records(self, context: Context | None) -> Iterable[dict[str, t.Any]]: last_feedback_time = datetime.datetime.now() except Exception: pass + else: + self.logger.warning( + "A message payload of %s could not be converted to JSON", + message.payload, + ) continue try: @@ -413,6 +448,33 @@ def get_records(self, context: Context | None) -> Iterable[dict[str, t.Any]]: logical_replication_cursor.close() logical_replication_connection.close() + @override + def _write_schema_message(self) -> None: + """Emit a SCHEMA message at most once per stream lifetime. + + ``TapPostgres._sync_log_based_streams_shared`` pre-writes schemas for every stream + up front (so RECORDs from siblings can't precede their own SCHEMA), and the SDK's + ``Stream.sync()`` calls this method again as part of the per-stream sync loop. + Without this guard, every LOG_BASED stream emits its SCHEMA twice. + """ + if self._schema_message_written: + return + super()._write_schema_message() + self._schema_message_written = True + + def emit_record(self, record: dict, *, context: Context | None = None) -> None: + """Emit one record as a Singer RECORD message and advance state. + + This is meant to decouple ``SingleConnectionWALReader`` from singer-sdk's + per-record internals. It does the following: stream map transformation => + type conformance => emission of one or more RECORD messages to stdout => + advancement of stream's replication bookmark. + + STATE message emission is *not* done here; that's the caller's responsibility. + """ + self._write_record_message(record) + self._increment_stream_state(record, context=context) + def _advance_slot_and_state( self, replication_cursor: extras.ReplicationCursor, @@ -496,111 +558,78 @@ def _query_current_wal_lsn(self) -> int | None: self.logger.warning("Could not query current WAL LSN: %s", exc) return None - def consume(self, message, cursor: extras.ReplicationCursor) -> dict | None: - """Ingest WAL message.""" - try: - message_payload = json.loads(message.payload) - except json.JSONDecodeError: - # wal2json outputs PostgreSQL enum types with unescaped quotes - # e.g., "type":""EnumName"" instead of "type":"EnumName" - # Try to fix this by removing the extra quotes around type values - fixed_payload = self._fix_wal2json_enum_quotes(message.payload) - try: - message_payload = json.loads(fixed_payload) - except json.JSONDecodeError: - self.logger.warning( - "A message payload of %s could not be converted to JSON", - message.payload, - ) - return {} - - row = {} - - upsert_actions = {"I", "U"} - delete_actions = {"D"} - truncate_actions = {"T"} - transaction_actions = {"B", "C"} - - if message_payload["action"] in upsert_actions: - for column in message_payload["columns"]: - row.update({column["name"]: self._parse_column_value(column, cursor)}) - row.update({"_sdc_deleted_at": None}) - row.update({"_sdc_lsn": message.data_start}) - elif message_payload["action"] in delete_actions: - for column in message_payload["identity"]: - row.update({column["name"]: self._parse_column_value(column, cursor)}) - row.update( - { - "_sdc_deleted_at": _now_utc(), - "_sdc_lsn": message.data_start, - } - ) - elif message_payload["action"] in truncate_actions: + def consume(self, payload: dict, lsn: int) -> dict | None: + """Build a Singer row dict from a parsed wal2json payload. + + Returns: + A dict suitable for emission as a RECORD, or None for non-data messages + (truncate, transaction begin/commit, unrecognized actions that were logged) + """ + action = payload["action"] + + if action in _UPSERT_ACTIONS: + row = { + column["name"]: self._parse_column_value(column) for column in payload["columns"] + } + row["_sdc_deleted_at"] = None + row["_sdc_lsn"] = lsn + return row + + if action in _DELETE_ACTIONS: + row = { + column["name"]: self._parse_column_value(column) for column in payload["identity"] + } + row["_sdc_deleted_at"] = _now_utc() + row["_sdc_lsn"] = lsn + return row + + if action in _TRUNCATE_ACTIONS: self.logger.debug( - ( - "A message payload of %s (corresponding to a truncate action) " - "could not be processed." - ), - message.payload, + "A message payload of %s (corresponding to a truncate action) " + "could not be processed.", + payload, ) - elif message_payload["action"] in transaction_actions: + return None + + if action in _TRANSACTION_ACTIONS: self.logger.debug( - ( - "A message payload of %s (corresponding to a transaction beginning " - "or commit) could not be processed." - ), - message.payload, - ) - else: - raise RuntimeError( - ( - "A message payload of %s (corresponding to an unknown action type) " - "could not be processed." - ), - message.payload, + "A message payload of %s (corresponding to a transaction begin " + "or commit) could not be processed.", + payload, ) + return None - return row + raise RuntimeError( + f"A message payload of {payload!r} (corresponding to an unknown " + f"action type {action!r}) could not be processed." + ) - def _fix_wal2json_enum_quotes(self, payload: str) -> str: - """Fix malformed JSON from wal2json for PostgreSQL enum types. + def _parse_column_value(self, column: dict) -> t.Any: + """Parse a single wal2json column dict into a Python value. - wal2json outputs enum type names with unescaped quotes, e.g.: - "type":""EnumName"" - This is invalid JSON. We fix it by removing the extra quotes: - "type":"EnumName" + Handles nullability, numeric-empty-string, and ``text[]``. A string value here + is a programming error, but we handle it gracefully with a last-resort parse. """ - return self._WAL2JSON_ENUM_QUOTE_RE.sub(r'"type":"\1"', payload) - - def _parse_column_value(self, column, cursor: extras.ReplicationCursor) -> t.Any: - # When using log based replication, the wal2json output for columns of - # array types returns a string encoded in sql format, e.g. '{a,b}' - # https://github.com/eulerto/wal2json/issues/221#issuecomment-1025143441 - if column["type"] == "text[]": - value = column.get("value") - if value is None: - return None - return psycopg2.extensions.STRINGARRAY(value, cursor) - - # Handle null values explicitly. - # wal2json represents nulls as JSON null, which becomes None in Python. value = column.get("value") if value is None: return None - # For numeric types, check if empty string should be treated as null. column_type = column.get("type", "") - numeric_types = [ - "int", - "numeric", - "decimal", - "real", - "double", - "float", - "bigint", - "smallint", - ] - if value == "" and any(numeric_type in column_type for numeric_type in numeric_types): + + if column_type == "text[]": + if isinstance(value, list): + return value + # fallback, reachable only if a caller forgot to call _pre_parse_text_arrays + # STRINGARRAY with cursor=None works for UTF-8 connections + # which is the majority case + self.logger.warning( + "Encountered unparsed text[] value in _parse_column_value; falling back " + "to cursor-less STRINGARRAY parse. This indicates a missing call " + "to _pre_parse_text_arrays()." + ) + return psycopg2.extensions.STRINGARRAY(value, None) + + if value == "" and any(numeric_type in column_type for numeric_type in _NUMERIC_TYPES): return None return value diff --git a/tap_postgres/tap.py b/tap_postgres/tap.py index 782ef322..c7fa7427 100644 --- a/tap_postgres/tap.py +++ b/tap_postgres/tap.py @@ -25,6 +25,7 @@ PostgresStream, ) from tap_postgres.connection_parameters import ConnectionParameters +from tap_postgres.wal_reader import SingleConnectionWALReader if TYPE_CHECKING: from collections.abc import Sequence @@ -223,6 +224,7 @@ def __init__( See https://github.com/MeltanoLabs/tap-postgres/issues/141 """ super().__init__(*args, **kwargs) + self._shared_wal_run_completed = False assert (self.config.get("sqlalchemy_url") is not None) or ( self.config.get("host") is not None and self.config.get("port") is not None @@ -532,6 +534,17 @@ def __init__( "this choice. One of `FULL_TABLE`, `INCREMENTAL`, or `LOG_BASED`." ), ), + th.Property( + "log_based_single_connection", + th.BooleanType, + default=False, + description=( + "Use a single replication connection to sync all LOG_BASED streams " + "in one pass over the WAL. This avoids redundant WAL scans when " + "multiple tables use LOG_BASED replication. Only applicable when " + "replication_method is LOG_BASED." + ), + ), ).to_dict() @cached_property @@ -751,3 +764,42 @@ def discover_streams(self) -> Sequence[Stream]: else: streams.append(PostgresStream(self, catalog_entry, connector=connector)) return streams + + def _sync_log_based_streams_shared(self) -> None: + """Run the single-connection WAL reader across all selected LOG_BASED streams. + + Called once per tap invocation, on first call into ``PostgresLogBasedStream.get_records()`` + from any LOG_BASED stream. Sibling streams' ``get_records()`` calls become no-ops + via the ``_shared_wal_run_completed`` flag on the tap instance. + """ + streams = [ + s for s in self.streams.values() if isinstance(s, PostgresLogBasedStream) and s.selected + ] + if not streams: + return + + # schema messages for all LOG_BASED streams must be on the wire *before* + # any RECORD message, since `Stream.sync()`` only writes schema for the single stream + # whose sync triggered; siblings would otherwise see RECORDs arrive before SCHEMA + for stream in streams: + stream._write_schema_message() + + reader = SingleConnectionWALReader( + connection_parameters=self.connection_parameters, + replication_slot_name=self.config["replication_slot_name"], + max_run_seconds=self.config["replication_max_run_seconds"], + idle_exit_seconds=self.config["replication_idle_exit_seconds"], + streams=streams, + state_flush_callback=self._write_state_checkpoint, + logger=self.logger, + ) + reader.run() + self._write_state_checkpoint() + + def _write_state_checkpoint(self) -> None: + """Emit a Singer STATE message reflecting current bookmarks. + + Called on a 30s cadence by ``SingleConnectionWALReader``, using the same + state-writing mechanism the SDK invokes between streams in the default sync loop. + """ + self._state_writer.write_state(self.state) diff --git a/tap_postgres/wal_reader.py b/tap_postgres/wal_reader.py new file mode 100644 index 00000000..574ae2ed --- /dev/null +++ b/tap_postgres/wal_reader.py @@ -0,0 +1,344 @@ +"""Single-connection WAL reader for LOG_BASED streams. + +One ``SingleConnectionWALReader`` replaces N per-stream replication connections. +It opens one ``LogicalReplicationConnection``, starts replication with ``add-tables`` +listing every LOG_BASED stream's table, and dispatches each incoming wal2json message +to the owning stream's ``emit_record()`` method for immediate emission as a Singer RECORD. +""" + +import datetime +import logging +import select +import typing as t +from collections.abc import Callable + +import psycopg2 +from psycopg2 import extras + +from tap_postgres._wal_helpers import build_add_tables_option, normalize_fqn, parse_wal_message + +if t.TYPE_CHECKING: + from tap_postgres.client import PostgresLogBasedStream + from tap_postgres.connection_parameters import ConnectionParameters + + +class SingleConnectionWALReader: + """Reads the WAL once and emits records from all LOG_BASED streams inline. + + Initialize one instance, and call ``run()`, once per tap run. Responsible for: + + - opening one ``LogicalReplicationConnection`` + - starting replication with ``start_lsn = min(bookmark for all streams)`` + and ``add-tables`` listing every registered table + - reading wal2json messages in a loop, bounded by ``replication_max_run_seconds`` + and exiting early after ``replication_idle_exit_seconds`` of no data + - parsing each message, dispatching to the correct stream via ``normalize_fqn()``, + and calling ``stream.emit_record(row)`` for records that pass the per-stream + LSN filter + - emitting STATE messages on a 30s cadence via ``state_flush_callback`` + - sending replication-slot feedback on a 30s cadence, using the maximum LSN + seen so far as the flush point + - advancing the slot and per-stream bookmarks to the server's current WAL tip + on exit (same as ``_advance_slot_and_state``, but for all streams at once) + + Per-stream LSN filtering: when ``min(start_lsn)`` is used to open the connection, + some messages will have LSN below some streams' individual bookmarks. Those messages + are already-processed for those streams and must be dropped, *not* re-emitted. + Each stream's own ``start_lsn`` is used to filter dispatches to it. + """ + + STATUS_INTERVAL = 10 # seconds between server keep-alives + SELECT_TIMEOUT = 1.0 # seconds to block in select() when no message + FEEDBACK_INTERVAL = 30 # seconds between send_feedback calls + STATE_FLUSH_INTERVAL = 30 # seconds between STATE message emissions + + def __init__( + self, + *, + connection_parameters: "ConnectionParameters", + replication_slot_name: str, + max_run_seconds: int, + idle_exit_seconds: int, + streams: list["PostgresLogBasedStream"], + state_flush_callback: Callable[[], None], + logger: logging.Logger, + ) -> None: + """Initialize WAL reader. + + Args: + connection_parameters: Database connection parameters (shared with the tap) + replication_slot_name: Name of the wal2json replication slot + max_run_seconds: Hard upper bound on run duration + idle_exit_seconds: Exit if no data messages arrive for this long + streams: All selected LOG_BASED streams. Must be non-empty. + state_flush_callback: Called every ``STATE_FLUSH_INTERVAL`` seconds to emit + a Singer STATE message. The callback reads the tap's current state dict + and writes the message. + logger: Logger for progress and slot-advancement messages. + """ + if not streams: + raise ValueError("SingleConnectionWALReader requires ≥1 stream") + + self._connection_parameters = connection_parameters + self._replication_slot_name = replication_slot_name + self._max_run_seconds = max_run_seconds + self._idle_exit_seconds = idle_exit_seconds + self._state_flush_callback = state_flush_callback + self._logger = logger + + # dispatch mapping: normalized FQN => (stream, start_lsn) + # start_lsn is captured at construction so filtering is cheap + # and doesn't re-read state mid-run + self._streams_by_fqn: dict[str, tuple[PostgresLogBasedStream, int]] = {} + for stream in streams: + fqn_obj = stream.fully_qualified_name + if fqn_obj.schema is None: + raise ValueError( + f"Stream {stream.name!r} has no schema in its fully-qualified name; " + f"cannot register with the single-connection WAL reader" + ) + fqn = normalize_fqn(fqn_obj.schema, fqn_obj.table) + bookmark = stream.get_starting_replication_key_value(context=None) + start_lsn = bookmark if bookmark is not None else 0 + if fqn in self._streams_by_fqn: + raise ValueError( + f"Duplicate fully-qualified name {fqn!r} among LOG_BASED " + f"streams; each table may be selected only once." + ) + self._streams_by_fqn[fqn] = (stream, start_lsn) + + self.records_emitted = 0 + self.records_filtered_by_lsn = 0 + self.records_unroutable = 0 + self.records_malformed = 0 + # per-FQN counters, useful for "is stream X actually getting any data?" debugging + # initialized for every registered stream so the dict is complete even if zero records + self.records_emitted_by_fqn: dict[str, int] = dict.fromkeys(self._streams_by_fqn, 0) + + def run(self) -> None: + """Execute single-connection WAL read loop. + + This is synchronous and blocks until either: no data message has arrived + for ``idle_exit_seconds`` OR ``max_run_seconds`` has elapsed. On exit, advances + replication slot to the current WAL tip and updates streams' bookmarks to that tip. + """ + global_start_lsn = min(start_lsn for _, start_lsn in self._streams_by_fqn.values()) + fqn_objs = [stream.fully_qualified_name for stream, _ in self._streams_by_fqn.values()] + add_tables = build_add_tables_option( + [(fqn_obj.schema, fqn_obj.table) for fqn_obj in fqn_objs] + ) + self._logger.info( + "Starting single-connection WAL read for %d stream(s) from LSN %d", + len(self._streams_by_fqn), + global_start_lsn, + ) + + conn = psycopg2.connect( + self._connection_parameters.render_as_psycopg2_dsn(), + connection_factory=extras.LogicalReplicationConnection, + ) + cursor = conn.cursor() + try: + cursor.send_feedback(flush_lsn=global_start_lsn) + cursor.start_replication( + slot_name=self._replication_slot_name, + decode=True, + start_lsn=global_start_lsn, + status_interval=self.STATUS_INTERVAL, + options={ + "format-version": 2, + "include-transaction": False, + "add-tables": add_tables, + }, + ) + self._run_loop(cursor) + self._advance_slot_and_state_all(cursor, global_start_lsn) + finally: + cursor.close() + conn.close() + + self._logger.info( + "WAL read complete: %d records emitted, %d filtered by per-stream LSN, " + "%d unroutable, %d malformed", + self.records_emitted, + self.records_filtered_by_lsn, + self.records_unroutable, + self.records_malformed, + ) + self._logger.info( + "Per-stream record counts: %s", + {fqn: self.records_emitted_by_fqn[fqn] for fqn in self._streams_by_fqn}, + ) + + def _run_loop(self, cursor: extras.ReplicationCursor) -> None: + """Inner read / dispatch / periodic-flush loop.""" + run_start = datetime.datetime.now() + last_data_message = run_start + last_feedback = run_start + last_state_flush = run_start + max_lsn_seen = 0 + + while True: + now = datetime.datetime.now() + # total time budget check + if (now - run_start).total_seconds() > self._max_run_seconds: + self._logger.info( + "Reached max run time of %d seconds (%d records emitted)", + self._max_run_seconds, + self.records_emitted, + ) + break + + # periodic STATE emission + if (now - last_state_flush).total_seconds() >= self.STATE_FLUSH_INTERVAL: + self._state_flush_callback() + last_state_flush = now + + # periodic replication-slot feedback + if max_lsn_seen > 0 and (now - last_feedback).total_seconds() >= self.FEEDBACK_INTERVAL: + try: + cursor.send_feedback(flush_lsn=max_lsn_seen) + last_feedback = now + except Exception as exc: + self._logger.warning("send_feedback failed: %s", exc) + + # read the next WAL message + message = cursor.read_message() + if message is not None: + last_data_message = datetime.datetime.now() + self._dispatch(cursor, message) + max_lsn_seen = max(max_lsn_seen, message.data_start) + continue + + # no message available -- block briefly and check idle exit + try: + ready = select.select([cursor], [], [], self.SELECT_TIMEOUT)[0] + except InterruptedError: + ready = [cursor] + + if not ready: + data_idle = (datetime.datetime.now() - last_data_message).total_seconds() + if data_idle >= self._idle_exit_seconds: + self._logger.info( + "No data for %.0f s, ending WAL read (%d records emitted in %.0f s)", + data_idle, + self.records_emitted, + (datetime.datetime.now() - run_start).total_seconds(), + ) + break + + def _dispatch( + self, cursor: extras.ReplicationCursor, message: extras.ReplicationMessage + ) -> None: + """Parse one WAL message and hand it to the owning stream.""" + # parse (+pre-parse text[] values) while the cursor is alive + payload = parse_wal_message(message.payload, cursor) + if payload is None: + self._logger.warning( + "A message payload of %s could not be converted to JSON", + message.payload, + ) + self.records_malformed += 1 + return + + # non-data messages (transactions, truncates) have no schema/table + # consume() returns None for them and we're done + schema_name = payload.get("schema") + table_name = payload.get("table") + if schema_name is None or table_name is None: + return + + fqn = normalize_fqn(schema_name, table_name) + routed = self._streams_by_fqn.get(fqn) + if routed is None: + # this should never happen: add-tables filters at the server, so we only receive + # messages for registered tables... count it and move on; a non-zero counter + # in logs is a signal to investigate (e.g. a normalize_fqn format mismatch) + self.records_unroutable += 1 + self._logger.debug("Received message for unregistered table %s; dropping", fqn) + return + + stream, stream_start_lsn = routed + + # per-stream LSN filter: because start_replication was opened at min(start_lsn), + # streams with higher bookmarks will see some already-processed messages + # it's safe to drop them silently + if message.data_start < stream_start_lsn: + self.records_filtered_by_lsn += 1 + return + + row = stream.consume(payload, message.data_start) + if not row: + return + + stream.emit_record(row) + self.records_emitted += 1 + self.records_emitted_by_fqn[fqn] += 1 + + def _advance_slot_and_state_all(self, cursor: extras.ReplicationCursor, start_lsn: int) -> None: + """Advance slot to the WAL tip and update every stream's bookmark. + + Mirrors ``PostgresLogBasedStream._advance_slot_and_state`` but applies resulting LSN + to every registered stream. + """ + # prefer server-reported wal_end if it's ahead of start_lsn, otherwise query the server + flush_lsn: int | None = None + try: + wal_end = getattr(cursor, "wal_end", None) + if wal_end is not None and wal_end > start_lsn: + flush_lsn = wal_end + except Exception: + pass + + if flush_lsn is None or flush_lsn <= start_lsn: + flush_lsn = self._query_current_wal_lsn() + + if flush_lsn is None or flush_lsn <= start_lsn: + return + + try: + cursor.send_feedback(flush_lsn=flush_lsn) + except Exception as exc: + self._logger.warning("Final send_feedback failed: %s", exc) + return + + self._logger.info( + "Advanced replication slot from %d to %d (delta %.2f MB)", + start_lsn, + flush_lsn, + (flush_lsn - start_lsn) / (1024 * 1024), + ) + + # update every stream's bookmark to the advanced LSN + # for streams whose per-stream start_lsn was already above this value, skip -- + # don't move bookmarks backward! + for stream, stream_start_lsn in self._streams_by_fqn.values(): + if flush_lsn <= stream_start_lsn: + continue + state = stream.get_context_state(context=None) + state["replication_key"] = stream.replication_key + state["replication_key_value"] = flush_lsn + + # one final STATE emission so the next run picks up the advance + self._state_flush_callback() + + def _query_current_wal_lsn(self) -> int | None: + """Query ``pg_current_wal_flush_lsn()`` on a non-replication conn.""" + try: + conn = psycopg2.connect( + self._connection_parameters.render_as_psycopg2_dsn(), + ) + try: + conn.autocommit = True + with conn.cursor() as cur: + cur.execute("SELECT pg_current_wal_flush_lsn()") + row = cur.fetchone() + if row is None: + return None + hi, lo = row[0].split("/") + return (int(hi, 16) << 32) + int(lo, 16) + finally: + conn.close() + except Exception as exc: + self._logger.warning("Could not query current WAL LSN: %s", exc) + return None diff --git a/tests/test_consume.py b/tests/test_consume.py new file mode 100644 index 00000000..2cc31eb3 --- /dev/null +++ b/tests/test_consume.py @@ -0,0 +1,153 @@ +"""Unit tests for parsing helpers and ``PostgresLogBasedStream.consume()``. + +No Postgres needed. Hand-crafted wal2json payloads cover the action-type branches +in ``consume()`` and the recovery paths in ``parse_wal_message``. +""" + +from __future__ import annotations + +import pytest +import sqlalchemy as sa +from singer_sdk.singerlib import CatalogEntry, MetadataMapping, Schema + +from tap_postgres._wal_helpers import parse_wal_message +from tap_postgres.client import PostgresConnector, PostgresLogBasedStream +from tap_postgres.connection_parameters import ConnectionParameters +from tap_postgres.tap import TapPostgres + +DUMMY_CONFIG = { + "user": "postgres", + "password": "postgres", + "host": "localhost", + "port": 5432, + "database": "postgres", +} + + +# TODO: should this be a shared fixture/function in conftest? +class DummyConnector(PostgresConnector): + """Connector that doesn't talk to a real database.""" + + def __init__(self, config: dict) -> None: + params = ConnectionParameters.from_tap_config(config) + super().__init__(config, params.render_as_sqlalchemy_url()) + + def get_table(self, full_table_name, column_names=None): + return sa.Table("dummy", sa.MetaData(), sa.Column("id", sa.Integer)) + + +@pytest.fixture +def stream() -> PostgresLogBasedStream: + """A ``PostgresLogBasedStream`` wired against a stub connector.""" + tap = TapPostgres(config=DUMMY_CONFIG, setup_mapper=False) + catalog_entry = CatalogEntry( + tap_stream_id="public-users", + metadata=MetadataMapping.from_iterable( + [ + { + "breadcrumb": [], + "metadata": { + "inclusion": "available", + "selected": True, + "schema-name": "public", + }, + }, + { + "breadcrumb": ["properties", "id"], + "metadata": {"inclusion": "available", "selected": True}, + }, + ] + ), + schema=Schema(properties={"id": Schema(type=["integer", "null"])}, type="object"), + table="users", + ) + return PostgresLogBasedStream( + tap, + catalog_entry.to_dict(), + connection_parameters=ConnectionParameters.from_tap_config(DUMMY_CONFIG), + connector=DummyConnector(config=DUMMY_CONFIG), + ) + + +@pytest.mark.parametrize("action", ["I", "U"], ids=["insert", "update"]) +def test_consume_upsert_returns_row_with_sdc_columns(stream, action): + """Inserts and updates have identical row-construction semantics.""" + payload = { + "action": action, + "schema": "public", + "table": "users", + "columns": [ + {"name": "id", "type": "integer", "value": 42}, + {"name": "name", "type": "text", "value": "alice"}, + ], + } + assert stream.consume(payload, lsn=12345) == { + "id": 42, + "name": "alice", + "_sdc_deleted_at": None, + "_sdc_lsn": 12345, + } + + +def test_consume_delete_uses_identity_and_sets_deleted_at(stream): + payload = { + "action": "D", + "schema": "public", + "table": "users", + "identity": [{"name": "id", "type": "integer", "value": 5}], + } + row = stream.consume(payload, lsn=55) + assert row["id"] == 5 + assert row["_sdc_lsn"] == 55 + # a stringly-typed UTC ISO timestamp is set; exact value is time-dependent + assert isinstance(row["_sdc_deleted_at"], str) + assert row["_sdc_deleted_at"].endswith("Z") + + +@pytest.mark.parametrize( + "action", + ["T", "B", "C"], + ids=["truncate", "transaction-begin", "transaction-commit"], +) +def test_consume_non_data_actions_return_none(stream, action): + """Truncate and transaction begin/commit are non-data and produce no row.""" + assert stream.consume({"action": action}, lsn=1) is None + + +def test_consume_unknown_action_raises(stream): + with pytest.raises(RuntimeError, match="unknown action"): + stream.consume({"action": "X", "columns": []}, lsn=1) + + +@pytest.mark.parametrize( + "column_type", + ["integer", "numeric(10,2)", "bigint", "double precision"], +) +def test_consume_numeric_empty_string_becomes_none(stream, column_type): + """wal2json sometimes emits ``""`` for numeric columns; treat as NULL.""" + payload = { + "action": "I", + "schema": "public", + "table": "users", + "columns": [ + {"name": "id", "type": "integer", "value": 1}, + {"name": "amount", "type": column_type, "value": ""}, + ], + } + assert stream.consume(payload, lsn=1)["amount"] is None + + +@pytest.mark.parametrize( + ["raw", "expected"], + [ + ('{"action":"B"}', {"action": "B"}), + ( + '{"action":"I","columns":[{"name":"c","type":""MyEnum"","value":"X"}]}', + {"action": "I", "columns": [{"name": "c", "type": "MyEnum", "value": "X"}]}, + ), + ("{not json{", None), + ], + ids=["valid", "enum-quote-bug-recovered", "unrecoverable"], +) +def test_parse_wal_message(raw, expected): + assert parse_wal_message(raw, cursor=None) == expected diff --git a/tests/test_wal_helpers.py b/tests/test_wal_helpers.py new file mode 100644 index 00000000..65d95db7 --- /dev/null +++ b/tests/test_wal_helpers.py @@ -0,0 +1,44 @@ +import pytest + +from tap_postgres._wal_helpers import build_add_tables_option, escape_for_add_tables, normalize_fqn + + +@pytest.mark.parametrize( + ["schema", "table", "expected"], + [ + ("public", "users", "public.users"), + ("Public", "MyTable", "Public.MyTable"), + ("weird.schema", "t", "weird.schema.t"), + ], + ids=["basic", "preserve-case", "schema-with-dot"], +) +def test_normalize_fqn(schema, table, expected): + assert normalize_fqn(schema, table) == expected + + +@pytest.mark.parametrize( + ["identifier", "expected"], + [ + ("users", "users"), + ("a,b", "a\\,b"), + ("a.b", "a\\.b"), + ("a\\b", "a\\\\b"), + ("a\\,b", "a\\\\\\,b"), + ], + ids=["basic", "comma", "dot", "backslash", "escaping-order"], +) +def test_escape_for_add_tables(identifier, expected): + assert escape_for_add_tables(identifier) == expected + + +@pytest.mark.parametrize( + ["fqn_pairs", "expected"], + [ + ([("public", "users")], "public.users"), + ([("public", "users"), ("app", "orders")], "public.users,app.orders"), + ([("my,schema", "tbl.name")], "my\\,schema.tbl\\.name"), + ], + ids=["single", "multiple", "special-chars"], +) +def test_build_add_tables_option(fqn_pairs, expected): + assert build_add_tables_option(fqn_pairs) == expected diff --git a/tests/test_wal_reader.py b/tests/test_wal_reader.py new file mode 100644 index 00000000..91e294da --- /dev/null +++ b/tests/test_wal_reader.py @@ -0,0 +1,643 @@ +"""Unit tests for the single-connection WAL reader. + +No Postgres DB necessary! Stubbed streams have the same interface as ``PostgresLogBasedStream``. +Tests for ``emit_record`` and the tap's LOG_BASED dispatch use ``PostgresLogBasedStream`` +against a ``DummyConnector``. Tests patch ``tap_postgres.wal_reader.psycopg2.connect`` +and ``tap_postgres.wal_reader.select.select`` so the read loop runs against in-memory fakes. +""" + +from __future__ import annotations + +import json +import logging +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +import sqlalchemy as sa +from singer_sdk.singerlib import CatalogEntry, MetadataMapping, Schema + +from tap_postgres.client import PostgresConnector, PostgresLogBasedStream +from tap_postgres.connection_parameters import ConnectionParameters +from tap_postgres.tap import TapPostgres +from tap_postgres.wal_reader import SingleConnectionWALReader + +# fake replication primitives + +DUMMY_CONFIG = { + "user": "postgres", + "password": "postgres", + "host": "localhost", + "port": 5432, + "database": "postgres", +} + + +class FakeReplicationMessage: + """Stand-in for ``psycopg2.extras.ReplicationMessage``.""" + + def __init__(self, payload: str, data_start: int) -> None: + self.payload = payload + self.data_start = data_start + + +class FakeReplicationCursor: + """ + Minimal stand-in for ``psycopg2.extras.ReplicationCursor``. + + Returns scripted messages from ``read_message()``. Returns None when exhausted, + so reader's idle-exit path fires (with a patched ``select.select`` returning empty). + """ + + def __init__( + self, messages: list[FakeReplicationMessage] | None = None, wal_end: int = 0 + ) -> None: + self._messages = list(messages or []) + self.feedback_lsns: list[int] = [] + self.start_options: dict | None = None + self.started = False + self.wal_end = wal_end + self.closed = False + + def send_feedback(self, *, flush_lsn: int) -> None: + self.feedback_lsns.append(flush_lsn) + + def start_replication(self, **kwargs) -> None: + self.started = True + self.start_options = kwargs + + def read_message(self): + if self._messages: + msg = self._messages.pop(0) + self.wal_end = max(self.wal_end, msg.data_start) + return msg + return None + + def close(self) -> None: + self.closed = True + + def fileno(self) -> int: + # select.select is patched out, but some env may still call this + return 0 + + +class FakeReplicationConnection: + """Stand-in for the connection returned by ``psycopg2.connect``.""" + + def __init__(self, cursor: FakeReplicationCursor) -> None: + self._cursor = cursor + self.closed = False + + def cursor(self) -> FakeReplicationCursor: + return self._cursor + + def close(self) -> None: + self.closed = True + + +@contextmanager +def patch_replication(cursor: FakeReplicationCursor): + """ + Patch ``psycopg2.connect`` and ``select.select`` in ``wal_reader``. + + ``select.select`` is patched to immediately return "no readiness" so the + idle-exit branch fires when the scripted message list is exhausted. + """ + conn = FakeReplicationConnection(cursor) + with ( + patch("tap_postgres.wal_reader.psycopg2.connect", return_value=conn) as p_connect, + patch("tap_postgres.wal_reader.select.select", return_value=([], [], [])) as p_select, + ): + yield p_connect, p_select + + +# stub stream -- mimics interface that PostgresLogBasedStream exposes to the reader + + +class StubStream: + """ + In-memory stand-in for ``PostgresLogBasedStream`` for reader unit tests. + Implements only the surface area touched by the reader! + """ + + replication_key = "_sdc_lsn" + + def __init__(self, schema: str, table: str, start_lsn: int = 0) -> None: + self.fully_qualified_name = SimpleNamespace(schema=schema, table=table) + self.name = f"{schema}-{table}" + self._start_lsn = start_lsn + self._state: dict = {} + self.emitted: list[dict] = [] + + def get_starting_replication_key_value(self, *, context=None): + return self._start_lsn + + def consume(self, payload: dict, lsn: int) -> dict | None: + action = payload.get("action") + if action in ("I", "U"): + row = {c["name"]: c["value"] for c in payload.get("columns", [])} + row["_sdc_deleted_at"] = None + row["_sdc_lsn"] = lsn + return row + if action == "D": + row = {c["name"]: c["value"] for c in payload.get("identity", [])} + row["_sdc_deleted_at"] = "2024-01-01T00:00:00Z" + row["_sdc_lsn"] = lsn + return row + # truncate / transaction — non-data + return None + + def emit_record(self, record: dict, *, context=None) -> None: + self.emitted.append(record) + + def get_context_state(self, context): + return self._state + + +def _wal_payload(schema: str, table: str, action: str = "I", **columns) -> str: + """Build a wal2json format-version=2 JSON payload.""" + cols = [{"name": k, "type": "int", "value": v} for k, v in columns.items()] + body: dict = {"action": action, "schema": schema, "table": table} + if action == "D": + body["identity"] = cols + elif action in ("I", "U"): + body["columns"] = cols + return json.dumps(body) + + +def _build_reader(streams, *, max_run=60, idle_exit=0, slot="testslot"): + """Construct a ``SingleConnectionWALReader`` with sensible defaults.""" + return SingleConnectionWALReader( + connection_parameters=ConnectionParameters.from_tap_config(DUMMY_CONFIG), + replication_slot_name=slot, + max_run_seconds=max_run, + idle_exit_seconds=idle_exit, + streams=streams, + state_flush_callback=MagicMock(), + logger=logging.getLogger("test_wal_reader"), + ) + + +# dummy connector / real-stream helpers -- for emit_record and tap-level tests + + +class DummyConnector(PostgresConnector): + """ + Connector that doesn't talk to a real database. + Mirrors ``DummyConnector`` in ``tests/test_stream_class.py``. + """ + + def __init__(self, config: dict) -> None: + params = ConnectionParameters.from_tap_config(config) + super().__init__(config, params.render_as_sqlalchemy_url()) + + def get_table(self, full_table_name, column_names=None): + return sa.Table("dummy", sa.MetaData(), sa.Column("id", sa.Integer)) + + +def _build_log_based_stream( + tap: TapPostgres, *, schema_name: str, table_name: str, stream_id: str | None = None +) -> PostgresLogBasedStream: + """Build a real ``PostgresLogBasedStream`` against a stub connector.""" + catalog_entry = CatalogEntry( + tap_stream_id=stream_id or f"{schema_name}-{table_name}", + metadata=MetadataMapping.from_iterable( + [ + { + "breadcrumb": [], + "metadata": { + "inclusion": "available", + "selected": True, + "schema-name": schema_name, + }, + }, + { + "breadcrumb": ["properties", "id"], + "metadata": {"inclusion": "available", "selected": True}, + }, + ] + ), + schema=Schema(properties={"id": Schema(type=["integer", "null"])}, type="object"), + table=table_name, + ) + return PostgresLogBasedStream( + tap, + catalog_entry.to_dict(), + connection_parameters=ConnectionParameters.from_tap_config(DUMMY_CONFIG), + connector=DummyConnector(config=DUMMY_CONFIG), + ) + + +# the actual tests, finally + + +def test_construction_rejects_empty_stream_list(): + """An empty stream list is a programming error; reject it loudly.""" + with pytest.raises(ValueError, match="≥1 stream"): + SingleConnectionWALReader( + connection_parameters=ConnectionParameters.from_tap_config(DUMMY_CONFIG), + replication_slot_name="s", + max_run_seconds=1, + idle_exit_seconds=0, + streams=[], + state_flush_callback=MagicMock(), + logger=logging.getLogger("test"), + ) + + +def test_construction_rejects_duplicate_fqn(): + """Two streams pointing at the same table is a misconfiguration.""" + s1 = StubStream("public", "users", start_lsn=10) + s2 = StubStream("public", "users", start_lsn=20) + with pytest.raises(ValueError, match="Duplicate fully-qualified name"): + _build_reader([s1, s2]) + + +def test_start_replication_uses_min_start_lsn_and_escaped_add_tables(): + """``add-tables`` must be the escaped FQN list; ``start_lsn`` the min.""" + s1 = StubStream("public", "users", start_lsn=200) + s2 = StubStream("my,schema", "tbl.name", start_lsn=50) # special chars + reader = _build_reader([s1, s2]) + + cursor = FakeReplicationCursor(messages=[], wal_end=300) + with patch_replication(cursor): + reader.run() + + assert cursor.started is True + assert cursor.start_options["start_lsn"] == 50 # min of (200, 50) + assert cursor.start_options["slot_name"] == "testslot" + assert cursor.start_options["decode"] is True + options = cursor.start_options["options"] + assert options["format-version"] == 2 + assert options["include-transaction"] is False + # FQNs are escape_for_add_tables()'d and joined with commas + add_tables = options["add-tables"] + assert "public.users" in add_tables + assert "my\\,schema.tbl\\.name" in add_tables + + +def test_routes_message_to_correct_stream_by_fqn(): + """A message for ``schema.table`` ends up only on that stream.""" + s_users = StubStream("public", "users") + s_orders = StubStream("public", "orders") + reader = _build_reader([s_users, s_orders]) + + msgs = [ + FakeReplicationMessage(_wal_payload("public", "users", id=1), data_start=10), + FakeReplicationMessage(_wal_payload("public", "orders", id=99), data_start=11), + FakeReplicationMessage(_wal_payload("public", "users", id=2), data_start=12), + ] + cursor = FakeReplicationCursor(messages=msgs, wal_end=12) + with patch_replication(cursor): + reader.run() + + assert [r["id"] for r in s_users.emitted] == [1, 2] + assert [r["id"] for r in s_orders.emitted] == [99] + assert reader.records_emitted == 3 + + +def test_drops_messages_below_per_stream_start_lsn(): + """Stream B with start_lsn=200 must NOT see a message at LSN 150.""" + s_a = StubStream("public", "a", start_lsn=100) + s_b = StubStream("public", "b", start_lsn=200) + reader = _build_reader([s_a, s_b]) + + msgs = [ + # Below B's start_lsn but above A's: A sees it, B does not. + FakeReplicationMessage(_wal_payload("public", "b", id=1), data_start=150), + # Above both — both eligible (this is a B message, B sees it). + FakeReplicationMessage(_wal_payload("public", "b", id=2), data_start=250), + # A always eligible above 100. + FakeReplicationMessage(_wal_payload("public", "a", id=3), data_start=120), + ] + cursor = FakeReplicationCursor(messages=msgs, wal_end=250) + with patch_replication(cursor): + reader.run() + + assert [r["id"] for r in s_a.emitted] == [3] + assert [r["id"] for r in s_b.emitted] == [2] + assert reader.records_filtered_by_lsn == 1 + assert reader.records_emitted == 2 + + +def test_unroutable_message_increments_counter_and_does_not_crash(): + """A payload whose schema/table doesn't match any registered stream is counted.""" + s = StubStream("public", "users") + reader = _build_reader([s]) + + msgs = [ + FakeReplicationMessage(_wal_payload("public", "ghosts", id=1), data_start=10), + FakeReplicationMessage(_wal_payload("public", "users", id=2), data_start=11), + ] + cursor = FakeReplicationCursor(messages=msgs, wal_end=11) + with patch_replication(cursor): + reader.run() + + assert s.emitted == [{"id": 2, "_sdc_deleted_at": None, "_sdc_lsn": 11}] + assert reader.records_unroutable == 1 + assert reader.records_emitted == 1 + + +def test_truncate_and_transaction_messages_do_not_emit(): + """Action ``T`` (truncate) and ``B``/``C`` (transaction) yield no records.""" + s = StubStream("public", "users") + reader = _build_reader([s]) + + msgs = [ + # Truncate has schema/table, but consume() returns None. + FakeReplicationMessage( + json.dumps({"action": "T", "schema": "public", "table": "users"}), + data_start=10, + ), + # Transaction begin/commit have no schema/table — dropped before consume. + FakeReplicationMessage(json.dumps({"action": "B"}), data_start=11), + FakeReplicationMessage(json.dumps({"action": "C"}), data_start=12), + # Real data message after, just to confirm the loop kept running. + FakeReplicationMessage(_wal_payload("public", "users", id=42), data_start=13), + ] + cursor = FakeReplicationCursor(messages=msgs, wal_end=13) + with patch_replication(cursor): + reader.run() + + assert s.emitted == [{"id": 42, "_sdc_deleted_at": None, "_sdc_lsn": 13}] + assert reader.records_emitted == 1 + + +def test_periodic_state_flush_fires_on_cadence(monkeypatch): + """``state_flush_callback`` fires once the STATE_FLUSH_INTERVAL has elapsed. + + Rather than fake the clock, drive the cadence by setting the interval to + zero so the flush happens every iteration that processes a message. + """ + monkeypatch.setattr(SingleConnectionWALReader, "STATE_FLUSH_INTERVAL", 0) + s = StubStream("public", "users") + reader = _build_reader([s]) + + msgs = [ + FakeReplicationMessage(_wal_payload("public", "users", id=i), data_start=i) + for i in range(1, 4) + ] + cursor = FakeReplicationCursor(messages=msgs, wal_end=10) + with patch_replication(cursor): + reader.run() + + # 3 in-loop flushes + 1 in _advance_slot_and_state_all = 4 -- be lenient + assert reader._state_flush_callback.call_count >= 2 + + +def test_send_feedback_uses_max_lsn_seen_on_cadence(monkeypatch): + """When the feedback interval elapses, send_feedback uses max_lsn_seen.""" + monkeypatch.setattr(SingleConnectionWALReader, "FEEDBACK_INTERVAL", 0) + s = StubStream("public", "users") + reader = _build_reader([s]) + + msgs = [ + FakeReplicationMessage(_wal_payload("public", "users", id=1), data_start=100), + FakeReplicationMessage(_wal_payload("public", "users", id=2), data_start=200), + ] + cursor = FakeReplicationCursor(messages=msgs, wal_end=300) + with patch_replication(cursor): + reader.run() + + # initial feedback at start_lsn=0, then in-loop feedbacks once max_lsn>0, + # then the final advance feedback; assert max_lsn_seen=200 made it in + assert 200 in cursor.feedback_lsns + + +def test_idle_exit_advances_slot_and_state_for_all_streams(): + """On idle-exit, every stream's ``replication_key_value`` advances to wal_end.""" + s_a = StubStream("public", "a", start_lsn=10) + s_b = StubStream("public", "b", start_lsn=20) + reader = _build_reader([s_a, s_b], idle_exit=0) + + msgs = [ + FakeReplicationMessage(_wal_payload("public", "a", id=1), data_start=50), + FakeReplicationMessage(_wal_payload("public", "b", id=2), data_start=60), + ] + advanced_to = 999 + cursor = FakeReplicationCursor(messages=msgs, wal_end=advanced_to) + with patch_replication(cursor): + reader.run() + + assert s_a.get_context_state(None)["replication_key_value"] == advanced_to + assert s_b.get_context_state(None)["replication_key_value"] == advanced_to + assert s_a.get_context_state(None)["replication_key"] == "_sdc_lsn" + # final feedback should have been sent at the advanced LSN + assert advanced_to in cursor.feedback_lsns + + +def test_max_run_time_exit_advances_slot_and_state(): + """Same advancement path runs when ``max_run_seconds`` is exceeded. + + With ``max_run_seconds = -1``, the very first iteration's time check is always true, + so the loop breaks before any reads -- exercising the max-run exit path deterministically. + """ + s = StubStream("public", "users", start_lsn=10) + reader = _build_reader([s], max_run=-1, idle_exit=10_000) + + cursor = FakeReplicationCursor(messages=[], wal_end=777) + with patch_replication(cursor): + reader.run() + + assert s.get_context_state(None)["replication_key_value"] == 777 + + +def test_emit_record_writes_record_message_and_advances_state(): + """``emit_record`` sends one RECORD message and bumps the LSN bookmark.""" + tap = TapPostgres(config=DUMMY_CONFIG, setup_mapper=False) + stream = _build_log_based_stream(tap, schema_name="public", table_name="users") + + # patch SDK call we depend on; we're only asserting the contract + stream._write_record_message = MagicMock() + + record = {"id": 1, "_sdc_deleted_at": None, "_sdc_lsn": 12345} + stream.emit_record(record) + + stream._write_record_message.assert_called_once_with(record) + state = stream.get_context_state(None) + assert state["replication_key"] == "_sdc_lsn" + assert state["replication_key_value"] == 12345 + + +def test_emit_record_does_not_move_bookmark_backward(): + """A record with an LSN below the current bookmark must not regress state.""" + tap = TapPostgres(config=DUMMY_CONFIG, setup_mapper=False) + stream = _build_log_based_stream(tap, schema_name="public", table_name="users") + stream._write_record_message = MagicMock() + + state = stream.get_context_state(None) + state["replication_key"] = "_sdc_lsn" + state["replication_key_value"] = 1000 + + stream.emit_record({"id": 1, "_sdc_deleted_at": None, "_sdc_lsn": 500}) + assert state["replication_key_value"] == 1000 # unchanged + + stream.emit_record({"id": 2, "_sdc_deleted_at": None, "_sdc_lsn": 2000}) + assert state["replication_key_value"] == 2000 # forward only + + +def test_schema_messages_emitted_before_any_record_message(): + """Every stream's SCHEMA must hit the wire before any of its RECORDs. + + Two LOG_BASED streams; WAL records interleaved between them. We mock each stream's + ``_write_schema_message`` and ``emit_record`` to record into a shared event log, + then assert that the first emit_record event for each stream is preceded + by at least one schema event for that stream. + """ + config = { + **DUMMY_CONFIG, + # force reader's main loop to exit immediately once read_message returns None + # i.e. as soon as scripted messages are exhausted + "replication_idle_exit_seconds": 0, + } + tap = TapPostgres(config=config, setup_mapper=False) + s_users = _build_log_based_stream(tap, schema_name="public", table_name="users") + s_orders = _build_log_based_stream(tap, schema_name="public", table_name="orders") + + # inject streams into the tap, bypassing discover_streams and cached connector property + # which would try to reach a real db + tap._streams = {s_users.name: s_users, s_orders.name: s_orders} + tap.connection_parameters = ConnectionParameters.from_tap_config(config) + tap._write_state_checkpoint = MagicMock() # avoid writing to stdout + + events: list[tuple[str, str]] = [] + + def schema_writer(stream_name): + return lambda: events.append(("SCHEMA", stream_name)) + + def record_writer(stream_name): + def _emit(record, *, context=None): + events.append(("RECORD", stream_name)) + + return _emit + + s_users._write_schema_message = schema_writer(s_users.name) + s_orders._write_schema_message = schema_writer(s_orders.name) + s_users.emit_record = record_writer(s_users.name) + s_orders.emit_record = record_writer(s_orders.name) + + msgs = [ + FakeReplicationMessage(_wal_payload("public", "users", id=1), data_start=10), + FakeReplicationMessage(_wal_payload("public", "orders", id=99), data_start=11), + FakeReplicationMessage(_wal_payload("public", "users", id=2), data_start=12), + FakeReplicationMessage(_wal_payload("public", "orders", id=100), data_start=13), + ] + cursor = FakeReplicationCursor(messages=msgs, wal_end=20) + with patch_replication(cursor): + tap._sync_log_based_streams_shared() + + # find index of the first RECORD per stream + def _first_record(name): + for i, (kind, sn) in enumerate(events): + if kind == "RECORD" and sn == name: + return i + raise AssertionError(f"no RECORD event for {name}") + + def _first_schema(name): + for i, (kind, sn) in enumerate(events): + if kind == "SCHEMA" and sn == name: + return i + raise AssertionError(f"no SCHEMA event for {name}") + + # both streams' SCHEMA must precede the first RECORD for either stream + last_schema_idx = max(_first_schema(s_users.name), _first_schema(s_orders.name)) + first_record_idx = min(_first_record(s_users.name), _first_record(s_orders.name)) + assert last_schema_idx < first_record_idx, f"SCHEMA-before-RECORD violated; events: {events}" + + +def test_write_schema_message_is_idempotent(): + """ + ``Stream.sync()`` and ``_sync_log_based_streams_shared`` both call this. + Without idempotency we'd emit duplicate SCHEMA messages for every LOG_BASED stream. + """ + tap = TapPostgres(config=DUMMY_CONFIG, setup_mapper=False) + stream = _build_log_based_stream(tap, schema_name="public", table_name="users") + + # patch the SDK base-class method so we can count actual emissions + with patch("singer_sdk.sql.SQLStream._write_schema_message", autospec=True) as base_write: + stream._write_schema_message() + stream._write_schema_message() + stream._write_schema_message() + + assert base_write.call_count == 1 + + +def test_malformed_payload_increments_counter(): + """Payloads that fail JSON parsing (even after the enum-quote repair) are counted.""" + s = StubStream("public", "users") + reader = _build_reader([s]) + + msgs = [ + # garbage JSON, beyond the enum-quote bug repair + FakeReplicationMessage("{not json{", data_start=10), + FakeReplicationMessage(_wal_payload("public", "users", id=1), data_start=11), + ] + cursor = FakeReplicationCursor(messages=msgs, wal_end=11) + with patch_replication(cursor): + reader.run() + + assert reader.records_malformed == 1 + assert reader.records_emitted == 1 + + +def test_per_stream_emit_counter_tracks_routing(): + """``records_emitted_by_fqn`` is keyed by FQN and reflects routing.""" + s_users = StubStream("public", "users") + s_orders = StubStream("public", "orders") + s_quiet = StubStream("public", "quiet") # registered but receives nothing + reader = _build_reader([s_users, s_orders, s_quiet]) + + msgs = [ + FakeReplicationMessage(_wal_payload("public", "users", id=1), data_start=10), + FakeReplicationMessage(_wal_payload("public", "orders", id=2), data_start=11), + FakeReplicationMessage(_wal_payload("public", "users", id=3), data_start=12), + ] + cursor = FakeReplicationCursor(messages=msgs, wal_end=12) + with patch_replication(cursor): + reader.run() + + assert reader.records_emitted_by_fqn == { + "public.users": 2, + "public.orders": 1, + "public.quiet": 0, + } + + +def test_idle_exit_seconds_zero_exits_immediately_when_no_messages(): + """``idle_exit_seconds=0`` is the explicit "exit as soon as the queue drains" knob.""" + s = StubStream("public", "users", start_lsn=10) + reader = _build_reader([s], idle_exit=0) + + cursor = FakeReplicationCursor(messages=[], wal_end=20) + with patch_replication(cursor): + reader.run() + + # no records, no crash: advancement path still ran since wal_end > start_lsn + assert reader.records_emitted == 0 + assert s.get_context_state(None)["replication_key_value"] == 20 + + +def test_config_flag_off_uses_legacy_per_stream_path(): + """``log_based_single_connection=false`` falls through to the legacy generator. + + With the flag off, ``get_records()`` must yield from ``_get_records_per_stream`` + and never trigger ``_sync_log_based_streams_shared``, which would build the reader. + """ + config = {**DUMMY_CONFIG, "log_based_single_connection": False} + tap = TapPostgres(config=config, setup_mapper=False) + stream = _build_log_based_stream(tap, schema_name="public", table_name="users") + + sentinel = [{"id": 1, "_sdc_lsn": 7, "_sdc_deleted_at": None}] + + with ( + patch.object(stream, "_get_records_per_stream", return_value=iter(sentinel)) as p_legacy, + patch.object(tap, "_sync_log_based_streams_shared") as p_shared, + patch("tap_postgres.wal_reader.SingleConnectionWALReader") as p_reader_cls, + ): + out = list(stream.get_records(context=None)) + + assert out == sentinel + p_legacy.assert_called_once() + p_shared.assert_not_called() + p_reader_cls.assert_not_called()