diff --git a/src/lang2sql/components/__init__.py b/src/lang2sql/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/lang2sql/components/retrieval/__init__.py b/src/lang2sql/components/retrieval/__init__.py new file mode 100644 index 0000000..3d9f791 --- /dev/null +++ b/src/lang2sql/components/retrieval/__init__.py @@ -0,0 +1,3 @@ +from .keyword import KeywordRetriever + +__all__ = ["KeywordRetriever"] diff --git a/src/lang2sql/components/retrieval/_bm25.py b/src/lang2sql/components/retrieval/_bm25.py new file mode 100644 index 0000000..d3e87c0 --- /dev/null +++ b/src/lang2sql/components/retrieval/_bm25.py @@ -0,0 +1,130 @@ +""" +Internal BM25 index — stdlib only (math, collections). + +BM25 parameters: + k1 = 1.5 (term frequency saturation) + b = 0.75 (document length normalization) + +Tokenization: text.lower().split() (whitespace, no external deps) +""" + +from __future__ import annotations + +import math +from collections import Counter +from typing import Any + +_K1 = 1.5 +_B = 0.75 + + +def _tokenize(text: str) -> list[str]: + return text.lower().split() + + +def _extract_text(value: Any) -> list[str]: + """Recursively extract text tokens from any value (str, list, dict, other).""" + if isinstance(value, str): + return [value] + if isinstance(value, list): + result: list[str] = [] + for item in value: + result.extend(_extract_text(item)) + return result + if isinstance(value, dict): + result = [] + for k, v in value.items(): + result.append(str(k)) + result.extend(_extract_text(v)) + return result + return [str(value)] + + +def _entry_to_text(entry: dict[str, Any], index_fields: list[str]) -> str: + """ + Convert a catalog dict entry into a single text string for indexing. + + Handles: + - str fields → joined as-is + - dict fields → "key value key value ..." (for columns: {col_name: col_desc}) + - list fields → each element extracted recursively + - other types → str(value) + """ + parts: list[str] = [] + for field in index_fields: + value = entry.get(field) + if value is None: + continue + parts.extend(_extract_text(value)) + return " ".join(parts) + + +class _BM25Index: + """ + In-memory BM25 index over a list[dict] catalog. + + Usage: + index = _BM25Index(catalog, index_fields=["name", "description", "columns"]) + scores = index.score("주문 테이블") # list[float], one per catalog entry + """ + + def __init__( + self, + catalog: list[dict[str, Any]], + index_fields: list[str], + ) -> None: + self._catalog = catalog + self._n = len(catalog) + + # Tokenize each document + self._docs: list[list[str]] = [ + _tokenize(_entry_to_text(entry, index_fields)) for entry in catalog + ] + + # Term frequencies per document + self._tfs: list[Counter[str]] = [Counter(doc) for doc in self._docs] + + # Document lengths + doc_lengths = [len(doc) for doc in self._docs] + self._avgdl: float = sum(doc_lengths) / self._n if self._n > 0 else 0.0 + + # Inverted index: term → set of doc indices that contain it + self._df: Counter[str] = Counter() + for tf in self._tfs: + for term in tf: + self._df[term] += 1 + + def score(self, query: str) -> list[float]: + """ + Return a BM25 score for each catalog entry. + + Args: + query: Natural language query string. + + Returns: + List of float scores, one per catalog entry, in original order. + """ + if self._n == 0: + return [] + + query_terms = _tokenize(query) + scores = [0.0] * self._n + + for term in query_terms: + df_t = self._df.get(term, 0) + if df_t == 0: + continue + + # IDF — smoothed to avoid log(0) + idf = math.log((self._n - df_t + 0.5) / (df_t + 0.5) + 1) + + for i, tf in enumerate(self._tfs): + tf_t = tf.get(term, 0) + if tf_t == 0: + continue + + dl = len(self._docs[i]) + denom = tf_t + _K1 * (1 - _B + _B * dl / self._avgdl) + scores[i] += idf * (tf_t * (_K1 + 1)) / denom + + return scores diff --git a/src/lang2sql/components/retrieval/keyword.py b/src/lang2sql/components/retrieval/keyword.py new file mode 100644 index 0000000..1444086 --- /dev/null +++ b/src/lang2sql/components/retrieval/keyword.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from typing import Any, Optional + +from ...core.base import BaseComponent +from ...core.context import RunContext +from ...core.hooks import TraceHook +from ._bm25 import _BM25Index + +_DEFAULT_INDEX_FIELDS = ["name", "description", "columns"] + + +class KeywordRetriever(BaseComponent): + """ + BM25-based keyword retriever over a table catalog. + + Indexes catalog entries at init time (in-memory). + On each call, reads ``run.query`` and writes top-N matches + into ``run.schema_selected``. + + Args: + catalog: List of table dicts. Each dict should have at minimum + ``name`` (str) and ``description`` (str). + Optional keys: ``columns`` (dict[str, str]), ``meta`` (dict). + top_n: Maximum number of results to return. Defaults to 5. + index_fields: Fields to index. Defaults to ["name", "description", "columns"]. + Pass a custom list to replace the default (complete override). + name: Component name for tracing. Defaults to "KeywordRetriever". + hook: Optional TraceHook for observability. + + Example:: + + retriever = KeywordRetriever(catalog=[ + {"name": "orders", "description": "주문 정보 테이블"}, + ]) + run = retriever(RunContext(query="주문 조회")) + print(run.schema_selected) # [{"name": "orders", ...}] + """ + + def __init__( + self, + *, + catalog: list[dict[str, Any]], + top_n: int = 5, + index_fields: Optional[list[str]] = None, + name: Optional[str] = None, + hook: Optional[TraceHook] = None, + ) -> None: + super().__init__(name=name or "KeywordRetriever", hook=hook) + self._catalog = catalog + self._top_n = top_n + self._index_fields = ( + index_fields if index_fields is not None else _DEFAULT_INDEX_FIELDS + ) + self._index = _BM25Index(catalog, self._index_fields) + + def run(self, run: RunContext) -> RunContext: + """ + Search the catalog with BM25 and store results in ``run.schema_selected``. + + Args: + run: Current RunContext. Reads ``run.query``. + + Returns: + The same RunContext with ``run.schema_selected`` set to a + ranked list[dict] (BM25 score descending). Empty list if no match. + """ + if not self._catalog: + run.schema_selected = [] + return run + + scores = self._index.score(run.query) + + # Pair each catalog entry with its score, sort descending + ranked = sorted( + zip(scores, self._catalog), + key=lambda x: x[0], + reverse=True, + ) + + # Return up to top_n entries that have a positive score + results = [entry for score, entry in ranked[: self._top_n] if score > 0.0] + + run.schema_selected = results + return run diff --git a/src/lang2sql/core/ports.py b/src/lang2sql/core/ports.py new file mode 100644 index 0000000..7452d4c --- /dev/null +++ b/src/lang2sql/core/ports.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import Protocol + + +class EmbeddingPort(Protocol): + """ + Placeholder — will be implemented in OQ-2 (VectorRetriever). + + Abstracts embedding backends (OpenAI, Azure, Bedrock, etc.) + so VectorRetriever is not coupled to any specific provider. + """ + + def embed_query(self, text: str) -> list[float]: ... + + def embed_texts(self, texts: list[str]) -> list[list[float]]: ... diff --git a/tests/test_components_keyword_retriever.py b/tests/test_components_keyword_retriever.py new file mode 100644 index 0000000..27b9d5c --- /dev/null +++ b/tests/test_components_keyword_retriever.py @@ -0,0 +1,227 @@ +""" +Tests for KeywordRetriever — 14 cases. + +Pattern follows test_core_base.py: +- pytest, inline fixtures, MemoryHook +""" + +import pytest + +from lang2sql.components.retrieval import KeywordRetriever +from lang2sql.core.context import RunContext +from lang2sql.core.hooks import MemoryHook +from lang2sql.flows.baseline import SequentialFlow + +# ------------------------- +# Shared test catalog +# ------------------------- + +ORDER_TABLE = { + "name": "order_table", + "description": "고객 주문 정보를 저장하는 테이블", + "columns": {"order_id": "주문 고유 ID", "amount": "주문 금액"}, + "meta": {"primary_key": "order_id", "tags": ["finance", "core"]}, +} + +USER_TABLE = { + "name": "user_table", + "description": "사용자 계정 정보 테이블", + "columns": {"user_id": "사용자 고유 ID", "email": "이메일"}, + "meta": {"primary_key": "user_id"}, +} + +PRODUCT_TABLE = { + "name": "product_table", + "description": "상품 목록 및 재고 테이블", + "columns": {"product_id": "상품 ID", "stock": "재고 수량"}, +} + +CATALOG = [ORDER_TABLE, USER_TABLE, PRODUCT_TABLE] + + +# ------------------------- +# Tests +# ------------------------- + + +def test_basic_search_returns_relevant_table(): + """'주문' 질문 → order_table이 top 위치.""" + retriever = KeywordRetriever(catalog=CATALOG) + run = retriever(RunContext(query="주문 정보 조회")) + + assert run.schema_selected + assert run.schema_selected[0]["name"] == "order_table" + + +def test_top_n_limits_results(): + """top_n=2 → 최대 2개 반환.""" + retriever = KeywordRetriever(catalog=CATALOG, top_n=2) + run = retriever(RunContext(query="테이블")) + + assert len(run.schema_selected) <= 2 + + +def test_top_n_larger_than_catalog(): + """top_n=10, catalog 3개 → 최대 3개 반환.""" + retriever = KeywordRetriever(catalog=CATALOG, top_n=10) + run = retriever(RunContext(query="테이블")) + + assert len(run.schema_selected) <= len(CATALOG) + + +def test_zero_results_returns_empty_list(): + """완전히 무관한 query → schema_selected == [].""" + retriever = KeywordRetriever(catalog=CATALOG) + run = retriever(RunContext(query="xyzzy_no_match_token_12345")) + + assert run.schema_selected == [] + + +def test_schema_selected_is_list_of_dict(): + """결과가 list[dict]인지 확인.""" + retriever = KeywordRetriever(catalog=CATALOG) + run = retriever(RunContext(query="주문")) + + assert isinstance(run.schema_selected, list) + assert len(run.schema_selected) > 0 + assert isinstance(run.schema_selected[0], dict) + + +def test_returns_runcontext(): + """run 메서드가 RunContext를 반환하는지 확인.""" + retriever = KeywordRetriever(catalog=CATALOG) + result = retriever(RunContext(query="주문")) + + assert isinstance(result, RunContext) + + +def test_hook_start_end_events(): + """MemoryHook으로 start/end 이벤트 확인.""" + hook = MemoryHook() + retriever = KeywordRetriever(catalog=CATALOG, hook=hook) + retriever(RunContext(query="주문")) + + assert len(hook.events) == 2 + assert hook.events[0].name == "component.run" + assert hook.events[0].phase == "start" + assert hook.events[1].name == "component.run" + assert hook.events[1].phase == "end" + assert hook.events[1].duration_ms is not None + assert hook.events[1].duration_ms >= 0.0 + + +def test_empty_catalog(): + """catalog=[] → schema_selected == [].""" + retriever = KeywordRetriever(catalog=[]) + run = retriever(RunContext(query="주문")) + + assert run.schema_selected == [] + + +def test_meta_preserved_in_results(): + """meta 필드가 결과 dict에 그대로 포함되는지 확인.""" + retriever = KeywordRetriever(catalog=CATALOG) + run = retriever(RunContext(query="주문")) + + result = run.schema_selected[0] + assert "meta" in result + assert result["meta"]["primary_key"] == "order_id" + + +def test_index_fields_meta(): + """index_fields=["description","meta"] → meta 텍스트도 검색에 반영.""" + # finance라는 단어는 meta.tags에만 존재 (name/description/columns에는 없음) + catalog = [ + { + "name": "alpha", + "description": "일반 데이터 저장소", + "meta": {"tags": ["finance", "core"]}, + }, + { + "name": "beta", + "description": "기타 로그 테이블", + "meta": {"tags": ["logging"]}, + }, + ] + + retriever = KeywordRetriever( + catalog=catalog, + index_fields=["description", "meta"], + ) + run = retriever(RunContext(query="finance")) + + assert len(run.schema_selected) > 0 + assert run.schema_selected[0]["name"] == "alpha" + + +def test_result_order_by_relevance(): + """관련도 높은 테이블이 앞에 위치하는지 확인.""" + catalog = [ + { + "name": "order_summary", + "description": "주문 요약 주문 집계 주문 통계", # '주문' 3회 + }, + { + "name": "user_table", + "description": "사용자 주문 기록", # '주문' 1회 + }, + ] + + retriever = KeywordRetriever(catalog=catalog) + run = retriever(RunContext(query="주문")) + + assert len(run.schema_selected) >= 2 + assert run.schema_selected[0]["name"] == "order_summary" + + +def test_columns_text_indexed(): + """컬럼명/컬럼설명으로 검색 가능한지 확인.""" + catalog = [ + { + "name": "sales", + "description": "판매 데이터", + "columns": {"revenue": "매출액", "region": "지역"}, + }, + { + "name": "logs", + "description": "시스템 로그", + "columns": {"event_type": "이벤트 유형"}, + }, + ] + + retriever = KeywordRetriever(catalog=catalog) + run = retriever(RunContext(query="매출액")) + + assert len(run.schema_selected) > 0 + assert run.schema_selected[0]["name"] == "sales" + + +def test_missing_optional_fields_no_error(): + """columns/meta 없는 entry가 있어도 crash 없음.""" + catalog = [ + {"name": "minimal", "description": "최소 필드만 있는 테이블"}, + { + "name": "full", + "description": "전체 필드", + "columns": {"id": "ID"}, + "meta": {}, + }, + ] + + retriever = KeywordRetriever(catalog=catalog) + # 예외가 발생하지 않으면 테스트 통과 + run = retriever(RunContext(query="테이블")) + assert isinstance(run.schema_selected, list) + + +def test_end_to_end_in_sequential_flow(): + """SequentialFlow(steps=[retriever]).run_query('...') 가 동작하는지 확인.""" + retriever = KeywordRetriever(catalog=CATALOG) + flow = SequentialFlow(steps=[retriever]) + + run = flow.run_query("주문 내역 확인") + + assert isinstance(run, RunContext) + assert isinstance(run.schema_selected, list) + assert len(run.schema_selected) > 0 + assert run.schema_selected[0]["name"] == "order_table"