diff --git a/.env.example b/.env.example new file mode 100644 index 000000000..364b75835 --- /dev/null +++ b/.env.example @@ -0,0 +1,15 @@ +# PageIndex - Environment Variables +# Copy this file to .env and fill in your API keys. +# Only one provider key is required depending on which model you use. + +# OpenAI (default) — required for gpt-* models +# Get your key at: https://platform.openai.com/api-keys +CHATGPT_API_KEY=your_openai_api_key_here + +# Anthropic — required for claude-* models +# Get your key at: https://console.anthropic.com/settings/keys +ANTHROPIC_API_KEY=your_anthropic_api_key_here + +# Google — required for gemini-* models +# Get your key at: https://aistudio.google.com/app/apikey +GOOGLE_API_KEY=your_google_api_key_here diff --git a/.gitignore b/.gitignore index 47d38baef..87ac6c099 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ chroma-collections.parquet chroma-embeddings.parquet .DS_Store .env* +!.env.example notebook SDK/* log/* diff --git a/pageindex/utils.py b/pageindex/utils.py index dc7acd888..313521a1e 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -1,5 +1,6 @@ import tiktoken import openai +import anthropic import logging import os from datetime import datetime @@ -18,94 +19,178 @@ from types import SimpleNamespace as config CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") +ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") +GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") + +# Gemini OpenAI-compatible base URL +_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/" + + +def get_provider(model): + """Detect LLM provider from model name.""" + if model and model.startswith("claude"): + return "anthropic" + elif model and model.startswith("gemini"): + return "google" + return "openai" + + +def _get_default_api_key(model): + provider = get_provider(model) + if provider == "anthropic": + return ANTHROPIC_API_KEY + elif provider == "google": + return GOOGLE_API_KEY + return CHATGPT_API_KEY + + +def _make_openai_client(model, api_key): + """Return a synchronous OpenAI-compatible client for OpenAI or Gemini.""" + if get_provider(model) == "google": + return openai.OpenAI(api_key=api_key, base_url=_GEMINI_BASE_URL) + return openai.OpenAI(api_key=api_key) + + +def _make_async_openai_client(model, api_key): + """Return an async OpenAI-compatible client for OpenAI or Gemini.""" + if get_provider(model) == "google": + return openai.AsyncOpenAI(api_key=api_key, base_url=_GEMINI_BASE_URL) + return openai.AsyncOpenAI(api_key=api_key) + def count_tokens(text, model=None): if not text: return 0 - enc = tiktoken.encoding_for_model(model) + try: + enc = tiktoken.encoding_for_model(model) + except KeyError: + # Fallback for non-OpenAI models (Claude, Gemini, etc.) + enc = tiktoken.get_encoding("cl100k_base") tokens = enc.encode(text) return len(tokens) -def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +def ChatGPT_API_with_finish_reason(model, prompt, api_key=None, chat_history=None): + if api_key is None: + api_key = _get_default_api_key(model) max_retries = 10 - client = openai.OpenAI(api_key=api_key) + provider = get_provider(model) + + if chat_history: + messages = chat_history + messages.append({"role": "user", "content": prompt}) + else: + messages = [{"role": "user", "content": prompt}] + for i in range(max_retries): try: - if chat_history: - messages = chat_history - messages.append({"role": "user", "content": prompt}) - else: - messages = [{"role": "user", "content": prompt}] - - response = client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - if response.choices[0].finish_reason == "length": - return response.choices[0].message.content, "max_output_reached" + if provider == "anthropic": + client = anthropic.Anthropic(api_key=api_key) + msg = client.messages.create( + model=model, + max_tokens=8096, + messages=messages, + temperature=0, + ) + text = msg.content[0].text + finish = "max_output_reached" if msg.stop_reason == "max_tokens" else "finished" + return text, finish else: - return response.choices[0].message.content, "finished" + client = _make_openai_client(model, api_key) + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + if response.choices[0].finish_reason == "length": + return response.choices[0].message.content, "max_output_reached" + else: + return response.choices[0].message.content, "finished" except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - time.sleep(1) # Wait for 1秒 before retrying + time.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "Error" - -def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +def ChatGPT_API(model, prompt, api_key=None, chat_history=None): + if api_key is None: + api_key = _get_default_api_key(model) max_retries = 10 - client = openai.OpenAI(api_key=api_key) + provider = get_provider(model) + + if chat_history: + messages = chat_history + messages.append({"role": "user", "content": prompt}) + else: + messages = [{"role": "user", "content": prompt}] + for i in range(max_retries): try: - if chat_history: - messages = chat_history - messages.append({"role": "user", "content": prompt}) + if provider == "anthropic": + client = anthropic.Anthropic(api_key=api_key) + msg = client.messages.create( + model=model, + max_tokens=8096, + messages=messages, + temperature=0, + ) + return msg.content[0].text else: - messages = [{"role": "user", "content": prompt}] - - response = client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - - return response.choices[0].message.content + client = _make_openai_client(model, api_key) + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + return response.choices[0].message.content except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - time.sleep(1) # Wait for 1秒 before retrying + time.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "Error" - -async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): + +async def ChatGPT_API_async(model, prompt, api_key=None): + if api_key is None: + api_key = _get_default_api_key(model) max_retries = 10 + provider = get_provider(model) messages = [{"role": "user", "content": prompt}] + for i in range(max_retries): try: - async with openai.AsyncOpenAI(api_key=api_key) as client: - response = await client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - return response.choices[0].message.content + if provider == "anthropic": + async with anthropic.AsyncAnthropic(api_key=api_key) as client: + msg = await client.messages.create( + model=model, + max_tokens=8096, + messages=messages, + temperature=0, + ) + return msg.content[0].text + else: + async with _make_async_openai_client(model, api_key) as client: + response = await client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + return response.choices[0].message.content except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - await asyncio.sleep(1) # Wait for 1s before retrying + await asyncio.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) - return "Error" + return "Error" def get_json_content(response): diff --git a/requirements.txt b/requirements.txt index 463db58f1..8a03401f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ openai==1.101.0 +anthropic>=0.40.0 +google-generativeai>=0.8.0 pymupdf==1.26.4 PyPDF2==3.0.1 python-dotenv==1.1.0