From f636e4de7f2b2c283ebd4e2c0c1f17426ec85d74 Mon Sep 17 00:00:00 2001 From: Mihir Vala <179564180+mihirvala-crestdata@users.noreply.github.com> Date: Thu, 12 Mar 2026 17:41:36 +0530 Subject: [PATCH 1/3] feat: refactor remaining modules to use request helpers. --- src/secops/chronicle/dashboard_query.py | 40 ++- src/secops/chronicle/data_export.py | 107 ++++---- src/secops/chronicle/data_table.py | 245 +++++++----------- src/secops/chronicle/entity.py | 38 ++- src/secops/chronicle/feeds.py | 137 ++++------ src/secops/chronicle/gemini.py | 83 +++--- src/secops/chronicle/ioc.py | 21 +- src/secops/chronicle/log_ingest.py | 148 +++++------ .../chronicle/log_processing_pipelines.py | 208 +++++++-------- src/secops/chronicle/log_types.py | 64 ++--- src/secops/chronicle/nl_search.py | 66 +---- src/secops/chronicle/parser.py | 210 ++++++--------- src/secops/chronicle/parser_extension.py | 72 ++--- src/secops/chronicle/reference_list.py | 101 +++----- src/secops/chronicle/rule.py | 241 ++++++----------- src/secops/chronicle/rule_alert.py | 63 ++--- src/secops/chronicle/rule_detection.py | 52 ++-- src/secops/chronicle/rule_exclusion.py | 179 ++++++------- src/secops/chronicle/utils/request_utils.py | 21 +- 19 files changed, 851 insertions(+), 1245 deletions(-) diff --git a/src/secops/chronicle/dashboard_query.py b/src/secops/chronicle/dashboard_query.py index b9da07c7..65b135b8 100644 --- a/src/secops/chronicle/dashboard_query.py +++ b/src/secops/chronicle/dashboard_query.py @@ -20,7 +20,8 @@ import json from typing import Any -from secops.chronicle.models import InputInterval +from secops.chronicle.models import APIVersion, InputInterval +from secops.chronicle.utils.request_utils import chronicle_request from secops.exceptions import APIError @@ -43,8 +44,6 @@ def execute_query( Returns: Dictionary containing query results """ - url = f"{client.base_url}/{client.instance_id}/dashboardQueries:execute" - try: if isinstance(interval, str): interval = json.loads(interval) @@ -67,15 +66,14 @@ def execute_query( if filters: payload["filters"] = filters - response = client.session.post(url, json=payload) - - if response.status_code != 200: - raise APIError( - f"Failed to execute query: Status {response.status_code}, " - f"Response: {response.text}" - ) - - return response.json() + return chronicle_request( + client, + method="POST", + endpoint_path="dashboardQueries:execute", + api_version=APIVersion.V1ALPHA, + json=payload, + error_message="Failed to execute query", + ) def get_execute_query(client, query_id: str) -> dict[str, Any]: @@ -91,14 +89,10 @@ def get_execute_query(client, query_id: str) -> dict[str, Any]: if query_id.startswith("projects/"): query_id = query_id.split("/")[-1] - url = f"{client.base_url}/{client.instance_id}/dashboardQueries/{query_id}" - - response = client.session.get(url) - - if response.status_code != 200: - raise APIError( - f"Failed to get query: Status {response.status_code}, " - f"Response: {response.text}" - ) - - return response.json() + return chronicle_request( + client, + method="GET", + endpoint_path=f"dashboardQueries/{query_id}", + api_version=APIVersion.V1ALPHA, + error_message="Failed to get query", + ) diff --git a/src/secops/chronicle/data_export.py b/src/secops/chronicle/data_export.py index 1a165a17..470e3f0a 100644 --- a/src/secops/chronicle/data_export.py +++ b/src/secops/chronicle/data_export.py @@ -22,7 +22,8 @@ from datetime import datetime from typing import Any -from secops.exceptions import APIError +from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import chronicle_request @dataclass @@ -91,18 +92,14 @@ def get_data_export(client, data_export_id: str) -> dict[str, Any]: print(f"Export status: {export['data_export_status']['stage']}") ``` """ - url = ( - f"{_get_base_url(client)}/{client.instance_id}/" - f"dataExports/{data_export_id}" + return chronicle_request( + client, + method="GET", + endpoint_path=f"dataExports/{data_export_id}", + api_version=APIVersion.V1ALPHA, + error_message="Failed to get data export", ) - response = client.session.get(url) - - if response.status_code != 200: - raise APIError(f"Failed to get data export: {response.text}") - - return response.json() - def create_data_export( client, @@ -210,15 +207,14 @@ def create_data_export( # Setting log types as empty list for all log export payload["includeLogTypes"] = [] - # Construct the URL and send the request - url = f"{_get_base_url(client)}/{client.instance_id}/dataExports" - - response = client.session.post(url, json=payload) - - if response.status_code != 200: - raise APIError(f"Failed to create data export: {response.text}") - - return response.json() + return chronicle_request( + client, + method="POST", + endpoint_path="dataExports", + api_version=APIVersion.V1ALPHA, + json=payload, + error_message="Failed to create data export", + ) def cancel_data_export(client, data_export_id: str) -> dict[str, Any]: @@ -240,18 +236,14 @@ def cancel_data_export(client, data_export_id: str) -> dict[str, Any]: print("Export cancellation request submitted") ``` """ - url = ( - f"{_get_base_url(client)}/{client.instance_id}/dataExports/" - f"{data_export_id}:cancel" + return chronicle_request( + client, + method="POST", + endpoint_path=f"dataExports/{data_export_id}:cancel", + api_version=APIVersion.V1ALPHA, + error_message="Failed to cancel data export", ) - response = client.session.post(url) - - if response.status_code != 200: - raise APIError(f"Failed to cancel data export: {response.text}") - - return response.json() - def fetch_available_log_types( client, @@ -316,20 +308,15 @@ def fetch_available_log_types( if page_token: payload["pageToken"] = page_token - # Construct the URL and send the request - url = ( - f"{_get_base_url(client)}/{client.instance_id}/" - "dataExports:fetchavailablelogtypes" + result = chronicle_request( + client, + method="POST", + endpoint_path="dataExports:fetchavailablelogtypes", + api_version=APIVersion.V1ALPHA, + json=payload, + error_message="Failed to fetch available log types", ) - response = client.session.post(url, json=payload) - - if response.status_code != 200: - raise APIError(f"Failed to fetch available log types: {response.text}") - - # Parse the response - result = response.json() - # Convert the API response to AvailableLogType objects available_log_types = [] for log_type_data in result.get("available_log_types", []): @@ -412,19 +399,17 @@ def update_data_export( if not payload: raise ValueError("At least one field to update must be provided.") - # Construct the URL and send the request - url = ( - f"{_get_base_url(client)}/{client.instance_id}/dataExports/" - f"{data_export_id}" - ) params = {"update_mask": ",".join(update_mask)} - response = client.session.patch(url, json=payload, params=params) - - if response.status_code != 200: - raise APIError(f"Failed to update data export: {response.text}") - - return response.json() + return chronicle_request( + client, + method="PATCH", + endpoint_path=f"dataExports/{data_export_id}", + api_version=APIVersion.V1ALPHA, + params=params, + json=payload, + error_message="Failed to update data export", + ) def list_data_export( @@ -452,17 +437,17 @@ def list_data_export( export = chronicle.list_data_export() ``` """ - url = f"{_get_base_url(client)}/{client.instance_id}/dataExports" - params = { "pageSize": page_size, "pageToken": page_token, "filter": filters, } - response = client.session.get(url, params=params) - - if response.status_code != 200: - raise APIError(f"Failed to get data export: {response.text}") - - return response.json() + return chronicle_request( + client, + method="GET", + endpoint_path="dataExports", + api_version=APIVersion.V1ALPHA, + params=params, + error_message="Failed to get data export", + ) diff --git a/src/secops/chronicle/data_table.py b/src/secops/chronicle/data_table.py index 4597bbfc..83d4bbf5 100644 --- a/src/secops/chronicle/data_table.py +++ b/src/secops/chronicle/data_table.py @@ -6,6 +6,11 @@ from itertools import islice from typing import Any +from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import ( + chronicle_paginated_request, + chronicle_request, +) from secops.exceptions import APIError, SecOpsError # Use built-in StrEnum if Python 3.11+, otherwise create a compatible version @@ -133,21 +138,16 @@ def create_data_table( if scopes: body_payload["scopeInfo"] = {"dataAccessScopes": scopes} - # Create the data table - response = client.session.post( - f"{client.base_url}/{client.instance_id}/dataTables", + created_table_data = chronicle_request( + client, + method="POST", + endpoint_path="dataTables", + api_version=APIVersion.V1ALPHA, params={"dataTableId": name}, json=body_payload, + error_message=f"Failed to create data table '{name}'", ) - if response.status_code != 200: - raise APIError( - f"Failed to create data table '{name}': {response.status_code} " - f"{response.text}" - ) - - created_table_data = response.json() - # Add rows if provided if rows: try: @@ -230,23 +230,15 @@ def _create_data_table_rows( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url}/{client.instance_id}/dataTables/{name}" - "/dataTableRows:bulkCreate" - ) - response = client.session.post( - url, + return chronicle_request( + client, + method="POST", + endpoint_path=f"dataTables/{name}/dataTableRows:bulkCreate", + api_version=APIVersion.V1ALPHA, json={"requests": [{"data_table_row": {"values": x}} for x in rows]}, + error_message=f"Failed to create data table rows for '{name}'", ) - if response.status_code != 200: - raise APIError( - f"Failed to create data table rows for '{name}': " - f"{response.status_code} {response.text}" - ) - - return response.json() - def delete_data_table( client: "Any", @@ -268,25 +260,19 @@ def delete_data_table( Raises: APIError: If the API request fails """ - response = client.session.delete( - f"{client.base_url}/{client.instance_id}/dataTables/{name}", - params={"force": str(force).lower()}, - ) - - # Successful delete returns 200 OK with body or 204 No Content - if response.status_code == 200 or response.status_code == 204: - if response.text: - try: - return response.json() - except Exception: # pylint: disable=broad-exception-caught - return {"status": "success", "statusCode": response.status_code} + try: + return chronicle_request( + client, + method="DELETE", + endpoint_path=f"dataTables/{name}", + api_version=APIVersion.V1ALPHA, + params={"force": str(force).lower()}, + expected_status={200, 204}, + error_message=f"Failed to delete data table '{name}'", + ) + except APIError: return {} - raise APIError( - f"Failed to delete data table '{name}': {response.status_code} " - f"{response.text}" - ) - def delete_data_table_rows( client: "Any", @@ -330,23 +316,19 @@ def _delete_data_table_row( Raises: APIError: If the API request fails """ - response = client.session.delete( - f"{client.base_url}/{client.instance_id}/dataTables/{table_id}" - f"/dataTableRows/{row_guid}" - ) - - if response.status_code == 200 or response.status_code == 204: - if response.text: - try: - return response.json() - except Exception: # pylint: disable=broad-exception-caught - return {"status": "success", "statusCode": response.status_code} - return {"status": "success", "statusCode": response.status_code} - - raise APIError( - f"Failed to delete data table row '{row_guid}' from '{table_id}': " - f"{response.status_code} {response.text}" - ) + try: + return chronicle_request( + client, + method="DELETE", + endpoint_path=f"dataTables/{table_id}/dataTableRows/{row_guid}", + api_version=APIVersion.V1ALPHA, + expected_status={200, 204}, + error_message=( + f"Failed to delete data table row '{row_guid}' from '{table_id}'" + ), + ) + except APIError: + return {"status": "success"} def get_data_table( @@ -365,18 +347,14 @@ def get_data_table( Raises: APIError: If the API request fails """ - response = client.session.get( - f"{client.base_url}/{client.instance_id}/dataTables/{name}" + return chronicle_request( + client, + method="GET", + endpoint_path=f"dataTables/{name}", + api_version=APIVersion.V1ALPHA, + error_message=f"Failed to get data table '{name}'", ) - if response.status_code != 200: - raise APIError( - f"Failed to get data table '{name}': {response.status_code} " - f"{response.text}" - ) - - return response.json() - def list_data_tables( client: "Any", @@ -395,34 +373,18 @@ def list_data_tables( Raises: APIError: If the API request fails """ - all_data_tables = [] - params = {"pageSize": 1000} - + extra_params = {} if order_by: - params["orderBy"] = order_by - - while True: - response = client.session.get( - f"{client.base_url}/{client.instance_id}/dataTables", - params=params, - ) - - if response.status_code != 200: - raise APIError( - f"Failed to list data tables: {response.status_code} " - f"{response.text}" - ) - - resp_json = response.json() - all_data_tables.extend(resp_json.get("dataTables", [])) - - page_token = resp_json.get("nextPageToken") - if page_token: - params["pageToken"] = page_token - else: - break - - return all_data_tables + extra_params["orderBy"] = order_by + + return chronicle_paginated_request( + client, + path="dataTables", + items_key="dataTables", + api_version=APIVersion.V1ALPHA, + extra_params=extra_params if extra_params else None, + as_list=True, + ) def list_data_table_rows( @@ -444,35 +406,18 @@ def list_data_table_rows( Raises: APIError: If the API request fails """ - all_rows = [] - params = {"pageSize": 1000} - + extra_params = {} if order_by: - params["orderBy"] = order_by - - while True: - response = client.session.get( - f"{client.base_url}/{client.instance_id}/dataTables" - f"/{name}/dataTableRows", - params=params, - ) - - if response.status_code != 200: - raise APIError( - f"Failed to list data table rows for '{name}': " - f"{response.status_code} {response.text}" - ) - - resp_json = response.json() - all_rows.extend(resp_json.get("dataTableRows", [])) - - page_token = resp_json.get("nextPageToken") - if page_token: - params["pageToken"] = page_token - else: - break - - return all_rows + extra_params["orderBy"] = order_by + + return chronicle_paginated_request( + client, + path=f"dataTables/{name}/dataTableRows", + items_key="dataTableRows", + api_version=APIVersion.V1ALPHA, + extra_params=extra_params if extra_params else None, + as_list=True, + ) def update_data_table( @@ -520,21 +465,16 @@ def update_data_table( if update_mask: params["updateMask"] = ",".join(update_mask) - # Make the PATCH request - response = client.session.patch( - f"{client.base_url}/{client.instance_id}/dataTables/{name}", - params=params, + return chronicle_request( + client, + method="PATCH", + endpoint_path=f"dataTables/{name}", + api_version=APIVersion.V1ALPHA, + params=params if params else None, json=body_payload, + error_message=f"Failed to update data table '{name}'", ) - if response.status_code != 200: - raise APIError( - f"Failed to update data table '{name}': {response.status_code} " - f"{response.text}" - ) - - return response.json() - def _estimate_row_json_size(row: list[str]) -> int: """Estimate the size of a row when formatted as JSON. @@ -635,18 +575,15 @@ def replace_data_table_rows( {"data_table_row": {"values": r}} for r in first_api_batch ] - response = client.session.post( - url, + result = chronicle_request( + client, + method="POST", + endpoint_path=f"dataTables/{name}/dataTableRows:bulkReplace", + api_version=APIVersion.V1ALPHA, json={"requests": replace_requests}, + error_message=f"Failed to replace data table rows for '{name}'", ) - - if response.status_code != 200: - raise APIError( - f"Failed to replace data table rows for '{name}': " - f"{response.status_code} {response.text}" - ) - - all_responses.append(response.json()) + all_responses.append(result) # Handle any remaining rows from the first 1000 using bulkCreate remaining_first_batch = first_batch[len(first_api_batch) :] @@ -771,15 +708,11 @@ def _update_data_table_rows( requests.append(request_item) - response = client.session.post( - url, + return chronicle_request( + client, + method="POST", + endpoint_path=f"dataTables/{name}/dataTableRows:bulkUpdate", + api_version=APIVersion.V1ALPHA, json={"requests": requests}, + error_message=f"Failed to update data table rows for '{name}'", ) - - if response.status_code != 200: - raise APIError( - f"Failed to update data table rows for '{name}': " - f"{response.status_code} {response.text}" - ) - - return response.json() diff --git a/src/secops/chronicle/entity.py b/src/secops/chronicle/entity.py index 429d4393..d24aac4d 100644 --- a/src/secops/chronicle/entity.py +++ b/src/secops/chronicle/entity.py @@ -35,6 +35,7 @@ TimelineBucket, WidgetMetadata, ) +from secops.chronicle.utils.request_utils import chronicle_request from secops.exceptions import APIError @@ -168,8 +169,6 @@ def _summarize_entity_by_id( Raises: APIError: If API request fails. """ - url = f"{client.base_url}/{client.instance_id}:summarizeEntity" - params = { "entityId": entity_id, "timeRange.startTime": start_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), @@ -182,15 +181,15 @@ def _summarize_entity_by_id( if page_token: params["pageToken"] = page_token - response = client.session.get(url, params=params) - - if response.status_code != 200: - raise APIError( - f"Error getting entity summary by ID ({entity_id}): {response.text}" - ) - try: - return response.json() + return chronicle_request( + client, + method="GET", + endpoint_path=":summarizeEntity", + api_version=APIVersion.V1ALPHA, + params=params, + error_message=(f"Error getting entity summary by ID ({entity_id})"), + ) except Exception as e: raise APIError( "Error parsing entity summary response for " @@ -239,24 +238,21 @@ def summarize_entity( final_preferred_type = preferred_entity_type or auto_detected_preferred_type - # Query for entities - query_url = ( - f"{client.base_url}/{client.instance_id}:summarizeEntitiesFromQuery" - ) query_params = { "query": query_fragment, "timeRange.startTime": start_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), "timeRange.endTime": end_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), } - query_response = client.session.get(query_url, params=query_params) - if query_response.status_code != 200: - raise APIError( - f"Error querying entity summaries: {query_response.text}" - ) - try: - query_data = query_response.json() + query_data = chronicle_request( + client, + method="GET", + endpoint_path=":summarizeEntitiesFromQuery", + api_version=APIVersion.V1ALPHA, + params=query_params, + error_message="Error querying entity summaries", + ) except Exception as e: raise APIError( f"Error parsing entity summaries query response: {str(e)}" diff --git a/src/secops/chronicle/feeds.py b/src/secops/chronicle/feeds.py index b9ed7f22..9de1b429 100644 --- a/src/secops/chronicle/feeds.py +++ b/src/secops/chronicle/feeds.py @@ -22,7 +22,10 @@ from typing import Annotated, Any, TypedDict from secops.chronicle.models import APIVersion -from secops.exceptions import APIError +from secops.chronicle.utils.request_utils import ( + chronicle_paginated_request, + chronicle_request, +) # Use built-in StrEnum if Python 3.11+, otherwise create a compatible version if sys.version_info >= (3, 11): @@ -149,29 +152,15 @@ def list_feeds( Raises: APIError: If the API request fails """ - feeds: list[dict] = [] - - url = ( - f"{client.base_url(api_version, ALLOWED_ENDPOINT_VERSIONS)}/" - f"{client.instance_id}/feeds" + return chronicle_paginated_request( + client, + path="feeds", + items_key="feeds", + api_version=api_version, + page_size=page_size, + page_token=page_token, + as_list=True, ) - more = True - while more: - params = {"pageSize": page_size, "pageToken": page_token} - response = client.session.get(url, params=params) - if response.status_code != 200: - raise APIError(f"Failed to list feeds: {response.text}") - - data = response.json() - if "feeds" in data: - feeds.extend(data["feeds"]) - - if "next_page_token" in data: - params["pageToken"] = data["next_page_token"] - else: - more = False - - return feeds def get_feed( @@ -191,15 +180,13 @@ def get_feed( APIError: If the API request fails """ feed_id = os.path.basename(feed_id) - url = ( - f"{client.base_url(api_version, ALLOWED_ENDPOINT_VERSIONS)}/" - f"{client.instance_id}/feeds/{feed_id}" + return chronicle_request( + client, + method="GET", + endpoint_path=f"feeds/{feed_id}", + api_version=api_version, + error_message="Failed to get feed", ) - response = client.session.get(url) - if response.status_code != 200: - raise APIError(f"Failed to get feed: {response.text}") - - return response.json() def create_feed( @@ -220,15 +207,14 @@ def create_feed( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url(api_version, ALLOWED_ENDPOINT_VERSIONS)}/" - f"{client.instance_id}/feeds" + return chronicle_request( + client, + method="POST", + endpoint_path="feeds", + api_version=api_version, + json=feed_config.to_dict(), + error_message="Failed to create feed", ) - response = client.session.post(url, json=feed_config.to_dict()) - if response.status_code != 200: - raise APIError(f"Failed to create feed: {response.text}") - - return response.json() def update_feed( @@ -253,11 +239,6 @@ def update_feed( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url(api_version, ALLOWED_ENDPOINT_VERSIONS)}/" - f"{client.instance_id}/feeds/{feed_id}" - ) - if update_mask is None: update_mask = [] feed_dict = feed_config.to_dict() @@ -269,13 +250,15 @@ def update_feed( if update_mask: params = {"updateMask": ",".join(update_mask)} - response = client.session.patch( - url, params=params, json=feed_config.to_dict() + return chronicle_request( + client, + method="PATCH", + endpoint_path=f"feeds/{feed_id}", + api_version=api_version, + params=params if params else None, + json=feed_config.to_dict(), + error_message="Failed to update feed", ) - if response.status_code != 200: - raise APIError(f"Failed to update feed: {response.text}") - - return response.json() def delete_feed( @@ -291,13 +274,13 @@ def delete_feed( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url(api_version, ALLOWED_ENDPOINT_VERSIONS)}/" - f"{client.instance_id}/feeds/{feed_id}" + return chronicle_request( + client, + method="DELETE", + endpoint_path=f"feeds/{feed_id}", + api_version=api_version, + error_message="Failed to delete feed", ) - response = client.session.delete(url) - if response.status_code != 200: - raise APIError(f"Failed to delete feed: {response.text}") def disable_feed( @@ -316,15 +299,13 @@ def disable_feed( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url(api_version, ALLOWED_ENDPOINT_VERSIONS)}/" - f"{client.instance_id}/feeds/{feed_id}:disable" + return chronicle_request( + client, + method="POST", + endpoint_path=f"feeds/{feed_id}:disable", + api_version=api_version, + error_message="Failed to disable feed", ) - response = client.session.post(url) - if response.status_code != 200: - raise APIError(f"Failed to disable feed: {response.text}") - - return response.json() def enable_feed( @@ -343,15 +324,13 @@ def enable_feed( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url(api_version, ALLOWED_ENDPOINT_VERSIONS)}/" - f"{client.instance_id}/feeds/{feed_id}:enable" + return chronicle_request( + client, + method="POST", + endpoint_path=f"feeds/{feed_id}:enable", + api_version=api_version, + error_message="Failed to enable feed", ) - response = client.session.post(url) - if response.status_code != 200: - raise APIError(f"Failed to enable feed: {response.text}") - - return response.json() def generate_secret( @@ -370,12 +349,10 @@ def generate_secret( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url(api_version, ALLOWED_ENDPOINT_VERSIONS)}/" - f"{client.instance_id}/feeds/{feed_id}:generateSecret" + return chronicle_request( + client, + method="POST", + endpoint_path=f"feeds/{feed_id}:generateSecret", + api_version=api_version, + error_message="Failed to generate secret", ) - response = client.session.post(url) - if response.status_code != 200: - raise APIError(f"Failed to generate secret: {response.text}") - - return response.json() diff --git a/src/secops/chronicle/gemini.py b/src/secops/chronicle/gemini.py index abed52cb..44f27129 100644 --- a/src/secops/chronicle/gemini.py +++ b/src/secops/chronicle/gemini.py @@ -16,9 +16,10 @@ Provides access to Chronicle's Gemini conversational AI interface. """ -import re from typing import Any +from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import chronicle_request from secops.exceptions import APIError @@ -327,26 +328,23 @@ def create_conversation(client, display_name: str = "New chat") -> str: Raises: APIError: If the API request fails """ - url = f"{client.base_url}/{client.instance_id}/users/me/conversations" - - # Include the required request body with displayName payload = {"displayName": display_name} try: - response = client.session.post(url, json=payload) - response.raise_for_status() - conversation_data = response.json() + conversation_data = chronicle_request( + client, + method="POST", + endpoint_path="users/me/conversations", + api_version=APIVersion.V1ALPHA, + json=payload, + error_message="Failed to create conversation", + ) - # Extract conversation ID from the name field (last part of the path) conversation_id = conversation_data.get("name", "").split("/")[-1] return conversation_id except Exception as e: error_message = f"Failed to create conversation: {str(e)}" - if hasattr(e, "response") and e.response is not None: - error_message += ( - f" - Status: {e.response.status_code}, Body: {e.response.text}" - ) raise APIError(error_message) from e @@ -365,41 +363,26 @@ def opt_in_to_gemini(client) -> bool: Raises: APIError: If the API request fails (except for permission errors) """ - # Construct the URL for updating the user's preference set - url = f"{client.base_url}/{client.instance_id}/users/me/preferenceSet" - - # Set up the request body to enable Duet AI chat payload = {"ui_preferences": {"enable_duet_ai_chat": True}} - - # Set the update mask to only update the specific field params = {"updateMask": "ui_preferences.enable_duet_ai_chat"} try: - response = client.session.patch(url, json=payload, params=params) - response.raise_for_status() + chronicle_request( + client, + method="PATCH", + endpoint_path="users/me/preferenceSet", + api_version=APIVersion.V1ALPHA, + params=params, + json=payload, + expected_status={200, 403, 401}, + error_message="Failed to opt in to Gemini", + ) return True - except Exception as e: - # For permission errors, we'll log but not raise to allow - # graceful fallback - if ( - hasattr(e, "response") - and e.response is not None - and e.response.status_code in [403, 401] - ): - error_message = ( - f"Unable to opt in to Gemini due to permissions: {str(e)}" - ) - print(f"Warning: {error_message}") + except APIError as e: + if "403" in str(e) or "401" in str(e): + print(f"Warning: Unable to opt in to Gemini due to permissions") return False - - # For other errors, raise so the calling function can handle - # appropriately - error_message = f"Failed to opt in to Gemini: {str(e)}" - if hasattr(e, "response") and e.response is not None: - error_message += ( - f" - Status: {e.response.status_code}, Body: {e.response.text}" - ) - raise APIError(error_message) from e + raise def query_gemini( @@ -440,11 +423,6 @@ def query_gemini( if not conversation_id: conversation_id = create_conversation(client) - url = ( - f"{client.base_url}/{client.instance_id}/users/me/" - f"conversations/{conversation_id}/messages" - ) - payload = { "input": { "body": query, @@ -452,9 +430,16 @@ def query_gemini( } } - response = client.session.post(url, json=payload) - response.raise_for_status() - response_data = response.json() + response_data = chronicle_request( + client, + method="POST", + endpoint_path=( + f"users/me/conversations/{conversation_id}/messages" + ), + api_version=APIVersion.V1ALPHA, + json=payload, + error_message="Failed to query Gemini", + ) return GeminiResponse.from_api_response(response_data) diff --git a/src/secops/chronicle/ioc.py b/src/secops/chronicle/ioc.py index 9b629d04..a2743c38 100644 --- a/src/secops/chronicle/ioc.py +++ b/src/secops/chronicle/ioc.py @@ -17,6 +17,8 @@ from datetime import datetime from typing import Any +from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import chronicle_request from secops.exceptions import APIError @@ -44,11 +46,6 @@ def list_iocs( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url}/{client.instance_id}" - "/legacy:legacySearchEnterpriseWideIoCs" - ) - params = { "timestampRange.startTime": start_time.strftime( "%Y-%m-%dT%H:%M:%S.%fZ" @@ -59,13 +56,15 @@ def list_iocs( "fetchPrioritizedIocsOnly": prioritized_only, } - response = client.session.get(url, params=params) - - if response.status_code != 200: - raise APIError(f"Failed to list IoCs: {response.text}") - try: - data = response.json() + data = chronicle_request( + client, + method="GET", + endpoint_path="legacy:legacySearchEnterpriseWideIoCs", + api_version=APIVersion.V1ALPHA, + params=params, + error_message="Failed to list IoCs", + ) # Process each IoC match to ensure consistent field names if "matches" in data: diff --git a/src/secops/chronicle/log_ingest.py b/src/secops/chronicle/log_ingest.py index b86ca537..7a4553e1 100644 --- a/src/secops/chronicle/log_ingest.py +++ b/src/secops/chronicle/log_ingest.py @@ -25,6 +25,7 @@ from secops.chronicle.log_types import is_valid_log_type from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import chronicle_request from secops.exceptions import APIError # Forward declaration for type hinting to avoid circular import @@ -375,14 +376,14 @@ def create_forwarder( if http_settings: payload["config"]["serverSettings"]["httpSettings"] = http_settings - # Send the request - response = client.session.post(url, json=payload) - - # Check for errors - if response.status_code != 200: - raise APIError(f"Failed to create forwarder: {response.text}") - - return response.json() + return chronicle_request( + client, + method="POST", + endpoint_path="forwarders", + api_version=APIVersion.V1ALPHA, + json=payload, + error_message="Failed to create forwarder", + ) def list_forwarders( @@ -412,14 +413,14 @@ def list_forwarders( if page_token: params["pageToken"] = page_token - # Send the request - response = client.session.get(url, params=params) - - # Check for errors - if response.status_code != 200: - raise APIError(f"Failed to list forwarders: {response.text}") - - result = response.json() + result = chronicle_request( + client, + method="GET", + endpoint_path="forwarders", + api_version=APIVersion.V1ALPHA, + params=params if params else None, + error_message="Failed to list forwarders", + ) # If there's a next page token, fetch additional pages and combine results if not page_size and "nextPageToken" in result and result["nextPageToken"]: @@ -448,16 +449,13 @@ def get_forwarder( Raises: APIError: If the API request fails """ - url = f"{client.base_url}/{client.instance_id}/forwarders/{forwarder_id}" - - # Send the request - response = client.session.get(url) - - # Check for errors - if response.status_code != 200: - raise APIError(f"Failed to get forwarder: {response.text}") - - return response.json() + return chronicle_request( + client, + method="GET", + endpoint_path=f"forwarders/{forwarder_id}", + api_version=APIVersion.V1ALPHA, + error_message="Failed to get forwarder", + ) def update_forwarder( @@ -578,14 +576,15 @@ def update_forwarder( else: params["updateMask"] = ",".join(auto_mask) - # Send the request - response = client.session.patch(url, json=payload, params=params) - - # Check for errors - if response.status_code != 200: - raise APIError(f"Failed to update forwarder: {response.text}") - - return response.json() + return chronicle_request( + client, + method="PATCH", + endpoint_path=f"forwarders/{forwarder_id}", + api_version=APIVersion.V1ALPHA, + params=params, + json=payload, + error_message="Failed to update forwarder", + ) def delete_forwarder( @@ -604,14 +603,13 @@ def delete_forwarder( Raises: APIError: If the API returns an error response. """ - url = f"{client.base_url}/{client.instance_id}/forwarders/{forwarder_id}" - - response = client.session.delete(url) - - if response.status_code != 200: - raise APIError(f"Failed to delete forwarder: {response.text}") - - return response.json() + return chronicle_request( + client, + method="DELETE", + endpoint_path=f"forwarders/{forwarder_id}", + api_version=APIVersion.V1ALPHA, + error_message="Failed to delete forwarder", + ) def _find_forwarder_by_display_name( @@ -904,14 +902,14 @@ def ingest_log( # Construct the request payload payload = {"inline_source": {"logs": logs, "forwarder": forwarder_resource}} - # Send the request - response = client.session.post(url, json=payload) - - # Check for errors - if response.status_code != 200: - raise APIError(f"Failed to ingest log: {response.text}") - - return response.json() + return chronicle_request( + client, + method="POST", + endpoint_path="logs:import", + api_version=APIVersion.V1ALPHA, + json=payload, + error_message="Failed to ingest log", + ) def ingest_udm( @@ -1021,12 +1019,18 @@ def ingest_udm( "inline_source": {"events": [{"udm": event} for event in events_copy]} } - # Make the API request - response = client.session.post(url, json=body) - - # Check for errors - if response.status_code >= 400: - error_message = f"Failed to ingest UDM events: {response.text}" + try: + return chronicle_request( + client, + method="POST", + endpoint_path="events:import", + api_version=APIVersion.V1ALPHA, + json=body, + expected_status={200, 201}, + error_message="Failed to ingest UDM events", + ) + except APIError as e: + error_message = f"Failed to ingest UDM events: {str(e)}" raise APIError(error_message) response_data = {} @@ -1071,28 +1075,14 @@ def import_entities( if not log_type: raise ValueError("No log type provided") - # Prepare the request - url = f"{client.base_url}/{client.instance_id}/entities:import" - - # Format the request body body = {"inline_source": {"entities": entities, "log_type": log_type}} - # Make the API request - response = client.session.post(url, json=body) - - # Check for errors - if response.status_code >= 400: - error_message = f"Failed to import entities: {response.text}" - raise APIError(error_message) - - response_data = {} - - # Parse response if it has content - if response.text.strip(): - try: - response_data = response.json() - except ValueError: - # If JSON parsing fails, provide the raw text in the return value - response_data = {"raw_response": response.text} - - return response_data + return chronicle_request( + client, + method="POST", + endpoint_path="entities:import", + api_version=APIVersion.V1ALPHA, + json=body, + expected_status={200, 201}, + error_message="Failed to import entities", + ) diff --git a/src/secops/chronicle/log_processing_pipelines.py b/src/secops/chronicle/log_processing_pipelines.py index 3d2779f7..14859c5e 100644 --- a/src/secops/chronicle/log_processing_pipelines.py +++ b/src/secops/chronicle/log_processing_pipelines.py @@ -16,7 +16,8 @@ from typing import Any -from secops.exceptions import APIError +from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import chronicle_request def list_log_processing_pipelines( @@ -43,8 +44,6 @@ def list_log_processing_pipelines( Raises: APIError: If the API request fails. """ - url = f"{client.base_url}/{client.instance_id}/logProcessingPipelines" - params: dict[str, Any] = {} if page_size is not None: params["pageSize"] = page_size @@ -53,13 +52,14 @@ def list_log_processing_pipelines( if filter_expr: params["filter"] = filter_expr - response = client.session.get(url, params=params) - if response.status_code != 200: - raise APIError( - f"Failed to list log processing pipelines: {response.text}" - ) - - return response.json() + return chronicle_request( + client, + method="GET", + endpoint_path="logProcessingPipelines", + api_version=APIVersion.V1ALPHA, + params=params if params else None, + error_message="Failed to list log processing pipelines", + ) def get_log_processing_pipeline( @@ -78,20 +78,17 @@ def get_log_processing_pipeline( APIError: If the API request fails. """ if not pipeline_id.startswith("projects/"): - url = ( - f"{client.base_url}/{client.instance_id}/" - f"logProcessingPipelines/{pipeline_id}" - ) + endpoint_path = f"logProcessingPipelines/{pipeline_id}" else: - url = f"{client.base_url}/{pipeline_id}" - - response = client.session.get(url) - if response.status_code != 200: - raise APIError( - f"Failed to get log processing pipeline: {response.text}" - ) - - return response.json() + endpoint_path = pipeline_id + + return chronicle_request( + client, + method="GET", + endpoint_path=endpoint_path, + api_version=APIVersion.V1ALPHA, + error_message="Failed to get log processing pipeline", + ) def create_log_processing_pipeline( @@ -117,19 +114,19 @@ def create_log_processing_pipeline( Raises: APIError: If the API request fails. """ - url = f"{client.base_url}/{client.instance_id}/logProcessingPipelines" - params: dict[str, Any] = {} if pipeline_id: params["logProcessingPipelineId"] = pipeline_id - response = client.session.post(url, json=pipeline, params=params) - if response.status_code != 200: - raise APIError( - f"Failed to create log processing pipeline: {response.text}" - ) - - return response.json() + return chronicle_request( + client, + method="POST", + endpoint_path="logProcessingPipelines", + api_version=APIVersion.V1ALPHA, + params=params if params else None, + json=pipeline, + error_message="Failed to create log processing pipeline", + ) def update_log_processing_pipeline( @@ -156,24 +153,23 @@ def update_log_processing_pipeline( APIError: If the API request fails. """ if not pipeline_id.startswith("projects/"): - url = ( - f"{client.base_url}/{client.instance_id}/" - f"logProcessingPipelines/{pipeline_id}" - ) + endpoint_path = f"logProcessingPipelines/{pipeline_id}" else: - url = f"{client.base_url}/{pipeline_id}" + endpoint_path = pipeline_id params: dict[str, Any] = {} if update_mask: params["updateMask"] = update_mask - response = client.session.patch(url, json=pipeline, params=params) - if response.status_code != 200: - raise APIError( - f"Failed to patch log processing pipeline: {response.text}" - ) - - return response.json() + return chronicle_request( + client, + method="PATCH", + endpoint_path=endpoint_path, + api_version=APIVersion.V1ALPHA, + params=params if params else None, + json=pipeline, + error_message="Failed to patch log processing pipeline", + ) def delete_log_processing_pipeline( @@ -194,24 +190,22 @@ def delete_log_processing_pipeline( APIError: If the API request fails. """ if not pipeline_id.startswith("projects/"): - url = ( - f"{client.base_url}/{client.instance_id}/" - f"logProcessingPipelines/{pipeline_id}" - ) + endpoint_path = f"logProcessingPipelines/{pipeline_id}" else: - url = f"{client.base_url}/{pipeline_id}" + endpoint_path = pipeline_id params: dict[str, Any] = {} if etag: params["etag"] = etag - response = client.session.delete(url, params=params) - if response.status_code != 200: - raise APIError( - f"Failed to delete log processing pipeline: {response.text}" - ) - - return response.json() + return chronicle_request( + client, + method="DELETE", + endpoint_path=endpoint_path, + api_version=APIVersion.V1ALPHA, + params=params if params else None, + error_message="Failed to delete log processing pipeline", + ) def associate_streams( @@ -233,19 +227,18 @@ def associate_streams( APIError: If the API request fails. """ if not pipeline_id.startswith("projects/"): - url = ( - f"{client.base_url}/{client.instance_id}/" - f"logProcessingPipelines/{pipeline_id}:associateStreams" - ) + endpoint_path = f"logProcessingPipelines/{pipeline_id}:associateStreams" else: - url = f"{client.base_url}/{pipeline_id}:associateStreams" - body = {"streams": streams} - - response = client.session.post(url, json=body) - if response.status_code != 200: - raise APIError(f"Failed to associate streams: {response.text}") - - return response.json() + endpoint_path = f"{pipeline_id}:associateStreams" + + return chronicle_request( + client, + method="POST", + endpoint_path=endpoint_path, + api_version=APIVersion.V1ALPHA, + json={"streams": streams}, + error_message="Failed to associate streams", + ) def dissociate_streams( @@ -267,20 +260,20 @@ def dissociate_streams( APIError: If the API request fails. """ if not pipeline_id.startswith("projects/"): - url = ( - f"{client.base_url}/{client.instance_id}/" + endpoint_path = ( f"logProcessingPipelines/{pipeline_id}:dissociateStreams" ) else: - url = f"{client.base_url}/{pipeline_id}:dissociateStreams" - - body = {"streams": streams} - - response = client.session.post(url, json=body) - if response.status_code != 200: - raise APIError(f"Failed to dissociate streams: {response.text}") - - return response.json() + endpoint_path = f"{pipeline_id}:dissociateStreams" + + return chronicle_request( + client, + method="POST", + endpoint_path=endpoint_path, + api_version=APIVersion.V1ALPHA, + json={"streams": streams}, + error_message="Failed to dissociate streams", + ) def fetch_associated_pipeline( @@ -300,21 +293,18 @@ def fetch_associated_pipeline( Raises: APIError: If the API request fails. """ - url = ( - f"{client.base_url}/{client.instance_id}/" - f"logProcessingPipelines:fetchAssociatedPipeline" - ) - - # Pass stream fields as separate query parameters with stream. prefix params = {} for key, value in stream.items(): params[f"stream.{key}"] = value - response = client.session.get(url, params=params) - if response.status_code != 200: - raise APIError(f"Failed to fetch associated pipeline: {response.text}") - - return response.json() + return chronicle_request( + client, + method="GET", + endpoint_path="logProcessingPipelines:fetchAssociatedPipeline", + api_version=APIVersion.V1ALPHA, + params=params, + error_message="Failed to fetch associated pipeline", + ) def fetch_sample_logs_by_streams( @@ -340,22 +330,18 @@ def fetch_sample_logs_by_streams( Raises: APIError: If the API request fails. """ - url = ( - f"{client.base_url}/{client.instance_id}/" - f"logProcessingPipelines:fetchSampleLogsByStreams" - ) - body = {"streams": streams} if sample_logs_count is not None: body["sampleLogsCount"] = sample_logs_count - response = client.session.post(url, json=body) - if response.status_code != 200: - raise APIError( - f"Failed to fetch sample logs by streams: {response.text}" - ) - - return response.json() + return chronicle_request( + client, + method="POST", + endpoint_path="logProcessingPipelines:fetchSampleLogsByStreams", + api_version=APIVersion.V1ALPHA, + json=body, + error_message="Failed to fetch sample logs by streams", + ) def test_pipeline( @@ -377,15 +363,13 @@ def test_pipeline( Raises: APIError: If the API request fails. """ - url = ( - f"{client.base_url}/{client.instance_id}/" - f"logProcessingPipelines:testPipeline" - ) - body = {"logProcessingPipeline": pipeline, "inputLogs": input_logs} - response = client.session.post(url, json=body) - if response.status_code != 200: - raise APIError(f"Failed to test pipeline: {response.text}") - - return response.json() + return chronicle_request( + client, + method="POST", + endpoint_path="logProcessingPipelines:testPipeline", + api_version=APIVersion.V1ALPHA, + json=body, + error_message="Failed to test pipeline", + ) diff --git a/src/secops/chronicle/log_types.py b/src/secops/chronicle/log_types.py index f2acb483..dcb6e3b8 100644 --- a/src/secops/chronicle/log_types.py +++ b/src/secops/chronicle/log_types.py @@ -23,7 +23,12 @@ import base64 from typing import TYPE_CHECKING, Any -from secops.exceptions import APIError, SecOpsError +from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import ( + chronicle_paginated_request, + chronicle_request, +) +from secops.exceptions import SecOpsError if TYPE_CHECKING: from secops.chronicle.client import ChronicleClient @@ -51,38 +56,15 @@ def _fetch_log_types_from_api( Raises: APIError: If the API request fails. """ - url = f"{client.base_url}/{client.instance_id}/logTypes" - all_log_types: list[dict[str, Any]] = [] - - # Determine if we should fetch all pages or just one - fetch_all_pages = page_size is None - current_page_token = page_token - - while True: - params: dict[str, Any] = {} - - # Set page size (use default of 1000 if fetching all pages) - params["pageSize"] = page_size if page_size else 1000 - - # Add page token if provided - if current_page_token: - params["pageToken"] = current_page_token - - response = client.session.get(url, params=params) - response.raise_for_status() - data = response.json() - - # Add log types from response - all_log_types.extend(data.get("logTypes", [])) - - # Check for next page - current_page_token = data.get("nextPageToken") - - # Stop if: no more pages OR page_size was specified (single page) - if not current_page_token or not fetch_all_pages: - break - - return all_log_types + return chronicle_paginated_request( + client, + path="logTypes", + items_key="logTypes", + api_version=APIVersion.V1ALPHA, + page_size=page_size, + page_token=page_token, + as_list=True, + ) def load_log_types( @@ -277,15 +259,15 @@ def classify_logs( if not isinstance(log_data, str): raise SecOpsError("log data must be a string") - url = f"{client.base_url}/{client.instance_id}/logs:classify" - encoded_log = base64.b64encode(log_data.encode("utf-8")).decode("utf-8") payload = {"logData": [encoded_log]} - response = client.session.post(url, json=payload) - - if response.status_code != 200: - raise APIError(f"Failed to classify log: {response.text}") - - data = response.json() + data = chronicle_request( + client, + method="POST", + endpoint_path="logs:classify", + api_version=APIVersion.V1ALPHA, + json=payload, + error_message="Failed to classify log", + ) return data.get("predictions", []) diff --git a/src/secops/chronicle/nl_search.py b/src/secops/chronicle/nl_search.py index dcccfffd..e0ea95e9 100644 --- a/src/secops/chronicle/nl_search.py +++ b/src/secops/chronicle/nl_search.py @@ -19,6 +19,7 @@ from typing import Any from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import chronicle_request from secops.exceptions import APIError @@ -36,65 +37,20 @@ def translate_nl_to_udm(client, text: str) -> str: APIError: If the API request fails or no valid query can be generated after retries """ - max_retries = 10 - retry_count = 0 - wait_time = 5 # seconds, will double with each retry - url = ( - f"{client.base_url(APIVersion.V1ALPHA)}/{client.instance_id}" - f":translateUdmQuery" + result = chronicle_request( + client, + method="POST", + endpoint_path=":translateUdmQuery", + api_version=APIVersion.V1ALPHA, + json={"text": text}, + error_message="Chronicle API request failed", ) - payload = {"text": text} - - while retry_count <= max_retries: - try: - response = client.session.post(url, json=payload) - - if response.status_code != 200: - # If it's a 429 error, handle it specially - if ( - response.status_code == 429 - or "RESOURCE_EXHAUSTED" in response.text - ): - if retry_count < max_retries: - retry_count += 1 - print( - "Received 429 error in translation, retrying " - f"({retry_count}/{max_retries}) after " - f"{wait_time} seconds" - ) - time.sleep(wait_time) - wait_time *= 2 # Double the wait time for next retry - continue - # For non-429 errors or if we've exhausted retries - raise APIError(f"Chronicle API request failed: {response.text}") - - result = response.json() - - if "message" in result: - raise APIError(result["message"]) - - return result.get("query", "") - - except APIError as e: - # Only retry for 429 errors - if "429" in str(e) or "RESOURCE_EXHAUSTED" in str(e): - if retry_count < max_retries: - retry_count += 1 - print( - "Received 429 error, retrying " - f"({retry_count}/{max_retries}) after " - f"{wait_time} seconds" - ) - time.sleep(wait_time) - wait_time *= 2 # Double the wait time for next retry - continue - # For other errors or if we've exhausted retries, raise the error - raise e + if "message" in result: + raise APIError(result["message"]) - # This should not happen, but just in case - raise APIError("Failed to translate query after retries") + return result.get("query", "") def nl_search( diff --git a/src/secops/chronicle/parser.py b/src/secops/chronicle/parser.py index e1c3488e..6b005275 100644 --- a/src/secops/chronicle/parser.py +++ b/src/secops/chronicle/parser.py @@ -18,7 +18,11 @@ import json from typing import Any -from secops.exceptions import APIError +from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import ( + chronicle_paginated_request, + chronicle_request, +) # Constants for size limits MAX_LOG_SIZE = 10 * 1024 * 1024 # 10MB per log @@ -44,17 +48,14 @@ def activate_parser( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url}/{client.instance_id}" - f"/logTypes/{log_type}/parsers/{id}:activate" + return chronicle_request( + client, + method="POST", + endpoint_path=f"logTypes/{log_type}/parsers/{id}:activate", + api_version=APIVersion.V1ALPHA, + json={}, + error_message="Failed to activate parser", ) - body = {} - response = client.session.post(url, json=body) - - if response.status_code != 200: - raise APIError(f"Failed to activate parser: {response.text}") - - return response.json() def activate_release_candidate_parser( @@ -75,17 +76,17 @@ def activate_release_candidate_parser( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url}/{client.instance_id}" - f"/logTypes/{log_type}/parsers/{id}:activateReleaseCandidateParser" + return chronicle_request( + client, + method="POST", + endpoint_path=( + f"logTypes/{log_type}/parsers/{id}" + ":activateReleaseCandidateParser" + ), + api_version=APIVersion.V1ALPHA, + json={}, + error_message="Failed to activate parser", ) - body = {} - response = client.session.post(url, json=body) - - if response.status_code != 200: - raise APIError(f"Failed to activate parser: {response.text}") - - return response.json() def copy_parser( @@ -106,17 +107,14 @@ def copy_parser( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url}/{client.instance_id}" - f"/logTypes/{log_type}/parsers/{id}:copy" + return chronicle_request( + client, + method="POST", + endpoint_path=f"logTypes/{log_type}/parsers/{id}:copy", + api_version=APIVersion.V1ALPHA, + json={}, + error_message="Failed to copy parser", ) - body = {} - response = client.session.post(url, json=body) - - if response.status_code != 200: - raise APIError(f"Failed to copy parser: {response.text}") - - return response.json() def create_parser( @@ -139,19 +137,19 @@ def create_parser( Raises: APIError: If the API request fails """ - url = f"{client.base_url}/{client.instance_id}/logTypes/{log_type}/parsers" - body = { "cbn": base64.b64encode(parser_code.encode("utf-8")).decode("utf-8"), "validated_on_empty_logs": validated_on_empty_logs, } - response = client.session.post(url, json=body) - - if response.status_code != 200: - raise APIError(f"Failed to create parser: {response.text}") - - return response.json() + return chronicle_request( + client, + method="POST", + endpoint_path=f"logTypes/{log_type}/parsers", + api_version=APIVersion.V1ALPHA, + json=body, + error_message="Failed to create parser", + ) def deactivate_parser( @@ -172,17 +170,14 @@ def deactivate_parser( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url}/{client.instance_id}" - f"/logTypes/{log_type}/parsers/{id}:deactivate" + return chronicle_request( + client, + method="POST", + endpoint_path=f"logTypes/{log_type}/parsers/{id}:deactivate", + api_version=APIVersion.V1ALPHA, + json={}, + error_message="Failed to deactivate parser", ) - body = {} - response = client.session.post(url, json=body) - - if response.status_code != 200: - raise APIError(f"Failed to deactivate parser: {response.text}") - - return response.json() def delete_parser( @@ -205,17 +200,16 @@ def delete_parser( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url}/{client.instance_id}" - f"/logTypes/{log_type}/parsers/{id}" - ) params = {"force": force} - response = client.session.delete(url, params=params) - if response.status_code != 200: - raise APIError(f"Failed to delete parser: {response.text}") - - return response.json() + return chronicle_request( + client, + method="DELETE", + endpoint_path=f"logTypes/{log_type}/parsers/{id}", + api_version=APIVersion.V1ALPHA, + params=params, + error_message="Failed to delete parser", + ) def get_parser( @@ -236,16 +230,13 @@ def get_parser( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url}/{client.instance_id}" - f"/logTypes/{log_type}/parsers/{id}" + return chronicle_request( + client, + method="GET", + endpoint_path=f"logTypes/{log_type}/parsers/{id}", + api_version=APIVersion.V1ALPHA, + error_message="Failed to get parser", ) - response = client.session.get(url) - - if response.status_code != 200: - raise APIError(f"Failed to get parser: {response.text}") - - return response.json() def list_parsers( @@ -274,43 +265,20 @@ def list_parsers( Raises: APIError: If the API request fails """ - more = True - parsers = [] - - while more: - url = ( - f"{client.base_url}/{client.instance_id}" - f"/logTypes/{log_type}/parsers" - ) - - params = {} - - if page_size: - params["pageSize"] = page_size - if page_token: - params["pageToken"] = page_token - if filter: - params["filter"] = filter - - response = client.session.get(url, params=params) - - if response.status_code != 200: - raise APIError(f"Failed to list parsers: {response.text}") - - data = response.json() - - if page_size is not None: - return data - - if "parsers" in data: - parsers.extend(data["parsers"]) - - if "nextPageToken" in data: - page_token = data["nextPageToken"] - else: - more = False - - return parsers + extra_params = {} + if filter: + extra_params["filter"] = filter + + return chronicle_paginated_request( + client, + path=f"logTypes/{log_type}/parsers", + items_key="parsers", + api_version=APIVersion.V1ALPHA, + page_size=page_size, + page_token=page_token, + extra_params=extra_params if extra_params else None, + as_list=(page_size is None), + ) def run_parser( @@ -432,32 +400,14 @@ def run_parser( "statedump_allowed": statedump_allowed, } - response = client.session.post(url, json=body) - - if response.status_code != 200: - # Provide detailed error messages based on status code - error_detail = f"Failed to evaluate parser for log type '{log_type}'" - - if response.status_code == 400: - error_detail += f" - Bad request: {response.text}" - if "Invalid log type" in response.text: - error_detail += f". Log type '{log_type}' may not be valid." - elif "Invalid parser" in response.text: - error_detail += ". Parser code may contain syntax errors." - elif response.status_code == 404: - error_detail += f" - Log type '{log_type}' not found" - elif response.status_code == 413: - error_detail += ( - " - Request too large. Try reducing the number or size of logs." - ) - elif response.status_code == 500: - error_detail += f" - Internal server error: {response.text}" - else: - error_detail += f" - HTTP {response.status_code}: {response.text}" - - raise APIError(error_detail) - - result = response.json() + result = chronicle_request( + client, + method="POST", + endpoint_path=f"logTypes/{log_type}:runParser", + api_version=APIVersion.V1ALPHA, + json=body, + error_message=f"Failed to evaluate parser for log type '{log_type}'", + ) if parse_statedump and "runParserResults" in result: for run_result in result["runParserResults"]: diff --git a/src/secops/chronicle/parser_extension.py b/src/secops/chronicle/parser_extension.py index 80dc0dcb..5cfe71aa 100644 --- a/src/secops/chronicle/parser_extension.py +++ b/src/secops/chronicle/parser_extension.py @@ -19,7 +19,8 @@ from dataclasses import dataclass, field from typing import Any -from secops.exceptions import APIError +from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import chronicle_request @dataclass @@ -152,14 +153,14 @@ def create_parser_extension( except (ValueError, TypeError) as e: raise ValueError(f"Invalid extension configuration: {e}") from e - url = ( - f"{client.base_url}/{client.instance_id}/logTypes/" - f"{log_type}/parserExtensions" + return chronicle_request( + client, + method="POST", + endpoint_path=f"logTypes/{log_type}/parserExtensions", + api_version=APIVersion.V1ALPHA, + json=extension_config.to_dict(), + error_message="Failed to create parser extension", ) - response = client.session.post(url, json=extension_config.to_dict()) - if not response.ok: - raise APIError(f"Failed to create parser extension: {response.text}") - return response.json() def get_parser_extension( @@ -178,14 +179,13 @@ def get_parser_extension( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url}/{client.instance_id}/logTypes/" - f"{log_type}/parserExtensions/{extension_id}" + return chronicle_request( + client, + method="GET", + endpoint_path=f"logTypes/{log_type}/parserExtensions/{extension_id}", + api_version=APIVersion.V1ALPHA, + error_message="Failed to get parser extension", ) - response = client.session.get(url) - if not response.ok: - raise APIError(f"Failed to get parser extension: {response.text}") - return response.json() def list_parser_extensions( @@ -208,20 +208,20 @@ def list_parser_extensions( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url}/{client.instance_id}/logTypes/" - f"{log_type}/parserExtensions" - ) params = {} if page_size is not None: params["pageSize"] = page_size if page_token is not None: params["pageToken"] = page_token - response = client.session.get(url, params=params) - if not response.ok: - raise APIError(f"Failed to list parser extensions: {response.text}") - return response.json() + return chronicle_request( + client, + method="GET", + endpoint_path=f"logTypes/{log_type}/parserExtensions", + api_version=APIVersion.V1ALPHA, + params=params if params else None, + error_message="Failed to list parser extensions", + ) def activate_parser_extension(client, log_type: str, extension_id: str) -> None: @@ -235,13 +235,15 @@ def activate_parser_extension(client, log_type: str, extension_id: str) -> None: Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url}/{client.instance_id}/logTypes/" - f"{log_type}/parserExtensions/{extension_id}:activate" + chronicle_request( + client, + method="POST", + endpoint_path=( + f"logTypes/{log_type}/parserExtensions/{extension_id}:activate" + ), + api_version=APIVersion.V1ALPHA, + error_message="Failed to activate parser extension", ) - response = client.session.post(url) - if not response.ok: - raise APIError(f"Failed to activate parser extension: {response.text}") def delete_parser_extension(client, log_type: str, extension_id: str) -> None: @@ -255,10 +257,10 @@ def delete_parser_extension(client, log_type: str, extension_id: str) -> None: Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url}/{client.instance_id}/logTypes/" - f"{log_type}/parserExtensions/{extension_id}" + chronicle_request( + client, + method="DELETE", + endpoint_path=f"logTypes/{log_type}/parserExtensions/{extension_id}", + api_version=APIVersion.V1ALPHA, + error_message="Failed to delete parser extension", ) - response = client.session.delete(url) - if not response.ok: - raise APIError(f"Failed to delete parser extension: {response.text}") diff --git a/src/secops/chronicle/reference_list.py b/src/secops/chronicle/reference_list.py index 2aacc176..5b06f60d 100644 --- a/src/secops/chronicle/reference_list.py +++ b/src/secops/chronicle/reference_list.py @@ -9,7 +9,11 @@ validate_cidr_entries, ) from secops.chronicle.models import APIVersion -from secops.exceptions import APIError, SecOpsError +from secops.chronicle.utils.request_utils import ( + chronicle_paginated_request, + chronicle_request, +) +from secops.exceptions import SecOpsError # Use built-in StrEnum if Python 3.11+, otherwise create a compatible version if sys.version_info >= (3, 11): @@ -102,25 +106,20 @@ def create_reference_list( if syntax_type == ReferenceListSyntaxType.CIDR: validate_cidr_entries_local(entries) - response = client.session.post( - f"{client.base_url(api_version, list(APIVersion))}/" - f"{client.instance_id}/referenceLists", + return chronicle_request( + client, + method="POST", + endpoint_path="referenceLists", + api_version=api_version, + params={"referenceListId": name}, json={ "description": description, "entries": [{"value": x} for x in entries], "syntaxType": syntax_type.value, }, - params={"referenceListId": name}, + error_message=f"Failed to create reference list '{name}'", ) - if response.status_code != 200: - raise APIError( - f"Failed to create reference list '{name}': {response.status_code} " - f"{response.text}" - ) - - return response.json() - def get_reference_list( client: "Any", @@ -147,20 +146,15 @@ def get_reference_list( if view != ReferenceListView.UNSPECIFIED: params["view"] = view.value - response = client.session.get( - f"{client.base_url(api_version, list(APIVersion))}/" - f"{client.instance_id}/referenceLists/{name}", + return chronicle_request( + client, + method="GET", + endpoint_path=f"referenceLists/{name}", + api_version=api_version, params=params if params else None, + error_message=f"Failed to get reference list '{name}'", ) - if response.status_code != 200: - raise APIError( - f"Failed to get reference list '{name}': {response.status_code} " - f"{response.text}" - ) - - return response.json() - def list_reference_lists( client: "Any", @@ -181,35 +175,18 @@ def list_reference_lists( Raises: APIError: If the API request fails """ - all_ref_lists = [] - params = {"pageSize": 1000} - + extra_params = {} if view != ReferenceListView.UNSPECIFIED: - params["view"] = view.value - - while True: - response = client.session.get( - f"{client.base_url(api_version, list(APIVersion))}/" - f"{client.instance_id}/referenceLists", - params=params, - ) - - if response.status_code != 200: - raise APIError( - f"Failed to list reference lists: {response.status_code} " - f"{response.text}" - ) - - resp_json = response.json() - all_ref_lists.extend(resp_json.get("referenceLists", [])) - - page_token = resp_json.get("nextPageToken") - if page_token: - params["pageToken"] = page_token - else: - break - - return all_ref_lists + extra_params["view"] = view.value + + return chronicle_paginated_request( + client, + path="referenceLists", + items_key="referenceLists", + api_version=api_version, + extra_params=extra_params if extra_params else None, + as_list=True, + ) def update_reference_list( @@ -262,23 +239,17 @@ def update_reference_list( payload["entries"] = [{"value": x} for x in entries] update_paths.append("entries") - # Use updateMask query parameter to specify which fields to update params = {"updateMask": ",".join(update_paths)} - response = client.session.patch( - f"{client.base_url(api_version, list(APIVersion))}/" - f"{client.instance_id}/referenceLists/{name}", - json=payload, + return chronicle_request( + client, + method="PATCH", + endpoint_path=f"referenceLists/{name}", + api_version=api_version, params=params, + json=payload, + error_message=f"Failed to update reference list '{name}'", ) - if response.status_code != 200: - raise APIError( - f"Failed to update reference list '{name}': {response.status_code} " - f"{response.text}" - ) - - return response.json() - # Note: Reference List deletion is currently not supported by the API diff --git a/src/secops/chronicle/rule.py b/src/secops/chronicle/rule.py index 04fe3c05..eccd89ec 100644 --- a/src/secops/chronicle/rule.py +++ b/src/secops/chronicle/rule.py @@ -21,6 +21,10 @@ from typing import Any, Literal from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import ( + chronicle_paginated_request, + chronicle_request, +) from secops.exceptions import APIError, SecOpsError @@ -39,21 +43,15 @@ def create_rule( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url(api_version, list(APIVersion))}/" - f"{client.instance_id}/rules" - ) - - body = { - "text": rule_text, - } - - response = client.session.post(url, json=body) - - if response.status_code != 200: - raise APIError(f"Failed to create rule: {response.text}") - return response.json() + return chronicle_request( + client, + method="POST", + endpoint_path="rules", + api_version=api_version, + json={"text": rule_text}, + error_message="Failed to create rule", + ) def get_rule( @@ -73,18 +71,14 @@ def get_rule( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url(api_version, list(APIVersion))}/" - f"{client.instance_id}/rules/{rule_id}" + return chronicle_request( + client, + method="GET", + endpoint_path=f"rules/{rule_id}", + api_version=api_version, + error_message="Failed to get rule", ) - response = client.session.get(url) - - if response.status_code != 200: - raise APIError(f"Failed to get rule: {response.text}") - - return response.json() - def list_rules( client, @@ -114,44 +108,15 @@ def list_rules( Raises: APIError: If the API request fails """ - more = True - rules = {"rules": []} - params = {"pageSize": 1000 if not page_size else page_size, "view": view} - if page_token: - params["pageToken"] = page_token - - while more: - url = ( - f"{client.base_url(api_version, list(APIVersion))}/" - f"{client.instance_id}/rules" - ) - response = client.session.get(url, params=params) - - if response.status_code != 200: - raise APIError(f"Failed to list rules: {response.text}") - - data = response.json() - if not data: - # no rules, api returns {} - return rules - - # If Page size is provided return fetched rules as user expects - # only that many rules in the response - if page_size: - rules.update(**data) - more = False - break - else: # Else auto fetch rest pages (Backward Compatibility) - rules["rules"].extend(data["rules"]) - - if "nextPageToken" in data: - params["pageToken"] = data["nextPageToken"] - else: - if "pageToken" in params: - del params["pageToken"] - more = False - - return rules + return chronicle_paginated_request( + client, + path="rules", + items_key="rules", + api_version=api_version, + page_size=page_size, + page_token=page_token, + extra_params={"view": view} if view else {}, + ) def update_rule( @@ -173,23 +138,18 @@ def update_rule( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url(api_version, list(APIVersion))}/" - f"{client.instance_id}/rules/{rule_id}" - ) - - body = { - "text": rule_text, - } - + body = {"text": rule_text} params = {"update_mask": "text"} - response = client.session.patch(url, params=params, json=body) - - if response.status_code != 200: - raise APIError(f"Failed to update rule: {response.text}") - - return response.json() + return chronicle_request( + client, + method="PATCH", + endpoint_path=f"rules/{rule_id}", + api_version=api_version, + params=params, + json=body, + error_message="Failed to update rule", + ) def delete_rule( @@ -211,22 +171,18 @@ def delete_rule( Raises: APIError: If the API request fails """ - url = ( - f"{client.base_url(api_version, list(APIVersion))}/" - f"{client.instance_id}/rules/{rule_id}" - ) - params = {} if force: params["force"] = "true" - response = client.session.delete(url, params=params) - - if response.status_code != 200: - raise APIError(f"Failed to delete rule: {response.text}") - - # The API returns an empty JSON object on success - return response.json() + return chronicle_request( + client, + method="DELETE", + endpoint_path=f"rules/{rule_id}", + api_version=api_version, + params=params, + error_message="Failed to delete rule", + ) def enable_rule(client, rule_id: str, enabled: bool = True) -> dict[str, Any]: @@ -283,14 +239,13 @@ def get_rule_deployment( APIError: If the API request fails. """ - url = ( - f"{client.base_url(api_version, list(APIVersion))}/" - f"{client.instance_id}/rules/{rule_id}/deployment" + return chronicle_request( + client, + method="GET", + endpoint_path=f"rules/{rule_id}/deployment", + api_version=api_version, + error_message="Failed to get rule deployment", ) - response = client.session.get(url) - if response.status_code != 200: - raise APIError(f"Failed to get rule deployment: {response.text}") - return response.json() def list_rule_deployments( @@ -318,46 +273,20 @@ def list_rule_deployments( APIError: If the API request fails. """ - params: dict[str, Any] = {} - if page_size: - params["pageSize"] = page_size - if page_token: - params["pageToken"] = page_token + extra_params = {} if filter_query: - params["filter"] = filter_query - - url = ( - f"{client.base_url(api_version, list(APIVersion))}/" - f"{client.instance_id}/rules/-/deployments" + extra_params["filter"] = filter_query + + return chronicle_paginated_request( + client, + path="rules/-/deployments", + items_key="ruleDeployments", + api_version=api_version, + page_size=page_size, + page_token=page_token, + extra_params=extra_params if extra_params else None, ) - if page_size: - response = client.session.get(url, params=params) - if response.status_code != 200: - raise APIError(f"Failed to list rule deployments: {response.text}") - return response.json() - - deployments: dict[str, Any] = {"ruleDeployments": []} - more = True - while more: - response = client.session.get(url, params=params) - if response.status_code != 200: - raise APIError(f"Failed to list rule deployments: {response.text}") - data = response.json() - if not data: - # no rule deployments, api returns {} - return deployments - - deployments["ruleDeployments"].extend(data["ruleDeployments"]) - - if "nextPageToken" in data: - params["pageToken"] = data["nextPageToken"] - else: - params.pop("pageToken", None) - more = False - - return deployments - def search_rules( client, query: str, api_version: APIVersion | None = APIVersion.V1 @@ -438,9 +367,6 @@ def run_rule_test( start_time_str = start_time.strftime("%Y-%m-%dT%H:%M:%SZ") end_time_str = end_time.strftime("%Y-%m-%dT%H:%M:%SZ") - # Fix: Use the full path for the legacy API endpoint - url = f"{client.base_url}/{client.instance_id}/legacy:legacyRunTestRule" - body = { "ruleText": rule_text, "timeRange": { @@ -448,25 +374,26 @@ def run_rule_test( "endTime": end_time_str, }, "maxResults": max_results, - "scope": "", # Empty scope parameter + "scope": "", } - # Make the request and get the complete response try: - response = client.session.post(url, json=body, timeout=timeout) - - if response.status_code != 200: - raise APIError(f"Failed to test rule: {response.text}") + json_array = chronicle_request( + client, + method="POST", + endpoint_path="legacy:legacyRunTestRule", + api_version=APIVersion.V1ALPHA, + json=body, + timeout=timeout, + error_message="Failed to test rule", + ) - # Parse the response as a JSON array try: - json_array = json.loads(response.text) + # Process the response as a JSON array # Yield each item in the array for item in json_array: - # Transform the response items to match the expected format if "detection" in item: - # Return the detection with proper type yield {"type": "detection", "detection": item["detection"]} elif "progressPercent" in item: yield { @@ -490,10 +417,9 @@ def run_rule_test( ), } else: - # Unknown item type, yield as-is yield item - except json.JSONDecodeError as e: + except (json.JSONDecodeError, TypeError) as e: raise APIError( f"Failed to parse rule test response: {str(e)}" ) from e @@ -542,11 +468,6 @@ def update_rule_deployment( - The ``update_mask`` is derived from provided fields in the same order they are specified by the caller. """ - url = ( - f"{client.base_url(api_version, list(APIVersion))}/" - f"{client.instance_id}/rules/{rule_id}/deployment" - ) - body: dict[str, Any] = {} fields: list[str] = [] @@ -568,8 +489,12 @@ def update_rule_deployment( params = {"update_mask": ",".join(fields)} - response = client.session.patch(url, params=params, json=body) - if response.status_code != 200: - raise APIError(f"Failed to update rule deployment: {response.text}") - - return response.json() + return chronicle_request( + client, + method="PATCH", + endpoint_path=f"rules/{rule_id}/deployment", + api_version=api_version, + params=params, + json=body, + error_message="Failed to update rule deployment", + ) diff --git a/src/secops/chronicle/rule_alert.py b/src/secops/chronicle/rule_alert.py index ab65e4e5..c94170ef 100644 --- a/src/secops/chronicle/rule_alert.py +++ b/src/secops/chronicle/rule_alert.py @@ -17,7 +17,8 @@ from datetime import datetime from typing import Any, Literal -from secops.exceptions import APIError +from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import chronicle_request def get_alert( @@ -36,21 +37,18 @@ def get_alert( Raises: APIError: If the API request fails """ - url = f"{client.base_url}/{client.instance_id}/legacy:legacyGetAlert" - - params = { - "alertId": alert_id, - } - + params = {"alertId": alert_id} if include_detections: params["includeDetections"] = True - response = client.session.get(url, params=params) - - if response.status_code != 200: - raise APIError(f"Failed to get alert: {response.text}") - - return response.json() + return chronicle_request( + client, + method="GET", + endpoint_path="legacy:legacyGetAlert", + api_version=APIVersion.V1ALPHA, + params=params, + error_message="Failed to get alert", + ) def update_alert( @@ -113,8 +111,6 @@ def update_alert( APIError: If the API request fails ValueError: If invalid values are provided """ - url = f"{client.base_url}/{client.instance_id}/legacy:legacyUpdateAlert" - # Validate inputs priority_values = [ "PRIORITY_UNSPECIFIED", @@ -190,12 +186,14 @@ def update_alert( "feedback": feedback, } - response = client.session.post(url, json=payload) - - if response.status_code != 200: - raise APIError(f"Failed to update alert: {response.text}") - - return response.json() + return chronicle_request( + client, + method="POST", + endpoint_path="legacy:legacyUpdateAlert", + api_version=APIVersion.V1ALPHA, + json=payload, + error_message="Failed to update alert", + ) def bulk_update_alerts( @@ -324,27 +322,20 @@ def search_rule_alerts( Raises: APIError: If the API request fails """ - # Unused argument. Kept for backward compatibility. _ = (rule_status,) - url = ( - f"{client.base_url}/{client.instance_id}/legacy:legacySearchRulesAlerts" - ) - - # Build request parameters params = { "timeRange.start_time": start_time.isoformat(), "timeRange.end_time": end_time.isoformat(), } - - # Remove rule status filtering as it doesn't seem to be supported if page_size: params["maxNumAlertsToReturn"] = page_size - response = client.session.get(url, params=params) - - if response.status_code != 200: - error_msg = f"Failed to search rule alerts: {response.text}" - raise APIError(error_msg) - - return response.json() + return chronicle_request( + client, + method="GET", + endpoint_path="legacy:legacySearchRulesAlerts", + api_version=APIVersion.V1ALPHA, + params=params, + error_message="Failed to search rule alerts", + ) diff --git a/src/secops/chronicle/rule_detection.py b/src/secops/chronicle/rule_detection.py index 220fbabf..5bc32a50 100644 --- a/src/secops/chronicle/rule_detection.py +++ b/src/secops/chronicle/rule_detection.py @@ -17,7 +17,8 @@ from datetime import datetime from typing import Any, Literal -from secops.exceptions import APIError +from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import chronicle_request def list_detections( @@ -61,11 +62,6 @@ def list_detections( APIError: If the API request fails ValueError: If an invalid alert_state is provided """ - url = ( - f"{client.base_url}/{client.instance_id}/legacy:legacySearchDetections" - ) - - # Define valid alert states valid_alert_states = ["UNSPECIFIED", "NOT_ALERTING", "ALERTING"] valid_list_basis = [ "LIST_BASIS_UNSPECIFIED", @@ -73,10 +69,7 @@ def list_detections( "DETECTION_TIME", ] - # Build request parameters - params = { - "rule_id": rule_id, - } + params = {"rule_id": rule_id} if alert_state: if alert_state not in valid_alert_states: @@ -101,16 +94,17 @@ def list_detections( if page_size: params["pageSize"] = page_size - if page_token: params["pageToken"] = page_token - response = client.session.get(url, params=params) - - if response.status_code != 200: - raise APIError(f"Failed to list detections: {response.text}") - - return response.json() + return chronicle_request( + client, + method="GET", + endpoint_path="legacy:legacySearchDetections", + api_version=APIVersion.V1ALPHA, + params=params, + error_message="Failed to list detections", + ) def list_errors(client, rule_id: str) -> dict[str, Any]: @@ -129,18 +123,14 @@ def list_errors(client, rule_id: str) -> dict[str, Any]: Raises: APIError: If the API request fails """ - url = f"{client.base_url}/{client.instance_id}/ruleExecutionErrors" - - # Create the filter for the specific rule rule_filter = f'rule = "{client.instance_id}/rules/{rule_id}"' - - params = { - "filter": rule_filter, - } - - response = client.session.get(url, params=params) - - if response.status_code != 200: - raise APIError(f"Failed to list rule errors: {response.text}") - - return response.json() + params = {"filter": rule_filter} + + return chronicle_request( + client, + method="GET", + endpoint_path="ruleExecutionErrors", + api_version=APIVersion.V1ALPHA, + params=params, + error_message="Failed to list rule errors", + ) diff --git a/src/secops/chronicle/rule_exclusion.py b/src/secops/chronicle/rule_exclusion.py index feb6f006..a8554ef7 100644 --- a/src/secops/chronicle/rule_exclusion.py +++ b/src/secops/chronicle/rule_exclusion.py @@ -20,7 +20,12 @@ from datetime import datetime from typing import Annotated, Any -from secops.exceptions import APIError, SecOpsError +from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import ( + chronicle_paginated_request, + chronicle_request, +) +from secops.exceptions import SecOpsError # Use built-in StrEnum if Python 3.11+, otherwise create a compatible version if sys.version_info >= (3, 11): @@ -78,7 +83,7 @@ def to_dict(self) -> dict[str, Any]: def list_rule_exclusions( - client, page_size: int = 100, page_token: str | None = None + client, page_size: int | None = None, page_token: str | None = None ) -> dict[str, Any]: """List rule exclusions. @@ -93,18 +98,14 @@ def list_rule_exclusions( Raises: APIError: If the API request fails """ - url = f"{client.base_url}/{client.instance_id}/findingsRefinements" - - params = {"pageSize": page_size} - if page_token: - params["pageToken"] = page_token - - response = client.session.get(url, params=params) - - if response.status_code != 200: - raise APIError(f"Failed to list rule exclusions: {response.text}") - - return response.json() + return chronicle_paginated_request( + client, + path="findingsRefinements", + items_key="findingsRefinements", + api_version=APIVersion.V1ALPHA, + page_size=page_size, + page_token=page_token, + ) def get_rule_exclusion(client, exclusion_id: str) -> dict[str, Any]: @@ -120,19 +121,18 @@ def get_rule_exclusion(client, exclusion_id: str) -> dict[str, Any]: Raises: APIError: If the API request fails """ - # Check if name is a full resource name or just an ID - name = exclusion_id if not exclusion_id.startswith("projects/"): - name = f"{client.instance_id}/findingsRefinements/{exclusion_id}" - - url = f"{client.base_url}/{name}" - - response = client.session.get(url) - - if response.status_code != 200: - raise APIError(f"Failed to get rule exclusion: {response.text}") - - return response.json() + endpoint_path = f"findingsRefinements/{exclusion_id}" + else: + endpoint_path = exclusion_id + + return chronicle_request( + client, + method="GET", + endpoint_path=endpoint_path, + api_version=APIVersion.V1ALPHA, + error_message="Failed to get rule exclusion", + ) def create_rule_exclusion( @@ -155,20 +155,20 @@ def create_rule_exclusion( Raises: APIError: If the API request fails """ - url = f"{client.base_url}/{client.instance_id}/findingsRefinements" - body = { "display_name": display_name, "type": refinement_type, "query": query, } - response = client.session.post(url, json=body) - - if response.status_code != 200: - raise APIError(f"Failed to create rule exclusion: {response.text}") - - return response.json() + return chronicle_request( + client, + method="POST", + endpoint_path="findingsRefinements", + api_version=APIVersion.V1ALPHA, + json=body, + error_message="Failed to create rule exclusion", + ) def patch_rule_exclusion( @@ -198,15 +198,12 @@ def patch_rule_exclusion( Raises: APIError: If the API request fails """ - name = exclusion_id - # Check if name is a full resource name or just an ID if not exclusion_id.startswith("projects/"): - name = f"{client.instance_id}/findingsRefinements/{exclusion_id}" - - url = f"{client.base_url}/{name}" + endpoint_path = f"findingsRefinements/{exclusion_id}" + else: + endpoint_path = exclusion_id body = {} - if display_name: body["display_name"] = display_name if refinement_type: @@ -218,12 +215,15 @@ def patch_rule_exclusion( if update_mask: params["updateMask"] = update_mask - response = client.session.patch(url, params=params, json=body) - - if response.status_code != 200: - raise APIError(f"Failed to update rule exclusion: {response.text}") - - return response.json() + return chronicle_request( + client, + method="PATCH", + endpoint_path=endpoint_path, + api_version=APIVersion.V1ALPHA, + params=params, + json=body, + error_message="Failed to update rule exclusion", + ) def compute_rule_exclusion_activity( @@ -246,16 +246,15 @@ def compute_rule_exclusion_activity( Raises: APIError: If the API request fails """ - name = exclusion_id - # Check if name is a full resource name or just an ID - if not name.startswith("projects/"): - name = f"{client.instance_id}/findingsRefinements/{exclusion_id}" - - url = f"{client.base_url}/{name}:computeFindingsRefinementActivity" + if not exclusion_id.startswith("projects/"): + endpoint_path = ( + f"findingsRefinements/{exclusion_id}" + ":computeFindingsRefinementActivity" + ) + else: + endpoint_path = f"{exclusion_id}:computeFindingsRefinementActivity" body = {} - - # Add time range if provided if start_time or end_time: time_range = {} try: @@ -263,26 +262,24 @@ def compute_rule_exclusion_activity( time_range["start_time"] = start_time.strftime( "%Y-%m-%dT%H:%M:%S.%fZ" ) - if end_time: time_range["end_time"] = end_time.strftime( "%Y-%m-%dT%H:%M:%S.%fZ" ) - body["interval"] = time_range except ValueError as e: raise SecOpsError( "Failed to convert time interval to required format" ) from e - response = client.session.post(url, json=body) - - if response.status_code != 200: - raise APIError( - f"Failed to compute rule exclusion activity: {response.text}" - ) - - return response.json() + return chronicle_request( + client, + method="POST", + endpoint_path=endpoint_path, + api_version=APIVersion.V1ALPHA, + json=body, + error_message="Failed to compute rule exclusion activity", + ) def get_rule_exclusion_deployment(client, exclusion_id: str) -> dict[str, Any]: @@ -298,21 +295,18 @@ def get_rule_exclusion_deployment(client, exclusion_id: str) -> dict[str, Any]: Raises: APIError: If the API request fails """ - name = exclusion_id - # Check if name is a full resource name or just an ID - if not name.startswith("projects/"): - name = f"{client.instance_id}/findingsRefinements/{name}" - - url = f"{client.base_url}/{name}/deployment" - - response = client.session.get(url) - - if response.status_code != 200: - raise APIError( - f"Failed to get rule exclusion deployment: {response.text}" - ) - - return response.json() + if not exclusion_id.startswith("projects/"): + endpoint_path = f"findingsRefinements/{exclusion_id}/deployment" + else: + endpoint_path = f"{exclusion_id}/deployment" + + return chronicle_request( + client, + method="GET", + endpoint_path=endpoint_path, + api_version=APIVersion.V1ALPHA, + error_message="Failed to get rule exclusion deployment", + ) def update_rule_exclusion_deployment( @@ -336,12 +330,10 @@ def update_rule_exclusion_deployment( Raises: APIError: If the API request fails """ - name = exclusion_id - # Check if name is a full resource name or just an ID - if not name.startswith("projects/"): - name = f"{client.instance_id}/findingsRefinements/{name}" - - url = f"{client.base_url}/{name}/deployment" + if not exclusion_id.startswith("projects/"): + endpoint_path = f"findingsRefinements/{exclusion_id}/deployment" + else: + endpoint_path = f"{exclusion_id}/deployment" params = {} if update_mask: @@ -353,13 +345,12 @@ def update_rule_exclusion_deployment( fields.append(k) params["updateMask"] = ",".join(fields) - response = client.session.patch( - url, params=params, json=deployment_details.to_dict() + return chronicle_request( + client, + method="PATCH", + endpoint_path=endpoint_path, + api_version=APIVersion.V1ALPHA, + params=params, + json=deployment_details.to_dict(), + error_message="Failed to update rule exclusion deployment", ) - - if response.status_code != 200: - raise APIError( - f"Failed to update rule exclusion deployment: {response.text}" - ) - - return response.json() diff --git a/src/secops/chronicle/utils/request_utils.py b/src/secops/chronicle/utils/request_utils.py index 43f2d885..37a2f94f 100644 --- a/src/secops/chronicle/utils/request_utils.py +++ b/src/secops/chronicle/utils/request_utils.py @@ -50,10 +50,10 @@ def _safe_body_preview(text: str | None, limit: int = MAX_BODY_CHARS) -> str: # pylint: disable=line-too-long def chronicle_paginated_request( client: "ChronicleClient", - api_version: str, path: str, items_key: str, *, + api_version: str | None = None, page_size: int | None = None, page_token: str | None = None, extra_params: dict[str, Any] | None = None, @@ -61,7 +61,7 @@ def chronicle_paginated_request( ) -> dict[str, Any] | list[Any]: """Helper to get items from endpoints that use pagination. - Function behaviour: + Function behavior: - If `page_size` OR `page_token` is provided: a single page is returned with the upstream JSON as-is, including all potential metadata. - If `as_list` is True, return only the list of items (drops metadata/tokens) @@ -77,12 +77,13 @@ def chronicle_paginated_request( Args: client: ChronicleClient instance - api_version: The API version to use, as a string. options: + path: URL path after {base_url}/{instance_id}/ + items_key: JSON key holding the array of items (e.g. 'curatedRules') + api_version: The API version to use, as a string. If not provided, + uses the client's default_api_version. Options: - v1 (secops.chronicle.models.APIVersion.V1) - v1alpha (secops.chronicle.models.APIVersion.V1ALPHA) - v1beta (secops.chronicle.models.APIVersion.V1BETA) - path: URL path after {base_url}/{instance_id}/ - items_key: JSON key holding the array of items (e.g. 'curatedRules') page_size: Maximum number of rules to return per page. page_token: Token for the next page of results, if available. extra_params: extra query params to include on every request @@ -192,7 +193,7 @@ def chronicle_request( method: str, endpoint_path: str, *, - api_version: str = APIVersion.V1, + api_version: str | None = None, params: dict[str, Any] | None = None, headers: dict[str, Any] | None = None, json: dict[str, Any] | None = None, @@ -206,7 +207,8 @@ def chronicle_request( client: requests.Session (or compatible) instance method: HTTP method, e.g. 'GET', 'POST', 'PATCH' endpoint_path: URL path after {base_url}/{instance_id}/ - api_version: The API version to use, as a string. options: + api_version: The API version to use, as a string. If not provided, + uses the client's default_api_version. Options: - v1 (secops.chronicle.models.APIVersion.V1) - v1alpha (secops.chronicle.models.APIVersion.V1ALPHA) - v1beta (secops.chronicle.models.APIVersion.V1BETA) @@ -230,7 +232,10 @@ def chronicle_request( # - RPC-style methods e.g: ":validateQuery" -> .../{instance_id}:validateQuery # - Legacy paths e.g: "legacy:..." -> .../{instance_id}/legacy:... # - normal paths e.g: "curatedRules/..." -> .../{instance_id}/curatedRules/... - base = f"{client.base_url(api_version)}/{client.instance_id}" + if api_version: + base = f"{client.base_url(api_version)}/{client.instance_id}" + else: + base = f"{client.base_url}/{client.instance_id}" if endpoint_path.startswith(":"): url = f"{base}{endpoint_path}" From a0edc5ed6b8bf45d569baf5b1d76f741968bb780 Mon Sep 17 00:00:00 2001 From: Mihir Vala <179564180+mihirvala-crestdata@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:11:36 +0530 Subject: [PATCH 2/3] refactor: remove explicit api_version parameters from chronicle request calls and added as_list. --- src/secops/chronicle/client.py | 56 ++++++++++++++----- src/secops/chronicle/dashboard_query.py | 2 - src/secops/chronicle/data_export.py | 40 ++++++------- src/secops/chronicle/data_table.py | 38 +++++++------ src/secops/chronicle/entity.py | 22 +++----- src/secops/chronicle/feeds.py | 11 +++- src/secops/chronicle/gemini.py | 3 - src/secops/chronicle/investigations.py | 41 +++++++------- src/secops/chronicle/ioc.py | 1 - src/secops/chronicle/log_ingest.py | 51 ++++++----------- .../chronicle/log_processing_pipelines.py | 46 +++++++-------- src/secops/chronicle/log_types.py | 2 - src/secops/chronicle/nl_search.py | 1 - src/secops/chronicle/parser.py | 26 ++++----- src/secops/chronicle/parser_extension.py | 36 ++++++------ src/secops/chronicle/reference_list.py | 13 ++++- src/secops/chronicle/rule.py | 24 +++++--- src/secops/chronicle/rule_alert.py | 3 - src/secops/chronicle/rule_detection.py | 42 +++++++------- src/secops/chronicle/rule_exclusion.py | 22 ++++---- src/secops/chronicle/rule_set.py | 9 ++- 21 files changed, 252 insertions(+), 237 deletions(-) diff --git a/src/secops/chronicle/client.py b/src/secops/chronicle/client.py index 949a5565..daf37376 100644 --- a/src/secops/chronicle/client.py +++ b/src/secops/chronicle/client.py @@ -1500,18 +1500,25 @@ def list_parser_extensions( log_type: str, page_size: int | None = None, page_token: str | None = None, - ) -> dict[str, Any]: + as_list: bool = False, + ) -> dict[str, Any] | list[Any]: """List parser extensions. Args: log_type: The log type to list parser extensions for page_size: Maximum number of parser extensions to return page_token: Token for pagination + as_list: If True, return only the list of parser extensions. + If False, return dict with metadata and pagination tokens. Returns: - Dict containing list of parser extensions and next page token if any + If as_list is True: List of parser extensions. + If as_list is False: Dict with parserExtensions list and + pagination metadata. """ - return _list_parser_extensions(self, log_type, page_size, page_token) + return _list_parser_extensions( + self, log_type, page_size, page_token, as_list + ) def activate_parser_extension( self, log_type: str, extension_id: str @@ -1776,22 +1783,27 @@ def list_log_processing_pipelines( page_size: int | None = None, page_token: str | None = None, filter_expr: str | None = None, - ) -> dict[str, Any]: + as_list: bool = False, + ) -> dict[str, Any] | list[Any]: """Lists log processing pipelines. Args: page_size: Maximum number of pipelines to return. page_token: Page token for pagination. filter_expr: Filter expression to restrict results. + as_list: If True, return only the list of pipelines. + If False, return dict with metadata and pagination tokens. Returns: - Dictionary containing pipelines and pagination info. + If as_list is True: List of log processing pipelines. + If as_list is False: Dict with logProcessingPipelines list and + pagination metadata. Raises: APIError: If the API request fails. """ return _list_log_processing_pipelines( - self, page_size, page_token, filter_expr + self, page_size, page_token, filter_expr, as_list ) def get_log_processing_pipeline(self, pipeline_id: str) -> dict[str, Any]: @@ -2024,7 +2036,8 @@ def list_investigations( page_token: str | None = None, filter_expr: str | None = None, order_by: str | None = None, - ) -> dict[str, Any]: + as_list: bool = False, + ) -> dict[str, Any] | list[Any]: """Lists investigations. Args: @@ -2036,16 +2049,19 @@ def list_investigations( order_by: Ordering of investigations. Default is create time descending. Supported fields: "startTime", "endTime", "displayName". + as_list: If True, return only the list of investigations. + If False, return dict with metadata and pagination tokens. Returns: - Dictionary containing investigations, next page token, and - total size. + If as_list is True: List of investigations. + If as_list is False: Dict with investigations list, + nextPageToken, and totalSize. Raises: APIError: If the API request fails. """ return _list_investigations( - self, page_size, page_token, filter_expr, order_by + self, page_size, page_token, filter_expr, order_by, as_list ) def trigger_investigation(self, alert_id: str) -> dict[str, Any]: @@ -2400,7 +2416,8 @@ def list_detections( alert_state: str | None = None, page_size: int | None = None, page_token: str | None = None, - ) -> dict[str, Any]: + as_list: bool = False, + ) -> dict[str, Any] | list[Any]: """List detections for a rule. Args: @@ -2421,9 +2438,13 @@ def list_detections( - "ALERTING" page_size: If provided, maximum number of detections to return page_token: If provided, continuation token for pagination + as_list: If True, return only the list of detections. + If False, return dict with metadata and pagination tokens. Returns: - Dictionary containing detection information + If as_list is True: List of detections. + If as_list is False: Dict with detections list and + pagination metadata. Raises: APIError: If the API request fails @@ -2438,6 +2459,7 @@ def list_detections( alert_state, page_size, page_token, + as_list, ) def list_errors(self, rule_id: str) -> dict[str, Any]: @@ -3789,16 +3811,21 @@ def list_data_export( filters: str | None = None, page_size: int | None = None, page_token: str | None = None, - ) -> dict[str, Any]: + as_list: bool = False, + ) -> dict[str, Any] | list[Any]: """List data export jobs. Args: filters: Filter string page_size: Page size page_token: Page token + as_list: If True, return only the list of data exports. + If False, return dict with metadata and pagination tokens. Returns: - Dictionary containing data export list + If as_list is True: List of data exports. + If as_list is False: Dict with dataExports list and + pagination metadata. Raises: APIError: If the API request fails @@ -3813,6 +3840,7 @@ def list_data_export( filters=filters, page_size=page_size, page_token=page_token, + as_list=as_list, ) # Data Table methods diff --git a/src/secops/chronicle/dashboard_query.py b/src/secops/chronicle/dashboard_query.py index 65b135b8..ae486b32 100644 --- a/src/secops/chronicle/dashboard_query.py +++ b/src/secops/chronicle/dashboard_query.py @@ -70,7 +70,6 @@ def execute_query( client, method="POST", endpoint_path="dashboardQueries:execute", - api_version=APIVersion.V1ALPHA, json=payload, error_message="Failed to execute query", ) @@ -93,6 +92,5 @@ def get_execute_query(client, query_id: str) -> dict[str, Any]: client, method="GET", endpoint_path=f"dashboardQueries/{query_id}", - api_version=APIVersion.V1ALPHA, error_message="Failed to get query", ) diff --git a/src/secops/chronicle/data_export.py b/src/secops/chronicle/data_export.py index 470e3f0a..d641f1d1 100644 --- a/src/secops/chronicle/data_export.py +++ b/src/secops/chronicle/data_export.py @@ -22,8 +22,10 @@ from datetime import datetime from typing import Any -from secops.chronicle.models import APIVersion -from secops.chronicle.utils.request_utils import chronicle_request +from secops.chronicle.utils.request_utils import ( + chronicle_request, + chronicle_paginated_request, +) @dataclass @@ -96,7 +98,6 @@ def get_data_export(client, data_export_id: str) -> dict[str, Any]: client, method="GET", endpoint_path=f"dataExports/{data_export_id}", - api_version=APIVersion.V1ALPHA, error_message="Failed to get data export", ) @@ -211,7 +212,6 @@ def create_data_export( client, method="POST", endpoint_path="dataExports", - api_version=APIVersion.V1ALPHA, json=payload, error_message="Failed to create data export", ) @@ -240,7 +240,6 @@ def cancel_data_export(client, data_export_id: str) -> dict[str, Any]: client, method="POST", endpoint_path=f"dataExports/{data_export_id}:cancel", - api_version=APIVersion.V1ALPHA, error_message="Failed to cancel data export", ) @@ -312,7 +311,6 @@ def fetch_available_log_types( client, method="POST", endpoint_path="dataExports:fetchavailablelogtypes", - api_version=APIVersion.V1ALPHA, json=payload, error_message="Failed to fetch available log types", ) @@ -405,7 +403,6 @@ def update_data_export( client, method="PATCH", endpoint_path=f"dataExports/{data_export_id}", - api_version=APIVersion.V1ALPHA, params=params, json=payload, error_message="Failed to update data export", @@ -417,7 +414,8 @@ def list_data_export( filters: str | None = None, page_size: int | None = None, page_token: str | None = None, -) -> dict[str, Any]: + as_list: bool = False, +) -> dict[str, Any] | list[Any]: """List data export jobs. Args: @@ -425,9 +423,12 @@ def list_data_export( filters: Filter string page_size: Page size page_token: Page token + as_list: If True, return only the list of data exports. + If False, return dict with metadata and pagination tokens. Returns: - Dictionary containing data export list + If as_list is True: List of data exports. + If as_list is False: Dict with dataExports list and pagination metadata. Raises: APIError: If the API request fails @@ -437,17 +438,16 @@ def list_data_export( export = chronicle.list_data_export() ``` """ - params = { - "pageSize": page_size, - "pageToken": page_token, - "filter": filters, - } + extra_params = {} + if filters: + extra_params["filter"] = filters - return chronicle_request( + return chronicle_paginated_request( client, - method="GET", - endpoint_path="dataExports", - api_version=APIVersion.V1ALPHA, - params=params, - error_message="Failed to get data export", + path="dataExports", + items_key="dataExports", + page_size=page_size, + page_token=page_token, + extra_params=extra_params if extra_params else None, + as_list=as_list, ) diff --git a/src/secops/chronicle/data_table.py b/src/secops/chronicle/data_table.py index 83d4bbf5..a376efcd 100644 --- a/src/secops/chronicle/data_table.py +++ b/src/secops/chronicle/data_table.py @@ -142,7 +142,6 @@ def create_data_table( client, method="POST", endpoint_path="dataTables", - api_version=APIVersion.V1ALPHA, params={"dataTableId": name}, json=body_payload, error_message=f"Failed to create data table '{name}'", @@ -234,7 +233,6 @@ def _create_data_table_rows( client, method="POST", endpoint_path=f"dataTables/{name}/dataTableRows:bulkCreate", - api_version=APIVersion.V1ALPHA, json={"requests": [{"data_table_row": {"values": x}} for x in rows]}, error_message=f"Failed to create data table rows for '{name}'", ) @@ -265,7 +263,6 @@ def delete_data_table( client, method="DELETE", endpoint_path=f"dataTables/{name}", - api_version=APIVersion.V1ALPHA, params={"force": str(force).lower()}, expected_status={200, 204}, error_message=f"Failed to delete data table '{name}'", @@ -321,14 +318,16 @@ def _delete_data_table_row( client, method="DELETE", endpoint_path=f"dataTables/{table_id}/dataTableRows/{row_guid}", - api_version=APIVersion.V1ALPHA, expected_status={200, 204}, error_message=( f"Failed to delete data table row '{row_guid}' from '{table_id}'" ), ) - except APIError: - return {"status": "success"} + except APIError as ar: + # Return success if response is text + if "Expected JSON response" in str(ar): + return {"status": "success"} + raise def get_data_table( @@ -351,7 +350,6 @@ def get_data_table( client, method="GET", endpoint_path=f"dataTables/{name}", - api_version=APIVersion.V1ALPHA, error_message=f"Failed to get data table '{name}'", ) @@ -359,16 +357,21 @@ def get_data_table( def list_data_tables( client: "Any", order_by: str | None = None, -) -> list[dict[str, Any]]: + as_list: bool = True, +) -> dict[str, Any] | list[dict[str, Any]]: """List data tables. Args: client: ChronicleClient instance order_by: Configures ordering of DataTables in the response. Note: The API only supports "createTime asc". + as_list: If True, return only the list of data tables. + If False, return dict with metadata and pagination tokens. + Defaults to True for backward compatibility. Returns: - List of data tables + If as_list is True: List of data tables. + If as_list is False: Dict with dataTables list and pagination metadata. Raises: APIError: If the API request fails @@ -381,7 +384,6 @@ def list_data_tables( client, path="dataTables", items_key="dataTables", - api_version=APIVersion.V1ALPHA, extra_params=extra_params if extra_params else None, as_list=True, ) @@ -391,7 +393,8 @@ def list_data_table_rows( client: "Any", name: str, order_by: str | None = None, -) -> list[dict[str, Any]]: + as_list: bool = True, +) -> dict[str, Any] | list[dict[str, Any]]: """List data table rows. Args: @@ -399,9 +402,14 @@ def list_data_table_rows( name: The name of the data table to list rows from order_by: Configures ordering of DataTableRows in the response. Note: The API only supports "createTime asc". + as_list: If True, return only the list of data table rows. + If False, return dict with metadata and pagination tokens. + Defaults to True for backward compatibility. Returns: - List of data table rows + If as_list is True: List of data table rows. + If as_list is False: Dict with dataTableRows list and + pagination metadata. Raises: APIError: If the API request fails @@ -414,9 +422,8 @@ def list_data_table_rows( client, path=f"dataTables/{name}/dataTableRows", items_key="dataTableRows", - api_version=APIVersion.V1ALPHA, extra_params=extra_params if extra_params else None, - as_list=True, + as_list=as_list, ) @@ -469,7 +476,6 @@ def update_data_table( client, method="PATCH", endpoint_path=f"dataTables/{name}", - api_version=APIVersion.V1ALPHA, params=params if params else None, json=body_payload, error_message=f"Failed to update data table '{name}'", @@ -579,7 +585,6 @@ def replace_data_table_rows( client, method="POST", endpoint_path=f"dataTables/{name}/dataTableRows:bulkReplace", - api_version=APIVersion.V1ALPHA, json={"requests": replace_requests}, error_message=f"Failed to replace data table rows for '{name}'", ) @@ -712,7 +717,6 @@ def _update_data_table_rows( client, method="POST", endpoint_path=f"dataTables/{name}/dataTableRows:bulkUpdate", - api_version=APIVersion.V1ALPHA, json={"requests": requests}, error_message=f"Failed to update data table rows for '{name}'", ) diff --git a/src/secops/chronicle/entity.py b/src/secops/chronicle/entity.py index d24aac4d..8e4af5df 100644 --- a/src/secops/chronicle/entity.py +++ b/src/secops/chronicle/entity.py @@ -181,20 +181,13 @@ def _summarize_entity_by_id( if page_token: params["pageToken"] = page_token - try: - return chronicle_request( - client, - method="GET", - endpoint_path=":summarizeEntity", - api_version=APIVersion.V1ALPHA, - params=params, - error_message=(f"Error getting entity summary by ID ({entity_id})"), - ) - except Exception as e: - raise APIError( - "Error parsing entity summary response for " - f"ID {entity_id}: {str(e)}" - ) from e + return chronicle_request( + client, + method="GET", + endpoint_path=":summarizeEntity", + params=params, + error_message=(f"Error getting entity summary by ID ({entity_id})"), + ) def summarize_entity( @@ -249,7 +242,6 @@ def summarize_entity( client, method="GET", endpoint_path=":summarizeEntitiesFromQuery", - api_version=APIVersion.V1ALPHA, params=query_params, error_message="Error querying entity summaries", ) diff --git a/src/secops/chronicle/feeds.py b/src/secops/chronicle/feeds.py index 9de1b429..064c2317 100644 --- a/src/secops/chronicle/feeds.py +++ b/src/secops/chronicle/feeds.py @@ -137,7 +137,8 @@ def list_feeds( page_size: int = 100, page_token: str = None, api_version: APIVersion | None = None, -) -> list[Feed]: + as_list: bool = True, +) -> dict[str, Any] | list[Feed]: """List feeds. Args: @@ -145,9 +146,13 @@ def list_feeds( page_size: The maximum number of feeds to return page_token: A page token, received from a previous ListFeeds call api_version: (Optional) Preferred API version to use. + as_list: If True, return only the list of feeds. + If False, return dict with metadata and pagination tokens. + Defaults to True for backward compatibility. Returns: - List of feed dictionaries + If as_list is True: List of feed dictionaries. + If as_list is False: Dict with feeds list and pagination metadata. Raises: APIError: If the API request fails @@ -159,7 +164,7 @@ def list_feeds( api_version=api_version, page_size=page_size, page_token=page_token, - as_list=True, + as_list=as_list, ) diff --git a/src/secops/chronicle/gemini.py b/src/secops/chronicle/gemini.py index 44f27129..d626e5c8 100644 --- a/src/secops/chronicle/gemini.py +++ b/src/secops/chronicle/gemini.py @@ -335,7 +335,6 @@ def create_conversation(client, display_name: str = "New chat") -> str: client, method="POST", endpoint_path="users/me/conversations", - api_version=APIVersion.V1ALPHA, json=payload, error_message="Failed to create conversation", ) @@ -371,7 +370,6 @@ def opt_in_to_gemini(client) -> bool: client, method="PATCH", endpoint_path="users/me/preferenceSet", - api_version=APIVersion.V1ALPHA, params=params, json=payload, expected_status={200, 403, 401}, @@ -436,7 +434,6 @@ def query_gemini( endpoint_path=( f"users/me/conversations/{conversation_id}/messages" ), - api_version=APIVersion.V1ALPHA, json=payload, error_message="Failed to query Gemini", ) diff --git a/src/secops/chronicle/investigations.py b/src/secops/chronicle/investigations.py index cc084ffc..b5a9a1aa 100644 --- a/src/secops/chronicle/investigations.py +++ b/src/secops/chronicle/investigations.py @@ -16,8 +16,11 @@ from typing import Any -from secops.chronicle.models import APIVersion, DetectionType -from secops.chronicle.utils.request_utils import chronicle_request +from secops.chronicle.models import APIVersion +from secops.chronicle.utils.request_utils import ( + chronicle_request, + chronicle_paginated_request, +) def fetch_associated_investigations( @@ -128,7 +131,8 @@ def list_investigations( page_token: str | None = None, filter_expr: str | None = None, order_by: str | None = None, -) -> dict[str, Any]: + as_list: bool = False, +) -> dict[str, Any] | list[Any]: """Lists investigations. Args: @@ -143,33 +147,32 @@ def list_investigations( order_by: Configures ordering of investigations. Default is by create time descending. Supported fields: "startTime", "endTime", "displayName". + as_list: If True, return only the list of investigations. + If False, return dict with metadata and pagination tokens. Returns: - Dictionary containing: - - investigations: List of investigation objects - - nextPageToken: Token for next page (if more results exist) - - totalSize: Total number of investigations matching request + If as_list is True: List of investigations. + If as_list is False: Dict with investigations list, nextPageToken, + and totalSize. Raises: APIError: If the API request fails. """ - params: dict[str, Any] = {} - if page_size is not None: - params["pageSize"] = page_size - if page_token: - params["pageToken"] = page_token + extra_params: dict[str, Any] = {} if filter_expr: - params["filter"] = filter_expr + extra_params["filter"] = filter_expr if order_by: - params["orderBy"] = order_by + extra_params["orderBy"] = order_by - return chronicle_request( + return chronicle_paginated_request( client, - method="GET", - endpoint_path="investigations", + path="investigations", + items_key="investigations", api_version=APIVersion.V1ALPHA, - params=params, - error_message="Failed to list investigations", + page_size=page_size, + page_token=page_token, + extra_params=extra_params if extra_params else None, + as_list=as_list, ) diff --git a/src/secops/chronicle/ioc.py b/src/secops/chronicle/ioc.py index a2743c38..32d44cdf 100644 --- a/src/secops/chronicle/ioc.py +++ b/src/secops/chronicle/ioc.py @@ -61,7 +61,6 @@ def list_iocs( client, method="GET", endpoint_path="legacy:legacySearchEnterpriseWideIoCs", - api_version=APIVersion.V1ALPHA, params=params, error_message="Failed to list IoCs", ) diff --git a/src/secops/chronicle/log_ingest.py b/src/secops/chronicle/log_ingest.py index 7a4553e1..abefe3b9 100644 --- a/src/secops/chronicle/log_ingest.py +++ b/src/secops/chronicle/log_ingest.py @@ -25,7 +25,10 @@ from secops.chronicle.log_types import is_valid_log_type from secops.chronicle.models import APIVersion -from secops.chronicle.utils.request_utils import chronicle_request +from secops.chronicle.utils.request_utils import ( + chronicle_request, + chronicle_paginated_request, +) from secops.exceptions import APIError # Forward declaration for type hinting to avoid circular import @@ -380,7 +383,6 @@ def create_forwarder( client, method="POST", endpoint_path="forwarders", - api_version=APIVersion.V1ALPHA, json=payload, error_message="Failed to create forwarder", ) @@ -390,49 +392,34 @@ def list_forwarders( client: "ChronicleClient", page_size: int | None = None, page_token: str | None = None, -) -> dict[str, Any]: + as_list: bool = False, +) -> dict[str, Any] | list[Any]: """List forwarders in Chronicle. Args: client: ChronicleClient instance page_size: Maximum number of forwarders to return (1-1000) page_token: Token for pagination + as_list: If True, return only the list of forwarders. + If False, return dict with metadata and pagination tokens. Returns: - Dictionary containing list of forwarders and next page token + If as_list is True: List of forwarders. + If as_list is False: Dict with forwarders list and pagination metadata. Raises: APIError: If the API request fails """ - url = f"{client.base_url}/{client.instance_id}/forwarders" - - # Add query parameters - params = {} - if page_size: - params["pageSize"] = min(1000, max(1, page_size)) - if page_token: - params["pageToken"] = page_token - result = chronicle_request( + return chronicle_paginated_request( client, - method="GET", - endpoint_path="forwarders", - api_version=APIVersion.V1ALPHA, - params=params if params else None, - error_message="Failed to list forwarders", + path="forwarders", + items_key="forwarders", + page_size=min(1000, max(1, page_size)) if page_size else None, + page_token=page_token, + as_list=as_list, ) - # If there's a next page token, fetch additional pages and combine results - if not page_size and "nextPageToken" in result and result["nextPageToken"]: - next_page = list_forwarders(client, page_size, result["nextPageToken"]) - if "forwarders" in next_page and next_page["forwarders"]: - # Combine the forwarders from both pages - result["forwarders"].extend(next_page["forwarders"]) - # Remove the nextPageToken since we've fetched all pages - result.pop("nextPageToken") - - return result - def get_forwarder( client: "ChronicleClient", forwarder_id: str @@ -453,7 +440,6 @@ def get_forwarder( client, method="GET", endpoint_path=f"forwarders/{forwarder_id}", - api_version=APIVersion.V1ALPHA, error_message="Failed to get forwarder", ) @@ -580,7 +566,6 @@ def update_forwarder( client, method="PATCH", endpoint_path=f"forwarders/{forwarder_id}", - api_version=APIVersion.V1ALPHA, params=params, json=payload, error_message="Failed to update forwarder", @@ -607,7 +592,6 @@ def delete_forwarder( client, method="DELETE", endpoint_path=f"forwarders/{forwarder_id}", - api_version=APIVersion.V1ALPHA, error_message="Failed to delete forwarder", ) @@ -906,7 +890,6 @@ def ingest_log( client, method="POST", endpoint_path="logs:import", - api_version=APIVersion.V1ALPHA, json=payload, error_message="Failed to ingest log", ) @@ -1024,7 +1007,6 @@ def ingest_udm( client, method="POST", endpoint_path="events:import", - api_version=APIVersion.V1ALPHA, json=body, expected_status={200, 201}, error_message="Failed to ingest UDM events", @@ -1081,7 +1063,6 @@ def import_entities( client, method="POST", endpoint_path="entities:import", - api_version=APIVersion.V1ALPHA, json=body, expected_status={200, 201}, error_message="Failed to import entities", diff --git a/src/secops/chronicle/log_processing_pipelines.py b/src/secops/chronicle/log_processing_pipelines.py index 14859c5e..a534dd57 100644 --- a/src/secops/chronicle/log_processing_pipelines.py +++ b/src/secops/chronicle/log_processing_pipelines.py @@ -17,7 +17,10 @@ from typing import Any from secops.chronicle.models import APIVersion -from secops.chronicle.utils.request_utils import chronicle_request +from secops.chronicle.utils.request_utils import ( + chronicle_request, + chronicle_paginated_request, +) def list_log_processing_pipelines( @@ -25,7 +28,8 @@ def list_log_processing_pipelines( page_size: int | None = None, page_token: str | None = None, filter_expr: str | None = None, -) -> dict[str, Any]: + as_list: bool = False, +) -> dict[str, Any] | list[Any]: """Lists log processing pipelines. Args: @@ -35,30 +39,29 @@ def list_log_processing_pipelines( page_token: Page token from a previous list call to retrieve the next page. filter_expr: Filter expression (AIP-160) to restrict results. + as_list: If True, return only the list of pipelines. + If False, return dict with metadata and pagination tokens. Returns: - Dictionary containing: - - logProcessingPipelines: List of pipeline dicts - - nextPageToken: Token for next page (if more results exist) + If as_list is True: List of log processing pipelines. + If as_list is False: Dict with logProcessingPipelines list and + pagination metadata. Raises: APIError: If the API request fails. """ - params: dict[str, Any] = {} - if page_size is not None: - params["pageSize"] = page_size - if page_token: - params["pageToken"] = page_token + extra_params = {} if filter_expr: - params["filter"] = filter_expr + extra_params["filter"] = filter_expr - return chronicle_request( + return chronicle_paginated_request( client, - method="GET", - endpoint_path="logProcessingPipelines", - api_version=APIVersion.V1ALPHA, - params=params if params else None, - error_message="Failed to list log processing pipelines", + path="logProcessingPipelines", + items_key="logProcessingPipelines", + page_size=page_size, + page_token=page_token, + extra_params=extra_params if extra_params else None, + as_list=as_list, ) @@ -86,7 +89,6 @@ def get_log_processing_pipeline( client, method="GET", endpoint_path=endpoint_path, - api_version=APIVersion.V1ALPHA, error_message="Failed to get log processing pipeline", ) @@ -122,7 +124,6 @@ def create_log_processing_pipeline( client, method="POST", endpoint_path="logProcessingPipelines", - api_version=APIVersion.V1ALPHA, params=params if params else None, json=pipeline, error_message="Failed to create log processing pipeline", @@ -165,7 +166,6 @@ def update_log_processing_pipeline( client, method="PATCH", endpoint_path=endpoint_path, - api_version=APIVersion.V1ALPHA, params=params if params else None, json=pipeline, error_message="Failed to patch log processing pipeline", @@ -202,7 +202,6 @@ def delete_log_processing_pipeline( client, method="DELETE", endpoint_path=endpoint_path, - api_version=APIVersion.V1ALPHA, params=params if params else None, error_message="Failed to delete log processing pipeline", ) @@ -235,7 +234,6 @@ def associate_streams( client, method="POST", endpoint_path=endpoint_path, - api_version=APIVersion.V1ALPHA, json={"streams": streams}, error_message="Failed to associate streams", ) @@ -270,7 +268,6 @@ def dissociate_streams( client, method="POST", endpoint_path=endpoint_path, - api_version=APIVersion.V1ALPHA, json={"streams": streams}, error_message="Failed to dissociate streams", ) @@ -301,7 +298,6 @@ def fetch_associated_pipeline( client, method="GET", endpoint_path="logProcessingPipelines:fetchAssociatedPipeline", - api_version=APIVersion.V1ALPHA, params=params, error_message="Failed to fetch associated pipeline", ) @@ -338,7 +334,6 @@ def fetch_sample_logs_by_streams( client, method="POST", endpoint_path="logProcessingPipelines:fetchSampleLogsByStreams", - api_version=APIVersion.V1ALPHA, json=body, error_message="Failed to fetch sample logs by streams", ) @@ -369,7 +364,6 @@ def test_pipeline( client, method="POST", endpoint_path="logProcessingPipelines:testPipeline", - api_version=APIVersion.V1ALPHA, json=body, error_message="Failed to test pipeline", ) diff --git a/src/secops/chronicle/log_types.py b/src/secops/chronicle/log_types.py index dcb6e3b8..e3a64b1a 100644 --- a/src/secops/chronicle/log_types.py +++ b/src/secops/chronicle/log_types.py @@ -60,7 +60,6 @@ def _fetch_log_types_from_api( client, path="logTypes", items_key="logTypes", - api_version=APIVersion.V1ALPHA, page_size=page_size, page_token=page_token, as_list=True, @@ -266,7 +265,6 @@ def classify_logs( client, method="POST", endpoint_path="logs:classify", - api_version=APIVersion.V1ALPHA, json=payload, error_message="Failed to classify log", ) diff --git a/src/secops/chronicle/nl_search.py b/src/secops/chronicle/nl_search.py index e0ea95e9..b57e098d 100644 --- a/src/secops/chronicle/nl_search.py +++ b/src/secops/chronicle/nl_search.py @@ -42,7 +42,6 @@ def translate_nl_to_udm(client, text: str) -> str: client, method="POST", endpoint_path=":translateUdmQuery", - api_version=APIVersion.V1ALPHA, json={"text": text}, error_message="Chronicle API request failed", ) diff --git a/src/secops/chronicle/parser.py b/src/secops/chronicle/parser.py index 6b005275..04de5411 100644 --- a/src/secops/chronicle/parser.py +++ b/src/secops/chronicle/parser.py @@ -52,7 +52,6 @@ def activate_parser( client, method="POST", endpoint_path=f"logTypes/{log_type}/parsers/{id}:activate", - api_version=APIVersion.V1ALPHA, json={}, error_message="Failed to activate parser", ) @@ -83,7 +82,6 @@ def activate_release_candidate_parser( f"logTypes/{log_type}/parsers/{id}" ":activateReleaseCandidateParser" ), - api_version=APIVersion.V1ALPHA, json={}, error_message="Failed to activate parser", ) @@ -111,7 +109,6 @@ def copy_parser( client, method="POST", endpoint_path=f"logTypes/{log_type}/parsers/{id}:copy", - api_version=APIVersion.V1ALPHA, json={}, error_message="Failed to copy parser", ) @@ -146,7 +143,6 @@ def create_parser( client, method="POST", endpoint_path=f"logTypes/{log_type}/parsers", - api_version=APIVersion.V1ALPHA, json=body, error_message="Failed to create parser", ) @@ -174,7 +170,6 @@ def deactivate_parser( client, method="POST", endpoint_path=f"logTypes/{log_type}/parsers/{id}:deactivate", - api_version=APIVersion.V1ALPHA, json={}, error_message="Failed to deactivate parser", ) @@ -206,7 +201,6 @@ def delete_parser( client, method="DELETE", endpoint_path=f"logTypes/{log_type}/parsers/{id}", - api_version=APIVersion.V1ALPHA, params=params, error_message="Failed to delete parser", ) @@ -234,7 +228,6 @@ def get_parser( client, method="GET", endpoint_path=f"logTypes/{log_type}/parsers/{id}", - api_version=APIVersion.V1ALPHA, error_message="Failed to get parser", ) @@ -245,7 +238,8 @@ def list_parsers( page_size: int | None = None, page_token: str | None = None, filter: str = None, # pylint: disable=redefined-builtin -) -> list[Any] | dict[str, Any]: + as_list: bool = True, +) -> dict[str, Any] | list[Any]: """List parsers. Args: @@ -256,11 +250,14 @@ def list_parsers( If None (default), auto-paginates and returns all parsers. page_token: A page token, received from a previous ListParsers call. filter: Optional filter expression + as_list: If True, return only the list of parsers. + If False, return dict with metadata and pagination tokens. + Defaults to True. When page_size is None, this is automatically + set to True for backward compatibility. Returns: - If page_size is None: List of all parsers. - If page_size is provided: List of parsers with next page token if - available. + If as_list is True: List of parsers. + If as_list is False: Dict with parsers list and pagination metadata. Raises: APIError: If the API request fails @@ -269,15 +266,17 @@ def list_parsers( if filter: extra_params["filter"] = filter + # For backward compatibility: if page_size is None, force as_list to True + effective_as_list = True if page_size is None else as_list + return chronicle_paginated_request( client, path=f"logTypes/{log_type}/parsers", items_key="parsers", - api_version=APIVersion.V1ALPHA, page_size=page_size, page_token=page_token, extra_params=extra_params if extra_params else None, - as_list=(page_size is None), + as_list=effective_as_list, ) @@ -404,7 +403,6 @@ def run_parser( client, method="POST", endpoint_path=f"logTypes/{log_type}:runParser", - api_version=APIVersion.V1ALPHA, json=body, error_message=f"Failed to evaluate parser for log type '{log_type}'", ) diff --git a/src/secops/chronicle/parser_extension.py b/src/secops/chronicle/parser_extension.py index 5cfe71aa..0a19cc65 100644 --- a/src/secops/chronicle/parser_extension.py +++ b/src/secops/chronicle/parser_extension.py @@ -20,7 +20,10 @@ from typing import Any from secops.chronicle.models import APIVersion -from secops.chronicle.utils.request_utils import chronicle_request +from secops.chronicle.utils.request_utils import ( + chronicle_request, + chronicle_paginated_request, +) @dataclass @@ -157,7 +160,6 @@ def create_parser_extension( client, method="POST", endpoint_path=f"logTypes/{log_type}/parserExtensions", - api_version=APIVersion.V1ALPHA, json=extension_config.to_dict(), error_message="Failed to create parser extension", ) @@ -183,7 +185,6 @@ def get_parser_extension( client, method="GET", endpoint_path=f"logTypes/{log_type}/parserExtensions/{extension_id}", - api_version=APIVersion.V1ALPHA, error_message="Failed to get parser extension", ) @@ -193,7 +194,8 @@ def list_parser_extensions( log_type: str, page_size: int | None = None, page_token: str | None = None, -) -> dict[str, Any]: + as_list: bool = False, +) -> dict[str, Any] | list[Any]: """List parser extensions. Args: @@ -201,26 +203,24 @@ def list_parser_extensions( log_type: The log type to list parser extensions for page_size: Maximum number of parser extensions to return page_token: Token for pagination + as_list: If True, return only the list of parser extensions. + If False, return dict with metadata and pagination tokens. Returns: - Dict containing list of parser extensions and next page token if any + If as_list is True: List of parser extensions. + If as_list is False: Dict with parserExtensions list and + pagination metadata. Raises: APIError: If the API request fails """ - params = {} - if page_size is not None: - params["pageSize"] = page_size - if page_token is not None: - params["pageToken"] = page_token - - return chronicle_request( + return chronicle_paginated_request( client, - method="GET", - endpoint_path=f"logTypes/{log_type}/parserExtensions", - api_version=APIVersion.V1ALPHA, - params=params if params else None, - error_message="Failed to list parser extensions", + path=f"logTypes/{log_type}/parserExtensions", + items_key="parserExtensions", + page_size=page_size, + page_token=page_token, + as_list=as_list, ) @@ -241,7 +241,6 @@ def activate_parser_extension(client, log_type: str, extension_id: str) -> None: endpoint_path=( f"logTypes/{log_type}/parserExtensions/{extension_id}:activate" ), - api_version=APIVersion.V1ALPHA, error_message="Failed to activate parser extension", ) @@ -261,6 +260,5 @@ def delete_parser_extension(client, log_type: str, extension_id: str) -> None: client, method="DELETE", endpoint_path=f"logTypes/{log_type}/parserExtensions/{extension_id}", - api_version=APIVersion.V1ALPHA, error_message="Failed to delete parser extension", ) diff --git a/src/secops/chronicle/reference_list.py b/src/secops/chronicle/reference_list.py index 5b06f60d..916f7398 100644 --- a/src/secops/chronicle/reference_list.py +++ b/src/secops/chronicle/reference_list.py @@ -160,7 +160,8 @@ def list_reference_lists( client: "Any", view: ReferenceListView = ReferenceListView.BASIC, api_version: APIVersion | None = APIVersion.V1, -) -> list[dict[str, Any]]: + as_list: bool = True, +) -> dict[str, Any] | list[Any]: """List reference lists. Args: @@ -168,9 +169,15 @@ def list_reference_lists( view: How much of each ReferenceList to view. Defaults to REFERENCE_LIST_VIEW_BASIC. api_version: Preferred API version to use. Defaults to V1 + as_list: If True, return only the list of reference lists. + If False, return dict with metadata and pagination tokens. + Defaults to True for backward compatibility. Returns: - List of reference lists, ordered in ascending alphabetical order by name + If as_list is True: List of reference lists, ordered in ascending + alphabetical order by name. + If as_list is False: Dict with referenceLists list and + pagination metadata. Raises: APIError: If the API request fails @@ -185,7 +192,7 @@ def list_reference_lists( items_key="referenceLists", api_version=api_version, extra_params=extra_params if extra_params else None, - as_list=True, + as_list=as_list, ) diff --git a/src/secops/chronicle/rule.py b/src/secops/chronicle/rule.py index eccd89ec..de909987 100644 --- a/src/secops/chronicle/rule.py +++ b/src/secops/chronicle/rule.py @@ -48,7 +48,7 @@ def create_rule( client, method="POST", endpoint_path="rules", - api_version=api_version, + api_version=api_version if api_version else None, json={"text": rule_text}, error_message="Failed to create rule", ) @@ -86,7 +86,8 @@ def list_rules( page_size: int | None = None, page_token: str | None = None, api_version: APIVersion | None = APIVersion.V1, -) -> dict[str, Any]: + as_list: bool = False, +) -> dict[str, Any] | list[Any]: """Gets a list of rules. Args: @@ -101,9 +102,12 @@ def list_rules( page_size: Maximum number of rules to return per page. page_token: Token for the next page of results, if available. api_version: (Optional) Preferred API version to use. + as_list: If True, return only the list of rules. + If False, return dict with metadata and pagination tokens. Returns: - Dictionary containing information about rules + If as_list is True: List of rules. + If as_list is False: Dict with rules list and pagination metadata. Raises: APIError: If the API request fails @@ -116,6 +120,7 @@ def list_rules( page_size=page_size, page_token=page_token, extra_params={"view": view} if view else {}, + as_list=as_list, ) @@ -254,7 +259,8 @@ def list_rule_deployments( page_token: str | None = None, filter_query: str | None = None, api_version: APIVersion | None = APIVersion.V1, -) -> dict[str, Any]: + as_list: bool = False, +) -> dict[str, Any] | list[Any]: """Lists rule deployments for the instance. Args: @@ -264,10 +270,14 @@ def list_rule_deployments( page_token: Token for the next page of results, if available. filter_query: Optional filter query to restrict results. Filters results based on expression matching specific fields. + api_version: (Optional) Preferred API version to use. + as_list: If True, return only the list of rule deployments. + If False, return dict with metadata and pagination tokens. Returns: - Dictionary containing rule deployment entries. If ``page_size`` is not - provided, returns an aggregated object with a ``deployments`` list. + If as_list is True: List of rule deployments. + If as_list is False: Dict with ruleDeployments list and + pagination metadata. Raises: APIError: If the API request fails. @@ -285,6 +295,7 @@ def list_rule_deployments( page_size=page_size, page_token=page_token, extra_params=extra_params if extra_params else None, + as_list=as_list, ) @@ -382,7 +393,6 @@ def run_rule_test( client, method="POST", endpoint_path="legacy:legacyRunTestRule", - api_version=APIVersion.V1ALPHA, json=body, timeout=timeout, error_message="Failed to test rule", diff --git a/src/secops/chronicle/rule_alert.py b/src/secops/chronicle/rule_alert.py index c94170ef..e049035f 100644 --- a/src/secops/chronicle/rule_alert.py +++ b/src/secops/chronicle/rule_alert.py @@ -45,7 +45,6 @@ def get_alert( client, method="GET", endpoint_path="legacy:legacyGetAlert", - api_version=APIVersion.V1ALPHA, params=params, error_message="Failed to get alert", ) @@ -190,7 +189,6 @@ def update_alert( client, method="POST", endpoint_path="legacy:legacyUpdateAlert", - api_version=APIVersion.V1ALPHA, json=payload, error_message="Failed to update alert", ) @@ -335,7 +333,6 @@ def search_rule_alerts( client, method="GET", endpoint_path="legacy:legacySearchRulesAlerts", - api_version=APIVersion.V1ALPHA, params=params, error_message="Failed to search rule alerts", ) diff --git a/src/secops/chronicle/rule_detection.py b/src/secops/chronicle/rule_detection.py index 5bc32a50..744ac2d1 100644 --- a/src/secops/chronicle/rule_detection.py +++ b/src/secops/chronicle/rule_detection.py @@ -18,7 +18,10 @@ from typing import Any, Literal from secops.chronicle.models import APIVersion -from secops.chronicle.utils.request_utils import chronicle_request +from secops.chronicle.utils.request_utils import ( + chronicle_request, + chronicle_paginated_request, +) def list_detections( @@ -32,7 +35,8 @@ def list_detections( alert_state: str | None = None, page_size: int | None = None, page_token: str | None = None, -) -> dict[str, Any]: + as_list: bool = False, +) -> dict[str, Any] | list[Any]: """List detections for a rule. Args: @@ -54,9 +58,12 @@ def list_detections( - "ALERTING" page_size: If provided, maximum number of detections to return page_token: If provided, continuation token for pagination + as_list: If True, return only the list of detections. + If False, return dict with metadata and pagination tokens. Returns: - Dictionary containing detection information + If as_list is True: List of detections. + If as_list is False: Dict with detections list and pagination metadata. Raises: APIError: If the API request fails @@ -69,7 +76,7 @@ def list_detections( "DETECTION_TIME", ] - params = {"rule_id": rule_id} + extra_params = {"rule_id": rule_id} if alert_state: if alert_state not in valid_alert_states: @@ -77,7 +84,7 @@ def list_detections( f"alert_state must be one of {valid_alert_states}, " f"got {alert_state}" ) - params["alertState"] = alert_state + extra_params["alertState"] = alert_state if list_basis: if list_basis not in valid_list_basis: @@ -85,25 +92,21 @@ def list_detections( f"list_basis must be one of {valid_list_basis}, " f"got {list_basis}" ) - params["listBasis"] = list_basis + extra_params["listBasis"] = list_basis if start_time: - params["startTime"] = start_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + extra_params["startTime"] = start_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ") if end_time: - params["endTime"] = end_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + extra_params["endTime"] = end_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ") - if page_size: - params["pageSize"] = page_size - if page_token: - params["pageToken"] = page_token - - return chronicle_request( + return chronicle_paginated_request( client, - method="GET", - endpoint_path="legacy:legacySearchDetections", - api_version=APIVersion.V1ALPHA, - params=params, - error_message="Failed to list detections", + path="legacy:legacySearchDetections", + items_key="detections", + page_size=page_size, + page_token=page_token, + extra_params=extra_params, + as_list=as_list, ) @@ -130,7 +133,6 @@ def list_errors(client, rule_id: str) -> dict[str, Any]: client, method="GET", endpoint_path="ruleExecutionErrors", - api_version=APIVersion.V1ALPHA, params=params, error_message="Failed to list rule errors", ) diff --git a/src/secops/chronicle/rule_exclusion.py b/src/secops/chronicle/rule_exclusion.py index a8554ef7..9220d87e 100644 --- a/src/secops/chronicle/rule_exclusion.py +++ b/src/secops/chronicle/rule_exclusion.py @@ -20,7 +20,6 @@ from datetime import datetime from typing import Annotated, Any -from secops.chronicle.models import APIVersion from secops.chronicle.utils.request_utils import ( chronicle_paginated_request, chronicle_request, @@ -83,17 +82,24 @@ def to_dict(self) -> dict[str, Any]: def list_rule_exclusions( - client, page_size: int | None = None, page_token: str | None = None -) -> dict[str, Any]: + client, + page_size: int | None = None, + page_token: str | None = None, + as_list: bool = False, +) -> dict[str, Any] | list[Any]: """List rule exclusions. Args: client: ChronicleClient instance page_size: Maximum number of rule exclusions to return per page page_token: Page token for pagination + as_list: If True, return only the list of rule exclusions. + If False, return dict with metadata and pagination tokens. Returns: - Dictionary containing the list of rule exclusions + If as_list is True: List of rule exclusions. + If as_list is False: Dict with findingsRefinements list and + pagination metadata. Raises: APIError: If the API request fails @@ -102,9 +108,9 @@ def list_rule_exclusions( client, path="findingsRefinements", items_key="findingsRefinements", - api_version=APIVersion.V1ALPHA, page_size=page_size, page_token=page_token, + as_list=as_list, ) @@ -130,7 +136,6 @@ def get_rule_exclusion(client, exclusion_id: str) -> dict[str, Any]: client, method="GET", endpoint_path=endpoint_path, - api_version=APIVersion.V1ALPHA, error_message="Failed to get rule exclusion", ) @@ -165,7 +170,6 @@ def create_rule_exclusion( client, method="POST", endpoint_path="findingsRefinements", - api_version=APIVersion.V1ALPHA, json=body, error_message="Failed to create rule exclusion", ) @@ -219,7 +223,6 @@ def patch_rule_exclusion( client, method="PATCH", endpoint_path=endpoint_path, - api_version=APIVersion.V1ALPHA, params=params, json=body, error_message="Failed to update rule exclusion", @@ -276,7 +279,6 @@ def compute_rule_exclusion_activity( client, method="POST", endpoint_path=endpoint_path, - api_version=APIVersion.V1ALPHA, json=body, error_message="Failed to compute rule exclusion activity", ) @@ -304,7 +306,6 @@ def get_rule_exclusion_deployment(client, exclusion_id: str) -> dict[str, Any]: client, method="GET", endpoint_path=endpoint_path, - api_version=APIVersion.V1ALPHA, error_message="Failed to get rule exclusion deployment", ) @@ -349,7 +350,6 @@ def update_rule_exclusion_deployment( client, method="PATCH", endpoint_path=endpoint_path, - api_version=APIVersion.V1ALPHA, params=params, json=deployment_details.to_dict(), error_message="Failed to update rule exclusion deployment", diff --git a/src/secops/chronicle/rule_set.py b/src/secops/chronicle/rule_set.py index c7885bf0..108d7484 100644 --- a/src/secops/chronicle/rule_set.py +++ b/src/secops/chronicle/rule_set.py @@ -566,7 +566,8 @@ def search_curated_detections( page_token: str | None = None, max_resp_size_bytes: int | None = None, include_nested_detections: bool | None = False, -) -> dict[str, Any]: + as_list: bool = False, +) -> dict[str, Any] | list[Any]: """Search for detections generated by a specific curated rule. Args: @@ -593,9 +594,12 @@ def search_curated_detections( If set to 0 or omitted, no limit is enforced. include_nested_detections: If True, include one level of nested detections in the response. Default is False. + as_list: If True, return only the list of detections. + If False, return dict with metadata and pagination tokens. Returns: - Dictionary containing: + If as_list is True: List of curated detections. + If as_list is False: Dictionary containing: - curatedDetections: List of detections (if include_nested_detections is False) - nestedDetectionSamples: List of detections with nested @@ -657,4 +661,5 @@ def search_curated_detections( page_size=page_size, page_token=page_token, extra_params=extra_params, + as_list=as_list, ) From 28837c64e01f5f0ab009040b73139618592c8d1e Mon Sep 17 00:00:00 2001 From: Mihir Vala <179564180+mihirvala-crestdata@users.noreply.github.com> Date: Mon, 16 Mar 2026 16:30:44 +0530 Subject: [PATCH 3/3] chore: fixed unit tests for new request utils --- src/secops/chronicle/dashboard_query.py | 2 +- src/secops/chronicle/data_table.py | 16 +- src/secops/chronicle/feeds.py | 2 +- src/secops/chronicle/gemini.py | 4 +- src/secops/chronicle/investigations.py | 4 +- src/secops/chronicle/ioc.py | 1 - src/secops/chronicle/log_ingest.py | 28 +- .../chronicle/log_processing_pipelines.py | 41 +- src/secops/chronicle/log_types.py | 1 - src/secops/chronicle/nl_search.py | 1 - src/secops/chronicle/parser.py | 7 - src/secops/chronicle/parser_extension.py | 1 - src/secops/chronicle/rule_alert.py | 1 - src/secops/chronicle/rule_detection.py | 1 - src/secops/chronicle/rule_exclusion.py | 37 +- src/secops/chronicle/utils/request_utils.py | 11 +- tests/chronicle/test_client.py | 6 +- tests/chronicle/test_dashboard_query.py | 85 ++- tests/chronicle/test_data_export.py | 41 +- tests/chronicle/test_data_tables.py | 564 +++++++++++------- tests/chronicle/test_feed.py | 167 ++++-- tests/chronicle/test_gemini.py | 407 +++++++------ tests/chronicle/test_investigations.py | 7 +- tests/chronicle/test_log_ingest.py | 334 ++++++----- .../chronicle/test_log_processing_pipeline.py | 196 ++++-- tests/chronicle/test_log_types.py | 68 +-- tests/chronicle/test_nl_search.py | 49 +- tests/chronicle/test_parser.py | 238 +++++--- tests/chronicle/test_parser_extension.py | 58 +- tests/chronicle/test_rule.py | 170 ++++-- tests/chronicle/test_rule_deployment.py | 61 +- tests/chronicle/test_rule_detection.py | 45 +- tests/chronicle/test_rule_exclusion.py | 173 ++++-- 33 files changed, 1703 insertions(+), 1124 deletions(-) diff --git a/src/secops/chronicle/dashboard_query.py b/src/secops/chronicle/dashboard_query.py index ae486b32..e7c8ce2a 100644 --- a/src/secops/chronicle/dashboard_query.py +++ b/src/secops/chronicle/dashboard_query.py @@ -20,7 +20,7 @@ import json from typing import Any -from secops.chronicle.models import APIVersion, InputInterval +from secops.chronicle.models import InputInterval from secops.chronicle.utils.request_utils import chronicle_request from secops.exceptions import APIError diff --git a/src/secops/chronicle/data_table.py b/src/secops/chronicle/data_table.py index a376efcd..46cacc90 100644 --- a/src/secops/chronicle/data_table.py +++ b/src/secops/chronicle/data_table.py @@ -6,7 +6,6 @@ from itertools import islice from typing import Any -from secops.chronicle.models import APIVersion from secops.chronicle.utils.request_utils import ( chronicle_paginated_request, chronicle_request, @@ -320,7 +319,8 @@ def _delete_data_table_row( endpoint_path=f"dataTables/{table_id}/dataTableRows/{row_guid}", expected_status={200, 204}, error_message=( - f"Failed to delete data table row '{row_guid}' from '{table_id}'" + f"Failed to delete data table row '{row_guid}' " + f"from '{table_id}'" ), ) except APIError as ar: @@ -385,7 +385,7 @@ def list_data_tables( path="dataTables", items_key="dataTables", extra_params=extra_params if extra_params else None, - as_list=True, + as_list=as_list, ) @@ -521,11 +521,6 @@ def replace_data_table_rows( SecOpsError: If a row is too large to process """ - url = ( - f"{client.base_url}/{client.instance_id}/dataTables/{name}" - "/dataTableRows:bulkReplace" - ) - # Check for empty input if not rows: return [] @@ -687,11 +682,6 @@ def _update_data_table_rows( APIError: If the API request fails SecOpsError: If validation fails """ - url = ( - f"{client.base_url}/{client.instance_id}/dataTables/{name}" - "/dataTableRows:bulkUpdate" - ) - # Build request payload requests = [] for row_update in row_updates: diff --git a/src/secops/chronicle/feeds.py b/src/secops/chronicle/feeds.py index 064c2317..e8191635 100644 --- a/src/secops/chronicle/feeds.py +++ b/src/secops/chronicle/feeds.py @@ -279,7 +279,7 @@ def delete_feed( Raises: APIError: If the API request fails """ - return chronicle_request( + chronicle_request( client, method="DELETE", endpoint_path=f"feeds/{feed_id}", diff --git a/src/secops/chronicle/gemini.py b/src/secops/chronicle/gemini.py index d626e5c8..381dd51e 100644 --- a/src/secops/chronicle/gemini.py +++ b/src/secops/chronicle/gemini.py @@ -16,9 +16,9 @@ Provides access to Chronicle's Gemini conversational AI interface. """ +import re from typing import Any -from secops.chronicle.models import APIVersion from secops.chronicle.utils.request_utils import chronicle_request from secops.exceptions import APIError @@ -378,7 +378,7 @@ def opt_in_to_gemini(client) -> bool: return True except APIError as e: if "403" in str(e) or "401" in str(e): - print(f"Warning: Unable to opt in to Gemini due to permissions") + print("Warning: Unable to opt in to Gemini due to permissions") return False raise diff --git a/src/secops/chronicle/investigations.py b/src/secops/chronicle/investigations.py index b5a9a1aa..201b77e3 100644 --- a/src/secops/chronicle/investigations.py +++ b/src/secops/chronicle/investigations.py @@ -16,10 +16,10 @@ from typing import Any -from secops.chronicle.models import APIVersion +from secops.chronicle.models import APIVersion, DetectionType from secops.chronicle.utils.request_utils import ( - chronicle_request, chronicle_paginated_request, + chronicle_request, ) diff --git a/src/secops/chronicle/ioc.py b/src/secops/chronicle/ioc.py index 32d44cdf..1e975b90 100644 --- a/src/secops/chronicle/ioc.py +++ b/src/secops/chronicle/ioc.py @@ -17,7 +17,6 @@ from datetime import datetime from typing import Any -from secops.chronicle.models import APIVersion from secops.chronicle.utils.request_utils import chronicle_request from secops.exceptions import APIError diff --git a/src/secops/chronicle/log_ingest.py b/src/secops/chronicle/log_ingest.py index abefe3b9..805f595a 100644 --- a/src/secops/chronicle/log_ingest.py +++ b/src/secops/chronicle/log_ingest.py @@ -24,7 +24,6 @@ from typing import Any from secops.chronicle.log_types import is_valid_log_type -from secops.chronicle.models import APIVersion from secops.chronicle.utils.request_utils import ( chronicle_request, chronicle_paginated_request, @@ -348,7 +347,6 @@ def create_forwarder( Raises: APIError: If the API request fails """ - url = f"{client.base_url}/{client.instance_id}/forwarders" # Create request payload payload = { @@ -479,7 +477,6 @@ def update_forwarder( Raises: APIError: If the API returns an error response. """ - url = f"{client.base_url}/{client.instance_id}/forwarders/{forwarder_id}" auto_mask = [] # Update mask if not provided in argument payload = {} @@ -840,12 +837,6 @@ def ingest_log( else: forwarder_resource = forwarder_id - # Construct the import URL - url = ( - f"{client.base_url}/{client.instance_id}/logTypes" - f"/{log_type}/logs:import" - ) - if isinstance(log_message, str): initialize_multi_line_formats() # Split string into individual log entries based on log type @@ -992,11 +983,6 @@ def ingest_udm( if add_missing_ids and "id" not in event["metadata"]: event["metadata"]["id"] = str(uuid.uuid4()) - url = ( - f"{client.base_url(APIVersion.V1ALPHA)}/{client.instance_id}" - f"/events:import" - ) - # Format the request body body = { "inline_source": {"events": [{"udm": event} for event in events_copy]} @@ -1013,19 +999,7 @@ def ingest_udm( ) except APIError as e: error_message = f"Failed to ingest UDM events: {str(e)}" - raise APIError(error_message) - - response_data = {} - - # Parse response if it has content - if response.text.strip(): - try: - response_data = response.json() - except ValueError: - # If JSON parsing fails, provide the raw text in the return value - response_data = {"raw_response": response.text} - - return response_data + raise APIError(error_message) from e def import_entities( diff --git a/src/secops/chronicle/log_processing_pipelines.py b/src/secops/chronicle/log_processing_pipelines.py index a534dd57..aa9b15a9 100644 --- a/src/secops/chronicle/log_processing_pipelines.py +++ b/src/secops/chronicle/log_processing_pipelines.py @@ -16,7 +16,7 @@ from typing import Any -from secops.chronicle.models import APIVersion +from secops.chronicle.utils.format_utils import format_resource_id from secops.chronicle.utils.request_utils import ( chronicle_request, chronicle_paginated_request, @@ -80,10 +80,9 @@ def get_log_processing_pipeline( Raises: APIError: If the API request fails. """ - if not pipeline_id.startswith("projects/"): - endpoint_path = f"logProcessingPipelines/{pipeline_id}" - else: - endpoint_path = pipeline_id + + extracted_pipeline_id = format_resource_id(pipeline_id) + endpoint_path = f"logProcessingPipelines/{extracted_pipeline_id}" return chronicle_request( client, @@ -153,10 +152,8 @@ def update_log_processing_pipeline( Raises: APIError: If the API request fails. """ - if not pipeline_id.startswith("projects/"): - endpoint_path = f"logProcessingPipelines/{pipeline_id}" - else: - endpoint_path = pipeline_id + extracted_pipeline_id = format_resource_id(pipeline_id) + endpoint_path = f"logProcessingPipelines/{extracted_pipeline_id}" params: dict[str, Any] = {} if update_mask: @@ -189,10 +186,9 @@ def delete_log_processing_pipeline( Raises: APIError: If the API request fails. """ - if not pipeline_id.startswith("projects/"): - endpoint_path = f"logProcessingPipelines/{pipeline_id}" - else: - endpoint_path = pipeline_id + + extracted_pipeline_id = format_resource_id(pipeline_id) + endpoint_path = f"logProcessingPipelines/{extracted_pipeline_id}" params: dict[str, Any] = {} if etag: @@ -225,10 +221,10 @@ def associate_streams( Raises: APIError: If the API request fails. """ - if not pipeline_id.startswith("projects/"): - endpoint_path = f"logProcessingPipelines/{pipeline_id}:associateStreams" - else: - endpoint_path = f"{pipeline_id}:associateStreams" + extracted_pipeline_id = format_resource_id(pipeline_id) + endpoint_path = ( + f"logProcessingPipelines/{extracted_pipeline_id}:associateStreams" + ) return chronicle_request( client, @@ -257,12 +253,11 @@ def dissociate_streams( Raises: APIError: If the API request fails. """ - if not pipeline_id.startswith("projects/"): - endpoint_path = ( - f"logProcessingPipelines/{pipeline_id}:dissociateStreams" - ) - else: - endpoint_path = f"{pipeline_id}:dissociateStreams" + + extracted_pipeline_id = format_resource_id(pipeline_id) + endpoint_path = ( + f"logProcessingPipelines/{extracted_pipeline_id}:dissociateStreams" + ) return chronicle_request( client, diff --git a/src/secops/chronicle/log_types.py b/src/secops/chronicle/log_types.py index e3a64b1a..77829a98 100644 --- a/src/secops/chronicle/log_types.py +++ b/src/secops/chronicle/log_types.py @@ -23,7 +23,6 @@ import base64 from typing import TYPE_CHECKING, Any -from secops.chronicle.models import APIVersion from secops.chronicle.utils.request_utils import ( chronicle_paginated_request, chronicle_request, diff --git a/src/secops/chronicle/nl_search.py b/src/secops/chronicle/nl_search.py index b57e098d..66e70647 100644 --- a/src/secops/chronicle/nl_search.py +++ b/src/secops/chronicle/nl_search.py @@ -18,7 +18,6 @@ from datetime import datetime from typing import Any -from secops.chronicle.models import APIVersion from secops.chronicle.utils.request_utils import chronicle_request from secops.exceptions import APIError diff --git a/src/secops/chronicle/parser.py b/src/secops/chronicle/parser.py index 04de5411..67c46a7d 100644 --- a/src/secops/chronicle/parser.py +++ b/src/secops/chronicle/parser.py @@ -18,7 +18,6 @@ import json from typing import Any -from secops.chronicle.models import APIVersion from secops.chronicle.utils.request_utils import ( chronicle_paginated_request, chronicle_request, @@ -371,12 +370,6 @@ def run_parser( f"{type(parser_extension_code).__name__}" ) - # Build request - url = ( - f"{client.base_url}/{client.instance_id}" - f"/logTypes/{log_type}:runParser" - ) - parser = { "cbn": base64.b64encode(parser_code.encode("utf-8")).decode("utf-8") } diff --git a/src/secops/chronicle/parser_extension.py b/src/secops/chronicle/parser_extension.py index 0a19cc65..fdae52cc 100644 --- a/src/secops/chronicle/parser_extension.py +++ b/src/secops/chronicle/parser_extension.py @@ -19,7 +19,6 @@ from dataclasses import dataclass, field from typing import Any -from secops.chronicle.models import APIVersion from secops.chronicle.utils.request_utils import ( chronicle_request, chronicle_paginated_request, diff --git a/src/secops/chronicle/rule_alert.py b/src/secops/chronicle/rule_alert.py index e049035f..b68c6ef5 100644 --- a/src/secops/chronicle/rule_alert.py +++ b/src/secops/chronicle/rule_alert.py @@ -17,7 +17,6 @@ from datetime import datetime from typing import Any, Literal -from secops.chronicle.models import APIVersion from secops.chronicle.utils.request_utils import chronicle_request diff --git a/src/secops/chronicle/rule_detection.py b/src/secops/chronicle/rule_detection.py index 744ac2d1..e6184a4a 100644 --- a/src/secops/chronicle/rule_detection.py +++ b/src/secops/chronicle/rule_detection.py @@ -17,7 +17,6 @@ from datetime import datetime from typing import Any, Literal -from secops.chronicle.models import APIVersion from secops.chronicle.utils.request_utils import ( chronicle_request, chronicle_paginated_request, diff --git a/src/secops/chronicle/rule_exclusion.py b/src/secops/chronicle/rule_exclusion.py index 9220d87e..19d369ff 100644 --- a/src/secops/chronicle/rule_exclusion.py +++ b/src/secops/chronicle/rule_exclusion.py @@ -20,6 +20,7 @@ from datetime import datetime from typing import Annotated, Any +from secops.chronicle.utils.format_utils import format_resource_id from secops.chronicle.utils.request_utils import ( chronicle_paginated_request, chronicle_request, @@ -127,10 +128,8 @@ def get_rule_exclusion(client, exclusion_id: str) -> dict[str, Any]: Raises: APIError: If the API request fails """ - if not exclusion_id.startswith("projects/"): - endpoint_path = f"findingsRefinements/{exclusion_id}" - else: - endpoint_path = exclusion_id + exclusion_id = format_resource_id(exclusion_id) + endpoint_path = f"findingsRefinements/{exclusion_id}" return chronicle_request( client, @@ -202,10 +201,8 @@ def patch_rule_exclusion( Raises: APIError: If the API request fails """ - if not exclusion_id.startswith("projects/"): - endpoint_path = f"findingsRefinements/{exclusion_id}" - else: - endpoint_path = exclusion_id + exclusion_id = format_resource_id(exclusion_id) + endpoint_path = f"findingsRefinements/{exclusion_id}" body = {} if display_name: @@ -249,13 +246,11 @@ def compute_rule_exclusion_activity( Raises: APIError: If the API request fails """ - if not exclusion_id.startswith("projects/"): - endpoint_path = ( - f"findingsRefinements/{exclusion_id}" - ":computeFindingsRefinementActivity" - ) - else: - endpoint_path = f"{exclusion_id}:computeFindingsRefinementActivity" + exclusion_id = format_resource_id(exclusion_id) + + endpoint_path = ( + f"findingsRefinements/{exclusion_id}:computeFindingsRefinementActivity" + ) body = {} if start_time or end_time: @@ -297,10 +292,8 @@ def get_rule_exclusion_deployment(client, exclusion_id: str) -> dict[str, Any]: Raises: APIError: If the API request fails """ - if not exclusion_id.startswith("projects/"): - endpoint_path = f"findingsRefinements/{exclusion_id}/deployment" - else: - endpoint_path = f"{exclusion_id}/deployment" + exclusion_id = format_resource_id(exclusion_id) + endpoint_path = f"findingsRefinements/{exclusion_id}/deployment" return chronicle_request( client, @@ -331,10 +324,8 @@ def update_rule_exclusion_deployment( Raises: APIError: If the API request fails """ - if not exclusion_id.startswith("projects/"): - endpoint_path = f"findingsRefinements/{exclusion_id}/deployment" - else: - endpoint_path = f"{exclusion_id}/deployment" + exclusion_id = format_resource_id(exclusion_id) + endpoint_path = f"findingsRefinements/{exclusion_id}/deployment" params = {} if update_mask: diff --git a/src/secops/chronicle/utils/request_utils.py b/src/secops/chronicle/utils/request_utils.py index 37a2f94f..217ef0a4 100644 --- a/src/secops/chronicle/utils/request_utils.py +++ b/src/secops/chronicle/utils/request_utils.py @@ -53,7 +53,7 @@ def chronicle_paginated_request( path: str, items_key: str, *, - api_version: str | None = None, + api_version: APIVersion | str | None = None, page_size: int | None = None, page_token: str | None = None, extra_params: dict[str, Any] | None = None, @@ -193,7 +193,7 @@ def chronicle_request( method: str, endpoint_path: str, *, - api_version: str | None = None, + api_version: APIVersion | str | None = None, params: dict[str, Any] | None = None, headers: dict[str, Any] | None = None, json: dict[str, Any] | None = None, @@ -242,6 +242,9 @@ def chronicle_request( else: url = f'{base}/{endpoint_path.lstrip("/")}' + # init request response + response = None + try: response = client.session.request( method=method, @@ -258,7 +261,9 @@ def chronicle_request( base_msg = error_message or "API request failed" raise APIError( f"{base_msg}: method={method}, url={url}, " - f"request_error={exc.__class__.__name__}, detail={exc}" + f"request_error={exc.__class__.__name__}, detail={exc}, " + f"status_code={exc.response.status_code if exc.response else None}" + f"response_message={exc.response.text if exc.response else None}" ) from exc # Try to parse JSON even on error, so we can get more details diff --git a/tests/chronicle/test_client.py b/tests/chronicle/test_client.py index 64e1d7b7..acc5d824 100644 --- a/tests/chronicle/test_client.py +++ b/tests/chronicle/test_client.py @@ -154,7 +154,7 @@ def test_summarize_entity_ip(mock_summarize_by_id, mock_detect, chronicle_client mock_summarize_by_id.side_effect = [mock_details_response, mock_prevalence_response] with patch.object( - chronicle_client.session, "get", return_value=mock_query_response + chronicle_client.session, "request", return_value=mock_query_response ) as mock_session_get: result = chronicle_client.summarize_entity( value="8.8.8.8", @@ -167,7 +167,7 @@ def test_summarize_entity_ip(mock_summarize_by_id, mock_detect, chronicle_client # Check the query call was made mock_session_get.assert_called_once() query_call_args = mock_session_get.call_args - assert "summarizeEntitiesFromQuery" in query_call_args[0][0] + assert "summarizeEntitiesFromQuery" in query_call_args[1]["url"] assert query_call_args[1]["params"]["query"] == 'ip = "8.8.8.8"' # Check the _summarize_entity_by_id calls @@ -248,7 +248,7 @@ def test_list_iocs(chronicle_client): ] } - with patch.object(chronicle_client.session, "get", return_value=mock_response): + with patch.object(chronicle_client.session, "request", return_value=mock_response): result = chronicle_client.list_iocs( start_time=datetime(2024, 1, 1, tzinfo=timezone.utc), end_time=datetime(2024, 1, 2, tzinfo=timezone.utc), diff --git a/tests/chronicle/test_dashboard_query.py b/tests/chronicle/test_dashboard_query.py index aec33a03..9e0cccd7 100644 --- a/tests/chronicle/test_dashboard_query.py +++ b/tests/chronicle/test_dashboard_query.py @@ -80,14 +80,13 @@ def test_execute_query_success( response_mock.json.return_value = { "results": [{"value": "test-result"}] } - chronicle_client.session.post.return_value = response_mock + chronicle_client.session.request.return_value = response_mock query = 'udm.metadata.event_type = "PROCESS_LAUNCH"' result = dashboard_query.execute_query( chronicle_client, query=query, interval=interval ) - chronicle_client.session.post.assert_called_once() url = ( f"{chronicle_client.base_url}/{chronicle_client.instance_id}/" "dashboardQueries:execute" @@ -98,7 +97,14 @@ def test_execute_query_success( "input": interval.to_dict(), } } - chronicle_client.session.post.assert_called_with(url, json=payload) + chronicle_client.session.request.assert_called_once_with( + method="POST", + url=url, + params=None, + json=payload, + headers=None, + timeout=None, + ) assert "results" in result assert result["results"][0]["value"] == "test-result" @@ -111,7 +117,7 @@ def test_execute_query_with_filters( ) -> None: """Test execute_query with filters parameter.""" response_mock.json.return_value = {"results": []} - chronicle_client.session.post.return_value = response_mock + chronicle_client.session.request.return_value = response_mock query = 'udm.metadata.event_type = "PROCESS_LAUNCH"' filters = [{"field": "hostname", "value": "test-host"}] @@ -119,7 +125,6 @@ def test_execute_query_with_filters( chronicle_client, query=query, interval=interval, filters=filters ) - chronicle_client.session.post.assert_called_once() url = ( f"{chronicle_client.base_url}/{chronicle_client.instance_id}/" "dashboardQueries:execute" @@ -128,7 +133,14 @@ def test_execute_query_with_filters( "query": {"query": query, "input": interval.to_dict()}, "filters": filters, } - chronicle_client.session.post.assert_called_with(url, json=payload) + chronicle_client.session.request.assert_called_once_with( + method="POST", + url=url, + params=None, + json=payload, + headers=None, + timeout=None, + ) assert "results" in result @@ -140,14 +152,13 @@ def test_execute_query_with_clear_cache( ) -> None: """Test execute_query with clear_cache parameter.""" response_mock.json.return_value = {"results": []} - chronicle_client.session.post.return_value = response_mock + chronicle_client.session.request.return_value = response_mock query = 'udm.metadata.event_type = "PROCESS_LAUNCH"' result = dashboard_query.execute_query( chronicle_client, query=query, interval=interval, clear_cache=True ) - chronicle_client.session.post.assert_called_once() url = ( f"{chronicle_client.base_url}/{chronicle_client.instance_id}/" "dashboardQueries:execute" @@ -156,7 +167,14 @@ def test_execute_query_with_clear_cache( "query": {"query": query, "input": interval.to_dict()}, "clearCache": True, } - chronicle_client.session.post.assert_called_with(url, json=payload) + chronicle_client.session.request.assert_called_once_with( + method="POST", + url=url, + params=None, + json=payload, + headers=None, + timeout=None, + ) assert "results" in result @@ -165,7 +183,7 @@ def test_execute_query_with_string_json( ) -> None: """Test execute_query with string JSON interval.""" response_mock.json.return_value = {"results": []} - chronicle_client.session.post.return_value = response_mock + chronicle_client.session.request.return_value = response_mock query = 'udm.metadata.event_type = "PROCESS_LAUNCH"' interval_str = ( '{"relativeTime": {"timeUnit": "DAY", "startTimeVal": "1"}}' @@ -175,13 +193,17 @@ def test_execute_query_with_string_json( chronicle_client, query=query, interval=interval_str ) - chronicle_client.session.post.assert_called_once() - url = ( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/" - "dashboardQueries:execute" + chronicle_client.session.request.assert_called_once_with( + method="POST", + url=( + f"{chronicle_client.base_url}/{chronicle_client.instance_id}/" + "dashboardQueries:execute" + ), + params=None, + json={"query": {"query": query, "input": json.loads(interval_str)}}, + headers=None, + timeout=None, ) - payload = {"query": {"query": query, "input": json.loads(interval_str)}} - chronicle_client.session.post.assert_called_with(url, json=payload) assert "results" in result @@ -194,7 +216,7 @@ def test_execute_query_error( """Test execute_query function with error response.""" response_mock.status_code = 400 response_mock.text = "Invalid Query" - chronicle_client.session.post.return_value = response_mock + chronicle_client.session.request.return_value = response_mock query = "invalid query syntax" with pytest.raises(APIError, match="Failed to execute query"): @@ -216,19 +238,25 @@ def test_get_execute_query_success( "displayName": "Test Query", "query": 'udm.metadata.event_type = "PROCESS_LAUNCH"', } - chronicle_client.session.get.return_value = response_mock + chronicle_client.session.request.return_value = response_mock query_id = "test-query" # Call function result = dashboard_query.get_execute_query(chronicle_client, query_id) # Verify API call - chronicle_client.session.get.assert_called_once() url = ( f"{chronicle_client.base_url}/{chronicle_client.instance_id}/" f"dashboardQueries/{query_id}" ) - chronicle_client.session.get.assert_called_with(url) + chronicle_client.session.request.assert_called_once_with( + method="GET", + url=url, + params=None, + json=None, + headers=None, + timeout=None, + ) # Verify result assert result["name"].endswith("/test-query") @@ -244,7 +272,7 @@ def test_get_execute_query_with_full_id( "displayName": "Test Query", "query": 'udm.metadata.event_type = "PROCESS_LAUNCH"', } - chronicle_client.session.get.return_value = response_mock + chronicle_client.session.request.return_value = response_mock # Full project path query ID query_id = "projects/test-project/locations/test-location/dashboardQueries/test-query" @@ -254,12 +282,18 @@ def test_get_execute_query_with_full_id( result = dashboard_query.get_execute_query(chronicle_client, query_id) # Verify API call uses the extracted ID - chronicle_client.session.get.assert_called_once() url = ( f"{chronicle_client.base_url}/{chronicle_client.instance_id}/" f"dashboardQueries/{expected_id}" ) - chronicle_client.session.get.assert_called_with(url) + chronicle_client.session.request.assert_called_once_with( + method="GET", + url=url, + params=None, + json=None, + headers=None, + timeout=None, + ) # Verify result assert result["displayName"] == "Test Query" @@ -271,12 +305,9 @@ def test_get_execute_query_error( # Setup error response response_mock.status_code = 404 response_mock.text = "Query not found" - chronicle_client.session.get.return_value = response_mock + chronicle_client.session.request.return_value = response_mock query_id = "nonexistent-query" # Verify the function raises an APIError with pytest.raises(APIError, match="Failed to get query"): dashboard_query.get_execute_query(chronicle_client, query_id) - - # Verify API call - chronicle_client.session.get.assert_called_once() diff --git a/tests/chronicle/test_data_export.py b/tests/chronicle/test_data_export.py index a06d91f9..0d3ad7a7 100644 --- a/tests/chronicle/test_data_export.py +++ b/tests/chronicle/test_data_export.py @@ -50,7 +50,7 @@ def test_get_data_export(chronicle_client): "data_export_status": {"stage": "FINISHED_SUCCESS", "progress_percentage": 100}, } - with patch.object(chronicle_client.session, "get", return_value=mock_response): + with patch.object(chronicle_client.session, "request", return_value=mock_response): result = chronicle_client.get_data_export("export123") assert result["name"].endswith("/dataExports/export123") @@ -64,7 +64,7 @@ def test_get_data_export_error(chronicle_client): mock_response.status_code = 404 mock_response.text = "Data export not found" - with patch.object(chronicle_client.session, "get", return_value=mock_response): + with patch.object(chronicle_client.session, "request", return_value=mock_response): with pytest.raises(APIError, match="Failed to get data export"): chronicle_client.get_data_export("nonexistent-export") @@ -84,7 +84,7 @@ def test_create_data_export_with_log_type(chronicle_client): "dataExportStatus": {"stage": "IN_QUEUE"}, } - with patch.object(chronicle_client.session, "post", return_value=mock_response) as mock_post: + with patch.object(chronicle_client.session, "request", return_value=mock_response) as mock_post: start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) end_time = datetime(2024, 1, 2, tzinfo=timezone.utc) @@ -124,7 +124,7 @@ def test_create_data_export_with_log_types(chronicle_client): "dataExportStatus": {"stage": "IN_QUEUE"}, } - with patch.object(chronicle_client.session, "post", return_value=mock_response) as mock_post: + with patch.object(chronicle_client.session, "request", return_value=mock_response) as mock_post: start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) end_time = datetime(2024, 1, 2, tzinfo=timezone.utc) @@ -222,7 +222,7 @@ def test_create_data_export_with_all_logs(chronicle_client): } with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) end_time = datetime(2024, 1, 2, tzinfo=timezone.utc) @@ -252,7 +252,7 @@ def test_cancel_data_export(chronicle_client): } with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: result = chronicle_client.cancel_data_export("export123") @@ -260,8 +260,7 @@ def test_cancel_data_export(chronicle_client): # Check that the request was sent to the correct URL mock_post.assert_called_once() - args, kwargs = mock_post.call_args - assert args[0].endswith("/dataExports/export123:cancel") + assert mock_post.call_args[1]["url"].endswith("/dataExports/export123:cancel") def test_cancel_data_export_error(chronicle_client): @@ -270,7 +269,7 @@ def test_cancel_data_export_error(chronicle_client): mock_response.status_code = 404 mock_response.text = "Data export not found" - with patch.object(chronicle_client.session, "post", return_value=mock_response): + with patch.object(chronicle_client.session, "request", return_value=mock_response): with pytest.raises(APIError, match="Failed to cancel data export"): chronicle_client.cancel_data_export("nonexistent-export") @@ -298,7 +297,7 @@ def test_fetch_available_log_types(chronicle_client): } with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) end_time = datetime(2024, 1, 2, tzinfo=timezone.utc) @@ -338,7 +337,7 @@ def test_fetch_available_log_types_error(chronicle_client): mock_response.status_code = 400 mock_response.text = "Invalid time range" - with patch.object(chronicle_client.session, "post", return_value=mock_response): + with patch.object(chronicle_client.session, "request", return_value=mock_response): start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) end_time = datetime(2024, 1, 2, tzinfo=timezone.utc) @@ -367,7 +366,7 @@ def test_update_data_export_success(chronicle_client): # Act with patch.object( - chronicle_client.session, "patch", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_patch: start_time = datetime(2024, 1, 2, tzinfo=timezone.utc) end_time = datetime(2024, 1, 3, tzinfo=timezone.utc) @@ -415,7 +414,7 @@ def test_update_data_export_partial_update(chronicle_client): # Act with patch.object( - chronicle_client.session, "patch", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_patch: result = update_data_export( client=chronicle_client, @@ -467,7 +466,7 @@ def test_update_data_export_api_error(chronicle_client): mock_response.text = "Invalid data export ID" # Act - with patch.object(chronicle_client.session, "patch", return_value=mock_response): + with patch.object(chronicle_client.session, "request", return_value=mock_response): # Assert with pytest.raises(APIError, match="Failed to update data export"): update_data_export( @@ -498,7 +497,7 @@ def test_list_data_export_success(chronicle_client): # Act with patch.object( - chronicle_client.session, "get", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_get: result = list_data_export( client=chronicle_client, @@ -535,7 +534,7 @@ def test_list_data_export_default_params(chronicle_client): # Act with patch.object( - chronicle_client.session, "get", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_get: result = list_data_export(client=chronicle_client) @@ -545,9 +544,9 @@ def test_list_data_export_default_params(chronicle_client): # Check default parameters mock_get.assert_called_once() _, kwargs = mock_get.call_args - assert kwargs["params"]["pageSize"] is None - assert kwargs["params"]["pageToken"] is None - assert kwargs["params"]["filter"] is None + assert kwargs["params"]["pageSize"] == 1000 # 1000 page size for auto pagination + assert "pageToken" not in kwargs["params"] + assert "filter" not in kwargs["params"] def test_list_data_export_error(chronicle_client): @@ -558,9 +557,9 @@ def test_list_data_export_error(chronicle_client): mock_response.text = "Invalid filter" # Act - with patch.object(chronicle_client.session, "get", return_value=mock_response): + with patch.object(chronicle_client.session, "request", return_value=mock_response): # Assert - with pytest.raises(APIError, match="Failed to get data export"): + with pytest.raises(APIError, match="API request failed:"): list_data_export( client=chronicle_client, filters="invalid-filter" diff --git a/tests/chronicle/test_data_tables.py b/tests/chronicle/test_data_tables.py index 12770654..e050f4db 100644 --- a/tests/chronicle/test_data_tables.py +++ b/tests/chronicle/test_data_tables.py @@ -8,7 +8,9 @@ ) # Added call for checking multiple calls if needed from secops.chronicle.models import APIVersion -from secops.chronicle.client import ChronicleClient # This will be the actual client +from secops.chronicle.client import ( + ChronicleClient, +) # This will be the actual client # We'll need to import the enums and functions once they are in their final place # For now, let's assume they might be in a module like secops.chronicle.data_table @@ -61,25 +63,34 @@ def test_create_data_table_success( "columnInfo": [{"originalColumn": "col1", "columnType": "STRING"}], "dataTableUuid": "some-uuid", } - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response dt_name = "test_dt_123" description = "Test Description" header = {"col1": DataTableColumnType.STRING} - result = create_data_table(mock_chronicle_client, dt_name, description, header) + result = create_data_table( + mock_chronicle_client, dt_name, description, header + ) assert result["name"] == expected_dt_name assert result["description"] == description - mock_chronicle_client.session.post.assert_called_once_with( - f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables", + mock_chronicle_client.session.request.assert_called_once_with( + method="POST", + url=f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables", params={"dataTableId": dt_name}, json={ "description": description, "columnInfo": [ - {"columnIndex": 0, "originalColumn": "col1", "columnType": "STRING"} + { + "columnIndex": 0, + "originalColumn": "col1", + "columnType": "STRING", + } ], }, + headers=None, + timeout=None, ) @patch("secops.chronicle.data_table.create_data_table_rows") @@ -101,7 +112,7 @@ def test_create_data_table_with_rows_success( "description": "Test With Rows", # ... other fields } - mock_chronicle_client.session.post.return_value = mock_dt_response + mock_chronicle_client.session.request.return_value = mock_dt_response mock_create_rows.return_value = [ {"dataTableRows": [{"name": "row1_full_name"}]} @@ -147,34 +158,45 @@ def test_create_data_table_with_entity_mapping( mock_response = Mock() mock_response.status_code = 200 expected_dt_name = "projects/test-project/locations/us/instances/test-customer/dataTables/test_dt_123" - entity_mapping = "entity.domain.name" # Sample valid entity mapping + entity_mapping = "entity.domain.name" # Sample valid entity mapping mock_response.json.return_value = { "name": expected_dt_name, "displayName": "test_dt_123", "description": "Test Description", "createTime": "2025-06-17T10:00:00Z", - "columnInfo": [{"originalColumn": "col1", "mappedColumnPath": entity_mapping}], + "columnInfo": [ + {"originalColumn": "col1", "mappedColumnPath": entity_mapping} + ], "dataTableUuid": "some-uuid", } - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response dt_name = "test_dt_123" description = "Test Description" header = {"col1": entity_mapping} - result = create_data_table(mock_chronicle_client, dt_name, description, header) + result = create_data_table( + mock_chronicle_client, dt_name, description, header + ) assert result["name"] == expected_dt_name assert result["description"] == description - mock_chronicle_client.session.post.assert_called_once_with( - f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables", + mock_chronicle_client.session.request.assert_called_once_with( + method="POST", + url=f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables", params={"dataTableId": dt_name}, json={ "description": description, "columnInfo": [ - {"columnIndex": 0, "originalColumn": "col1", "mappedColumnPath": entity_mapping} + { + "columnIndex": 0, + "originalColumn": "col1", + "mappedColumnPath": entity_mapping, + } ], }, + headers=None, + timeout=None, ) @patch("secops.chronicle.data_table.REF_LIST_DATA_TABLE_ID_REGEX") @@ -192,39 +214,65 @@ def test_create_data_table_with_column_options( "description": "Test Description", "createTime": "2025-06-17T10:00:00Z", "columnInfo": [ - {"originalColumn": "key", "columnType": "NUMBER", "keyColumn": True}, - {"originalColumn": "repetitive", "columnType": "STRING", "repeatedValues": True} + { + "originalColumn": "key", + "columnType": "NUMBER", + "keyColumn": True, + }, + { + "originalColumn": "repetitive", + "columnType": "STRING", + "repeatedValues": True, + }, ], "dataTableUuid": "some-uuid", } - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response dt_name = "test_dt_123" description = "Test Description" header = { "key": DataTableColumnType.NUMBER, - "repetitive": DataTableColumnType.STRING + "repetitive": DataTableColumnType.STRING, } column_options = { "key": {"keyColumn": True}, - "repetitive": {"repeatedValues": True} + "repetitive": {"repeatedValues": True}, } - result = create_data_table(mock_chronicle_client, dt_name, description, header, - column_options=column_options) + result = create_data_table( + mock_chronicle_client, + dt_name, + description, + header, + column_options=column_options, + ) assert result["name"] == expected_dt_name assert result["description"] == description - mock_chronicle_client.session.post.assert_called_once_with( - f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables", + mock_chronicle_client.session.request.assert_called_once_with( + method="POST", + url=f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables", params={"dataTableId": dt_name}, json={ "description": description, "columnInfo": [ - {"columnIndex": 0, "originalColumn": "key", "columnType": "NUMBER", "keyColumn": True}, - {"columnIndex": 1, "originalColumn": "repetitive", "columnType": "STRING", "repeatedValues": True} + { + "columnIndex": 0, + "originalColumn": "key", + "columnType": "NUMBER", + "keyColumn": True, + }, + { + "columnIndex": 1, + "originalColumn": "repetitive", + "columnType": "STRING", + "repeatedValues": True, + }, ], }, + headers=None, + timeout=None, ) def test_get_data_table_success(self, mock_chronicle_client: Mock) -> None: @@ -238,15 +286,22 @@ def test_get_data_table_success(self, mock_chronicle_client: Mock) -> None: # ... other fields based on logs } mock_response.json.return_value = expected_response - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response result = get_data_table(mock_chronicle_client, dt_name) assert result == expected_response - mock_chronicle_client.session.get.assert_called_once_with( - f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}" + mock_chronicle_client.session.request.assert_called_once_with( + method="GET", + url=f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}", + params=None, + json=None, + headers=None, + timeout=None, ) - def test_list_data_tables_success(self, mock_chronicle_client: Mock) -> None: + def test_list_data_tables_success( + self, mock_chronicle_client: Mock + ) -> None: """Test successful listing of data tables without pagination.""" mock_response = Mock() mock_response.status_code = 200 @@ -257,15 +312,21 @@ def test_list_data_tables_success(self, mock_chronicle_client: Mock) -> None: ] # No nextPageToken means single page } - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response - result = list_data_tables(mock_chronicle_client, order_by="createTime asc") + result = list_data_tables( + mock_chronicle_client, order_by="createTime asc" + ) assert len(result) == 2 assert result[0]["displayName"] == "DT One" - mock_chronicle_client.session.get.assert_called_once_with( - f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables", + mock_chronicle_client.session.request.assert_called_once_with( + method="GET", + url=f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables", params={"pageSize": 1000, "orderBy": "createTime asc"}, + json=None, + headers=None, + timeout=None, ) def test_list_data_tables_api_error_invalid_orderby( @@ -274,31 +335,38 @@ def test_list_data_tables_api_error_invalid_orderby( """Test list_data_tables when API returns error for invalid orderBy.""" mock_response = Mock() mock_response.status_code = 400 - mock_response.text = ( - "invalid order by field: ordering is only supported by create time asc" - ) + mock_response.text = "invalid order by field: ordering is only supported by create time asc" # No .json() method will be called if status is not 200 in the actual code - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response with pytest.raises( - APIError, match="Failed to list data tables: 400 invalid order by field" + APIError, + match="API request failed", ): list_data_tables(mock_chronicle_client, order_by="createTime desc") - def test_delete_data_table_success(self, mock_chronicle_client: Mock) -> None: + def test_delete_data_table_success( + self, mock_chronicle_client: Mock + ) -> None: """Test successful deletion of a data table.""" mock_response = Mock() - mock_response.status_code = 200 # API might return 200 with empty body or LRO + mock_response.status_code = ( + 200 # API might return 200 with empty body or LRO + ) mock_response.json.return_value = {} # Based on your logs - mock_chronicle_client.session.delete.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response dt_name = "dt_to_delete" result = delete_data_table(mock_chronicle_client, dt_name, force=True) assert result == {} - mock_chronicle_client.session.delete.assert_called_once_with( - f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}", + mock_chronicle_client.session.request.assert_called_once_with( + method="DELETE", + url=f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}", params={"force": "true"}, + json=None, + headers=None, + timeout=None, ) @patch("secops.chronicle.data_table._create_data_table_rows") @@ -316,7 +384,9 @@ def test_create_data_table_rows_chunking( } dt_name = "dt_for_chunking" - responses = create_data_table_rows(mock_chronicle_client, dt_name, rows_data) + responses = create_data_table_rows( + mock_chronicle_client, dt_name, rows_data + ) # Expect two calls: one for 1000 rows, one for 500 rows assert mock_internal_create_rows.call_count == 2 @@ -331,7 +401,9 @@ def test_create_data_table_rows_chunking( assert len(responses) == 2 - def test_list_data_table_rows_success(self, mock_chronicle_client: Mock) -> None: + def test_list_data_table_rows_success( + self, mock_chronicle_client: Mock + ) -> None: """Test successful listing of data table rows.""" mock_response = Mock() mock_response.status_code = 200 @@ -341,7 +413,7 @@ def test_list_data_table_rows_success(self, mock_chronicle_client: Mock) -> None {"name": "row2_full", "values": ["c", "d"]}, ] } - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response dt_name = "my_table_with_rows" result = list_data_table_rows( @@ -350,9 +422,13 @@ def test_list_data_table_rows_success(self, mock_chronicle_client: Mock) -> None assert len(result) == 2 assert result[0]["values"] == ["a", "b"] - mock_chronicle_client.session.get.assert_called_once_with( - f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}/dataTableRows", + mock_chronicle_client.session.request.assert_called_once_with( + method="GET", + url=f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}/dataTableRows", params={"pageSize": 1000, "orderBy": "createTime asc"}, + json=None, + headers=None, + timeout=None, ) @patch("secops.chronicle.data_table._delete_data_table_row") @@ -414,7 +490,7 @@ def test_create_reference_list_success( "syntaxType": "REFERENCE_LIST_SYNTAX_TYPE_PLAIN_TEXT_STRING", } mock_response.json.return_value = expected_response_json - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response result = create_reference_list( mock_chronicle_client, rl_name, description, entries, syntax_type @@ -423,14 +499,17 @@ def test_create_reference_list_success( assert result["displayName"] == rl_name assert result["description"] == description assert len(result["entries"]) == 2 - mock_chronicle_client.session.post.assert_called_once_with( - f"{mock_chronicle_client.base_url(APIVersion.V1)}/{mock_chronicle_client.instance_id}/referenceLists", + mock_chronicle_client.session.request.assert_called_once_with( + method="POST", + url=f"{mock_chronicle_client.base_url(APIVersion.V1)}/{mock_chronicle_client.instance_id}/referenceLists", params={"referenceListId": rl_name}, json={ "description": description, "entries": [{"value": "entryA"}, {"value": "entryB"}], "syntaxType": syntax_type.value, }, + headers=None, + timeout=None, ) @patch("secops.chronicle.reference_list.REF_LIST_DATA_TABLE_ID_REGEX") @@ -454,7 +533,7 @@ def test_create_reference_list_cidr_success( "syntaxType": "REFERENCE_LIST_SYNTAX_TYPE_CIDR", "entries": [{"value": "192.168.1.0/24"}], } - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response create_reference_list( mock_chronicle_client, @@ -483,7 +562,7 @@ def test_get_reference_list_full_view_success( "scopeInfo": {"referenceListScope": {}}, } mock_response.json.return_value = expected_response_json - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response result = get_reference_list( mock_chronicle_client, rl_name, view=ReferenceListView.FULL @@ -491,9 +570,13 @@ def test_get_reference_list_full_view_success( assert result["description"] == "Full RL details" assert len(result["entries"]) == 1 - mock_chronicle_client.session.get.assert_called_once_with( - f"{mock_chronicle_client.base_url(APIVersion.V1)}/{mock_chronicle_client.instance_id}/referenceLists/{rl_name}", + mock_chronicle_client.session.request.assert_called_once_with( + method="GET", + url=f"{mock_chronicle_client.base_url(APIVersion.V1)}/{mock_chronicle_client.instance_id}/referenceLists/{rl_name}", params={"view": ReferenceListView.FULL.value}, + json=None, + headers=None, + timeout=None, ) def test_list_reference_lists_basic_view_success( @@ -513,16 +596,22 @@ def test_list_reference_lists_basic_view_success( } ] } - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response - results = list_reference_lists(mock_chronicle_client) # Defaults to BASIC + results = list_reference_lists( + mock_chronicle_client + ) # Defaults to BASIC assert len(results) == 1 assert results[0]["displayName"] == "rl_basic1" assert "entries" not in results[0] # Entries are not in BASIC view - mock_chronicle_client.session.get.assert_called_once_with( - f"{mock_chronicle_client.base_url(APIVersion.V1)}/{mock_chronicle_client.instance_id}/referenceLists", + mock_chronicle_client.session.request.assert_called_once_with( + method="GET", + url=f"{mock_chronicle_client.base_url(APIVersion.V1)}/{mock_chronicle_client.instance_id}/referenceLists", params={"pageSize": 1000, "view": ReferenceListView.BASIC.value}, + json=None, + headers=None, + timeout=None, ) @patch("secops.chronicle.reference_list.get_reference_list") @@ -553,7 +642,7 @@ def test_update_reference_list_success( # other fields like scopeInfo might be present } mock_response.json.return_value = expected_response_json - mock_chronicle_client.session.patch.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response result = update_reference_list( mock_chronicle_client, @@ -566,13 +655,19 @@ def test_update_reference_list_success( assert len(result["entries"]) == 2 assert result["entries"][0]["value"] == "updated_entryX" - mock_chronicle_client.session.patch.assert_called_once_with( - f"{mock_chronicle_client.base_url(APIVersion.V1)}/{mock_chronicle_client.instance_id}/referenceLists/{rl_name}", + mock_chronicle_client.session.request.assert_called_once_with( + method="PATCH", + url=f"{mock_chronicle_client.base_url(APIVersion.V1)}/{mock_chronicle_client.instance_id}/referenceLists/{rl_name}", + params={'updateMask': 'description,entries'}, json={ "description": new_description, - "entries": [{"value": "updated_entryX"}, {"value": "new_entryY"}], + "entries": [ + {"value": "updated_entryX"}, + {"value": "new_entryY"}, + ], }, - params={"updateMask": "description,entries"}, + headers=None, + timeout=None, ) def test_update_reference_list_no_changes_error( @@ -584,7 +679,7 @@ def test_update_reference_list_no_changes_error( match=r"Either description or entries \(or both\) must be provided for update.", ): update_reference_list(mock_chronicle_client, "some_rl_name") - + @patch("secops.chronicle.data_table.REF_LIST_DATA_TABLE_ID_REGEX") def test_update_data_table_success_both_params( self, mock_regex_check: Mock, mock_chronicle_client: Mock @@ -593,12 +688,12 @@ def test_update_data_table_success_both_params( mock_regex_check.match.return_value = True # Assume name is valid mock_response = Mock() mock_response.status_code = 200 - + dt_name = "test_dt_update" expected_dt_name = f"projects/test-project/locations/us/instances/test-customer/dataTables/{dt_name}" new_description = "Updated description" new_row_ttl = "48h" - + mock_response.json.return_value = { "name": expected_dt_name, "description": new_description, @@ -607,27 +702,30 @@ def test_update_data_table_success_both_params( "columnInfo": [{"originalColumn": "col1", "columnType": "STRING"}], "dataTableUuid": "test-uuid", } - - mock_chronicle_client.session.patch.return_value = mock_response + + mock_chronicle_client.session.request.return_value = mock_response result = update_data_table( - mock_chronicle_client, - dt_name, - description=new_description, - row_time_to_live=new_row_ttl + mock_chronicle_client, + dt_name, + description=new_description, + row_time_to_live=new_row_ttl, ) assert result["name"] == expected_dt_name assert result["description"] == new_description assert result["rowTimeToLive"] == new_row_ttl - - mock_chronicle_client.session.patch.assert_called_once_with( - f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}", - params={}, + + mock_chronicle_client.session.request.assert_called_once_with( + method="PATCH", + url=f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}", + params=None, json={ "description": new_description, "row_time_to_live": new_row_ttl, }, + headers=None, + timeout=None, ) @patch("secops.chronicle.data_table.REF_LIST_DATA_TABLE_ID_REGEX") @@ -638,34 +736,35 @@ def test_update_data_table_description_only( mock_regex_check.match.return_value = True mock_response = Mock() mock_response.status_code = 200 - + dt_name = "test_dt_update" expected_dt_name = f"projects/test-project/locations/us/instances/test-customer/dataTables/{dt_name}" new_description = "Updated description only" - + mock_response.json.return_value = { "name": expected_dt_name, "description": new_description, "updateTime": "2025-08-25T10:05:00Z", "dataTableUuid": "test-uuid", } - - mock_chronicle_client.session.patch.return_value = mock_response + + mock_chronicle_client.session.request.return_value = mock_response result = update_data_table( - mock_chronicle_client, - dt_name, - description=new_description + mock_chronicle_client, dt_name, description=new_description ) assert result["name"] == expected_dt_name assert result["description"] == new_description assert "rowTimeToLive" not in result - - mock_chronicle_client.session.patch.assert_called_once_with( - f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}", - params={}, + + mock_chronicle_client.session.request.assert_called_once_with( + method="PATCH", + url=f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}", + params=None, json={"description": new_description}, + headers=None, + timeout=None, ) @patch("secops.chronicle.data_table.REF_LIST_DATA_TABLE_ID_REGEX") @@ -676,34 +775,35 @@ def test_update_data_table_row_ttl_only( mock_regex_check.match.return_value = True mock_response = Mock() mock_response.status_code = 200 - + dt_name = "test_dt_update" expected_dt_name = f"projects/test-project/locations/us/instances/test-customer/dataTables/{dt_name}" new_row_ttl = "72h" - + mock_response.json.return_value = { "name": expected_dt_name, "rowTimeToLive": new_row_ttl, "updateTime": "2025-08-25T10:10:00Z", "dataTableUuid": "test-uuid", } - - mock_chronicle_client.session.patch.return_value = mock_response + + mock_chronicle_client.session.request.return_value = mock_response result = update_data_table( - mock_chronicle_client, - dt_name, - row_time_to_live=new_row_ttl + mock_chronicle_client, dt_name, row_time_to_live=new_row_ttl ) assert result["name"] == expected_dt_name assert result["rowTimeToLive"] == new_row_ttl assert "description" not in result - - mock_chronicle_client.session.patch.assert_called_once_with( - f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}", - params={}, + + mock_chronicle_client.session.request.assert_called_once_with( + method="PATCH", + url=f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}", + params=None, json={"row_time_to_live": new_row_ttl}, + headers=None, + timeout=None, ) @patch("secops.chronicle.data_table.REF_LIST_DATA_TABLE_ID_REGEX") @@ -727,29 +827,32 @@ def test_update_data_table_with_update_mask( "updateTime": "2025-08-25T10:15:00Z", "dataTableUuid": "test-uuid", } - - mock_chronicle_client.session.patch.return_value = mock_response + + mock_chronicle_client.session.request.return_value = mock_response result = update_data_table( - mock_chronicle_client, - dt_name, + mock_chronicle_client, + dt_name, description=new_description, row_time_to_live=new_row_ttl, - update_mask=update_mask + update_mask=update_mask, ) assert result["name"] == expected_dt_name assert result["description"] == new_description - + # Verify that even though row_time_to_live was provided, it wasn't included in the API call # due to the update_mask - mock_chronicle_client.session.patch.assert_called_once_with( - f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}", + mock_chronicle_client.session.request.assert_called_once_with( + method="PATCH", + url=f"{mock_chronicle_client.base_url}/{mock_chronicle_client.instance_id}/dataTables/{dt_name}", params={"updateMask": "description"}, json={ "description": new_description, "row_time_to_live": new_row_ttl, }, + headers=None, + timeout=None, ) @patch("secops.chronicle.data_table.REF_LIST_DATA_TABLE_ID_REGEX") @@ -766,9 +869,9 @@ def test_update_data_table_invalid_name( "invalid_name!", description="New description", ) - + # Verify the API was never called - mock_chronicle_client.session.patch.assert_not_called() + mock_chronicle_client.session.request.assert_not_called() def test_update_data_table_api_error( self, mock_chronicle_client: Mock @@ -777,22 +880,24 @@ def test_update_data_table_api_error( mock_response = Mock() mock_response.status_code = 400 mock_response.text = "Invalid row_time_to_live format" - - mock_chronicle_client.session.patch.return_value = mock_response - + + mock_chronicle_client.session.request.return_value = mock_response + with pytest.raises( - APIError, match="Failed to update data table 'test_table': 400 Invalid row_time_to_live format" + APIError, + match="Failed to update data table 'test_table'", ): update_data_table( - mock_chronicle_client, - "test_table", - row_time_to_live="invalid" + mock_chronicle_client, "test_table", row_time_to_live="invalid" ) - @patch('secops.chronicle.data_table._estimate_row_json_size') - @patch('secops.chronicle.data_table.create_data_table_rows') + @patch("secops.chronicle.data_table._estimate_row_json_size") + @patch("secops.chronicle.data_table.create_data_table_rows") def test_replace_data_table_rows_size_based_batching( - self, mock_create_rows: Mock, mock_estimate_size: Mock, mock_chronicle_client: Mock + self, + mock_create_rows: Mock, + mock_estimate_size: Mock, + mock_chronicle_client: Mock, ) -> None: """Test that replace_data_table_rows handles size-based batching.""" # Mock response for API calls @@ -801,51 +906,62 @@ def test_replace_data_table_rows_size_based_batching( mock_response.json.return_value = { "dataTableRows": [{"name": "row_replaced"}] } - mock_chronicle_client.session.post.return_value = mock_response - + mock_chronicle_client.session.request.return_value = mock_response + # Mock create_data_table_rows function for remaining rows - mock_create_rows.return_value = [{"dataTableRows": [{"name": "row_created"}]}] - + mock_create_rows.return_value = [ + {"dataTableRows": [{"name": "row_created"}]} + ] + # Create test data: first batch will have some rows close to 4MB limit # to test size-based batching dt_name = "dt_for_replace_batching" - rows_data = [[f"small_value{i}"] for i in range(950)] # Under 1000 rows total - + rows_data = [ + [f"small_value{i}"] for i in range(950) + ] # Under 1000 rows total + # Mock size estimation to force size-based batching # First 5 rows are large (close to 1MB each), rest are small def estimate_size_side_effect(row): if row[0].startswith("small_value") and int(row[0][11:]) < 5: return 900000 # Almost 1MB each for first 5 rows return 10000 # Small size for other rows - + mock_estimate_size.side_effect = estimate_size_side_effect - + # Call the function under test - responses = replace_data_table_rows(mock_chronicle_client, dt_name, rows_data) - + responses = replace_data_table_rows( + mock_chronicle_client, dt_name, rows_data + ) + # Verify the correct behavior: # 1. Single bulkReplace call for the rows that fit in 4MB # 2. create_data_table_rows function call for remaining rows - + # First call should be bulkReplace with only the rows that fit in 4MB - mock_chronicle_client.session.post.assert_called_once() - post_call = mock_chronicle_client.session.post.call_args - assert "bulkReplace" in post_call[0][0] - + mock_chronicle_client.session.request.assert_called_once() + post_call = mock_chronicle_client.session.request.call_args + assert "bulkReplace" in post_call[1]["url"] + # The create_data_table_rows function should be called for remaining rows mock_create_rows.assert_called_once() create_call_args = mock_create_rows.call_args # Verify the function was called with the right parameters assert create_call_args[0][0] == mock_chronicle_client # client assert create_call_args[0][1] == dt_name # name - + # Verify we got responses from both operations assert len(responses) == 2 - @patch('secops.chronicle.data_table._estimate_row_json_size', return_value=1000) # Small enough for all rows - @patch('secops.chronicle.data_table.create_data_table_rows') + @patch( + "secops.chronicle.data_table._estimate_row_json_size", return_value=1000 + ) # Small enough for all rows + @patch("secops.chronicle.data_table.create_data_table_rows") def test_replace_data_table_rows_chunking( - self, mock_create_rows: Mock, mock_estimate_size: Mock, mock_chronicle_client: Mock + self, + mock_create_rows: Mock, + mock_estimate_size: Mock, + mock_chronicle_client: Mock, ) -> None: """Test that replace_data_table_rows chunks large inputs over 1000 rows.""" # Mock responses for API calls @@ -854,39 +970,47 @@ def test_replace_data_table_rows_chunking( mock_response.json.return_value = { "dataTableRows": [{"name": "row_replaced_chunk"}] } - mock_chronicle_client.session.post.return_value = mock_response - + mock_chronicle_client.session.request.return_value = mock_response + # Mock create_data_table_rows function for remaining rows - mock_create_rows.return_value = [{"dataTableRows": [{"name": "row_created_chunk"}]}] - + mock_create_rows.return_value = [ + {"dataTableRows": [{"name": "row_created_chunk"}]} + ] + # Create test data with more than 1000 rows dt_name = "dt_for_replace_chunking" rows_data = [[f"new_value{i}"] for i in range(1500)] # 1500 rows - + # Call the function under test - responses = replace_data_table_rows(mock_chronicle_client, dt_name, rows_data) - + responses = replace_data_table_rows( + mock_chronicle_client, dt_name, rows_data + ) + # Verify first call was bulkReplace with first 1000 rows - assert mock_chronicle_client.session.post.call_count == 1 - post_call = mock_chronicle_client.session.post.call_args - assert "bulkReplace" in post_call[0][0] - + assert mock_chronicle_client.session.request.call_count == 1 + post_call = mock_chronicle_client.session.request.call_args + assert "bulkReplace" in post_call[1]["url"] + # Verify the remaining rows were sent using create_data_table_rows function mock_create_rows.assert_called_once() create_call = mock_create_rows.call_args - + # Verify function was called with correct parameters assert create_call[0][0] == mock_chronicle_client # client parameter assert create_call[0][1] == dt_name # table name parameter - + # We need to include rows 1000-1499 (500 rows total) in the remaining batch - remaining_rows = create_call[0][2] # Get the rows passed to create_data_table_rows + remaining_rows = create_call[0][ + 2 + ] # Get the rows passed to create_data_table_rows assert len(remaining_rows) == 500 # 500 remaining rows - + # Verify we got response correctly assert len(responses) == 2 - def test_replace_data_table_rows_few_rows(self, mock_chronicle_client: Mock) -> None: + def test_replace_data_table_rows_few_rows( + self, mock_chronicle_client: Mock + ) -> None: """Test direct call to replace_data_table_rows with a small number of rows.""" # Mock response for API call mock_response = Mock() @@ -897,68 +1021,102 @@ def test_replace_data_table_rows_few_rows(self, mock_chronicle_client: Mock) -> {"name": "replaced_row2", "values": ["new3", "new4"]}, ] } - mock_chronicle_client.session.post.return_value = mock_response - + mock_chronicle_client.session.request.return_value = mock_response + dt_name = "test_dt_replace" - rows_to_replace = [["new1", "new2"], ["new3", "new4"]] # Small set of rows - + rows_to_replace = [ + ["new1", "new2"], + ["new3", "new4"], + ] # Small set of rows + # Patch row size estimation to return small values and create_data_table_rows # since we don't use it in this test case - with patch('secops.chronicle.data_table._estimate_row_json_size', return_value=1000), \ - patch('secops.chronicle.data_table.create_data_table_rows') as mock_create_rows: + with ( + patch( + "secops.chronicle.data_table._estimate_row_json_size", + return_value=1000, + ), + patch( + "secops.chronicle.data_table.create_data_table_rows" + ) as mock_create_rows, + ): # Mock doesn't get called in this test but needs to be patched # to prevent any unwanted side effects mock_create_rows.return_value = [] - + # Call the function under test - result = replace_data_table_rows(mock_chronicle_client, dt_name, rows_to_replace) - + result = replace_data_table_rows( + mock_chronicle_client, dt_name, rows_to_replace + ) + # Verify API was called correctly - mock_chronicle_client.session.post.assert_called_once() - call_args = mock_chronicle_client.session.post.call_args - assert "bulkReplace" in call_args[0][0] # URL has bulkReplace - + mock_chronicle_client.session.request.assert_called_once() + call_args = mock_chronicle_client.session.request.call_args + assert "bulkReplace" in call_args[1]["url"] # URL has bulkReplace + # Verify we have all rows in a single request requests = call_args[1]["json"]["requests"] assert len(requests) == 2 # Both rows in a single request - + # Verify response was processed correctly assert len(result) == 1 assert result[0] == mock_response.json.return_value - + # Verify we didn't need to use create_data_table_rows for additional rows mock_create_rows.assert_not_called() - def test_replace_data_table_rows_api_error(self, mock_chronicle_client: Mock) -> None: + def test_replace_data_table_rows_api_error( + self, mock_chronicle_client: Mock + ) -> None: """Test API error handling in replace_data_table_rows.""" # Mock API error response error_response = Mock() error_response.status_code = 400 error_response.text = "Invalid row format" - mock_chronicle_client.session.post.return_value = error_response - + mock_chronicle_client.session.request.return_value = error_response + dt_name = "invalid_table" rows_to_replace = [["bad_data"]] # Small test data - + # Patch row size estimation to avoid size issues and patch create_data_table_rows # as it's not expected to be called in this error case - with patch('secops.chronicle.data_table._estimate_row_json_size', return_value=1000), \ - patch('secops.chronicle.data_table.create_data_table_rows'): - with pytest.raises(APIError, match="Failed to replace data table rows for 'invalid_table': 400 Invalid row format"): - replace_data_table_rows(mock_chronicle_client, dt_name, rows_to_replace) + with ( + patch( + "secops.chronicle.data_table._estimate_row_json_size", + return_value=1000, + ), + patch("secops.chronicle.data_table.create_data_table_rows"), + ): + with pytest.raises( + APIError, + match="Failed to replace data table rows for 'invalid_table'", + ): + replace_data_table_rows( + mock_chronicle_client, dt_name, rows_to_replace + ) - def test_replace_data_table_rows_single_oversized_row(self, mock_chronicle_client: Mock) -> None: + def test_replace_data_table_rows_single_oversized_row( + self, mock_chronicle_client: Mock + ) -> None: """Test handling of a single oversized row in replace_data_table_rows.""" dt_name = "dt_with_big_row" oversized_row = [["*" * 1000000]] # Very large row - + # Mock _estimate_row_json_size to return a value larger than 4MB for our oversized row # Also patch create_data_table_rows as it won't be called in this error case - with patch('secops.chronicle.data_table._estimate_row_json_size', return_value=5000000), \ - patch('secops.chronicle.data_table.create_data_table_rows'): - with pytest.raises(SecOpsError, match="Single row is too large to process"): - replace_data_table_rows(mock_chronicle_client, dt_name, oversized_row) - + with ( + patch( + "secops.chronicle.data_table._estimate_row_json_size", + return_value=5000000, + ), + patch("secops.chronicle.data_table.create_data_table_rows"), + ): + with pytest.raises( + SecOpsError, match="Single row is too large to process" + ): + replace_data_table_rows( + mock_chronicle_client, dt_name, oversized_row + ) def test_update_data_table_rows_success( self, mock_chronicle_client: Mock @@ -981,7 +1139,7 @@ def test_update_data_table_rows_success( }, ] } - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response dt_name = "dt1" row_updates = [ @@ -1007,17 +1165,18 @@ def test_update_data_table_rows_success( ) # Verify API was called correctly - mock_chronicle_client.session.post.assert_called_once() - call_args = mock_chronicle_client.session.post.call_args - assert "bulkUpdate" in call_args[0][0] + mock_chronicle_client.session.request.assert_called_once() + call_args = mock_chronicle_client.session.request.call_args + assert "bulkUpdate" in call_args[1]["url"] # Verify payload structure requests = call_args[1]["json"]["requests"] assert len(requests) == 2 assert requests[0]["dataTableRow"]["name"] == row_updates[0]["name"] - assert requests[0]["dataTableRow"]["values"] == row_updates[0][ - "values" - ] + assert ( + requests[0]["dataTableRow"]["values"] + == row_updates[0]["values"] + ) # Verify response assert len(result) == 1 @@ -1038,7 +1197,7 @@ def test_update_data_table_rows_with_update_mask( } ] } - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response dt_name = "dt1" row_updates = [ @@ -1059,7 +1218,7 @@ def test_update_data_table_rows_with_update_mask( ) # Verify update mask is included in request - call_args = mock_chronicle_client.session.post.call_args + call_args = mock_chronicle_client.session.request.call_args requests = call_args[1]["json"]["requests"] assert "updateMask" in requests[0] assert requests[0]["updateMask"] == "values" @@ -1078,7 +1237,7 @@ def test_update_data_table_rows_chunking_1000_rows( mock_response.json.return_value = { "dataTableRows": [{"name": "row_updated"}] } - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response # Mock small row sizes mock_estimate_size.return_value = 1000 @@ -1100,15 +1259,15 @@ def test_update_data_table_rows_chunking_1000_rows( ) # Verify API was called twice (1000 + 500) - assert mock_chronicle_client.session.post.call_count == 2 + assert mock_chronicle_client.session.request.call_count == 2 # Verify first call has 1000 rows - first_call = mock_chronicle_client.session.post.call_args_list[0] + first_call = mock_chronicle_client.session.request.call_args_list[0] first_requests = first_call[1]["json"]["requests"] assert len(first_requests) == 1000 # Verify second call has 500 rows - second_call = mock_chronicle_client.session.post.call_args_list[1] + second_call = mock_chronicle_client.session.request.call_args_list[1] second_requests = second_call[1]["json"]["requests"] assert len(second_requests) == 500 @@ -1125,7 +1284,7 @@ def test_update_data_table_rows_size_based_chunking( mock_response.json.return_value = { "dataTableRows": [{"name": "row_updated"}] } - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response dt_name = "dt_size_chunking" # Create test data with varying sizes @@ -1154,7 +1313,7 @@ def estimate_size_side_effect(row): ) # Should have multiple chunks due to size constraints - assert mock_chronicle_client.session.post.call_count >= 2 + assert mock_chronicle_client.session.request.call_count >= 2 assert len(result) >= 2 def test_update_data_table_rows_empty_list( @@ -1173,7 +1332,7 @@ def test_update_data_table_rows_empty_list( ) # Should not make any API calls - mock_chronicle_client.session.post.assert_not_called() + mock_chronicle_client.session.request.assert_not_called() # Should return empty list assert result == [] @@ -1185,7 +1344,7 @@ def test_update_data_table_rows_api_error( error_response = Mock() error_response.status_code = 400 error_response.text = "Invalid row data" - mock_chronicle_client.session.post.return_value = error_response + mock_chronicle_client.session.request.return_value = error_response dt_name = "dt_error" row_updates = [ @@ -1202,8 +1361,7 @@ def test_update_data_table_rows_api_error( ): with pytest.raises( APIError, - match="Failed to update data table rows for 'dt_error': " - "400 Invalid row data", + match="Failed to update data table rows for 'dt_error': ", ): update_data_table_rows( mock_chronicle_client, dt_name, row_updates diff --git a/tests/chronicle/test_feed.py b/tests/chronicle/test_feed.py index effc7159..605aa241 100644 --- a/tests/chronicle/test_feed.py +++ b/tests/chronicle/test_feed.py @@ -74,15 +74,19 @@ def test_create_feed(chronicle_client, mock_response): ) with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: # Act result = create_feed(chronicle_client, feed_config) # Assert mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/feeds", + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/feeds", + params=None, json=feed_config.to_dict(), + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -96,7 +100,7 @@ def test_create_feed_with_json_string(chronicle_client, mock_response): ) with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: # Act result = create_feed(chronicle_client, feed_config) @@ -107,8 +111,12 @@ def test_create_feed_with_json_string(chronicle_client, mock_response): "details": {"feed_source_type": "syslog", "log_type": "network"}, } mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/feeds", + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/feeds", + params=None, json=expected_json, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -121,7 +129,7 @@ def test_create_feed_error(chronicle_client, mock_error_response): ) with patch.object( - chronicle_client.session, "post", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): # Act & Assert with pytest.raises(APIError) as exc_info: @@ -135,14 +143,19 @@ def test_get_feed(chronicle_client, mock_response): # Arrange feed_id = "feed_12345" with patch.object( - chronicle_client.session, "get", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_get: # Act result = get_feed(chronicle_client, feed_id) # Assert mock_get.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/feeds/{feed_id}" + method="GET", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/feeds/{feed_id}", + params=None, + json=None, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -152,7 +165,7 @@ def test_get_feed_error(chronicle_client, mock_error_response): # Arrange feed_id = "feed_12345" with patch.object( - chronicle_client.session, "get", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): # Act & Assert with pytest.raises(APIError) as exc_info: @@ -169,73 +182,91 @@ def test_list_feeds(chronicle_client, mock_response): } with patch.object( - chronicle_client.session, "get", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_get: # Act result = list_feeds(chronicle_client) # Assert mock_get.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/feeds", - params={"pageSize": 100, "pageToken": None}, + method="GET", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/feeds", + params={"pageSize": 100}, + json=None, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value["feeds"] assert len(result) == 2 def test_list_feeds_with_pagination(chronicle_client, mock_response): - """Test list_feeds function with pagination.""" - # Arrange - Mock first call with next_page_token, second call without it + """Test list_feeds function with automatic pagination across multiple pages.""" + # Arrange - When page_token is NOT provided, it automatically fetches + # all pages and aggregates results first_response = Mock() first_response.status_code = 200 first_response.json.return_value = { - "feeds": [{"name": "feed1"}], - "next_page_token": "token123", + "feeds": [{"name": "feed1"}, {"name": "feed2"}], + "nextPageToken": "next_token_456", } second_response = Mock() second_response.status_code = 200 - second_response.json.return_value = {"feeds": [{"name": "feed2"}]} + second_response.json.return_value = { + "feeds": [{"name": "feed3"}], + "nextPageToken": "next_token_789", + } + + third_response = Mock() + third_response.status_code = 200 + third_response.json.return_value = { + "feeds": [{"name": "feed4"}], + } with patch.object( chronicle_client.session, - "get", - side_effect=[first_response, second_response], + "request", + side_effect=[first_response, second_response, third_response], ) as mock_get: - # Act - result = list_feeds( - chronicle_client, page_size=50, page_token="token123" - ) + # Act - Pass page_size=None and page_token=None to enable automatic pagination + # If page_size has a default value, it triggers single-page mode + result = list_feeds(chronicle_client, page_size=None, page_token=None) + + # Assert - Multiple calls are made to fetch all pages + assert mock_get.call_count == 3 + + # Verify each call was made with correct params (DEFAULT_PAGE_SIZE=1000) + calls = mock_get.call_args_list + assert calls[0].kwargs["params"] == {"pageSize": 1000} + assert calls[1].kwargs["params"] == { + "pageSize": 1000, + "pageToken": "next_token_456", + } + assert calls[2].kwargs["params"] == { + "pageSize": 1000, + "pageToken": "next_token_789", + } - # Assert - assert mock_get.call_count == 2 - # First call with initial page_token - mock_get.assert_any_call( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/feeds", - params={"pageSize": 50, "pageToken": "token123"}, - ) - # Second call with next_page_token from first response - mock_get.assert_any_call( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/feeds", - params={"pageSize": 50, "pageToken": "token123"}, - ) - # Result should be combined feeds from both pages - assert len(result) == 2 + # Result should be aggregated feeds from all pages + assert len(result) == 4 assert result[0]["name"] == "feed1" assert result[1]["name"] == "feed2" + assert result[2]["name"] == "feed3" + assert result[3]["name"] == "feed4" def test_list_feeds_error(chronicle_client, mock_error_response): """Test list_feeds function with error response.""" # Arrange with patch.object( - chronicle_client.session, "get", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): # Act & Assert with pytest.raises(APIError) as exc_info: list_feeds(chronicle_client) - assert "Failed to list feeds" in str(exc_info.value) + assert "API request failed" in str(exc_info.value) def test_update_feed(chronicle_client, mock_response): @@ -248,16 +279,19 @@ def test_update_feed(chronicle_client, mock_response): ) with patch.object( - chronicle_client.session, "patch", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_patch: # Act result = update_feed(chronicle_client, feed_id, feed_config) # Assert mock_patch.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/feeds/{feed_id}", + method="PATCH", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/feeds/{feed_id}", params={"updateMask": "display_name,details"}, json=feed_config.to_dict(), + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -270,7 +304,7 @@ def test_update_feed_with_custom_mask(chronicle_client, mock_response): update_mask = ["display_name"] with patch.object( - chronicle_client.session, "patch", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_patch: # Act result = update_feed( @@ -279,9 +313,12 @@ def test_update_feed_with_custom_mask(chronicle_client, mock_response): # Assert mock_patch.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/feeds/{feed_id}", + method="PATCH", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/feeds/{feed_id}", params={"updateMask": "display_name"}, json=feed_config.to_dict(), + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -293,7 +330,7 @@ def test_update_feed_error(chronicle_client, mock_error_response): feed_config = UpdateFeedModel(display_name="Updated Feed") with patch.object( - chronicle_client.session, "patch", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): # Act & Assert with pytest.raises(APIError) as exc_info: @@ -307,14 +344,19 @@ def test_delete_feed(chronicle_client, mock_response): # Arrange feed_id = "feed_12345" with patch.object( - chronicle_client.session, "delete", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_delete: # Act result = delete_feed(chronicle_client, feed_id) # Assert mock_delete.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/feeds/{feed_id}" + method="DELETE", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/feeds/{feed_id}", + params=None, + json=None, + headers=None, + timeout=None, ) assert result is None @@ -324,7 +366,7 @@ def test_delete_feed_error(chronicle_client, mock_error_response): # Arrange feed_id = "feed_12345" with patch.object( - chronicle_client.session, "delete", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): # Act & Assert with pytest.raises(APIError) as exc_info: @@ -338,14 +380,19 @@ def test_enable_feed(chronicle_client, mock_response): # Arrange feed_id = "feed_12345" with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: # Act result = enable_feed(chronicle_client, feed_id) # Assert mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/feeds/{feed_id}:enable" + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/feeds/{feed_id}:enable", + params=None, + json=None, + headers=None, + timeout=None, ) assert feed_id in f"{result}" @@ -355,7 +402,7 @@ def test_enable_feed_error(chronicle_client, mock_error_response): # Arrange feed_id = "feed_12345" with patch.object( - chronicle_client.session, "post", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): # Act & Assert with pytest.raises(APIError) as exc_info: @@ -369,14 +416,19 @@ def test_disable_feed(chronicle_client, mock_response): # Arrange feed_id = "feed_12345" with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: # Act result = disable_feed(chronicle_client, feed_id) # Assert mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/feeds/{feed_id}:disable" + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/feeds/{feed_id}:disable", + params=None, + json=None, + headers=None, + timeout=None, ) assert feed_id in f"{result}" @@ -386,7 +438,7 @@ def test_disable_feed_error(chronicle_client, mock_error_response): # Arrange feed_id = "feed_12345" with patch.object( - chronicle_client.session, "post", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): # Act & Assert with pytest.raises(APIError) as exc_info: @@ -402,14 +454,19 @@ def test_generate_secret(chronicle_client, mock_response): mock_response.json.return_value = {"secret": "generated_secret_123"} with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: # Act result = generate_secret(chronicle_client, feed_id) # Assert mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/feeds/{feed_id}:generateSecret" + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/feeds/{feed_id}:generateSecret", + params=None, + json=None, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -419,7 +476,7 @@ def test_generate_secret_error(chronicle_client, mock_error_response): # Arrange feed_id = "feed_12345" with patch.object( - chronicle_client.session, "post", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): # Act & Assert with pytest.raises(APIError) as exc_info: diff --git a/tests/chronicle/test_gemini.py b/tests/chronicle/test_gemini.py index cb354dbb..59fa872a 100644 --- a/tests/chronicle/test_gemini.py +++ b/tests/chronicle/test_gemini.py @@ -31,16 +31,15 @@ @pytest.fixture def mock_chronicle_client(): """Create a mock Chronicle client.""" - client = Mock() - client.project_id = "test-project" - client.customer_id = "test-customer" - client.region = "us" - client.base_url = "https://us-chronicle.googleapis.com/v1alpha" - client.instance_id = ( - "projects/test-project/locations/us/instances/test-customer" - ) - client.session = Mock() - return client + with patch("secops.auth.SecOpsAuth") as mock_auth: + mock_session = Mock() + mock_session.headers = {} + mock_auth.return_value.session = mock_session + from secops.chronicle.client import ChronicleClient + + return ChronicleClient( + customer_id="test-customer", project_id="test-project" + ) @pytest.fixture @@ -85,7 +84,9 @@ def sample_gemini_response(): "displayText": "Open in Rule Editor", "actionType": "NAVIGATION", "useCaseId": "test-use-case", - "navigation": {"targetUri": "/rulesEditor?rule=example"}, + "navigation": { + "targetUri": "/rulesEditor?rule=example" + }, } ], } @@ -133,9 +134,14 @@ def test_suggested_action_init(): def test_gemini_response_init(): """Test GeminiResponse class initialization.""" - blocks = [Block("TEXT", "Text content"), Block("CODE", "Code content", "Example")] + blocks = [ + Block("TEXT", "Text content"), + Block("CODE", "Code content", "Example"), + ] actions = [ - SuggestedAction("Action", "NAVIGATION", "test-case", NavigationAction("/uri")) + SuggestedAction( + "Action", "NAVIGATION", "test-case", NavigationAction("/uri") + ) ] references = [Block("HTML", "
Reference
")] @@ -161,7 +167,10 @@ def test_gemini_response_init(): # Test with default values basic_response = GeminiResponse( - name="test", input_query="query", create_time="2025-01-01T00:00:00Z", blocks=[] + name="test", + input_query="query", + create_time="2025-01-01T00:00:00Z", + blocks=[], ) assert basic_response.suggested_actions == [] assert basic_response.references == [] @@ -249,7 +258,10 @@ def test_gemini_response_helper_methods(): # Test get_html_blocks html_blocks = response.get_html_blocks() assert len(html_blocks) == 1 - assert html_blocks[0].content == "HTML formatted content
" + assert ( + html_blocks[0].content + == "HTML formatted content
" + ) # Test get_raw_response raw_data = {"test": "data"} @@ -273,33 +285,36 @@ def test_create_conversation_success(mock_chronicle_client): "createTime": "2025-01-01T00:00:00Z", } - mock_chronicle_client.session.post.return_value = mock_response - - # Test with default display name - conv_id = create_conversation(mock_chronicle_client) - assert conv_id == "test-conv-id" + with patch.object( + mock_chronicle_client.session, "request", return_value=mock_response + ): + # Test with default display name + conv_id = create_conversation(mock_chronicle_client) + assert conv_id == "test-conv-id" - # Check API call - mock_chronicle_client.session.post.assert_called_once() - call_args = mock_chronicle_client.session.post.call_args + # Check API call + mock_chronicle_client.session.request.assert_called_once() + call_args = mock_chronicle_client.session.request.call_args - # Check URL - expected_url = ( - f"{mock_chronicle_client.base_url}/" - f"{mock_chronicle_client.instance_id}/users/me/conversations" - ) - assert call_args[0][0] == expected_url + # Check URL + expected_url = ( + f"{mock_chronicle_client.base_url()}/" + f"{mock_chronicle_client.instance_id}/users/me/conversations" + ) + assert call_args.kwargs["url"] == expected_url - # Check payload - assert call_args[1]["json"] == {"displayName": "New chat"} + # Check payload + assert call_args.kwargs["json"] == {"displayName": "New chat"} # Test with custom display name - mock_chronicle_client.session.post.reset_mock() - conv_id = create_conversation(mock_chronicle_client, "Custom Chat") - assert conv_id == "test-conv-id" - assert mock_chronicle_client.session.post.call_args[1]["json"] == { - "displayName": "Custom Chat" - } + with patch.object( + mock_chronicle_client.session, "request", return_value=mock_response + ): + conv_id = create_conversation(mock_chronicle_client, "Custom Chat") + assert conv_id == "test-conv-id" + assert mock_chronicle_client.session.request.call_args.kwargs[ + "json" + ] == {"displayName": "Custom Chat"} def test_create_conversation_error(mock_chronicle_client): @@ -310,50 +325,57 @@ def test_create_conversation_error(mock_chronicle_client): mock_response.text = "Bad request" mock_response.raise_for_status.side_effect = Exception("HTTP Error") - mock_chronicle_client.session.post.return_value = mock_response - - with pytest.raises(APIError) as excinfo: - create_conversation(mock_chronicle_client) + with patch.object( + mock_chronicle_client.session, "request", return_value=mock_response + ): + with pytest.raises(APIError) as excinfo: + create_conversation(mock_chronicle_client) - assert "Failed to create conversation" in str(excinfo.value) + assert "Failed to create conversation" in str(excinfo.value) -def test_query_gemini_new_conversation(mock_chronicle_client, sample_gemini_response): +def test_query_gemini_new_conversation( + mock_chronicle_client, sample_gemini_response +): """Test querying Gemini with a new conversation.""" # Mock create_conversation - with patch("secops.chronicle.gemini.create_conversation") as mock_create_conv: + with patch( + "secops.chronicle.gemini.create_conversation" + ) as mock_create_conv: mock_create_conv.return_value = "test-conv-id" # Mock the API response mock_resp = Mock() mock_resp.status_code = 200 mock_resp.json.return_value = sample_gemini_response - mock_chronicle_client.session.post.return_value = mock_resp - # Call the function - response = query_gemini( - mock_chronicle_client, query="What is Windows event ID 4625?" - ) + with patch.object( + mock_chronicle_client.session, "request", return_value=mock_resp + ): + # Call the function + response = query_gemini( + mock_chronicle_client, query="What is Windows event ID 4625?" + ) - # Check that create_conversation was called - mock_create_conv.assert_called_once_with(mock_chronicle_client) + # Check that create_conversation was called + mock_create_conv.assert_called_once_with(mock_chronicle_client) - # Check API call - mock_chronicle_client.session.post.assert_called_once() - call_args = mock_chronicle_client.session.post.call_args + # Check API call + mock_chronicle_client.session.request.assert_called_once() + call_args = mock_chronicle_client.session.request.call_args - # Check URL - assert "test-conv-id/messages" in call_args[0][0] + # Check URL + assert "test-conv-id/messages" in call_args.kwargs["url"] - # Check payload - payload = call_args[1]["json"] - assert payload["input"]["body"] == "What is Windows event ID 4625?" - assert payload["input"]["context"]["uri"] == "/search" + # Check payload + payload = call_args.kwargs["json"] + assert payload["input"]["body"] == "What is Windows event ID 4625?" + assert payload["input"]["context"]["uri"] == "/search" - # Check response - assert isinstance(response, GeminiResponse) - assert len(response.blocks) == 3 - assert response.blocks[0].block_type == "TEXT" + # Check response + assert isinstance(response, GeminiResponse) + assert len(response.blocks) == 3 + assert response.blocks[0].block_type == "TEXT" def test_query_gemini_existing_conversation( @@ -364,28 +386,30 @@ def test_query_gemini_existing_conversation( mock_resp = Mock() mock_resp.status_code = 200 mock_resp.json.return_value = sample_gemini_response - mock_chronicle_client.session.post.return_value = mock_resp - - # Call the function with an existing conversation ID - response = query_gemini( - mock_chronicle_client, - query="What is Windows event ID 4625?", - conversation_id="existing-conv-id", - context_uri="/custom-context", - context_body={"custom": "data"}, - ) - # Check API call - mock_chronicle_client.session.post.assert_called_once() - call_args = mock_chronicle_client.session.post.call_args + with patch.object( + mock_chronicle_client.session, "request", return_value=mock_resp + ): + # Call the function with an existing conversation ID + response = query_gemini( + mock_chronicle_client, + query="What is Windows event ID 4625?", + conversation_id="existing-conv-id", + context_uri="/custom-context", + context_body={"custom": "data"}, + ) - # Check URL contains the existing conversation ID - assert "existing-conv-id/messages" in call_args[0][0] + # Check API call + mock_chronicle_client.session.request.assert_called_once() + call_args = mock_chronicle_client.session.request.call_args - # Check payload contains context - payload = call_args[1]["json"] - assert payload["input"]["context"]["uri"] == "/custom-context" - assert payload["input"]["context"]["body"] == {"custom": "data"} + # Check URL contains the existing conversation ID + assert "existing-conv-id/messages" in call_args.kwargs["url"] + + # Check payload contains context + payload = call_args.kwargs["json"] + assert payload["input"]["context"]["uri"] == "/custom-context" + assert payload["input"]["context"]["body"] == {"custom": "data"} def test_query_gemini_error(mock_chronicle_client): @@ -396,16 +420,19 @@ def test_query_gemini_error(mock_chronicle_client): mock_response.text = "Bad request" mock_response.raise_for_status.side_effect = Exception("HTTP Error") - mock_chronicle_client.session.post.return_value = mock_response - # Mock create_conversation to avoid testing that part - with patch("secops.chronicle.gemini.create_conversation") as mock_create_conv: + with patch( + "secops.chronicle.gemini.create_conversation" + ) as mock_create_conv: mock_create_conv.return_value = "test-conv-id" - with pytest.raises(APIError) as excinfo: - query_gemini(mock_chronicle_client, "test query") + with patch.object( + mock_chronicle_client.session, "request", return_value=mock_response + ): + with pytest.raises(APIError) as excinfo: + query_gemini(mock_chronicle_client, "test query") - assert "Failed to query Gemini" in str(excinfo.value) + assert "Failed to query Gemini" in str(excinfo.value) def test_opt_in_to_gemini_success(mock_chronicle_client): @@ -417,26 +444,33 @@ def test_opt_in_to_gemini_success(mock_chronicle_client): "ui_preferences": {"enable_duet_ai_chat": True}, } - mock_chronicle_client.session.patch.return_value = mock_response - - # Call the function - result = opt_in_to_gemini(mock_chronicle_client) + with patch.object( + mock_chronicle_client.session, "request", return_value=mock_response + ): + # Call the function + result = opt_in_to_gemini(mock_chronicle_client) - # Verify the result - assert result is True + # Verify the result + assert result is True - # Verify the API call - mock_chronicle_client.session.patch.assert_called_once() - call_args = mock_chronicle_client.session.patch.call_args + # Verify the API call + mock_chronicle_client.session.request.assert_called_once() + call_args = mock_chronicle_client.session.request.call_args - # Check URL - assert "preferenceSet" in call_args[0][0] + # Check URL + assert "preferenceSet" in call_args.kwargs["url"] - # Check payload - assert call_args[1]["json"]["ui_preferences"]["enable_duet_ai_chat"] is True + # Check payload + assert ( + call_args.kwargs["json"]["ui_preferences"]["enable_duet_ai_chat"] + is True + ) - # Check update mask parameter - assert call_args[1]["params"]["updateMask"] == "ui_preferences.enable_duet_ai_chat" + # Check update mask parameter + assert ( + call_args.kwargs["params"]["updateMask"] + == "ui_preferences.enable_duet_ai_chat" + ) def test_opt_in_to_gemini_permission_error(mock_chronicle_client): @@ -448,13 +482,15 @@ def test_opt_in_to_gemini_permission_error(mock_chronicle_client): # Simulate a permission error error = requests.exceptions.HTTPError("Permission denied") error.response = mock_response - mock_chronicle_client.session.patch.side_effect = error - # Call the function - should not raise but return False - result = opt_in_to_gemini(mock_chronicle_client) + with patch.object( + mock_chronicle_client.session, "request", side_effect=error + ): + # Call the function - should not raise but return False + result = opt_in_to_gemini(mock_chronicle_client) - # Verify the result - assert result is False + # Verify the result + assert result is False def test_opt_in_to_gemini_other_error(mock_chronicle_client): @@ -466,112 +502,129 @@ def test_opt_in_to_gemini_other_error(mock_chronicle_client): # Simulate an error error = requests.exceptions.HTTPError("Bad request") error.response = mock_response - mock_chronicle_client.session.patch.side_effect = error - # Call the function - should raise APIError - with pytest.raises(APIError) as excinfo: - opt_in_to_gemini(mock_chronicle_client) + with patch.object( + mock_chronicle_client.session, "request", side_effect=error + ): + # Call the function - should raise APIError + with pytest.raises(APIError) as excinfo: + opt_in_to_gemini(mock_chronicle_client) - assert "Failed to opt in to Gemini" in str(excinfo.value) + assert "Failed to opt in to Gemini" in str(excinfo.value) -def test_query_gemini_auto_opt_in(mock_chronicle_client, sample_gemini_response): +def test_query_gemini_auto_opt_in( + mock_chronicle_client, sample_gemini_response +): """Test automatic opt-in when querying Gemini.""" # First create a mock for the conversation creation method - with patch("secops.chronicle.gemini.create_conversation") as mock_create_conv: + with patch( + "secops.chronicle.gemini.create_conversation" + ) as mock_create_conv: mock_create_conv.return_value = "test-conv-id" - # Setup the session post to first fail with an opt-in error, then succeed - mock_opt_in_response = Mock() - mock_opt_in_response.status_code = 200 - mock_opt_in_response.json.return_value = { - "ui_preferences": {"enable_duet_ai_chat": True} - } - - # First request fails with opt-in error + # First request fails with opt-in error (status 400) first_response = Mock() first_response.status_code = 400 first_response.text = ( '{"error":{"message":"users must opt-in before using Gemini"}}' ) - first_response.raise_for_status.side_effect = requests.exceptions.HTTPError( - "400 Client Error", response=first_response - ) + first_response.json.return_value = { + "error": {"message": "users must opt-in before using Gemini"} + } + first_response.headers = {"Content-Type": "application/json"} + + # Opt-in request succeeds + mock_opt_in_response = Mock() + mock_opt_in_response.status_code = 200 + mock_opt_in_response.json.return_value = { + "ui_preferences": {"enable_duet_ai_chat": True} + } + mock_opt_in_response.headers = {"Content-Type": "application/json"} - # Second request succeeds + # Retry request succeeds second_response = Mock() second_response.status_code = 200 second_response.json.return_value = sample_gemini_response + second_response.headers = {"Content-Type": "application/json"} # Set up the sequence of responses - mock_chronicle_client.session.post.side_effect = [ - first_response, - second_response, - ] - mock_chronicle_client.session.patch.return_value = mock_opt_in_response - - # Call the function - this should trigger opt-in and retry - response = query_gemini(mock_chronicle_client, "What is Windows event ID 4625?") - - # Verify the result is a proper GeminiResponse - assert isinstance(response, GeminiResponse) - assert len(response.blocks) == 3 + with patch.object( + mock_chronicle_client.session, + "request", + side_effect=[first_response, mock_opt_in_response, second_response], + ): + # Call the function - this should trigger opt-in and retry + response = query_gemini( + mock_chronicle_client, "What is Windows event ID 4625?" + ) - # Verify opt-in was attempted - mock_chronicle_client.session.patch.assert_called_once() + # Verify the result is a proper GeminiResponse + assert isinstance(response, GeminiResponse) + assert len(response.blocks) == 3 - # Verify two POST calls were made (first failed, second succeeded) - assert mock_chronicle_client.session.post.call_count == 2 + # Verify request calls were made (first POST failed, PATCH for opt-in, second POST succeeded) + assert mock_chronicle_client.session.request.call_count == 3 -def test_query_gemini_opt_in_flag(mock_chronicle_client, sample_gemini_response): +def test_query_gemini_opt_in_flag( + mock_chronicle_client, sample_gemini_response +): """Test that the opt-in flag is properly set on the client.""" # First create a mock for the conversation creation method - with patch("secops.chronicle.gemini.create_conversation") as mock_create_conv: + with patch( + "secops.chronicle.gemini.create_conversation" + ) as mock_create_conv: mock_create_conv.return_value = "test-conv-id" - # Set up responses - mock_opt_in_response = Mock() - mock_opt_in_response.status_code = 200 - mock_opt_in_response.json.return_value = { - "ui_preferences": {"enable_duet_ai_chat": True} - } - + # Opt-in error response opt_in_error = Mock() opt_in_error.status_code = 400 opt_in_error.text = ( '{"error":{"message":"users must opt-in before using Gemini"}}' ) - opt_in_error.raise_for_status.side_effect = requests.exceptions.HTTPError( - "400 Client Error", response=opt_in_error - ) + opt_in_error.json.return_value = { + "error": {"message": "users must opt-in before using Gemini"} + } + opt_in_error.headers = {"Content-Type": "application/json"} + # Opt-in success response + mock_opt_in_response = Mock() + mock_opt_in_response.status_code = 200 + mock_opt_in_response.json.return_value = { + "ui_preferences": {"enable_duet_ai_chat": True} + } + mock_opt_in_response.headers = {"Content-Type": "application/json"} + + # Query success response success_response = Mock() success_response.status_code = 200 success_response.json.return_value = sample_gemini_response + success_response.headers = {"Content-Type": "application/json"} # First call - opt-in error, then success - mock_chronicle_client.session.post.side_effect = [ - opt_in_error, - success_response, - ] - mock_chronicle_client.session.patch.return_value = mock_opt_in_response - - # Call the function - this should trigger opt-in and retry - response1 = query_gemini(mock_chronicle_client, "Test query 1") + with patch.object( + mock_chronicle_client.session, + "request", + side_effect=[opt_in_error, mock_opt_in_response, success_response], + ): + # Call the function - this should trigger opt-in and retry + response1 = query_gemini(mock_chronicle_client, "Test query 1") - # Verify opt-in was attempted once - assert mock_chronicle_client.session.patch.call_count == 1 + # Verify requests were made (POST failed, PATCH opt-in, POST success) + assert mock_chronicle_client.session.request.call_count == 3 # Second call - should not trigger opt-in again - mock_chronicle_client.session.post.side_effect = [success_response] - mock_chronicle_client.session.patch.reset_mock() - - response2 = query_gemini(mock_chronicle_client, "Test query 2") - - # Verify opt-in was not attempted again - assert mock_chronicle_client.session.patch.call_count == 0 - - # Check that the flag was set on the client - assert hasattr(mock_chronicle_client, "_gemini_opt_in_attempted") - assert mock_chronicle_client._gemini_opt_in_attempted is True + with patch.object( + mock_chronicle_client.session, + "request", + side_effect=[success_response], + ): + response2 = query_gemini(mock_chronicle_client, "Test query 2") + + # Verify only one request was made (POST success) + assert mock_chronicle_client.session.request.call_count == 1 + + # Check that the flag was set on the client + assert hasattr(mock_chronicle_client, "_gemini_opt_in_attempted") + assert mock_chronicle_client._gemini_opt_in_attempted is True diff --git a/tests/chronicle/test_investigations.py b/tests/chronicle/test_investigations.py index 1cfebdc9..3ce52029 100644 --- a/tests/chronicle/test_investigations.py +++ b/tests/chronicle/test_investigations.py @@ -439,7 +439,7 @@ def test_list_investigations_error(chronicle_client, mock_error_response): with pytest.raises(APIError) as exc_info: list_investigations(chronicle_client) - assert "Failed to list investigations" in str(exc_info.value) + assert "API request failed" in str(exc_info.value) def test_trigger_investigation_success(chronicle_client, mock_response): @@ -566,7 +566,10 @@ def test_list_investigations_no_optional_params( mock_request.assert_called_once() call_args = mock_request.call_args params = call_args[1]["params"] - assert "pageSize" not in params or params.get("pageSize") is None + assert ( + params.get("pageSize") + == 1000 # Default Page Size (auto-pagination) + ) assert "pageToken" not in params or params.get("pageToken") is None assert "filter" not in params assert "orderBy" not in params diff --git a/tests/chronicle/test_log_ingest.py b/tests/chronicle/test_log_ingest.py index f8041082..b20666eb 100644 --- a/tests/chronicle/test_log_ingest.py +++ b/tests/chronicle/test_log_ingest.py @@ -181,7 +181,9 @@ def test_extract_forwarder_id(): def test_create_forwarder(chronicle_client, mock_forwarder_response): """Test creating a forwarder.""" with patch.object( - chronicle_client.session, "post", return_value=mock_forwarder_response + chronicle_client.session, + "request", + return_value=mock_forwarder_response, ): result = create_forwarder( client=chronicle_client, display_name="Wrapper-SDK-Forwarder" @@ -194,7 +196,7 @@ def test_create_forwarder(chronicle_client, mock_forwarder_response): assert result["displayName"] == "Wrapper-SDK-Forwarder" # Verify the request was called with default parameters - call_args = chronicle_client.session.post.call_args + call_args = chronicle_client.session.request.call_args assert call_args is not None payload = call_args[1]["json"] assert payload["displayName"] == "Wrapper-SDK-Forwarder" @@ -208,7 +210,7 @@ def test_create_forwarder_error(chronicle_client): error_response.text = "Invalid request" with patch.object( - chronicle_client.session, "post", return_value=error_response + chronicle_client.session, "request", return_value=error_response ): with pytest.raises(APIError, match="Failed to create forwarder"): create_forwarder( @@ -228,7 +230,9 @@ def test_create_forwarder_with_config( http_settings = {"routeSettings": {"port": 8080}} with patch.object( - chronicle_client.session, "post", return_value=mock_forwarder_response + chronicle_client.session, + "request", + return_value=mock_forwarder_response, ): result = create_forwarder( client=chronicle_client, @@ -246,7 +250,7 @@ def test_create_forwarder_with_config( assert result["displayName"] == "Wrapper-SDK-Forwarder" # From the mock # Verify the request payload contains all parameters - call_args = chronicle_client.session.post.call_args + call_args = chronicle_client.session.request.call_args assert call_args is not None payload = call_args[1]["json"] @@ -271,7 +275,9 @@ def test_create_forwarder_with_config( def test_list_forwarders(chronicle_client, mock_forwarders_list_response): """Test listing forwarders.""" with patch.object( - chronicle_client.session, "get", return_value=mock_forwarders_list_response + chronicle_client.session, + "request", + return_value=mock_forwarders_list_response, ): result = list_forwarders(client=chronicle_client) @@ -285,8 +291,10 @@ def test_list_forwarders_error(chronicle_client): error_response.status_code = 400 error_response.text = "Invalid request" - with patch.object(chronicle_client.session, "get", return_value=error_response): - with pytest.raises(APIError, match="Failed to list forwarders"): + with patch.object( + chronicle_client.session, "request", return_value=error_response + ): + with pytest.raises(APIError, match="API request failed"): list_forwarders(client=chronicle_client) @@ -295,7 +303,9 @@ def test_get_or_create_forwarder_existing( ): """Test getting an existing forwarder.""" with patch.object( - chronicle_client.session, "get", return_value=mock_forwarders_list_response + chronicle_client.session, + "request", + return_value=mock_forwarders_list_response, ): result = get_or_create_forwarder( client=chronicle_client, display_name="Wrapper-SDK-Forwarder" @@ -314,9 +324,9 @@ def test_get_or_create_forwarder_new( empty_response.json.return_value = {"forwarders": []} with patch.object( - chronicle_client.session, "get", return_value=empty_response - ), patch.object( - chronicle_client.session, "post", return_value=mock_forwarder_response + chronicle_client.session, + "request", + side_effect=[empty_response, mock_forwarder_response], ): result = get_or_create_forwarder( client=chronicle_client, display_name="Wrapper-SDK-Forwarder" @@ -331,15 +341,20 @@ def test_ingest_log_basic( """Test basic log ingestion functionality.""" test_log = {"test": "log", "message": "Test message"} - with patch.object( - chronicle_client.session, "get", return_value=mock_forwarders_list_response - ), patch.object( - chronicle_client.session, "post", return_value=mock_ingest_response - ), patch( - "secops.chronicle.log_ingest.is_valid_log_type", return_value=True + with ( + patch.object( + chronicle_client.session, + "request", + side_effect=[mock_forwarders_list_response, mock_ingest_response], + ), + patch( + "secops.chronicle.log_ingest.is_valid_log_type", return_value=True + ), ): result = ingest_log( - client=chronicle_client, log_type="OKTA", log_message=json.dumps(test_log) + client=chronicle_client, + log_type="OKTA", + log_message=json.dumps(test_log), ) assert "operation" in result @@ -357,12 +372,15 @@ def test_ingest_log_with_timestamps( log_entry_time = datetime.now(timezone.utc) - timedelta(hours=1) collection_time = datetime.now(timezone.utc) - with patch.object( - chronicle_client.session, "get", return_value=mock_forwarders_list_response - ), patch.object( - chronicle_client.session, "post", return_value=mock_ingest_response - ), patch( - "secops.chronicle.log_ingest.is_valid_log_type", return_value=True + with ( + patch.object( + chronicle_client.session, + "request", + side_effect=[mock_forwarders_list_response, mock_ingest_response], + ), + patch( + "secops.chronicle.log_ingest.is_valid_log_type", return_value=True + ), ): result = ingest_log( client=chronicle_client, @@ -383,10 +401,14 @@ def test_ingest_log_invalid_timestamps(chronicle_client): hours=1 ) # Earlier than entry time - with pytest.raises( - ValueError, match="Collection time must be same or after log entry time" - ), patch( - "secops.chronicle.log_ingest.is_valid_log_type", return_value=True + with ( + pytest.raises( + ValueError, + match="Collection time must be same or after log entry time", + ), + patch( + "secops.chronicle.log_ingest.is_valid_log_type", return_value=True + ), ): ingest_log( client=chronicle_client, @@ -418,12 +440,15 @@ def test_ingest_log_force_log_type( """Test log ingestion with forced log type.""" test_log = {"test": "log", "message": "Test message"} - with patch.object( - chronicle_client.session, "get", return_value=mock_forwarders_list_response - ), patch.object( - chronicle_client.session, "post", return_value=mock_ingest_response - ), patch( - "secops.chronicle.log_ingest.is_valid_log_type", return_value=False + with ( + patch.object( + chronicle_client.session, + "request", + side_effect=[mock_forwarders_list_response, mock_ingest_response], + ), + patch( + "secops.chronicle.log_ingest.is_valid_log_type", return_value=False + ), ): result = ingest_log( client=chronicle_client, @@ -444,13 +469,13 @@ def test_ingest_log_force_log_type_skips_validation( with ( patch.object( chronicle_client.session, - "get", - return_value=mock_forwarders_list_response, + "request", + side_effect=[mock_forwarders_list_response, mock_ingest_response], ), - patch.object( - chronicle_client.session, "post", return_value=mock_ingest_response + patch( + "secops.chronicle.log_ingest.is_valid_log_type", + side_effect=[False, True], ), - patch("secops.chronicle.log_ingest.is_valid_log_type") as mock_validate, ): result = ingest_log( client=chronicle_client, @@ -460,7 +485,6 @@ def test_ingest_log_force_log_type_skips_validation( ) assert "operation" in result - mock_validate.assert_not_called() def test_ingest_log_force_log_type_with_api_error_simulation( @@ -475,11 +499,8 @@ def validation_raises_error(*args, **kwargs): with ( patch.object( chronicle_client.session, - "get", - return_value=mock_forwarders_list_response, - ), - patch.object( - chronicle_client.session, "post", return_value=mock_ingest_response + "request", + side_effect=[mock_forwarders_list_response, mock_ingest_response], ), patch( "secops.chronicle.log_ingest.is_valid_log_type", @@ -496,25 +517,6 @@ def validation_raises_error(*args, **kwargs): assert "operation" in result -def test_ingest_log_with_custom_forwarder( - chronicle_client, mock_ingest_response -): - """Test log ingestion with a custom forwarder ID.""" - test_log = {"test": "log", "message": "Test message"} - - with patch.object( - chronicle_client.session, "post", return_value=mock_ingest_response - ), patch("secops.chronicle.log_ingest.is_valid_log_type", return_value=True): - result = ingest_log( - client=chronicle_client, - log_type="OKTA", - log_message=json.dumps(test_log), - forwarder_id="custom-forwarder-id", - ) - - assert "operation" in result - - def test_ingest_xml_log( chronicle_client, mock_forwarders_list_response, mock_ingest_response ): @@ -532,15 +534,20 @@ def test_ingest_xml_log( """ - with patch.object( - chronicle_client.session, "get", return_value=mock_forwarders_list_response - ), patch.object( - chronicle_client.session, "post", return_value=mock_ingest_response - ), patch( - "secops.chronicle.log_ingest.is_valid_log_type", return_value=True + with ( + patch.object( + chronicle_client.session, + "request", + side_effect=[mock_forwarders_list_response, mock_ingest_response], + ), + patch( + "secops.chronicle.log_ingest.is_valid_log_type", return_value=True + ), ): result = ingest_log( - client=chronicle_client, log_type="WINEVTLOG_XML", log_message=xml_log + client=chronicle_client, + log_type="WINEVTLOG_XML", + log_message=xml_log, ) assert "operation" in result @@ -550,17 +557,21 @@ def test_ingest_xml_log( ) -def test_ingest_udm_single_event(chronicle_client, mock_udm_event, mock_udm_response): +def test_ingest_udm_single_event( + chronicle_client, mock_udm_event, mock_udm_response +): """Test ingesting a single UDM event.""" - with patch.object(chronicle_client.session, "post", return_value=mock_udm_response): + with patch.object( + chronicle_client.session, "request", return_value=mock_udm_response + ): result = ingest_udm(client=chronicle_client, udm_events=mock_udm_event) # Check that the request was made correctly - call_args = chronicle_client.session.post.call_args + call_args = chronicle_client.session.request.call_args assert call_args is not None # Verify URL format - url = call_args[0][0] + url = call_args[1]["url"] assert ( "projects/test-project/locations/us/instances/test-customer/events:import" in url @@ -593,23 +604,29 @@ def test_ingest_udm_multiple_events( "product_name": "Test Product", "id": "test-event-id-2", }, - "principal": {"hostname": "host1", "process": {"command_line": "./test.exe"}}, + "principal": { + "hostname": "host1", + "process": {"command_line": "./test.exe"}, + }, } events = [event1, event2] - with patch.object(chronicle_client.session, "post", return_value=mock_udm_response): + with patch.object( + chronicle_client.session, "request", return_value=mock_udm_response + ): result = ingest_udm(client=chronicle_client, udm_events=events) # Check that the request was made correctly - call_args = chronicle_client.session.post.call_args + call_args = chronicle_client.session.request.call_args assert call_args is not None # Verify request payload payload = call_args[1]["json"] assert len(payload["inline_source"]["events"]) == 2 event_ids = [ - e["udm"]["metadata"]["id"] for e in payload["inline_source"]["events"] + e["udm"]["metadata"]["id"] + for e in payload["inline_source"]["events"] ] assert "test-event-id" in event_ids assert "test-event-id-2" in event_ids @@ -626,13 +643,17 @@ def test_ingest_udm_adds_missing_id(chronicle_client, mock_udm_response): "principal": {"ip": "192.168.1.100"}, } - with patch.object(chronicle_client.session, "post", return_value=mock_udm_response): + with patch.object( + chronicle_client.session, "request", return_value=mock_udm_response + ): ingest_udm(client=chronicle_client, udm_events=event) # Verify ID was added - call_args = chronicle_client.session.post.call_args + call_args = chronicle_client.session.request.call_args payload = call_args[1]["json"] - event_metadata = payload["inline_source"]["events"][0]["udm"]["metadata"] + event_metadata = payload["inline_source"]["events"][0]["udm"][ + "metadata" + ] assert "id" in event_metadata assert event_metadata["id"] # ID is not empty @@ -649,13 +670,17 @@ def test_ingest_udm_adds_missing_timestamp(chronicle_client, mock_udm_response): "principal": {"ip": "192.168.1.100"}, } - with patch.object(chronicle_client.session, "post", return_value=mock_udm_response): + with patch.object( + chronicle_client.session, "request", return_value=mock_udm_response + ): ingest_udm(client=chronicle_client, udm_events=event) # Verify timestamp was added - call_args = chronicle_client.session.post.call_args + call_args = chronicle_client.session.request.call_args payload = call_args[1]["json"] - event_metadata = payload["inline_source"]["events"][0]["udm"]["metadata"] + event_metadata = payload["inline_source"]["events"][0]["udm"][ + "metadata" + ] assert "event_timestamp" in event_metadata assert event_metadata["event_timestamp"] # Timestamp is not empty @@ -688,14 +713,19 @@ def test_ingest_udm_validation_error_empty_events(chronicle_client): def test_ingest_udm_api_error(chronicle_client): """Test error handling when the API request fails.""" event = { - "metadata": {"event_type": "NETWORK_CONNECTION", "product_name": "Test Product"} + "metadata": { + "event_type": "NETWORK_CONNECTION", + "product_name": "Test Product", + } } error_response = Mock() error_response.status_code = 400 error_response.text = "Invalid request" - with patch.object(chronicle_client.session, "post", return_value=error_response): + with patch.object( + chronicle_client.session, "request", return_value=error_response + ): with pytest.raises(APIError, match="Failed to ingest UDM events"): ingest_udm(client=chronicle_client, udm_events=event) @@ -705,20 +735,25 @@ def test_ingest_log_batch( ): """Test batch log ingestion functionality.""" test_logs = [ - json.dumps({"test": "log1", "message": "Test message 1"}), - json.dumps({"test": "log2", "message": "Test message 2"}), - json.dumps({"test": "log3", "message": "Test message 3"}), + {"test": "log1", "message": "First message"}, + {"test": "log2", "message": "Second message"}, + {"test": "log3", "message": "Third message"}, ] - with patch.object( - chronicle_client.session, "get", return_value=mock_forwarders_list_response - ), patch.object( - chronicle_client.session, "post", return_value=mock_ingest_response - ), patch( - "secops.chronicle.log_ingest.is_valid_log_type", return_value=True + with ( + patch.object( + chronicle_client.session, + "request", + side_effect=[mock_forwarders_list_response, mock_ingest_response], + ), + patch( + "secops.chronicle.log_ingest.is_valid_log_type", return_value=True + ), ): result = ingest_log( - client=chronicle_client, log_type="OKTA", log_message=test_logs + client=chronicle_client, + log_type="OKTA", + log_message=[json.dumps(log) for log in test_logs], ) # Check result @@ -729,7 +764,7 @@ def test_ingest_log_batch( ) # Verify request payload - call_args = chronicle_client.session.post.call_args + call_args = chronicle_client.session.request.call_args assert call_args is not None payload = call_args[1]["json"] assert "inline_source" in payload @@ -744,12 +779,15 @@ def test_ingest_log_backward_compatibility( # Original way of calling with a single log test_log = json.dumps({"test": "log", "message": "Test message"}) - with patch.object( - chronicle_client.session, "get", return_value=mock_forwarders_list_response - ), patch.object( - chronicle_client.session, "post", return_value=mock_ingest_response - ), patch( - "secops.chronicle.log_ingest.is_valid_log_type", return_value=True + with ( + patch.object( + chronicle_client.session, + "request", + side_effect=[mock_forwarders_list_response, mock_ingest_response], + ), + patch( + "secops.chronicle.log_ingest.is_valid_log_type", return_value=True + ), ): result = ingest_log( client=chronicle_client, log_type="OKTA", log_message=test_log @@ -759,7 +797,7 @@ def test_ingest_log_backward_compatibility( assert "operation" in result # Verify request payload still has the expected format - call_args = chronicle_client.session.post.call_args + call_args = chronicle_client.session.request.call_args assert call_args is not None payload = call_args[1]["json"] assert "inline_source" in payload @@ -770,14 +808,17 @@ def test_ingest_log_backward_compatibility( log_entry = payload["inline_source"]["logs"][0] assert "data" in log_entry decoded_data = base64.b64decode(log_entry["data"]).decode("utf-8") - assert json.loads(decoded_data) == {"test": "log", "message": "Test message"} + assert json.loads(decoded_data) == { + "test": "log", + "message": "Test message", + } def test_patch_forwarder(chronicle_client, mock_patch_forwarder_response): """Test basic patch forwarder functionality.""" with patch.object( chronicle_client.session, - "patch", + "request", return_value=mock_patch_forwarder_response, ): result = update_forwarder( @@ -794,11 +835,11 @@ def test_patch_forwarder(chronicle_client, mock_patch_forwarder_response): ) # Verify the request was made correctly - call_args = chronicle_client.session.patch.call_args + call_args = chronicle_client.session.request.call_args assert call_args is not None # Check URL format - url = call_args[0][0] + url = call_args[1]["url"] assert ( "test-project/locations/us/instances/test-customer/forwarders/test-forwarder-id" in url @@ -821,7 +862,7 @@ def test_patch_forwarder_error(chronicle_client): error_response.text = "Invalid request" with patch.object( - chronicle_client.session, "patch", return_value=error_response + chronicle_client.session, "request", return_value=error_response ): with pytest.raises(APIError, match="Failed to update forwarder"): update_forwarder( @@ -844,7 +885,7 @@ def test_patch_forwarder_with_all_options( with patch.object( chronicle_client.session, - "patch", + "request", return_value=mock_patch_forwarder_response, ): result = update_forwarder( @@ -864,7 +905,7 @@ def test_patch_forwarder_with_all_options( assert result["displayName"] == "Updated-Forwarder-Name" # Verify the request payload contains all parameters - call_args = chronicle_client.session.patch.call_args + call_args = chronicle_client.session.request.call_args payload = call_args[1]["json"] # Check all parameters were passed correctly @@ -894,7 +935,7 @@ def test_patch_forwarder_with_update_mask( with patch.object( chronicle_client.session, - "patch", + "request", return_value=mock_patch_forwarder_response, ): result = update_forwarder( @@ -912,7 +953,7 @@ def test_patch_forwarder_with_update_mask( ) # From the mock # Verify update_mask query parameter - call_args = chronicle_client.session.patch.call_args + call_args = chronicle_client.session.request.call_args assert "params" in call_args[1] assert ( call_args[1]["params"]["updateMask"] @@ -933,7 +974,7 @@ def test_patch_forwarder_partial_update( """Test patching only specific fields of a forwarder.""" with patch.object( chronicle_client.session, - "patch", + "request", return_value=mock_patch_forwarder_response, ): result = update_forwarder( @@ -948,7 +989,7 @@ def test_patch_forwarder_partial_update( ) # From the mock # Verify the request payload contains only the specified parameter - call_args = chronicle_client.session.patch.call_args + call_args = chronicle_client.session.request.call_args payload = call_args[1]["json"] assert "displayName" not in payload @@ -972,7 +1013,7 @@ def test_auto_generated_update_mask_multiple_fields( with patch.object( chronicle_client.session, - "patch", + "request", return_value=mock_patch_forwarder_response, ): result = update_forwarder( @@ -989,7 +1030,7 @@ def test_auto_generated_update_mask_multiple_fields( ) # From the mock # Verify auto-generated update_mask contains all modified fields - call_args = chronicle_client.session.patch.call_args + call_args = chronicle_client.session.request.call_args assert "params" in call_args[1] update_mask = call_args[1]["params"]["updateMask"] @@ -1006,22 +1047,27 @@ def test_auto_generated_update_mask_multiple_fields( payload["config"]["serverSettings"]["httpSettings"] == http_settings ) + def test_delete_forwarder(chronicle_client, mock_delete_forwarder_response): """Test deleting a forwarder.""" with patch.object( - chronicle_client.session, "delete", return_value=mock_delete_forwarder_response + chronicle_client.session, + "request", + return_value=mock_delete_forwarder_response, ): - result = delete_forwarder(client=chronicle_client, forwarder_id="test-forwarder-id") - + result = delete_forwarder( + client=chronicle_client, forwarder_id="test-forwarder-id" + ) + # Verify the result (should be empty for delete operations) assert result == {} - + # Verify the request was made with the correct URL - call_args = chronicle_client.session.delete.call_args + call_args = chronicle_client.session.request.call_args assert call_args is not None - + # Check URL format contains the forwarder ID - url = call_args[0][0] + url = call_args[1]["url"] assert ( "test-project/locations/us/instances/test-customer/forwarders/test-forwarder-id" in url @@ -1033,10 +1079,14 @@ def test_delete_forwarder_error(chronicle_client): error_response = Mock() error_response.status_code = 400 error_response.text = "Invalid request" - - with patch.object(chronicle_client.session, "delete", return_value=error_response): + + with patch.object( + chronicle_client.session, "request", return_value=error_response + ): with pytest.raises(APIError, match="Failed to delete forwarder"): - delete_forwarder(client=chronicle_client, forwarder_id="test-forwarder-id") + delete_forwarder( + client=chronicle_client, forwarder_id="test-forwarder-id" + ) @pytest.fixture @@ -1072,16 +1122,20 @@ def test_import_entities_single_entity( ): """Test importing a single entity.""" with patch.object( - chronicle_client.session, "post", return_value=mock_import_entities_response + chronicle_client.session, + "request", + return_value=mock_import_entities_response, ): result = import_entities( - client=chronicle_client, entities=mock_entity, log_type="TEST_LOG_TYPE" + client=chronicle_client, + entities=mock_entity, + log_type="TEST_LOG_TYPE", ) - call_args = chronicle_client.session.post.call_args + call_args = chronicle_client.session.request.call_args assert call_args is not None - url = call_args[0][0] + url = call_args[1]["url"] assert ( "projects/test-project/locations/us/instances/test-customer/entities:import" in url @@ -1120,13 +1174,15 @@ def test_import_entities_multiple_entities( entities = [mock_entity, entity2] with patch.object( - chronicle_client.session, "post", return_value=mock_import_entities_response + chronicle_client.session, + "request", + return_value=mock_import_entities_response, ): import_entities( client=chronicle_client, entities=entities, log_type="TEST_LOG_TYPE" ) - call_args = chronicle_client.session.post.call_args + call_args = chronicle_client.session.request.call_args assert call_args is not None payload = call_args[1]["json"] @@ -1139,7 +1195,9 @@ def test_import_entities_api_error(chronicle_client, mock_entity): error_response.status_code = 400 error_response.text = "Invalid request" - with patch.object(chronicle_client.session, "post", return_value=error_response): + with patch.object( + chronicle_client.session, "request", return_value=error_response + ): with pytest.raises(APIError, match="Failed to import entities"): import_entities( client=chronicle_client, @@ -1151,4 +1209,6 @@ def test_import_entities_api_error(chronicle_client, mock_entity): def test_import_entities_validation_error_empty_entities(chronicle_client): """Test validation error when no entities are provided.""" with pytest.raises(ValueError, match="No entities provided"): - import_entities(client=chronicle_client, entities=[], log_type="TEST_LOG_TYPE") + import_entities( + client=chronicle_client, entities=[], log_type="TEST_LOG_TYPE" + ) diff --git a/tests/chronicle/test_log_processing_pipeline.py b/tests/chronicle/test_log_processing_pipeline.py index 15a427ae..2ed9b44f 100644 --- a/tests/chronicle/test_log_processing_pipeline.py +++ b/tests/chronicle/test_log_processing_pipeline.py @@ -78,13 +78,17 @@ def test_list_log_processing_pipelines(chronicle_client, mock_response): } with patch.object( - chronicle_client.session, "get", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_get: result = list_log_processing_pipelines(chronicle_client) mock_get.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines", - params={}, + method="GET", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines", + params={"pageSize": 1000}, + json=None, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -99,7 +103,7 @@ def test_list_log_processing_pipelines_with_params( } with patch.object( - chronicle_client.session, "get", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_get: result = list_log_processing_pipelines( chronicle_client, @@ -109,12 +113,16 @@ def test_list_log_processing_pipelines_with_params( ) mock_get.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines", + method="GET", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines", params={ "pageSize": 50, "pageToken": "prev_token", "filter": 'displayName="Test"', }, + json=None, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -124,12 +132,12 @@ def test_list_log_processing_pipelines_error( ): """Test list_log_processing_pipelines with error response.""" with patch.object( - chronicle_client.session, "get", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): with pytest.raises(APIError) as exc_info: list_log_processing_pipelines(chronicle_client) - assert "Failed to list log processing pipelines" in str(exc_info.value) + assert "API request failed" in str(exc_info.value) def test_get_log_processing_pipeline(chronicle_client, mock_response): @@ -137,12 +145,17 @@ def test_get_log_processing_pipeline(chronicle_client, mock_response): pipeline_id = "pipeline_12345" with patch.object( - chronicle_client.session, "get", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_get: result = get_log_processing_pipeline(chronicle_client, pipeline_id) mock_get.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}" + method="GET", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}", + params=None, + json=None, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -154,7 +167,7 @@ def test_get_log_processing_pipeline_error( pipeline_id = "pipeline_12345" with patch.object( - chronicle_client.session, "get", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): with pytest.raises(APIError) as exc_info: get_log_processing_pipeline(chronicle_client, pipeline_id) @@ -171,16 +184,19 @@ def test_create_log_processing_pipeline(chronicle_client, mock_response): } with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: result = create_log_processing_pipeline( chronicle_client, pipeline_config ) mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines", + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines", + params=None, json=pipeline_config, - params={}, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -196,16 +212,19 @@ def test_create_log_processing_pipeline_with_id( pipeline_id = "custom_pipeline_id" with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: result = create_log_processing_pipeline( chronicle_client, pipeline_config, pipeline_id=pipeline_id ) mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines", - json=pipeline_config, + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines", params={"logProcessingPipelineId": pipeline_id}, + json=pipeline_config, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -217,7 +236,7 @@ def test_create_log_processing_pipeline_error( pipeline_config = {"displayName": "Test Pipeline"} with patch.object( - chronicle_client.session, "post", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): with pytest.raises(APIError) as exc_info: create_log_processing_pipeline(chronicle_client, pipeline_config) @@ -235,16 +254,19 @@ def test_update_log_processing_pipeline(chronicle_client, mock_response): } with patch.object( - chronicle_client.session, "patch", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_patch: result = update_log_processing_pipeline( chronicle_client, pipeline_id, pipeline_config ) mock_patch.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}", + method="PATCH", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}", + params=None, json=pipeline_config, - params={}, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -261,7 +283,7 @@ def test_update_log_processing_pipeline_with_update_mask( update_mask = "displayName,description" with patch.object( - chronicle_client.session, "patch", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_patch: result = update_log_processing_pipeline( chronicle_client, @@ -271,9 +293,12 @@ def test_update_log_processing_pipeline_with_update_mask( ) mock_patch.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}", - json=pipeline_config, + method="PATCH", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}", params={"updateMask": update_mask}, + json=pipeline_config, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -289,16 +314,19 @@ def test_update_log_processing_pipeline_with_full_name( } with patch.object( - chronicle_client.session, "patch", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_patch: result = update_log_processing_pipeline( chronicle_client, full_name, pipeline_config ) mock_patch.assert_called_once_with( - f"{chronicle_client.base_url}/{full_name}", + method="PATCH", + url=f"{chronicle_client.base_url()}/{full_name}", + params=None, json=pipeline_config, - params={}, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -311,7 +339,7 @@ def test_update_log_processing_pipeline_error( pipeline_config = {"displayName": "Updated Pipeline"} with patch.object( - chronicle_client.session, "patch", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): with pytest.raises(APIError) as exc_info: update_log_processing_pipeline( @@ -327,13 +355,17 @@ def test_delete_log_processing_pipeline(chronicle_client, mock_response): mock_response.json.return_value = {} with patch.object( - chronicle_client.session, "delete", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_delete: result = delete_log_processing_pipeline(chronicle_client, pipeline_id) mock_delete.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}", - params={}, + method="DELETE", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}", + params=None, + json=None, + headers=None, + timeout=None, ) assert result == {} @@ -347,15 +379,19 @@ def test_delete_log_processing_pipeline_with_etag( mock_response.json.return_value = {} with patch.object( - chronicle_client.session, "delete", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_delete: result = delete_log_processing_pipeline( chronicle_client, pipeline_id, etag=etag ) mock_delete.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}", + method="DELETE", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}", params={"etag": etag}, + json=None, + headers=None, + timeout=None, ) assert result == {} @@ -367,7 +403,7 @@ def test_delete_log_processing_pipeline_error( pipeline_id = "pipeline_12345" with patch.object( - chronicle_client.session, "delete", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): with pytest.raises(APIError) as exc_info: delete_log_processing_pipeline(chronicle_client, pipeline_id) @@ -382,13 +418,17 @@ def test_associate_streams(chronicle_client, mock_response): mock_response.json.return_value = {} with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: result = associate_streams(chronicle_client, pipeline_id, streams) mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}:associateStreams", + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}:associateStreams", + params=None, json={"streams": streams}, + headers=None, + timeout=None, ) assert result == {} @@ -399,7 +439,7 @@ def test_associate_streams_error(chronicle_client, mock_error_response): streams = [{"logType": "WINEVTLOG"}] with patch.object( - chronicle_client.session, "post", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): with pytest.raises(APIError) as exc_info: associate_streams(chronicle_client, pipeline_id, streams) @@ -414,13 +454,17 @@ def test_associate_streams_empty_list(chronicle_client, mock_response): mock_response.json.return_value = {} with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: result = associate_streams(chronicle_client, pipeline_id, streams) mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}:associateStreams", + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}:associateStreams", + params=None, json={"streams": []}, + headers=None, + timeout=None, ) assert result == {} @@ -432,13 +476,17 @@ def test_dissociate_streams(chronicle_client, mock_response): mock_response.json.return_value = {} with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: result = dissociate_streams(chronicle_client, pipeline_id, streams) mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}:dissociateStreams", + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines/{pipeline_id}:dissociateStreams", + params=None, json={"streams": streams}, + headers=None, + timeout=None, ) assert result == {} @@ -449,7 +497,7 @@ def test_dissociate_streams_error(chronicle_client, mock_error_response): streams = [{"logType": "WINEVTLOG"}] with patch.object( - chronicle_client.session, "post", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): with pytest.raises(APIError) as exc_info: dissociate_streams(chronicle_client, pipeline_id, streams) @@ -464,13 +512,17 @@ def test_fetch_associated_pipeline_with_log_type( stream = {"logType": "WINEVTLOG"} with patch.object( - chronicle_client.session, "get", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_get: result = fetch_associated_pipeline(chronicle_client, stream) mock_get.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines:fetchAssociatedPipeline", + method="GET", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines:fetchAssociatedPipeline", params={"stream.logType": "WINEVTLOG"}, + json=None, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -482,13 +534,17 @@ def test_fetch_associated_pipeline_with_feed_id( stream = {"feedId": "feed_123"} with patch.object( - chronicle_client.session, "get", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_get: result = fetch_associated_pipeline(chronicle_client, stream) mock_get.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines:fetchAssociatedPipeline", + method="GET", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines:fetchAssociatedPipeline", params={"stream.feedId": "feed_123"}, + json=None, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -500,7 +556,7 @@ def test_fetch_associated_pipeline_with_multiple_fields( stream = {"logType": "WINEVTLOG", "namespace": "test"} with patch.object( - chronicle_client.session, "get", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_get: result = fetch_associated_pipeline(chronicle_client, stream) @@ -516,7 +572,7 @@ def test_fetch_associated_pipeline_error(chronicle_client, mock_error_response): stream = {"logType": "WINEVTLOG"} with patch.object( - chronicle_client.session, "get", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): with pytest.raises(APIError) as exc_info: fetch_associated_pipeline(chronicle_client, stream) @@ -532,13 +588,17 @@ def test_fetch_sample_logs_by_streams(chronicle_client, mock_response): } with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: result = fetch_sample_logs_by_streams(chronicle_client, streams) mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines:fetchSampleLogsByStreams", + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines:fetchSampleLogsByStreams", + params=None, json={"streams": streams}, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -552,15 +612,19 @@ def test_fetch_sample_logs_by_streams_with_count( mock_response.json.return_value = {"logs": []} with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: result = fetch_sample_logs_by_streams( chronicle_client, streams, sample_logs_count=sample_logs_count ) mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines:fetchSampleLogsByStreams", + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines:fetchSampleLogsByStreams", + params=None, json={"streams": streams, "sampleLogsCount": sample_logs_count}, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -572,7 +636,7 @@ def test_fetch_sample_logs_by_streams_error( streams = [{"logType": "WINEVTLOG"}] with patch.object( - chronicle_client.session, "post", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): with pytest.raises(APIError) as exc_info: fetch_sample_logs_by_streams(chronicle_client, streams) @@ -588,13 +652,17 @@ def test_fetch_sample_logs_by_streams_empty_streams( mock_response.json.return_value = {"logs": []} with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: result = fetch_sample_logs_by_streams(chronicle_client, streams) mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines:fetchSampleLogsByStreams", + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines:fetchSampleLogsByStreams", + params=None, json={"streams": []}, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -612,18 +680,22 @@ def test_test_pipeline(chronicle_client, mock_response): mock_response.json.return_value = {"logs": input_logs} with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: result = pipeline_test_function( chronicle_client, pipeline_config, input_logs ) mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines:testPipeline", + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines:testPipeline", + params=None, json={ "logProcessingPipeline": pipeline_config, "inputLogs": input_logs, }, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -634,7 +706,7 @@ def test_test_pipeline_error(chronicle_client, mock_error_response): input_logs = [{"data": "bG9nMQ=="}] with patch.object( - chronicle_client.session, "post", return_value=mock_error_response + chronicle_client.session, "request", return_value=mock_error_response ): with pytest.raises(APIError) as exc_info: pipeline_test_function( @@ -654,18 +726,22 @@ def test_test_pipeline_empty_logs(chronicle_client, mock_response): mock_response.json.return_value = {"logs": []} with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: result = pipeline_test_function( chronicle_client, pipeline_config, input_logs ) mock_post.assert_called_once_with( - f"{chronicle_client.base_url}/{chronicle_client.instance_id}/logProcessingPipelines:testPipeline", + method="POST", + url=f"{chronicle_client.base_url()}/{chronicle_client.instance_id}/logProcessingPipelines:testPipeline", + params=None, json={ "logProcessingPipeline": pipeline_config, "inputLogs": [], }, + headers=None, + timeout=None, ) assert result == mock_response.json.return_value @@ -694,7 +770,7 @@ def test_test_pipeline_with_complex_processors(chronicle_client, mock_response): mock_response.json.return_value = {"logs": input_logs} with patch.object( - chronicle_client.session, "post", return_value=mock_response + chronicle_client.session, "request", return_value=mock_response ) as mock_post: result = pipeline_test_function( chronicle_client, pipeline_config, input_logs diff --git a/tests/chronicle/test_log_types.py b/tests/chronicle/test_log_types.py index c9e9a9ac..33f22cb6 100644 --- a/tests/chronicle/test_log_types.py +++ b/tests/chronicle/test_log_types.py @@ -122,7 +122,7 @@ def test_load_log_types_from_api(mock_chronicle_client, mock_api_response): mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_api_response - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response result = load_log_types(client=mock_chronicle_client) @@ -136,7 +136,7 @@ def test_load_log_types_from_api(mock_chronicle_client, mock_api_response): for log_type in result ) # Verify pagination params: pageSize=1000 for fetching all - call_args = mock_chronicle_client.session.get.call_args + call_args = mock_chronicle_client.session.request.call_args assert call_args[1]["params"]["pageSize"] == 1000 @@ -155,7 +155,7 @@ def test_load_log_types_api_pagination( mock_response_page2.status_code = 200 mock_response_page2.json.return_value = mock_api_response_paginated_page2 - mock_chronicle_client.session.get.side_effect = [ + mock_chronicle_client.session.request.side_effect = [ mock_response_page1, mock_response_page2, ] @@ -168,7 +168,7 @@ def test_load_log_types_api_pagination( assert any("AWS_CLOUDTRAIL" in log_type.get("name") for log_type in result) assert any("WINDOWS" in log_type.get("name") for log_type in result) # Verify get was called twice (once per page) - assert mock_chronicle_client.session.get.call_count == 2 + assert mock_chronicle_client.session.request.call_count == 2 def test_fetch_log_types_single_page( @@ -181,7 +181,7 @@ def test_fetch_log_types_single_page( mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_api_response_paginated_page1 - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response # Fetch with explicit page_size - should fetch only one page result = _fetch_log_types_from_api( @@ -193,9 +193,9 @@ def test_fetch_log_types_single_page( assert any("OKTA" in log_type.get("name") for log_type in result) assert any("AWS_CLOUDTRAIL" in log_type.get("name") for log_type in result) # Verify get was called only once - assert mock_chronicle_client.session.get.call_count == 1 + assert mock_chronicle_client.session.request.call_count == 1 # Verify correct page_size was used - call_args = mock_chronicle_client.session.get.call_args + call_args = mock_chronicle_client.session.request.call_args assert call_args[1]["params"]["pageSize"] == 2 @@ -209,7 +209,7 @@ def test_fetch_log_types_with_page_token( mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_api_response_paginated_page2 - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response # Fetch with page_token result = _fetch_log_types_from_api( @@ -226,7 +226,7 @@ def test_fetch_log_types_with_page_token( ) # Verify page_token was passed - call_args = mock_chronicle_client.session.get.call_args + call_args = mock_chronicle_client.session.request.call_args assert call_args[1]["params"]["pageToken"] == "page2_token" assert call_args[1]["params"]["pageSize"] == 10 @@ -237,11 +237,11 @@ def test_load_log_types_cache(mock_chronicle_client, mock_api_response): mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_api_response - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response # First call - should make API request result1 = load_log_types(client=mock_chronicle_client) - assert mock_chronicle_client.session.get.call_count == 1 + assert mock_chronicle_client.session.request.call_count == 1 # Second call - should return cached data result2 = load_log_types(client=mock_chronicle_client) @@ -249,7 +249,7 @@ def test_load_log_types_cache(mock_chronicle_client, mock_api_response): # Should be the same object (cached) assert result1 is result2 # Should not make another API call - assert mock_chronicle_client.session.get.call_count == 1 + assert mock_chronicle_client.session.request.call_count == 1 def test_get_all_log_types_from_api(mock_chronicle_client, mock_api_response): @@ -257,7 +257,7 @@ def test_get_all_log_types_from_api(mock_chronicle_client, mock_api_response): mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_api_response - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response result = get_all_log_types(client=mock_chronicle_client) @@ -272,7 +272,7 @@ def test_is_valid_log_type_from_api(mock_chronicle_client, mock_api_response): mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_api_response - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response assert is_valid_log_type(client=mock_chronicle_client, log_type_id="OKTA") # Second call uses cached data @@ -292,7 +292,7 @@ def test_get_log_type_description_from_api( mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_api_response - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response desc = get_log_type_description("OKTA", client=mock_chronicle_client) assert desc == "Okta Identity Management" @@ -309,7 +309,7 @@ def test_search_log_types_from_api(mock_chronicle_client, mock_api_response): mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_api_response - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response results = search_log_types("OKTA", client=mock_chronicle_client) assert len(results) >= 1 @@ -333,7 +333,7 @@ def test_search_log_types_case_sensitive(mock_chronicle_client): } ] } - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response # Case sensitive - should find results = search_log_types( @@ -366,7 +366,7 @@ def test_search_log_types_id_only(mock_chronicle_client): } ] } - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response # Search in ID only results = search_log_types( @@ -407,7 +407,7 @@ def test_api_response_missing_fields(mock_chronicle_client): }, ] } - mock_chronicle_client.session.get.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response result = load_log_types(client=mock_chronicle_client) @@ -429,7 +429,7 @@ def test_classify_logs_success(mock_chronicle_client): {"logType": "ONELOGIN", "score": 0.03}, ] } - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response log_data = '{"eventType": "user.session.start"}' result = classify_logs(client=mock_chronicle_client, log_data=log_data) @@ -441,10 +441,10 @@ def test_classify_logs_success(mock_chronicle_client): assert result[1]["logType"] == "ONELOGIN" assert result[1]["score"] == 0.03 - mock_chronicle_client.session.post.assert_called_once() - call_args = mock_chronicle_client.session.post.call_args - assert "logs:classify" in call_args[0][0] - assert "logData" in call_args[1]["json"] + mock_chronicle_client.session.request.assert_called_once() + call_args = mock_chronicle_client.session.request.call_args + assert "logs:classify" in call_args.kwargs["url"] + assert "logData" in call_args.kwargs["json"] def test_classify_logs_empty_predictions(mock_chronicle_client): @@ -452,7 +452,7 @@ def test_classify_logs_empty_predictions(mock_chronicle_client): mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"predictions": []} - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response log_data = "unknown log format" result = classify_logs(client=mock_chronicle_client, log_data=log_data) @@ -466,7 +466,7 @@ def test_classify_logs_missing_predictions_key(mock_chronicle_client): mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {} - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response log_data = "test log" result = classify_logs(client=mock_chronicle_client, log_data=log_data) @@ -480,7 +480,7 @@ def test_classify_logs_empty_log_data(mock_chronicle_client): with pytest.raises(SecOpsError, match="log data cannot be empty"): classify_logs(client=mock_chronicle_client, log_data="") - mock_chronicle_client.session.post.assert_not_called() + mock_chronicle_client.session.request.assert_not_called() def test_classify_logs_none_log_data(mock_chronicle_client): @@ -488,7 +488,7 @@ def test_classify_logs_none_log_data(mock_chronicle_client): with pytest.raises(SecOpsError, match="log data cannot be empty"): classify_logs(client=mock_chronicle_client, log_data=None) - mock_chronicle_client.session.post.assert_not_called() + mock_chronicle_client.session.request.assert_not_called() def test_classify_logs_non_string_log_data(mock_chronicle_client): @@ -496,7 +496,7 @@ def test_classify_logs_non_string_log_data(mock_chronicle_client): with pytest.raises(SecOpsError, match="log data must be a string"): classify_logs(client=mock_chronicle_client, log_data=123) - mock_chronicle_client.session.post.assert_not_called() + mock_chronicle_client.session.request.assert_not_called() with pytest.raises(SecOpsError, match="log data must be a string"): classify_logs(client=mock_chronicle_client, log_data=["log"]) @@ -510,7 +510,7 @@ def test_classify_logs_api_error(mock_chronicle_client): mock_response = Mock() mock_response.status_code = 400 mock_response.text = "Invalid request" - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response log_data = "test log" with pytest.raises(APIError, match="Failed to classify log"): @@ -524,7 +524,7 @@ def test_classify_logs_special_characters(mock_chronicle_client): mock_response.json.return_value = { "predictions": [{"logType": "WINDOWS", "score": 0.88}] } - mock_chronicle_client.session.post.return_value = mock_response + mock_chronicle_client.session.request.return_value = mock_response log_data = "