Skip to content

Commit 287546c

Browse files
committed
fix: add support for varint and bitstring
1 parent 2740f73 commit 287546c

1 file changed

Lines changed: 119 additions & 6 deletions

File tree

src/query_farm_server_base/duckdb_serialized_expression.py

Lines changed: 119 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,102 @@
1+
import base64
2+
import codecs
13
from typing import Any
24

35

46
def _quote_string(value: str) -> str:
7+
assert isinstance(value, str)
58
return f"'{value}'"
69

710

11+
def decode_base64_value(value: Any) -> bytes:
12+
assert "base64" in value
13+
return base64.b64decode(value["base64"])
14+
15+
16+
def decode_bitstring(data: bytes) -> str:
17+
if not data or len(data) < 2:
18+
return ""
19+
20+
padding_bits = data[0]
21+
bit_data = data[1:]
22+
23+
# Convert all bytes to bits
24+
bits = "".join(f"{byte:08b}" for byte in bit_data)
25+
26+
# Remove the padding bits from the end
27+
if padding_bits:
28+
bits = bits[padding_bits:]
29+
30+
return bits
31+
32+
33+
def varint_get_byte_array(blob: bytes) -> tuple[list[int], bool]:
34+
if len(blob) < 4:
35+
raise ValueError("Invalid blob size.")
36+
37+
# Determine if the number is negative
38+
is_negative = (blob[0] & 0x80) == 0
39+
40+
# Extract byte array starting from the 4th byte
41+
if is_negative:
42+
byte_array = [~b & 0xFF for b in blob[3:]] # Apply bitwise NOT and mask to 8 bits
43+
else:
44+
byte_array = list(blob[3:])
45+
46+
return byte_array, is_negative
47+
48+
49+
def varint_to_varchar(blob: bytes) -> str:
50+
decimal_string = ""
51+
byte_array, is_negative = varint_get_byte_array(blob)
52+
digits: list[int] = []
53+
54+
# Constants matching your C++ code (update if needed)
55+
DIGIT_BYTES = 4 # Assuming 4 bytes per digit (like a uint32_t)
56+
DIGIT_BITS = 32
57+
DECIMAL_BASE = 1000000000 # Typically 10^9 for efficient base conversion
58+
DECIMAL_SHIFT = 9 # Number of decimal digits in DECIMAL_BASE
59+
60+
# Pad the byte array so we can process in DIGIT_BYTES chunks without conditionals
61+
padding_size = (-len(byte_array)) & (DIGIT_BYTES - 1)
62+
byte_array = [0] * padding_size + byte_array
63+
64+
for i in range(0, len(byte_array), DIGIT_BYTES):
65+
hi = 0
66+
for j in range(DIGIT_BYTES):
67+
hi |= byte_array[i + j] << (8 * (DIGIT_BYTES - j - 1))
68+
69+
for j in range(len(digits)):
70+
tmp = (digits[j] << DIGIT_BITS) | hi
71+
hi = tmp // DECIMAL_BASE
72+
digits[j] = tmp - DECIMAL_BASE * hi
73+
74+
while hi:
75+
digits.append(hi % DECIMAL_BASE)
76+
hi //= DECIMAL_BASE
77+
78+
if not digits:
79+
digits.append(0)
80+
81+
for i in range(len(digits) - 1):
82+
remain = digits[i]
83+
for _ in range(DECIMAL_SHIFT):
84+
decimal_string += str(remain % 10)
85+
remain //= 10
86+
87+
remain = digits[-1]
88+
while remain != 0:
89+
decimal_string += str(remain % 10)
90+
remain //= 10
91+
92+
if is_negative:
93+
decimal_string += "-"
94+
95+
# Reverse the string to get the correct number
96+
decimal_string = decimal_string[::-1]
97+
return decimal_string if decimal_string else "0"
98+
99+
8100
comparison_type_to_operator: dict[str, str] = {
9101
"COMPARE_EQUAL": "=",
10102
"COMPARE_NOTEQUAL": "!=",
@@ -24,7 +116,7 @@ def comparison_type_to_string_(comparison_type: str) -> str:
24116
raise NotImplementedError(f"Comparison type {comparison_type} is not supported")
25117

26118

27-
simple_types = {
119+
non_parameterized_duckdb_types = {
28120
"BIGINT",
29121
"BIT",
30122
"BLOB",
@@ -70,7 +162,7 @@ def _type_to_sql_type(type: dict[str, Any]) -> str:
70162
)
71163
+ ")"
72164
)
73-
elif type["id"] in simple_types:
165+
elif type["id"] in non_parameterized_duckdb_types:
74166
return type["id"]
75167
elif type["id"] == "DECIMAL":
76168
return f"DECIMAL({type['type_info']['width']}, {type['type_info']['scale']})"
@@ -106,15 +198,36 @@ def e_to_s(expr: dict[str, Any]) -> str:
106198
elif expression["expression_class"] == "BOUND_CONSTANT":
107199
if expression["value"]["is_null"]:
108200
return "null"
109-
if expression["value"]["type"]["id"] in (
201+
elif expression["value"]["type"]["id"] == "VARINT":
202+
varint_value = expression["value"]["value"]
203+
if isinstance(varint_value, str):
204+
varint_bytes = codecs.decode(varint_value, "unicode_escape").encode("utf-8")
205+
elif "base64" in varint_value:
206+
varint_bytes = decode_base64_value(varint_value)
207+
else:
208+
raise Exception(
209+
"Varint value must be a base64 encoded string or a string with unicode escape sequences"
210+
)
211+
212+
return varint_to_varchar(varint_bytes)
213+
elif expression["value"]["type"]["id"] in (
110214
"VARCHAR",
111215
"BLOB",
112-
"BITSTRING",
113-
"BIT",
114-
"VARINT",
115216
"UUID",
116217
):
117218
return _quote_string(expression["value"]["value"])
219+
elif expression["value"]["type"]["id"] == "BIT":
220+
bit_value = expression["value"]["value"]
221+
222+
if isinstance(bit_value, str):
223+
bitstring_bytes = codecs.decode(bit_value, "unicode_escape").encode("utf-8")
224+
elif "base64" in bit_value:
225+
bitstring_bytes = decode_base64_value(bit_value)
226+
else:
227+
raise Exception(
228+
"Bit string value must be a base64 encoded string or a string with unicode escape sequences"
229+
)
230+
return decode_bitstring(bitstring_bytes)
118231
elif expression["value"]["type"]["id"] == "BOOLEAN":
119232
return "True" if expression["value"]["value"] else "False"
120233
elif expression["value"]["type"]["id"] == "NULL":

0 commit comments

Comments
 (0)