diff --git a/aws_advanced_python_wrapper/aws_secrets_manager_plugin.py b/aws_advanced_python_wrapper/aws_secrets_manager_plugin.py index 6fb696e19..2ec929b6f 100644 --- a/aws_advanced_python_wrapper/aws_secrets_manager_plugin.py +++ b/aws_advanced_python_wrapper/aws_secrets_manager_plugin.py @@ -14,6 +14,7 @@ from __future__ import annotations +from datetime import timedelta from json import JSONDecodeError, loads from re import search from types import SimpleNamespace @@ -23,7 +24,7 @@ from aws_advanced_python_wrapper.aws_credentials_manager import \ AwsCredentialsManager -from aws_advanced_python_wrapper.utils.cache_map import CacheMap +from aws_advanced_python_wrapper.utils import services_container if TYPE_CHECKING: from boto3 import Session @@ -46,6 +47,13 @@ logger = Logger(__name__) +class Secret: + """Wrapper type for secrets, used as StorageService type key.""" + + def __init__(self, value: SimpleNamespace): + self.value = value + + class AwsSecretsManagerPlugin(Plugin): _SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.CONNECT.method_name, DbApiMethod.FORCE_CONNECT.method_name} @@ -53,7 +61,6 @@ class AwsSecretsManagerPlugin(Plugin): _ONE_YEAR_IN_SECONDS = 60 * 60 * 24 * 365 _secret: Optional[SimpleNamespace] = None - _secrets_cache: CacheMap[Tuple, SimpleNamespace] = CacheMap() _secret_key: Tuple = () @property @@ -63,6 +70,8 @@ def subscribed_methods(self) -> Set[str]: def __init__(self, plugin_service: PluginService, props: Properties, session: Optional[Session] = None): self._plugin_service = plugin_service self._session = session + self._storage_service = services_container.get_storage_service() + self._storage_service.register(Secret, item_expiration_time=timedelta(minutes=30)) secret_id = WrapperProperties.SECRETS_MANAGER_SECRET_ID.get(props) if not secret_id: @@ -100,13 +109,13 @@ def force_connect( return self._connect(host_info, props, force_connect_func) def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callable) -> Connection: - token_expiration_sec: int = WrapperProperties.SECRETS_MANAGER_EXPIRATION.get_int(props) + token_expiration_sec = WrapperProperties.SECRETS_MANAGER_EXPIRATION.get_int(props) # if value is less than 0, default to one year if token_expiration_sec < 0: token_expiration_sec = AwsSecretsManagerPlugin._ONE_YEAR_IN_SECONDS token_expiration_ns = token_expiration_sec * 1_000_000_000 - secret_fetched: bool = self._update_secret(host_info, props, token_expiration_ns=token_expiration_ns) + secret_fetched: bool = self._update_secret(host_info, props, token_expiration_ns) try: self._apply_secret_to_properties(props) @@ -120,7 +129,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl raise AwsWrapperError( Messages.get_formatted("AwsSecretsManagerPlugin.ConnectException", e), e) from e - secret_fetched = self._update_secret(host_info, props, token_expiration_ns=token_expiration_ns, force_refetch=True) + secret_fetched = self._update_secret(host_info, props, token_expiration_ns, force_refetch=True) if secret_fetched: try: @@ -146,13 +155,14 @@ def _update_secret(self, host_info: HostInfo, props: Properties, token_expiratio try: fetched: bool = False - self._secret: Optional[SimpleNamespace] = AwsSecretsManagerPlugin._secrets_cache.get(self._secret_key) + cached_secret = self._storage_service.get(Secret, self._secret_key) + self._secret = cached_secret.value if cached_secret is not None else None endpoint = self._secret_key[2] if not self._secret or force_refetch: try: self._secret = self._fetch_latest_credentials(host_info, props) if self._secret: - AwsSecretsManagerPlugin._secrets_cache.put(self._secret_key, self._secret, token_expiration_ns) + self._storage_service.put(Secret, self._secret_key, Secret(self._secret), item_expiration_ns=token_expiration_ns) fetched = True except (ClientError, AttributeError) as e: logger.debug("AwsSecretsManagerPlugin.FailedToFetchDbCredentials", e) diff --git a/aws_advanced_python_wrapper/blue_green_plugin.py b/aws_advanced_python_wrapper/blue_green_plugin.py index 2cff6c6b4..386caa53c 100644 --- a/aws_advanced_python_wrapper/blue_green_plugin.py +++ b/aws_advanced_python_wrapper/blue_green_plugin.py @@ -15,7 +15,7 @@ from __future__ import annotations import socket -from datetime import datetime +from datetime import datetime, timedelta from time import perf_counter_ns from typing import TYPE_CHECKING, FrozenSet, List, cast @@ -44,9 +44,11 @@ from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.atomic import AtomicInt from aws_advanced_python_wrapper.utils.concurrent import (ConcurrentDict, ConcurrentSet) +from aws_advanced_python_wrapper.utils.events import MonitorResetEvent from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, @@ -901,6 +903,9 @@ def _run(self): finally: self._close_connection() + if self._host_list_provider is not None: + self._host_list_provider.stop_monitor() + self._host_list_provider = None logger.debug("BlueGreenStatusMonitor.ThreadCompleted", self._bg_role) def _open_connection(self): @@ -1237,6 +1242,8 @@ def __init__(self, plugin_service: PluginService, props: Properties, bg_id: str) self._green_dns_removed = False self._green_topology_changed = False self._all_green_hosts_changed_name = False + self._monitor_reset_on_in_progress_completed = False + self._monitor_reset_on_topology_completed = False self._post_status_end_time_ns = 0 self._process_status_lock = RLock() self._status_check_intervals_ms: Dict[BlueGreenIntervalRate, int] = {} @@ -1258,6 +1265,8 @@ def __init__(self, plugin_service: PluginService, props: Properties, bg_id: str) Messages.get_formatted( "BlueGreenStatusProvider.UnsupportedDialect", self._bg_id, dialect.__class__.__name__)) + services_container.get_storage_service().register(BlueGreenStatus, item_expiration_time=timedelta(hours=1)) + current_host_info = self._plugin_service.current_host_info blue_monitor = BlueGreenStatusMonitor( BlueGreenRole.SOURCE, @@ -1476,6 +1485,7 @@ def _update_summary_status(self, bg_role: BlueGreenRole, interim_status: BlueGre elif self._latest_phase == BlueGreenPhase.IN_PROGRESS: self._update_dns_flags(bg_role, interim_status) self._summary_status = self._get_status_of_in_progress() + self._reset_monitors("_monitor_reset_on_in_progress_completed", "- start") elif self._latest_phase == BlueGreenPhase.POST: self._update_dns_flags(bg_role, interim_status) @@ -1503,6 +1513,7 @@ def _update_dns_flags(self, bg_role: BlueGreenRole, interim_status: BlueGreenInt logger.debug("BlueGreenStatusProvider.GreenTopologyChanged", self._bg_id) self._green_topology_changed = True self._store_event_phase_time("Green topology changed") + self._reset_monitors("_monitor_reset_on_topology_completed", "- green topology") def _store_event_phase_time(self, key_prefix: str, phase: Optional[BlueGreenPhase] = None): rollback_str = " (rollback)" if self._rollback else "" @@ -1846,6 +1857,20 @@ def _update_status_cache(self): with latest_status.cv: latest_status.cv.notify_all() + def _reset_monitors(self, completed_flag_attr: str, event_name: str): + if getattr(self, completed_flag_attr): + return + setattr(self, completed_flag_attr, True) + + blue_endpoints = frozenset( + host for host, role in self._roles_by_host.items() + if role == BlueGreenRole.SOURCE) + + cluster_id = self._plugin_service.host_list_provider.get_cluster_id() + services_container.get_event_publisher().publish( + MonitorResetEvent(cluster_id=cluster_id, endpoints=blue_endpoints)) + self._store_event_phase_time(f"Monitor reset {event_name}") + def _log_current_context(self): logger.debug(f"[bg_id: '{self._bg_id}'] Summary status: \n{self._summary_status}") hosts_str = "\n".join( @@ -1913,6 +1938,8 @@ def _reset_context_when_completed(self): self._green_dns_removed = False self._green_topology_changed = False self._all_green_hosts_changed_name = False + self._monitor_reset_on_in_progress_completed = False + self._monitor_reset_on_topology_completed = False self._post_status_end_time_ns = 0 self._interim_status_hashes = [0, 0] self._latest_context_hash = 0 diff --git a/aws_advanced_python_wrapper/cleanup.py b/aws_advanced_python_wrapper/cleanup.py index 2eb20ead0..8430e5456 100644 --- a/aws_advanced_python_wrapper/cleanup.py +++ b/aws_advanced_python_wrapper/cleanup.py @@ -16,18 +16,11 @@ OpenedConnectionTracker from aws_advanced_python_wrapper.aws_credentials_manager import \ AwsCredentialsManager -from aws_advanced_python_wrapper.host_monitoring_plugin import \ - MonitoringThreadContainer -from aws_advanced_python_wrapper.thread_pool_container import \ - ThreadPoolContainer -from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ - SlidingExpirationCacheContainer +from aws_advanced_python_wrapper.utils import services_container def release_resources() -> None: """Release all global resources used by the wrapper.""" - MonitoringThreadContainer.clean_up() - ThreadPoolContainer.release_resources() + services_container.release_resources() AwsCredentialsManager.release_resources() OpenedConnectionTracker.release_resources() - SlidingExpirationCacheContainer.release_resources() diff --git a/aws_advanced_python_wrapper/cluster_topology_monitor.py b/aws_advanced_python_wrapper/cluster_topology_monitor.py index 47b178f4b..173cb0014 100644 --- a/aws_advanced_python_wrapper/cluster_topology_monitor.py +++ b/aws_advanced_python_wrapper/cluster_topology_monitor.py @@ -18,16 +18,18 @@ import time from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor +from time import perf_counter_ns from typing import TYPE_CHECKING, Dict, Optional from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_availability import HostAvailability -from aws_advanced_python_wrapper.hostinfo import HostInfo +from aws_advanced_python_wrapper.hostinfo import HostInfo, Topology +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.atomic import AtomicReference +from aws_advanced_python_wrapper.utils.events import (EventBase, + MonitorResetEvent) from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils -from aws_advanced_python_wrapper.utils.storage.storage_service import ( - StorageService, Topology) from aws_advanced_python_wrapper.utils.thread_safe_connection_holder import \ ThreadSafeConnectionHolder from aws_advanced_python_wrapper.utils.utils import LogUtils @@ -55,10 +57,20 @@ def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topolog def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Topology: pass + @property @abstractmethod def can_dispose(self) -> bool: pass + @abstractmethod + def stop(self) -> None: + pass + + @property + @abstractmethod + def last_activity_ns(self) -> int: + pass + @abstractmethod def close(self) -> None: pass @@ -109,6 +121,7 @@ def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils, self._high_refresh_rate_end_time_nano = 0 self._stop = threading.Event() self._monitor_thread: Optional[threading.Thread] = None + self._last_activity_ns: int = perf_counter_ns() self._monitoring_properties = PropertiesUtils.create_topology_monitoring_properties(properties) if WrapperProperties.SOCKET_TIMEOUT_SEC.get(self._monitoring_properties) is None: @@ -124,7 +137,7 @@ def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topolog current_time_nano < self._ignore_new_topology_requests_end_time_nano): current_hosts = self._get_stored_hosts() if current_hosts is not None: - logger.debug("ClusterTopologyMonitorImpl.IgnoringTopologyRequest", self._cluster_id, LogUtils.log_topology(current_hosts)) + logger.debug("ClusterTopologyMonitor.IgnoringTopologyRequest", self._cluster_id, LogUtils.log_topology(current_hosts)) return current_hosts if should_verify_writer: @@ -145,7 +158,7 @@ def _wait_till_topology_gets_updated(self, timeout_sec: int) -> Topology: self._request_to_update_topology.set() if timeout_sec == 0: - logger.debug("ClusterTopologyMonitorImpl.TimeoutSetToZero", self._cluster_id, LogUtils.log_topology(current_hosts)) + logger.debug("ClusterTopologyMonitor.TimeoutSetToZero", self._cluster_id, LogUtils.log_topology(current_hosts)) return current_hosts end_time = time.time() + timeout_sec @@ -159,21 +172,45 @@ def _wait_till_topology_gets_updated(self, timeout_sec: int) -> Topology: raise TimeoutError( Messages.get_formatted( - "ClusterTopologyMonitorImpl.TopologyNotUpdated", + "ClusterTopologyMonitor.TopologyNotUpdated", self._cluster_id, timeout_sec * 1000)) def _get_stored_hosts(self) -> Topology: - hosts = StorageService.get(Topology, self._cluster_id) + hosts = services_container.get_storage_service().get(Topology, self._cluster_id) if hosts is None: return () return hosts + def stop(self) -> None: + self._stop.set() + self.close() + + @property def can_dispose(self) -> bool: return self._stop.is_set() + @property + def last_activity_ns(self) -> int: + return self._last_activity_ns + + def process_event(self, event: EventBase) -> None: + if isinstance(event, MonitorResetEvent) and event.cluster_id == self._cluster_id: + logger.debug("ClusterTopologyMonitor.ResetEventReceived", self._cluster_id) + self._host_threads_stop.set() + self._close_host_monitors() + self._close_connection_from_ref(self._host_threads_writer_connection) + self._close_connection_from_ref(self._host_threads_reader_connection) + self._host_threads_stop.clear() + self._submitted_hosts.clear() + self._host_threads_writer_host_info.set(None) + self._host_threads_latest_topology.set(None) + self._monitoring_connection.clear() + self._is_verified_writer_connection = False + self._writer_host_info.set(None) + self._high_refresh_rate_end_time_nano = 0 + def close(self) -> None: - logger.debug("ClusterTopologyMonitorImpl.ClosingMonitor", self._cluster_id) - self._stop.set() + logger.debug("ClusterTopologyMonitor.ClosingMonitor", self._cluster_id) self._request_to_update_topology.set() self._close_host_monitors() @@ -195,6 +232,7 @@ def _monitor(self) -> None: logger.debug("ClusterTopologyMonitor.StartMonitoringThread", self._cluster_id, self._initial_host_info.host) while not self._stop.is_set(): + self._last_activity_ns = perf_counter_ns() if self._is_in_panic_mode(): if not self._submitted_hosts: self._close_host_monitors() @@ -207,7 +245,7 @@ def _monitor(self) -> None: hosts = self._open_any_connection_and_update_topology() if hosts and not self._is_verified_writer_connection: - logger.debug("ClusterTopologyMonitorImpl.StartingHostMonitoringThreads", self._cluster_id) + logger.debug("ClusterTopologyMonitor.StartingHostMonitoringThreads", self._cluster_id) writer_host_info = self._writer_host_info.get() for host_info in hosts: if host_info.host not in self._submitted_hosts: @@ -217,14 +255,14 @@ def _monitor(self) -> None: self._submitted_hosts[host_info.host] = True except Exception as e: logger.debug( - "ClusterTopologyMonitorImpl.ExceptionStartingHostMonitor", + "ClusterTopologyMonitor.ExceptionStartingHostMonitor", self._cluster_id, host_info.host, e) else: # Check if writer has been detected writer_host_info = self._host_threads_writer_host_info.get() writer_connection = self._host_threads_writer_connection.get() if (writer_connection is not None and writer_host_info is not None): - logger.debug("ClusterTopologyMonitorImpl.WriterPickedUpFromHostMonitors", self._cluster_id, writer_host_info.host) + logger.debug("ClusterTopologyMonitor.WriterPickedUpFromHostMonitors", self._cluster_id, writer_host_info.host) # Transfer the writer connection to monitoring connection self._monitoring_connection.set(writer_connection, close_previous=True) self._writer_host_info.set(writer_host_info) @@ -254,7 +292,7 @@ def _monitor(self) -> None: self._submitted_hosts[host_info.host] = True except Exception as e: logger.debug( - "ClusterTopologyMonitorImpl.ExceptionStartingHostMonitor", + "ClusterTopologyMonitor.ExceptionStartingHostMonitor", self._cluster_id, host_info.host, e) self._delay(True) @@ -283,7 +321,7 @@ def _monitor(self) -> None: self._ignore_new_topology_requests_end_time_nano = 0 except Exception as ex: - logger.info("ClusterTopologyMonitorImpl.ExceptionDuringMonitoringStop", self._cluster_id, ex) + logger.info("ClusterTopologyMonitor.ExceptionDuringMonitoringStop", self._cluster_id, ex) finally: self._stop.set() self._close_host_monitors() @@ -303,7 +341,7 @@ def _open_any_connection_and_update_topology(self) -> Topology: try: conn = self._plugin_service.force_connect(self._initial_host_info, self._monitoring_properties) self._monitoring_connection.set(conn, close_previous=False) - logger.debug("ClusterTopologyMonitorImpl.OpenedMonitoringConnection", + logger.debug("ClusterTopologyMonitor.OpenedMonitoringConnection", self._cluster_id, self._initial_host_info.host) try: @@ -330,7 +368,7 @@ def _open_any_connection_and_update_topology(self) -> Topology: host_id=writer_id) self._writer_host_info.set(writer_host_info) - logger.debug("ClusterTopologyMonitorImpl.WriterMonitoringConnection", + logger.debug("ClusterTopologyMonitor.WriterMonitoringConnection", self._cluster_id, writer_host_info.host) except Exception: pass @@ -385,7 +423,7 @@ def _close_host_monitors(self) -> None: def _get_host_executor_service(self) -> ThreadPoolExecutor: if self._stop.is_set(): raise RuntimeError(Messages.get_formatted( - "ClusterTopologyMonitorImpl.CannotCreateExecutorWhenStopped", self._cluster_id)) + "ClusterTopologyMonitor.CannotCreateExecutorWhenStopped", self._cluster_id)) thread_pool_executor = self._thread_pool_executor.get() if thread_pool_executor is None: thread_pool_executor = ThreadPoolExecutor(thread_name_prefix=self._cluster_id) @@ -421,7 +459,7 @@ def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) -> return hosts return () except Exception as ex: - logger.debug("ClusterTopologyMonitorImpl.ErrorFetchingTopology", self._cluster_id, ex) + logger.debug("ClusterTopologyMonitor.ErrorFetchingTopology", self._cluster_id, ex) return () def _fetch_topology_and_update_cache_safe(self) -> Topology: @@ -444,7 +482,7 @@ def _get_instance_template(self, instance_id: str, connection: Connection) -> Ho return self._instance_template def _update_topology_cache(self, hosts: Topology) -> None: - StorageService.set(self._cluster_id, hosts, Topology) + services_container.get_storage_service().put(Topology, self._cluster_id, hosts) # Notify waiting threads self._request_to_update_topology.clear() self._topology_updated.set() diff --git a/aws_advanced_python_wrapper/custom_endpoint_plugin.py b/aws_advanced_python_wrapper/custom_endpoint_plugin.py index 81f3e28fe..907530963 100644 --- a/aws_advanced_python_wrapper/custom_endpoint_plugin.py +++ b/aws_advanced_python_wrapper/custom_endpoint_plugin.py @@ -22,9 +22,9 @@ from aws_advanced_python_wrapper.allowed_and_blocked_hosts import \ AllowedAndBlockedHosts from aws_advanced_python_wrapper.errors import AwsWrapperError -from aws_advanced_python_wrapper.utils.cache_map import CacheMap from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.region_utils import RegionUtils +from aws_advanced_python_wrapper.utils.storage.cache_map import CacheMap if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect @@ -39,11 +39,10 @@ from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.properties import WrapperProperties from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils -from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ - SlidingExpirationCacheContainer from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryCounter, TelemetryFactory) @@ -136,6 +135,7 @@ def __init__(self, self._client = self._session.client('rds', region_name=region) self._stop_event = Event() + self._last_activity_ns: int = perf_counter_ns() telemetry_factory = self._plugin_service.get_telemetry_factory() self._info_changed_counter = telemetry_factory.create_counter("customEndpoint.infoChanged.counter") @@ -148,6 +148,7 @@ def _run(self): try: while not self._stop_event.is_set(): try: + self._last_activity_ns = perf_counter_ns() start_ns = perf_counter_ns() response = self._client.describe_db_cluster_endpoints( @@ -220,10 +221,21 @@ def _run(self): def has_custom_endpoint_info(self): return CustomEndpointMonitor._custom_endpoint_info_cache.get(self._custom_endpoint_host_info.host) is not None + def stop(self) -> None: + self._stop_event.set() + self.close() + + @property + def can_dispose(self) -> bool: + return self._stop_event.is_set() + + @property + def last_activity_ns(self) -> int: + return self._last_activity_ns + def close(self): logger.debug("CustomEndpointMonitor.StoppingMonitor", self._custom_endpoint_host_info.host) CustomEndpointMonitor._custom_endpoint_info_cache.remove(self._custom_endpoint_host_info.host) - self._stop_event.set() class CustomEndpointPlugin(Plugin): @@ -232,8 +244,6 @@ class CustomEndpointPlugin(Plugin): or removing an instance in the custom endpoint. """ _SUBSCRIBED_METHODS: ClassVar[Set[str]] = {DbApiMethod.CONNECT.method_name} - _CACHE_CLEANUP_RATE_NS: ClassVar[int] = 60_000_000_000 # 1 minute - _MONITOR_CACHE_NAME: ClassVar[str] = "custom_endpoint_monitors" def __init__(self, plugin_service: PluginService, props: Properties): self._plugin_service = plugin_service @@ -252,12 +262,11 @@ def __init__(self, plugin_service: PluginService, props: Properties): telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory() self._wait_for_info_counter: TelemetryCounter | None = telemetry_factory.create_counter("customEndpoint.waitForInfo.counter") - self._monitors = SlidingExpirationCacheContainer.get_or_create_cache( - name=CustomEndpointPlugin._MONITOR_CACHE_NAME, - cleanup_interval_ns=CustomEndpointPlugin._CACHE_CLEANUP_RATE_NS, - should_dispose_func=lambda _: True, - item_disposal_func=lambda monitor: monitor.close() - ) + self._monitors = services_container.get_monitor_service() + self._monitors.register_monitor_type( + CustomEndpointMonitor, + expiration_timeout_ns=self._idle_monitor_expiration_ms * 1_000_000, + inactive_timeout_ns=1 * 60 * 1_000_000_000) # 1 minute, matches JDBC CustomEndpointPlugin._SUBSCRIBED_METHODS.update(self._plugin_service.network_bound_methods) @@ -302,15 +311,15 @@ def _create_monitor_if_absent(self, props: Properties) -> CustomEndpointMonitor: host_info = cast('HostInfo', self._custom_endpoint_host_info) endpoint_id = cast('str', self._custom_endpoint_id) region = cast('str', self._region) - monitor = self._monitors.compute_if_absent( + monitor = self._monitors.run_if_absent( + CustomEndpointMonitor, host_info.host, - lambda key: CustomEndpointMonitor( + lambda: CustomEndpointMonitor( self._plugin_service, host_info, endpoint_id, region, - WrapperProperties.CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS.get_int(props) * 1_000_000), - self._idle_monitor_expiration_ms * 1_000_000) + WrapperProperties.CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS.get_int(props) * 1_000_000)) return cast('CustomEndpointMonitor', monitor) diff --git a/aws_advanced_python_wrapper/database_dialect.py b/aws_advanced_python_wrapper/database_dialect.py index 2ef01a0d6..86c74acb3 100644 --- a/aws_advanced_python_wrapper/database_dialect.py +++ b/aws_advanced_python_wrapper/database_dialect.py @@ -39,8 +39,7 @@ QueryTimeoutError, UnsupportedOperationError) from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole -from aws_advanced_python_wrapper.thread_pool_container import \ - ThreadPoolContainer +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.decorators import \ preserve_transaction_status_with_timeout from aws_advanced_python_wrapper.utils.log import Logger @@ -49,8 +48,8 @@ WrapperProperties) from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils from .driver_dialect_codes import DriverDialectCodes -from .utils.cache_map import CacheMap from .utils.messages import Messages +from .utils.storage.cache_map import CacheMap from .utils.utils import Utils logger = Logger(__name__) @@ -861,7 +860,7 @@ def __init__(self, props: Properties, rds_helper: Optional[RdsUtils] = None): self._can_update: bool = False self._dialect: DatabaseDialect = UnknownDatabaseDialect() self._dialect_code: DialectCode = DialectCode.UNKNOWN - self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name) + self._thread_pool = services_container.get_thread_pool(self._executor_name) @staticmethod def get_custom_dialect(): diff --git a/aws_advanced_python_wrapper/driver_dialect.py b/aws_advanced_python_wrapper/driver_dialect.py index 3683a4354..47476fe77 100644 --- a/aws_advanced_python_wrapper/driver_dialect.py +++ b/aws_advanced_python_wrapper/driver_dialect.py @@ -27,8 +27,7 @@ from aws_advanced_python_wrapper.errors import (QueryTimeoutError, UnsupportedOperationError) from aws_advanced_python_wrapper.pep249_methods import DbApiMethod -from aws_advanced_python_wrapper.thread_pool_container import \ - ThreadPoolContainer +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.decorators import timeout from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, @@ -51,7 +50,7 @@ class DriverDialect(ABC): def __init__(self, props: Properties): self._props = props - self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name) + self._thread_pool = services_container.get_thread_pool(self._executor_name) @property def driver_name(self): diff --git a/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py b/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py index 5f9f0fdca..3d15e6b1b 100644 --- a/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py +++ b/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py @@ -17,21 +17,23 @@ import time from copy import copy from dataclasses import dataclass +from datetime import timedelta from threading import Event, Lock, Thread -from time import sleep +from time import perf_counter_ns, sleep from typing import (TYPE_CHECKING, Callable, ClassVar, Dict, List, Optional, Set, Tuple) from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_selector import RandomHostSelector from aws_advanced_python_wrapper.plugin import Plugin -from aws_advanced_python_wrapper.utils.cache_map import CacheMap +from aws_advanced_python_wrapper.utils import services_container +from aws_advanced_python_wrapper.utils.events import (EventBase, + MonitorResetEvent) from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ - SlidingExpirationCacheContainer +from aws_advanced_python_wrapper.utils.storage.cache_map import CacheMap from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryContext, TelemetryFactory, TelemetryGauge, TelemetryTraceLevel) @@ -47,6 +49,14 @@ MAX_VALUE = 2147483647 +class ResponseTimeHolder: + """Wrapper type for response time data, used as StorageService type key.""" + + def __init__(self, url: str, response_time: int): + self.url = url + self.response_time = response_time + + class FastestResponseStrategyPlugin(Plugin): _FASTEST_RESPONSE_STRATEGY_NAME = "fastest_response" _SUBSCRIBED_METHODS: Set[str] = {"accepts_strategy", @@ -154,12 +164,14 @@ def __init__(self, plugin_service: PluginService, host_info: HostInfo, props: Pr self._host_info = host_info self._properties = props self._interval_ms = interval_ms + self._storage_service = services_container.get_storage_service() self._telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory() self._response_time: int = MAX_VALUE self._lock: Lock = Lock() self._monitoring_conn: Optional[Connection] = None self._is_stopped: Event = Event() + self._last_activity_ns: int = perf_counter_ns() self._host_id: Optional[str] = self._host_info.host_id if self._host_id is None or self._host_id == "": @@ -190,8 +202,26 @@ def host_info(self): def is_stopped(self): return self._is_stopped.is_set() - def close(self): + def stop(self) -> None: self._is_stopped.set() + self.close() + + @property + def can_dispose(self) -> bool: + return self._is_stopped.is_set() + + @property + def last_activity_ns(self) -> int: + return self._last_activity_ns + + def process_event(self, event: EventBase) -> None: + if isinstance(event, MonitorResetEvent) and self._host_info.host in event.endpoints: + logger.debug("HostResponseTimeMonitor.ResetEventReceived", self._host_info.host) + self._monitoring_conn = None + self._response_time = MAX_VALUE + self._storage_service.remove(ResponseTimeHolder, self._host_info.url) + + def close(self): self._daemon_thread.join(5) logger.debug("HostResponseTimeMonitor.Stopped", self._host_info.host) @@ -206,6 +236,7 @@ def run(self): context.set_attribute("url", self._host_info.url) try: while not self.is_stopped: + self._last_activity_ns = perf_counter_ns() self._open_connection() if self._monitoring_conn is not None: @@ -223,8 +254,11 @@ def run(self): if count > 0: self.response_time = response_time_sum / count + self._storage_service.put(ResponseTimeHolder, self._host_info.url, + ResponseTimeHolder(self._host_info.url, self._response_time)) else: self.response_time = MAX_VALUE + self._storage_service.remove(ResponseTimeHolder, self._host_info.url) logger.debug("HostResponseTimeMonitor.ResponseTime", self._host_info.host, self._response_time) sleep(self._interval_ms / 1000) @@ -279,8 +313,6 @@ def _open_connection(self): class HostResponseTimeService: _CACHE_EXPIRATION_NS: ClassVar[int] = 10 * 60_000_000_000 # 10 minutes - _CACHE_CLEANUP_NS: ClassVar[int] = 60_000_000_000 # 1 minute - _CACHE_NAME: ClassVar[str] = "host_response_time_monitors" _lock: ClassVar[Lock] = Lock() def __init__(self, plugin_service: PluginService, props: Properties, interval_ms: int): @@ -289,17 +321,20 @@ def __init__(self, plugin_service: PluginService, props: Properties, interval_ms self._interval_ms = interval_ms self._hosts: Tuple[HostInfo, ...] = () self._telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory() + self._storage_service = services_container.get_storage_service() - self._monitoring_hosts = SlidingExpirationCacheContainer.get_or_create_cache( - name=HostResponseTimeService._CACHE_NAME, - cleanup_interval_ns=HostResponseTimeService._CACHE_CLEANUP_NS, - should_dispose_func=lambda monitor: True, - item_disposal_func=lambda monitor: HostResponseTimeService._monitor_close(monitor) - ) + self._storage_service.register( + ResponseTimeHolder, item_expiration_time=timedelta(minutes=10)) + + self._monitor_service = services_container.get_monitor_service() + self._monitor_service.register_monitor_type( + HostResponseTimeMonitor, + expiration_timeout_ns=HostResponseTimeService._CACHE_EXPIRATION_NS, + produced_data_type=ResponseTimeHolder) self._host_count_gauge: TelemetryGauge | None = self._telemetry_factory.create_gauge( "frt.hosts.count", - lambda: len(self._monitoring_hosts) + lambda: self._monitor_service.count(HostResponseTimeMonitor) ) @property @@ -310,18 +345,11 @@ def hosts(self) -> Tuple[HostInfo, ...]: def hosts(self, new_hosts: Tuple[HostInfo, ...]): self._hosts = new_hosts - @staticmethod - def _monitor_close(monitor: HostResponseTimeMonitor): - try: - monitor.close() - except Exception: - pass - def get_response_time(self, host_info: HostInfo) -> int: - monitor: Optional[HostResponseTimeMonitor] = self._monitoring_hosts.get(host_info.url) - if monitor is None: - return MAX_VALUE - return monitor.response_time + holder = self._storage_service.get(ResponseTimeHolder, host_info.url) + if holder is not None: + return holder.response_time + return MAX_VALUE def set_hosts(self, new_hosts: Tuple[HostInfo, ...]) -> None: old_hosts_dict = {x.url: x for x in self.hosts} @@ -329,10 +357,11 @@ def set_hosts(self, new_hosts: Tuple[HostInfo, ...]) -> None: for host in self.hosts: if host.url not in old_hosts_dict: + def _create_monitor(h: HostInfo = host) -> HostResponseTimeMonitor: + return HostResponseTimeMonitor(self._plugin_service, h, self._properties, self._interval_ms) + with self._lock: - self._monitoring_hosts.compute_if_absent(host.url, - lambda _: HostResponseTimeMonitor( - self._plugin_service, - host, - self._properties, - self._interval_ms), HostResponseTimeService._CACHE_EXPIRATION_NS) + self._monitor_service.run_if_absent( + HostResponseTimeMonitor, + host.url, + _create_monitor) diff --git a/aws_advanced_python_wrapper/federated_plugin.py b/aws_advanced_python_wrapper/federated_plugin.py index 33f3685b8..7e040e01d 100644 --- a/aws_advanced_python_wrapper/federated_plugin.py +++ b/aws_advanced_python_wrapper/federated_plugin.py @@ -43,6 +43,7 @@ from aws_advanced_python_wrapper.errors import AwsConnectError, AwsWrapperError from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, @@ -57,15 +58,16 @@ class FederatedAuthPlugin(Plugin): _SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"} _rds_utils: RdsUtils = RdsUtils() - _token_cache: Dict[str, TokenInfo] = {} def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory): self._plugin_service = plugin_service self._credentials_provider_factory = credentials_provider_factory + self._storage_service = services_container.get_storage_service() + self._storage_service.register(TokenInfo, item_expiration_time=timedelta(minutes=30)) telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("federated.fetch_token.count") - self._cache_size_gauge = telemetry_factory.create_gauge("federated.token_cache.size", lambda: len(FederatedAuthPlugin._token_cache)) + self._cache_size_gauge = telemetry_factory.create_gauge("federated.token_cache.size", lambda: self._storage_service.size(TokenInfo)) @property def subscribed_methods(self) -> Set[str]: @@ -109,7 +111,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl region ) - token_info: Optional[TokenInfo] = FederatedAuthPlugin._token_cache.get(cache_key) + token_info: Optional[TokenInfo] = self._storage_service.get(TokenInfo, cache_key) token_host_info = deepcopy(host_info) token_host_info.host = host @@ -174,7 +176,7 @@ def _update_authentication_token(self, session, credentials) WrapperProperties.PASSWORD.set(props, token) - FederatedAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) + self._storage_service.put(TokenInfo, cache_key, TokenInfo(token, token_expiry)) class FederatedAuthPluginFactory(PluginFactory): diff --git a/aws_advanced_python_wrapper/host_list_provider.py b/aws_advanced_python_wrapper/host_list_provider.py index 4b7d832fe..06cf26c21 100644 --- a/aws_advanced_python_wrapper/host_list_provider.py +++ b/aws_advanced_python_wrapper/host_list_provider.py @@ -27,12 +27,10 @@ from aws_advanced_python_wrapper.cluster_topology_monitor import ( ClusterTopologyMonitor, ClusterTopologyMonitorImpl, GlobalAuroraTopologyMonitor) +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.decorators import \ preserve_transaction_status_with_timeout -from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ - SlidingExpirationCacheContainer -from aws_advanced_python_wrapper.utils.storage.storage_service import ( - StorageService, Topology) +from aws_advanced_python_wrapper.utils.events import MonitorStopEvent if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect @@ -43,11 +41,9 @@ UnsupportedOperationError) from aws_advanced_python_wrapper.host_availability import ( HostAvailability, create_host_availability_strategy) -from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole, Topology from aws_advanced_python_wrapper.pep249 import (Connection, Cursor, ProgrammingError) -from aws_advanced_python_wrapper.thread_pool_container import \ - ThreadPoolContainer from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, @@ -83,6 +79,9 @@ def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) def get_cluster_id(self) -> str: ... + def stop_monitor(self) -> None: + ... + @runtime_checkable class DynamicHostListProvider(HostListProvider, Protocol): @@ -149,9 +148,7 @@ def is_static_host_list_provider(self) -> bool: class RdsHostListProvider(DynamicHostListProvider, HostListProvider): - _CACHE_CLEANUP_NANO: ClassVar[int] = 1 * 60 * 1_000_000_000 # 1 minute _MONITOR_CLEANUP_NANO: ClassVar[int] = 15 * 60 * 1_000_000_000 # 15 minutes - _MONITOR_CACHE_NAME: ClassVar[str] = "cluster_topology_monitors" _DEFAULT_TOPOLOGY_QUERY_TIMEOUT_SEC: ClassVar[int] = 5 def __init__( @@ -177,12 +174,11 @@ def __init__( self._high_refresh_rate_ns = ( WrapperProperties.CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.get_int(self._props) * 1_000_000) - self._monitors = SlidingExpirationCacheContainer.get_or_create_cache( - name=RdsHostListProvider._MONITOR_CACHE_NAME, - cleanup_interval_ns=RdsHostListProvider._CACHE_CLEANUP_NANO, - should_dispose_func=lambda monitor: monitor.can_dispose(), - item_disposal_func=lambda monitor: monitor.close() - ) + self._monitor_service = services_container.get_monitor_service() + self._monitor_service.register_monitor_type( + ClusterTopologyMonitorImpl, + expiration_timeout_ns=RdsHostListProvider._MONITOR_CLEANUP_NANO, + produced_data_type=Topology) def _initialize(self): if self._is_initialized: @@ -219,7 +215,7 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False) """ self._initialize() - cached_hosts = StorageService.get(Topology, self._cluster_id) + cached_hosts = services_container.get_storage_service().get(Topology, self._cluster_id) if not cached_hosts or force_update: if not conn: # Cannot fetch topology without a connection @@ -237,9 +233,10 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False) def _get_or_create_monitor(self) -> Optional[ClusterTopologyMonitor]: """Get or create monitor - matches Java's getOrCreateMonitor""" - return self._monitors.compute_if_absent_with_disposal( + return self._monitor_service.run_if_absent( + ClusterTopologyMonitorImpl, self.get_cluster_id(), - lambda k: ClusterTopologyMonitorImpl( + lambda: ClusterTopologyMonitorImpl( self._plugin_service, self._topology_utils, self._cluster_id, @@ -248,8 +245,7 @@ def _get_or_create_monitor(self) -> Optional[ClusterTopologyMonitor]: self._topology_utils.instance_template, self._refresh_rate_ns, self._high_refresh_rate_ns - ), - RdsHostListProvider._MONITOR_CLEANUP_NANO + ) ) def _force_refresh_monitor(self, should_verify_writer: bool, timeout_sec: int) -> Optional[Topology]: @@ -322,6 +318,10 @@ def get_cluster_id(self): self._initialize() return self._cluster_id + def stop_monitor(self) -> None: + services_container.get_event_publisher().publish( + MonitorStopEvent(monitor_type=ClusterTopologyMonitorImpl, key=self._cluster_id)) + @dataclass() class FetchTopologyResult: hosts: Topology @@ -370,6 +370,9 @@ def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) def get_cluster_id(self): return "" + def stop_monitor(self) -> None: + pass + class GlobalAuroraHostListProvider(RdsHostListProvider): _global_topology_utils: GlobalAuroraTopologyUtils @@ -395,9 +398,10 @@ def _init_settings(self): def _get_or_create_monitor(self) -> Optional[ClusterTopologyMonitor]: """Override to create GlobalAuroraTopologyMonitor""" - return self._monitors.compute_if_absent_with_disposal( + return self._monitor_service.run_if_absent( + ClusterTopologyMonitorImpl, self.get_cluster_id(), - lambda k: GlobalAuroraTopologyMonitor( + lambda: GlobalAuroraTopologyMonitor( self._plugin_service, self._global_topology_utils, self._cluster_id, @@ -407,8 +411,7 @@ def _get_or_create_monitor(self) -> Optional[ClusterTopologyMonitor]: self._refresh_rate_ns, self._high_refresh_rate_ns, self._instance_templates_by_region - ), - RdsHostListProvider._MONITOR_CLEANUP_NANO + ) ) def get_current_topology(self, connection: Connection, initial_host_info: HostInfo) -> Topology: @@ -461,7 +464,7 @@ def __init__(self, dialect: db_dialect.TopologyAwareDatabaseDialect, props: Prop self.instance_template: HostInfo = instance_template self._max_timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get_int(props) - self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name) + self._thread_pool = services_container.get_thread_pool(self._executor_name) def _validate_host_pattern(self, host: str): if not self._rds_utils.is_dns_pattern_valid(host): diff --git a/aws_advanced_python_wrapper/host_monitoring_plugin.py b/aws_advanced_python_wrapper/host_monitoring_plugin.py index ec6a1f349..445167a08 100644 --- a/aws_advanced_python_wrapper/host_monitoring_plugin.py +++ b/aws_advanced_python_wrapper/host_monitoring_plugin.py @@ -22,12 +22,11 @@ from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService -from concurrent.futures import Future, TimeoutError from dataclasses import dataclass from queue import Queue -from threading import Event, Lock, RLock +from threading import Event, Lock, Thread from time import perf_counter_ns -from typing import Any, Callable, ClassVar, Dict, FrozenSet, Optional, Set +from typing import Any, Callable, Dict, FrozenSet, Optional, Set from _weakref import ReferenceType, ref @@ -36,9 +35,9 @@ from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import (CanReleaseResources, Plugin, PluginFactory) -from aws_advanced_python_wrapper.thread_pool_container import \ - ThreadPoolContainer -from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict +from aws_advanced_python_wrapper.utils import services_container +from aws_advanced_python_wrapper.utils.events import (EventBase, + MonitorResetEvent) from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.notifications import ( @@ -78,7 +77,7 @@ def __init__(self, plugin_service, props): self._is_connection_initialized = False self._monitoring_host_info: Optional[HostInfo] = None self._rds_utils: RdsUtils = RdsUtils() - self._monitor_service: MonitorService = MonitorService(plugin_service) + self._monitor_service: HostMonitorService = HostMonitorService(plugin_service) self._lock: Lock = Lock() self._is_enabled = WrapperProperties.FAILURE_DETECTION_ENABLED.get_bool(self._props) self._failure_detection_time_ms = WrapperProperties.FAILURE_DETECTION_TIME_MS.get_int(self._props) @@ -201,9 +200,6 @@ def _get_monitoring_host_info(self) -> HostInfo: return self._monitoring_host_info def release_resources(self): - if self._monitor_service is not None: - self._monitor_service.release_resources() - self._monitor_service = None @@ -349,12 +345,10 @@ def __init__( self, plugin_service: PluginService, host_info: HostInfo, - props: Properties, - monitor_container: MonitoringThreadContainer): + props: Properties): self._plugin_service: PluginService = plugin_service self._host_info: HostInfo = host_info self._props: Properties = props - self._monitor_container: MonitoringThreadContainer = monitor_container self._telemetry_factory = self._plugin_service.get_telemetry_factory() self._lock: Lock = Lock() @@ -370,6 +364,9 @@ def __init__( self._host_invalid_counter = self._telemetry_factory.create_counter( f"host_monitoring.host_unhealthy.count.{host_id}") + self._thread = Thread(daemon=True, target=self.run, name="EFMv1Monitor") + self._thread.start() + @dataclass class HostStatus: is_available: bool @@ -381,6 +378,29 @@ def is_stopped(self): def stop(self): self._is_stopped.set() + self.close() + + def close(self) -> None: + if self._thread.is_alive(): + self._thread.join(timeout=5.0) + if self._monitoring_conn is not None: + try: + self._monitoring_conn.close() + except Exception: + pass + + @property + def can_dispose(self) -> bool: + return self._active_contexts.empty() and self._new_contexts.empty() + + @property + def last_activity_ns(self) -> int: + return self._context_last_used_ns + + def process_event(self, event: EventBase) -> None: + if isinstance(event, MonitorResetEvent) and self._host_info.host in event.endpoints: + logger.debug("Monitor.ResetEventReceived", self._host_info.host) + self._monitoring_conn = None def start_monitoring(self, context: MonitoringContext): current_time_ns = perf_counter_ns() @@ -434,7 +454,7 @@ def run(self): if self._active_contexts.empty(): if (perf_counter_ns() - self._context_last_used_ns) >= self._monitor_disposal_time_ms * 1_000_000: - self._monitor_container.release_monitor(self) + # No active contexts and idle too long — stop self break self.sleep(Monitor._INACTIVE_SLEEP_MS / 1000) @@ -499,15 +519,13 @@ def run(self): logger.debug("Monitor.StoppingMonitorUnhandledException", self._host_info.host) logger.debug(e, exc_info=True) finally: - self._monitor_container.release_monitor(self) - self.stop() + self._is_stopped.set() + services_container.get_monitor_service().detach(Monitor, self) if self._monitoring_conn is not None: try: self._monitoring_conn.close() except Exception: - # Do nothing pass - self.stop() def _check_host_status(self, host_check_timeout_ms: int) -> HostStatus: context = self._telemetry_factory.open_telemetry_context( @@ -564,119 +582,21 @@ def sleep(self, duration: int): self._is_stopped.wait(duration) -class MonitoringThreadContainer: - """ - This singleton class keeps track of all the monitoring threads and handles the creation and clean up of each - monitoring thread. - """ - - _instance: ClassVar[Optional[MonitoringThreadContainer]] = None - _lock: ClassVar[RLock] = RLock() - _monitor_lock: ClassVar[RLock] = RLock() - - _monitor_map: ConcurrentDict[str, Monitor] = ConcurrentDict() - _tasks_map: ConcurrentDict[Monitor, Future] = ConcurrentDict() - _executor_name: ClassVar[str] = "MonitoringThreadContainerExecutor" - - def __init__(self): - self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name) - - # This logic ensures that this class is a Singleton - def __new__(cls, *args, **kwargs): - if cls._instance is None: - with cls._lock: - if not cls._instance: - cls._instance = super().__new__(cls, *args, **kwargs) - return cls._instance - - def get_or_create_monitor(self, host_aliases: FrozenSet[str], monitor_supplier: Callable) -> Monitor: - if not host_aliases: - raise AwsWrapperError(Messages.get("MonitoringThreadContainer.EmptyHostKeys")) - - with self._monitor_lock: - monitor = None - any_alias = next(iter(host_aliases)) - for host_alias in host_aliases: - monitor = self._monitor_map.get(host_alias) - any_alias = host_alias - if monitor is not None: - break - - def _get_or_create_monitor(_) -> Monitor: - supplied_monitor = monitor_supplier() - if supplied_monitor is None: - raise AwsWrapperError(Messages.get("MonitoringThreadContainer.SupplierMonitorNone")) - self._tasks_map.compute_if_absent( - supplied_monitor, - lambda _: self._thread_pool.submit(supplied_monitor.run)) - return supplied_monitor - - if monitor is None: - monitor = self._monitor_map.compute_if_absent(any_alias, _get_or_create_monitor) - if monitor is None: - raise AwsWrapperError( - Messages.get_formatted("MonitoringThreadContainer.ErrorGettingMonitor", host_aliases)) - - for host_alias in host_aliases: - self._monitor_map.put_if_absent(host_alias, monitor) - - return monitor - - @staticmethod - def _cancel(monitor, future: Future) -> None: - future.cancel() - monitor.stop() - return None - - def get_monitor(self, alias: str) -> Optional[Monitor]: - return self._monitor_map.get(alias) - - def release_monitor(self, monitor: Monitor): - with self._monitor_lock: - self._monitor_map.remove_matching_values([monitor]) - self._tasks_map.compute_if_present(monitor, MonitoringThreadContainer._cancel) - - @staticmethod - def clean_up(): - """Clean up any dangling monitoring threads created by the host monitoring plugin. - This method should be called at the end of the application. - The Host Monitoring Plugin creates monitoring threads in the background to monitor all connections established to each of cluster instances. - The threads will terminate if there are no connections to the cluster instance the thread is monitoring for over a period of time, - specified by the `monitor_disposal_time_ms`. Client applications can also manually call this method to clean up any dangling resources. - This method should be called right before application termination. - """ - if MonitoringThreadContainer._instance is None: - return - - with MonitoringThreadContainer._lock: - if MonitoringThreadContainer._instance is not None: - MonitoringThreadContainer._instance._release_resources() - MonitoringThreadContainer._instance = None - - def _release_resources(self): - with self._monitor_lock: - self._monitor_map.clear() - self._tasks_map.apply_if( - lambda monitor, future: not future.done() and not future.cancelled(), - lambda monitor, future: future.cancel()) - - for monitor, _ in self._tasks_map.items(): - monitor.stop() - - ThreadPoolContainer.release_pool(MonitoringThreadContainer._executor_name, wait=False) - self._tasks_map.clear() - - -class MonitorService: +class HostMonitorService: def __init__(self, plugin_service: PluginService): self._plugin_service: PluginService = plugin_service - self._monitor_container: MonitoringThreadContainer = MonitoringThreadContainer() self._cached_monitor_aliases: Optional[FrozenSet[str]] = None self._cached_monitor: Optional[ReferenceType[Monitor]] = None telemetry_factory = self._plugin_service.get_telemetry_factory() self._aborted_connections_counter = telemetry_factory.create_counter("host_monitoring.connections.aborted") + self._monitor_service = services_container.get_monitor_service() + self._monitor_service.register_monitor_type( + Monitor, + expiration_timeout_ns=WrapperProperties.MONITOR_DISPOSAL_TIME_MS.get_int( + self._plugin_service.props) * 1_000_000) + def start_monitoring(self, conn: Connection, host_aliases: FrozenSet[str], @@ -693,8 +613,10 @@ def start_monitoring(self, or monitor.is_stopped \ or self._cached_monitor_aliases is None \ or self._cached_monitor_aliases != host_aliases: - monitor = self._monitor_container.get_or_create_monitor( - host_aliases, lambda: self._create_monitor(host_info, props, self._monitor_container)) + monitor = self._monitor_service.run_if_absent_with_aliases( + Monitor, + host_aliases, + lambda: Monitor(self._plugin_service, host_info, props)) self._cached_monitor = ref(monitor) self._cached_monitor_aliases = host_aliases @@ -704,9 +626,6 @@ def start_monitoring(self, monitor.start_monitoring(context) return context - def _create_monitor(self, host_info: HostInfo, props: Properties, monitor_container: MonitoringThreadContainer): - return Monitor(self._plugin_service, host_info, props, monitor_container) - @staticmethod def stop_monitoring(context: MonitoringContext): monitor = context.monitor @@ -714,10 +633,7 @@ def stop_monitoring(context: MonitoringContext): def stop_monitoring_host(self, host_aliases: FrozenSet): for alias in host_aliases: - monitor = self._monitor_container.get_monitor(alias) + monitor = self._monitor_service.get(Monitor, alias) if monitor is not None: monitor.clear_contexts() return - - def release_resources(self): - self._monitor_container = None diff --git a/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py b/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py index 728172e08..efd964fd8 100644 --- a/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py +++ b/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py @@ -25,9 +25,12 @@ from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import (CanReleaseResources, Plugin, PluginFactory) +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.atomic import (AtomicBoolean, AtomicReference) from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict +from aws_advanced_python_wrapper.utils.events import (EventBase, + MonitorResetEvent) from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.notifications import ( @@ -36,8 +39,6 @@ PropertiesUtils, WrapperProperties) from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils -from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ - SlidingExpirationCacheContainer from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryCounter, TelemetryFactory, TelemetryTraceLevel) @@ -235,6 +236,7 @@ def __init__( self._invalid_host_start_time_ns: int = 0 self._monitoring_connection: Optional[Connection] = None self._driver_dialect: DriverDialect = self._plugin_service.driver_dialect + self._last_activity_ns: int = perf_counter_ns() self._monitor_run_thread: Thread = Thread(daemon=True, name="HostMonitoringThreadRun", target=self.run) self._monitor_run_thread.start() @@ -242,15 +244,29 @@ def __init__( target=self._new_context_run) self._monitor_new_context_thread.start() + @property def can_dispose(self) -> bool: return self._active_contexts.empty() and len(self._new_contexts.items()) == 0 + @property + def last_activity_ns(self) -> int: + return self._last_activity_ns + + def process_event(self, event: EventBase) -> None: + if isinstance(event, MonitorResetEvent) and self._host_info.host in event.endpoints: + logger.debug("HostMonitorV2.ResetEventReceived", self._host_info.host) + self._monitoring_connection = None + self._invalid_host_start_time_ns = 0 + self._failure_count = 0 + self._is_unhealthy = False + @property def is_stopped(self): return self._is_stopped.get() def stop(self): self._is_stopped.set(True) + self.close() def start_monitoring(self, context: MonitoringContext): if self.is_stopped: @@ -306,6 +322,7 @@ def run(self) -> None: try: while not self.is_stopped: + self._last_activity_ns = perf_counter_ns() if self._active_contexts.empty() and not self._is_unhealthy: sleep(HostMonitorV2._THREAD_SLEEP_SEC) continue @@ -356,7 +373,7 @@ def run(self) -> None: except Exception as ex: logger.debug("HostMonitorV2.ExceptionDuringMonitoringStop", self._host_info.host, ex) finally: - self.stop() + self._is_stopped.set(True) if self._monitoring_connection is not None: try: self.abort_connection(self._monitoring_connection) @@ -443,15 +460,12 @@ def abort_connection(self, connection: Connection) -> None: logger.debug("HostMonitorV2.ExceptionAbortingConnection", ex) def close(self) -> None: - self.stop() self._monitor_run_thread.join(10) self._monitor_new_context_thread.join(10) class MonitorServiceV2: - # 1 Minute to Nanoseconds _CACHE_CLEANUP_NANO: ClassVar[int] = 1 * 60 * 1_000_000_000 - _MONITOR_CACHE_NAME: ClassVar[str] = "host_monitors_v2" def __init__(self, plugin_service: PluginService): self._plugin_service: PluginService = plugin_service @@ -459,12 +473,10 @@ def __init__(self, plugin_service: PluginService): telemetry_factory = self._plugin_service.get_telemetry_factory() self._aborted_connections_counter = telemetry_factory.create_counter("efm2.connections.aborted") - self._monitors = SlidingExpirationCacheContainer.get_or_create_cache( - name=MonitorServiceV2._MONITOR_CACHE_NAME, - cleanup_interval_ns=MonitorServiceV2._CACHE_CLEANUP_NANO, - should_dispose_func=lambda monitor: monitor.can_dispose(), - item_disposal_func=lambda monitor: monitor.close() - ) + self._monitor_service = services_container.get_monitor_service() + self._monitor_service.register_monitor_type( + HostMonitorV2, + expiration_timeout_ns=MonitorServiceV2._CACHE_CLEANUP_NANO) def start_monitoring( self, @@ -510,13 +522,14 @@ def get_monitor(self, host_info.url ) - cache_expiration_ns = int(WrapperProperties.MONITOR_DISPOSAL_TIME_MS.get_float(props) * 10**6) - return self._monitors.compute_if_absent(monitor_key, - lambda k: HostMonitorV2(self._plugin_service, - host_info, - props, - failure_detection_time_ms, - failure_detection_interval_ms, - failure_detection_count, - self._aborted_connections_counter), - cache_expiration_ns) + return self._monitor_service.run_if_absent( + HostMonitorV2, + monitor_key, + lambda: HostMonitorV2(self._plugin_service, + host_info, + props, + failure_detection_time_ms, + failure_detection_interval_ms, + failure_detection_count, + self._aborted_connections_counter) + ) diff --git a/aws_advanced_python_wrapper/host_selector.py b/aws_advanced_python_wrapper/host_selector.py index 87b68fd8c..5a60b6a37 100644 --- a/aws_advanced_python_wrapper/host_selector.py +++ b/aws_advanced_python_wrapper/host_selector.py @@ -25,7 +25,7 @@ from .hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.errors import AwsWrapperError -from aws_advanced_python_wrapper.utils.cache_map import CacheMap +from aws_advanced_python_wrapper.utils.storage.cache_map import CacheMap from .pep249 import Error from .utils.messages import Messages from .utils.properties import Properties, WrapperProperties diff --git a/aws_advanced_python_wrapper/hostinfo.py b/aws_advanced_python_wrapper/hostinfo.py index e02247515..5c9893a9f 100644 --- a/aws_advanced_python_wrapper/hostinfo.py +++ b/aws_advanced_python_wrapper/hostinfo.py @@ -16,11 +16,14 @@ from dataclasses import dataclass from enum import Enum, auto -from typing import TYPE_CHECKING, ClassVar, FrozenSet, Optional, Set +from typing import (TYPE_CHECKING, ClassVar, FrozenSet, Optional, Set, Tuple, + TypeAlias) from aws_advanced_python_wrapper.host_availability import ( HostAvailability, HostAvailabilityStrategy) +Topology: TypeAlias = Tuple["HostInfo", ...] + if TYPE_CHECKING: from datetime import datetime diff --git a/aws_advanced_python_wrapper/iam_plugin.py b/aws_advanced_python_wrapper/iam_plugin.py index ca655da36..791a1a531 100644 --- a/aws_advanced_python_wrapper/iam_plugin.py +++ b/aws_advanced_python_wrapper/iam_plugin.py @@ -31,11 +31,12 @@ from aws_advanced_python_wrapper.plugin_service import PluginService from datetime import datetime, timedelta -from typing import Callable, Dict, Set +from typing import Callable, Set from aws_advanced_python_wrapper.errors import AwsConnectError, AwsWrapperError from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, @@ -51,15 +52,16 @@ class IamAuthPlugin(Plugin): _DEFAULT_TOKEN_EXPIRATION_SEC = 15 * 60 - 30 _rds_utils: RdsUtils = RdsUtils() - _token_cache: Dict[str, TokenInfo] = {} def __init__(self, plugin_service: PluginService): self._plugin_service = plugin_service + self._storage_service = services_container.get_storage_service() + self._storage_service.register(TokenInfo, item_expiration_time=timedelta(minutes=15)) telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("iam.fetch_token.count") self._cache_size_gauge = telemetry_factory.create_gauge( - "iam.token_cache.size", lambda: len(IamAuthPlugin._token_cache)) + "iam.token_cache.size", lambda: self._storage_service.size(TokenInfo)) @property def subscribed_methods(self) -> Set[str]: @@ -104,7 +106,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl region ) - token_info = IamAuthPlugin._token_cache.get(cache_key) + token_info = self._storage_service.get(TokenInfo, cache_key) if token_info is not None and not token_info.is_expired(): logger.debug("IamAuthPlugin.UseCachedIamToken", token_info.token) @@ -125,7 +127,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl region, session) self._plugin_service.driver_dialect.set_password(props, token) - IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) + self._storage_service.put(TokenInfo, cache_key, TokenInfo(token, token_expiry)) try: return connect_func() @@ -150,7 +152,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl session = AwsCredentialsManager.get_session(session_host_info, props, region) token = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, session) self._plugin_service.driver_dialect.set_password(props, token) - IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) + self._storage_service.put(TokenInfo, cache_key, TokenInfo(token, token_expiry)) try: return connect_func() diff --git a/aws_advanced_python_wrapper/limitless_plugin.py b/aws_advanced_python_wrapper/limitless_plugin.py index bc2398573..6946fc036 100644 --- a/aws_advanced_python_wrapper/limitless_plugin.py +++ b/aws_advanced_python_wrapper/limitless_plugin.py @@ -15,8 +15,9 @@ import math import time from contextlib import closing +from datetime import timedelta from threading import Event, RLock, Thread -from time import sleep +from time import perf_counter_ns, sleep from typing import (TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Set, Tuple) @@ -28,13 +29,12 @@ from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ - SlidingExpirationCacheContainer from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryContext, TelemetryFactory, TelemetryTraceLevel) from aws_advanced_python_wrapper.utils.utils import LogUtils, Utils @@ -47,6 +47,13 @@ logger = Logger(__name__) +class LimitlessRouters: + """Wrapper type for limitless router list, used as StorageService type key.""" + + def __init__(self, hosts: List[HostInfo]): + self.hosts = hosts + + class LimitlessPlugin(Plugin): _SUBSCRIBED_METHODS: Set[str] = {"connect"} @@ -112,14 +119,13 @@ class LimitlessRouterMonitor: def __init__(self, plugin_service: PluginService, host_info: HostInfo, - limitless_router_cache, # SlidingExpirationCache from container limitless_router_cache_key: str, props: Properties, interval_ms: int): self._plugin_service = plugin_service self._host_info = host_info - self._limitless_router_cache = limitless_router_cache self._limitless_router_cache_key = limitless_router_cache_key + self._storage_service = services_container.get_storage_service() self._properties = copy.deepcopy(props) for property_key in self._properties.keys(): @@ -134,6 +140,7 @@ def __init__(self, self._telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory() self._monitoring_conn: Optional[Connection] = None self._is_stopped: Event = Event() + self._last_activity_ns: int = perf_counter_ns() self._daemon_thread: Thread = Thread(daemon=True, target=self.run) self._daemon_thread.start() @@ -146,8 +153,19 @@ def host_info(self): def is_stopped(self): return self._is_stopped.is_set() - def close(self): + def stop(self) -> None: self._is_stopped.set() + self.close() + + @property + def can_dispose(self) -> bool: + return self._is_stopped.is_set() + + @property + def last_activity_ns(self) -> int: + return self._last_activity_ns + + def close(self): if self._monitoring_conn: self._monitoring_conn.close() self._daemon_thread.join(5) @@ -163,15 +181,14 @@ def run(self): try: while not self.is_stopped: + self._last_activity_ns = perf_counter_ns() self._open_connection() if self._monitoring_conn is not None: new_limitless_routers = self._query_helper.query_for_limitless_routers(self._monitoring_conn, self._host_info.port) - self._limitless_router_cache.compute_if_absent(self._limitless_router_cache_key, - lambda _: new_limitless_routers, - WrapperProperties.LIMITLESS_MONITOR_DISPOSAL_TIME_MS.get( - self._properties) * 1_000_000) + self._storage_service.put(LimitlessRouters, self._limitless_router_cache_key, + LimitlessRouters(new_limitless_routers)) logger.debug(LogUtils.log_topology(tuple(new_limitless_routers), "[limitlessRouterMonitor] Topology:")) sleep(self._interval_ms / 1000) @@ -313,29 +330,26 @@ def is_any_router_available(self): class LimitlessRouterService: _CACHE_CLEANUP_NS: ClassVar[int] = 60_000_000_000 # 1 minute - _ROUTER_CACHE_NAME: ClassVar[str] = "limitless_router_cache" - _MONITOR_CACHE_NAME: ClassVar[str] = "limitless_monitor_cache" _force_get_limitless_routers_lock_map: ClassVar[ConcurrentDict[str, RLock]] = ConcurrentDict() def __init__(self, plugin_service: PluginService, query_helper: LimitlessQueryHelper): self._plugin_service = plugin_service self._query_helper = query_helper + self._storage_service = services_container.get_storage_service() - self._limitless_router_cache = SlidingExpirationCacheContainer.get_or_create_cache( - name=LimitlessRouterService._ROUTER_CACHE_NAME, - cleanup_interval_ns=LimitlessRouterService._CACHE_CLEANUP_NS - ) + self._storage_service.register( + LimitlessRouters, + item_expiration_time=timedelta(milliseconds=WrapperProperties.LIMITLESS_MONITOR_DISPOSAL_TIME_MS.get_int( + plugin_service.props))) - self._limitless_router_monitor = SlidingExpirationCacheContainer.get_or_create_cache( - name=LimitlessRouterService._MONITOR_CACHE_NAME, - cleanup_interval_ns=LimitlessRouterService._CACHE_CLEANUP_NS, - should_dispose_func=lambda monitor: True, - item_disposal_func=lambda monitor: monitor.close() - ) + self._monitor_service = services_container.get_monitor_service() + self._monitor_service.register_monitor_type( + LimitlessRouterMonitor, + expiration_timeout_ns=LimitlessRouterService._CACHE_CLEANUP_NS) def establish_connection(self, context: LimitlessContext) -> None: context.set_limitless_routers(self._get_limitless_routers( - self._plugin_service.host_list_provider.get_cluster_id(), context.get_props())) + self._plugin_service.host_list_provider.get_cluster_id())) if context.get_limitless_routers() is None or len(context.get_limitless_routers()) == 0: logger.debug("LimitlessRouterService.LimitlessRouterCacheEmpty") @@ -388,14 +402,11 @@ def establish_connection(self, context: LimitlessContext) -> None: self._retry_connection_with_least_loaded_routers(context) - def _get_limitless_routers(self, cluster_id: str, props: Properties) -> List[HostInfo]: - # Convert milliseconds to nanoseconds - cache_expiration_nano: int = WrapperProperties.LIMITLESS_MONITOR_DISPOSAL_TIME_MS.get_int(props) * 1_000_000 - self._limitless_router_cache.set_cleanup_interval_ns(cache_expiration_nano) - routers = self._limitless_router_cache.get(cluster_id) + def _get_limitless_routers(self, cluster_id: str) -> List[HostInfo]: + routers = self._storage_service.get(LimitlessRouters, cluster_id) if routers is None: return [] - return routers + return routers.hosts def _retry_connection_with_least_loaded_routers(self, context: LimitlessContext) -> None: retry_count = 0 @@ -476,8 +487,6 @@ def _synchronously_get_limitless_routers_with_retry(self, context: LimitlessCont raise AwsWrapperError(Messages.get("LimitlessRouterService.NoRoutersAvailable")) def _synchronously_get_limitless_routers(self, context: LimitlessContext) -> None: - cache_expiration_nano: int = WrapperProperties.LIMITLESS_MONITOR_DISPOSAL_TIME_MS.get_int(context.get_props()) * 1_000_000 - lock = LimitlessRouterService._force_get_limitless_routers_lock_map.compute_if_absent( self._plugin_service.host_list_provider.get_cluster_id(), lambda _: RLock() @@ -487,7 +496,7 @@ def _synchronously_get_limitless_routers(self, context: LimitlessContext) -> Non lock.acquire() try: - limitless_routers = self._limitless_router_cache.get( + limitless_routers = self._get_limitless_routers( self._plugin_service.host_list_provider.get_cluster_id()) if limitless_routers is not None and len(limitless_routers) != 0: context.set_limitless_routers(limitless_routers) @@ -501,11 +510,10 @@ def _synchronously_get_limitless_routers(self, context: LimitlessContext) -> Non if new_limitless_routers is not None and len(new_limitless_routers) != 0: context.set_limitless_routers(new_limitless_routers) - self._limitless_router_cache.compute_if_absent( + self._storage_service.put( + LimitlessRouters, self._plugin_service.host_list_provider.get_cluster_id(), - lambda _: new_limitless_routers, - cache_expiration_nano - ) + LimitlessRouters(new_limitless_routers)) else: raise AwsWrapperError(Messages.get("LimitlessRouterService.FetchedEmptyRouterList")) @@ -519,21 +527,20 @@ def start_monitoring(self, host_info: HostInfo, props: Properties) -> None: try: limitless_router_monitor_key: str = self._plugin_service.host_list_provider.get_cluster_id() - cache_expiration_nano: int = WrapperProperties.LIMITLESS_MONITOR_DISPOSAL_TIME_MS.get_int(props) * 1_000_000 intervals_ms: int = WrapperProperties.LIMITLESS_INTERVAL_MILLIS.get_int(props) - self._limitless_router_monitor.compute_if_absent( + self._monitor_service.run_if_absent( + LimitlessRouterMonitor, limitless_router_monitor_key, - lambda _: LimitlessRouterMonitor(self._plugin_service, - host_info, - self._limitless_router_cache, - limitless_router_monitor_key, - props, - intervals_ms), cache_expiration_nano) + lambda: LimitlessRouterMonitor(self._plugin_service, + host_info, + limitless_router_monitor_key, + props, + intervals_ms)) except Exception as e: logger.debug("LimitlessRouterService.ErrorStartingMonitor", e) raise e def clear_cache(self) -> None: LimitlessRouterService._force_get_limitless_routers_lock_map.clear() - self._limitless_router_cache.clear() + self._storage_service.clear(LimitlessRouters) diff --git a/aws_advanced_python_wrapper/okta_plugin.py b/aws_advanced_python_wrapper/okta_plugin.py index 4cc0a5d13..d4d1fe10d 100644 --- a/aws_advanced_python_wrapper/okta_plugin.py +++ b/aws_advanced_python_wrapper/okta_plugin.py @@ -40,6 +40,7 @@ from aws_advanced_python_wrapper.errors import AwsConnectError, AwsWrapperError from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, @@ -53,15 +54,16 @@ class OktaAuthPlugin(Plugin): _SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"} _rds_utils: RdsUtils = RdsUtils() - _token_cache: Dict[str, TokenInfo] = {} def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory): self._plugin_service = plugin_service self._credentials_provider_factory = credentials_provider_factory + self._storage_service = services_container.get_storage_service() + self._storage_service.register(TokenInfo, item_expiration_time=timedelta(minutes=30)) telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("okta.fetch_token.count") - self._cache_size_gauge = telemetry_factory.create_gauge("okta.token_cache.size", lambda: len(OktaAuthPlugin._token_cache)) + self._cache_size_gauge = telemetry_factory.create_gauge("okta.token_cache.size", lambda: self._storage_service.size(TokenInfo)) @property def subscribed_methods(self) -> Set[str]: @@ -105,7 +107,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl region ) - token_info: Optional[TokenInfo] = OktaAuthPlugin._token_cache.get(cache_key) + token_info: Optional[TokenInfo] = self._storage_service.get(TokenInfo, cache_key) token_host_info = deepcopy(host_info) token_host_info.host = host @@ -169,7 +171,7 @@ def _update_authentication_token(self, session, credentials) WrapperProperties.PASSWORD.set(props, token) - OktaAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) + self._storage_service.put(TokenInfo, cache_key, TokenInfo(token, token_expiry)) class OktaCredentialsProviderFactory(SamlCredentialsProviderFactory): diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index 9ab0fb93c..22834646f 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -36,7 +36,6 @@ from aws_advanced_python_wrapper.utils.utils import Utils if TYPE_CHECKING: - from aws_advanced_python_wrapper.allowed_and_blocked_hosts import AllowedAndBlockedHosts from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.driver_dialect_manager import DriverDialectManager from aws_advanced_python_wrapper.pep249 import Connection @@ -48,6 +47,8 @@ from typing import (Any, Callable, Dict, FrozenSet, Optional, Protocol, Set, Tuple) +from aws_advanced_python_wrapper.allowed_and_blocked_hosts import \ + AllowedAndBlockedHosts from aws_advanced_python_wrapper.aurora_connection_tracker_plugin import \ AuroraConnectionTrackerPluginFactory from aws_advanced_python_wrapper.aws_secrets_manager_plugin import \ @@ -88,9 +89,7 @@ from aws_advanced_python_wrapper.simple_read_write_splitting_plugin import \ SimpleReadWriteSplittingPluginFactory from aws_advanced_python_wrapper.stale_dns_plugin import StaleDnsPluginFactory -from aws_advanced_python_wrapper.thread_pool_container import \ - ThreadPoolContainer -from aws_advanced_python_wrapper.utils.cache_map import CacheMap +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.decorators import \ preserve_transaction_status_with_timeout from aws_advanced_python_wrapper.utils.log import Logger @@ -100,6 +99,7 @@ from aws_advanced_python_wrapper.utils.properties import (Properties, PropertiesUtils, WrapperProperties) +from aws_advanced_python_wrapper.utils.storage.cache_map import CacheMap from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryContext, TelemetryFactory, TelemetryTraceLevel) @@ -112,7 +112,7 @@ def plugin_service(self) -> PluginService: return self._plugin_service @plugin_service.setter - def plugin_service(self, value): + def plugin_service(self, value: PluginService) -> None: self._plugin_service = value @property @@ -120,7 +120,7 @@ def plugin_manager(self) -> PluginManager: return self._plugin_manager @plugin_manager.setter - def plugin_manager(self, value): + def plugin_manager(self, value: PluginManager) -> None: self._plugin_manager = value @@ -320,7 +320,6 @@ def get_status(self, clazz: Type[StatusType], key: str) -> Optional[StatusType]: class PluginServiceImpl(PluginService, HostListProviderService, CanReleaseResources): _STATUS_CACHE_EXPIRATION_NANO = 60 * 60 * 1_000_000_000 # one hour _host_availability_expiring_cache: CacheMap[str, HostAvailability] = CacheMap() - _status_cache: ClassVar[CacheMap[str, Any]] = CacheMap() _executor_name: ClassVar[str] = "PluginServiceImplExecutor" @@ -339,7 +338,6 @@ def __init__( self._host_list_provider: HostListProvider = ConnectionStringHostListProvider(self, props) self._all_hosts: Tuple[HostInfo, ...] = () - self._allowed_and_blocked_hosts: Optional[AllowedAndBlockedHosts] = None self._current_connection: Optional[Connection] = None self._current_host_info: Optional[HostInfo] = None self._initial_connection_host_info: Optional[HostInfo] = None @@ -351,7 +349,7 @@ def __init__( self._driver_dialect = driver_dialect self._database_dialect = self._dialect_provider.get_dialect(driver_dialect.dialect_code, props) self._session_state_service = session_state_service if session_state_service is not None else SessionStateServiceImpl(self, props) - self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name) + self._thread_pool = services_container.get_thread_pool(self._executor_name) @property def all_hosts(self) -> Tuple[HostInfo, ...]: @@ -377,11 +375,15 @@ def hosts(self) -> Tuple[HostInfo, ...]: @property def allowed_and_blocked_hosts(self) -> Optional[AllowedAndBlockedHosts]: - return self._allowed_and_blocked_hosts + return services_container.get_storage_service().get(AllowedAndBlockedHosts, self._original_url) @allowed_and_blocked_hosts.setter def allowed_and_blocked_hosts(self, allowed_and_blocked_hosts: Optional[AllowedAndBlockedHosts]): - self._allowed_and_blocked_hosts = allowed_and_blocked_hosts + storage = services_container.get_storage_service() + if allowed_and_blocked_hosts is None: + storage.remove(AllowedAndBlockedHosts, self._original_url) + else: + storage.put(AllowedAndBlockedHosts, self._original_url, allowed_and_blocked_hosts) @property def current_connection(self) -> Optional[Connection]: @@ -787,19 +789,14 @@ def release_resources(self): host_list_provider.release_resources() def set_status(self, clazz: Type[StatusType], status: Optional[StatusType], key: str): - cache_key = self._get_status_cache_key(clazz, key) + storage = services_container.get_storage_service() if status is None: - self._status_cache.remove(cache_key) + storage.remove(clazz, key) else: - self._status_cache.put(cache_key, status, PluginServiceImpl._STATUS_CACHE_EXPIRATION_NANO) - - def _get_status_cache_key(self, clazz: Type[StatusType], key: str) -> str: - key_str = "" if key is None else key.strip().lower() - return f"{key_str}::{clazz.__name__}" + storage.put(clazz, key, status) def get_status(self, clazz: Type[StatusType], key: str) -> Optional[StatusType]: - cache_key = self._get_status_cache_key(clazz, key) - status = PluginServiceImpl._status_cache.get(cache_key) + status = services_container.get_storage_service().get(clazz, key) if status is None: return None diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index 0bacf5258..ff596d541 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -77,18 +77,19 @@ CloseConnectionExecuteRouting.InProgressConnectionClosed=[CloseConnectionExecute ClusterTopologyMonitor.StartMonitoringThread=[ClusterTopologyMonitor, clusterId: '{}'] Starting cluster topology monitoring thread for '{}'. ClusterTopologyMonitor.StopMonitoringThread=[ClusterTopologyMonitor, clusterId: '{}'] Stop cluster topology monitoring thread for '{}'. -ClusterTopologyMonitorImpl.IgnoringTopologyRequest=[ClusterTopologyMonitor, clusterId: '{}'] A topology refresh was requested, but the topology was already updated recently. Returning cached hosts: -ClusterTopologyMonitorImpl.TopologyNotUpdated=[ClusterTopologyMonitor, clusterId: '{}'] Topology has not been updated after {} ms. -ClusterTopologyMonitorImpl.TimeoutSetToZero=[ClusterTopologyMonitor, clusterId: '{}'] A topology refresh was requested, but the given timeout for the request was 0ms. Returning cached hosts: -ClusterTopologyMonitorImpl.StartingHostMonitoringThreads=[ClusterTopologyMonitor, clusterId: '{}'] Starting host monitoring threads. -ClusterTopologyMonitorImpl.ExceptionStartingHostMonitor=[ClusterTopologyMonitor, clusterId: '{}'] Exception starting monitor for host '{}': '{}'. -ClusterTopologyMonitorImpl.WriterPickedUpFromHostMonitors=[ClusterTopologyMonitor, clusterId: '{}'] The writer host detected by the host monitors was picked up by the topology monitor: '{}'. -ClusterTopologyMonitorImpl.ExceptionDuringMonitoringStop=[ClusterTopologyMonitor, clusterId: '{}'] Stopping cluster topology monitoring after unhandled exception was thrown in monitoring thread '{}'. -ClusterTopologyMonitorImpl.ClosingMonitor=[ClusterTopologyMonitor, clusterId: '{}'] Closing monitor. -ClusterTopologyMonitorImpl.OpenedMonitoringConnection=[ClusterTopologyMonitor, clusterId: '{}'] Opened monitoring connection to host '{}'. -ClusterTopologyMonitorImpl.WriterMonitoringConnection=[ClusterTopologyMonitor, clusterId: '{}'] The monitoring connection is connected to a writer: '{}'. -ClusterTopologyMonitorImpl.ErrorFetchingTopology=[ClusterTopologyMonitor, clusterId: '{}'] An error occurred while querying for topology: {} -ClusterTopologyMonitorImpl.CannotCreateExecutorWhenStopped=[ClusterTopologyMonitor, clusterId: '{}'] Monitor is stopped, cannot create executor. +ClusterTopologyMonitor.IgnoringTopologyRequest=[ClusterTopologyMonitor, clusterId: '{}'] A topology refresh was requested, but the topology was already updated recently. Returning cached hosts: +ClusterTopologyMonitor.TopologyNotUpdated=[ClusterTopologyMonitor, clusterId: '{}'] Topology has not been updated after {} ms. +ClusterTopologyMonitor.TimeoutSetToZero=[ClusterTopologyMonitor, clusterId: '{}'] A topology refresh was requested, but the given timeout for the request was 0ms. Returning cached hosts: +ClusterTopologyMonitor.StartingHostMonitoringThreads=[ClusterTopologyMonitor, clusterId: '{}'] Starting host monitoring threads. +ClusterTopologyMonitor.ExceptionStartingHostMonitor=[ClusterTopologyMonitor, clusterId: '{}'] Exception starting monitor for host '{}': '{}'. +ClusterTopologyMonitor.WriterPickedUpFromHostMonitors=[ClusterTopologyMonitor, clusterId: '{}'] The writer host detected by the host monitors was picked up by the topology monitor: '{}'. +ClusterTopologyMonitor.ExceptionDuringMonitoringStop=[ClusterTopologyMonitor, clusterId: '{}'] Stopping cluster topology monitoring after unhandled exception was thrown in monitoring thread '{}'. +ClusterTopologyMonitor.ClosingMonitor=[ClusterTopologyMonitor, clusterId: '{}'] Closing monitor. +ClusterTopologyMonitor.OpenedMonitoringConnection=[ClusterTopologyMonitor, clusterId: '{}'] Opened monitoring connection to host '{}'. +ClusterTopologyMonitor.WriterMonitoringConnection=[ClusterTopologyMonitor, clusterId: '{}'] The monitoring connection is connected to a writer: '{}'. +ClusterTopologyMonitor.ErrorFetchingTopology=[ClusterTopologyMonitor, clusterId: '{}'] An error occurred while querying for topology: {} +ClusterTopologyMonitor.CannotCreateExecutorWhenStopped=[ClusterTopologyMonitor, clusterId: '{}'] Monitor is stopped, cannot create executor. +ClusterTopologyMonitor.ResetEventReceived=[ClusterTopologyMonitor] MonitorResetEvent received for cluster '{}'. conftest.ExceptionWhileObtainingInstanceIDs=[conftest] An exception was thrown while attempting to obtain the cluster's instance IDs: '{}' @@ -285,6 +286,7 @@ Monitor.StoppingMonitorUnhandledException=[Monitor] Stopping thread after an unh Monitor.InterruptedException=[Monitor] Monitoring thread for host '{}' was interrupted. Monitor.OpenedMonitorConnection=[Monitor] Opened a monitoring connection to '{}'. Monitor.OpeningMonitorConnection=[Monitor] Opening a monitoring connection to '{}'. +Monitor.ResetEventReceived=[Monitor] MonitorResetEvent received for host '{}'. MonitorContext.ExceptionAbortingConnection=[MonitorContext] An exception occurred while attempting to abort the monitored connection: '{}'. MonitorContext.HostAvailable=[MonitorContext] Host '{}' is *available*. @@ -315,6 +317,7 @@ HostResponseTimeMonitor.OpenedConnection=[HostResponseTimeMonitor] Opened Respon HostResponseTimeMonitor.OpeningConnection=[HostResponseTimeMonitor] Opening a Response time connection to '{}'. HostResponseTimeMonitor.ResponseTime=[HostResponseTimeMonitor] Response time for '{}': {} ms HostResponseTimeMonitor.Stopped=[HostResponseTimeMonitor] Stopped Response time thread for host '{}'. +HostResponseTimeMonitor.ResetEventReceived=[HostResponseTimeMonitor] MonitorResetEvent received for host '{}'. ThreadPoolContainer.ErrorShuttingDownPool=[ThreadPoolContainer] Error shutting down pool '{}': '{}'. @@ -497,3 +500,15 @@ XRayTelemetryFactory.WrongParameterType="[XRayTelemetryFactory] Wrong parameter SlidingExpirationCacheContainer.ErrorReleasingCache=[SlidingExpirationCacheContainer] Error releasing cache '{}': {} SlidingExpirationCacheContainer.ErrorDuringCleanup=[SlidingExpirationCacheContainer] Error during cleanup of cache '{}': {} + +BatchingEventPublisher.ErrorDeliveringEvent=[BatchingEventPublisher] Error delivering event: {} + +CoreServices.ErrorShuttingDownPool=[CoreServices] Error shutting down thread pool '{}': {} +CoreServices.ErrorReleasingMonitorService=[CoreServices] Error releasing monitor service: {} +CoreServices.ErrorReleasingStorageService=[CoreServices] Error releasing storage service: {} +CoreServices.ErrorReleasingEventPublisher=[CoreServices] Error releasing event publisher: {} + +MonitorService.ErrorDisposingMonitor=[MonitorService] Error disposing monitor: {} +MonitorService.ErrorPropagatingEvent=[MonitorService] Error propagating event: {} +MonitorService.StuckMonitorDetected=[MonitorService] Stuck monitor detected for type '{}' with key '{}'. +MonitorService.ErrorDuringCleanup=[MonitorService] Error during cleanup: {} diff --git a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py index 7e980d99f..927bd3c49 100644 --- a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py +++ b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py @@ -34,7 +34,7 @@ WrapperProperties) from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils -from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ +from aws_advanced_python_wrapper.utils.storage.sliding_expiration_cache import \ SlidingExpirationCache diff --git a/aws_advanced_python_wrapper/thread_pool_container.py b/aws_advanced_python_wrapper/thread_pool_container.py deleted file mode 100644 index 9254dbb2f..000000000 --- a/aws_advanced_python_wrapper/thread_pool_container.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import threading -from concurrent.futures import ThreadPoolExecutor -from typing import ClassVar, Dict, List, Optional - -from aws_advanced_python_wrapper.utils.log import Logger - -logger = Logger(__name__) - - -class ThreadPoolContainer: - """ - A container class for managing multiple named thread pools. - Provides static methods for getting, creating, and releasing thread pools. - """ - - _pools: ClassVar[Dict[str, ThreadPoolExecutor]] = {} - _lock: ClassVar[threading.Lock] = threading.Lock() - _default_max_workers: ClassVar[Optional[int]] = None # Uses Python's default - - @classmethod - def get_thread_pool( - cls, - name: str, - max_workers: Optional[int] = None - ) -> ThreadPoolExecutor: - """ - Get an existing thread pool or create a new one if it doesn't exist. - - Args: - name: Unique identifier for the thread pool - max_workers: Max worker threads (only used when creating new pool) - If None, uses Python's default: min(32, os.cpu_count() + 4) - - Returns: - ThreadPoolExecutor instance - """ - with cls._lock: - if name not in cls._pools: - workers = max_workers or cls._default_max_workers - cls._pools[name] = ThreadPoolExecutor( - max_workers=workers, - thread_name_prefix=name - ) - return cls._pools[name] - - @classmethod - def release_resources(cls, wait=False) -> None: - """ - Shutdown all thread pools and release resources. - - Args: - wait: If True, wait for all pending tasks to complete - """ - with cls._lock: - for name, pool in cls._pools.items(): - try: - pool.shutdown(wait=wait) - except Exception as e: - logger.warning("ThreadPoolContainer.ErrorShuttingDownPool", name, e) - cls._pools.clear() - - @classmethod - def release_pool(cls, name: str, wait: bool = True) -> bool: - """ - Release a specific thread pool by name. - - Args: - name: The name of the thread pool to release - wait: If True, wait for pending tasks to complete - - Returns: - True if pool was found and released, False otherwise - """ - with cls._lock: - if name in cls._pools: - try: - cls._pools[name].shutdown(wait=wait) - del cls._pools[name] - return True - except Exception as e: - logger.warning("ThreadPoolContainer.ErrorShuttingDownPool", name, e) - return False - - @classmethod - def has_pool(cls, name: str) -> bool: - """Check if a pool with the given name exists.""" - with cls._lock: - return name in cls._pools - - @classmethod - def get_pool_names(cls) -> List[str]: - """Get a list of all active pool names.""" - with cls._lock: - return list(cls._pools.keys()) - - @classmethod - def get_pool_count(cls) -> int: - """Get the number of active pools.""" - with cls._lock: - return len(cls._pools) - - @classmethod - def set_default_max_workers(cls, max_workers: Optional[int]) -> None: - """Set the default max workers for new pools.""" - cls._default_max_workers = max_workers diff --git a/aws_advanced_python_wrapper/utils/concurrent.py b/aws_advanced_python_wrapper/utils/concurrent.py index a209810e8..5a8250872 100644 --- a/aws_advanced_python_wrapper/utils/concurrent.py +++ b/aws_advanced_python_wrapper/utils/concurrent.py @@ -77,6 +77,25 @@ def compute_if_absent(self, key: K, mapping_func: Callable) -> Optional[V]: return new_value return value + def compute_for_keys( + self, keys, factory: Callable, + on_existing: Optional[Callable] = None) -> V: + with self._lock: + value = None + for key in keys: + value = self._dict.get(key) + if value is not None: + if on_existing: + on_existing(value) + break + + if value is None: + value = factory() + + for key in keys: + self._dict.setdefault(key, value) + return value + def put(self, key: K, value: V): with self._lock: self._dict[key] = value @@ -98,6 +117,14 @@ def remove(self, key: K) -> V: with self._lock: return self._dict.pop(key, None) + def remove_key_if(self, key: K, predicate: Callable[[V], bool]) -> Optional[V]: + with self._lock: + value = self._dict.get(key) + if value is not None and predicate(value): + del self._dict[key] + return value + return None + def remove_if(self, predicate: Callable) -> bool: with self._lock: original_len = len(self._dict) diff --git a/aws_advanced_python_wrapper/utils/events.py b/aws_advanced_python_wrapper/utils/events.py new file mode 100644 index 000000000..86b7fcc14 --- /dev/null +++ b/aws_advanced_python_wrapper/utils/events.py @@ -0,0 +1,146 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import weakref +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from threading import Event, Lock, Thread +from typing import Any, Dict, Protocol, Set, Type, runtime_checkable + +from aws_advanced_python_wrapper.utils.log import Logger + +logger = Logger(__name__) + + +class EventBase: + """Base class for all events.""" + immediate_delivery: bool = False + + +@dataclass(frozen=True, eq=True) +class DataAccessEvent(EventBase): + """Published when data is accessed in StorageService.""" + data_type: type + key: Any + immediate_delivery: bool = field(default=False, compare=False, hash=False) + + +@dataclass(frozen=True, eq=True) +class MonitorStopEvent(EventBase): + """Published to signal a monitor should be stopped.""" + monitor_type: type + key: Any + immediate_delivery: bool = field(default=True, compare=False, hash=False) + + +@dataclass(frozen=True, eq=True) +class MonitorResetEvent(EventBase): + """Published during Blue/Green switchover to reset monitors holding stale connections.""" + cluster_id: str + endpoints: frozenset + immediate_delivery: bool = field(default=True, compare=False, hash=False) + + +@runtime_checkable +class EventSubscriber(Protocol): + @abstractmethod + def process_event(self, event: EventBase) -> None: + ... + + +class EventPublisher(ABC): + @abstractmethod + def subscribe(self, subscriber: EventSubscriber, event_types: Set[Type[EventBase]]) -> None: + ... + + @abstractmethod + def unsubscribe(self, subscriber: EventSubscriber, event_types: Set[Type[EventBase]]) -> None: + ... + + @abstractmethod + def publish(self, event: EventBase) -> None: + ... + + @abstractmethod + def release_resources(self) -> None: + ... + + +class BatchingEventPublisher(EventPublisher): + _DELIVERY_INTERVAL_SEC = 30.0 + + def __init__(self) -> None: + self._subscribers: Dict[Type[EventBase], weakref.WeakSet[EventSubscriber]] = {} + self._pending_events: Set[EventBase] = set() + self._lock = Lock() + self._stop_event = Event() + self._thread = Thread( + target=self._delivery_loop, daemon=True, name="BatchingEventPublisher") + self._thread.start() + + def subscribe(self, subscriber: EventSubscriber, event_types: Set[Type[EventBase]]) -> None: + with self._lock: + for event_type in event_types: + if event_type not in self._subscribers: + self._subscribers[event_type] = weakref.WeakSet() + self._subscribers[event_type].add(subscriber) + + def unsubscribe(self, subscriber: EventSubscriber, event_types: Set[Type[EventBase]]) -> None: + with self._lock: + for event_type in event_types: + ws = self._subscribers.get(event_type) + if ws is not None: + ws.discard(subscriber) + + def publish(self, event: EventBase) -> None: + if event.immediate_delivery: + self._deliver(event) + else: + with self._lock: + self._pending_events.add(event) + + def release_resources(self) -> None: + self._stop_event.set() + if self._thread.is_alive(): + self._thread.join(timeout=2.0) + with self._lock: + self._pending_events.clear() + self._subscribers.clear() + + def _delivery_loop(self) -> None: + while not self._stop_event.is_set(): + if self._stop_event.wait(timeout=self._DELIVERY_INTERVAL_SEC): + break + self._drain_and_deliver() + + def _drain_and_deliver(self) -> None: + with self._lock: + events = list(self._pending_events) + self._pending_events.clear() + for event in events: + self._deliver(event) + + def _deliver(self, event: EventBase) -> None: + with self._lock: + ws = self._subscribers.get(type(event)) + if ws is None: + return + subscribers = list(ws) + for subscriber in subscribers: + try: + subscriber.process_event(event) + except Exception as e: + logger.debug("BatchingEventPublisher.ErrorDeliveringEvent", e) diff --git a/aws_advanced_python_wrapper/utils/monitor_service.py b/aws_advanced_python_wrapper/utils/monitor_service.py new file mode 100644 index 000000000..2a7006dfa --- /dev/null +++ b/aws_advanced_python_wrapper/utils/monitor_service.py @@ -0,0 +1,222 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import threading +from dataclasses import dataclass +from time import perf_counter_ns +from typing import (Any, Callable, Dict, FrozenSet, Optional, Protocol, + runtime_checkable) + +from aws_advanced_python_wrapper.utils.events import (DataAccessEvent, + EventBase, + EventPublisher, + EventSubscriber, + MonitorResetEvent, + MonitorStopEvent) +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.storage.expiration_tracking_cache import \ + ExpirationTrackingCache + +logger = Logger(__name__) + + +@runtime_checkable +class Monitor(Protocol): + """Structural interface for all monitors managed by MonitorService. + + stop() signals the monitor to stop and releases all resources (calls close() internally). + close() releases resources only — called by stop(), not directly by external code. + """ + + def stop(self) -> None: ... + def close(self) -> None: ... + + @property + def can_dispose(self) -> bool: ... + + @property + def last_activity_ns(self) -> int: ... + + +@dataclass +class MonitorSettings: + expiration_timeout_ns: int = 15 * 60 * 1_000_000_000 # 15 min + inactive_timeout_ns: int = 3 * 60 * 1_000_000_000 # 3 min + produced_data_type: Optional[type] = None + + +class _CacheContainer: + def __init__(self, settings: MonitorSettings) -> None: + self.settings = settings + self.cache: ExpirationTrackingCache = ExpirationTrackingCache( + expiration_timeout_ns=settings.expiration_timeout_ns) + + @staticmethod + def _dispose(monitor: Monitor) -> None: + try: + monitor.stop() + except Exception as e: + logger.debug("MonitorService.ErrorDisposingMonitor", e) + + +class MonitorService(EventSubscriber): + """Centralized monitor lifecycle manager.""" + + _CLEANUP_INTERVAL_SEC = 60.0 + + def __init__(self, event_publisher: EventPublisher) -> None: + self._event_publisher = event_publisher + self._monitor_caches: Dict[type, _CacheContainer] = {} + self._lock = threading.RLock() + self._stop_event = threading.Event() + + self._event_publisher.subscribe(self, {DataAccessEvent, MonitorStopEvent, MonitorResetEvent}) + + self._cleanup_thread = threading.Thread( + target=self._cleanup_loop, daemon=True, name="MonitorService-Cleanup") + self._cleanup_thread.start() + + def register_monitor_type( + self, + monitor_type: type, + expiration_timeout_ns: int = 15 * 60 * 1_000_000_000, + inactive_timeout_ns: int = 3 * 60 * 1_000_000_000, + produced_data_type: Optional[type] = None) -> None: + with self._lock: + if monitor_type not in self._monitor_caches: + settings = MonitorSettings( + expiration_timeout_ns=expiration_timeout_ns, + inactive_timeout_ns=inactive_timeout_ns, + produced_data_type=produced_data_type) + self._monitor_caches[monitor_type] = _CacheContainer(settings) + + def run_if_absent( + self, + monitor_type: type, + key: Any, + factory: Callable[[], Any]) -> Any: + container = self._get_or_create_container(monitor_type) + return container.cache.compute_if_absent(key, lambda _: factory()) + + def run_if_absent_with_aliases( + self, + monitor_type: type, + aliases: FrozenSet[str], + factory: Callable[[], Any]) -> Any: + container = self._get_or_create_container(monitor_type) + return container.cache.get_or_create_for_aliases(aliases, factory) + + def detach(self, monitor_type: type, monitor: Any) -> None: + container = self._monitor_caches.get(monitor_type) + if container is not None: + container.cache.detach_value(monitor) + + def get(self, monitor_type: type, key: Any) -> Optional[Any]: + container = self._monitor_caches.get(monitor_type) + if container is None: + return None + return container.cache.get(key) + + def count(self, monitor_type: type) -> int: + container = self._monitor_caches.get(monitor_type) + return len(container.cache) if container is not None else 0 + + def stop_and_remove(self, monitor_type: type, key: Any) -> None: + container = self._monitor_caches.get(monitor_type) + if container is not None: + monitor = container.cache.remove(key) + if monitor is not None: + _CacheContainer._dispose(monitor) + + def stop_and_remove_all(self, monitor_type: type) -> None: + container = self._monitor_caches.get(monitor_type) + if container is not None: + for monitor in container.cache.clear(): + _CacheContainer._dispose(monitor) + + def stop_all(self) -> None: + with self._lock: + for container in self._monitor_caches.values(): + for monitor in container.cache.clear(): + _CacheContainer._dispose(monitor) + + def release_resources(self) -> None: + self._stop_event.set() + self._event_publisher.unsubscribe(self, {DataAccessEvent, MonitorStopEvent, MonitorResetEvent}) + self.stop_all() + if self._cleanup_thread.is_alive(): + self._cleanup_thread.join(timeout=2.0) + with self._lock: + self._monitor_caches.clear() + + def process_event(self, event: EventBase) -> None: + if isinstance(event, DataAccessEvent): + self._on_data_access(event) + return + if isinstance(event, MonitorStopEvent): + self.stop_and_remove(event.monitor_type, event.key) + return + + with self._lock: + containers = list(self._monitor_caches.values()) + for container in containers: + for _key, monitor in container.cache.items(): + if isinstance(monitor, EventSubscriber): + try: + monitor.process_event(event) + except Exception as e: + logger.debug("MonitorService.ErrorPropagatingEvent", e) + + def _on_data_access(self, event: DataAccessEvent) -> None: + """Extend expiration of monitors whose produced_data_type matches.""" + with self._lock: + containers = list(self._monitor_caches.values()) + for container in containers: + if container.settings.produced_data_type == event.data_type: + container.cache.extend_expiration(event.key) + + def _get_or_create_container(self, monitor_type: type) -> _CacheContainer: + with self._lock: + if monitor_type not in self._monitor_caches: + self._monitor_caches[monitor_type] = _CacheContainer(MonitorSettings()) + return self._monitor_caches[monitor_type] + + def _cleanup_loop(self) -> None: + while not self._stop_event.is_set(): + if self._stop_event.wait(timeout=self._CLEANUP_INTERVAL_SEC): + break + self._run_cleanup() + + def _run_cleanup(self) -> None: + with self._lock: + containers = list(self._monitor_caches.items()) + now = perf_counter_ns() + for monitor_type, container in containers: + try: + inactive_timeout = container.settings.inactive_timeout_ns + for key, monitor in container.cache.items(): + if now - monitor.last_activity_ns > inactive_timeout: + logger.debug("MonitorService.StuckMonitorDetected", monitor_type, key) + removed = container.cache.remove(key) + if removed is not None: + _CacheContainer._dispose(removed) + continue + + removed = container.cache.remove_expired_if(key, lambda m: m.can_dispose) + if removed is not None: + _CacheContainer._dispose(removed) + except Exception as e: + logger.debug("MonitorService.ErrorDuringCleanup", e) diff --git a/aws_advanced_python_wrapper/utils/services_container.py b/aws_advanced_python_wrapper/utils/services_container.py new file mode 100644 index 000000000..07370d369 --- /dev/null +++ b/aws_advanced_python_wrapper/utils/services_container.py @@ -0,0 +1,144 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import threading +from concurrent.futures import ThreadPoolExecutor +from datetime import timedelta +from typing import Dict, Optional + +from aws_advanced_python_wrapper.allowed_and_blocked_hosts import \ + AllowedAndBlockedHosts +from aws_advanced_python_wrapper.hostinfo import Topology +from aws_advanced_python_wrapper.utils.events import BatchingEventPublisher +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.monitor_service import MonitorService +from aws_advanced_python_wrapper.utils.storage.storage_service import \ + StorageService + +logger = Logger(__name__) + + +class _ServicesContainer: + def __init__(self) -> None: + self._event_publisher: Optional[BatchingEventPublisher] = None + self._storage_service: Optional[StorageService] = None + self._monitor_service: Optional[MonitorService] = None + self._thread_pools: Dict[str, ThreadPoolExecutor] = {} + self._lock = threading.Lock() + + def _ensure_initialized(self) -> None: + if self._event_publisher is not None: + return + self._event_publisher = BatchingEventPublisher() + self._storage_service = StorageService(self._event_publisher) + self._storage_service.register(Topology, item_expiration_time=timedelta(minutes=5)) + self._storage_service.register(AllowedAndBlockedHosts, item_expiration_time=timedelta(minutes=5)) + self._monitor_service = MonitorService(self._event_publisher) + + @property + def event_publisher(self) -> BatchingEventPublisher: + self._ensure_initialized() + return self._event_publisher # type: ignore + + @property + def storage_service(self) -> StorageService: + self._ensure_initialized() + return self._storage_service # type: ignore + + @property + def monitor_service(self) -> MonitorService: + self._ensure_initialized() + return self._monitor_service # type: ignore + + def get_thread_pool(self, name: str, max_workers: Optional[int] = None) -> ThreadPoolExecutor: + pool = self._thread_pools.get(name) + if pool is not None: + return pool + with self._lock: + if name not in self._thread_pools: + self._thread_pools[name] = ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix=name) + return self._thread_pools[name] + + def release_thread_pool(self, name: str, wait: bool = True) -> bool: + with self._lock: + pool = self._thread_pools.pop(name, None) + if pool is not None: + try: + pool.shutdown(wait=wait) + except Exception as e: + logger.warning("CoreServices.ErrorShuttingDownPool", name, e) + return True + return False + + def release_resources(self) -> None: + if self._monitor_service is not None: + try: + self._monitor_service.release_resources() + except Exception as e: + logger.debug("CoreServices.ErrorReleasingMonitorService", e) + + if self._storage_service is not None: + try: + self._storage_service.release_resources() + except Exception as e: + logger.debug("CoreServices.ErrorReleasingStorageService", e) + + if self._event_publisher is not None: + try: + self._event_publisher.release_resources() + except Exception as e: + logger.debug("CoreServices.ErrorReleasingEventPublisher", e) + + self._event_publisher = None + self._storage_service = None + self._monitor_service = None + + with self._lock: + for name, pool in self._thread_pools.items(): + try: + pool.shutdown(wait=False) + except Exception as e: + logger.debug("CoreServices.ErrorShuttingDownPool", name, e) + self._thread_pools.clear() + + +_instance = _ServicesContainer() +_instance._ensure_initialized() + + +def get_event_publisher() -> BatchingEventPublisher: + return _instance.event_publisher + + +def get_storage_service() -> StorageService: + return _instance.storage_service + + +def get_monitor_service() -> MonitorService: + return _instance.monitor_service + + +def get_thread_pool(name: str, max_workers: Optional[int] = None) -> ThreadPoolExecutor: + return _instance.get_thread_pool(name, max_workers) + + +def release_thread_pool(name: str, wait: bool = True) -> bool: + return _instance.release_thread_pool(name, wait) + + +def release_resources() -> None: + _instance.release_resources() diff --git a/aws_advanced_python_wrapper/utils/sliding_expiration_cache_container.py b/aws_advanced_python_wrapper/utils/sliding_expiration_cache_container.py deleted file mode 100644 index 0f9f1e716..000000000 --- a/aws_advanced_python_wrapper/utils/sliding_expiration_cache_container.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import threading -from threading import Event, Thread -from typing import Callable, ClassVar, Dict, Optional - -from aws_advanced_python_wrapper.utils.log import Logger -from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ - SlidingExpirationCache - -logger = Logger(__name__) - - -class SlidingExpirationCacheContainer: - """ - A container class for managing multiple named sliding expiration caches. - Provides static methods for getting, creating, and releasing caches. - - This container manages SlidingExpirationCache instances and provides a single - cleanup thread that periodically cleans up all managed caches. - """ - - _caches: ClassVar[Dict[str, SlidingExpirationCache]] = {} - _lock: ClassVar[threading.Lock] = threading.Lock() - _cleanup_thread: ClassVar[Optional[Thread]] = None - _cleanup_interval_ns: ClassVar[int] = 300_000_000_000 # 5 minutes default - _is_stopped: ClassVar[Event] = Event() - - @classmethod - def get_or_create_cache( - cls, - name: str, - cleanup_interval_ns: int = 10 * 60_000_000_000, # 10 minutes - should_dispose_func: Optional[Callable] = None, - item_disposal_func: Optional[Callable] = None - ) -> SlidingExpirationCache: - """ - Get an existing cache or create a new one if it doesn't exist. - - The cleanup thread is started lazily when the first cache is created. - - Args: - name: Unique identifier for the cache - cleanup_interval_ns: Cleanup interval in nanoseconds (only used when creating new cache) - should_dispose_func: Optional function to determine if item should be disposed - item_disposal_func: Optional function to dispose items - - Returns: - SlidingExpirationCache instance - """ - with cls._lock: - if name not in cls._caches: - cls._caches[name] = SlidingExpirationCache( - cleanup_interval_ns=cleanup_interval_ns, - should_dispose_func=should_dispose_func, - item_disposal_func=item_disposal_func - ) - - # Start cleanup thread if not already running - if cls._cleanup_thread is None or not cls._cleanup_thread.is_alive(): - cls._is_stopped.clear() - cls._cleanup_thread = Thread( - target=cls._cleanup_thread_internal, - daemon=True, - name="SlidingExpirationCacheContainer-Cleanup" - ) - cls._cleanup_thread.start() - - return cls._caches[name] - - @classmethod - def release_resources(cls) -> None: - """ - Clear all caches and stop the cleanup thread. - This will dispose all cached items and release all resources. - """ - with cls._lock: - # Stop the cleanup thread - cls._is_stopped.set() - - # Clear all caches (will dispose items if disposal function is set) - for name, cache in cls._caches.items(): - try: - cache.clear() - except Exception as e: - logger.warning("SlidingExpirationCacheContainer.ErrorReleasingCache", name, e) - - cls._caches.clear() - - # Wait for cleanup thread to stop (outside the lock) - if cls._cleanup_thread is not None and cls._cleanup_thread.is_alive(): - cls._cleanup_thread.join(timeout=2.0) - cls._cleanup_thread = None - - @classmethod - def _cleanup_thread_internal(cls) -> None: - while not cls._is_stopped.is_set(): - # Wait for the cleanup interval or until stopped - if cls._is_stopped.wait(timeout=cls._cleanup_interval_ns / 1_000_000_000): - break - - # Cleanup all caches - with cls._lock: - cache_items = list(cls._caches.items()) - - for name, cache in cache_items: - try: - cache.cleanup() - except Exception as e: - logger.debug("SlidingExpirationCacheContainer.ErrorDuringCleanup", name, e) diff --git a/aws_advanced_python_wrapper/utils/cache_map.py b/aws_advanced_python_wrapper/utils/storage/cache_map.py similarity index 100% rename from aws_advanced_python_wrapper/utils/cache_map.py rename to aws_advanced_python_wrapper/utils/storage/cache_map.py diff --git a/aws_advanced_python_wrapper/utils/storage/expiration_tracking_cache.py b/aws_advanced_python_wrapper/utils/storage/expiration_tracking_cache.py new file mode 100644 index 000000000..304140639 --- /dev/null +++ b/aws_advanced_python_wrapper/utils/storage/expiration_tracking_cache.py @@ -0,0 +1,102 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from time import perf_counter_ns +from typing import Callable, FrozenSet, Generic, List, Optional, Tuple, TypeVar + +from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict + +K = TypeVar('K') +V = TypeVar('V') + + +class _CacheItem(Generic[V]): + def __init__(self, item: V, expiration_ns: int) -> None: + self.item = item + self.expiration_ns = expiration_ns + + def is_expired(self) -> bool: + return perf_counter_ns() > self.expiration_ns + + def extend(self, duration_ns: int) -> None: + self.expiration_ns = perf_counter_ns() + duration_ns + + +class ExpirationTrackingCache(Generic[K, V]): + """A cache that tracks expiration but never cleans itself up. + Removal of expired entries is handled by an external caller.""" + + def __init__(self, expiration_timeout_ns: int) -> None: + self._expiration_timeout_ns = expiration_timeout_ns + self._cache: ConcurrentDict[K, _CacheItem[V]] = ConcurrentDict() + + def __len__(self) -> int: + return len(self._cache) + + def get(self, key: K) -> Optional[V]: + entry = self._cache.get(key) + if entry is None or entry.is_expired(): + return None + return entry.item + + def put(self, key: K, value: V) -> Optional[V]: + old = self._cache.remove(key) + self._cache.put(key, _CacheItem(value, perf_counter_ns() + self._expiration_timeout_ns)) + return old.item if old is not None else None + + def compute_if_absent(self, key: K, factory: Callable[[K], V]) -> Optional[V]: + entry = self._cache.compute_if_absent( + key, lambda k: _CacheItem(factory(k), perf_counter_ns() + self._expiration_timeout_ns)) + if entry is None: + return None + entry.extend(self._expiration_timeout_ns) + return entry.item + + def get_or_create_for_aliases(self, aliases: FrozenSet[K], factory: Callable[[], V]) -> V: + entry = self._cache.compute_for_keys( + aliases, + lambda: _CacheItem(factory(), perf_counter_ns() + self._expiration_timeout_ns), + on_existing=lambda e: e.extend(self._expiration_timeout_ns)) + return entry.item + + def extend_expiration(self, key: K) -> None: + entry = self._cache.get(key) + if entry is not None: + entry.extend(self._expiration_timeout_ns) + + def remove(self, key: K) -> Optional[V]: + entry = self._cache.remove(key) + return entry.item if entry is not None else None + + def remove_expired_if(self, key: K, predicate: Callable[[V], bool]) -> Optional[V]: + removed = self._cache.remove_key_if( + key, lambda entry: entry.is_expired() and predicate(entry.item)) + return removed.item if removed is not None else None + + def detach_value(self, value: V) -> bool: + for key, entry in self._cache.items(): + if entry.item is value: + self._cache.remove(key) + return True + return False + + def items(self) -> List[Tuple[K, V]]: + return [(k, entry.item) for k, entry in self._cache.items()] + + def clear(self) -> List[V]: + values: List[V] = [] + self._cache.clear(lambda _k, entry: values.append(entry.item)) + return values diff --git a/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py b/aws_advanced_python_wrapper/utils/storage/sliding_expiration_cache.py similarity index 62% rename from aws_advanced_python_wrapper/utils/sliding_expiration_cache.py rename to aws_advanced_python_wrapper/utils/storage/sliding_expiration_cache.py index 8033362e5..4a74003ad 100644 --- a/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py +++ b/aws_advanced_python_wrapper/utils/storage/sliding_expiration_cache.py @@ -14,8 +14,7 @@ from __future__ import annotations -from threading import Thread -from time import perf_counter_ns, sleep +from time import perf_counter_ns from typing import Callable, Generic, List, Optional, Tuple, TypeVar from aws_advanced_python_wrapper.utils.atomic import AtomicInt @@ -43,6 +42,9 @@ def __init__( def __len__(self): return len(self._cdict) + def __contains__(self, key: K) -> bool: + return key in self._cdict + def set_cleanup_interval_ns(self, interval_ns): self._cleanup_interval_ns = interval_ns @@ -58,12 +60,12 @@ def compute_if_absent(self, key: K, mapping_func: Callable, item_expiration_ns: key, lambda k: CacheItem(mapping_func(k), perf_counter_ns() + item_expiration_ns)) return None if cache_item is None else cache_item.update_expiration(item_expiration_ns).item - def compute_if_absent_with_disposal(self, key: K, mapping_func: Callable, item_expiration_ns: int) -> Optional[V]: + def put(self, key: K, value: V, item_expiration_ns: int) -> None: self.cleanup() - self._remove_if_disposable(key) - cache_item = self._cdict.compute_if_absent( - key, lambda k: CacheItem(mapping_func(k), perf_counter_ns() + item_expiration_ns)) - return None if cache_item is None else cache_item.update_expiration(item_expiration_ns).item + old = self._cdict.remove(key) + if old is not None and self._item_disposal_func is not None: + self._item_disposal_func(old.item) + self._cdict.put(key, CacheItem(value, perf_counter_ns() + item_expiration_ns)) def get(self, key: K) -> Optional[V]: self.cleanup() @@ -71,7 +73,9 @@ def get(self, key: K) -> Optional[V]: return cache_item.item if cache_item is not None else None def remove(self, key: K): - self._remove_and_dispose(key) + cache_item = self._cdict.remove(key) + if cache_item is not None and self._item_disposal_func is not None: + self._item_disposal_func(cache_item.item) self.cleanup() def clear(self): @@ -88,35 +92,16 @@ def cleanup(self): return self._cleanup_time_ns.set(current_time + self._cleanup_interval_ns) + to_dispose = [] keys = self._cdict.keys() for key in keys: - self._remove_if_expired(key) - - def _remove_if_disposable(self, key: K): - def _remove_if_disposable_internal(_, cache_item): - if self._should_dispose_func is not None and self._should_dispose_func(cache_item.item): - if self._item_disposal_func is not None: - self._item_disposal_func(cache_item.item) - return None - return cache_item - - self._cdict.compute_if_present(key, _remove_if_disposable_internal) - - def _remove_and_dispose(self, key: K): - cache_item = self._cdict.remove(key) - if cache_item is not None and self._item_disposal_func is not None: - self._item_disposal_func(cache_item.item) - - def _remove_if_expired(self, key: K): - def _remove_if_expired_internal(_, cache_item): - if self._should_cleanup_item(cache_item): - # Dispose while holding the lock to prevent race conditions - if self._item_disposal_func is not None: - self._item_disposal_func(cache_item.item) - return None - return cache_item - - self._cdict.compute_if_present(key, _remove_if_expired_internal) + cache_item = self._cdict.remove_key_if(key, self._should_cleanup_item) + if cache_item is not None: + to_dispose.append(cache_item.item) + # Dispose outside the lock to avoid blocking cache operations during slow disposal (e.g. thread.join) + for item in to_dispose: + if self._item_disposal_func is not None: + self._item_disposal_func(item) def _should_cleanup_item(self, cache_item: CacheItem) -> bool: if self._should_dispose_func is not None: @@ -124,27 +109,6 @@ def _should_cleanup_item(self, cache_item: CacheItem) -> bool: return perf_counter_ns() > cache_item.expiration_time -class SlidingExpirationCacheWithCleanupThread(SlidingExpirationCache, Generic[K, V]): - def __init__( - self, - cleanup_interval_ns: int = 10 * 60_000_000_000, # 10 minutes - should_dispose_func: Optional[Callable] = None, - item_disposal_func: Optional[Callable] = None): - super().__init__(cleanup_interval_ns, should_dispose_func, item_disposal_func) - self._cleanup_thread = Thread(target=self._cleanup_thread_internal, daemon=True) - self._cleanup_thread.start() - - def _cleanup_thread_internal(self): - while True: - try: - sleep(self._cleanup_interval_ns / 1_000_000_000) - # Force cleanup by resetting the interval timer - self._cleanup_time_ns.set(0) - self.cleanup() - except Exception: - break - - class CacheItem(Generic[V]): def __init__(self, item: V, expiration_time: int): self.item = item diff --git a/aws_advanced_python_wrapper/utils/storage/storage_service.py b/aws_advanced_python_wrapper/utils/storage/storage_service.py index a5543a2c7..2df472442 100644 --- a/aws_advanced_python_wrapper/utils/storage/storage_service.py +++ b/aws_advanced_python_wrapper/utils/storage/storage_service.py @@ -12,58 +12,85 @@ # See the License for the specific language governing permissions and # limitations under the License. -from types import MappingProxyType -from typing import (TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Type, - TypeAlias, TypeVar) +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar + +from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict +from aws_advanced_python_wrapper.utils.events import (DataAccessEvent, + EventPublisher) +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.storage.sliding_expiration_cache import \ + SlidingExpirationCache if TYPE_CHECKING: - from aws_advanced_python_wrapper.hostinfo import HostInfo + from datetime import timedelta -from aws_advanced_python_wrapper.utils.cache_map import CacheMap +logger = Logger(__name__) V = TypeVar('V') -Topology: TypeAlias = Tuple["HostInfo", ...] class StorageService: - _storage_map: ClassVar[MappingProxyType] = MappingProxyType({ - Topology: CacheMap() - }) + """Instance-based typed key-value cache with expiration and event publishing.""" + + def __init__(self, event_publisher: EventPublisher) -> None: + self._event_publisher = event_publisher + self._caches: ConcurrentDict[type, SlidingExpirationCache] = ConcurrentDict() - @staticmethod - def get(item_class: Type[V], key: Any) -> Optional[V]: - cache = StorageService._storage_map.get(item_class) + def register( + self, + item_type: Type[V], + item_expiration_time: timedelta, + should_dispose: Optional[Callable] = None, + on_dispose: Optional[Callable] = None) -> None: + item_expiration_ns = int(item_expiration_time.total_seconds() * 1_000_000_000) + self._caches.compute_if_absent( + item_type, + lambda _: SlidingExpirationCache( + cleanup_interval_ns=item_expiration_ns, + should_dispose_func=should_dispose, + item_disposal_func=on_dispose)) + + def get(self, item_type: Type[V], key: Any) -> Optional[V]: + cache = self._caches.get(item_type) if cache is None: return None - value = cache.get(key) - # TODO: publish data access event + if value is not None: + self._event_publisher.publish(DataAccessEvent(data_type=item_type, key=key)) return value - @staticmethod - def get_all(item_class: Type[V]) -> Optional[CacheMap[Any, V]]: - cache = StorageService._storage_map.get(item_class) - return cache + def put(self, item_type: Type[V], key: Any, value: V, item_expiration_ns: Optional[int] = None) -> None: + cache = self._caches.get(item_type) + if cache is None: + raise ValueError(f"Type {item_type} is not registered with StorageService") + if item_expiration_ns is None: + item_expiration_ns = cache._cleanup_interval_ns + cache.put(key, value, item_expiration_ns) - @staticmethod - def set(key: Any, item: V, item_class: Type[V]) -> None: - cache = StorageService._storage_map.get(item_class) - if cache is not None: - cache.put(key, item) + def exists(self, item_type: Type[V], key: Any) -> bool: + cache = self._caches.get(item_type) + return cache is not None and key in cache - @staticmethod - def remove(item_class: Type, key: Any) -> None: - cache = StorageService._storage_map.get(item_class) + def remove(self, item_type: Type[V], key: Any) -> None: + cache = self._caches.get(item_type) if cache is not None: cache.remove(key) - @staticmethod - def clear(item_class: Type) -> None: - cache = StorageService._storage_map.get(item_class) + def clear(self, item_type: Type[V]) -> None: + cache = self._caches.get(item_type) if cache is not None: cache.clear() - @staticmethod - def clear_all() -> None: - for cache in StorageService._storage_map.values(): + def size(self, item_type: Type[V]) -> int: + cache = self._caches.get(item_type) + return len(cache) if cache is not None else 0 + + def clear_all(self) -> None: + for cache in self._caches.values(): cache.clear() + + def release_resources(self) -> None: + self.clear_all() + self._caches.clear() diff --git a/aws_advanced_python_wrapper/wrapper.py b/aws_advanced_python_wrapper/wrapper.py index 1499d1914..c946751cb 100644 --- a/aws_advanced_python_wrapper/wrapper.py +++ b/aws_advanced_python_wrapper/wrapper.py @@ -17,8 +17,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Type, TypeVar, Union) +from aws_advanced_python_wrapper.plugin_service import ( + PluginManager, PluginServiceImpl, PluginServiceManagerContainer) + if TYPE_CHECKING: from aws_advanced_python_wrapper.host_list_provider import HostListProviderService + from aws_advanced_python_wrapper.plugin_service import PluginService from aws_advanced_python_wrapper.driver_dialect_manager import \ DriverDialectManager @@ -26,9 +30,6 @@ from aws_advanced_python_wrapper.pep249 import Connection, Cursor, Error from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import CanReleaseResources -from aws_advanced_python_wrapper.plugin_service import ( - PluginManager, PluginService, PluginServiceImpl, - PluginServiceManagerContainer) from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, @@ -159,12 +160,16 @@ def connect( try: driver_dialect_manager: DriverDialectManager = DriverDialectManager() driver_dialect = driver_dialect_manager.get_dialect(target_func, props) - container: PluginServiceManagerContainer = PluginServiceManagerContainer() - plugin_service = PluginServiceImpl( - container, props, target_func, driver_dialect_manager, driver_dialect) + + container = PluginServiceManagerContainer() + plugin_service = PluginServiceImpl(container, props, target_func, driver_dialect_manager, driver_dialect) plugin_manager: PluginManager = PluginManager(container, props, telemetry_factory) - return AwsWrapperConnection(target_func, plugin_service, plugin_service, plugin_manager) + return AwsWrapperConnection( + target_func, + plugin_service, + plugin_service, + plugin_manager) except Exception as ex: if context is not None: context.set_exception(ex) diff --git a/tests/integration/container/conftest.py b/tests/integration/container/conftest.py index f32be63ad..88e553119 100644 --- a/tests/integration/container/conftest.py +++ b/tests/integration/container/conftest.py @@ -27,17 +27,10 @@ from aws_advanced_python_wrapper.driver_dialect_manager import \ DriverDialectManager from aws_advanced_python_wrapper.exception_handling import ExceptionManager -from aws_advanced_python_wrapper.host_monitoring_plugin import \ - MonitoringThreadContainer from aws_advanced_python_wrapper.plugin_service import PluginServiceImpl -from aws_advanced_python_wrapper.thread_pool_container import \ - ThreadPoolContainer +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils -from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ - SlidingExpirationCacheContainer -from aws_advanced_python_wrapper.utils.storage.storage_service import \ - StorageService if TYPE_CHECKING: from .utils.test_driver import TestDriver @@ -142,13 +135,10 @@ def pytest_runtest_setup(item): assert cluster_ip == writer_ip RdsUtils.clear_cache() - StorageService.clear_all() + services_container.release_resources() PluginServiceImpl._host_availability_expiring_cache.clear() DatabaseDialectManager._known_endpoint_dialects.clear() CustomEndpointMonitor._custom_endpoint_info_cache.clear() - MonitoringThreadContainer.clean_up() - ThreadPoolContainer.release_resources(wait=True) - SlidingExpirationCacheContainer.release_resources() ConnectionProviderManager.release_resources() ConnectionProviderManager.reset_provider() diff --git a/tests/integration/container/test_read_write_splitting.py b/tests/integration/container/test_read_write_splitting.py index f03e664e3..931b17c8a 100644 --- a/tests/integration/container/test_read_write_splitting.py +++ b/tests/integration/container/test_read_write_splitting.py @@ -25,11 +25,10 @@ ReadWriteSplittingError, TransactionResolutionUnknownError) from aws_advanced_python_wrapper.sql_alchemy_connection_provider import \ SqlAlchemyPooledConnectionProvider +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.storage.storage_service import \ - StorageService from tests.integration.container.utils.conditions import ( disable_on_engines, disable_on_features, enable_on_deployments, enable_on_features, enable_on_num_instances) @@ -79,7 +78,7 @@ def rds_utils(self): @pytest.fixture(autouse=True) def clear_caches(self): - StorageService.clear_all() + services_container.get_storage_service().clear_all() yield ConnectionProviderManager.release_resources() ConnectionProviderManager.reset_provider() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index dd11473a4..b410175a4 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -19,12 +19,11 @@ DriverDialectManager from aws_advanced_python_wrapper.exception_handling import ExceptionManager from aws_advanced_python_wrapper.plugin_service import PluginServiceImpl -from aws_advanced_python_wrapper.utils.storage.storage_service import \ - StorageService +from aws_advanced_python_wrapper.utils import services_container def pytest_runtest_setup(item): - StorageService.clear_all() + services_container.get_storage_service().clear_all() PluginServiceImpl._host_availability_expiring_cache.clear() DatabaseDialectManager._known_endpoint_dialects.clear() diff --git a/tests/unit/test_cache_map.py b/tests/unit/test_cache_map.py index ca1944698..24e2b1545 100644 --- a/tests/unit/test_cache_map.py +++ b/tests/unit/test_cache_map.py @@ -14,7 +14,7 @@ import time -from aws_advanced_python_wrapper.utils.cache_map import CacheMap +from aws_advanced_python_wrapper.utils.storage.cache_map import CacheMap def test_get(): diff --git a/tests/unit/test_core_services_thread_pool.py b/tests/unit/test_core_services_thread_pool.py new file mode 100644 index 000000000..f7dfec5b2 --- /dev/null +++ b/tests/unit/test_core_services_thread_pool.py @@ -0,0 +1,70 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent.futures import ThreadPoolExecutor + +import pytest + +from aws_advanced_python_wrapper.utils import services_container + + +@pytest.fixture(autouse=True) +def cleanup_pools(): + yield + services_container._instance._thread_pools.clear() + + +def test_get_thread_pool_creates_new_pool(): + pool = services_container.get_thread_pool("test_pool") + assert isinstance(pool, ThreadPoolExecutor) + assert "test_pool" in services_container._instance._thread_pools + + +def test_get_thread_pool_returns_existing_pool(): + pool1 = services_container.get_thread_pool("test_pool") + pool2 = services_container.get_thread_pool("test_pool") + assert pool1 is pool2 + + +def test_get_thread_pool_with_max_workers(): + pool = services_container.get_thread_pool("test_pool", max_workers=5) + assert pool._max_workers == 5 + + +def test_thread_name_prefix(): + pool = services_container.get_thread_pool("custom_name") + assert pool._thread_name_prefix == "custom_name" + + +def test_release_thread_pool(): + services_container.get_thread_pool("test_pool") + assert "test_pool" in services_container._instance._thread_pools + + result = services_container.release_thread_pool("test_pool") + assert result is True + assert "test_pool" not in services_container._instance._thread_pools + + +def test_release_nonexistent_pool(): + result = services_container.release_thread_pool("nonexistent") + assert result is False + + +def test_release_resources_clears_pools(): + services_container.get_thread_pool("pool1") + services_container.get_thread_pool("pool2") + assert len(services_container._instance._thread_pools) == 2 + + services_container.release_resources() + assert len(services_container._instance._thread_pools) == 0 diff --git a/tests/unit/test_developer_plugin.py b/tests/unit/test_developer_plugin.py index 51625fe14..6e785a456 100644 --- a/tests/unit/test_developer_plugin.py +++ b/tests/unit/test_developer_plugin.py @@ -19,6 +19,7 @@ from aws_advanced_python_wrapper.developer_plugin import \ ExceptionSimulatorManager +from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.plugin_service import ( PluginManager, PluginServiceImpl, PluginServiceManagerContainer) from aws_advanced_python_wrapper.utils.properties import Properties @@ -87,14 +88,21 @@ def setup_container(container, plugin_service, plugin_manager): container.plugin_manager = plugin_manager +def _make_conn(plugin_service, plugin_manager): + conn = object.__new__(AwsWrapperConnection) + conn._plugin_service = plugin_service + conn._plugin_manager = plugin_manager + return conn + + @pytest.fixture def telemetry_factory(mocker): return mocker.MagicMock() -def test_raise_exception(mocker, plugin_service, plugin_manager): +def test_raise_exception(mocker, plugin_service, plugin_manager, container): exception: RuntimeError = RuntimeError("exception to raise") - conn = AwsWrapperConnection(mocker.MagicMock(), plugin_service, plugin_service, plugin_manager) + conn = _make_conn(plugin_service, plugin_manager) conn.cursor() @@ -105,9 +113,9 @@ def test_raise_exception(mocker, plugin_service, plugin_manager): conn.cursor() -def test_raise_exception_for_method_name(mocker, plugin_service, plugin_manager): +def test_raise_exception_for_method_name(mocker, plugin_service, plugin_manager, container): exception: RuntimeError = RuntimeError("exception to raise") - conn = AwsWrapperConnection(mocker.MagicMock(), plugin_service, plugin_service, plugin_manager) + conn = _make_conn(plugin_service, plugin_manager) conn.cursor() @@ -118,9 +126,9 @@ def test_raise_exception_for_method_name(mocker, plugin_service, plugin_manager) conn.cursor() -def test_raise_exception_for_wrong_method_name(mocker, plugin_service, plugin_manager): +def test_raise_exception_for_wrong_method_name(mocker, plugin_service, plugin_manager, container): exception: RuntimeError = RuntimeError("exception to raise") - conn = AwsWrapperConnection(mocker.MagicMock(), plugin_service, plugin_service, plugin_manager) + conn = _make_conn(plugin_service, plugin_manager) conn.cursor() @@ -128,26 +136,34 @@ def test_raise_exception_for_wrong_method_name(mocker, plugin_service, plugin_ma conn.cursor() -def test_raise_exception_on_connect(mocker, plugin_service, plugin_manager): +def test_raise_exception_on_connect(mocker, plugin_service, plugin_manager, container): exception: Exception = Exception("exception to raise") exception_simulator_manager = ExceptionSimulatorManager() exception_simulator_manager.raise_exception_on_next_connect(exception) + mock_ps = mocker.MagicMock() + mock_ps.current_connection = None + mock_ps.initial_connection_host_info = HostInfo("localhost") + mock_ps.props = Properties({"host": "localhost", "plugins": "dev"}) with pytest.raises(Exception, match="exception to raise"): - AwsWrapperConnection(mocker.MagicMock(), plugin_service, plugin_service, plugin_manager) + AwsWrapperConnection(mocker.MagicMock(), mock_ps, mock_ps, plugin_manager) - AwsWrapperConnection(mocker.MagicMock(), plugin_service, plugin_service, plugin_manager) + AwsWrapperConnection(mocker.MagicMock(), mock_ps, mock_ps, plugin_manager) -def test_no_exception_on_connect_with_callback(mocker, mock_connect_callback, plugin_service, plugin_manager): +def test_no_exception_on_connect_with_callback(mocker, mock_connect_callback, plugin_service, plugin_manager, container): exception_simulator_manager = ExceptionSimulatorManager() mock_connect_callback.get_exception_to_raise.return_value = None exception_simulator_manager.set_connect_callback(mock_connect_callback) - AwsWrapperConnection(mocker.MagicMock(), plugin_service, plugin_service, plugin_manager) + mock_ps = mocker.MagicMock() + mock_ps.current_connection = None + mock_ps.initial_connection_host_info = HostInfo("localhost") + mock_ps.props = Properties({"host": "localhost", "plugins": "dev"}) + AwsWrapperConnection(mocker.MagicMock(), mock_ps, mock_ps, plugin_manager) -def test_raise_exception_on_connect_with_callback(mocker, mock_connect_callback, plugin_service, plugin_manager): +def test_raise_exception_on_connect_with_callback(mocker, mock_connect_callback, plugin_service, plugin_manager, container): exception: Exception = Exception("exception to raise") exception_simulator_manager = ExceptionSimulatorManager() ExceptionSimulatorManager.connect_callback = mocker.MagicMock() @@ -156,6 +172,11 @@ def test_raise_exception_on_connect_with_callback(mocker, mock_connect_callback, exception_simulator_manager.raise_exception_on_next_connect(exception) exception_simulator_manager.set_connect_callback(mock_connect_callback) + mock_ps = mocker.MagicMock() + mock_ps.current_connection = None + mock_ps.initial_connection_host_info = HostInfo("localhost") + mock_ps.props = Properties({"host": "localhost", "plugins": "dev"}) + with pytest.raises(Exception, match="exception to raise"): AwsWrapperConnection(mocker.MagicMock(), plugin_service, plugin_service, plugin_manager) diff --git a/tests/unit/test_expiration_tracking_cache.py b/tests/unit/test_expiration_tracking_cache.py new file mode 100644 index 000000000..faa4756ae --- /dev/null +++ b/tests/unit/test_expiration_tracking_cache.py @@ -0,0 +1,162 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +from aws_advanced_python_wrapper.utils.storage.expiration_tracking_cache import \ + ExpirationTrackingCache + +_LONG_TTL = 15_000_000_000 # 15 seconds +_SHORT_TTL = 50_000_000 # 50 ms + + +@pytest.fixture +def cache(): + return ExpirationTrackingCache(_LONG_TTL) + + +def test_put_and_get(cache): + cache.put("k1", "v1") + assert cache.get("k1") == "v1" + + +def test_get_missing_key(cache): + assert cache.get("missing") is None + + +def test_get_expired_key(): + cache = ExpirationTrackingCache(_SHORT_TTL) + cache.put("k1", "v1") + time.sleep(0.07) + assert cache.get("k1") is None + + +def test_put_returns_old_value(cache): + assert cache.put("k1", "old") is None + assert cache.put("k1", "new") == "old" + assert cache.get("k1") == "new" + + +def test_compute_if_absent(cache): + result1 = cache.compute_if_absent("k1", lambda _: "v1") + result2 = cache.compute_if_absent("k1", lambda _: "v2") + assert result1 == "v1" + assert result2 == "v1" + + +def test_compute_if_absent_extends_expiration(): + cache = ExpirationTrackingCache(_SHORT_TTL) + cache.compute_if_absent("k1", lambda _: "v1") + time.sleep(0.03) + # Should extend expiration, not expire + cache.compute_if_absent("k1", lambda _: "v2") + assert cache.get("k1") == "v1" + + +def test_get_or_create_for_aliases(cache): + result = cache.get_or_create_for_aliases( + frozenset(["a", "b", "c"]), lambda: "val") + assert result == "val" + assert cache.get("a") == "val" + assert cache.get("b") == "val" + assert cache.get("c") == "val" + + +def test_get_or_create_for_aliases_reuses_existing(cache): + cache.put("b", "existing") + result = cache.get_or_create_for_aliases( + frozenset(["a", "b", "c"]), lambda: "new") + assert result == "existing" + # "a" and "c" should also be set + assert cache.get("a") == "existing" + + +def test_extend_expiration(): + cache = ExpirationTrackingCache(_SHORT_TTL) + cache.put("k1", "v1") + time.sleep(0.03) + cache.extend_expiration("k1") + time.sleep(0.03) + assert cache.get("k1") == "v1" + + +def test_remove(cache): + cache.put("k1", "v1") + removed = cache.remove("k1") + assert removed == "v1" + assert cache.get("k1") is None + + +def test_remove_missing_key(cache): + assert cache.remove("missing") is None + + +def test_remove_expired_if(cache): + short_cache = ExpirationTrackingCache(_SHORT_TTL) + short_cache.put("k1", "v1") + time.sleep(0.07) + removed = short_cache.remove_expired_if("k1", lambda v: True) + assert removed == "v1" + + +def test_remove_expired_if_not_expired(cache): + cache.put("k1", "v1") + removed = cache.remove_expired_if("k1", lambda v: True) + assert removed is None + assert cache.get("k1") == "v1" + + +def test_remove_expired_if_predicate_false(): + cache = ExpirationTrackingCache(_SHORT_TTL) + cache.put("k1", "v1") + time.sleep(0.07) + removed = cache.remove_expired_if("k1", lambda v: False) + assert removed is None + + +def test_detach_value(cache): + cache.put("k1", "v1") + cache.put("k2", "v2") + assert cache.detach_value("v1") is True + assert cache.get("k1") is None + assert cache.get("k2") == "v2" + + +def test_detach_value_not_found(cache): + assert cache.detach_value("missing") is False + + +def test_items(cache): + cache.put("k1", "v1") + cache.put("k2", "v2") + result = dict(cache.items()) + assert result == {"k1": "v1", "k2": "v2"} + + +def test_clear(cache): + cache.put("k1", "v1") + cache.put("k2", "v2") + values = cache.clear() + assert sorted(values) == ["v1", "v2"] + assert len(cache) == 0 + + +def test_len(cache): + assert len(cache) == 0 + cache.put("k1", "v1") + assert len(cache) == 1 + cache.put("k2", "v2") + assert len(cache) == 2 diff --git a/tests/unit/test_federated_auth_plugin.py b/tests/unit/test_federated_auth_plugin.py index 0f60d00cc..b84ed072b 100644 --- a/tests/unit/test_federated_auth_plugin.py +++ b/tests/unit/test_federated_auth_plugin.py @@ -15,8 +15,6 @@ from __future__ import annotations from datetime import datetime, timedelta -from typing import Dict -from unittest.mock import patch import pytest from boto3 import Session @@ -27,6 +25,7 @@ from aws_advanced_python_wrapper.federated_plugin import FederatedAuthPlugin from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.iam_plugin import TokenInfo +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) @@ -39,12 +38,12 @@ _PG_HOST_INFO = HostInfo("pg.testdb.us-east-2.rds.amazonaws.com") -_token_cache: Dict[str, TokenInfo] = {} - @pytest.fixture(autouse=True) def clear_cache(): - _token_cache.clear() + from datetime import timedelta + services_container.get_storage_service().register(TokenInfo, item_expiration_time=timedelta(minutes=30)) + services_container.get_storage_service().clear(TokenInfo) AwsCredentialsManager.release_resources() @@ -101,18 +100,17 @@ def mock_default_behavior(mock_session, mock_client, mock_func, mock_connection, yield -@patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache) def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): properties: Properties = Properties() WrapperProperties.PLUGINS.set(properties, "federated_auth") WrapperProperties.DB_USER.set(properties, _DB_USER) initial_token = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5)) - _token_cache[_PG_CACHE_KEY] = initial_token + services_container.get_storage_service().put(TokenInfo, _PG_CACHE_KEY, initial_token) target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_session) key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + str(_DEFAULT_PG_PORT) + ":postgesqlUser" - _token_cache[key] = initial_token + services_container.get_storage_service().put(TokenInfo, key, initial_token) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -124,19 +122,18 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi mock_client.generate_db_auth_token.assert_not_called() - actual_token = _token_cache.get(_PG_CACHE_KEY) + actual_token = services_container.get_storage_service().get(TokenInfo, _PG_CACHE_KEY) assert _GENERATED_TOKEN != actual_token.token assert _TEST_TOKEN == actual_token.token assert actual_token.is_expired() is False -@patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache) def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect, mock_credentials_provider_factory): test_props: Properties = Properties({"plugins": "federated_auth", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) WrapperProperties.DB_USER.set(test_props, _DB_USER) initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5)) - _token_cache[_PG_CACHE_KEY] = initial_token + services_container.get_storage_service().put(TokenInfo, _PG_CACHE_KEY, initial_token) target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory) @@ -157,7 +154,6 @@ def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_fu assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN -@patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache) def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect, mock_credentials_provider_factory): test_props: Properties = Properties({"plugins": "federated_auth", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) @@ -182,7 +178,6 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN -@patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache) def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect, mock_credentials_provider_factory): test_props: Properties = Properties( @@ -213,7 +208,6 @@ def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_sess assert str(e_info.value) == Messages.get_formatted("FederatedAuthPlugin.ConnectException", exception_message) -@patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache) def test_connect_with_specified_iam_host_port_region(mocker, mock_plugin_service, mock_session, @@ -234,7 +228,7 @@ def test_connect_with_specified_iam_host_port_region(mocker, test_token_info = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5)) key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + str(expected_port) + ":specifiedUser" - _token_cache[key] = test_token_info + services_container.get_storage_service().put(TokenInfo, key, test_token_info) mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{expected_region}" diff --git a/tests/unit/test_global_aurora_host_list_provider.py b/tests/unit/test_global_aurora_host_list_provider.py index edca658ad..8610580cc 100644 --- a/tests/unit/test_global_aurora_host_list_provider.py +++ b/tests/unit/test_global_aurora_host_list_provider.py @@ -22,17 +22,14 @@ from aws_advanced_python_wrapper.host_list_provider import ( GlobalAuroraHostListProvider, GlobalAuroraTopologyUtils) from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.properties import Properties -from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ - SlidingExpirationCacheContainer -from aws_advanced_python_wrapper.utils.storage.storage_service import \ - StorageService @pytest.fixture(autouse=True) def clear_caches(): - StorageService.clear_all() - SlidingExpirationCacheContainer.release_resources() + services_container.get_storage_service().clear_all() + services_container.get_monitor_service().stop_all() @pytest.fixture diff --git a/tests/unit/test_host_monitor_v2_plugin.py b/tests/unit/test_host_monitor_v2_plugin.py index 46c79423b..273a2e442 100644 --- a/tests/unit/test_host_monitor_v2_plugin.py +++ b/tests/unit/test_host_monitor_v2_plugin.py @@ -241,17 +241,17 @@ def test_is_active_when_inactive(monitoring_context): def test_can_dispose_none_empty_active_context(host_monitor): - assert host_monitor.can_dispose() is True + assert host_monitor.can_dispose is True host_monitor._active_contexts.put(MagicMock()) - assert host_monitor.can_dispose() is False + assert host_monitor.can_dispose is False def test_can_dispose_none_new_contexts_context(host_monitor): - assert host_monitor.can_dispose() is True + assert host_monitor.can_dispose is True host_monitor._new_contexts.compute_if_absent(1, lambda key: Queue()) - assert host_monitor.can_dispose() is False + assert host_monitor.can_dispose is False def test_is_stopped(host_monitor): diff --git a/tests/unit/test_host_monitoring_plugin.py b/tests/unit/test_host_monitoring_plugin.py index faa8ffa20..dab274180 100644 --- a/tests/unit/test_host_monitoring_plugin.py +++ b/tests/unit/test_host_monitoring_plugin.py @@ -200,4 +200,4 @@ def test_get_monitoring_host_info_errors(mocker, plugin, mock_plugin_service): def test_release_resources(plugin, mock_monitor_service): plugin.release_resources() - mock_monitor_service.release_resources.assert_called_once() + assert plugin._monitor_service is None diff --git a/tests/unit/test_host_response_time_monitor.py b/tests/unit/test_host_response_time_monitor.py index fa6e44bc5..37d81b1b5 100644 --- a/tests/unit/test_host_response_time_monitor.py +++ b/tests/unit/test_host_response_time_monitor.py @@ -67,7 +67,7 @@ def test_run_host_available(mock_conn, mock_plugin_service, host_info, props): host_info, props, 1_000) sleep(0.1) - monitor.close() + monitor.stop() mock_plugin_service.driver_dialect.ping.assert_called() mock_plugin_service.force_connect.assert_called() diff --git a/tests/unit/test_iam_plugin.py b/tests/unit/test_iam_plugin.py index 9b1329416..9bf9e3aff 100644 --- a/tests/unit/test_iam_plugin.py +++ b/tests/unit/test_iam_plugin.py @@ -16,8 +16,6 @@ import urllib.request from datetime import datetime, timedelta -from typing import Dict -from unittest.mock import patch import pytest from boto3 import Session @@ -26,7 +24,9 @@ AwsCredentialsManager from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.hostinfo import HostInfo -from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin, TokenInfo +from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin +from aws_advanced_python_wrapper.utils import services_container +from aws_advanced_python_wrapper.utils.iam_utils import TokenInfo from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) @@ -41,12 +41,12 @@ _PG_HOST_INFO_WITH_PORT = HostInfo("pg.testdb.us-east-2.rds.amazonaws.com", port=1234) _PG_HOST_INFO_WITH_REGION = HostInfo("pg.testdb.us-west-1.rds.amazonaws.com") -_token_cache: Dict[str, TokenInfo] = {} - @pytest.fixture(autouse=True) def clear_caches(): - _token_cache.clear() + from datetime import timedelta + services_container.get_storage_service().register(TokenInfo, item_expiration_time=timedelta(minutes=15)) + services_container.get_storage_service().clear(TokenInfo) AwsCredentialsManager.release_resources() @@ -99,11 +99,10 @@ def pg_properties(): return Properties({"user": "postgresqlUser"}) -@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) initial_token = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5)) - _token_cache[_PG_CACHE_KEY] = initial_token + services_container.get_storage_service().put(TokenInfo, _PG_CACHE_KEY, initial_token) target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service) target_plugin.connect( @@ -116,13 +115,12 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi mock_client.generate_db_auth_token.assert_not_called() - actual_token = _token_cache.get(_PG_CACHE_KEY) + actual_token = services_container.get_storage_service().get(TokenInfo, _PG_CACHE_KEY) assert _GENERATED_TOKEN != actual_token.token assert _TEST_TOKEN == actual_token.token assert actual_token.is_expired() is False -@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_pg_connect_with_invalid_port_fall_backs_to_host_port( mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) @@ -147,7 +145,7 @@ def test_pg_connect_with_invalid_port_fall_backs_to_host_port( DBUsername="postgresqlUser" ) - actual_token = _token_cache.get("us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:1234:postgresqlUser") + actual_token = services_container.get_storage_service().get(TokenInfo, "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:1234:postgresqlUser") assert _GENERATED_TOKEN == actual_token.token assert actual_token.is_expired() is False @@ -156,7 +154,6 @@ def test_pg_connect_with_invalid_port_fall_backs_to_host_port( mock_dialect.set_password.assert_called_with(expected_props, _GENERATED_TOKEN) -@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_pg_connect_with_invalid_port_and_no_host_port_fall_backs_to_host_port( mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) @@ -182,8 +179,8 @@ def test_pg_connect_with_invalid_port_and_no_host_port_fall_backs_to_host_port( DBUsername="postgresqlUser" ) - actual_token = _token_cache.get( - f"us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:{expected_default_pg_port}:postgresqlUser") + actual_token = services_container.get_storage_service().get( + TokenInfo, f"us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:{expected_default_pg_port}:postgresqlUser") assert _GENERATED_TOKEN == actual_token.token assert actual_token.is_expired() is False @@ -192,11 +189,10 @@ def test_pg_connect_with_invalid_port_and_no_host_port_fall_backs_to_host_port( mock_dialect.set_password.assert_called_with(expected_props, _GENERATED_TOKEN) -@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_connect_expired_token_in_cache(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5)) - _token_cache[_PG_CACHE_KEY] = initial_token + services_container.get_storage_service().put(TokenInfo, _PG_CACHE_KEY, initial_token) mock_func.side_effect = Exception("generic exception") target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service) @@ -215,13 +211,12 @@ def test_connect_expired_token_in_cache(mocker, mock_plugin_service, mock_sessio DBUsername="postgresqlUser" ) - actual_token = _token_cache.get(_PG_CACHE_KEY) + actual_token = services_container.get_storage_service().get(TokenInfo, _PG_CACHE_KEY) assert initial_token != actual_token assert _GENERATED_TOKEN == actual_token.token assert actual_token.is_expired() is False -@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_connect_empty_cache(mocker, mock_plugin_service, mock_connection, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service) @@ -239,18 +234,17 @@ def test_connect_empty_cache(mocker, mock_plugin_service, mock_connection, mock_ DBUsername="postgresqlUser" ) - actual_token = _token_cache.get(_PG_CACHE_KEY) + actual_token = services_container.get_storage_service().get(TokenInfo, _PG_CACHE_KEY) assert mock_connection == actual_connection assert _GENERATED_TOKEN == actual_token.token assert actual_token.is_expired() is False -@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_connect_with_specified_port(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) cache_key_with_new_port: str = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:1234:postgresqlUser" initial_token = TokenInfo(f"{_TEST_TOKEN}:1234", datetime.now() + timedelta(minutes=5)) - _token_cache[cache_key_with_new_port] = initial_token + services_container.get_storage_service().put(TokenInfo, cache_key_with_new_port, initial_token) # Assert no password has been set assert test_props.get("password") is None @@ -266,8 +260,8 @@ def test_connect_with_specified_port(mocker, mock_plugin_service, mock_session, mock_client.generate_db_auth_token.assert_not_called() - actual_token = _token_cache.get(cache_key_with_new_port) - assert _token_cache.get(_PG_CACHE_KEY) is None + actual_token = services_container.get_storage_service().get(TokenInfo, cache_key_with_new_port) + assert services_container.get_storage_service().get(TokenInfo, _PG_CACHE_KEY) is None assert _GENERATED_TOKEN != actual_token.token assert f"{_TEST_TOKEN}:1234" == actual_token.token assert actual_token.is_expired() is False @@ -277,14 +271,13 @@ def test_connect_with_specified_port(mocker, mock_plugin_service, mock_session, mock_dialect.set_password.assert_called_with(expected_props, f"{_TEST_TOKEN}:1234") -@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_connect_with_specified_iam_default_port(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) iam_default_port: str = "9999" test_props[WrapperProperties.IAM_DEFAULT_PORT.name] = iam_default_port cache_key_with_new_port = f"us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:{iam_default_port}:postgresqlUser" initial_token = TokenInfo(f"{_TEST_TOKEN}:{iam_default_port}", datetime.now() + timedelta(minutes=5)) - _token_cache[cache_key_with_new_port] = initial_token + services_container.get_storage_service().put(TokenInfo, cache_key_with_new_port, initial_token) # Assert no password has been set assert test_props.get("password") is None @@ -300,8 +293,8 @@ def test_connect_with_specified_iam_default_port(mocker, mock_plugin_service, mo mock_client.generate_db_auth_token.assert_not_called() - actual_token = _token_cache.get(cache_key_with_new_port) - assert _token_cache.get(_PG_CACHE_KEY) is None + actual_token = services_container.get_storage_service().get(TokenInfo, cache_key_with_new_port) + assert services_container.get_storage_service().get(TokenInfo, _PG_CACHE_KEY) is None assert _GENERATED_TOKEN != actual_token.token assert f"{_TEST_TOKEN}:{iam_default_port}" == actual_token.token assert actual_token.is_expired() is False @@ -311,7 +304,6 @@ def test_connect_with_specified_iam_default_port(mocker, mock_plugin_service, mo mock_dialect.set_password.assert_called_with(expected_props, f"{_TEST_TOKEN}:{iam_default_port}") -@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_connect_with_specified_region(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) iam_region: str = "us-east-1" @@ -319,7 +311,7 @@ def test_connect_with_specified_region(mocker, mock_plugin_service, mock_session # Cache a token with a different region cache_key_with_region = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:5432:postgresqlUser" initial_token = TokenInfo("us-east-2", datetime.now() + timedelta(minutes=5)) - _token_cache[cache_key_with_region] = initial_token + services_container.get_storage_service().put(TokenInfo, cache_key_with_region, initial_token) test_props[WrapperProperties.IAM_REGION.name] = iam_region @@ -343,7 +335,7 @@ def test_connect_with_specified_region(mocker, mock_plugin_service, mock_session DBUsername="postgresqlUser" ) - actual_token = _token_cache.get("us-east-1:pg.testdb.us-east-2.rds.amazonaws.com:5432:postgresqlUser") + actual_token = services_container.get_storage_service().get(TokenInfo, "us-east-1:pg.testdb.us-east-2.rds.amazonaws.com:5432:postgresqlUser") assert f"{_TEST_TOKEN}:{iam_region}" == actual_token.token assert actual_token.is_expired() is False @@ -360,7 +352,6 @@ def test_connect_with_specified_region(mocker, mock_plugin_service, mock_session pytest.param("test-.proxy-123456789012.us-east-2.rds.amazonaws.com.cn"), pytest.param("test-.proxy-123456789012.us-east-2.rds.amazonaws.com"), ]) -@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_connect_with_specified_host(iam_host: str, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) @@ -385,7 +376,7 @@ def test_connect_with_specified_host(iam_host: str, mocker, mock_plugin_service, DBUsername="postgresqlUser" ) - actual_token = _token_cache.get(f"us-east-2:{iam_host}:5432:postgresqlUser") + actual_token = services_container.get_storage_service().get(TokenInfo, f"us-east-2:{iam_host}:5432:postgresqlUser") assert actual_token is not None assert _GENERATED_TOKEN != actual_token.token assert f"{_TEST_TOKEN}:{iam_host}" == actual_token.token diff --git a/tests/unit/test_limitless_router_service.py b/tests/unit/test_limitless_router_service.py index 9838565c8..295c36bc7 100644 --- a/tests/unit/test_limitless_router_service.py +++ b/tests/unit/test_limitless_router_service.py @@ -16,22 +16,17 @@ from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.limitless_plugin import ( - LimitlessContext, LimitlessPlugin, LimitlessRouterService) + LimitlessContext, LimitlessPlugin, LimitlessRouters, + LimitlessRouterService) +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ - SlidingExpirationCacheContainer CLUSTER_ID: str = "some_cluster_id" EXPIRATION_NANO_SECONDS: int = 60 * 60 * 1_000_000_000 -def get_router_cache(): - """Helper to get the limitless router cache from the container.""" - return SlidingExpirationCacheContainer.get_or_create_cache("limitless_router_cache") - - @pytest.fixture def writer_host(): return HostInfo("instance-0", 5432, HostRole.WRITER, HostAvailability.AVAILABLE) @@ -97,6 +92,7 @@ def mock_plugin_service(mocker, mock_driver_dialect, mock_conn, host_info, defau service_mock.hosts = default_hosts service_mock.host_list_provider = mocker.MagicMock() service_mock.host_list_provider.get_cluster_id.return_value = CLUSTER_ID + service_mock.props = Properties({}) type(service_mock).driver_dialect = mocker.PropertyMock(return_value=mock_driver_dialect) return service_mock @@ -137,14 +133,9 @@ def plugin(mock_plugin_service, props, mock_limitless_router_service): @pytest.fixture(autouse=True) -def run_before_and_after_tests(mock_limitless_router_service): - # Before - +def run_before_and_after_tests(): yield - - # After - # Clear the cache through the container - get_router_cache().clear() + services_container.get_storage_service().clear(LimitlessRouters) def test_establish_connection_empty_routers_list_then_wait_for_router_info_then_raises_exception(mocker, @@ -209,8 +200,10 @@ def test_establish_connection_host_info_in_router_cache_then_call_connection_fun props, mock_plugin_service, limitless_routers): - get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + services_container.get_storage_service().put( + LimitlessRouters, CLUSTER_ID, LimitlessRouters(limitless_routers), EXPIRATION_NANO_SECONDS) mock_connect_func = mocker.MagicMock() mock_connect_func.return_value = mock_conn @@ -224,8 +217,6 @@ def test_establish_connection_host_info_in_router_cache_then_call_connection_fun mock_plugin_service ) - limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, - mock_limitless_query_helper) limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() @@ -258,7 +249,7 @@ def test_establish_connection_fetch_router_list_and_host_info_in_router_list_the limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == get_router_cache().get(CLUSTER_ID) + assert limitless_routers == services_container.get_storage_service().get(LimitlessRouters, CLUSTER_ID).hosts mock_limitless_query_helper.query_for_limitless_routers.assert_called_once() mock_connect_func.assert_called_once() @@ -272,8 +263,10 @@ def test_establish_connection_router_cache_then_select_host(mocker, plugin, limitless_router1, limitless_routers): - get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + services_container.get_storage_service().put( + LimitlessRouters, CLUSTER_ID, LimitlessRouters(limitless_routers), EXPIRATION_NANO_SECONDS) mock_plugin_service.get_host_info_by_strategy.return_value = limitless_router1 mock_plugin_service.connect.return_value = mock_conn @@ -289,12 +282,10 @@ def test_establish_connection_router_cache_then_select_host(mocker, plugin ) - limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, - mock_limitless_query_helper) limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == get_router_cache().get(CLUSTER_ID) + assert limitless_routers == services_container.get_storage_service().get(LimitlessRouters, CLUSTER_ID).hosts mock_plugin_service.get_host_info_by_strategy.assert_called_once() mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "weighted_random", limitless_routers) @@ -333,7 +324,7 @@ def test_establish_connection_fetch_router_list_then_select_host(mocker, limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == get_router_cache().get(CLUSTER_ID) + assert limitless_routers == services_container.get_storage_service().get(LimitlessRouters, CLUSTER_ID).hosts mock_limitless_query_helper.query_for_limitless_routers.assert_called_once() mock_plugin_service.get_host_info_by_strategy.assert_called_once() mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "weighted_random", @@ -352,8 +343,10 @@ def test_establish_connection_host_info_in_router_cache_can_call_connection_func plugin, limitless_router1, limitless_routers): - get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + services_container.get_storage_service().put( + LimitlessRouters, CLUSTER_ID, LimitlessRouters(limitless_routers), EXPIRATION_NANO_SECONDS) mock_plugin_service.get_host_info_by_strategy.return_value = limitless_router1 mock_plugin_service.connect.return_value = mock_conn @@ -369,12 +362,10 @@ def test_establish_connection_host_info_in_router_cache_can_call_connection_func plugin ) - limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, - mock_limitless_query_helper) limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == get_router_cache().get(CLUSTER_ID) + assert limitless_routers == services_container.get_storage_service().get(LimitlessRouters, CLUSTER_ID).hosts mock_plugin_service.get_host_info_by_strategy.assert_called_once() mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", limitless_routers) @@ -392,8 +383,10 @@ def test_establish_connection_selected_host_raises_exception_and_retries(mocker, plugin, limitless_router1, limitless_routers): - get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + services_container.get_storage_service().put( + LimitlessRouters, CLUSTER_ID, LimitlessRouters(limitless_routers), EXPIRATION_NANO_SECONDS) mock_plugin_service.get_host_info_by_strategy.side_effect = [ Exception(), limitless_router1 @@ -412,12 +405,10 @@ def test_establish_connection_selected_host_raises_exception_and_retries(mocker, plugin ) - limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, - mock_limitless_query_helper) limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == get_router_cache().get(CLUSTER_ID) + assert limitless_routers == services_container.get_storage_service().get(LimitlessRouters, CLUSTER_ID).hosts assert mock_plugin_service.get_host_info_by_strategy.call_count == 2 mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", limitless_routers) @@ -436,8 +427,10 @@ def test_establish_connection_selected_host_none_then_retry(mocker, plugin, limitless_router1, limitless_routers): - get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + services_container.get_storage_service().put( + LimitlessRouters, CLUSTER_ID, LimitlessRouters(limitless_routers), EXPIRATION_NANO_SECONDS) mock_plugin_service.get_host_info_by_strategy.side_effect = [ None, limitless_router1 @@ -456,12 +449,10 @@ def test_establish_connection_selected_host_none_then_retry(mocker, plugin ) - limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, - mock_limitless_query_helper) limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == get_router_cache().get(CLUSTER_ID) + assert limitless_routers == services_container.get_storage_service().get(LimitlessRouters, CLUSTER_ID).hosts assert mock_plugin_service.get_host_info_by_strategy.call_count == 2 mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", limitless_routers) @@ -481,8 +472,10 @@ def test_establish_connection_plugin_service_connect_raises_exception_then_retry limitless_router1, limitless_router2, limitless_routers): - get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + services_container.get_storage_service().put( + LimitlessRouters, CLUSTER_ID, LimitlessRouters(limitless_routers), EXPIRATION_NANO_SECONDS) mock_plugin_service.get_host_info_by_strategy.side_effect = [ limitless_router1, limitless_router2 @@ -504,12 +497,10 @@ def test_establish_connection_plugin_service_connect_raises_exception_then_retry plugin ) - limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, - mock_limitless_query_helper) limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == get_router_cache().get(CLUSTER_ID) + assert limitless_routers == services_container.get_storage_service().get(LimitlessRouters, CLUSTER_ID).hosts assert mock_plugin_service.get_host_info_by_strategy.call_count == 2 mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", limitless_routers) @@ -528,8 +519,10 @@ def test_establish_connection_retry_and_max_retries_exceeded_then_raise_exceptio plugin, limitless_router1, limitless_routers): - get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + services_container.get_storage_service().put( + LimitlessRouters, CLUSTER_ID, LimitlessRouters(limitless_routers), EXPIRATION_NANO_SECONDS) mock_plugin_service.get_host_info_by_strategy.return_value = limitless_router1 mock_plugin_service.connect.side_effect = Exception() @@ -545,8 +538,6 @@ def test_establish_connection_retry_and_max_retries_exceeded_then_raise_exceptio plugin ) - limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, - mock_limitless_query_helper) with pytest.raises(Exception) as e_info: limitless_router_service.establish_connection(input_context) diff --git a/tests/unit/test_monitor.py b/tests/unit/test_monitor.py index 9eca560ec..2504ad107 100644 --- a/tests/unit/test_monitor.py +++ b/tests/unit/test_monitor.py @@ -20,7 +20,7 @@ from aws_advanced_python_wrapper import release_resources from aws_advanced_python_wrapper.host_monitoring_plugin import ( - Monitor, MonitoringContext, MonitoringThreadContainer) + Monitor, MonitoringContext) from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) @@ -73,18 +73,18 @@ def mock_plugin_service(mocker, mock_conn, mock_driver_dialect): @pytest.fixture def monitor(mock_plugin_service, host_info, props): - return Monitor( - mock_plugin_service, - host_info, - props, - MonitoringThreadContainer()) + m = Monitor(mock_plugin_service, host_info, props) + # Stop the auto-started thread so tests can control execution + m.stop() + m._thread.join(timeout=2) + m._is_stopped.clear() + return m @pytest.fixture(autouse=True) -def release_container(): +def cleanup(): yield - while MonitoringThreadContainer._instance is not None: - release_resources() + release_resources() @pytest.fixture @@ -137,10 +137,6 @@ def test_run_host_available( mock_driver_dialect, mock_aborted_connection_counter): remove_delays() - host_alias = "host-1" - container = MonitoringThreadContainer() - container._monitor_map.put_if_absent(host_alias, monitor) - container._tasks_map.put_if_absent(monitor, mocker.MagicMock()) executor = ThreadPoolExecutor() context = MonitoringContext(monitor, mock_conn, mock_driver_dialect, @@ -158,11 +154,9 @@ def test_run_host_available( mock_conn.close.assert_called_once() assert context._is_host_unavailable is False assert monitor._is_stopped.is_set() - assert container._monitor_map.get(host_alias) is None - assert container._tasks_map.get(monitor) is None -def test_ensure_stopped_monitor_removed_from_map( +def test_ensure_stopped_monitor_exits( mocker, monitor, host_info, @@ -172,10 +166,6 @@ def test_ensure_stopped_monitor_removed_from_map( mock_driver_dialect, mock_aborted_connection_counter): remove_delays() - host_alias = "host-1" - container = MonitoringThreadContainer() - container._monitor_map.put_if_absent(host_alias, monitor) - container._tasks_map.put_if_absent(monitor, mocker.MagicMock()) executor = ThreadPoolExecutor() context = MonitoringContext(monitor, mock_conn, mock_driver_dialect, @@ -189,8 +179,7 @@ def test_ensure_stopped_monitor_removed_from_map( sleep(0.1) # Allow some time for the monitor to loop wait([future], 3) - assert container._monitor_map.get(host_alias) is None - assert container._tasks_map.get(monitor) is None + assert monitor._is_stopped.is_set() def test_run_host_unavailable( @@ -221,17 +210,9 @@ def test_run_host_unavailable( def test_run__no_contexts(mocker, monitor): - host_alias = "host-1" - container = MonitoringThreadContainer() - container._monitor_map.put_if_absent(host_alias, monitor) - container._tasks_map.put_if_absent(monitor, mocker.MagicMock()) - # Monitor should exit because there are no contexts monitor.run() - - assert container._monitor_map.get(host_alias) is None - assert container._tasks_map.get(monitor) is None - release_resources() + assert monitor._is_stopped.is_set() def test_check_connection_status__valid_then_invalid(mocker, monitor): @@ -258,6 +239,19 @@ def test_check_connection_status__conn_check_throws_exception(mocker, monitor): assert not status.is_available +def test_can_dispose(monitor): + assert monitor.can_dispose is True + monitor._active_contexts.put("ctx") + assert monitor.can_dispose is False + + +def test_stop(monitor, mock_conn): + monitor._monitoring_conn = mock_conn + monitor.stop() + assert monitor._is_stopped.is_set() + mock_conn.close.assert_called_once() + + def remove_delays(): Monitor._INACTIVE_SLEEP_MS = 0 Monitor._MIN_HOST_CHECK_TIMEOUT_MS = 0 diff --git a/tests/unit/test_monitor_service.py b/tests/unit/test_monitor_service.py index b06209660..b570eb055 100644 --- a/tests/unit/test_monitor_service.py +++ b/tests/unit/test_monitor_service.py @@ -18,8 +18,8 @@ from aws_advanced_python_wrapper import release_resources from aws_advanced_python_wrapper.errors import AwsWrapperError -from aws_advanced_python_wrapper.host_monitoring_plugin import ( - MonitoringThreadContainer, MonitorService) +from aws_advanced_python_wrapper.host_monitoring_plugin import \ + HostMonitorService from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.utils.properties import Properties @@ -34,11 +34,6 @@ def mock_plugin_service(mocker): return mocker.MagicMock() -@pytest.fixture -def mock_thread_container(mocker): - return mocker.MagicMock() - - @pytest.fixture def mock_monitor(mocker): monitor = mocker.MagicMock() @@ -47,120 +42,65 @@ def mock_monitor(mocker): @pytest.fixture -def thread_container(): - return MonitoringThreadContainer() - - -@pytest.fixture -def monitor_service_mocked_container(mock_plugin_service, mock_thread_container): - service = MonitorService(mock_plugin_service) - service._monitor_container = mock_thread_container - return service - - -@pytest.fixture -def monitor_service_with_container(mock_plugin_service, thread_container): - service = MonitorService(mock_plugin_service) - service._monitor_container = thread_container +def monitor_service(mock_plugin_service, mocker, mock_monitor): + mocker.patch( + "aws_advanced_python_wrapper.host_monitoring_plugin.Monitor.__init__", return_value=None) + service = HostMonitorService(mock_plugin_service) return service @pytest.fixture(autouse=True) -def setup_teardown(mocker, mock_thread_container, mock_plugin_service, mock_monitor): - mock_thread_container.get_or_create_monitor.return_value = mock_monitor - mocker.patch( - "aws_advanced_python_wrapper.host_monitoring_plugin.MonitorService._create_monitor", return_value=mock_monitor) - +def cleanup(): yield + release_resources() - while MonitoringThreadContainer._instance is not None: - release_resources() - -def test_start_monitoring( - monitor_service_mocked_container, - mock_plugin_service, - mock_monitor, - mock_conn, - mock_thread_container): +def test_start_monitoring(mocker, monitor_service, mock_plugin_service, mock_monitor, mock_conn): aliases = frozenset({"instance-1"}) - monitor_service_mocked_container.start_monitoring( + mocker.patch.object(monitor_service._monitor_service, 'run_if_absent_with_aliases', return_value=mock_monitor) + + monitor_service.start_monitoring( mock_conn, aliases, HostInfo("instance-1"), Properties(), 5000, 1000, 3) mock_monitor.start_monitoring.assert_called_once() - assert mock_monitor == monitor_service_mocked_container._cached_monitor() - assert aliases == monitor_service_mocked_container._cached_monitor_aliases - - -def test_start_monitoring__multiple_calls(monitor_service_with_container, mock_monitor, mock_conn, mocker): - aliases = frozenset({"instance-1"}) - - # Mock the _thread_pool directly on the container instance since it's now cached in __init__ - mock_thread_pool = mocker.MagicMock() - monitor_service_with_container._monitor_container._thread_pool = mock_thread_pool - - num_calls = 5 - for _ in range(num_calls): - monitor_service_with_container.start_monitoring( - mock_conn, aliases, HostInfo("instance-1"), Properties(), 5000, 1000, 3) - - assert num_calls == mock_monitor.start_monitoring.call_count - mock_thread_pool.submit.assert_called_once_with(mock_monitor.run) - assert mock_monitor == monitor_service_with_container._cached_monitor() - assert aliases == monitor_service_with_container._cached_monitor_aliases + assert mock_monitor == monitor_service._cached_monitor() + assert aliases == monitor_service._cached_monitor_aliases def test_start_monitoring__cached_monitor( - monitor_service_mocked_container, mock_plugin_service, mock_monitor, mock_conn, mock_thread_container): + mocker, monitor_service, mock_plugin_service, mock_monitor, mock_conn): aliases = frozenset({"instance-1"}) - monitor_service_mocked_container._cached_monitor = ref(mock_monitor) - monitor_service_mocked_container._cached_monitor_aliases = aliases + monitor_service._cached_monitor = ref(mock_monitor) + monitor_service._cached_monitor_aliases = aliases - monitor_service_mocked_container.start_monitoring( + monitor_service.start_monitoring( mock_conn, aliases, HostInfo("instance-1"), Properties(), 5000, 1000, 3) - mock_plugin_service.get_dialect.assert_not_called() - mock_thread_container.get_or_create_monitor.assert_not_called() mock_monitor.start_monitoring.assert_called_once() - assert mock_monitor == monitor_service_mocked_container._cached_monitor() - assert aliases == monitor_service_mocked_container._cached_monitor_aliases + assert mock_monitor == monitor_service._cached_monitor() + assert aliases == monitor_service._cached_monitor_aliases -def test_start_monitoring__errors(monitor_service_mocked_container, mock_conn, mock_plugin_service): +def test_start_monitoring__errors(monitor_service, mock_conn): with pytest.raises(AwsWrapperError): - monitor_service_mocked_container.start_monitoring( + monitor_service.start_monitoring( mock_conn, frozenset(), HostInfo("instance-1"), Properties(), 5000, 1000, 3) -def test_stop_monitoring(monitor_service_with_container, mock_monitor, mock_conn): +def test_stop_monitoring(mocker, monitor_service, mock_monitor, mock_conn): aliases = frozenset({"instance-1"}) - context = monitor_service_with_container.start_monitoring( - mock_conn, aliases, HostInfo("instance-1"), Properties(), 5000, 1000, 3) - monitor_service_with_container.stop_monitoring(context) - mock_monitor.stop_monitoring.assert_called_once_with(context) - + mocker.patch.object(monitor_service._monitor_service, 'run_if_absent_with_aliases', return_value=mock_monitor) -def test_stop_monitoring__multiple_calls(monitor_service_with_container, mock_monitor, mock_conn): - aliases = frozenset({"instance-1"}) - context = monitor_service_with_container.start_monitoring( - mock_conn, aliases, HostInfo("instance-1"), Properties(), 5000, 1000, 3) - monitor_service_with_container.stop_monitoring(context) + context = monitor_service.start_monitoring( + mock_conn, aliases, HostInfo("instance-1"), Properties(), 5000, 1000, 3) + monitor_service.stop_monitoring(context) mock_monitor.stop_monitoring.assert_called_once_with(context) - monitor_service_with_container.stop_monitoring(context) - assert 2 == mock_monitor.stop_monitoring.call_count - -def test_stop_monitoring_host_connections(mocker, monitor_service_with_container, thread_container): - aliases1 = frozenset({"alias-1"}) - aliases2 = frozenset({"alias-2"}) - mock_monitor1 = mocker.MagicMock() - mock_monitor2 = mocker.MagicMock() - thread_container.get_or_create_monitor(aliases1, lambda: mock_monitor1) - thread_container.get_or_create_monitor(aliases2, lambda: mock_monitor2) - monitor_service_with_container.stop_monitoring_host(aliases1) - mock_monitor1.clear_contexts.assert_called_once() +def test_stop_monitoring_host(mocker, monitor_service, mock_monitor): + aliases = frozenset({"alias-1"}) + mocker.patch.object(monitor_service._monitor_service, 'get', return_value=mock_monitor) - monitor_service_with_container.stop_monitoring_host(aliases2) - mock_monitor2.clear_contexts.assert_called_once() + monitor_service.stop_monitoring_host(aliases) + mock_monitor.clear_contexts.assert_called_once() diff --git a/tests/unit/test_monitoring_thread_container.py b/tests/unit/test_monitoring_thread_container.py deleted file mode 100644 index 1a9469d14..000000000 --- a/tests/unit/test_monitoring_thread_container.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from aws_advanced_python_wrapper import release_resources -from aws_advanced_python_wrapper.errors import AwsWrapperError -from aws_advanced_python_wrapper.host_monitoring_plugin import \ - MonitoringThreadContainer - - -@pytest.fixture -def container(): - return MonitoringThreadContainer() - - -@pytest.fixture -def mock_future(mocker): - return mocker.MagicMock() - - -@pytest.fixture -def mock_monitor1(mocker): - monitor = mocker.MagicMock() - monitor.is_stopped = False - return monitor - - -@pytest.fixture -def mock_monitor2(mocker): - monitor = mocker.MagicMock() - monitor.is_stopped = False - return monitor - - -@pytest.fixture -def mock_stopped_monitor(mocker): - monitor = mocker.MagicMock() - monitor.is_stopped = True - return monitor - - -@pytest.fixture -def mock_monitor_supplier(mocker, mock_monitor1, mock_monitor2): - supplier = mocker.MagicMock() - supplier.side_effect = [mock_monitor1, mock_monitor2] - return supplier - - -@pytest.fixture(autouse=True) -def release_container(): - yield - while MonitoringThreadContainer._instance is not None: - release_resources() - - -def test_get_or_create_monitor__monitor_created( - container, mock_monitor_supplier, mock_stopped_monitor, mock_monitor1, mock_future, mocker): - mock_thread_pool = mocker.MagicMock() - mock_thread_pool.submit.return_value = mock_future - # Mock the _thread_pool directly on the container instance since it's now cached in __init__ - container._thread_pool = mock_thread_pool - - result = container.get_or_create_monitor(frozenset({"alias-1", "alias-2"}), mock_monitor_supplier) - assert mock_monitor1 == result - - mock_monitor_supplier.assert_called_once() - mock_thread_pool.submit.assert_called_once_with(mock_monitor1.run) - assert mock_monitor1 == container._monitor_map.get("alias-1") - assert mock_monitor1 == container._monitor_map.get("alias-2") - - -def test_get_or_create_monitor__from_monitor_map(container, mock_monitor1): - container._monitor_map.put_if_absent("alias-2", mock_monitor1) - - result = container.get_or_create_monitor(frozenset({"alias-1", "alias-2"}), mock_monitor_supplier) - assert mock_monitor1 == result - assert mock_monitor1 == container._monitor_map.get("alias-1") - - -def test_get_or_create_monitor__shared_aliases(container, mock_monitor_supplier, mock_monitor1): - host_aliases1 = frozenset({"host-1", "host-2"}) - host_aliases2 = frozenset({"host-2"}) - - aliases1_monitor = container.get_or_create_monitor(host_aliases1, mock_monitor_supplier) - aliases2_monitor = container.get_or_create_monitor(host_aliases2, mock_monitor_supplier) - assert mock_monitor1 == aliases1_monitor - assert aliases1_monitor == aliases2_monitor - mock_monitor_supplier.assert_called_once() - - -def test_get_or_create_monitor__separate_aliases(container, mock_monitor_supplier, mock_monitor1, mock_monitor2): - host_aliases1 = frozenset({"host-1"}) - host_aliases2 = frozenset({"host-2"}) - - aliases1_monitor = container.get_or_create_monitor(host_aliases1, mock_monitor_supplier) - aliases1_monitor_second_call = container.get_or_create_monitor(host_aliases1, mock_monitor_supplier) - assert mock_monitor1 == aliases1_monitor - assert aliases1_monitor == aliases1_monitor_second_call - mock_monitor_supplier.assert_called_once() - - aliases2_monitor = container.get_or_create_monitor(host_aliases2, mock_monitor_supplier) - assert mock_monitor2 == aliases2_monitor - assert aliases2_monitor != aliases1_monitor - - -def test_get_or_create_monitor__aliases_intersection(container, mock_monitor_supplier, mock_monitor1): - host_aliases1 = frozenset({"host-1"}) - host_aliases2 = frozenset({"host-1", "host-2"}) - host_aliases3 = frozenset({"host-2"}) - - aliases1_monitor = container.get_or_create_monitor(host_aliases1, mock_monitor_supplier) - aliases2_monitor = container.get_or_create_monitor(host_aliases2, mock_monitor_supplier) - aliases3_monitor = container.get_or_create_monitor(host_aliases3, mock_monitor_supplier) - - assert mock_monitor1 == aliases1_monitor - assert aliases1_monitor == aliases2_monitor - assert aliases3_monitor == aliases1_monitor - mock_monitor_supplier.assert_called_once() - - -def test_get_or_create_monitor__empty_aliases(container, mock_monitor_supplier): - with pytest.raises(AwsWrapperError): - container.get_or_create_monitor(frozenset(), mock_monitor_supplier) - - -def test_get_or_create_monitor__null_monitor(container, mock_monitor_supplier): - mock_monitor_supplier.side_effect = None - mock_monitor_supplier.return_value = None - with pytest.raises(AwsWrapperError): - container.get_or_create_monitor(frozenset({"alias-1"}), mock_monitor_supplier) - - -def test_release_monitor(mocker, mock_monitor1, mock_monitor2, container): - container._monitor_map.put_if_absent("alias-1", mock_monitor1) - container._monitor_map.put_if_absent("alias-2", mock_monitor2) - mock_future_1 = mocker.MagicMock() - mock_future_2 = mocker.MagicMock() - container._tasks_map.put_if_absent(mock_monitor1, mock_future_1) - container._tasks_map.put_if_absent(mock_monitor2, mock_future_2) - - container.release_monitor(mock_monitor2) - assert container._monitor_map.get("alias-1") - assert container._monitor_map.get("alias-2") is None - assert mock_future_1 == container._tasks_map.get(mock_monitor1) - assert container._tasks_map.get(mock_monitor2) is None - mock_future_1.cancel.assert_not_called() - mock_future_2.cancel.assert_called_once() - - -def test_release_instance(mocker, container, mock_monitor1, mock_future): - container._monitor_map.put_if_absent("alias-1", mock_monitor1) - container._tasks_map.put_if_absent(mock_monitor1, mock_future) - mock_future.done.return_value = False - mock_future.cancelled.return_value = False - - container2 = MonitoringThreadContainer() - assert container2 is container - - container2.clean_up() - - assert 0 == len(container._monitor_map) - assert 0 == len(container._tasks_map) - mock_future.cancel.assert_called_once() - assert MonitoringThreadContainer._instance is None - release_resources() diff --git a/tests/unit/test_multi_az_rds_host_list_provider.py b/tests/unit/test_multi_az_rds_host_list_provider.py index 4d06b87b3..bd0064d43 100644 --- a/tests/unit/test_multi_az_rds_host_list_provider.py +++ b/tests/unit/test_multi_az_rds_host_list_provider.py @@ -22,15 +22,15 @@ MultiAzTopologyUtils, RdsHostListProvider) from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import ProgrammingError +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.storage.storage_service import ( - StorageService, Topology) +from aws_advanced_python_wrapper.utils.services_container import Topology @pytest.fixture(autouse=True) def clear_caches(): - StorageService.clear_all() + services_container.get_storage_service().clear_all() def mock_topology_query(mock_conn, mock_cursor, records, writer_id=None): @@ -98,7 +98,7 @@ def create_provider(mock_provider_service, props): def test_get_topology_caches_topology(mocker, mock_provider_service, mock_conn, props, cache_hosts, refresh_ns): provider = create_provider(mock_provider_service, props) provider._initialize() - StorageService.set(provider._cluster_id, cache_hosts, Topology) + services_container.get_storage_service().put(Topology, provider._cluster_id, cache_hosts) mock_force_refresh = mocker.patch.object(provider, '_force_refresh_monitor') result = provider.refresh(mock_conn) @@ -110,7 +110,7 @@ def test_get_topology_caches_topology(mocker, mock_provider_service, mock_conn, def test_get_topology_force_update( mocker, mock_provider_service, mock_conn, cache_hosts, queried_hosts, props, refresh_ns): provider = create_provider(mock_provider_service, props) - StorageService.set(provider._cluster_id, cache_hosts, Topology) + services_container.get_storage_service().put(Topology, provider._cluster_id, cache_hosts) mocker.patch.object(provider, '_force_refresh_monitor', return_value=queried_hosts) result = provider.force_refresh(mock_conn) @@ -131,7 +131,7 @@ def test_get_topology_invalid_topology( mocker, mock_provider_service, mock_conn, mock_cursor, props, cache_hosts, refresh_ns): provider = create_provider(mock_provider_service, props) provider._initialize() - StorageService.set(provider._cluster_id, cache_hosts, Topology) + services_container.get_storage_service().put(Topology, provider._cluster_id, cache_hosts) mocker.patch.object(provider, '_force_refresh_monitor', return_value=()) result = provider.force_refresh() diff --git a/tests/unit/test_multithreaded_monitor_service.py b/tests/unit/test_multithreaded_monitor_service.py index 26cb91605..3c07731d1 100644 --- a/tests/unit/test_multithreaded_monitor_service.py +++ b/tests/unit/test_multithreaded_monitor_service.py @@ -22,7 +22,7 @@ from aws_advanced_python_wrapper import release_resources from aws_advanced_python_wrapper.host_monitoring_plugin import ( - MonitoringContext, MonitoringThreadContainer, MonitorService) + HostMonitorService, MonitoringContext) from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.utils.atomic import AtomicInt from aws_advanced_python_wrapper.utils.properties import (Properties, @@ -36,12 +36,9 @@ def mock_conn(mocker): @pytest.fixture def mock_monitor(mocker): - return mocker.MagicMock() - - -@pytest.fixture -def mock_executor(mocker): - return mocker.MagicMock() + monitor = mocker.MagicMock() + monitor.is_stopped = False + return monitor @pytest.fixture @@ -49,33 +46,6 @@ def mock_plugin_service(mocker): return mocker.MagicMock() -@pytest.fixture -def mock_future(mocker): - return mocker.MagicMock() - - -@pytest.fixture -def failure_detection_time_ms(): - return 10 - - -@pytest.fixture -def failure_detection_interval_ms(): - return 100 - - -@pytest.fixture -def failure_detection_count(): - return 3 - - -@pytest.fixture -def thread_container(mock_executor): - MonitoringThreadContainer._executor = mock_executor - mock_executor.return_value = mock_future - return mock_executor - - @pytest.fixture def host_info(): return HostInfo("localhost") @@ -102,9 +72,7 @@ def mock_aborted_connection_counter(mocker): @pytest.fixture(autouse=True) -def verify_concurrency(mock_monitor, mock_executor, mock_future, counter, concurrent_counter): - # The ThreadPoolExecutor may have been shut down by a previous test, so we'll need to recreate it here. - MonitoringThreadContainer._executor = ThreadPoolExecutor(thread_name_prefix="MonitoringThreadContainerExecutor") +def verify_concurrency(counter, concurrent_counter): yield counter.set(0) @@ -127,18 +95,18 @@ def test_start_monitoring__connections_to_different_hosts( host_alias_list = generate_host_aliases(num_conns, True) services = generate_services(num_conns) - try: - mock_create_monitor = mocker.patch( - "aws_advanced_python_wrapper.host_monitoring_plugin.MonitorService._create_monitor", - return_value=mock_monitor) - contexts = start_monitoring(num_conns, services, host_alias_list) - expected_start_monitoring_calls = [mocker.call(context) for context in contexts] - mock_monitor.start_monitoring.assert_has_calls(expected_start_monitoring_calls, True) - assert num_conns == len(MonitoringThreadContainer()._monitor_map) - expected_create_monitor_calls = [mocker.call(host_info, props, MonitoringThreadContainer())] * num_conns - mock_create_monitor.assert_has_calls(expected_create_monitor_calls) - finally: - release_service_resource(services) + mocker.patch( + "aws_advanced_python_wrapper.host_monitoring_plugin.Monitor.__init__", + return_value=None) + mocker.patch.object( + services[0]._monitor_service, 'run_if_absent_with_aliases', return_value=mock_monitor) + # Patch all services to use the same mock + for svc in services: + svc._monitor_service = services[0]._monitor_service + + contexts = start_monitoring(num_conns, services, host_alias_list) + expected_start_monitoring_calls = [mocker.call(context) for context in contexts] + mock_monitor.start_monitoring.assert_has_calls(expected_start_monitoring_calls, True) def test_start_monitoring__connections_to_same_host( @@ -154,18 +122,17 @@ def test_start_monitoring__connections_to_same_host( host_alias_list = generate_host_aliases(num_conns, False) services = generate_services(num_conns) - try: - mock_create_monitor = mocker.patch( - "aws_advanced_python_wrapper.host_monitoring_plugin.MonitorService._create_monitor", - return_value=mock_monitor) - contexts = start_monitoring(num_conns, services, host_alias_list) - expected_start_monitoring_calls = [mocker.call(context) for context in contexts] - mock_monitor.start_monitoring.assert_has_calls(expected_start_monitoring_calls, True) - assert 1 == len(MonitoringThreadContainer()._monitor_map) - expected_create_monitor_calls = [mocker.call(host_info, props, MonitoringThreadContainer())] - mock_create_monitor.assert_has_calls(expected_create_monitor_calls) - finally: - release_service_resource(services) + mocker.patch( + "aws_advanced_python_wrapper.host_monitoring_plugin.Monitor.__init__", + return_value=None) + mocker.patch.object( + services[0]._monitor_service, 'run_if_absent_with_aliases', return_value=mock_monitor) + for svc in services: + svc._monitor_service = services[0]._monitor_service + + contexts = start_monitoring(num_conns, services, host_alias_list) + expected_start_monitoring_calls = [mocker.call(context) for context in contexts] + mock_monitor.start_monitoring.assert_has_calls(expected_start_monitoring_calls, True) def test_stop_monitoring__connections_to_different_hosts( @@ -182,12 +149,9 @@ def test_stop_monitoring__connections_to_different_hosts( contexts = generate_contexts(num_conns, True) services = generate_services(num_conns) - try: - stop_monitoring(num_conns, services, contexts) - expected_stop_monitoring_calls = [mocker.call(context) for context in contexts] - mock_monitor.stop_monitoring.assert_has_calls(expected_stop_monitoring_calls, True) - finally: - release_service_resource(services) + stop_monitoring(num_conns, services, contexts) + expected_stop_monitoring_calls = [mocker.call(context) for context in contexts] + mock_monitor.stop_monitoring.assert_has_calls(expected_stop_monitoring_calls, True) def test_stop_monitoring__connections_to_same_host( @@ -204,12 +168,9 @@ def test_stop_monitoring__connections_to_same_host( contexts = generate_contexts(num_conns, False) services = generate_services(num_conns) - try: - stop_monitoring(num_conns, services, contexts) - expected_stop_monitoring_calls = [mocker.call(context) for context in contexts] - mock_monitor.stop_monitoring.assert_has_calls(expected_stop_monitoring_calls, True) - finally: - release_service_resource(services) + stop_monitoring(num_conns, services, contexts) + expected_stop_monitoring_calls = [mocker.call(context) for context in contexts] + mock_monitor.stop_monitoring.assert_has_calls(expected_stop_monitoring_calls, True) def generate_host_aliases(num_aliases: int, generate_unique_aliases: bool) -> List[FrozenSet[str]]: @@ -221,43 +182,35 @@ def generate_host_aliases(num_aliases: int, generate_unique_aliases: bool) -> Li @pytest.fixture def generate_services(mock_plugin_service): - def _generate_services(num_services: int) -> List[MonitorService]: - return [MonitorService(mock_plugin_service) for _ in range(num_services)] + def _generate_services(num_services: int) -> List[HostMonitorService]: + return [HostMonitorService(mock_plugin_service) for _ in range(num_services)] return _generate_services @pytest.fixture def generate_contexts( - mocker, mock_monitor, failure_detection_time_ms, failure_detection_interval_ms, - failure_detection_count, mock_aborted_connection_counter): + mocker, mock_monitor, mock_aborted_connection_counter): def _generate_contexts(num_contexts: int, generate_unique_contexts) -> List[MonitoringContext]: - host_aliases_list = generate_host_aliases(num_contexts, generate_unique_contexts) contexts = [] - for host_aliases in host_aliases_list: - MonitoringThreadContainer().get_or_create_monitor(host_aliases, lambda: mock_monitor) + for _ in range(num_contexts): contexts.append( MonitoringContext( mock_monitor, mocker.MagicMock(), mocker.MagicMock(), - failure_detection_time_ms, - failure_detection_interval_ms, - failure_detection_count, + 10, + 100, + 3, mock_aborted_connection_counter)) return contexts return _generate_contexts -def release_service_resource(services): - for service in services: - service.release_resources() - - @pytest.fixture def start_monitoring(counter, concurrent_counter, start_monitoring_thread): def _start_monitoring( num_threads: int, - services: List[MonitorService], + services: List[HostMonitorService], host_aliases_list: List[FrozenSet[str]]) -> List[MonitoringContext]: barrier = Barrier(num_threads) futures = [] @@ -279,33 +232,29 @@ def _start_monitoring( def start_monitoring_thread( mock_conn, host_info, - props, - failure_detection_time_ms, - failure_detection_interval_ms, - failure_detection_count): + props): def _start_monitoring_thread( barrier: Barrier, counter: AtomicInt, concurrent_counter: AtomicInt, - service: MonitorService, + service: HostMonitorService, host_aliases: FrozenSet[str]) -> MonitoringContext: barrier.wait() val = counter.get_and_increment() if val != 0: - # If the counter value is greater than 0 it means that this method was called by another thread concurrently concurrent_counter.get_and_increment() context = service.start_monitoring( mock_conn, host_aliases, host_info, - props, - failure_detection_time_ms, - failure_detection_interval_ms, - failure_detection_count) + Properties(), + 10, + 100, + 3) - sleep(0.01) # Briefly sleep to allow other threads to be executed concurrently + sleep(0.01) counter.get_and_decrement() return context return _start_monitoring_thread @@ -313,7 +262,7 @@ def _start_monitoring_thread( @pytest.fixture def stop_monitoring(counter, concurrent_counter): - def _stop_monitoring(num_threads: int, services: List[MonitorService], contexts: List[MonitoringContext]): + def _stop_monitoring(num_threads: int, services: List[HostMonitorService], contexts: List[MonitoringContext]): barrier = Barrier(num_threads) with ThreadPoolExecutor(num_threads) as executor: for i in range(num_threads): @@ -331,15 +280,14 @@ def stop_monitoring_thread( barrier: Barrier, counter: AtomicInt, concurrent_counter: AtomicInt, - service: MonitorService, + service: HostMonitorService, context: MonitoringContext): barrier.wait() val = counter.get_and_increment() if val != 0: - # If the counter value is greater than 0 it means that this method was called by another thread concurrently concurrent_counter.get_and_increment() - sleep(0.01) # Briefly sleep to allow other threads to be executed concurrently + sleep(0.01) service.stop_monitoring(context) counter.get_and_decrement() diff --git a/tests/unit/test_okta_plugin.py b/tests/unit/test_okta_plugin.py index 7c8902926..0ebafb95f 100644 --- a/tests/unit/test_okta_plugin.py +++ b/tests/unit/test_okta_plugin.py @@ -15,8 +15,6 @@ from __future__ import annotations from datetime import datetime, timedelta -from typing import Dict -from unittest.mock import patch import pytest from boto3 import Session @@ -25,8 +23,9 @@ AwsCredentialsManager from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.hostinfo import HostInfo -from aws_advanced_python_wrapper.iam_plugin import TokenInfo from aws_advanced_python_wrapper.okta_plugin import OktaAuthPlugin +from aws_advanced_python_wrapper.utils import services_container +from aws_advanced_python_wrapper.utils.iam_utils import TokenInfo from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) @@ -39,12 +38,12 @@ _PG_HOST_INFO = HostInfo("pg.testdb.us-east-2.rds.amazonaws.com") -_token_cache: Dict[str, TokenInfo] = {} - @pytest.fixture(autouse=True) def clear_cache(): - _token_cache.clear() + from datetime import timedelta + services_container.get_storage_service().register(TokenInfo, item_expiration_time=timedelta(minutes=30)) + services_container.get_storage_service().clear(TokenInfo) AwsCredentialsManager.release_resources() @@ -101,17 +100,17 @@ def mock_default_behavior(mock_session, mock_client, mock_func, mock_connection, yield -@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache) def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): properties: Properties = Properties() WrapperProperties.PLUGINS.set(properties, "okta") WrapperProperties.DB_USER.set(properties, _DB_USER) initial_token = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5)) - _token_cache[_PG_CACHE_KEY] = initial_token + storage = services_container.get_storage_service() + storage.put(TokenInfo, _PG_CACHE_KEY, initial_token) target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_session) key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + str(_DEFAULT_PG_PORT) + ":postgesqlUser" - _token_cache[key] = initial_token + storage.put(TokenInfo, key, initial_token) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -123,18 +122,17 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi mock_client.generate_db_auth_token.assert_not_called() - actual_token = _token_cache.get(_PG_CACHE_KEY) + actual_token = storage.get(TokenInfo, _PG_CACHE_KEY) assert _GENERATED_TOKEN != actual_token.token assert _TEST_TOKEN == actual_token.token assert actual_token.is_expired() is False -@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache) def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect, mock_credentials_provider_factory): test_props: Properties = Properties({"plugins": "okta", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) WrapperProperties.DB_USER.set(test_props, _DB_USER) initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5)) - _token_cache[_PG_CACHE_KEY] = initial_token + services_container.get_storage_service().put(TokenInfo, _PG_CACHE_KEY, initial_token) target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory) @@ -155,7 +153,6 @@ def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_fu assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN -@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache) def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect, mock_credentials_provider_factory): test_props: Properties = Properties({"plugins": "okta", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) WrapperProperties.DB_USER.set(test_props, _DB_USER) @@ -179,7 +176,6 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN -@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache) def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect, mock_credentials_provider_factory): test_props: Properties = Properties({"plugins": "okta", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) @@ -210,7 +206,6 @@ def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_sess assert str(e_info.value) == Messages.get_formatted("OktaAuthPlugin.ConnectException", exception_message) -@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache) def test_connect_with_specified_iam_host_port_region(mocker, mock_plugin_service, mock_session, @@ -231,7 +226,7 @@ def test_connect_with_specified_iam_host_port_region(mocker, test_token_info = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5)) key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + str(expected_port) + ":specifiedUser" - _token_cache[key] = test_token_info + services_container.get_storage_service().put(TokenInfo, key, test_token_info) mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{expected_region}" diff --git a/tests/unit/test_rds_host_list_provider.py b/tests/unit/test_rds_host_list_provider.py index 6c97cbf4d..c9df3ac83 100644 --- a/tests/unit/test_rds_host_list_provider.py +++ b/tests/unit/test_rds_host_list_provider.py @@ -21,15 +21,15 @@ AuroraTopologyUtils, RdsHostListProvider) from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.pep249 import ProgrammingError +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.storage.storage_service import ( - StorageService, Topology) +from aws_advanced_python_wrapper.utils.services_container import Topology @pytest.fixture(autouse=True) def clear_caches(): - StorageService.clear_all() + services_container.get_storage_service().clear_all() def mock_topology_query(mock_conn, mock_cursor, records): @@ -91,7 +91,7 @@ def test_get_topology_caches_topology(mocker, mock_provider_service, mock_conn, topology_utils = AuroraTopologyUtils(AuroraPgDialect(), props) provider = RdsHostListProvider(mock_provider_service, mock_provider_service, props, topology_utils) provider._initialize() - StorageService.set(provider._cluster_id, cache_hosts, Topology) + services_container.get_storage_service().put(Topology, provider._cluster_id, cache_hosts) mock_force_refresh = mocker.patch.object(provider, '_force_refresh_monitor') result = provider.refresh(mock_conn) @@ -104,7 +104,7 @@ def test_get_topology_force_update( mocker, mock_provider_service, mock_conn, cache_hosts, queried_hosts, props, refresh_ns): topology_utils = AuroraTopologyUtils(AuroraPgDialect(), props) provider = RdsHostListProvider(mock_provider_service, mock_provider_service, props, topology_utils) - StorageService.set(provider._cluster_id, cache_hosts, Topology) + services_container.get_storage_service().put(Topology, provider._cluster_id, cache_hosts) mocker.patch.object(provider, '_force_refresh_monitor', return_value=queried_hosts) result = provider.force_refresh(mock_conn) @@ -127,7 +127,7 @@ def test_get_topology_invalid_topology( topology_utils = AuroraTopologyUtils(AuroraPgDialect(), props) provider = RdsHostListProvider(mock_provider_service, mock_provider_service, props, topology_utils) provider._initialize() - StorageService.set(provider._cluster_id, cache_hosts, Topology) + services_container.get_storage_service().put(Topology, provider._cluster_id, cache_hosts) mocker.patch.object(provider, '_force_refresh_monitor', return_value=()) result = provider.force_refresh() diff --git a/tests/unit/test_secrets_manager_plugin.py b/tests/unit/test_secrets_manager_plugin.py index ae7e94187..526de38a5 100644 --- a/tests/unit/test_secrets_manager_plugin.py +++ b/tests/unit/test_secrets_manager_plugin.py @@ -15,8 +15,6 @@ from __future__ import annotations from types import SimpleNamespace -from typing import Tuple -from unittest.mock import patch import pytest from boto3 import Session @@ -24,11 +22,11 @@ from aws_advanced_python_wrapper.aws_credentials_manager import \ AwsCredentialsManager -from aws_advanced_python_wrapper.aws_secrets_manager_plugin import \ - AwsSecretsManagerPlugin +from aws_advanced_python_wrapper.aws_secrets_manager_plugin import ( + AwsSecretsManagerPlugin, Secret) from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.hostinfo import HostInfo -from aws_advanced_python_wrapper.utils.cache_map import CacheMap +from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.properties import Properties _TEST_REGION = "us-east-2" @@ -67,12 +65,12 @@ } }, "some_operation") -_secrets_cache: CacheMap[Tuple, SimpleNamespace] = CacheMap() - @pytest.fixture(autouse=True) def clear_caches(): - _secrets_cache.clear() + from datetime import timedelta + services_container.get_storage_service().register(Secret, item_expiration_time=timedelta(minutes=30)) + services_container.get_storage_service().clear(Secret) AwsCredentialsManager.release_resources() @@ -129,26 +127,26 @@ def test_properties(): }) -@patch("aws_advanced_python_wrapper.aws_secrets_manager_plugin.AwsSecretsManagerPlugin._secrets_cache", _secrets_cache) def test_connect_with_cached_secrets( mocker, mock_plugin_service, mock_session, mock_func, mock_client, test_properties): - _secrets_cache.put(_SECRET_CACHE_KEY, _TEST_SECRET, _ONE_YEAR_IN_NANOSECONDS) + storage = services_container.get_storage_service() + storage.put(Secret, _SECRET_CACHE_KEY, Secret(_TEST_SECRET), item_expiration_ns=_ONE_YEAR_IN_NANOSECONDS) target_plugin: AwsSecretsManagerPlugin = AwsSecretsManagerPlugin( mock_plugin_service, test_properties, mock_session) target_plugin.connect( mocker.MagicMock(), mocker.MagicMock(), _TEST_HOST_INFO, test_properties, True, mock_func) - assert 1 == len(_secrets_cache) + assert 1 == storage.size(Secret) mock_client.get_secret_value.assert_not_called() mock_func.assert_called_once() assert _TEST_USERNAME == test_properties.get("user") assert _TEST_PASSWORD == test_properties.get("password") -@patch("aws_advanced_python_wrapper.aws_secrets_manager_plugin.AwsSecretsManagerPlugin._secrets_cache", _secrets_cache) def test_connect_with_new_secrets( mocker, mock_plugin_service, mock_session, mock_func, mock_client, test_properties): - assert 0 == len(_secrets_cache) + storage = services_container.get_storage_service() + assert 0 == storage.size(Secret) target_plugin: AwsSecretsManagerPlugin = AwsSecretsManagerPlugin( mock_plugin_service, test_properties, mock_session) @@ -156,7 +154,7 @@ def test_connect_with_new_secrets( target_plugin.connect( mocker.MagicMock(), mocker.MagicMock(), _TEST_HOST_INFO, test_properties, True, mock_func) - assert 1 == len(_secrets_cache) + assert 1 == storage.size(Secret) mock_client.get_secret_value.assert_called_once() mock_func.assert_called_once() assert _TEST_USERNAME == test_properties.get("user") @@ -176,9 +174,9 @@ def test_missing_required_params(key: str, mock_plugin_service, mock_session): assert "required" in str(exc_info.value).lower() -@patch("aws_advanced_python_wrapper.aws_secrets_manager_plugin.AwsSecretsManagerPlugin._secrets_cache", _secrets_cache) def test_failed_initial_connection_with_unhandled_error( mocker, mock_plugin_service, mock_session, mock_func, mock_client, test_properties): + storage = services_container.get_storage_service() exception_msg = "Unhandled error during connection" # Simulate an unhandled exception (neither a login exception nor a network exception) @@ -192,17 +190,17 @@ def test_failed_initial_connection_with_unhandled_error( target_plugin.connect( mocker.MagicMock(), mocker.MagicMock(), _TEST_HOST_INFO, test_properties, True, mock_func) - assert 1 == len(_secrets_cache) + assert 1 == storage.size(Secret) mock_client.get_secret_value.assert_called_once() mock_func.assert_called_once() assert _TEST_USERNAME == test_properties.get("user") assert _TEST_PASSWORD == test_properties.get("password") -@patch("aws_advanced_python_wrapper.aws_secrets_manager_plugin.AwsSecretsManagerPlugin._secrets_cache", _secrets_cache) def test_connect_with_new_secrets_after_trying_with_cached_secrets( mocker, mock_plugin_service, mock_session, mock_func, mock_client, test_properties): - _secrets_cache.put(_SECRET_CACHE_KEY, _INVALID_TEST_SECRET, _ONE_YEAR_IN_NANOSECONDS) + storage = services_container.get_storage_service() + storage.put(Secret, _SECRET_CACHE_KEY, Secret(_INVALID_TEST_SECRET), item_expiration_ns=_ONE_YEAR_IN_NANOSECONDS) login_exception = Exception("Login failed with cached credentials") mock_func.side_effect = [login_exception, mocker.MagicMock()] @@ -213,14 +211,13 @@ def test_connect_with_new_secrets_after_trying_with_cached_secrets( target_plugin.connect(mocker.MagicMock(), mocker.MagicMock(), _TEST_HOST_INFO, test_properties, True, mock_func) - assert 1 == len(_secrets_cache) + assert 1 == storage.size(Secret) mock_client.get_secret_value.assert_called_once() assert 2 == mock_func.call_count assert _TEST_USERNAME == test_properties.get("user") assert _TEST_PASSWORD == test_properties.get("password") -@patch("aws_advanced_python_wrapper.aws_secrets_manager_plugin.AwsSecretsManagerPlugin._secrets_cache", _secrets_cache) def test_failed_to_read_secrets( mocker, mock_plugin_service, mock_session, mock_func, mock_client, test_properties): mock_client.get_secret_value.return_value = "foo" @@ -233,7 +230,6 @@ def test_failed_to_read_secrets( mocker.MagicMock(), mocker.MagicMock(), _TEST_HOST_INFO, test_properties, True, mock_func) -@patch("aws_advanced_python_wrapper.aws_secrets_manager_plugin.AwsSecretsManagerPlugin._secrets_cache", _secrets_cache) def test_failed_to_get_secrets( mocker, mock_plugin_service, mock_session, mock_func, mock_client, test_properties): mock_client.get_secret_value.side_effect = _GENERIC_CLIENT_ERROR @@ -291,9 +287,9 @@ def test_connection_with_region_parameter_and_arn( mock_client.get_secret_value.assert_called_with(SecretId=arn) -@patch("aws_advanced_python_wrapper.aws_secrets_manager_plugin.AwsSecretsManagerPlugin._secrets_cache", _secrets_cache) def test_connect_with_different_secret_keys( mocker, mock_plugin_service, mock_session, mock_func, mock_client, test_properties): + storage = services_container.get_storage_service() test_properties["secrets_manager_secret_username_key"] = _TEST_USERNAME_KEY test_properties["secrets_manager_secret_password_key"] = _TEST_PASSWORD_KEY secret_string = ( @@ -307,7 +303,7 @@ def test_connect_with_different_secret_keys( target_plugin.connect( mocker.MagicMock(), mocker.MagicMock(), _TEST_HOST_INFO, test_properties, True, mock_func) - assert 1 == len(_secrets_cache) + assert 1 == storage.size(Secret) mock_client.get_secret_value.assert_called_once() mock_func.assert_called_once() assert _TEST_USERNAME == test_properties.get("user") diff --git a/tests/unit/test_sliding_expiration_cache.py b/tests/unit/test_sliding_expiration_cache.py index 4029faf86..328a2b9ca 100644 --- a/tests/unit/test_sliding_expiration_cache.py +++ b/tests/unit/test_sliding_expiration_cache.py @@ -14,8 +14,8 @@ import time -from aws_advanced_python_wrapper.utils.sliding_expiration_cache import ( - SlidingExpirationCache, SlidingExpirationCacheWithCleanupThread) +from aws_advanced_python_wrapper.utils.storage.sliding_expiration_cache import \ + SlidingExpirationCache def test_compute_if_absent(): @@ -89,32 +89,29 @@ def test_clear(): assert item2.disposed is True -def test_cleanup_thread_continuous_removal(): - # Use very short cleanup interval for testing (100ms) - cache = SlidingExpirationCacheWithCleanupThread( - cleanup_interval_ns=100_000_000, # 100ms - item_disposal_func=lambda item: item.dispose() - ) - - # First cycle: insert item that expires quickly - item1 = DisposableItem(True) - cache.compute_if_absent("key1", lambda _: item1, 50_000_000) # 50ms expiration - assert cache.get("key1") == item1 - - # Wait for cleanup thread to remove expired item - time.sleep(0.2) # Wait 200ms for cleanup - assert cache.get("key1") is None - assert item1.disposed is True +def test_put_inserts_new_item(): + cache = SlidingExpirationCache(50_000_000) + cache.put("key", "value", 15_000_000_000) + assert "value" == cache.get("key") - # Second cycle: insert another item that expires quickly - item2 = DisposableItem(True) - cache.compute_if_absent("key2", lambda _: item2, 50_000_000) # 50ms expiration - assert cache.get("key2") == item2 - # Wait for cleanup thread to remove second expired item - time.sleep(0.2) # Wait 200ms for cleanup - assert cache.get("key2") is None - assert item2.disposed is True +def test_put_replaces_existing_item(): + cache = SlidingExpirationCache(50_000_000) + cache.put("key", "old", 15_000_000_000) + cache.put("key", "new", 15_000_000_000) + assert "new" == cache.get("key") + assert 1 == len(cache) + + +def test_put_disposes_old_item(): + cache = SlidingExpirationCache(50_000_000, item_disposal_func=lambda item: item.dispose()) + old_item = DisposableItem(True) + new_item = DisposableItem(True) + cache.put("key", old_item, 15_000_000_000) + cache.put("key", new_item, 15_000_000_000) + assert old_item.disposed is True + assert new_item.disposed is False + assert new_item == cache.get("key") class DisposableItem: diff --git a/tests/unit/test_sliding_expiration_cache_container.py b/tests/unit/test_sliding_expiration_cache_container.py deleted file mode 100644 index c21c54826..000000000 --- a/tests/unit/test_sliding_expiration_cache_container.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import pytest - -from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ - SlidingExpirationCache -from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ - SlidingExpirationCacheContainer - - -@pytest.fixture(autouse=True) -def cleanup_caches(): - """Clean up all caches after each test""" - yield - SlidingExpirationCacheContainer.release_resources() - - -def test_get_or_create_cache_creates_new_cache(): - cache = SlidingExpirationCacheContainer.get_or_create_cache("test_cache") - assert isinstance(cache, SlidingExpirationCache) - - -def test_get_or_create_cache_returns_existing_cache(): - cache1 = SlidingExpirationCacheContainer.get_or_create_cache("test_cache") - cache2 = SlidingExpirationCacheContainer.get_or_create_cache("test_cache") - assert cache1 is cache2 - - -def test_get_or_create_cache_with_custom_cleanup_interval(): - cache = SlidingExpirationCacheContainer.get_or_create_cache( - "test_cache", - cleanup_interval_ns=5_000_000_000 # 5 seconds - ) - assert cache._cleanup_interval_ns == 5_000_000_000 - - -def test_get_or_create_cache_with_disposal_functions(): - disposed_items = [] - - def should_dispose(item): - return item > 10 - - def dispose(item): - disposed_items.append(item) - - cache = SlidingExpirationCacheContainer.get_or_create_cache( - "test_cache", - should_dispose_func=should_dispose, - item_disposal_func=dispose - ) - - assert cache._should_dispose_func is should_dispose - assert cache._item_disposal_func is dispose - - -def test_multiple_caches_are_independent(): - cache1 = SlidingExpirationCacheContainer.get_or_create_cache("cache1") - cache2 = SlidingExpirationCacheContainer.get_or_create_cache("cache2") - - cache1.compute_if_absent("key1", lambda k: "value1", 1_000_000_000) - cache2.compute_if_absent("key2", lambda k: "value2", 1_000_000_000) - - assert cache1.get("key1") == "value1" - assert cache1.get("key2") is None - assert cache2.get("key2") == "value2" - assert cache2.get("key1") is None - - -def test_cleanup_thread_starts_on_first_cache(): - # Cleanup thread should start when first cache is created - SlidingExpirationCacheContainer.get_or_create_cache("test_cache") - - # Check that cleanup thread is running - assert SlidingExpirationCacheContainer._cleanup_thread is not None - assert SlidingExpirationCacheContainer._cleanup_thread.is_alive() - - -def test_release_resources_clears_all_caches(): - cache1 = SlidingExpirationCacheContainer.get_or_create_cache("cache1") - cache2 = SlidingExpirationCacheContainer.get_or_create_cache("cache2") - - cache1.compute_if_absent("key1", lambda k: "value1", 1_000_000_000) - cache2.compute_if_absent("key2", lambda k: "value2", 1_000_000_000) - - SlidingExpirationCacheContainer.release_resources() - - # Caches should be cleared - assert len(SlidingExpirationCacheContainer._caches) == 0 - - -def test_release_resources_stops_cleanup_thread(): - SlidingExpirationCacheContainer.get_or_create_cache("test_cache") - - cleanup_thread = SlidingExpirationCacheContainer._cleanup_thread - assert cleanup_thread is not None - assert cleanup_thread.is_alive() - - SlidingExpirationCacheContainer.release_resources() - - # Give thread time to stop - time.sleep(0.1) - - # Thread should be stopped - assert not cleanup_thread.is_alive() - - -def test_release_resources_disposes_items(): - disposed_items = [] - - def dispose(item): - disposed_items.append(item) - - cache = SlidingExpirationCacheContainer.get_or_create_cache( - "test_cache", - item_disposal_func=dispose - ) - - cache.compute_if_absent("key1", lambda k: "value1", 1_000_000_000) - cache.compute_if_absent("key2", lambda k: "value2", 1_000_000_000) - - SlidingExpirationCacheContainer.release_resources() - - # Items should have been disposed - assert "value1" in disposed_items - assert "value2" in disposed_items - - -def test_cleanup_thread_cleans_expired_items(): - # Use very short intervals for testing - cache = SlidingExpirationCacheContainer.get_or_create_cache( - "test_cache", - cleanup_interval_ns=100_000_000 # 0.1 seconds - ) - - # Add item with very short expiration - cache.compute_if_absent("key1", lambda k: "value1", 50_000_000) # 0.05 seconds - - assert cache.get("key1") == "value1" - - # Wait for item to expire and cleanup to run - time.sleep(0.3) - - # Item should be cleaned up - assert cache.get("key1") is None - - -def test_same_cache_name_returns_same_instance_across_calls(): - cache1 = SlidingExpirationCacheContainer.get_or_create_cache("shared_cache") - cache1.compute_if_absent("key1", lambda k: "value1", 1_000_000_000) - - # Get the same cache again - cache2 = SlidingExpirationCacheContainer.get_or_create_cache("shared_cache") - - # Should be the same instance with the same data - assert cache1 is cache2 - assert cache2.get("key1") == "value1" - - -def test_cleanup_thread_handles_multiple_caches(): - cache1 = SlidingExpirationCacheContainer.get_or_create_cache( - "cache1", - cleanup_interval_ns=100_000_000 # 0.1 seconds - ) - cache2 = SlidingExpirationCacheContainer.get_or_create_cache( - "cache2", - cleanup_interval_ns=100_000_000 # 0.1 seconds - ) - - # Add items with short expiration - cache1.compute_if_absent("key1", lambda k: "value1", 50_000_000) - cache2.compute_if_absent("key2", lambda k: "value2", 50_000_000) - - assert cache1.get("key1") == "value1" - assert cache2.get("key2") == "value2" - - # Wait for cleanup - time.sleep(0.3) - - # Both should be cleaned up - assert cache1.get("key1") is None - assert cache2.get("key2") is None - - -def test_release_resources_handles_disposal_errors(): - def failing_dispose(item): - raise Exception("Disposal failed") - - cache = SlidingExpirationCacheContainer.get_or_create_cache( - "test_cache", - item_disposal_func=failing_dispose - ) - - cache.compute_if_absent("key1", lambda k: "value1", 1_000_000_000) - - # Should not raise exception even if disposal fails - SlidingExpirationCacheContainer.release_resources() - - # Cache should still be cleared - assert len(SlidingExpirationCacheContainer._caches) == 0 - - -def test_cleanup_thread_respects_is_stopped_event(): - # Clear the stop event first in case it was set by a previous test - SlidingExpirationCacheContainer._is_stopped.clear() - - SlidingExpirationCacheContainer.get_or_create_cache("test_cache") - - cleanup_thread = SlidingExpirationCacheContainer._cleanup_thread - assert cleanup_thread is not None - assert cleanup_thread.is_alive() - - # Set the stop event - SlidingExpirationCacheContainer._is_stopped.set() - - # Thread should stop quickly (not wait for full cleanup interval) - cleanup_thread.join(timeout=1.0) - assert not cleanup_thread.is_alive() diff --git a/tests/unit/test_sql_alchemy_pooled_connection_provider.py b/tests/unit/test_sql_alchemy_pooled_connection_provider.py index ba36c08c4..56e82795e 100644 --- a/tests/unit/test_sql_alchemy_pooled_connection_provider.py +++ b/tests/unit/test_sql_alchemy_pooled_connection_provider.py @@ -21,7 +21,7 @@ PoolKey, SqlAlchemyPooledConnectionProvider) from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ +from aws_advanced_python_wrapper.utils.storage.sliding_expiration_cache import \ SlidingExpirationCache diff --git a/tests/unit/test_storage_service.py b/tests/unit/test_storage_service.py new file mode 100644 index 000000000..69d44c0b5 --- /dev/null +++ b/tests/unit/test_storage_service.py @@ -0,0 +1,154 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import timedelta +from unittest.mock import MagicMock + +import pytest + +from aws_advanced_python_wrapper.utils.storage.storage_service import \ + StorageService + + +class TypeA: + pass + + +class TypeB: + pass + + +@pytest.fixture +def publisher(): + return MagicMock() + + +@pytest.fixture +def storage(publisher): + return StorageService(publisher) + + +def test_register_and_get(storage): + storage.register(TypeA, item_expiration_time=timedelta(minutes=5)) + storage.put(TypeA, "key1", "value1") + assert storage.get(TypeA, "key1") == "value1" + + +def test_get_unregistered_type_returns_none(storage): + assert storage.get(TypeA, "key1") is None + + +def test_put_unregistered_type_raises(storage): + with pytest.raises(ValueError): + storage.put(TypeA, "key1", "value1") + + +def test_put_replaces_existing_value(storage): + storage.register(TypeA, item_expiration_time=timedelta(minutes=5)) + storage.put(TypeA, "key1", "old") + storage.put(TypeA, "key1", "new") + assert storage.get(TypeA, "key1") == "new" + + +def test_multiple_types_are_independent(storage): + storage.register(TypeA, item_expiration_time=timedelta(minutes=5)) + storage.register(TypeB, item_expiration_time=timedelta(minutes=5)) + + storage.put(TypeA, "key1", "a_value") + storage.put(TypeB, "key1", "b_value") + + assert storage.get(TypeA, "key1") == "a_value" + assert storage.get(TypeB, "key1") == "b_value" + + +def test_remove(storage): + storage.register(TypeA, item_expiration_time=timedelta(minutes=5)) + storage.put(TypeA, "key1", "value1") + storage.remove(TypeA, "key1") + assert storage.get(TypeA, "key1") is None + + +def test_remove_unregistered_type_is_noop(storage): + storage.remove(TypeA, "key1") # should not raise + + +def test_clear(storage): + storage.register(TypeA, item_expiration_time=timedelta(minutes=5)) + storage.put(TypeA, "k1", "v1") + storage.put(TypeA, "k2", "v2") + storage.clear(TypeA) + assert storage.get(TypeA, "k1") is None + assert storage.get(TypeA, "k2") is None + + +def test_clear_all(storage): + storage.register(TypeA, item_expiration_time=timedelta(minutes=5)) + storage.register(TypeB, item_expiration_time=timedelta(minutes=5)) + storage.put(TypeA, "k1", "v1") + storage.put(TypeB, "k2", "v2") + + storage.clear_all() + + assert storage.get(TypeA, "k1") is None + assert storage.get(TypeB, "k2") is None + + +def test_exists(storage): + storage.register(TypeA, item_expiration_time=timedelta(minutes=5)) + assert storage.exists(TypeA, "key1") is False + storage.put(TypeA, "key1", "value1") + assert storage.exists(TypeA, "key1") is True + + +def test_exists_unregistered_type(storage): + assert storage.exists(TypeA, "key1") is False + + +def test_release_resources(storage): + storage.register(TypeA, item_expiration_time=timedelta(minutes=5)) + storage.put(TypeA, "key1", "value1") + + storage.release_resources() + + assert storage.get(TypeA, "key1") is None + assert len(storage._caches) == 0 + + +def test_get_publishes_data_access_event(storage, publisher): + storage.register(TypeA, item_expiration_time=timedelta(minutes=5)) + storage.put(TypeA, "key1", "value1") + + publisher.reset_mock() + storage.get(TypeA, "key1") + + publisher.publish.assert_called_once() + event = publisher.publish.call_args[0][0] + assert event.data_type is TypeA + assert event.key == "key1" + + +def test_get_miss_does_not_publish_event(storage, publisher): + storage.register(TypeA, item_expiration_time=timedelta(minutes=5)) + + publisher.reset_mock() + storage.get(TypeA, "nonexistent") + + publisher.publish.assert_not_called() + + +def test_register_is_idempotent(storage): + storage.register(TypeA, item_expiration_time=timedelta(minutes=5)) + storage.put(TypeA, "key1", "value1") + storage.register(TypeA, item_expiration_time=timedelta(minutes=10)) # should not replace + assert storage.get(TypeA, "key1") == "value1" diff --git a/tests/unit/test_thread_pool_container.py b/tests/unit/test_thread_pool_container.py deleted file mode 100644 index 5f4d415c4..000000000 --- a/tests/unit/test_thread_pool_container.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from concurrent.futures import ThreadPoolExecutor - -import pytest - -from aws_advanced_python_wrapper.thread_pool_container import \ - ThreadPoolContainer - - -@pytest.fixture(autouse=True) -def cleanup_pools(): - """Clean up all pools after each test""" - yield - ThreadPoolContainer.release_resources() - - -def test_get_thread_pool_creates_new_pool(): - pool = ThreadPoolContainer.get_thread_pool("test_pool") - assert isinstance(pool, ThreadPoolExecutor) - assert ThreadPoolContainer.has_pool("test_pool") - - -def test_get_thread_pool_returns_existing_pool(): - pool1 = ThreadPoolContainer.get_thread_pool("test_pool") - pool2 = ThreadPoolContainer.get_thread_pool("test_pool") - assert pool1 is pool2 - - -def test_get_thread_pool_with_max_workers(): - pool = ThreadPoolContainer.get_thread_pool("test_pool", max_workers=5) - assert pool._max_workers == 5 - - -def test_has_pool(): - assert not ThreadPoolContainer.has_pool("nonexistent") - ThreadPoolContainer.get_thread_pool("test_pool") - assert ThreadPoolContainer.has_pool("test_pool") - - -def test_get_pool_names(): - assert ThreadPoolContainer.get_pool_names() == [] - ThreadPoolContainer.get_thread_pool("pool1") - ThreadPoolContainer.get_thread_pool("pool2") - names = ThreadPoolContainer.get_pool_names() - assert "pool1" in names - assert "pool2" in names - assert len(names) == 2 - - -def test_get_pool_count(): - assert ThreadPoolContainer.get_pool_count() == 0 - ThreadPoolContainer.get_thread_pool("pool1") - assert ThreadPoolContainer.get_pool_count() == 1 - ThreadPoolContainer.get_thread_pool("pool2") - assert ThreadPoolContainer.get_pool_count() == 2 - - -def test_release_pool(): - ThreadPoolContainer.get_thread_pool("test_pool") - assert ThreadPoolContainer.has_pool("test_pool") - - result = ThreadPoolContainer.release_pool("test_pool") - assert result is True - assert not ThreadPoolContainer.has_pool("test_pool") - - -def test_release_nonexistent_pool(): - result = ThreadPoolContainer.release_pool("nonexistent") - assert result is False - - -def test_release_resources(): - ThreadPoolContainer.get_thread_pool("pool1") - ThreadPoolContainer.get_thread_pool("pool2") - assert ThreadPoolContainer.get_pool_count() == 2 - - ThreadPoolContainer.release_resources() - assert ThreadPoolContainer.get_pool_count() == 0 - - -def test_set_default_max_workers(): - ThreadPoolContainer.set_default_max_workers(10) - pool = ThreadPoolContainer.get_thread_pool("test_pool") - assert pool._max_workers == 10 - - -def test_thread_name_prefix(): - pool = ThreadPoolContainer.get_thread_pool("custom_name") - # Check that the thread name prefix is set correctly - assert pool._thread_name_prefix == "custom_name"