Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions core/ast/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ class NodeType(Enum):
SUBQUERY = "subquery"
COLUMN = "column"
LITERAL = "literal"
DATA_TYPE = "data_type"
TIME_UNIT = "time_unit"
LIST = "list"
INTERVAL = "interval"

# VarSQL specific
VAR = "var"
VARSET = "varset"
Expand All @@ -32,6 +37,8 @@ class NodeType(Enum):
LIMIT = "limit"
OFFSET = "offset"
QUERY = "query"
CASE = "case"
WHEN_THEN = "when_then"

# ============================================================================
# Join Type Enumeration
Expand Down
108 changes: 104 additions & 4 deletions core/ast/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,59 @@ def __eq__(self, other):
def __hash__(self):
return hash((super().__hash__(), self.value))

class DataTypeNode(Node):
"""SQL data type node used in CAST expressions (e.g. TEXT, DATE, INTEGER)"""
def __init__(self, _name: str, **kwargs):
super().__init__(NodeType.DATA_TYPE, **kwargs)
self.name = _name

def __eq__(self, other):
if not isinstance(other, DataTypeNode):
return False
return super().__eq__(other) and self.name == other.name

def __hash__(self):
return hash((super().__hash__(), self.name))


class TimeUnitNode(Node):
"""SQL time unit node used in INTERVAL and temporal functions (e.g. DAY, MONTH, SECOND)"""
def __init__(self, _name: str, **kwargs):
super().__init__(NodeType.TIME_UNIT, **kwargs)
self.name = _name

def __eq__(self, other):
if not isinstance(other, TimeUnitNode):
return False
return super().__eq__(other) and self.name == other.name

def __hash__(self):
return hash((super().__hash__(), self.name))

class ListNode(Node):
"""A list of nodes, e.g. the right-hand side of an IN expression"""
def __init__(self, _items: List[Node], **kwargs):
super().__init__(NodeType.LIST, children=_items, **kwargs)

class IntervalNode(Node):
def __init__(self, _value, _unit: TimeUnitNode, **kwargs):
# Include the value in children when it is itself a Node, so that
# generic traversals/formatters that walk via `children` see it.
if isinstance(_value, Node):
children = [_value, _unit]
else:
children = [_unit]
super().__init__(NodeType.INTERVAL, children=children, **kwargs)
self.value = _value
self.unit = _unit

def __eq__(self, other):
if not isinstance(other, IntervalNode):
return False
return super().__eq__(other) and self.value == other.value and self.unit == other.unit

def __hash__(self):
return hash((super().__hash__(), self.value, self.unit))

class VarNode(Node):
"""VarSQL variable node"""
Expand Down Expand Up @@ -192,9 +245,22 @@ def __hash__(self):
# ============================================================================

class SelectNode(Node):
"""SELECT clause node"""
def __init__(self, _items: List['Node'], **kwargs):
super().__init__(NodeType.SELECT, children=_items, **kwargs)
"""SELECT clause node. _distinct_on is the list of expressions for DISTINCT ON (e.g. ListNode of columns)."""
def __init__(self, _items: List['Node'], _distinct: bool = False, _distinct_on: Optional['Node'] = None, **kwargs):
children = list(_items)
if _distinct_on is not None:
children.append(_distinct_on)
super().__init__(NodeType.SELECT, children=children, **kwargs)
self.distinct = _distinct
self.distinct_on = _distinct_on

def __eq__(self, other):
if not isinstance(other, SelectNode):
return False
return super().__eq__(other) and self.distinct == other.distinct and self.distinct_on == other.distinct_on

def __hash__(self):
return hash((super().__hash__(), self.distinct, self.distinct_on))


# TODO - confine the valid NodeTypes as children of FromNode
Expand Down Expand Up @@ -304,4 +370,38 @@ def __init__(self,
children.append(_limit)
if _offset:
children.append(_offset)
super().__init__(NodeType.QUERY, children=children, **kwargs)
super().__init__(NodeType.QUERY, children=children, **kwargs)

class WhenThenNode(Node):
"""Single WHEN ... THEN ... branch of a CASE expression"""
def __init__(self, _when: Node, _then: Node, **kwargs):
super().__init__(NodeType.WHEN_THEN, children=[_when, _then], **kwargs)
self.when = _when
self.then = _then

def __eq__(self, other):
if not isinstance(other, WhenThenNode):
return False
return super().__eq__(other) and self.when == other.when and self.then == other.then

def __hash__(self):
return hash((super().__hash__(), self.when, self.then))


class CaseNode(Node):
"""SQL CASE WHEN ... THEN ... ELSE ... END expression"""
def __init__(self, _whens: List[WhenThenNode], _else: Optional[Node] = None, **kwargs):
children: List[Node] = list(_whens)
if _else is not None:
children.append(_else)
super().__init__(NodeType.CASE, children=children, **kwargs)
self.whens = _whens
self.else_val = _else

def __eq__(self, other):
if not isinstance(other, CaseNode):
return False
return super().__eq__(other) and self.whens == other.whens and self.else_val == other.else_val

def __hash__(self):
return hash((super().__hash__(), tuple(self.whens), self.else_val))
Loading