diff --git a/tests/unit/vertexai/genai/replays/test_assess_multimodal_dataset.py b/tests/unit/vertexai/genai/replays/test_assess_multimodal_dataset.py index 673a3b5b68..0db129943a 100644 --- a/tests/unit/vertexai/genai/replays/test_assess_multimodal_dataset.py +++ b/tests/unit/vertexai/genai/replays/test_assess_multimodal_dataset.py @@ -66,6 +66,29 @@ def test_assess_tuning_resources(client): assert isinstance(response, types.TuningResourceUsageAssessmentResult) +def test_assess_tuning_validity(client): + response = client.datasets.assess_tuning_validity( + dataset_name=DATASET, + dataset_usage="SFT_VALIDATION", + model_name="gemini-2.5-flash-001", + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + }, + { + "role": "model", + "parts": [{"text": "{capital}"}], + }, + ], + ), + ), + ) + assert isinstance(response, types.TuningValidationAssessmentResult) + + pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), @@ -88,7 +111,7 @@ async def test_assess_dataset_async(client): { "role": "user", "parts": [{"text": "What is the capital of {name}?"}], - } + }, ], ), ), @@ -114,3 +137,27 @@ async def test_assess_tuning_resources_async(client): ), ) assert isinstance(response, types.TuningResourceUsageAssessmentResult) + + +@pytest.mark.asyncio +async def test_assess_tuning_validity_async(client): + response = await client.aio.datasets.assess_tuning_validity( + dataset_name=DATASET, + dataset_usage="SFT_VALIDATION", + model_name="gemini-2.5-flash-001", + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + }, + { + "role": "model", + "parts": [{"text": "{capital}"}], + }, + ], + ), + ), + ) + assert isinstance(response, types.TuningValidationAssessmentResult) diff --git a/vertexai/_genai/datasets.py b/vertexai/_genai/datasets.py index e92070704f..3da6db4709 100644 --- a/vertexai/_genai/datasets.py +++ b/vertexai/_genai/datasets.py @@ -1054,6 +1054,68 @@ def assess_tuning_resources( response["tuningResourceUsageAssessmentResult"], ) + def assess_tuning_validity( + self, + *, + dataset_name: str, + model_name: str, + dataset_usage: str, + template_config: Optional[types.GeminiTemplateConfigOrDict] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.TuningValidationAssessmentResult: + """Assess if the assembled dataset is valid in terms of tuning a given + model. + + Args: + dataset_name: + Required. The name of the dataset to assess the tuning validity + for. + model_name: + Required. The name of the model to assess the tuning validity + for. + dataset_usage: + Required. The dataset usage to assess the tuning validity for. + Must be one of the following: SFT_TRAINING, SFT_VALIDATION. + template_config: + Optional. The template config used to assemble the dataset + before assessing the tuning validity. If not provided, the + template config attached to the dataset will be used. Required + if no template config is attached to the dataset. + config: + Optional. A configuration for assessing the tuning validity. If not + provided, the default configuration will be used. + + Returns: + A dict containing the tuning validity assessment result. The dict + contains the following keys: + - errors: A list of errors that occurred during the tuning validity + assessment. + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + operation = self._assess_multimodal_dataset( + name=dataset_name, + tuning_validation_assessment_config=types.TuningValidationAssessmentConfig( + model_name=model_name, + dataset_usage=dataset_usage, + ), + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=template_config, + ), + config=config, + ) + response = self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + return _datasets_utils.create_from_response( + types.TuningValidationAssessmentResult, + response["tuningValidationAssessmentResult"], + ) + class AsyncDatasets(_api_module.BaseModule): @@ -1875,3 +1937,65 @@ async def assess_tuning_resources( types.TuningResourceUsageAssessmentResult, response["tuningResourceUsageAssessmentResult"], ) + + async def assess_tuning_validity( + self, + *, + dataset_name: str, + model_name: str, + dataset_usage: str, + template_config: Optional[types.GeminiTemplateConfigOrDict] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.TuningValidationAssessmentResult: + """Assess if the assembled dataset is valid in terms of tuning a given + model. + + Args: + dataset_name: + Required. The name of the dataset to assess the tuning validity + for. + model_name: + Required. The name of the model to assess the tuning validity + for. + dataset_usage: + Required. The dataset usage to assess the tuning validity for. + Must be one of the following: SFT_TRAINING, SFT_VALIDATION. + template_config: + Optional. The template config used to assemble the dataset + before assessing the tuning validity. If not provided, the + template config attached to the dataset will be used. Required + if no template config is attached to the dataset. + config: + Optional. A configuration for assessing the tuning validity. If not + provided, the default configuration will be used. + + Returns: + A dict containing the tuning validity assessment result. The dict + contains the following keys: + - errors: A list of errors that occurred during the tuning validity + assessment. + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + operation = await self._assess_multimodal_dataset( + name=dataset_name, + tuning_validation_assessment_config=types.TuningValidationAssessmentConfig( + model_name=model_name, + dataset_usage=dataset_usage, + ), + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=template_config, + ), + config=config, + ) + response = await self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + return _datasets_utils.create_from_response( + types.TuningValidationAssessmentResult, + response["tuningValidationAssessmentResult"], + )