Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions datajunction-server/datajunction_server/internal/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,13 +2051,19 @@ async def column_lineage(
)

# Find the expression AST for the column on the node
column = [
matching = [
col
for col in query_ast.select.projection
if ( # pragma: no cover
col != ast.Null() and col.alias_or_name.name.lower() == column_name.lower() # type: ignore
)
][0]
]
if not matching:
# The column name (from the DB) doesn't appear in the compiled projection —
# this can happen for derived metrics whose alias was generated differently.
# Return what we have (empty lineage) rather than crashing.
return lineage_column # pragma: no cover
column = matching[0]
column_or_child = column.child if isinstance(column, ast.Alias) else column # type: ignore
column_expr = (
column_or_child.expression # type: ignore
Expand Down
61 changes: 44 additions & 17 deletions datajunction-server/datajunction_server/sql/parsing/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,7 @@ class Column(Aliasable, Named, Expression):
default=None,
)
_is_struct_ref: bool = False
_struct_col_name: Optional[str] = field(repr=False, default=None)
_type: Optional["ColumnType"] = field(repr=False, default=None)
_expression: Optional[Expression] = field(repr=False, default=None)
_is_compiled: bool = False
Expand Down Expand Up @@ -858,12 +859,19 @@ def add_expression(self, expression: "Expression") -> "Column":
self._expression = expression
return self

def set_struct_ref(self):
def set_struct_ref(self, struct_col_name: Optional[str] = None):
"""
Marks this column as a struct dereference. This implies that we treat the name
and namespace values on this object as struct column and struct subscript values.

struct_col_name: the full dotted struct field path up to (but not including) the
leaf subscript, e.g. "device_health.network" for `device_health.network.rtt_ms`.
When provided, this is used verbatim by struct_column_name instead of being
re-derived from the namespace at render time.
"""
self._is_struct_ref = True
if struct_col_name is not None:
self._struct_col_name = struct_col_name

def add_table(self, table: "TableExpression"):
self._table = table
Expand Down Expand Up @@ -1272,10 +1280,12 @@ async def compile(self, ctx: CompileContext):
@property
def struct_column_name(self) -> str:
"""If this is a struct reference, the struct type's column name"""
column_namespace, column_name, subscript_name = self.column_names()
if len(self.namespace) == 1: # non-struct
return column_namespace
return column_name
if self._struct_col_name is not None:
return self._struct_col_name
# Fallback for columns where set_struct_ref() was called without a name
if len(self.namespace) == 1:
return self.namespace[0].name
return ".".join(n.name for n in self.namespace[1:])

@property
def struct_subscript(self) -> str:
Expand Down Expand Up @@ -1524,23 +1534,33 @@ def add_column_reference(
if len(column.namespace) == 2:
column_namespace, column_name, _ = column.column_names()

if col.alias_or_name.identifier(False) in (
_col_id = col.alias_or_name.identifier(False)
if _col_id in (
column_namespace,
column_name,
):
for type_field in col.type.fields:
if type_field.name.name == subscript_name:
self._ref_columns.append(column)
column.set_struct_ref()
column.add_table(self)
column.add_expression(col)
column.add_type(type_field.type)
return True
# One-level check: find subscript_name as a direct struct field.
# Skip when len(namespace)==2 and col_id==column_namespace —
# the user explicitly wrote a deeper path (e.g.
# struct_col.intermediate.leaf) so we must use the two-level
# check below to honour the full path and not short-circuit on
# a same-named field at the wrong depth.
if not (
len(column.namespace) == 2 and _col_id == column_namespace
):
for type_field in col.type.fields:
if type_field.name.name == subscript_name:
self._ref_columns.append(column)
column.set_struct_ref(_col_id)
column.add_table(self)
column.add_expression(col)
column.add_type(type_field.type)
return True
# Two-level struct access: col is viewing_secs (StructType),
# column_name is wall_clock (intermediate field), and
# subscript_name is total (leaf field).
# Find column_name in col.type.fields, then subscript_name in that.
if col.alias_or_name.identifier(False) == column_namespace:
if _col_id == column_namespace:
for mid_field in col.type.fields:
if mid_field.name.name == column_name and isinstance(
mid_field.type,
Expand All @@ -1549,7 +1569,9 @@ def add_column_reference(
for leaf_field in mid_field.type.fields:
if leaf_field.name.name == subscript_name:
self._ref_columns.append(column)
column.set_struct_ref()
column.set_struct_ref(
f"{column_namespace}.{column_name}",
)
column.add_table(self)
column.add_expression(col)
column.add_type(leaf_field.type)
Expand All @@ -1569,7 +1591,12 @@ def add_column_reference(
)
if resolved is not None:
self._ref_columns.append(column)
column.set_struct_ref()
struct_col_name = (
col_id
if len(field_path) == 1
else col_id + "." + ".".join(field_path[:-1])
)
column.set_struct_ref(struct_col_name)
column.add_table(self)
column.add_expression(col)
column.add_type(resolved)
Expand Down
30 changes: 30 additions & 0 deletions datajunction-server/tests/api/nodes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4805,6 +4805,36 @@ async def test_node_column_lineage(self, client_with_roads: AsyncClient):
},
]

@pytest.mark.asyncio
async def test_node_column_lineage_derived_metric(
self,
client_with_roads: AsyncClient,
):
"""
A derived metric (references another metric, no FROM clause, no explicit alias)
stores its output column as 'col0' during creation, but column_lineage uses
format_metric_alias which aliases it to amenable_name(node_name). These names
don't match, so the projection lookup returns an empty list. Verify that the
endpoint returns 200 with empty lineage rather than crashing with IndexError.
"""
await client_with_roads.post(
"/nodes/metric/",
json={
"name": "default.derived_orders_scaled",
"query": "SELECT default.num_repair_orders * 2.0",
"mode": "published",
"description": "Derived metric with no explicit alias",
},
)
response = await client_with_roads.get(
"/nodes/default.derived_orders_scaled/lineage/",
)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["node_name"] == "default.derived_orders_scaled"
assert data[0]["node_type"] == "metric"

@pytest.mark.asyncio
async def test_revalidating_existing_nodes(self, client_with_roads: AsyncClient):
"""
Expand Down
Loading
Loading