Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/ast/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class NodeType(Enum):
SUBQUERY = "subquery"
COLUMN = "column"
LITERAL = "literal"
TYPE = "type"
LIST = "list"
INTERVAL = "interval"

# VarSQL specific
VAR = "var"
VARSET = "varset"
Expand All @@ -32,6 +36,7 @@ class NodeType(Enum):
LIMIT = "limit"
OFFSET = "offset"
QUERY = "query"
CASE = "case"

# ============================================================================
# Join Type Enumeration
Expand Down
71 changes: 68 additions & 3 deletions core/ast/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,42 @@ def __eq__(self, other):
def __hash__(self):
return hash((super().__hash__(), self.value))

class TypeNode(Node):
"""SQL type keyword node (e.g. TEXT, DATE, INTEGER)"""
SQL_TYPE_KEYWORDS = {"TEXT", "DATE", "INTEGER", "TIMESTAMP", "VARCHAR", "BOOLEAN", "FLOAT", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "YEAR", "NULL"}

def __init__(self, _name: str, **kwargs):
if _name not in TypeNode.SQL_TYPE_KEYWORDS:
raise ValueError(f"Invalid SQL type keyword: {_name}")
Comment on lines +118 to +123
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeNode’s docstring says it represents SQL type keywords (TEXT/DATE/INTEGER), but the allowed keyword set also includes interval units (SECOND, MINUTE, …) and NULL. Either widen the docstring (and possibly rename the class) to reflect that it models general SQL keywords/units, or split this into separate node types to avoid confusion for AST consumers.

Suggested change
"""SQL type keyword node (e.g. TEXT, DATE, INTEGER)"""
SQL_TYPE_KEYWORDS = {"TEXT", "DATE", "INTEGER", "TIMESTAMP", "VARCHAR", "BOOLEAN", "FLOAT", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "YEAR", "NULL"}
def __init__(self, _name: str, **kwargs):
if _name not in TypeNode.SQL_TYPE_KEYWORDS:
raise ValueError(f"Invalid SQL type keyword: {_name}")
"""SQL keyword/unit node for types, interval units, and NULL (e.g. TEXT, DATE, INTEGER, SECOND, YEAR, NULL)"""
SQL_TYPE_KEYWORDS = {"TEXT", "DATE", "INTEGER", "TIMESTAMP", "VARCHAR", "BOOLEAN", "FLOAT", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "YEAR", "NULL"}
def __init__(self, _name: str, **kwargs):
if _name not in TypeNode.SQL_TYPE_KEYWORDS:
raise ValueError(f"Invalid SQL type/keyword: {_name}")

Copilot uses AI. Check for mistakes.
super().__init__(NodeType.TYPE, **kwargs)
self.name = _name

def __eq__(self, other):
if not isinstance(other, TypeNode):
return False
return super().__eq__(other) and self.name == other.name

def __hash__(self):
return hash((super().__hash__(), self.name))

class ListNode(Node):
"""A list of nodes, e.g. the right-hand side of an IN expression"""
def __init__(self, _items: List[Node], **kwargs):
super().__init__(NodeType.LIST, children=_items, **kwargs)

class IntervalNode(Node):
def __init__(self, _value, _unit: TypeNode, **kwargs):
super().__init__(NodeType.INTERVAL, children=[_unit], **kwargs)
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IntervalNode stores _value as an attribute but does not include it in children. Since much of the codebase (e.g., formatting/visualization/traversal) walks the AST via children, the interval value can be skipped entirely. Consider including the value node in children (e.g., [value, unit] when value is a Node) or otherwise ensuring traversal/formatting code accounts for it.

Suggested change
super().__init__(NodeType.INTERVAL, children=[_unit], **kwargs)
# Include the value in children when it is itself a Node, so that
# generic traversals/formatters that walk via `children` see it.
if isinstance(_value, Node):
children = [_value, _unit]
else:
children = [_unit]
super().__init__(NodeType.INTERVAL, children=children, **kwargs)

Copilot uses AI. Check for mistakes.
self.value = _value
self.unit = _unit

def __eq__(self, other):
if not isinstance(other, IntervalNode):
return False
return super().__eq__(other) and self.value == other.value and self.unit == other.unit

def __hash__(self):
return hash((super().__hash__(), self.value, self.unit))

class VarNode(Node):
"""VarSQL variable node"""
Expand Down Expand Up @@ -192,9 +228,19 @@ def __hash__(self):
# ============================================================================

class SelectNode(Node):
"""SELECT clause node"""
def __init__(self, _items: List['Node'], **kwargs):
"""SELECT clause node. _distinct_on is the list of expressions for DISTINCT ON (e.g. ListNode of columns)."""
def __init__(self, _items: List['Node'], _distinct: bool = False, _distinct_on: Optional['Node'] = None, **kwargs):
super().__init__(NodeType.SELECT, children=_items, **kwargs)
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SelectNode tracks _distinct_on but does not include it in children. Any generic AST traversal that relies on children will miss DISTINCT ON expressions, which can lead to incorrect rewrites/formatting/analysis. Consider representing DISTINCT ON as part of the node’s subtree (e.g., include it in children with a dedicated wrapper node/field-aware traversal) and update consumers accordingly.

Suggested change
super().__init__(NodeType.SELECT, children=_items, **kwargs)
# Include DISTINCT ON expressions in children so generic AST traversals see them.
children: List[Node] = list(_items) if _items is not None else []
if _distinct_on is not None:
children.append(_distinct_on)
super().__init__(NodeType.SELECT, children=children, **kwargs)

Copilot uses AI. Check for mistakes.
self.distinct = _distinct
self.distinct_on = _distinct_on

def __eq__(self, other):
if not isinstance(other, SelectNode):
return False
return super().__eq__(other) and self.distinct == other.distinct and self.distinct_on == other.distinct_on

def __hash__(self):
return hash((super().__hash__(), self.distinct, self.distinct_on))


# TODO - confine the valid NodeTypes as children of FromNode
Expand Down Expand Up @@ -304,4 +350,23 @@ def __init__(self,
children.append(_limit)
if _offset:
children.append(_offset)
super().__init__(NodeType.QUERY, children=children, **kwargs)
super().__init__(NodeType.QUERY, children=children, **kwargs)

class CaseNode(Node):
"""SQL CASE WHEN ... THEN ... ELSE ... END expression"""
def __init__(self, _whens: List[tuple], _else=None, **kwargs):
# flatten whens into children: [cond1, val1, cond2, val2, ..., else]
children = [node for pair in _whens for node in pair]
if _else is not None:
children.append(_else)
super().__init__(NodeType.CASE, children=children, **kwargs)
self.whens = _whens
self.else_val = _else

def __eq__(self, other):
if not isinstance(other, CaseNode):
return False
return super().__eq__(other) and self.whens == other.whens and self.else_val == other.else_val

def __hash__(self):
return hash((super().__hash__(), tuple(self.whens), self.else_val))
Loading