Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions pageindex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@
from types import SimpleNamespace as config

CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")
# Novita AI - OpenAI-compatible API support
NOVITA_API_KEY = os.getenv("NOVITA_API_KEY")
NOVITA_BASE_URL = "https://api.novita.ai/openai"
NOVITA_MODEL = os.getenv("NOVITA_MODEL")
NOVITA_DEFAULT_MODEL = "deepseek/deepseek-r1"
DEFAULT_OPENAI_MODEL = "gpt-4o-2024-11-20"

def get_openai_client(api_key=CHATGPT_API_KEY, async_client=False):
"""Get OpenAI client - supports both OpenAI and Novita AI (OpenAI-compatible)."""
client_cls = openai.AsyncOpenAI if async_client else openai.OpenAI
if NOVITA_API_KEY and api_key == CHATGPT_API_KEY:
return client_cls(api_key=NOVITA_API_KEY, base_url=NOVITA_BASE_URL)
return client_cls(api_key=api_key)

def resolve_chat_model(model, api_key=CHATGPT_API_KEY):
"""Resolve model name for OpenAI-compatible providers."""
if NOVITA_API_KEY and api_key == CHATGPT_API_KEY and model == DEFAULT_OPENAI_MODEL:
return NOVITA_MODEL or NOVITA_DEFAULT_MODEL
return model

def count_tokens(text, model=None):
if not text:
Expand All @@ -28,7 +47,8 @@ def count_tokens(text, model=None):

def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
max_retries = 10
client = openai.OpenAI(api_key=api_key)
client = get_openai_client(api_key=api_key)
resolved_model = resolve_chat_model(model, api_key=api_key)
for i in range(max_retries):
try:
if chat_history:
Expand All @@ -38,7 +58,7 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_
messages = [{"role": "user", "content": prompt}]

response = client.chat.completions.create(
model=model,
model=resolved_model,
messages=messages,
temperature=0,
)
Expand All @@ -60,7 +80,8 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_

def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
max_retries = 10
client = openai.OpenAI(api_key=api_key)
client = get_openai_client(api_key=api_key)
resolved_model = resolve_chat_model(model, api_key=api_key)
for i in range(max_retries):
try:
if chat_history:
Expand All @@ -70,7 +91,7 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
messages = [{"role": "user", "content": prompt}]

response = client.chat.completions.create(
model=model,
model=resolved_model,
messages=messages,
temperature=0,
)
Expand All @@ -88,12 +109,13 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):

async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY):
max_retries = 10
resolved_model = resolve_chat_model(model, api_key=api_key)
messages = [{"role": "user", "content": prompt}]
for i in range(max_retries):
try:
async with openai.AsyncOpenAI(api_key=api_key) as client:
async with get_openai_client(api_key=api_key, async_client=True) as client:
response = await client.chat.completions.create(
model=model,
model=resolved_model,
messages=messages,
temperature=0,
)
Expand Down Expand Up @@ -709,4 +731,4 @@ def load(self, user_opt=None) -> config:

self._validate_keys(user_dict)
merged = {**self._default_dict, **user_dict}
return config(**merged)
return config(**merged)