diff --git a/setup.cfg b/setup.cfg index 136b00d645b..be198a8a537 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,7 +36,7 @@ install_requires = pydantic[dotenv]>=1.8.2,<3.0.0 python-dateutil>=2.8.2 readerwriterlock>=1.0.9 - sqlparse@git+https://github.com/lorenzhs/sqlparse.git@8d379386c1c3e103ee67ef6582ea1b7c2296aa5b + sqlparse==0.5.5 trio>=0.22.0 truststore>=0.10;python_version>="3.10" python_requires = >=3.9 diff --git a/src/firebolt/common/statement_formatter.py b/src/firebolt/common/statement_formatter.py index 27a59aea71f..636b92a21e5 100644 --- a/src/firebolt/common/statement_formatter.py +++ b/src/firebolt/common/statement_formatter.py @@ -3,6 +3,10 @@ from typing import Dict, List, Optional, Sequence, Union from sqlparse import parse as parse_sql # type: ignore +from sqlparse import tokens as _T +from sqlparse.engine.statement_splitter import ( + StatementSplitter as _StatementSplitter, +) from sqlparse.sql import ( # type: ignore Comment, Comparison, @@ -22,6 +26,78 @@ NotSupportedError, ) + +def _patched_change_splitlevel(self, ttype, value): # type: ignore[no-untyped-def] + """Patched version of StatementSplitter._change_splitlevel. + + Fixes CASE...END level tracking outside of CREATE blocks. + See: https://github.com/andialbrecht/sqlparse/pull/839 + """ + if ttype is _T.Punctuation and value == "(": + return 1 + elif ttype is _T.Punctuation and value == ")": + return -1 + elif ttype not in _T.Keyword: + return 0 + + unified = value.upper() + + if ttype is _T.Keyword.DDL and unified.startswith("CREATE"): + self._is_create = True + return 0 + + if unified == "DECLARE" and self._is_create and self._begin_depth == 0: + self._in_declare = True + return 1 + + if unified == "BEGIN": + self._begin_depth += 1 + self._seen_begin = True + if self._is_create: + return 1 + return 0 + + if ( + self._seen_begin + and (ttype is _T.Keyword or ttype is _T.Name) + and unified + in ( + "TRANSACTION", + "WORK", + "TRAN", + "DISTRIBUTED", + "DEFERRED", + "IMMEDIATE", + "EXCLUSIVE", + ) + ): + self._begin_depth = max(0, self._begin_depth - 1) + self._seen_begin = False + return 0 + + if unified == "END": + if not self._in_case: + self._begin_depth = max(0, self._begin_depth - 1) + else: + self._in_case = False + return -1 + + if unified == "CASE": + self._in_case = True + return 1 + + if unified in ("IF", "FOR", "WHILE") and self._is_create and self._begin_depth > 0: + return 1 + + if unified in ("END IF", "END FOR", "END WHILE"): + return -1 + + return 0 + + +setattr(_StatementSplitter, "_change_splitlevel", _patched_change_splitlevel) + + escape_chars_v2 = { "\0": "\\0", "'": "''",