Skip to content
1 change: 1 addition & 0 deletions mellea/backends/adapters/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class IntriniscsCatalogEntry(pydantic.BaseModel):
############################################
# Core Intrinsics
############################################
IntriniscsCatalogEntry(name="context-attribution", repo_id=_CORE_R1_REPO),
IntriniscsCatalogEntry(name="requirement-check", repo_id=_CORE_R1_REPO),
IntriniscsCatalogEntry(
name="requirement_check", repo_id=_CORE_REPO
Expand Down
65 changes: 49 additions & 16 deletions mellea/formatters/granite/intrinsics/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,31 +48,31 @@ def sentence_delimiter(tag, sentence_num) -> str:


def mark_sentence_boundaries(
split_strings: list[list[str]], tag_prefix: str
) -> list[str]:
split_strings: list[list[str]], tag_prefix: str, index: int = 0
) -> tuple[list[str], int]:
"""Modify input strings by inserting sentence boundary markers.

Modify one or more input strings by inserting a tag in the form
``<[prefix][number]>``
at the location of each sentence boundary.

Args:
split_strings: Input string(s), pre-split into sentences.
tag_prefix: String to place before the number part of each tagged
sentence boundary.
:param split_strings: Input string(s), pre-split into sentences
:param tag_prefix: String to place before the number part of each tagged
sentence boundary.
:param index: Starting index for sentence numbering. Defaults to 0. Pass a
non-zero value to continue numbering from a prior call.

Returns:
List of input strings with all sentence boundaries marked.
:returns: Tuple of (list of input strings with all sentence boundaries marked,
next available index after the last sentence).
"""
index = 0
result: list[str] = []
for sentences in split_strings:
to_concat = []
for sentence in sentences:
to_concat.append(f"{sentence_delimiter(tag_prefix, index)}{sentence}")
index += 1
result.append(" ".join(to_concat))
return result
return result, index


def move_documents_to_message(
Expand Down Expand Up @@ -291,10 +291,11 @@ def __init__(
f"Received {self.sentence_boundaries}."
)
for k, v in self.sentence_boundaries.items():
if k not in ("last_message", "documents"):
if k not in ("last_message", "documents", "all_but_last_message"):
raise ValueError(
f"Unexpected location '{k}' in 'sentence_boundaries' field. "
f"Value should be 'last_message' or 'documents'."
f"Value should be 'last_message', 'documents', or "
f"'all_but_last_message'."
)
if not isinstance(v, str):
raise TypeError(
Expand Down Expand Up @@ -324,21 +325,26 @@ def _mark_sentence_boundaries(
:rtype: ChatCompletion
"""
# Mark sentence boundaries in the last message.
# last_message uses its own numbering starting from 0, independent of
# the numbering used for documents and conversation history.
if self.sentence_boundaries and "last_message" in self.sentence_boundaries:
messages = chat_completion.messages.copy() # Do not modify input!
last_message_as_sentences = list(
self.sentence_splitter.tokenize(messages[-1].content)
)
last_message_tag = self.sentence_boundaries["last_message"]
if last_message_tag:
rewritten_last_message_text = mark_sentence_boundaries(
rewritten_texts, _ = mark_sentence_boundaries(
[last_message_as_sentences], last_message_tag
)[0]
messages[-1].content = rewritten_last_message_text
)
messages[-1].content = rewritten_texts[0]
chat_completion = chat_completion.model_copy(
update={"messages": messages}
)

# documents and all_but_last_message share a continuous numbering.
index = 0

# Mark sentence boundaries in documents if present
if (
chat_completion.extra_body
Expand All @@ -355,11 +361,14 @@ def _mark_sentence_boundaries(
# where `k` is the number of sentences in ALL documents.
documents_tag = self.sentence_boundaries["documents"]
if documents_tag:
rewritten_texts, index = mark_sentence_boundaries(
docs_as_sentences, documents_tag, index
)
rewritten_docs = [
doc.model_copy(update={"text": text})
for doc, text in zip(
chat_completion.extra_body.documents,
mark_sentence_boundaries(docs_as_sentences, documents_tag),
rewritten_texts,
strict=True,
)
]
Expand All @@ -370,6 +379,30 @@ def _mark_sentence_boundaries(
chat_completion = chat_completion.model_copy(
update={"extra_body": extra_body}
)

# Mark sentence boundaries in conversation history if requested.
# Uses the same numbering as documents, continuing from where they left off.
if (
self.sentence_boundaries
and "all_but_last_message" in self.sentence_boundaries
):
history_tag = self.sentence_boundaries["all_but_last_message"]
if history_tag:
messages = chat_completion.messages.copy() # Do not modify input!
for i, message in enumerate(messages[:-1]):
msg_as_sentences = list(
self.sentence_splitter.tokenize(message.content)
)
rewritten_texts, index = mark_sentence_boundaries(
[msg_as_sentences], history_tag, index
)
messages[i] = message.model_copy(
update={"content": rewritten_texts[0]}
)
chat_completion = chat_completion.model_copy(
update={"messages": messages}
)

return chat_completion

def _transform(
Expand Down
186 changes: 118 additions & 68 deletions mellea/formatters/granite/intrinsics/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,33 +515,42 @@ def __init__(
config: dict,
input_path_expr: list[str | int | None],
/,
source: str,
source: str | list[str],
output_names: dict,
):
"""Initialize DecodeSentences with a source location and output field name mapping.

:param source: Name (or list of names) of the location(s) to look for
sentences; each name can be "last_message", "documents", or
"all_but_last_message".
:param output_names: Names of new result fields to add

Raises:
ValueError: If ``source`` is not ``"last_message"`` or
``"documents"``, or if an unexpected key is found in
``output_names``.
ValueError: If ``source`` is not one of the allowed values, or if
an unexpected key is found in ``output_names``.
TypeError: If ``output_names`` is not a dict.
"""
super().__init__(config, input_path_expr)

allowed_sources = ("last_message", "documents")
if source not in allowed_sources:
raise ValueError(
f"'source' argument must be one of {allowed_sources}. "
f"Received '{source}'"
)
if isinstance(source, str):
source = [source]
allowed_sources = ("last_message", "documents", "all_but_last_message")
for s in source:
if s not in allowed_sources:
raise ValueError(
f"'source' argument must be one of {allowed_sources}. "
f"Received '{s}'"
)
self.source = source

if not isinstance(output_names, dict):
raise TypeError(
f"Expected mapping for output_names, but received {output_names}"
)
for k in output_names:
if source == "documents" and k == "document_id":
if "documents" in source and k == "document_id":
continue
if "all_but_last_message" in source and k == "message_index":
continue
if k not in ("begin", "end", "text"):
raise ValueError(f"Unexpected key '{k}' in output_names")
Expand All @@ -551,6 +560,7 @@ def __init__(
self.end_name = output_names.get("end")
self.text_name = output_names.get("text")
self.document_id_name = output_names.get("document_id")
self.message_index_name = output_names.get("message_index")

if config["docs_as_message"] and config["docs_as_message"] not in [
"json",
Expand All @@ -575,84 +585,122 @@ def _prepare(
f"'{self.rule_name()}' rule requires this object."
)

if self.source == "documents":
tag = self.config["sentence_boundaries"]["documents"]
if tag is None:
raise ValueError(
f"'{self.rule_name()}' attempting to decode document sentences, "
f"but 'sentence_boundaries' section of config file is missing "
f"the entry that tells how to tag document sentence boundaries."
)
begins: list[int] = []
ends: list[int] = []
texts: list[str] = []
document_ids: list[str | None] = []
message_indices: list[int | None] = []
next_sentence_num = 0

for src in self.source:
if src == "documents":
tag = self.config["sentence_boundaries"]["documents"]
if tag is None:
raise ValueError(
f"'{self.rule_name()}' attempting to decode document sentences, "
f"but 'sentence_boundaries' section of config file is missing "
f"the entry that tells how to tag document sentence boundaries."
)

documents: list[Document] = []
if not self.config["docs_as_message"]:
# Most common path: Documents from extra_body
if chat_completion.extra_body is not None:
documents = chat_completion.extra_body.documents or []
else:
# Model requires documents in a user message. Decode the message.
if self.config["docs_as_message"] == "json":
documents_json = json.loads(chat_completion.messages[0].content)
documents = [Document.model_validate(d) for d in documents_json]
elif self.config["docs_as_message"] == "roles":
for message in chat_completion.messages:
if message.role.startswith("document "):
document = Document(
doc_id=message.role[len("document ") :],
text=message.content,
)
documents.append(document)
documents: list[Document] = []
if not self.config["docs_as_message"]:
# Most common path: Documents from extra_body
if chat_completion.extra_body is not None:
documents = chat_completion.extra_body.documents or []
else:
# Model requires documents in a user message. Decode the message.
if self.config["docs_as_message"] == "json":
documents_json = json.loads(chat_completion.messages[0].content)
documents = [Document.model_validate(d) for d in documents_json]
elif self.config["docs_as_message"] == "roles":
for message in chat_completion.messages:
if message.role.startswith("document "):
document = Document(
doc_id=message.role[len("document ") :],
text=message.content,
)
documents.append(document)
else:
raise ValueError(
f"Unsupported doc type {self.config['docs_as_message']}"
)

# De-split sentences; numbers start at next_sentence_num and continue
# across documents.
for d in documents:
local_results = _desplit_sentences(d.text, tag, next_sentence_num)
num_local_sentences = len(local_results["begins"])
begins.extend(local_results["begins"])
ends.extend(local_results["ends"])
texts.extend(local_results["texts"])
document_ids.extend([d.doc_id] * num_local_sentences)
message_indices.extend([None] * num_local_sentences)
next_sentence_num += num_local_sentences

elif src == "last_message":
tag = self.config["sentence_boundaries"]["last_message"]
if tag is None:
raise ValueError(
f"Unsupported doc type {self.config['docs_as_message']}"
f"'{self.rule_name()}' attempting to decode the last message, "
f"but 'sentence_boundaries' section of config file is missing "
f"the entry that tells how to tag message sentence boundaries."
)

# De-split the sentences in each document in turn. Sentence numbers
# start at zero on the first document and continue in subsequent documents.
begins = []
ends = []
texts = []
document_ids = []

next_sentence_num = 0
for d in documents:
local_results = _desplit_sentences(d.text, tag, next_sentence_num)
# Use second-to-last turn if the input processing added an instruction turn
message_ix = -2 if self.config["instruction"] else -1
target_text = chat_completion.messages[message_ix].content
local_results = _desplit_sentences(target_text, tag, next_sentence_num)
num_local_sentences = len(local_results["begins"])
begins.extend(local_results["begins"])
ends.extend(local_results["ends"])
texts.extend(local_results["texts"])
document_ids.extend([d.doc_id] * num_local_sentences)
document_ids.extend([None] * num_local_sentences)
message_indices.extend([None] * num_local_sentences)
next_sentence_num += num_local_sentences

return {
"begins": begins,
"ends": ends,
"texts": texts,
"document_ids": document_ids,
}
if self.source == "last_message":
tag = self.config["sentence_boundaries"]["last_message"]
if tag is None:
raise ValueError(
f"'{self.rule_name()}' attempting to decode the last message, "
f"but 'sentence_boundaries' section of config file is missing "
f"the entry that tells how to tag message sentence boundaries."
)
elif src == "all_but_last_message":
tag = self.config["sentence_boundaries"]["all_but_last_message"]
if tag is None:
raise ValueError(
f"'{self.rule_name()}' attempting to decode conversation "
f"history sentences, but 'sentence_boundaries' section of "
f"config file is missing the entry that tells how to tag "
f"all_but_last_message sentence boundaries."
)

# Use second-to-last turn if the input processing added an instruction turn
message_ix = -2 if self.config["instruction"] else -1
target_text = chat_completion.messages[message_ix].content
# Use second-to-last as the boundary if an instruction turn was added
last_ix = -2 if self.config["instruction"] else -1
history_messages = chat_completion.messages[:last_ix]
for i, message in enumerate(history_messages):
local_results = _desplit_sentences(
message.content, tag, next_sentence_num
)
num_local_sentences = len(local_results["begins"])
begins.extend(local_results["begins"])
ends.extend(local_results["ends"])
texts.extend(local_results["texts"])
document_ids.extend([None] * num_local_sentences)
message_indices.extend([i] * num_local_sentences)
next_sentence_num += num_local_sentences

return _desplit_sentences(target_text, tag, 0)
else:
raise ValueError(f"Unexpected source string '{src}'")

raise ValueError(f"Unexpected source string '{self.source}'")
return {
"begins": begins,
"ends": ends,
"texts": texts,
"document_ids": document_ids,
"message_indices": message_indices,
}

def _transform(self, value: Any, path: tuple, prepare_output: dict) -> dict:
# Unpack global values we set aside during the prepare phase
begins = prepare_output["begins"]
ends = prepare_output["ends"]
texts = prepare_output["texts"]
document_ids = prepare_output.get("document_ids")
message_indices = prepare_output.get("message_indices")

if not isinstance(value, int):
raise TypeError(
Expand All @@ -670,6 +718,8 @@ def _transform(self, value: Any, path: tuple, prepare_output: dict) -> dict:
result[self.text_name] = texts[sentence_num]
if self.document_id_name is not None:
result[self.document_id_name] = document_ids[sentence_num] # type: ignore[index]
if self.message_index_name is not None:
result[self.message_index_name] = message_indices[sentence_num] # type: ignore[index]
return result


Expand Down
Loading
Loading