Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions datajunction-server/.env
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,27 @@ QUERY_SERVICE=http://djqs:8001
SECRET=a-fake-secretkey
NODE_LIST_MAX=10000

# Rate Limiting Configuration
# Enable/disable rate limiting (default: true)
RATE_LIMITING_ENABLED=false

# Cache backend (OPTIONAL - uses in-memory cache if not configured)
# REDIS_CACHE=redis://localhost:6379/0

# Time window for rate limiting in seconds (default: 1)
RATE_LIMIT_WINDOW_SECONDS=1

# Rate limits (requests per second per user)
# Adjusted for actual traffic: 0.5-5 RPS incoming
RATE_LIMIT_DEFAULT_RPS=3 # Default for most endpoints
RATE_LIMIT_STANDARD_RPS=5 # For standard GET operations
RATE_LIMIT_EXPENSIVE_RPS=2 # For expensive operations (/sql/*, /data/*, /graphql)

# Writer DB (required)
WRITER_DB__URI=postgresql+psycopg://dj:dj@postgres_metadata:5432/dj
WRITER_DB__POOL_SIZE=20
WRITER_DB__MAX_OVERFLOW=20
WRITER_DB__POOL_SIZE
WRITER_DB__POOL_SIZE=30
WRITER_DB__MAX_OVERFLOW=30
WRITER_DB__POOL_TIMEOUT=10
WRITER_DB__CONNECT_TIMEOUT=5
WRITER_DB__POOL_PRE_PING=true
Expand All @@ -17,8 +34,8 @@ WRITER_DB__KEEPALIVES_COUNT=5

# Reader DB (optional)
READER_DB__URI=postgresql+psycopg://readonly_user:readonly_pass@postgres_metadata:5432/dj
READER_DB__POOL_SIZE=10
READER_DB__MAX_OVERFLOW=10
READER_DB__POOL_SIZE=40
READER_DB__MAX_OVERFLOW=40
READER_DB__POOL_TIMEOUT=5
READER_DB__CONNECT_TIMEOUT=5
READER_DB__POOL_PRE_PING=true
Expand Down
13 changes: 8 additions & 5 deletions datajunction-server/datajunction_server/api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,11 +527,14 @@ async def validate_cube(
message=("Metrics and dimensions must be part of a common catalog"),
)

await validate_shared_dimensions(
session,
metric_nodes,
dimension_names,
)
# Only validate shared dimensions if dimensions were actually requested
# This avoids expensive dimension graph loading when dimensions=[]
if dimension_names:
await validate_shared_dimensions(
session,
metric_nodes,
dimension_names,
)
return metrics, metric_nodes, list(dimension_nodes.values()), dimensions, catalog


Expand Down
55 changes: 48 additions & 7 deletions datajunction-server/datajunction_server/api/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from datajunction_server.database.queryrequest import QueryBuildType
from datajunction_server.errors import DJInvalidInputException
from datajunction_server.internal.access.authentication.http import SecureAPIRouter
from datajunction_server.internal.rate_limiting import enforce_rate_limit
from datajunction_server.models.metric import TranslatedSQL, V3TranslatedSQL
from datajunction_server.models.node_type import NodeType
from datajunction_server.models.query import V3ColumnMetadata
Expand All @@ -53,7 +54,14 @@

_logger = logging.getLogger(__name__)
settings = get_settings()
router = SecureAPIRouter(tags=["sql"])

# SQL router with rate limiting applied to all /sql/* endpoints
router = SecureAPIRouter(
tags=["sql"],
dependencies=[Depends(enforce_rate_limit)]
if settings.rate_limiting_enabled
else [],
)


@router.get(
Expand Down Expand Up @@ -578,13 +586,30 @@ async def get_sql_for_metrics(
"""
Return SQL for a set of metrics with dimensions and filters
"""
import time

start_time = time.time()

# Label this session for debugging
session.info["session_label"] = "initial node loading"

# make sure all metrics exist and have correct node type
nodes = [
await Node.get_by_name(session, node, raise_if_not_exists=True)
for node in metrics
]
non_metric_nodes = [node for node in nodes if node and node.type != NodeType.METRIC]
# Use get_by_names (plural) to fetch all nodes in a single query instead of N queries
t1 = time.time()
nodes = await Node.get_by_names(session, metrics)
_logger.info(f"[PERF] get_by_names took {(time.time() - t1) * 1000:.0f}ms")

# Check if all requested nodes exist
found_names = {node.name for node in nodes}
missing_nodes = set(metrics) - found_names
if missing_nodes:
raise DJInvalidInputException(
message=f"The following nodes do not exist: {', '.join(missing_nodes)}",
http_status_code=HTTPStatus.NOT_FOUND,
)

# Validate node types
non_metric_nodes = [node for node in nodes if node and node.type != NodeType.METRIC]
if non_metric_nodes:
raise DJInvalidInputException(
message="All nodes must be of metric type, but some are not: "
Expand All @@ -596,7 +621,9 @@ async def get_sql_for_metrics(
cache=cache,
query_type=QueryBuildType.METRICS,
)
return await query_cache_manager.get_or_load(

t2 = time.time()
result = await query_cache_manager.get_or_load(
background_tasks,
request,
QueryRequestParams(
Expand All @@ -611,4 +638,18 @@ async def get_sql_for_metrics(
use_materialized=use_materialized,
ignore_errors=ignore_errors,
),
session=session, # Pass the session to reuse it
)

total_time = time.time() - start_time
build_time = time.time() - t2
_logger.info(
f"[PERF] /sql/ total={total_time * 1000:.0f}ms "
f"(build={build_time * 1000:.0f}ms, "
f"metrics={len(metrics)}, dims={len(dimensions)})",
)
_logger.info(
"[REQUEST_END] /sql/ request completed - add up QUERY_COUNT from sessions above",
)

return result
18 changes: 17 additions & 1 deletion datajunction-server/datajunction_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class Settings(BaseSettings): # pragma: no cover
# [optional] `reader_db` is used for read operations and defaults to `writer_db`
# if no dedicated read replica is configured.
writer_db: DatabaseConfig = DatabaseConfig(
uri="postgresql+psycopg://dj:dj@postgres_metadata:5432/dj",
uri="postgresql+psycopg://dj:dj@host.docker.internal:4000/dj",
)
reader_db: DatabaseConfig = writer_db

Expand Down Expand Up @@ -205,6 +205,22 @@ class Settings(BaseSettings): # pragma: no cover
github_app_private_key: Optional[str] = None # PEM-encoded private key
github_app_installation_id: Optional[str] = None

# Rate limiting configuration
# Enable/disable rate limiting globally
rate_limiting_enabled: bool = True

# Time window for rate limiting (in seconds)
rate_limit_window_seconds: int = 1

# Default rate limit for most endpoints (requests per second per user)
rate_limit_default_rps: int = 3

# Rate limit for standard read operations (GET requests)
rate_limit_standard_rps: int = 5

# Rate limit for expensive operations (data queries, SQL generation, refreshes)
rate_limit_expensive_rps: int = 2

@property
def celery(self) -> Celery:
"""
Expand Down
18 changes: 14 additions & 4 deletions datajunction-server/datajunction_server/construction/build_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from sqlalchemy import text, bindparam, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, selectinload
from sqlalchemy.orm import joinedload, selectinload, noload

from datajunction_server.internal.access.authorization import (
AccessChecker,
Expand Down Expand Up @@ -723,6 +723,8 @@ async def find_join_paths_batch(

This is O(1) database calls instead of O(nodes * depth) individual queries.
"""
# Filter out empty strings and check if we have any valid dimension names
target_dimension_names = {name for name in target_dimension_names if name}
if not target_dimension_names:
return {} # pragma: no cover

Expand Down Expand Up @@ -800,18 +802,24 @@ async def load_dimension_links_and_nodes(
.where(DimensionLink.id.in_(link_ids))
.options(
joinedload(DimensionLink.dimension).options(
noload(Node.created_by), # Prevent User N+1 queries
joinedload(Node.current).options(
noload(NodeRevision.created_by), # Prevent User N+1 queries
selectinload(NodeRevision.columns).options(
joinedload(Column.attributes).joinedload(
ColumnAttribute.attribute_type,
),
joinedload(Column.dimension),
joinedload(Column.dimension).options(
noload(Node.created_by), # Prevent User N+1 queries
),
joinedload(Column.partition),
),
joinedload(NodeRevision.catalog),
selectinload(NodeRevision.availability),
selectinload(NodeRevision.dimension_links).options(
joinedload(DimensionLink.dimension),
joinedload(DimensionLink.dimension).options(
noload(Node.created_by), # Prevent User N+1 queries
),
),
),
),
Expand Down Expand Up @@ -1323,7 +1331,9 @@ async def build(self) -> ast.Query:
Builds SQL for multiple metrics with the requested set of dimensions,
filter expressions, order by, and limit clauses.
"""
self.add_dimensions(get_dimensions_referenced_in_metrics(self.metric_nodes))
# Only load dimension graph if dimensions were actually requested
if self.dimensions:
self.add_dimensions(get_dimensions_referenced_in_metrics(self.metric_nodes))

measures_queries = await self.build_measures_queries()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, selectinload
from sqlalchemy.orm import joinedload, selectinload, noload

from datajunction_server.construction.build_v3.decomposition import is_derived_metric
from datajunction_server.models.dialect import Dialect
Expand Down Expand Up @@ -87,9 +87,13 @@ async def find_matching_cube(
),
)
.options(
noload(Node.created_by), # Prevent User N+1 queries
joinedload(Node.current).options(
selectinload(NodeRevision.cube_elements).selectinload(
Column.node_revision,
noload(NodeRevision.created_by), # Prevent User N+1 queries
selectinload(NodeRevision.cube_elements).options(
selectinload(Column.node_revision).options(
noload(NodeRevision.created_by), # Prevent User N+1 queries
),
),
joinedload(NodeRevision.availability),
selectinload(NodeRevision.materializations),
Expand Down Expand Up @@ -225,7 +229,18 @@ async def resolve_dialect_and_engine_for_metrics(
)

# Fallback: use first metric's catalog's default engine
node = await Node.get_by_name(session, metrics[0], raise_if_not_exists=True)
node = await Node.get_by_name(
session,
metrics[0],
raise_if_not_exists=True,
options=[
joinedload(Node.current).options(
noload(NodeRevision.created_by), # Prevent User N+1 queries
joinedload(NodeRevision.catalog),
),
noload(Node.created_by), # Prevent User N+1 queries
],
)
if not node: # pragma: no cover
raise ValueError(f"Metric not found: {metrics[0]}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from sqlalchemy import select, text, bindparam
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload, joinedload, load_only
from sqlalchemy.orm import selectinload, joinedload, load_only, noload

from datajunction_server.database.dimensionlink import DimensionLink
from datajunction_server.database.node import Node, NodeRevision, Column
Expand Down Expand Up @@ -275,7 +275,9 @@ async def load_dimension_links_batch(
.where(DimensionLink.id.in_(link_ids))
.options(
joinedload(DimensionLink.dimension).options(
noload(Node.created_by), # Prevent User N+1 queries
joinedload(Node.current).options(
noload(NodeRevision.created_by), # Prevent User N+1 queries
# Load what's needed for table references, parsing, and type lookups
joinedload(NodeRevision.catalog),
joinedload(NodeRevision.availability),
Expand Down Expand Up @@ -375,6 +377,7 @@ async def load_nodes(ctx: BuildContext) -> None:
Node.current_version,
),
joinedload(Node.current).options(
noload(NodeRevision.created_by), # Prevent User N+1 queries
load_only(
NodeRevision.name,
NodeRevision.query,
Expand All @@ -391,13 +394,18 @@ async def load_nodes(ctx: BuildContext) -> None:
selectinload(NodeRevision.required_dimensions).options(
# Load the node_revision and node to reconstruct full dimension path
joinedload(Column.node_revision).options(
joinedload(NodeRevision.node),
noload(NodeRevision.created_by), # Prevent User N+1 queries
joinedload(NodeRevision.node).options(
noload(Node.created_by), # Prevent User N+1 queries
),
),
),
joinedload(NodeRevision.availability), # For materialization support
selectinload(NodeRevision.dimension_links).options(
# Load dimension node for link matching in temporal filters
joinedload(DimensionLink.dimension),
joinedload(DimensionLink.dimension).options(
noload(Node.created_by), # Prevent User N+1 queries
),
),
),
)
Expand Down
Loading
Loading