Skip to content

Commit 60c2ffd

Browse files
committed
work in progress
1 parent e626bd5 commit 60c2ffd

2 files changed

Lines changed: 59 additions & 27 deletions

File tree

src/query_farm_server_base/flight_inventory.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import pyarrow as pa
77
import pyarrow.flight as flight
88
import structlog
9-
import zstandard as zstd
9+
from pydantic import BaseModel
1010

11-
from . import schema_uploader
11+
from . import schema_uploader, server
1212

1313
log = structlog.get_logger()
1414

@@ -35,6 +35,28 @@ class SchemaInfo:
3535
tags: dict[str, Any]
3636

3737

38+
class AirportSerializedContentsWithSHA256Hash(BaseModel):
39+
# This is the sha256 hash of the serialized data
40+
sha256: str
41+
# This is the url to the serialized data
42+
url: str | None
43+
# This is the serialized data, if we are doing inline serialization
44+
serialized: str | None
45+
46+
47+
class AirportSerializedSchema(BaseModel):
48+
name: str
49+
description: str
50+
tags: dict[str, str]
51+
contents: AirportSerializedContentsWithSHA256Hash
52+
53+
54+
class AirportSerializedCatalogRoot(BaseModel):
55+
contents: AirportSerializedContentsWithSHA256Hash
56+
schemas: list[AirportSerializedSchema]
57+
version_info: server.GetCatalogVersionResult
58+
59+
3860
class FlightSchemaMetadata:
3961
def __init__(
4062
self,
@@ -83,7 +105,7 @@ def upload_and_generate_schema_list(
83105
catalog_version_fixed: bool,
84106
enable_sha256_caching: bool = True,
85107
serialize_inline: bool = False,
86-
) -> bytes:
108+
) -> AirportSerializedCatalogRoot:
87109
serialized_schema_data: list[dict[str, Any]] = []
88110
s3_client = boto3.client("s3")
89111
all_schema_flights_serialized: list[Any] = []
@@ -128,7 +150,7 @@ def upload_and_generate_schema_list(
128150

129151
serialized_schema_data.append(
130152
{
131-
"schema": schema_name,
153+
"name": schema_name,
132154
"description": schema_details[schema_name].description
133155
if schema_name in schema_details
134156
else "",
@@ -153,18 +175,14 @@ def upload_and_generate_schema_list(
153175
)
154176
all_schema_path = f"{SCHEMA_BASE_URL}/{all_schema_contents_upload.s3_path}"
155177

156-
schemas_list_data = {
157-
"schemas": serialized_schema_data,
158-
# This encodes the contents of all schemas in one file.
159-
"contents": {
160-
"sha256": all_schema_contents_upload.sha256_hash,
161-
"url": all_schema_path if not serialize_inline else None,
162-
"serialized": all_schema_contents_upload.compressed_data if serialize_inline else None,
163-
},
164-
"version_info": [catalog_version, catalog_version_fixed],
165-
}
166-
167-
packed_data = msgpack.packb(schemas_list_data)
168-
compressor = zstd.ZstdCompressor(level=SCHEMA_TOP_LEVEL_COMPRESSION_LEVEL)
169-
compressed_data = compressor.compress(packed_data)
170-
return msgpack.packb([len(packed_data), compressed_data])
178+
return AirportSerializedCatalogRoot(
179+
schemas=serialized_schema_data,
180+
contents=AirportSerializedContentsWithSHA256Hash(
181+
sha256=all_schema_contents_upload.sha256_hash,
182+
url=all_schema_path if not serialize_inline else None,
183+
serialized=all_schema_contents_upload.compressed_data if serialize_inline else None,
184+
),
185+
version_info=server.GetCatalogVersionResult(
186+
catalog_version=catalog_version, is_fixed=catalog_version_fixed
187+
),
188+
)

src/query_farm_server_base/server.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,17 @@
77
import msgpack
88
import pyarrow.flight as flight
99
import structlog
10+
import zstandard as zstd
11+
from pydantic import BaseModel
1012

1113
from query_farm_server_base import action_decoders
1214

13-
from . import auth, middleware
15+
from . import auth, flight_inventory, middleware
16+
17+
# This is the level of ZStandard compression to use for the top-level schema
18+
# JSON information.
19+
SCHEMA_TOP_LEVEL_COMPRESSION_LEVEL = 12
20+
1421

1522
P = ParamSpec("P")
1623
R = TypeVar("R")
@@ -29,6 +36,11 @@ class CallContext(Generic[AccountType, TokenType]):
2936
logger: structlog.BoundLogger
3037

3138

39+
class GetCatalogVersionResult(BaseModel):
40+
catalog_version: int
41+
is_fixed: bool
42+
43+
3244
# Setup a decorator to log the action and its parameters.
3345
def log_action() -> Callable[[Callable[P, R]], Callable[P, R]]:
3446
def decorator(func: Callable[P, R]) -> Callable[P, R]:
@@ -245,7 +257,7 @@ def action_list_schemas(
245257
*,
246258
context: CallContext[AccountType, TokenType],
247259
parameters: action_decoders.ListSchemasParameters,
248-
) -> list[bytes]:
260+
) -> flight_inventory.AirportSerializedCatalogRoot:
249261
self._unimplemented_action("list_schemas")
250262

251263
@log_action()
@@ -326,7 +338,7 @@ def action_catalog_version(
326338
*,
327339
context: CallContext[AccountType, TokenType],
328340
database_name: str,
329-
) -> tuple[int, bool]:
341+
) -> GetCatalogVersionResult:
330342
pass
331343

332344
@abstractmethod
@@ -475,12 +487,14 @@ def do_action(
475487
]
476488
)
477489
elif action.type == "list_schemas":
478-
return self.pack_result(
479-
self.action_list_schemas(
480-
context=call_context,
481-
parameters=action_decoders.list_schemas(action),
482-
)
490+
schemas_result = self.action_list_schemas(
491+
context=call_context,
492+
parameters=action_decoders.list_schemas(action),
483493
)
494+
packed_data = msgpack.packb(schemas_result.model_dump())
495+
compressor = zstd.ZstdCompressor(level=SCHEMA_TOP_LEVEL_COMPRESSION_LEVEL)
496+
compressed_data = compressor.compress(packed_data)
497+
return iter([msgpack.packb([len(packed_data), compressed_data])])
484498
elif action.type == "remove_column":
485499
self.action_remove_column(
486500
context=call_context,

0 commit comments

Comments
 (0)