diff --git a/tests/unit/vertexai/genai/replays/test_assess_multimodal_dataset.py b/tests/unit/vertexai/genai/replays/test_assess_multimodal_dataset.py index be01b64e9f..dc6167c916 100644 --- a/tests/unit/vertexai/genai/replays/test_assess_multimodal_dataset.py +++ b/tests/unit/vertexai/genai/replays/test_assess_multimodal_dataset.py @@ -111,6 +111,28 @@ def test_assess_batch_prediction_resources(client): assert isinstance(response, types.BatchPredictionResourceUsageAssessmentResult) +def test_assess_batch_prediction_validity(client): + response = client.datasets.assess_batch_prediction_validity( + dataset_name=DATASET, + model_name="gemini-2.5-flash-001", + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + }, + { + "role": "model", + "parts": [{"text": "{capital}"}], + }, + ], + ), + ), + ) + assert isinstance(response, types.BatchPredictionValidationAssessmentResult) + + pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), @@ -206,3 +228,26 @@ async def test_assess_batch_prediction_resources_async(client): ), ) assert isinstance(response, types.BatchPredictionResourceUsageAssessmentResult) + + +@pytest.mark.asyncio +async def test_assess_batch_prediction_validity_async(client): + response = await client.aio.datasets.assess_batch_prediction_validity( + dataset_name=DATASET, + model_name="gemini-2.5-flash-001", + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + }, + { + "role": "model", + "parts": [{"text": "{capital}"}], + }, + ], + ), + ), + ) + assert isinstance(response, types.BatchPredictionValidationAssessmentResult) diff --git a/vertexai/_genai/datasets.py b/vertexai/_genai/datasets.py index 00bc55003f..9a6b507e10 100644 --- a/vertexai/_genai/datasets.py +++ b/vertexai/_genai/datasets.py @@ -1180,6 +1180,65 @@ def assess_batch_prediction_resources( types.BatchPredictionResourceUsageAssessmentResult, result ) + def assess_batch_prediction_validity( + self, + *, + dataset_name: str, + model_name: str, + template_config: Optional[types.GeminiTemplateConfigOrDict] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.BatchPredictionValidationAssessmentResult: + """Assess if the assembled dataset is valid in terms of batch prediction + for a given model. Raises an error if the dataset is invalid, otherwise + returns None. + + Args: + dataset_name: + Required. The name of the dataset to assess the batch prediction + validity for. + model_name: + Required. The name of the model to assess the batch prediction + validity for. + template_config: + Optional. The template config used to assemble the dataset + before assessing the batch prediction validity. If not provided, the + template config attached to the dataset will be used. Required + if no template config is attached to the dataset. + config: + Optional. A configuration for assessing the batch prediction validity. + If not provided, the default configuration will be used. + + Returns: + A types.BatchPredictionValidationAssessmentResult object representing + the batch prediction validity assessment result. + It contains the following keys: + - errors: A list of errors that occurred during the batch prediction + validity assessment. + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + operation = self._assess_multimodal_dataset( + name=dataset_name, + batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig( + model_name=model_name, + ), + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=template_config, + ), + config=config, + ) + response = self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + result = response["batchPredictionValidationAssessmentResult"] + return _datasets_utils.create_from_response( + types.BatchPredictionValidationAssessmentResult, result + ) + class AsyncDatasets(_api_module.BaseModule): @@ -2127,3 +2186,62 @@ async def assess_batch_prediction_resources( return _datasets_utils.create_from_response( types.BatchPredictionResourceUsageAssessmentResult, result ) + + async def assess_batch_prediction_validity( + self, + *, + dataset_name: str, + model_name: str, + template_config: Optional[types.GeminiTemplateConfigOrDict] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.BatchPredictionValidationAssessmentResult: + """Assess if the assembled dataset is valid in terms of batch prediction + for a given model. Raises an error if the dataset is invalid, otherwise + returns None. + + Args: + dataset_name: + Required. The name of the dataset to assess the batch prediction + validity for. + model_name: + Required. The name of the model to assess the batch prediction + validity for. + template_config: + Optional. The template config used to assemble the dataset + before assessing the batch prediction validity. If not provided, the + template config attached to the dataset will be used. Required + if no template config is attached to the dataset. + config: + Optional. A configuration for assessing the batch prediction validity. + If not provided, the default configuration will be used. + + Returns: + A types.BatchPredictionValidationAssessmentResult object representing + the batch prediction validity assessment result. + It contains the following keys: + - errors: A list of errors that occurred during the batch prediction + validity assessment. + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + operation = await self._assess_multimodal_dataset( + name=dataset_name, + batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig( + model_name=model_name, + ), + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=template_config, + ), + config=config, + ) + response = await self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + result = response["batchPredictionValidationAssessmentResult"] + return _datasets_utils.create_from_response( + types.BatchPredictionValidationAssessmentResult, result + ) diff --git a/vertexai/_genai/types/__init__.py b/vertexai/_genai/types/__init__.py index fa43a483dc..19c3b021bb 100644 --- a/vertexai/_genai/types/__init__.py +++ b/vertexai/_genai/types/__init__.py @@ -148,6 +148,9 @@ from .common import BatchPredictionValidationAssessmentConfig from .common import BatchPredictionValidationAssessmentConfigDict from .common import BatchPredictionValidationAssessmentConfigOrDict +from .common import BatchPredictionValidationAssessmentResult +from .common import BatchPredictionValidationAssessmentResultDict +from .common import BatchPredictionValidationAssessmentResultOrDict from .common import BigQueryRequestSet from .common import BigQueryRequestSetDict from .common import BigQueryRequestSetOrDict @@ -1908,6 +1911,9 @@ "BatchPredictionResourceUsageAssessmentResult", "BatchPredictionResourceUsageAssessmentResultDict", "BatchPredictionResourceUsageAssessmentResultOrDict", + "BatchPredictionValidationAssessmentResult", + "BatchPredictionValidationAssessmentResultDict", + "BatchPredictionValidationAssessmentResultOrDict", "TuningResourceUsageAssessmentResult", "TuningResourceUsageAssessmentResultDict", "TuningResourceUsageAssessmentResultOrDict", diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index d0cedc22e0..b56d563fdf 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -14250,6 +14250,27 @@ class BatchPredictionResourceUsageAssessmentResultDict(TypedDict, total=False): ] +class BatchPredictionValidationAssessmentResult(_common.BaseModel): + """Result of batch prediction validation assessment.""" + + errors: Optional[list[str]] = Field( + default=None, description="""The list of errors found in the dataset.""" + ) + + +class BatchPredictionValidationAssessmentResultDict(TypedDict, total=False): + """Result of batch prediction validation assessment.""" + + errors: Optional[list[str]] + """The list of errors found in the dataset.""" + + +BatchPredictionValidationAssessmentResultOrDict = Union[ + BatchPredictionValidationAssessmentResult, + BatchPredictionValidationAssessmentResultDict, +] + + class TuningResourceUsageAssessmentResult(_common.BaseModel): """Result of tuning resource usage assessment."""