Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,28 @@ def test_assess_batch_prediction_resources(client):
assert isinstance(response, types.BatchPredictionResourceUsageAssessmentResult)


def test_assess_batch_prediction_validity(client):
response = client.datasets.assess_batch_prediction_validity(
dataset_name=DATASET,
model_name="gemini-2.5-flash-001",
template_config=types.GeminiTemplateConfig(
gemini_example=types.GeminiExample(
contents=[
{
"role": "user",
"parts": [{"text": "What is the capital of {name}?"}],
},
{
"role": "model",
"parts": [{"text": "{capital}"}],
},
],
),
),
)
assert isinstance(response, types.BatchPredictionValidationAssessmentResult)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand Down Expand Up @@ -206,3 +228,26 @@ async def test_assess_batch_prediction_resources_async(client):
),
)
assert isinstance(response, types.BatchPredictionResourceUsageAssessmentResult)


@pytest.mark.asyncio
async def test_assess_batch_prediction_validity_async(client):
response = await client.aio.datasets.assess_batch_prediction_validity(
dataset_name=DATASET,
model_name="gemini-2.5-flash-001",
template_config=types.GeminiTemplateConfig(
gemini_example=types.GeminiExample(
contents=[
{
"role": "user",
"parts": [{"text": "What is the capital of {name}?"}],
},
{
"role": "model",
"parts": [{"text": "{capital}"}],
},
],
),
),
)
assert isinstance(response, types.BatchPredictionValidationAssessmentResult)
118 changes: 118 additions & 0 deletions vertexai/_genai/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,65 @@ def assess_batch_prediction_resources(
types.BatchPredictionResourceUsageAssessmentResult, result
)

def assess_batch_prediction_validity(
self,
*,
dataset_name: str,
model_name: str,
template_config: Optional[types.GeminiTemplateConfigOrDict] = None,
config: Optional[types.AssessDatasetConfigOrDict] = None,
) -> types.BatchPredictionValidationAssessmentResult:
"""Assess if the assembled dataset is valid in terms of batch prediction
for a given model. Raises an error if the dataset is invalid, otherwise
returns None.

Args:
dataset_name:
Required. The name of the dataset to assess the batch prediction
validity for.
model_name:
Required. The name of the model to assess the batch prediction
validity for.
template_config:
Optional. The template config used to assemble the dataset
before assessing the batch prediction validity. If not provided, the
template config attached to the dataset will be used. Required
if no template config is attached to the dataset.
config:
Optional. A configuration for assessing the batch prediction validity.
If not provided, the default configuration will be used.

Returns:
A types.BatchPredictionValidationAssessmentResult object representing
the batch prediction validity assessment result.
It contains the following keys:
- errors: A list of errors that occurred during the batch prediction
validity assessment.
"""
if isinstance(config, dict):
config = types.AssessDatasetConfig(**config)
elif not config:
config = types.AssessDatasetConfig()

operation = self._assess_multimodal_dataset(
name=dataset_name,
batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig(
model_name=model_name,
),
gemini_request_read_config=types.GeminiRequestReadConfig(
template_config=template_config,
),
config=config,
)
response = self._wait_for_operation(
operation=operation,
timeout_seconds=config.timeout,
)
result = response["batchPredictionValidationAssessmentResult"]
return _datasets_utils.create_from_response(
types.BatchPredictionValidationAssessmentResult, result
)


class AsyncDatasets(_api_module.BaseModule):

Expand Down Expand Up @@ -2127,3 +2186,62 @@ async def assess_batch_prediction_resources(
return _datasets_utils.create_from_response(
types.BatchPredictionResourceUsageAssessmentResult, result
)

async def assess_batch_prediction_validity(
self,
*,
dataset_name: str,
model_name: str,
template_config: Optional[types.GeminiTemplateConfigOrDict] = None,
config: Optional[types.AssessDatasetConfigOrDict] = None,
) -> types.BatchPredictionValidationAssessmentResult:
"""Assess if the assembled dataset is valid in terms of batch prediction
for a given model. Raises an error if the dataset is invalid, otherwise
returns None.

Args:
dataset_name:
Required. The name of the dataset to assess the batch prediction
validity for.
model_name:
Required. The name of the model to assess the batch prediction
validity for.
template_config:
Optional. The template config used to assemble the dataset
before assessing the batch prediction validity. If not provided, the
template config attached to the dataset will be used. Required
if no template config is attached to the dataset.
config:
Optional. A configuration for assessing the batch prediction validity.
If not provided, the default configuration will be used.

Returns:
A types.BatchPredictionValidationAssessmentResult object representing
the batch prediction validity assessment result.
It contains the following keys:
- errors: A list of errors that occurred during the batch prediction
validity assessment.
"""
if isinstance(config, dict):
config = types.AssessDatasetConfig(**config)
elif not config:
config = types.AssessDatasetConfig()

operation = await self._assess_multimodal_dataset(
name=dataset_name,
batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig(
model_name=model_name,
),
gemini_request_read_config=types.GeminiRequestReadConfig(
template_config=template_config,
),
config=config,
)
response = await self._wait_for_operation(
operation=operation,
timeout_seconds=config.timeout,
)
result = response["batchPredictionValidationAssessmentResult"]
return _datasets_utils.create_from_response(
types.BatchPredictionValidationAssessmentResult, result
)
6 changes: 6 additions & 0 deletions vertexai/_genai/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@
from .common import BatchPredictionValidationAssessmentConfig
from .common import BatchPredictionValidationAssessmentConfigDict
from .common import BatchPredictionValidationAssessmentConfigOrDict
from .common import BatchPredictionValidationAssessmentResult
from .common import BatchPredictionValidationAssessmentResultDict
from .common import BatchPredictionValidationAssessmentResultOrDict
from .common import BigQueryRequestSet
from .common import BigQueryRequestSetDict
from .common import BigQueryRequestSetOrDict
Expand Down Expand Up @@ -1908,6 +1911,9 @@
"BatchPredictionResourceUsageAssessmentResult",
"BatchPredictionResourceUsageAssessmentResultDict",
"BatchPredictionResourceUsageAssessmentResultOrDict",
"BatchPredictionValidationAssessmentResult",
"BatchPredictionValidationAssessmentResultDict",
"BatchPredictionValidationAssessmentResultOrDict",
"TuningResourceUsageAssessmentResult",
"TuningResourceUsageAssessmentResultDict",
"TuningResourceUsageAssessmentResultOrDict",
Expand Down
21 changes: 21 additions & 0 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14250,6 +14250,27 @@ class BatchPredictionResourceUsageAssessmentResultDict(TypedDict, total=False):
]


class BatchPredictionValidationAssessmentResult(_common.BaseModel):
"""Result of batch prediction validation assessment."""

errors: Optional[list[str]] = Field(
default=None, description="""The list of errors found in the dataset."""
)


class BatchPredictionValidationAssessmentResultDict(TypedDict, total=False):
"""Result of batch prediction validation assessment."""

errors: Optional[list[str]]
"""The list of errors found in the dataset."""


BatchPredictionValidationAssessmentResultOrDict = Union[
BatchPredictionValidationAssessmentResult,
BatchPredictionValidationAssessmentResultDict,
]


class TuningResourceUsageAssessmentResult(_common.BaseModel):
"""Result of tuning resource usage assessment."""

Expand Down
Loading