diff --git a/pageindex/page_index.py b/pageindex/page_index.py index 882fb5dea..a4ca299da 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -7,9 +7,104 @@ from .utils import * import os from concurrent.futures import ThreadPoolExecutor, as_completed +from pydantic import BaseModel, Field, ValidationError +from typing import Literal, Type, TypeVar, Optional, Union, List +T = TypeVar("T", bound=BaseModel) + +def parse_llm_response(model_cls: Type[T], raw_response: str, logger=None) -> T | None: + """ + Parse and validate LLM JSON output using Pydantic. + Returns None if validation fails. + """ + try: + data = extract_json(raw_response) + normalizable_fields = { + 'answer', 'start_begin', 'toc_detected', 'completed', + 'page_index_given_in_toc', 'start' + } + + # Normalize only specific fields, recursively handle nested structures + def normalize_data(obj): + if isinstance(obj, dict): + return { + k: v.strip().lower() if isinstance(v, str) and k in normalizable_fields else normalize_data(v) + for k, v in obj.items() + } + elif isinstance(obj, list): + return [normalize_data(item) for item in obj] + return obj + + data = normalize_data(data) + + return model_cls(**data) + + except (ValidationError, ValueError, TypeError) as e: + if logger: + logger.error( + f"LLM response validation failed for {model_cls.__name__}: {e}" + ) + return None ################### check title in page ######################################################### +class CheckTitleAppearanceResponse(BaseModel): + thinking: str = Field(..., min_length=1) + answer: Literal["yes", "no"] + +class CheckTitleAppearanceInStartResponse(BaseModel): + thinking: str = Field(..., min_length=1) + start_begin: Literal["yes", "no"] + +class TocDetectorResponse(BaseModel): + thinking: str = Field(..., min_length=1) + toc_detected: Literal["yes", "no"] + +class TocCompletionResponse(BaseModel): + thinking: str = Field(..., min_length=1) + completed: Literal["yes", "no"] + +class PageIndexDetectionResponse(BaseModel): + thinking: str = Field(..., min_length=1) + page_index_given_in_toc: Literal["yes", "no"] + +class TocIndexItem(BaseModel): + structure: Optional[str] = None + title: str + physical_index: str + +class TocIndexResponse(BaseModel): + root: List[TocIndexItem] + +class TocTransformerItem(BaseModel): + structure: Optional[str] = None + title: str + page: Union[str, int, None] = None + +class TocTransformerResponse(BaseModel): + table_of_contents: List[TocTransformerItem] + +class AddPageNumberItem(BaseModel): + structure: Optional[str] = None + title: str + start: Literal["yes", "no"] + physical_index: Optional[str] = None + +class AddPageNumberResponse(BaseModel): + root: List[AddPageNumberItem] + +class SingleTocFixerResponse(BaseModel): + thinking: str = Field(..., min_length=1) + physical_index: str + +class GenerateTocItem(BaseModel): + structure: Optional[str] = None + title: str + physical_index: str + +class GenerateTocResponse(BaseModel): + root: List[GenerateTocItem] + + async def check_title_appearance(item, page_list, start_index=1, model=None): title=item['title'] if 'physical_index' not in item or item['physical_index'] is None: @@ -37,15 +132,15 @@ async def check_title_appearance(item, page_list, start_index=1, model=None): Directly return the final JSON structure. Do not output anything else.""" response = await ChatGPT_API_async(model=model, prompt=prompt) - response = extract_json(response) - if 'answer' in response: - answer = response['answer'] + parsed_response = parse_llm_response(CheckTitleAppearanceResponse, response) + + if parsed_response: + answer = parsed_response.answer else: answer = 'no' return {'list_index': item['list_index'], 'answer': answer, 'title': title, 'page_number': page_number} - -async def check_title_appearance_in_start(title, page_text, model=None, logger=None): +async def check_title_appearance_in_start(title, page_text, model=None, logger=None): prompt = f""" You will be given the current section title and the current page_text. Your job is to check if the current section starts in the beginning of the given page_text. @@ -65,11 +160,10 @@ async def check_title_appearance_in_start(title, page_text, model=None, logger=N Directly return the final JSON structure. Do not output anything else.""" response = await ChatGPT_API_async(model=model, prompt=prompt) - response = extract_json(response) + parsed_response = parse_llm_response(CheckTitleAppearanceInStartResponse, response, logger) if logger: logger.info(f"Response: {response}") - return response.get("start_begin", "no") - + return parsed_response.start_begin if parsed_response else "no" async def check_title_appearance_in_start_concurrent(structure, page_list, model=None, logger=None): if logger: @@ -100,7 +194,7 @@ async def check_title_appearance_in_start_concurrent(structure, page_list, model return structure - + def toc_detector_single_page(content, model=None): prompt = f""" Your job is to detect if there is a table of content provided in the given text. @@ -117,9 +211,8 @@ def toc_detector_single_page(content, model=None): Please note: abstract,summary, notation list, figure list, table list, etc. are not table of contents.""" response = ChatGPT_API(model=model, prompt=prompt) - # print('response', response) - json_content = extract_json(response) - return json_content['toc_detected'] + parsed_response = parse_llm_response(TocDetectorResponse, response) + return parsed_response.toc_detected if parsed_response else "no" def check_if_toc_extraction_is_complete(content, toc, model=None): @@ -136,8 +229,8 @@ def check_if_toc_extraction_is_complete(content, toc, model=None): prompt = prompt + '\n Document:\n' + content + '\n Table of contents:\n' + toc response = ChatGPT_API(model=model, prompt=prompt) - json_content = extract_json(response) - return json_content['completed'] + parsed_response = parse_llm_response(TocCompletionResponse, response) + return parsed_response.completed if parsed_response else "no" def check_if_toc_transformation_is_complete(content, toc, model=None): @@ -154,8 +247,8 @@ def check_if_toc_transformation_is_complete(content, toc, model=None): prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc response = ChatGPT_API(model=model, prompt=prompt) - json_content = extract_json(response) - return json_content['completed'] + parsed_response = parse_llm_response(TocCompletionResponse, response) + return parsed_response.completed if parsed_response else "no" def extract_toc_content(content, model=None): prompt = f""" @@ -191,7 +284,7 @@ def extract_toc_content(content, model=None): if_complete = check_if_toc_transformation_is_complete(content, response, model) # Optional: Add a maximum retry limit to prevent infinite loops - if len(chat_history) > 5: # Arbitrary limit of 10 attempts + if len(chat_history) > 10: # Arbitrary limit of 10 attempts raise Exception('Failed to complete table of contents after maximum retries') return response @@ -213,8 +306,8 @@ def detect_page_index(toc_content, model=None): Directly return the final JSON structure. Do not output anything else.""" response = ChatGPT_API(model=model, prompt=prompt) - json_content = extract_json(response) - return json_content['page_index_given_in_toc'] + parsed_response = parse_llm_response(PageIndexDetectionResponse, response) + return parsed_response.page_index_given_in_toc if parsed_response else "no" def toc_extractor(page_list, toc_page_list, model): def transform_dots_to_colon(text): @@ -237,9 +330,9 @@ def transform_dots_to_colon(text): -def toc_index_extractor(toc, content, model=None): +def toc_index_extractor(toc, content, model=None, logger=None): print('start toc_index_extractor') - tob_extractor_prompt = """ + toc_extractor_prompt = """ You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format. The provided pages contains tags like and to indicate the physical location of the page X. @@ -260,14 +353,22 @@ def toc_index_extractor(toc, content, model=None): If the section is not in the provided pages, do not add the physical_index to it. Directly return the final JSON structure. Do not output anything else.""" - prompt = tob_extractor_prompt + '\nTable of contents:\n' + str(toc) + '\nDocument pages:\n' + content + prompt = toc_extractor_prompt + '\nTable of contents:\n' + str(toc) + '\nDocument pages:\n' + content response = ChatGPT_API(model=model, prompt=prompt) - json_content = extract_json(response) - return json_content + + # Normalize response to expected format + data = extract_json(response) + if isinstance(data, list): + data = {'root': data} + + parsed_response = parse_llm_response(TocIndexResponse, json.dumps(data), logger) + if parsed_response: + return [item.model_dump() for item in parsed_response.root] + return [] -def toc_transformer(toc_content, model=None): +def toc_transformer(toc_content, model=None, logger=None): print('start toc_transformer') init_prompt = """ You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents. @@ -292,9 +393,12 @@ def toc_transformer(toc_content, model=None): last_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) if if_complete == "yes" and finish_reason == "finished": - last_complete = extract_json(last_complete) - cleaned_response=convert_page_to_int(last_complete['table_of_contents']) - return cleaned_response + parsed_response = parse_llm_response(TocTransformerResponse, last_complete, logger) + if parsed_response: + # Normalize to list of dicts for convert_page_to_int + data_list = [item.model_dump() for item in parsed_response.table_of_contents] + return convert_page_to_int(data_list) + return [] last_complete = get_json_content(last_complete) while not (if_complete == "yes" and finish_reason == "finished"): @@ -322,10 +426,11 @@ def toc_transformer(toc_content, model=None): if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) - last_complete = json.loads(last_complete) - - cleaned_response=convert_page_to_int(last_complete['table_of_contents']) - return cleaned_response + parsed_response = parse_llm_response(TocTransformerResponse, last_complete, logger) + if parsed_response: + data_list = [item.model_dump() for item in parsed_response.table_of_contents] + return convert_page_to_int(data_list) + return [] @@ -450,7 +555,7 @@ def page_list_to_group_text(page_contents, token_lengths, max_tokens=20000, over print('divide page_list to groups', len(subsets)) return subsets -def add_page_number_to_toc(part, structure, model=None): +def add_page_number_to_toc(part, structure, model=None, logger=None): fill_prompt_seq = """ You are given an JSON structure of a document and a partial part of the document. Your task is to check if the title that is described in the structure is started in the partial given document. @@ -475,7 +580,17 @@ def add_page_number_to_toc(part, structure, model=None): prompt = fill_prompt_seq + f"\n\nCurrent Partial Document:\n{part}\n\nGiven Structure\n{json.dumps(structure, indent=2)}\n" current_json_raw = ChatGPT_API(model=model, prompt=prompt) - json_result = extract_json(current_json_raw) + extracted_data = extract_json(current_json_raw) + if isinstance(extracted_data, list): + data = {'root': extracted_data} + else: + data = extracted_data + + parsed_response = parse_llm_response(AddPageNumberResponse, json.dumps(data), logger) + if parsed_response: + json_result = [item.model_dump() for item in parsed_response.root] + else: + json_result = [] for item in json_result: if 'start' in item: @@ -495,8 +610,7 @@ def remove_first_physical_index_section(text): return text.replace(match.group(0), '', 1) return text -### add verify completeness -def generate_toc_continue(toc_content, part, model="gpt-4o-2024-11-20"): +def generate_toc_continue(toc_content, part, model="gpt-4o-2024-11-20", logger=None): print('start generate_toc_continue') prompt = """ You are an expert in extracting hierarchical tree structure. @@ -526,12 +640,19 @@ def generate_toc_continue(toc_content, part, model="gpt-4o-2024-11-20"): prompt = prompt + '\nGiven text\n:' + part + '\nPrevious tree structure\n:' + json.dumps(toc_content, indent=2) response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) if finish_reason == 'finished': - return extract_json(response) + data = extract_json(response) + if isinstance(data, list): + data = {'root': data} + + parsed_response = parse_llm_response(GenerateTocResponse, json.dumps(data), logger) + if parsed_response: + return [item.model_dump() for item in parsed_response.root] + return [] else: raise Exception(f'finish reason: {finish_reason}') ### add verify completeness -def generate_toc_init(part, model=None): +def generate_toc_init(part, model=None, logger=None): print('start generate_toc_init') prompt = """ You are an expert in extracting hierarchical tree structure, your task is to generate the tree structure of the document. @@ -561,7 +682,14 @@ def generate_toc_init(part, model=None): response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) if finish_reason == 'finished': - return extract_json(response) + data = extract_json(response) + if isinstance(data, list): + data = {'root': data} + + parsed_response = parse_llm_response(GenerateTocResponse, json.dumps(data), logger) + if parsed_response: + return [item.model_dump() for item in parsed_response.root] + return [] else: raise Exception(f'finish reason: {finish_reason}') @@ -575,9 +703,9 @@ def process_no_toc(page_list, start_index=1, model=None, logger=None): group_texts = page_list_to_group_text(page_contents, token_lengths) logger.info(f'len(group_texts): {len(group_texts)}') - toc_with_page_number= generate_toc_init(group_texts[0], model) + toc_with_page_number= generate_toc_init(group_texts[0], model, logger) for group_text in group_texts[1:]: - toc_with_page_number_additional = generate_toc_continue(toc_with_page_number, group_text, model) + toc_with_page_number_additional = generate_toc_continue(toc_with_page_number, group_text, model, logger) toc_with_page_number.extend(toc_with_page_number_additional) logger.info(f'generate_toc: {toc_with_page_number}') @@ -589,7 +717,7 @@ def process_no_toc(page_list, start_index=1, model=None, logger=None): def process_toc_no_page_numbers(toc_content, toc_page_list, page_list, start_index=1, model=None, logger=None): page_contents=[] token_lengths=[] - toc_content = toc_transformer(toc_content, model) + toc_content = toc_transformer(toc_content, model, logger) logger.info(f'toc_transformer: {toc_content}') for page_index in range(start_index, start_index+len(page_list)): page_text = f"\n{page_list[page_index-start_index][0]}\n\n\n" @@ -601,7 +729,7 @@ def process_toc_no_page_numbers(toc_content, toc_page_list, page_list, start_in toc_with_page_number=copy.deepcopy(toc_content) for group_text in group_texts: - toc_with_page_number = add_page_number_to_toc(group_text, toc_with_page_number, model) + toc_with_page_number = add_page_number_to_toc(group_text, toc_with_page_number, model, logger) logger.info(f'add_page_number_to_toc: {toc_with_page_number}') toc_with_page_number = convert_physical_index_to_int(toc_with_page_number) @@ -612,7 +740,7 @@ def process_toc_no_page_numbers(toc_content, toc_page_list, page_list, start_in def process_toc_with_page_numbers(toc_content, toc_page_list, page_list, toc_check_page_num=None, model=None, logger=None): - toc_with_page_number = toc_transformer(toc_content, model) + toc_with_page_number = toc_transformer(toc_content, model, logger) logger.info(f'toc_with_page_number: {toc_with_page_number}') toc_no_page_number = remove_page_number(copy.deepcopy(toc_with_page_number)) @@ -622,7 +750,7 @@ def process_toc_with_page_numbers(toc_content, toc_page_list, page_list, toc_che for page_index in range(start_page_index, min(start_page_index + toc_check_page_num, len(page_list))): main_content += f"\n{page_list[page_index][0]}\n\n\n" - toc_with_physical_index = toc_index_extractor(toc_no_page_number, main_content, model) + toc_with_physical_index = toc_index_extractor(toc_no_page_number, main_content, model, logger) logger.info(f'toc_with_physical_index: {toc_with_physical_index}') toc_with_physical_index = convert_physical_index_to_int(toc_with_physical_index) @@ -648,8 +776,6 @@ def process_toc_with_page_numbers(toc_content, toc_page_list, page_list, toc_che def process_none_page_numbers(toc_items, page_list, start_index=1, model=None): for i, item in enumerate(toc_items): if "physical_index" not in item: - # logger.info(f"fix item: {item}") - # Find previous physical_index prev_physical_index = 0 # Default if no previous item exists for j in range(i - 1, -1, -1): if toc_items[j].get('physical_index') is not None: @@ -674,11 +800,11 @@ def process_none_page_numbers(toc_items, page_list, start_index=1, model=None): continue item_copy = copy.deepcopy(item) - del item_copy['page'] + item_copy.pop('page', None) result = add_page_number_to_toc(page_contents, item_copy, model) if isinstance(result[0]['physical_index'], str) and result[0]['physical_index'].startswith('').strip()) - del item['page'] + item.pop('page', None) return toc_items @@ -744,8 +870,8 @@ def single_toc_item_index_fixer(section_title, content, model="gpt-4o-2024-11-20 prompt = tob_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content response = ChatGPT_API(model=model, prompt=prompt) - json_content = extract_json(response) - return convert_physical_index_to_int(json_content['physical_index']) + parsed_response = parse_llm_response(SingleTocFixerResponse, response) + return convert_physical_index_to_int(parsed_response.physical_index) if parsed_response else None @@ -1141,4 +1267,4 @@ def validate_and_truncate_physical_indices(toc_with_page_number, page_list_lengt if truncated_items: print(f"Truncated {len(truncated_items)} TOC items that exceeded document length") - return toc_with_page_number \ No newline at end of file + return toc_with_page_number