From f40ac336853747863d45714420f22077220e88f6 Mon Sep 17 00:00:00 2001 From: Qiushi Bai Date: Tue, 24 Feb 2026 12:57:16 -0800 Subject: [PATCH 1/2] Adding a ast_util.py to visualize ASTs. --- pytest.ini | 9 + tests/ast_util.py | 304 +++++++++++++++++ tests/test_query_parser.py | 645 +++++++++++++++++++------------------ 3 files changed, 637 insertions(+), 321 deletions(-) create mode 100644 pytest.ini create mode 100644 tests/ast_util.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..6852c5a --- /dev/null +++ b/pytest.ini @@ -0,0 +1,9 @@ +[pytest] +# Capture log messages during test execution +log_cli = true +log_cli_level = INFO +log_file_level = INFO +log_level = INFO + +# Show output from print statements and logging +addopts = -v --tb=short diff --git a/tests/ast_util.py b/tests/ast_util.py new file mode 100644 index 0000000..274e074 --- /dev/null +++ b/tests/ast_util.py @@ -0,0 +1,304 @@ +""" +Utility functions for visualizing and working with AST structures. +""" +import textwrap +import sqlparse +from core.ast.node import ( + Node, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, + LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, + OrderByNode, OrderByItemNode, LimitNode, OffsetNode, JoinNode, SubqueryNode, + VarNode, VarSetNode +) + + +def _beautify_sql(sql: str) -> str: + """ + Beautify SQL query string with proper indentation and formatting. + + Uses sqlparse library. + + Args: + sql: Raw SQL query string + + Returns: + Formatted SQL string with proper indentation + """ + + formatted = sqlparse.format( + sql, + reindent=True, + keyword_case="upper" + ) + + return formatted + + +def _node_to_string(node: Node, indent: int = 0) -> str: + """ + Convert an AST node to a tree-formatted string representation. + + This function recursively converts AST nodes into a human-readable tree format + for visualization. The translation rules for each node type are: + + - TableNode: "table: name [alias]" + - name: table name + - [alias]: optional table alias (e.g., "employees [e]") + + - ColumnNode: "column: name (parent_alias) as alias" + - name: column name + - (parent_alias): optional table alias this column references (e.g., "salary (e)") + - as alias: optional column-level alias (e.g., "as emp_count") + + - LiteralNode: "literal: value" + - value: the literal value (e.g., 40000, 'text') + + - FunctionNode: "function: name as alias" + - name: function name (e.g., COUNT, SUM) + - as alias: optional function alias (e.g., "as emp_count") + - children: function arguments displayed as child nodes + + - OperatorNode: "operator: op_name" + - op_name: the operator (e.g., =, AND, OR, IN, >) + - children: operands as child nodes + - Special case for IN: displays a "values:" node containing the list items + + - JoinNode: "join: join_type" + - join_type: INNER, LEFT, RIGHT, FULL, CROSS, etc. + - children: left table, right table, and join condition + + - OrderByItemNode: "order_by_item: sort_order" + - sort_order: ASC or DESC + - children: the column being sorted + + - SelectNode, FromNode, WhereNode, GroupByNode, HavingNode, OrderByNode: + "select", "from", "where", "group_by", "having", "order_by" + - These clause nodes have children representing their contents + + - LimitNode, OffsetNode: "limit: value" / "offset: value" + - value: the numeric limit or offset + + - QueryNode: "query" + - Represents the root query or a subquery's internal structure + - children: SELECT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET clauses + + - SubqueryNode: "subquery [alias]" + - [alias]: optional subquery alias (e.g., "[grouped_items]") + - children: the internal QueryNode + + Args: + node: AST node to convert + indent: Current indentation level + + Returns: + String representation of the node in tree format + """ + result = [] + prefix = "| " * indent + "+- " + + # Get node type name + node_type = node.type.value if hasattr(node.type, 'value') else str(node.type) + + # Build node representation based on node type + if isinstance(node, TableNode): + # TableNode: display as "table: table_name [alias]" + # Example: "table: employees [e]" - "e" is the table alias for reference in WHERE/SELECT + alias_str = f" [{node.alias}]" if node.alias else "" + result.append(f"{prefix}{node_type}: {node.name}{alias_str}") + + elif isinstance(node, ColumnNode): + # ColumnNode: display as "column: column_name (parent_alias) as alias" + # Example: "column: salary (e) as avg_salary" + # - (e) indicates this column belongs to table with alias "e" + # - "as avg_salary" is the column's output alias in the result set + parent_alias = f" ({node.parent_alias})" if node.parent_alias else "" + alias_str = f" as {node.alias}" if node.alias else "" + result.append(f"{prefix}{node_type}: {node.name}{parent_alias}{alias_str}") + + elif isinstance(node, LiteralNode): + # LiteralNode: display the literal value + # Examples: "literal: 40000", "literal: 'hello'", "literal: true" + result.append(f"{prefix}{node_type}: {node.value}") + + elif isinstance(node, FunctionNode): + # FunctionNode: display as "function: function_name as alias" + # Example: "function: COUNT as emp_count", "function: SUM" + # The function arguments are shown as child nodes + alias_str = f" as {node.alias}" if node.alias else "" + result.append(f"{prefix}{node_type}: {node.name}{alias_str}") + if node.children: + for i, child in enumerate(node.children): + child_lines = _node_to_string(child, indent + 1).split('\n') + for line in child_lines: + result.append(line) + + elif isinstance(node, OperatorNode): + # OperatorNode: display as "operator: operator_symbol" + # Examples: "operator: =", "operator: AND", "operator: >", "operator: IN" + # Binary operators like "=" have two operands (left, right) as children + # Logical operators like "AND" combine conditions + result.append(f"{prefix}{node_type}: {node.name}") + if node.children: + for i, child in enumerate(node.children): + if isinstance(child, list): + # Special handling for IN operator with list of values + # IN can have: (column, IN, [value1, value2, ...]) + list_prefix = "| " * (indent + 1) + "+- " + result.append(f"{list_prefix}values:") + for item in child: + item_lines = _node_to_string(item, indent + 2).split('\n') + for line in item_lines: + result.append(line) + else: + child_lines = _node_to_string(child, indent + 1).split('\n') + for line in child_lines: + result.append(line) + + elif isinstance(node, JoinNode): + # JoinNode: display as "join: join_type" + # Example: "join: inner" for INNER JOIN + # Children include: left table, right table, and join condition (ON clause) + join_type = node.join_type.value if hasattr(node.join_type, 'value') else str(node.join_type) + result.append(f"{prefix}{node_type}: {join_type}") + left_lines = _node_to_string(node.left_table, indent + 1).split('\n') + for line in left_lines: + result.append(line) + right_lines = _node_to_string(node.right_table, indent + 1).split('\n') + for line in right_lines: + result.append(line) + if node.on_condition: + cond_lines = _node_to_string(node.on_condition, indent + 1).split('\n') + for line in cond_lines: + result.append(line) + + elif isinstance(node, OrderByItemNode): + # OrderByItemNode: display as "order_by_item: sort_order" + # Example: "order_by_item: ASC" or "order_by_item: DESC" + # The column being sorted is shown as a child + sort_order = node.sort.value if hasattr(node.sort, 'value') else str(node.sort) + result.append(f"{prefix}{node_type}: {sort_order}") + if node.children: + for child in node.children: + child_lines = _node_to_string(child, indent + 1).split('\n') + for line in child_lines: + result.append(line) + + elif isinstance(node, (SelectNode, FromNode, WhereNode, GroupByNode, HavingNode, OrderByNode)): + # Clause nodes: display as the clause name only + # Examples: "select", "from", "where", "group_by", "having", "order_by" + # Children represent the contents of each clause + result.append(f"{prefix}{node_type}") + if node.children: + for child in node.children: + child_lines = _node_to_string(child, indent + 1).split('\n') + for line in child_lines: + result.append(line) + + elif isinstance(node, (LimitNode, OffsetNode)): + # LimitNode/OffsetNode: display as "limit: value" or "offset: value" + # Example: "limit: 10", "offset: 5" + value = node.limit if isinstance(node, LimitNode) else node.offset + result.append(f"{prefix}{node_type}: {value}") + + elif isinstance(node, QueryNode): + # QueryNode: root query or subquery structure, display as "query" + # Maintains tree structure consistency by using proper prefix and indentation + # Children are the clauses: SELECT, FROM, WHERE, GROUP BY, etc. + result.append(f"{prefix}query") + if node.children: + for child in node.children: + child_lines = _node_to_string(child, indent + 1).split('\n') + for line in child_lines: + result.append(line) + + elif isinstance(node, SubqueryNode): + # SubqueryNode: display as "subquery [alias]" + # Example: "subquery [t1]" where "t1" is the alias used to reference this subquery + # Children: the internal QueryNode representing the subquery's structure + alias_str = f" [{node.alias}]" if node.alias else "" + result.append(f"{prefix}{node_type}{alias_str}") + if node.children: + for child in node.children: + child_lines = _node_to_string(child, indent + 1).split('\n') + for line in child_lines: + result.append(line) + + elif isinstance(node, (VarNode, VarSetNode)): + # VarNode/VarSetNode: VarSQL variable, display as "var: name" or "varset: name" + result.append(f"{prefix}{node_type}: {node.name}") + + else: + # Default case for any other node types + result.append(f"{prefix}{node_type}") + if node.children: + for child in node.children: + child_lines = _node_to_string(child, indent + 1).split('\n') + for line in child_lines: + result.append(line) + + return '\n'.join(result) + + +def visualize_ast(sql: str, ast: QueryNode, max_sql_width: int = 50) -> str: + """ + Generate a side-by-side visualization of SQL query and AST structure. + + This function beautifies the SQL query on the left and displays the AST + tree structure on the right, allowing for easy comparison and review. + Individual SQL lines that exceed max_sql_width are automatically wrapped. + + Args: + sql: SQL query string to visualize + ast: QueryNode representing the parsed AST + max_sql_width: Maximum width for SQL column before wrapping (default: 50) + + Returns: + Formatted string with SQL on the left and AST tree on the right + """ + # Beautify SQL + beautified_sql = _beautify_sql(sql) + sql_lines = beautified_sql.split('\n') + + # Wrap long SQL lines to fit within max_sql_width + wrapped_sql_lines = [] + for line in sql_lines: + if len(line) > max_sql_width: + # Wrap long lines, preserving indentation + wrapped = textwrap.fill( + line, + width=max_sql_width, + subsequent_indent=' ', # Indent continuation lines + break_long_words=False, + break_on_hyphens=False + ) + wrapped_sql_lines.extend(wrapped.split('\n')) + else: + wrapped_sql_lines.append(line) + + # Convert AST to tree format + ast_tree = _node_to_string(ast) + ast_lines = ast_tree.split('\n') + + # Calculate column widths based on wrapped SQL + actual_sql_width = max(len(line) for line in wrapped_sql_lines) if wrapped_sql_lines else 0 + max_ast_width = max(len(line) for line in ast_lines) if ast_lines else 0 + padding = 3 # Space between columns + + total_width = actual_sql_width + padding + max_ast_width + + result = [] + result.append("=" * total_width) + result.append(f"{'SQL QUERY':<{actual_sql_width}}{' ' * padding}{'AST STRUCTURE'}") + result.append("=" * total_width) + + # Merge lines side-by-side + max_lines = max(len(wrapped_sql_lines), len(ast_lines)) + for i in range(max_lines): + sql_line = wrapped_sql_lines[i] if i < len(wrapped_sql_lines) else "" + ast_line = ast_lines[i] if i < len(ast_lines) else "" + + # Pad SQL line to match column width + result.append(f"{sql_line:<{actual_sql_width}}{' ' * padding}{ast_line}") + + result.append("=" * total_width) + + return '\n'.join(result) diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index a9aecda..a0ba83f 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -1,322 +1,325 @@ -from core.query_parser import QueryParser -from data.queries import get_query -from data.asts import get_ast - -parser = QueryParser() - - -def test_basic_parse(): - """ - Test parsing of a complex SQL query with JOINs, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, and OFFSET clauses. - """ - - # Construct input query text - sql = """ - SELECT e.name, d.name as dept_name, COUNT(*) as emp_count - FROM employees e JOIN departments d ON e.department_id = d.id - WHERE e.salary > 40000 AND e.age < 60 - GROUP BY d.id, d.name - HAVING COUNT(*) > 2 - ORDER BY dept_name, emp_count DESC - LIMIT 10 OFFSET 5 - """ - - assert parser.parse(sql) == get_ast(44) - - -def test_subquery_parse(): - """ - Test parsing of a SQL query with subquery in WHERE clause (IN operator). - """ - query = get_query(9) - sql = query['pattern'] - - assert parser.parse(sql) == get_ast(9) - - -def test_query_1(): - """Query 1: Remove Cast Date Match Twice.""" - query = get_query(1) - sql = query["pattern"] - #assert parser.parse(sql) == get_ast(1) - - -def test_query_2(): - """Query 2: Remove Cast Date Match Once.""" - query = get_query(2) - sql = query["rewrite"] - #assert parser.parse(sql) == get_ast(2) - - -# query 3 has the exact same query as query 2, so I skipped it - - -def test_query_4(): - """Query 4.""" - query = get_query(4) - sql = query["rewrite"] - #assert parser.parse(sql) == get_ast(4) - - -# query 5 has the exact same query as query 4, so I skipped it - - -def test_query_6(): - """Query 6: Remove Self Join Match.""" - query = get_query(6) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(6) - - -def test_query_7(): - """Query 7: Remove Self Join No Match.""" - query = get_query(7) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(7) - - -def test_query_8(): - """Query 8.""" - query = get_query(8) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(8) - - -# query 9 is used in test_subquery_parse - - -def test_query_10(): - """Query 10: Subquery to Join Match 2.""" - query = get_query(10) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(10) - - -def test_query_11(): - """Query 11: Subquery to Join Match 3.""" - query = get_query(11) - sql = query["rewrite"] - # TODO: Rewrite has SELECT DISTINCT (not supported by parser yet) - #assert parser.parse(sql) == get_ast(11) - - -def test_query_12(): - """Query 12: Join to Filter Match 1.""" - query = get_query(12) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(12) - - -def test_query_13(): - """Query 13: Join to Filter Match 2.""" - query = get_query(13) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(13) - - -def test_query_14(): - """Query 14: Test Rule Wetune 90 Match.""" - query = get_query(14) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(14) - - -# TODO: Query 15 uses UNION, which is not supported by parser yet - - -def test_query_16(): - """Query 16: Remove Max Distinct.""" - query = get_query(16) - sql = query["pattern"] - # TODO: DISTINCT is not supported by parser yet - #assert parser.parse(sql) == get_ast(16) - - -def test_query_17(): - """Query 17.""" - query = get_query(17) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(17) - - -def test_query_18(): - """Query 18 (parser drops SELECT for SELECT DISTINCT with comma join).""" - query = get_query(18) - sql = query["pattern"] - # TODO: DISTINCT is not supported by parser yet - #assert parser.parse(sql) == get_ast(18) - - -def test_query_19(): - """Query 19: Stackoverflow 2.""" - query = get_query(19) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(19) - - -def test_query_20(): - """Query 20: Partial Matching Base Case 2.""" - query = get_query(20) - sql = query["pattern"] - # TODO: IN with literal list not supported by parser yet - #assert parser.parse(sql) == get_ast(20) - - -def test_query_21(): - """Query 21: Partial Matching 0.""" - query = get_query(21) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(21) - - -def test_query_22(): - """Query 22: Partial Matching 4.""" - query = get_query(22) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(22) - - -def test_query_23(): - """Query 23: Partial Keeps Remaining OR.""" - query = get_query(23) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(23) - - -def test_query_24(): - """Query 24: Partial Keeps Remaining AND.""" - query = get_query(24) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(24) - - -def test_query_25(): - """Query 25: And On True.""" - query = get_query(25) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(25) - - -def test_query_26(): - """Query 26: Multiple And On True.""" - query = get_query(26) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(26) - - -def test_query_27(): - """Query 27: Remove Where True.""" - query = get_query(27) - sql = query["pattern"] - # TODO: arithmetic expressions not supported by parser yet - #assert parser.parse(sql) == get_ast(27) - - -def test_query_28(): - """Query 28: Rewrite Skips Failed Partial.""" - query = get_query(28) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(28) - - -# TODO: Query 29: Full Matching: UNION not supported by parser - - -def test_query_30(): - """Query 30: Over Partial Matching.""" - query = get_query(30) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(30) - - -def test_query_31(): - """Query 31: Aggregation to Subquery.""" - query = get_query(31) - sql = query["pattern"] - # TODO: CASE not cleanly supported yet - #assert parser.parse(sql) == get_ast(31) - - -# TODO: Query 32: UNION not supported by parser - - -def test_query_33(): - """Query 33: Spreadsheet ID 3.""" - query = get_query(33) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(33) - - -def test_query_34(): - """Query 34: Spreadsheet ID 7.""" - query = get_query(34) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(34) - - -def test_query_35(): - """Query 35: Spreadsheet ID 9.""" - query = get_query(35) - sql = query["pattern"] - # TODO: DISTINCT not supported by parser yet - #assert parser.parse(sql) == get_ast(35) - - -def test_query_36(): - """Query 36: Spreadsheet ID 10.""" - query = get_query(36) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(36) - - -def test_query_37(): - """Query 37: Spreadsheet ID 11.""" - query = get_query(37) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(37) - - -def test_query_38(): - """Query 38: Spreadsheet ID 12.""" - query = get_query(38) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(38) - - -def test_query_39(): - """Query 39: Spreadsheet ID 15.""" - query = get_query(39) - sql = query["pattern"] - assert parser.parse(sql) == get_ast(39) - - -def test_query_40(): - """Query 40.""" - query = get_query(40) - sql = query["pattern"] - # TODO: DISTINCT ON not supported by parser yet - #assert parser.parse(sql) == get_ast(40) - - -def test_query_41(): - """Query 41: Spreadsheet ID 20.""" - query = get_query(41) - sql = query["pattern"] - # TODO: NULL keyword and IS NULL not fully supported yet - #assert parser.parse(sql) == get_ast(41) - - -def test_query_42(): - """Query 42: PostgreSQL Test.""" - query = get_query(42) - sql = query["pattern"] - # TODO: INTERVAL, unary minus, keyword types not fully supported - #assert parser.parse(sql) == get_ast(42) - - -def test_query_43(): - """Query 43: MySQL Test.""" - query = get_query(43) - sql = query["pattern"] - # TODO: INTERVAL unit keyword not fully supported +import logging +from core.query_parser import QueryParser +from data.queries import get_query +from data.asts import get_ast +from .ast_util import visualize_ast + +parser = QueryParser() +logger = logging.getLogger(__name__) + + +def test_basic_parse(): + """ + Test parsing of a complex SQL query with JOINs, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, and OFFSET clauses. + """ + + # Construct input query text + sql = """ + SELECT e.name, d.name as dept_name, COUNT(*) as emp_count + FROM employees e JOIN departments d ON e.department_id = d.id + WHERE e.salary > 40000 AND e.age < 60 + GROUP BY d.id, d.name + HAVING COUNT(*) > 2 + ORDER BY dept_name, emp_count DESC + LIMIT 10 OFFSET 5 + """ + + assert parser.parse(sql) == get_ast(44) + + +def test_subquery_parse(): + """ + Test parsing of a SQL query with subquery in WHERE clause (IN operator). + """ + query = get_query(9) + sql = query['pattern'] + + assert parser.parse(sql) == get_ast(9) + + +def test_query_1(): + """Query 1: Remove Cast Date Match Twice.""" + query = get_query(1) + sql = query["pattern"] + #assert parser.parse(sql) == get_ast(1) + + +def test_query_2(): + """Query 2: Remove Cast Date Match Once.""" + query = get_query(2) + sql = query["rewrite"] + #assert parser.parse(sql) == get_ast(2) + + +# query 3 has the exact same query as query 2, so I skipped it + + +def test_query_4(): + """Query 4.""" + query = get_query(4) + sql = query["rewrite"] + #assert parser.parse(sql) == get_ast(4) + + +# query 5 has the exact same query as query 4, so I skipped it + + +def test_query_6(): + """Query 6: Remove Self Join Match.""" + query = get_query(6) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(6) + + +def test_query_7(): + """Query 7: Remove Self Join No Match.""" + query = get_query(7) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(7) + + +def test_query_8(): + """Query 8.""" + query = get_query(8) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(8) + + +# query 9 is used in test_subquery_parse + + +def test_query_10(): + """Query 10: Subquery to Join Match 2.""" + query = get_query(10) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(10) + + +def test_query_11(): + """Query 11: Subquery to Join Match 3.""" + query = get_query(11) + sql = query["rewrite"] + # TODO: Rewrite has SELECT DISTINCT (not supported by parser yet) + #assert parser.parse(sql) == get_ast(11) + + +def test_query_12(): + """Query 12: Join to Filter Match 1.""" + query = get_query(12) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(12) + + +def test_query_13(): + """Query 13: Join to Filter Match 2.""" + query = get_query(13) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(13) + + +def test_query_14(): + """Query 14: Test Rule Wetune 90 Match.""" + query = get_query(14) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(14) + + +# TODO: Query 15 uses UNION, which is not supported by parser yet + + +def test_query_16(): + """Query 16: Remove Max Distinct.""" + query = get_query(16) + sql = query["pattern"] + # TODO: DISTINCT is not supported by parser yet + #assert parser.parse(sql) == get_ast(16) + + +def test_query_17(): + """Query 17.""" + query = get_query(17) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(17) + + +def test_query_18(): + """Query 18 (parser drops SELECT for SELECT DISTINCT with comma join).""" + query = get_query(18) + sql = query["pattern"] + # TODO: DISTINCT is not supported by parser yet + #assert parser.parse(sql) == get_ast(18) + + +def test_query_19(): + """Query 19: Stackoverflow 2.""" + query = get_query(19) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(19) + + +def test_query_20(): + """Query 20: Partial Matching Base Case 2.""" + query = get_query(20) + sql = query["pattern"] + # TODO: IN with literal list not supported by parser yet + #assert parser.parse(sql) == get_ast(20) + + +def test_query_21(): + """Query 21: Partial Matching 0.""" + query = get_query(21) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(21) + + +def test_query_22(): + """Query 22: Partial Matching 4.""" + query = get_query(22) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(22) + + +def test_query_23(): + """Query 23: Partial Keeps Remaining OR.""" + query = get_query(23) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(23) + + +def test_query_24(): + """Query 24: Partial Keeps Remaining AND.""" + query = get_query(24) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(24) + + +def test_query_25(): + """Query 25: And On True.""" + query = get_query(25) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(25) + + +def test_query_26(): + """Query 26: Multiple And On True.""" + query = get_query(26) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(26) + + +def test_query_27(): + """Query 27: Remove Where True.""" + query = get_query(27) + sql = query["pattern"] + # TODO: arithmetic expressions not supported by parser yet + #assert parser.parse(sql) == get_ast(27) + + +def test_query_28(): + """Query 28: Rewrite Skips Failed Partial.""" + query = get_query(28) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(28) + + +# TODO: Query 29: Full Matching: UNION not supported by parser + + +def test_query_30(): + """Query 30: Over Partial Matching.""" + query = get_query(30) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(30) + + +def test_query_31(): + """Query 31: Aggregation to Subquery.""" + query = get_query(31) + sql = query["pattern"] + # TODO: CASE not cleanly supported yet + #assert parser.parse(sql) == get_ast(31) + + +# TODO: Query 32: UNION not supported by parser + + +def test_query_33(): + """Query 33: Spreadsheet ID 3.""" + query = get_query(33) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(33) + + +def test_query_34(): + """Query 34: Spreadsheet ID 7.""" + query = get_query(34) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(34) + + +def test_query_35(): + """Query 35: Spreadsheet ID 9.""" + query = get_query(35) + sql = query["pattern"] + # TODO: DISTINCT not supported by parser yet + #assert parser.parse(sql) == get_ast(35) + + +def test_query_36(): + """Query 36: Spreadsheet ID 10.""" + query = get_query(36) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(36) + + +def test_query_37(): + """Query 37: Spreadsheet ID 11.""" + query = get_query(37) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(37) + + +def test_query_38(): + """Query 38: Spreadsheet ID 12.""" + query = get_query(38) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(38) + + +def test_query_39(): + """Query 39: Spreadsheet ID 15.""" + query = get_query(39) + sql = query["pattern"] + assert parser.parse(sql) == get_ast(39) + + +def test_query_40(): + """Query 40.""" + query = get_query(40) + sql = query["pattern"] + # TODO: DISTINCT ON not supported by parser yet + #assert parser.parse(sql) == get_ast(40) + + +def test_query_41(): + """Query 41: Spreadsheet ID 20.""" + query = get_query(41) + sql = query["pattern"] + # TODO: NULL keyword and IS NULL not fully supported yet + #assert parser.parse(sql) == get_ast(41) + + +def test_query_42(): + """Query 42: PostgreSQL Test.""" + query = get_query(42) + sql = query["pattern"] + # TODO: INTERVAL, unary minus, keyword types not fully supported + #assert parser.parse(sql) == get_ast(42) + + +def test_query_43(): + """Query 43: MySQL Test.""" + query = get_query(43) + sql = query["pattern"] + # TODO: INTERVAL unit keyword not fully supported #assert parser.parse(sql) == get_ast(43) \ No newline at end of file From 839ecf4f5c0cd517170f7f917d898dd5de422f3e Mon Sep 17 00:00:00 2001 From: Qiushi Bai Date: Tue, 24 Feb 2026 20:08:28 -0800 Subject: [PATCH 2/2] Using the visualize_ast for all test cases in test_query_parser.py. --- tests/test_query_parser.py | 41 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index a0ba83f..5bfadfc 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -24,6 +24,8 @@ def test_basic_parse(): LIMIT 10 OFFSET 5 """ + logger.info("\n" + visualize_ast(sql, get_ast(44))) + assert parser.parse(sql) == get_ast(44) @@ -34,6 +36,8 @@ def test_subquery_parse(): query = get_query(9) sql = query['pattern'] + logger.info("\n" + visualize_ast(sql, get_ast(9))) + assert parser.parse(sql) == get_ast(9) @@ -41,6 +45,7 @@ def test_query_1(): """Query 1: Remove Cast Date Match Twice.""" query = get_query(1) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(1))) #assert parser.parse(sql) == get_ast(1) @@ -48,6 +53,7 @@ def test_query_2(): """Query 2: Remove Cast Date Match Once.""" query = get_query(2) sql = query["rewrite"] + logger.info("\n" + visualize_ast(sql, get_ast(2))) #assert parser.parse(sql) == get_ast(2) @@ -58,6 +64,7 @@ def test_query_4(): """Query 4.""" query = get_query(4) sql = query["rewrite"] + logger.info("\n" + visualize_ast(sql, get_ast(4))) #assert parser.parse(sql) == get_ast(4) @@ -68,6 +75,7 @@ def test_query_6(): """Query 6: Remove Self Join Match.""" query = get_query(6) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(6))) assert parser.parse(sql) == get_ast(6) @@ -75,6 +83,7 @@ def test_query_7(): """Query 7: Remove Self Join No Match.""" query = get_query(7) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(7))) assert parser.parse(sql) == get_ast(7) @@ -82,6 +91,7 @@ def test_query_8(): """Query 8.""" query = get_query(8) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(8))) assert parser.parse(sql) == get_ast(8) @@ -92,6 +102,7 @@ def test_query_10(): """Query 10: Subquery to Join Match 2.""" query = get_query(10) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(10))) assert parser.parse(sql) == get_ast(10) @@ -100,6 +111,7 @@ def test_query_11(): query = get_query(11) sql = query["rewrite"] # TODO: Rewrite has SELECT DISTINCT (not supported by parser yet) + logger.info("\n" + visualize_ast(sql, get_ast(11))) #assert parser.parse(sql) == get_ast(11) @@ -107,6 +119,7 @@ def test_query_12(): """Query 12: Join to Filter Match 1.""" query = get_query(12) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(12))) assert parser.parse(sql) == get_ast(12) @@ -114,6 +127,7 @@ def test_query_13(): """Query 13: Join to Filter Match 2.""" query = get_query(13) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(13))) assert parser.parse(sql) == get_ast(13) @@ -121,6 +135,7 @@ def test_query_14(): """Query 14: Test Rule Wetune 90 Match.""" query = get_query(14) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(14))) assert parser.parse(sql) == get_ast(14) @@ -132,6 +147,7 @@ def test_query_16(): query = get_query(16) sql = query["pattern"] # TODO: DISTINCT is not supported by parser yet + logger.info("\n" + visualize_ast(sql, get_ast(16))) #assert parser.parse(sql) == get_ast(16) @@ -139,6 +155,7 @@ def test_query_17(): """Query 17.""" query = get_query(17) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(17))) assert parser.parse(sql) == get_ast(17) @@ -147,6 +164,7 @@ def test_query_18(): query = get_query(18) sql = query["pattern"] # TODO: DISTINCT is not supported by parser yet + logger.info("\n" + visualize_ast(sql, get_ast(18))) #assert parser.parse(sql) == get_ast(18) @@ -154,6 +172,7 @@ def test_query_19(): """Query 19: Stackoverflow 2.""" query = get_query(19) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(19))) assert parser.parse(sql) == get_ast(19) @@ -162,6 +181,7 @@ def test_query_20(): query = get_query(20) sql = query["pattern"] # TODO: IN with literal list not supported by parser yet + logger.info("\n" + visualize_ast(sql, get_ast(20))) #assert parser.parse(sql) == get_ast(20) @@ -169,6 +189,7 @@ def test_query_21(): """Query 21: Partial Matching 0.""" query = get_query(21) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(21))) assert parser.parse(sql) == get_ast(21) @@ -176,6 +197,7 @@ def test_query_22(): """Query 22: Partial Matching 4.""" query = get_query(22) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(22))) assert parser.parse(sql) == get_ast(22) @@ -183,6 +205,7 @@ def test_query_23(): """Query 23: Partial Keeps Remaining OR.""" query = get_query(23) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(23))) assert parser.parse(sql) == get_ast(23) @@ -190,6 +213,7 @@ def test_query_24(): """Query 24: Partial Keeps Remaining AND.""" query = get_query(24) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(24))) assert parser.parse(sql) == get_ast(24) @@ -197,6 +221,7 @@ def test_query_25(): """Query 25: And On True.""" query = get_query(25) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(25))) assert parser.parse(sql) == get_ast(25) @@ -204,6 +229,7 @@ def test_query_26(): """Query 26: Multiple And On True.""" query = get_query(26) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(26))) assert parser.parse(sql) == get_ast(26) @@ -212,6 +238,7 @@ def test_query_27(): query = get_query(27) sql = query["pattern"] # TODO: arithmetic expressions not supported by parser yet + logger.info("\n" + visualize_ast(sql, get_ast(27))) #assert parser.parse(sql) == get_ast(27) @@ -219,6 +246,7 @@ def test_query_28(): """Query 28: Rewrite Skips Failed Partial.""" query = get_query(28) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(28))) assert parser.parse(sql) == get_ast(28) @@ -229,6 +257,7 @@ def test_query_30(): """Query 30: Over Partial Matching.""" query = get_query(30) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(30))) assert parser.parse(sql) == get_ast(30) @@ -237,6 +266,7 @@ def test_query_31(): query = get_query(31) sql = query["pattern"] # TODO: CASE not cleanly supported yet + logger.info("\n" + visualize_ast(sql, get_ast(31))) #assert parser.parse(sql) == get_ast(31) @@ -247,6 +277,7 @@ def test_query_33(): """Query 33: Spreadsheet ID 3.""" query = get_query(33) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(33))) assert parser.parse(sql) == get_ast(33) @@ -254,6 +285,7 @@ def test_query_34(): """Query 34: Spreadsheet ID 7.""" query = get_query(34) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(34))) assert parser.parse(sql) == get_ast(34) @@ -262,6 +294,7 @@ def test_query_35(): query = get_query(35) sql = query["pattern"] # TODO: DISTINCT not supported by parser yet + logger.info("\n" + visualize_ast(sql, get_ast(35))) #assert parser.parse(sql) == get_ast(35) @@ -269,6 +302,7 @@ def test_query_36(): """Query 36: Spreadsheet ID 10.""" query = get_query(36) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(36))) assert parser.parse(sql) == get_ast(36) @@ -276,6 +310,7 @@ def test_query_37(): """Query 37: Spreadsheet ID 11.""" query = get_query(37) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(37))) assert parser.parse(sql) == get_ast(37) @@ -283,6 +318,7 @@ def test_query_38(): """Query 38: Spreadsheet ID 12.""" query = get_query(38) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(38))) assert parser.parse(sql) == get_ast(38) @@ -290,6 +326,7 @@ def test_query_39(): """Query 39: Spreadsheet ID 15.""" query = get_query(39) sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(39))) assert parser.parse(sql) == get_ast(39) @@ -298,6 +335,7 @@ def test_query_40(): query = get_query(40) sql = query["pattern"] # TODO: DISTINCT ON not supported by parser yet + logger.info("\n" + visualize_ast(sql, get_ast(40))) #assert parser.parse(sql) == get_ast(40) @@ -306,6 +344,7 @@ def test_query_41(): query = get_query(41) sql = query["pattern"] # TODO: NULL keyword and IS NULL not fully supported yet + logger.info("\n" + visualize_ast(sql, get_ast(41))) #assert parser.parse(sql) == get_ast(41) @@ -314,6 +353,7 @@ def test_query_42(): query = get_query(42) sql = query["pattern"] # TODO: INTERVAL, unary minus, keyword types not fully supported + logger.info("\n" + visualize_ast(sql, get_ast(42))) #assert parser.parse(sql) == get_ast(42) @@ -322,4 +362,5 @@ def test_query_43(): query = get_query(43) sql = query["pattern"] # TODO: INTERVAL unit keyword not fully supported + logger.info("\n" + visualize_ast(sql, get_ast(43))) #assert parser.parse(sql) == get_ast(43) \ No newline at end of file