Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
3 changes: 3 additions & 0 deletions src/lang2sql/components/retrieval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .keyword import KeywordRetriever

__all__ = ["KeywordRetriever"]
130 changes: 130 additions & 0 deletions src/lang2sql/components/retrieval/_bm25.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
Internal BM25 index — stdlib only (math, collections).

BM25 parameters:
k1 = 1.5 (term frequency saturation)
b = 0.75 (document length normalization)

Tokenization: text.lower().split() (whitespace, no external deps)
"""

from __future__ import annotations

import math
from collections import Counter
from typing import Any

_K1 = 1.5
_B = 0.75


def _tokenize(text: str) -> list[str]:
return text.lower().split()


def _extract_text(value: Any) -> list[str]:
"""Recursively extract text tokens from any value (str, list, dict, other)."""
if isinstance(value, str):
return [value]
if isinstance(value, list):
result: list[str] = []
for item in value:
result.extend(_extract_text(item))
return result
if isinstance(value, dict):
result = []
for k, v in value.items():
result.append(str(k))
result.extend(_extract_text(v))
return result
return [str(value)]


def _entry_to_text(entry: dict[str, Any], index_fields: list[str]) -> str:
"""
Convert a catalog dict entry into a single text string for indexing.

Handles:
- str fields → joined as-is
- dict fields → "key value key value ..." (for columns: {col_name: col_desc})
- list fields → each element extracted recursively
- other types → str(value)
"""
parts: list[str] = []
for field in index_fields:
value = entry.get(field)
if value is None:
continue
parts.extend(_extract_text(value))
return " ".join(parts)


class _BM25Index:
"""
In-memory BM25 index over a list[dict] catalog.

Usage:
index = _BM25Index(catalog, index_fields=["name", "description", "columns"])
scores = index.score("주문 테이블") # list[float], one per catalog entry
"""

def __init__(
self,
catalog: list[dict[str, Any]],
index_fields: list[str],
) -> None:
self._catalog = catalog
self._n = len(catalog)

# Tokenize each document
self._docs: list[list[str]] = [
_tokenize(_entry_to_text(entry, index_fields)) for entry in catalog
]

# Term frequencies per document
self._tfs: list[Counter[str]] = [Counter(doc) for doc in self._docs]

# Document lengths
doc_lengths = [len(doc) for doc in self._docs]
self._avgdl: float = sum(doc_lengths) / self._n if self._n > 0 else 0.0

# Inverted index: term → set of doc indices that contain it
self._df: Counter[str] = Counter()
for tf in self._tfs:
for term in tf:
self._df[term] += 1

def score(self, query: str) -> list[float]:
"""
Return a BM25 score for each catalog entry.

Args:
query: Natural language query string.

Returns:
List of float scores, one per catalog entry, in original order.
"""
if self._n == 0:
return []

query_terms = _tokenize(query)
scores = [0.0] * self._n

for term in query_terms:
df_t = self._df.get(term, 0)
if df_t == 0:
continue

# IDF — smoothed to avoid log(0)
idf = math.log((self._n - df_t + 0.5) / (df_t + 0.5) + 1)

for i, tf in enumerate(self._tfs):
tf_t = tf.get(term, 0)
if tf_t == 0:
continue

dl = len(self._docs[i])
denom = tf_t + _K1 * (1 - _B + _B * dl / self._avgdl)
scores[i] += idf * (tf_t * (_K1 + 1)) / denom

return scores
85 changes: 85 additions & 0 deletions src/lang2sql/components/retrieval/keyword.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

from typing import Any, Optional

from ...core.base import BaseComponent
from ...core.context import RunContext
from ...core.hooks import TraceHook
from ._bm25 import _BM25Index

_DEFAULT_INDEX_FIELDS = ["name", "description", "columns"]


class KeywordRetriever(BaseComponent):
"""
BM25-based keyword retriever over a table catalog.

Indexes catalog entries at init time (in-memory).
On each call, reads ``run.query`` and writes top-N matches
into ``run.schema_selected``.

Args:
catalog: List of table dicts. Each dict should have at minimum
``name`` (str) and ``description`` (str).
Optional keys: ``columns`` (dict[str, str]), ``meta`` (dict).
top_n: Maximum number of results to return. Defaults to 5.
index_fields: Fields to index. Defaults to ["name", "description", "columns"].
Pass a custom list to replace the default (complete override).
name: Component name for tracing. Defaults to "KeywordRetriever".
hook: Optional TraceHook for observability.

Example::

retriever = KeywordRetriever(catalog=[
{"name": "orders", "description": "주문 정보 테이블"},
])
run = retriever(RunContext(query="주문 조회"))
print(run.schema_selected) # [{"name": "orders", ...}]
"""

def __init__(
self,
*,
catalog: list[dict[str, Any]],
top_n: int = 5,
index_fields: Optional[list[str]] = None,
name: Optional[str] = None,
hook: Optional[TraceHook] = None,
) -> None:
super().__init__(name=name or "KeywordRetriever", hook=hook)
self._catalog = catalog
self._top_n = top_n
self._index_fields = (
index_fields if index_fields is not None else _DEFAULT_INDEX_FIELDS
)
self._index = _BM25Index(catalog, self._index_fields)

def run(self, run: RunContext) -> RunContext:
"""
Search the catalog with BM25 and store results in ``run.schema_selected``.

Args:
run: Current RunContext. Reads ``run.query``.

Returns:
The same RunContext with ``run.schema_selected`` set to a
ranked list[dict] (BM25 score descending). Empty list if no match.
"""
if not self._catalog:
run.schema_selected = []
return run

scores = self._index.score(run.query)

# Pair each catalog entry with its score, sort descending
ranked = sorted(
zip(scores, self._catalog),
key=lambda x: x[0],
reverse=True,
)

# Return up to top_n entries that have a positive score
results = [entry for score, entry in ranked[: self._top_n] if score > 0.0]

run.schema_selected = results
return run
16 changes: 16 additions & 0 deletions src/lang2sql/core/ports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations

from typing import Protocol


class EmbeddingPort(Protocol):
"""
Placeholder — will be implemented in OQ-2 (VectorRetriever).

Abstracts embedding backends (OpenAI, Azure, Bedrock, etc.)
so VectorRetriever is not coupled to any specific provider.
"""

def embed_query(self, text: str) -> list[float]: ...

def embed_texts(self, texts: list[str]) -> list[list[float]]: ...
Loading