Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 112 additions & 39 deletions datajunction-server/datajunction_server/construction/build_v3/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from datajunction_server.database.preaggregation import PreAggregation
from datajunction_server.models.node_type import NodeType

from datajunction_server.sql.parsing.backends.antlr4 import parse as parse_sql

from datajunction_server.construction.build_v3.dimensions import parse_dimension_ref
from datajunction_server.construction.build_v3.types import BuildContext
from datajunction_server.construction.build_v3.utils import collect_required_dimensions
Expand Down Expand Up @@ -334,6 +336,112 @@ async def preload_join_paths(
ctx.nodes[link.dimension.name] = link.dimension


def _node_load_options():
"""Shared eager-load options for batch loading nodes."""
return [
load_only(
Node.name,
Node.type,
Node.current_version,
),
joinedload(Node.current).options(
noload(NodeRevision.created_by),
load_only(
NodeRevision.name,
NodeRevision.query,
NodeRevision.schema_,
NodeRevision.table,
),
selectinload(NodeRevision.columns).options(
load_only(Column.name, Column.type),
),
joinedload(NodeRevision.catalog),
selectinload(NodeRevision.required_dimensions).options(
joinedload(Column.node_revision).options(
noload(NodeRevision.created_by),
joinedload(NodeRevision.node).options(
noload(Node.created_by),
),
),
),
joinedload(NodeRevision.availability),
selectinload(NodeRevision.dimension_links).options(
joinedload(DimensionLink.dimension).options(
noload(Node.created_by),
),
),
),
]


async def _load_missing_upstream_nodes(ctx: BuildContext) -> None:
"""
Self-healing fallback: scan loaded nodes' SQL for table references that are
not in ctx.nodes, and load them from the DB.

This guards against stale NodeRelationship data (e.g. a node was saved via a
code path that did not call revalidate_node, so its parent rows are missing or
wrong). When the relationship table is healthy this function is a no-op.
"""
max_stale_upstream_node_depth = 5
for _ in range(max_stale_upstream_node_depth):
missing: set[str] = set()
for node in ctx.nodes.values():
if (
node.type in (NodeType.SOURCE,)
or not node.current
or not node.current.query
):
continue
try:
query_ast = parse_sql(node.current.query)
cte_names = {cte.alias_or_name.identifier() for cte in query_ast.ctes}
from datajunction_server.sql.parsing import ast as sql_ast

for table in query_ast.find_all(sql_ast.Table):
name = str(table.name)
if name not in cte_names and name not in ctx.nodes:
missing.add(name)
except Exception: # pragma: no cover
pass
if not missing:
break

logger.warning(
"[BuildV3] %d node(s) referenced in SQL but absent from NodeRelationship "
"(stale parent data) — loading them now: %s",
len(missing),
missing,
)

# Recursively find their upstreams too, then batch-load everything
additional_names, additional_parent_map = await find_upstream_node_names(
ctx.session,
list(missing),
)
additional_names.update(missing)
new_names = additional_names - set(ctx.nodes.keys())
if not new_names:
break

stmt = (
select(Node)
.where(Node.name.in_(new_names))
.where(Node.deactivated_at.is_(None))
.options(*_node_load_options())
)
result = await ctx.session.execute(stmt)
for node in result.scalars().unique().all():
ctx.nodes[node.name] = node

# Merge newly discovered parent relationships into ctx.parent_map
for child, parents in additional_parent_map.items():
existing = ctx.parent_map.setdefault(child, [])
for p in parents:
if p not in existing:
existing.append(p)


async def load_nodes(ctx: BuildContext) -> None:
"""
Load all nodes needed for SQL generation
Expand Down Expand Up @@ -370,45 +478,7 @@ async def load_nodes(ctx: BuildContext) -> None:
select(Node)
.where(Node.name.in_(all_node_names))
.where(Node.deactivated_at.is_(None))
.options(
load_only(
Node.name,
Node.type,
Node.current_version,
),
joinedload(Node.current).options(
noload(NodeRevision.created_by), # Prevent User N+1 queries
load_only(
NodeRevision.name,
NodeRevision.query,
NodeRevision.schema_,
NodeRevision.table,
),
selectinload(NodeRevision.columns).options(
load_only(
Column.name,
Column.type,
),
),
joinedload(NodeRevision.catalog),
selectinload(NodeRevision.required_dimensions).options(
# Load the node_revision and node to reconstruct full dimension path
joinedload(Column.node_revision).options(
noload(NodeRevision.created_by), # Prevent User N+1 queries
joinedload(NodeRevision.node).options(
noload(Node.created_by), # Prevent User N+1 queries
),
),
),
joinedload(NodeRevision.availability), # For materialization support
selectinload(NodeRevision.dimension_links).options(
# Load dimension node for link matching in temporal filters
joinedload(DimensionLink.dimension).options(
noload(Node.created_by), # Prevent User N+1 queries
),
),
),
)
.options(*_node_load_options())
)

result = await ctx.session.execute(stmt)
Expand All @@ -418,6 +488,9 @@ async def load_nodes(ctx: BuildContext) -> None:
for node in nodes:
ctx.nodes[node.name] = node

# Self-healing: load any upstream nodes missing from NodeRelationship
await _load_missing_upstream_nodes(ctx)

# Collect required dimensions from metrics and add to context
# Required dimensions are stored as Column objects, so they don't have role info.
# We need to check if a user-requested dimension already covers the same (node, column).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,16 +506,28 @@ def add_table_prefix(e):

# Build GROUP BY (use same column references as projection, without aliases)
# Skip filter-only dimensions as they're only needed for WHERE clause
# Deduplicate to handle cases where multiple dimensions resolve to the same
# physical column (e.g., via skip-join optimization mapping dateint -> calendar_date
# when calendar_date is also a direct local dimension).
group_by: list[ast.Expression] = []
group_by_seen: set[str] = set()
for resolved_dim in resolved_dimensions:
if resolved_dim.original_ref in ctx.filter_dimensions:
continue
table_alias = get_dimension_table_alias(resolved_dim, main_alias, dim_aliases)
group_by.append(make_column_ref(resolved_dim.column_name, table_alias))
col_ref = make_column_ref(resolved_dim.column_name, table_alias)
col_key = str(col_ref)
if col_key not in group_by_seen:
group_by_seen.add(col_key)
group_by.append(col_ref)

# Add grain columns to GROUP BY for LIMITED aggregability
for grain_col in grain_columns:
group_by.append(make_column_ref(grain_col, main_alias))
col_ref = make_column_ref(grain_col, main_alias)
col_key = str(col_ref)
if col_key not in group_by_seen:
group_by_seen.add(col_key)
group_by.append(col_ref)

# Collect all nodes that need CTEs and their needed columns
nodes_for_ctes: list[Node] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,7 @@ async def _create_or_update_dimension_link(
role=join_link.role,
default_value=join_link.default_value,
)
print("link_input", link_input)
(
dimension_link,
activity_type,
Expand Down Expand Up @@ -2421,6 +2422,7 @@ async def create_or_update_dimension_join_link(
if existing_link:
if len(existing_link) >= 1: # pragma: no cover
for dup_link in existing_link[1:]:
node_revision.dimension_links.remove(dup_link) # type: ignore
await self.session.delete(dup_link)
# Update the existing dimension link
activity_type = ActivityType.UPDATE
Expand Down
106 changes: 105 additions & 1 deletion datajunction-server/tests/api/deployments_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from datajunction_server.internal.git.github_service import GitHubServiceError
from datajunction_server.api.deployments import InProcessExecutor
from datajunction_server.models.dimensionlink import JoinType
from datajunction_server.database.node import Node
from datajunction_server.database.node import Node, NodeRelationship
from datajunction_server.database.tag import Tag
from datajunction_server.models.node import (
MetricDirection,
Expand Down Expand Up @@ -3997,3 +3997,107 @@ def test_falls_back_to_current_user_when_no_source(self):
spec = DeploymentSpec(namespace="test")
orchestrator = self._make_orchestrator(spec, username="admin")
assert orchestrator._history_user == "admin"


class TestDeploymentRevalidation:
"""
Deployment write path: _create_node_revision sets NodeRelationship rows
from node_graph, ensuring each new revision has correct parent links.
"""

@pytest.mark.asyncio
async def test_redeploy_creates_new_revision_with_correct_parents(
self,
client,
session,
):
"""
Force-redeploying a spec creates new node revisions via _create_node_revision,
which derives parents from node_graph. Verify the new revision has the
correct NodeRelationship row regardless of any corruption on the old revision.
"""
from sqlalchemy import select, delete

namespace = "revalidate_parents_test"

source = SourceSpec(
name="default.parts",
catalog="default",
schema="shop",
table="parts",
columns=[
ColumnSpec(name="part_id", type="int"),
ColumnSpec(name="name", type="string"),
],
)
transform = TransformSpec(
name="default.parts_enriched",
query="SELECT part_id, name FROM ${prefix}default.parts",
columns=[
ColumnSpec(name="part_id", type="int"),
ColumnSpec(name="name", type="string"),
],
)
spec = DeploymentSpec(namespace=namespace, nodes=[source, transform])

# Initial deployment
data = await deploy_and_wait(client, spec)
assert data["status"] == DeploymentStatus.SUCCESS.value

# Fetch the deployed transform and verify it has a parent
transform_name = f"{namespace}.default.parts_enriched"
from sqlalchemy.orm import joinedload

transform_node = (
await session.execute(
select(Node)
.where(Node.name == transform_name)
.options(joinedload(Node.current)),
)
).scalar_one()
original_rev_id = transform_node.current.id
original_rows = (
await session.execute(
select(NodeRelationship).where(
NodeRelationship.child_id == original_rev_id,
),
)
).all()
assert len(original_rows) == 1, "transform should have one parent after deploy"

# Corrupt: delete the NodeRelationship rows
await session.execute(
delete(NodeRelationship).where(
NodeRelationship.child_id == original_rev_id,
),
)
await session.commit()
assert (
await session.execute(
select(NodeRelationship).where(
NodeRelationship.child_id == original_rev_id,
),
)
).all() == []

# Re-deploy with force=True so even unchanged nodes get new revisions
# with correct parent links set by _create_node_revision.
force_spec = DeploymentSpec(
namespace=namespace,
nodes=[source, transform],
force=True,
)
data = await deploy_and_wait(client, force_spec)
assert data["status"] == DeploymentStatus.SUCCESS.value

# Re-fetch to get the new current revision created by re-deployment
await session.refresh(transform_node, ["current"])
new_rev_id = transform_node.current.id
restored_rows = (
await session.execute(
select(NodeRelationship).where(NodeRelationship.child_id == new_rev_id),
)
).all()
assert len(restored_rows) == 1, (
"_create_node_revision should have written correct parent relationships"
)
Loading
Loading