diff --git a/pageindex/__init__.py b/pageindex/__init__.py index 4606eb396..85bd52969 100644 --- a/pageindex/__init__.py +++ b/pageindex/__init__.py @@ -1,2 +1,16 @@ from .page_index import * -from .page_index_md import md_to_tree \ No newline at end of file +from .page_index_md import md_to_tree +from .retrieve import tool_get_document, tool_get_document_structure, tool_get_page_content + +try: + from .client import PageIndexClient +except ImportError as _e: + _import_error_msg = str(_e) + + class PageIndexClient: # type: ignore[no-redef] + def __init__(self, *args, **kwargs): + raise ImportError( + "PageIndexClient requires 'openai-agents'. " + "Install it with: pip install openai-agents\n" + f"(Original error: {_import_error_msg})" + ) diff --git a/pageindex/client.py b/pageindex/client.py new file mode 100644 index 000000000..8bb78f1c4 --- /dev/null +++ b/pageindex/client.py @@ -0,0 +1,214 @@ +import os +import uuid +import json +import asyncio +import concurrent.futures +from pathlib import Path +from agents import Agent, Runner, function_tool +from agents.stream_events import RunItemStreamEvent + +from .page_index import page_index +from .page_index_md import md_to_tree +from .retrieve import tool_get_document, tool_get_document_structure, tool_get_page_content + +AGENT_SYSTEM_PROMPT = """ +You are PageIndex, a document QA assistant. +TOOL USE: +- Call get_document() first to confirm status and page/line count. +- Call get_document_structure() to find relevant page ranges (use node summaries and start_index/end_index). +- Call get_page_content(pages="5-7") with tight ranges. Never fetch the whole doc. +- For Markdown, pages = line numbers from the structure (the line_num field). Use line_count from get_document() as the upper bound. +ANSWERING: Answer based only on tool output. Be concise. +""" + + +class PageIndexClient: + """ + A client for the PageIndex API. + Uses an OpenAI Agents SDK agent with 3 tools to answer document questions. + Flow: Index -> query_agent (tool-use loop) -> Answer + """ + def __init__(self, api_key: str = None, model: str = "gpt-4o-2024-11-20", workspace: str = None): + self.api_key = api_key or os.getenv("CHATGPT_API_KEY") + if self.api_key: + os.environ["CHATGPT_API_KEY"] = self.api_key + os.environ["OPENAI_API_KEY"] = self.api_key + self.model = model + self.workspace = Path(workspace).expanduser() if workspace else None + if self.workspace: + self.workspace.mkdir(parents=True, exist_ok=True) + self.documents = {} + if self.workspace: + self._load_workspace() + + def index(self, file_path: str, mode: str = "auto") -> str: + """Upload and index a document. Returns a document_id.""" + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + doc_id = str(uuid.uuid4()) + ext = os.path.splitext(file_path)[1].lower() + + is_pdf = ext == '.pdf' + is_md = ext in ['.md', '.markdown'] + + if mode == "pdf" or (mode == "auto" and is_pdf): + print(f"Indexing PDF: {file_path}") + result = page_index( + doc=file_path, + model=self.model, + if_add_node_summary='yes', + if_add_node_text='yes', + if_add_node_id='yes', + if_add_doc_description='yes' + ) + self.documents[doc_id] = { + 'id': doc_id, + 'path': file_path, + 'type': 'pdf', + 'structure': result['structure'], + 'doc_name': result.get('doc_name', ''), + 'doc_description': result.get('doc_description', '') + } + + elif mode == "md" or (mode == "auto" and is_md): + print(f"Indexing Markdown: {file_path}") + result = asyncio.run(md_to_tree( + md_path=file_path, + if_thinning=False, + if_add_node_summary='yes', + summary_token_threshold=200, + model=self.model, + if_add_doc_description='yes', + if_add_node_text='yes', + if_add_node_id='yes' + )) + self.documents[doc_id] = { + 'id': doc_id, + 'path': file_path, + 'type': 'md', + 'structure': result['structure'], + 'doc_name': result.get('doc_name', ''), + 'doc_description': result.get('doc_description', '') + } + else: + raise ValueError(f"Unsupported file format for: {file_path}") + + print(f"Indexing complete. Document ID: {doc_id}") + if self.workspace: + self._save_doc(doc_id) + return doc_id + + def _save_doc(self, doc_id: str): + path = self.workspace / f"{doc_id}.json" + with open(path, "w", encoding="utf-8") as f: + json.dump(self.documents[doc_id], f, ensure_ascii=False, indent=2) + + def _load_workspace(self): + loaded = 0 + for path in self.workspace.glob("*.json"): + try: + with open(path, "r", encoding="utf-8") as f: + doc = json.load(f) + self.documents[path.stem] = doc + loaded += 1 + except (json.JSONDecodeError, OSError) as e: + print(f"Warning: skipping corrupt workspace file {path.name}: {e}") + if loaded: + print(f"Loaded {loaded} document(s) from workspace.") + + # ── Public tool methods (thin wrappers) ─────────────────────────────────── + + def get_document(self, doc_id: str) -> str: + """Return document metadata JSON.""" + return tool_get_document(self.documents, doc_id) + + def get_document_structure(self, doc_id: str) -> str: + """Return document tree structure JSON (without text fields).""" + return tool_get_document_structure(self.documents, doc_id) + + def get_page_content(self, doc_id: str, pages: str) -> str: + """Return page content JSON for the given pages string (e.g. '5-7', '3,8', '12').""" + return tool_get_page_content(self.documents, doc_id, pages) + + # ── Agent core ──────────────────────────────────────────────────────────── + + def query_agent(self, doc_id: str, prompt: str, verbose: bool = False) -> str: + """ + Run the PageIndex agent for a document query. + The agent automatically calls get_document, get_document_structure, + and get_page_content tools as needed to answer the question. + + Args: + verbose: If True, print each tool call and result as they happen. + """ + client_self = self + + @function_tool + def get_document() -> str: + """Get document metadata: status, page count, name, and description.""" + return client_self.get_document(doc_id) + + @function_tool + def get_document_structure() -> str: + """Get the document's full tree structure (without text) to find relevant sections.""" + return client_self.get_document_structure(doc_id) + + @function_tool + def get_page_content(pages: str) -> str: + """ + Get the text content of specific pages or line numbers. + Use tight ranges: e.g. '5-7' for pages 5 to 7, '3,8' for pages 3 and 8, '12' for page 12. + For Markdown documents, use line numbers from the structure's line_num field. + """ + return client_self.get_page_content(doc_id, pages) + + agent = Agent( + name="PageIndex", + instructions=AGENT_SYSTEM_PROMPT, + tools=[get_document, get_document_structure, get_page_content], + model=self.model, + ) + + if not verbose: + result = Runner.run_sync(agent, prompt) + return result.final_output + + # verbose mode: stream events and print tool calls + async def _run_verbose(): + turn = 0 + stream = Runner.run_streamed(agent, prompt) + async for event in stream.stream_events(): + if not isinstance(event, RunItemStreamEvent): + continue + if event.name == "tool_called": + turn += 1 + raw = event.item.raw_item + args = getattr(raw, "arguments", "{}") + print(f"\n[Turn {turn}] → {raw.name}({args})") + elif event.name == "tool_output": + output = str(event.item.output) + preview = output[:200] + "..." if len(output) > 200 else output + print(f" ← {preview}") + return stream.final_output + + try: + asyncio.get_running_loop() + # Inside a running event loop (e.g. Jupyter) — run in a thread + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + return pool.submit(asyncio.run, _run_verbose()).result() + except RuntimeError: + return asyncio.run(_run_verbose()) + + # ── Public query API ────────────────────────────────────────────────────── + + def query(self, doc_id: str, prompt: str) -> str: + """Ask a question about an indexed document. Returns the agent's answer.""" + return self.query_agent(doc_id, prompt) + + def query_stream(self, doc_id: str, prompt: str): + """ + Ask a question about an indexed document with streaming output. + MVP: yields the full answer as a single chunk. + """ + yield self.query_agent(doc_id, prompt) diff --git a/pageindex/retrieve.py b/pageindex/retrieve.py new file mode 100644 index 000000000..3ce85ce68 --- /dev/null +++ b/pageindex/retrieve.py @@ -0,0 +1,137 @@ +import json +import PyPDF2 + +try: + from .utils import get_number_of_pages, remove_fields +except ImportError: + from utils import get_number_of_pages, remove_fields + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +def _parse_pages(pages: str) -> list[int]: + """Parse a pages string like '5-7', '3,8', or '12' into a sorted list of ints.""" + result = [] + for part in pages.split(','): + part = part.strip() + if '-' in part: + start, end = part.split('-', 1) + result.extend(range(int(start.strip()), int(end.strip()) + 1)) + else: + result.append(int(part)) + return sorted(set(result)) + + +def _count_pages(doc_info: dict) -> int: + """Return total page count for a document.""" + if doc_info.get('type') == 'pdf': + return get_number_of_pages(doc_info['path']) + # For MD, find max line_num across all nodes + max_line = 0 + def _traverse(nodes): + nonlocal max_line + for node in nodes: + ln = node.get('line_num', 0) + if ln and ln > max_line: + max_line = ln + if node.get('nodes'): + _traverse(node['nodes']) + _traverse(doc_info.get('structure', [])) + return max_line + + +def _get_pdf_page_content(doc_info: dict, page_nums: list[int]) -> list[dict]: + """Extract text for specific PDF pages (1-indexed), opening the PDF once.""" + path = doc_info['path'] + with open(path, 'rb') as f: + pdf_reader = PyPDF2.PdfReader(f) + total = len(pdf_reader.pages) + valid_pages = [p for p in page_nums if 1 <= p <= total] + return [ + {'page': p, 'content': pdf_reader.pages[p - 1].extract_text() or ''} + for p in valid_pages + ] + + +def _get_md_page_content(doc_info: dict, page_nums: list[int]) -> list[dict]: + """ + For Markdown documents, 'pages' are line numbers. + Find nodes whose line_num falls within the requested set and return their text. + """ + page_set = set(page_nums) + results = [] + seen = set() + + def _traverse(nodes): + for node in nodes: + ln = node.get('line_num') + if ln and ln in page_set and ln not in seen: + seen.add(ln) + results.append({'page': ln, 'content': node.get('text', '')}) + if node.get('nodes'): + _traverse(node['nodes']) + + _traverse(doc_info.get('structure', [])) + results.sort(key=lambda x: x['page']) + return results + + +# ── Tool functions ──────────────────────────────────────────────────────────── + +def tool_get_document(documents: dict, doc_id: str) -> str: + """Return JSON with document metadata: doc_id, doc_name, doc_description, type, status, page_count (PDF) or line_count (Markdown).""" + doc_info = documents.get(doc_id) + if not doc_info: + return json.dumps({'error': f'Document {doc_id} not found'}) + result = { + 'doc_id': doc_id, + 'doc_name': doc_info.get('doc_name', ''), + 'doc_description': doc_info.get('doc_description', ''), + 'type': doc_info.get('type', ''), + 'status': 'completed', + } + if doc_info.get('type') == 'pdf': + result['page_count'] = _count_pages(doc_info) + else: + result['line_count'] = _count_pages(doc_info) + return json.dumps(result) + + +def tool_get_document_structure(documents: dict, doc_id: str) -> str: + """Return tree structure JSON with text fields removed (saves tokens).""" + doc_info = documents.get(doc_id) + if not doc_info: + return json.dumps({'error': f'Document {doc_id} not found'}) + structure = doc_info.get('structure', []) + structure_no_text = remove_fields(structure, fields=['text']) + return json.dumps(structure_no_text, ensure_ascii=False) + + +def tool_get_page_content(documents: dict, doc_id: str, pages: str) -> str: + """ + Retrieve page content for a document. + + pages format: '5-7', '3,8', or '12' + For PDF: pages are physical page numbers (1-indexed). + For Markdown: pages are line numbers corresponding to node headers. + + Returns JSON list of {'page': int, 'content': str}. + """ + doc_info = documents.get(doc_id) + if not doc_info: + return json.dumps({'error': f'Document {doc_id} not found'}) + + try: + page_nums = _parse_pages(pages) + except (ValueError, AttributeError) as e: + return json.dumps({'error': f'Invalid pages format: {pages!r}. Use "5-7", "3,8", or "12". Error: {e}'}) + + try: + if doc_info.get('type') == 'pdf': + content = _get_pdf_page_content(doc_info, page_nums) + else: + content = _get_md_page_content(doc_info, page_nums) + except Exception as e: + return json.dumps({'error': f'Failed to read page content: {e}'}) + + return json.dumps(content, ensure_ascii=False) diff --git a/pageindex/utils.py b/pageindex/utils.py index dc7acd888..0c05d9566 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -2,6 +2,7 @@ import openai import logging import os +import textwrap from datetime import datetime import time import json @@ -86,6 +87,21 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): return "Error" +def ChatGPT_API_stream(model, prompt, api_key=CHATGPT_API_KEY): + """Return a generator that yields token chunks (str) one at a time.""" + client = openai.OpenAI(api_key=api_key) + with client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + temperature=0, + stream=True, + ) as stream: + for chunk in stream: + delta = chunk.choices[0].delta.content + if delta: + yield delta + + async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): max_retries = 10 messages = [{"role": "user", "content": prompt}] @@ -709,4 +725,29 @@ def load(self, user_opt=None) -> config: self._validate_keys(user_dict) merged = {**self._default_dict, **user_dict} - return config(**merged) \ No newline at end of file + return config(**merged) + +def create_node_mapping(tree): + """Create a flat dict mapping node_id to node for quick lookup.""" + mapping = {} + def _traverse(nodes): + for node in nodes: + if node.get('node_id'): + mapping[node['node_id']] = node + if node.get('nodes'): + _traverse(node['nodes']) + _traverse(tree) + return mapping + +def print_tree(tree, indent=0): + for node in tree: + summary = node.get('summary') or node.get('prefix_summary', '') + summary_str = f" — {summary[:60]}..." if summary else "" + print(' ' * indent + f"[{node.get('node_id', '?')}] {node.get('title', '')}{summary_str}") + if node.get('nodes'): + print_tree(node['nodes'], indent + 1) + +def print_wrapped(text, width=100): + for line in text.splitlines(): + print(textwrap.fill(line, width=width)) + diff --git a/requirements.txt b/requirements.txt index 463db58f1..b4d743905 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ openai==1.101.0 +openai-agents pymupdf==1.26.4 PyPDF2==3.0.1 python-dotenv==1.1.0 diff --git a/test_client.py b/test_client.py new file mode 100644 index 000000000..ff2c179c8 --- /dev/null +++ b/test_client.py @@ -0,0 +1,67 @@ +""" +PageIndex Agent SDK Demo +4-step demo using the 3-tool agent: + Step 1 — Download, index PDF, and inspect tree structure + Step 2 — Inspect document metadata via get_document() + Step 3 — Ask a question (agent auto-calls tools) + Step 4 — Reload from workspace and verify persistence +""" +import os +import requests +from pageindex import PageIndexClient +import pageindex.utils as utils + +PDF_URL = "https://arxiv.org/pdf/2501.12948.pdf" +PDF_PATH = "tests/pdfs/deepseek-r1.pdf" +WORKSPACE = "~/.pageindex" + +# ── Download PDF if needed ──────────────────────────────────────────────────── +if not os.path.exists(PDF_PATH): + print(f"Downloading {PDF_URL} ...") + os.makedirs(os.path.dirname(PDF_PATH), exist_ok=True) + with requests.get(PDF_URL, stream=True, timeout=30) as r: + r.raise_for_status() + with open(PDF_PATH, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + print("Download complete.\n") + +# ── Setup ───────────────────────────────────────────────────────────────────── +client = PageIndexClient(workspace=WORKSPACE) + +# ── Step 1: Index + Tree ────────────────────────────────────────────────────── +print("=" * 60) +print("Step 1: Indexing PDF and inspecting tree structure") +print("=" * 60) +doc_id = client.index(PDF_PATH) +print(f"\nDocument ID: {doc_id}") +print("\nTree Structure:") +utils.print_tree(client.documents[doc_id]["structure"]) + +# ── Step 2: Document Metadata ───────────────────────────────────────────────── +print("\n" + "=" * 60) +print("Step 2: Document Metadata (get_document)") +print("=" * 60) +metadata = client.get_document(doc_id) +print(metadata) + +# ── Step 3: Agent Query ─────────────────────────────────────────────────────── +print("\n" + "=" * 60) +print("Step 3: Agent Query (auto tool-use)") +print("=" * 60) +question = "What are the main conclusions of this paper?" +print(f"\nQuestion: '{question}'\n") +answer = client.query_agent(doc_id, question) +print("Answer:") +print(answer) + +# ── Step 4: Persistence Verification ───────────────────────────────────────── +print("\n" + "=" * 60) +print("Step 4: Persistence — reload without re-indexing") +print("=" * 60) +client2 = PageIndexClient(workspace=WORKSPACE) +answer2 = client2.query_agent(doc_id, "What are the main conclusions of this paper?", verbose=True) +print("Answer from reloaded client:") +print(answer2) +print("\nPersistence verified. ✓")