Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/firebolt/async_db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
fix_url_schema,
parse_url_and_params,
validate_engine_name_and_url_v1,
validate_firebolt_parameters_v1,
validate_firebolt_parameters_v2,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -292,6 +294,7 @@ async def connect(
url: Optional[str] = None,
autocommit: bool = True,
additional_parameters: Dict[str, Any] = {},
client_side_lb: Optional[bool] = None,
) -> Connection:
# auth parameter is optional in function signature
# but is required to connect.
Expand All @@ -313,14 +316,22 @@ async def connect(
if auth_version == FireboltAuthVersion.CORE:
# Verify that Core-incompatible parameters are not provided
validate_firebolt_core_parameters(account_name, engine_name, engine_url)
if client_side_lb == None:
# When using Core, client_side_lb is True by default
client_side_lb = True

return connect_core(
auth=auth,
user_agent_header=user_agent_header,
database=database,
connection_url=url,
autocommit=autocommit,
client_side_lb=client_side_lb,
)
elif auth_version == FireboltAuthVersion.V2:
# Verify that v2-incompatible parameters are not provided
validate_firebolt_parameters_v2(client_side_lb)

assert account_name is not None
return await connect_v2(
auth=auth,
Expand All @@ -334,6 +345,9 @@ async def connect(
autocommit=autocommit,
)
elif auth_version == FireboltAuthVersion.V1:
# Verify that v1-incompatible parameters are not provided
validate_firebolt_parameters_v1(client_side_lb)

return await connect_v1(
auth=auth,
user_agent_header=user_agent_header,
Expand Down Expand Up @@ -490,6 +504,7 @@ def connect_core(
database: Optional[str] = None,
connection_url: Optional[str] = None,
autocommit: bool = True,
client_side_lb: bool = False,
) -> Connection:
"""Connect to Firebolt Core.

Expand Down Expand Up @@ -519,6 +534,7 @@ def connect_core(
timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None),
headers={"User-Agent": user_agent_header},
verify=ctx,
client_side_lb=client_side_lb,
)

return Connection(
Expand Down
19 changes: 17 additions & 2 deletions src/firebolt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,16 @@ def clone(self) -> "Client":

class Client(FireboltClientMixin, HttpxClient, metaclass=ABCMeta):
def __init__(self, *args: Any, **kwargs: Any):
# We pop it from kwargs because it's unknown to HttpxClient which won't accept it
client_side_lb = kwargs.pop("client_side_lb", False)

super().__init__(
*args,
**kwargs,
transport=KeepaliveTransport(verify=kwargs.get("verify", True)),
transport=KeepaliveTransport(
verify=kwargs.get("verify", True),
client_side_lb=client_side_lb,
),
)

@property
Expand Down Expand Up @@ -139,13 +145,15 @@ def __init__(
auth: Auth,
account_name: str,
api_endpoint: str = DEFAULT_API_URL,
client_side_lb: bool = False,
**kwargs: Any,
):
super().__init__(
*args,
auth=auth,
account_name=account_name,
api_endpoint=api_endpoint,
client_side_lb=client_side_lb,
**kwargs,
)

Expand Down Expand Up @@ -273,10 +281,15 @@ def _resolve_engine_url(self, engine_name: str) -> str:

class AsyncClient(FireboltClientMixin, HttpxAsyncClient, metaclass=ABCMeta):
def __init__(self, *args: Any, **kwargs: Any):
# We pop it from kwargs because it's unknown to HttpxClient which won't accept it
client_side_lb = kwargs.pop("client_side_lb", False)
super().__init__(
*args,
**kwargs,
transport=AsyncKeepaliveTransport(verify=kwargs.get("verify", True)),
transport=AsyncKeepaliveTransport(
verify=kwargs.get("verify", True),
client_side_lb=client_side_lb,
),
)

@property
Expand Down Expand Up @@ -306,13 +319,15 @@ def __init__(
auth: Auth,
account_name: str,
api_endpoint: str = DEFAULT_API_URL,
client_side_lb: bool = False,
**kwargs: Any,
):
super().__init__(
*args,
auth=auth,
account_name=account_name,
api_endpoint=api_endpoint,
client_side_lb=client_side_lb,
**kwargs,
)

Expand Down
169 changes: 159 additions & 10 deletions src/firebolt/client/http_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import anyio
import socket
from typing import Any
import threading
import time
from typing import Any, Dict, List

try:
from httpcore.backends.auto import AutoBackend # type: ignore
Expand All @@ -8,7 +11,7 @@
from httpcore._backends.auto import AutoBackend # type: ignore
from httpcore._backends.sync import SyncBackend # type: ignore

from httpx import AsyncHTTPTransport, HTTPTransport
from httpx import AsyncHTTPTransport, HTTPTransport, Request, Response

from firebolt.common.constants import KEEPALIVE_FLAG, KEEPIDLE_RATE

Expand All @@ -29,6 +32,45 @@ def override_stream(stream): # type: ignore [no-untyped-def]
return stream


class DNSCache:
def __init__(self, ttl: float = 30.0):
self.ttl = ttl
self.cache: Dict[str, List[str]] = {}
self.expiry: Dict[str, float] = {}
self.indices: Dict[str, int] = {}
self._lock = threading.Lock()

def get_ip_round_robin(self, hostname: str) -> str:
now = time.monotonic()

with self._lock:
cached_ips = self.cache.get(hostname)
expires_at = self.expiry.get(hostname, 0)

if not cached_ips or now >= expires_at:
try:
_, _, new_ips = socket.gethostbyname_ex(hostname)
if new_ips:
self.cache[hostname] = sorted(new_ips)
self.expiry[hostname] = now + self.ttl
cached_ips = self.cache[hostname]
except Exception:
if not cached_ips:
raise

# explicit check as hint for type checkers
if not cached_ips:
raise RuntimeError(f"Could not resolve or find cached IPs for {hostname}")

# calculate round robin index
current_index = self.indices.get(hostname, 0)
target_ip = cached_ips[current_index % len(cached_ips)]

self.indices[hostname] = (current_index + 1) % len(cached_ips)

return target_ip


class AsyncOverriddenHttpBackend(AutoBackend):
"""
`OverriddenHttpBackend` is a short-term solution for the TCP
Expand Down Expand Up @@ -68,18 +110,125 @@ def open_tcp_stream(self, *args, **kwargs): # type: ignore


class AsyncKeepaliveTransport(AsyncHTTPTransport):
_dns_cache = DNSCache(ttl=30.0)

def __init__(self, *args: Any, **kwargs: Any) -> None:
self._client_side_lb = kwargs.pop("client_side_lb", False)
super().__init__(*args, **kwargs)
if hasattr(self._pool, "_network_backend"):
self._pool._network_backend = AsyncOverriddenHttpBackend() # type: ignore
if hasattr(self._pool, "_backend"):
self._pool._backend = AsyncOverriddenHttpBackend() # type: ignore
self._apply_custom_backend(self)
self._transport_kwargs = kwargs
self._ip_transports: Dict[str, AsyncHTTPTransport] = {}
self._lock = anyio.Lock()

def _apply_custom_backend(self, transport: AsyncHTTPTransport) -> None:
pool = getattr(transport, "_pool", None)
if pool:
for attr in ["_network_backend", "_backend"]:
if hasattr(pool, attr):
setattr(pool, attr, AsyncOverriddenHttpBackend())

async def handle_async_request(self, request: Request) -> Response:
if not self._client_side_lb:
return await super().handle_async_request(request)

hostname = request.url.host

try:
target_ip = self._dns_cache.get_ip_round_robin(hostname)
except Exception:
return await super().handle_async_request(request)

# Lazy-load the lock to ensure it's bound to the correct event loop
if self._lock is None:
self._lock = anyio.Lock()

async with self._lock:
if target_ip not in self._ip_transports:
new_transport = AsyncHTTPTransport(**self._transport_kwargs)
self._apply_custom_backend(new_transport)
self._ip_transports[target_ip] = new_transport
sub_transport = self._ip_transports[target_ip]

original_url = request.url
request.url = request.url.copy_with(host=target_ip)
try:
return await sub_transport.handle_async_request(request)
finally:
request.url = original_url

async def aclose(self) -> None:
"""
Close the primary transport and all sub-transports created for load balancing.
"""
# Close the base transport first
await super().aclose()

# Close all child transports created for specific IPs
if self._ip_transports:
async with anyio.create_task_group() as tg:
# Gather all transports in task group and close them
for transport in self._ip_transports.values():
tg.start_soon(transport.aclose)

self._ip_transports.clear()


class KeepaliveTransport(HTTPTransport):
_dns_cache = DNSCache(ttl=30.0)

def __init__(self, *args: Any, **kwargs: Any) -> None:
self._client_side_lb = kwargs.pop("client_side_lb", False)
super().__init__(*args, **kwargs)
if hasattr(self._pool, "_network_backend"):
self._pool._network_backend = OverriddenHttpBackend() # type: ignore
if hasattr(self._pool, "_backend"):
self._pool._backend = OverriddenHttpBackend() # type: ignore
self._apply_custom_backend(self)
self._transport_kwargs = kwargs
self._ip_transports: Dict[str, HTTPTransport] = {}
self._lock = threading.Lock()

def _apply_custom_backend(self, transport: HTTPTransport) -> None:
pool = getattr(transport, "_pool", None)
if pool:
for attr in ["_network_backend", "_backend"]:
if hasattr(pool, attr):
setattr(pool, attr, OverriddenHttpBackend())

def handle_request(self, request: Request) -> Response:
if not self._client_side_lb:
return super().handle_request(request)

hostname = request.url.host

try:
target_ip = self._dns_cache.get_ip_round_robin(hostname)
except Exception:
return super().handle_request(request)

with self._lock:
if target_ip not in self._ip_transports:
new_transport = HTTPTransport(**self._transport_kwargs)
self._apply_custom_backend(new_transport)
self._ip_transports[target_ip] = new_transport
sub_transport = self._ip_transports[target_ip]

original_url = request.url
request.url = request.url.copy_with(host=target_ip)
try:
return sub_transport.handle_request(request)
finally:
request.url = original_url

def close(self) -> None:
"""
Close the primary transport and all sub-transports.
"""
# Close the base transport first
super().close()

# Close all child transports created for specific IPs
with self._lock:
for transport in self._ip_transports.values():
try:
transport.close()
except Exception:
# Best effort to close others if one fails
pass
self._ip_transports.clear()
15 changes: 15 additions & 0 deletions src/firebolt/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
fix_url_schema,
parse_url_and_params,
validate_engine_name_and_url_v1,
validate_firebolt_parameters_v1,
validate_firebolt_parameters_v2,
)

logger = logging.getLogger(__name__)
Expand All @@ -61,6 +63,7 @@ def connect(
url: Optional[str] = None,
autocommit: bool = True,
additional_parameters: Dict[str, Any] = {},
client_side_lb: Optional[bool] = None,
) -> Connection:
# auth parameter is optional in function signature
# but is required to connect.
Expand All @@ -82,15 +85,22 @@ def connect(
if auth_version == FireboltAuthVersion.CORE:
# Verify that Core-incompatible parameters are not provided
validate_firebolt_core_parameters(account_name, engine_name, engine_url)
if client_side_lb == None:
# When using Core, client_side_lb is True by default
client_side_lb = True

return connect_core(
auth=auth,
user_agent_header=user_agent_header,
database=database,
connection_url=url,
autocommit=autocommit,
client_side_lb=client_side_lb,
)
elif auth_version == FireboltAuthVersion.V2:
# Verify that v2-incompatible parameters are not provided
validate_firebolt_parameters_v2(client_side_lb)

assert account_name is not None
return connect_v2(
auth=auth,
Expand All @@ -104,6 +114,9 @@ def connect(
autocommit=autocommit,
)
elif auth_version == FireboltAuthVersion.V1:
# Verify that v1-incompatible parameters are not provided
validate_firebolt_parameters_v1(client_side_lb)

return connect_v1(
auth=auth,
user_agent_header=user_agent_header,
Expand Down Expand Up @@ -490,6 +503,7 @@ def connect_core(
database: Optional[str] = None,
connection_url: Optional[str] = None,
autocommit: bool = True,
client_side_lb: bool = True,
) -> Connection:
"""Connect to Firebolt Core.

Expand Down Expand Up @@ -520,6 +534,7 @@ def connect_core(
timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None),
headers={"User-Agent": user_agent_header},
verify=ctx,
client_side_lb=client_side_lb,
)

return Connection(
Expand Down
Loading
Loading