Skip to content

Commit e3fba86

Browse files
committed
fixes
1 parent 275b2d4 commit e3fba86

1 file changed

Lines changed: 68 additions & 38 deletions

File tree

src/query_farm_server_base/duckdb_serialized_expression.py

Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,41 @@
22
import codecs
33
import math
44
import uuid
5-
from datetime import date, datetime, time, timedelta, timezone
5+
from datetime import UTC, date, datetime, time, timedelta
66
from decimal import Decimal
77
from typing import Any
88

99

10-
def interpret_timestamp_with_time_zone(value: str) -> str:
10+
def decode_timestamptz_value(value: str) -> str:
11+
"""
12+
Convert a DuckDB serialized timestamp with time zone into a SQL representation.
13+
"""
1114
return (
1215
"TIMESTAMPTZ '"
13-
+ datetime.fromtimestamp(int(value) / 1_000_000, tz=timezone.utc).strftime(
14-
"%Y-%m-%d %H:%M:%S.%f"
15-
)
16+
+ datetime.fromtimestamp(int(value) / 1_000_000, tz=UTC).strftime("%Y-%m-%d %H:%M:%S.%f")
1617
+ "'"
1718
)
1819

1920

2021
def _quote_string(value: str) -> str:
2122
assert isinstance(value, str)
22-
return f"'{value}'"
23+
return f"'{codecs.encode(value, 'unicode_escape')}'"
2324

2425

25-
def decode_base64_value(value: Any) -> bytes:
26+
def decode_base64_value(value: dict[str, str]) -> bytes:
27+
"""
28+
Decode a base64 encoded value into a series of bytes, the
29+
Base64 value is produced by the Airport extension's JSON serialization.
30+
"""
2631
assert "base64" in value
2732
return base64.b64decode(value["base64"])
2833

2934

30-
def decode_bitstring(data: bytes) -> str:
35+
def decode_bitstring_value(data: bytes) -> str:
36+
"""
37+
Decode a DuckDB serialized bitstring into a string of bits.
38+
The first byte indicates the number of padding bits at the end of the bitstring.
39+
"""
3140
if not data or len(data) < 2:
3241
return ""
3342

@@ -44,7 +53,10 @@ def decode_bitstring(data: bytes) -> str:
4453
return bits
4554

4655

47-
def interpret_time(value: int) -> str:
56+
def decode_time_value(value: int) -> str:
57+
"""
58+
Convert a DuckDB serialized time value (microseconds since midnight) into a SQL time string.
59+
"""
4860
t = timedelta(microseconds=value)
4961
hours, remainder = divmod(t.seconds, 3600)
5062
minutes, seconds = divmod(remainder, 60)
@@ -53,7 +65,10 @@ def interpret_time(value: int) -> str:
5365
return result.strftime("%H:%M:%S.%f")
5466

5567

56-
def interpret_real(value: Any) -> str:
68+
def decode_real_value(value: Any) -> str:
69+
"""
70+
Convert a DuckDB serialized real value (float or double) into a SQL string.
71+
"""
5772
if math.isinf(value):
5873
if value > 0:
5974
return "'infinity'"
@@ -63,18 +78,27 @@ def interpret_real(value: Any) -> str:
6378
return value
6479

6580

66-
def interpret_timestamp_ms(value: int) -> str:
67-
dt = datetime.fromtimestamp(value / 1000, tz=timezone.utc)
81+
def decode_timestamp_ms_value(value: int) -> str:
82+
"""
83+
Convert a DuckDB serialized timestamp in milliseconds since epoch into a SQL timestamp string.
84+
"""
85+
dt = datetime.fromtimestamp(value / 1000, tz=UTC)
6886
return dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] # Trim to milliseconds
6987

7088

71-
def interpret_uhugeint(value: dict[str, Any]) -> str:
89+
def decode_uhugeint_value(value: dict[str, Any]) -> str:
90+
"""
91+
Decode a DuckDB serialized UHUUGEINT value into a string representation.
92+
"""
7293
upper = value["upper"]
7394
lower = value["lower"]
7495
return str((upper << 64) | lower)
7596

7697

77-
def interpret_hugeint(value: dict) -> str:
98+
def decode_hugeint_value(value: dict[str, Any]) -> str:
99+
"""
100+
Decode a DuckDB serialized HUGEINT value into a string representation.
101+
"""
78102
upper = value["upper"]
79103
lower = value["lower"]
80104
result = (upper << 64) | lower
@@ -86,7 +110,10 @@ def interpret_hugeint(value: dict) -> str:
86110
return str(result)
87111

88112

89-
def decode_uuid(value: dict[str, int]) -> str:
113+
def decode_uuid_value(value: dict[str, int]) -> str:
114+
"""
115+
Decode a DuckDB serialized UUID value into a string representation.
116+
"""
90117
assert "upper" in value and "lower" in value, "Invalid GUID format"
91118

92119
# Handle the two's complement for the signed upper 64 bits
@@ -105,7 +132,10 @@ def decode_uuid(value: dict[str, int]) -> str:
105132
return str(u)
106133

107134

108-
def decode_date(days: int) -> str:
135+
def decode_date_value(days: int) -> str:
136+
"""
137+
Convert a DuckDB serialized date value (days since epoch) into a SQL date string.
138+
"""
109139
if days == -2147483647:
110140
return "'-infinity'"
111141
elif days == 2147483647:
@@ -114,7 +144,10 @@ def decode_date(days: int) -> str:
114144
return f"'{formatted_date}'"
115145

116146

117-
def interpret_decimal(value: dict[str, Any]) -> Decimal:
147+
def decode_decimal_value(value: dict[str, Any]) -> Decimal:
148+
"""
149+
Decode a DuckDB serialized decimal value into a Decimal object.
150+
"""
118151
type_info = value["type"]["type_info"]
119152
scale = type_info["scale"]
120153
v = value["value"]
@@ -141,25 +174,24 @@ def interpret_decimal(value: dict[str, Any]) -> Decimal:
141174
return decimal_value / Decimal(10) ** scale
142175

143176

144-
def varint_get_byte_array(blob: bytes) -> tuple[list[int], bool]:
177+
def _varint_get_byte_array(blob: bytes) -> tuple[list[int], bool]:
145178
if len(blob) < 4:
146179
raise ValueError("Invalid blob size.")
147180

148181
# Determine if the number is negative
149182
is_negative = (blob[0] & 0x80) == 0
150183

151184
# Extract byte array starting from the 4th byte
152-
if is_negative:
153-
byte_array = [~b & 0xFF for b in blob[3:]] # Apply bitwise NOT and mask to 8 bits
154-
else:
155-
byte_array = list(blob[3:])
156-
185+
byte_array = [~b & 255 for b in blob[3:]] if is_negative else list(blob[3:])
157186
return byte_array, is_negative
158187

159188

160-
def varint_to_varchar(blob: bytes) -> str:
189+
def decode_varint_value(blob: bytes) -> str:
190+
"""
191+
Decode a DuckDB serialized VARINT value into a decimal string.
192+
"""
161193
decimal_string = ""
162-
byte_array, is_negative = varint_get_byte_array(blob)
194+
byte_array, is_negative = _varint_get_byte_array(blob)
163195
digits: list[int] = []
164196

165197
# Constants matching your C++ code (update if needed)
@@ -320,9 +352,9 @@ def e_to_s(expr: dict[str, Any]) -> str:
320352
"Varint value must be a base64 encoded string or a string with unicode escape sequences"
321353
)
322354

323-
return varint_to_varchar(varint_bytes)
355+
return decode_varint_value(varint_bytes)
324356
elif expression["value"]["type"]["id"] == "UUID":
325-
return decode_uuid(expression["value"]["value"])
357+
return decode_uuid_value(expression["value"]["value"])
326358
elif expression["value"]["type"]["id"] in (
327359
"VARCHAR",
328360
"BLOB",
@@ -339,22 +371,22 @@ def e_to_s(expr: dict[str, Any]) -> str:
339371
raise Exception(
340372
"Bit string value must be a base64 encoded string or a string with unicode escape sequences"
341373
)
342-
return decode_bitstring(bitstring_bytes)
374+
return decode_bitstring_value(bitstring_bytes)
343375
elif expression["value"]["type"]["id"] == "BOOLEAN":
344376
return "True" if expression["value"]["value"] else "False"
345377
elif expression["value"]["type"]["id"] == "NULL":
346378
return "null"
347379
elif expression["value"]["type"]["id"] == "DATE":
348-
return decode_date(expression["value"]["value"])
380+
return decode_date_value(expression["value"]["value"])
349381
elif expression["value"]["type"]["id"] == "DECIMAL":
350-
decimal_value = interpret_decimal(expression["value"])
382+
decimal_value = decode_decimal_value(expression["value"])
351383
return str(decimal_value)
352384
elif expression["value"]["type"]["id"] in ("FLOAT", "DOUBLE"):
353-
return interpret_real(expression["value"]["value"])
385+
return decode_real_value(expression["value"]["value"])
354386
elif expression["value"]["type"]["id"] == "UHUGEINT":
355-
return interpret_uhugeint(expression["value"]["value"])
387+
return decode_uhugeint_value(expression["value"]["value"])
356388
elif expression["value"]["type"]["id"] == "HUGEINT":
357-
return interpret_hugeint(expression["value"]["value"])
389+
return decode_hugeint_value(expression["value"]["value"])
358390
elif expression["value"]["type"]["id"] in (
359391
"BIGINT",
360392
"INTEGER",
@@ -372,17 +404,15 @@ def e_to_s(expr: dict[str, Any]) -> str:
372404
elif expression["value"]["type"]["id"] == "TIMESTAMP":
373405
return f"make_timestamp({expression['value']['value']}::bigint)"
374406
elif expression["value"]["type"]["id"] == "TIMESTAMP WITH TIME ZONE":
375-
return interpret_timestamp_with_time_zone(expression["value"]["value"])
407+
return decode_timestamptz_value(expression["value"]["value"])
376408
elif expression["value"]["type"]["id"] == "TIME":
377-
return f"TIME '{interpret_time(expression['value']['value'])}'"
409+
return f"TIME '{decode_time_value(expression['value']['value'])}'"
378410
elif expression["value"]["type"]["id"] == "TIMESTAMP_S":
379411
return f"make_timestamp({expression['value']['value']}::bigint*1000000)"
380412
elif expression["value"]["type"]["id"] == "TIMESTAMP_MS":
381-
return f"'{interpret_timestamp_ms(expression['value']['value'])}'"
413+
return f"'{decode_timestamp_ms_value(expression['value']['value'])}'"
382414
elif expression["value"]["type"]["id"] == "TIMESTAMP_NS":
383415
return f"make_timestamp_ns({expression['value']['value']}::bigint)"
384-
# elif expression["value"]["type"]["id"] == "TIMESTAMP WITH TIME ZONE":
385-
# return f"make_timestamp({expression['value']['value']}::bigint)"
386416
elif expression["value"]["type"]["id"] == "LIST":
387417
if expression["type"] == "VALUE_CONSTANT":
388418
# So the children in this case aren't expressions, they are constants.

0 commit comments

Comments
 (0)