Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions aws_advanced_python_wrapper/aws_secrets_manager_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from datetime import timedelta
from json import JSONDecodeError, loads
from re import search
from types import SimpleNamespace
Expand All @@ -23,7 +24,7 @@

from aws_advanced_python_wrapper.aws_credentials_manager import \
AwsCredentialsManager
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
from aws_advanced_python_wrapper.utils import services_container

if TYPE_CHECKING:
from boto3 import Session
Expand All @@ -46,14 +47,20 @@
logger = Logger(__name__)


class Secret:
"""Wrapper type for secrets, used as StorageService type key."""

def __init__(self, value: SimpleNamespace):
self.value = value


class AwsSecretsManagerPlugin(Plugin):
_SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.CONNECT.method_name, DbApiMethod.FORCE_CONNECT.method_name}

_SECRETS_ARN_PATTERN = r"^arn:aws:secretsmanager:(?P<region>[^:\n]*):[^:\n]*:([^:/\n]*[:/])?(.*)$"
_ONE_YEAR_IN_SECONDS = 60 * 60 * 24 * 365

_secret: Optional[SimpleNamespace] = None
_secrets_cache: CacheMap[Tuple, SimpleNamespace] = CacheMap()
_secret_key: Tuple = ()

@property
Expand All @@ -63,6 +70,8 @@ def subscribed_methods(self) -> Set[str]:
def __init__(self, plugin_service: PluginService, props: Properties, session: Optional[Session] = None):
self._plugin_service = plugin_service
self._session = session
self._storage_service = services_container.get_storage_service()
self._storage_service.register(Secret, item_expiration_time=timedelta(minutes=30))

secret_id = WrapperProperties.SECRETS_MANAGER_SECRET_ID.get(props)
if not secret_id:
Expand Down Expand Up @@ -100,13 +109,13 @@ def force_connect(
return self._connect(host_info, props, force_connect_func)

def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callable) -> Connection:
token_expiration_sec: int = WrapperProperties.SECRETS_MANAGER_EXPIRATION.get_int(props)
token_expiration_sec = WrapperProperties.SECRETS_MANAGER_EXPIRATION.get_int(props)
# if value is less than 0, default to one year
if token_expiration_sec < 0:
token_expiration_sec = AwsSecretsManagerPlugin._ONE_YEAR_IN_SECONDS
token_expiration_ns = token_expiration_sec * 1_000_000_000

secret_fetched: bool = self._update_secret(host_info, props, token_expiration_ns=token_expiration_ns)
secret_fetched: bool = self._update_secret(host_info, props, token_expiration_ns)

try:
self._apply_secret_to_properties(props)
Expand All @@ -120,7 +129,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
raise AwsWrapperError(
Messages.get_formatted("AwsSecretsManagerPlugin.ConnectException", e), e) from e

secret_fetched = self._update_secret(host_info, props, token_expiration_ns=token_expiration_ns, force_refetch=True)
secret_fetched = self._update_secret(host_info, props, token_expiration_ns, force_refetch=True)

if secret_fetched:
try:
Expand All @@ -146,13 +155,14 @@ def _update_secret(self, host_info: HostInfo, props: Properties, token_expiratio

try:
fetched: bool = False
self._secret: Optional[SimpleNamespace] = AwsSecretsManagerPlugin._secrets_cache.get(self._secret_key)
cached_secret = self._storage_service.get(Secret, self._secret_key)
self._secret = cached_secret.value if cached_secret is not None else None
endpoint = self._secret_key[2]
if not self._secret or force_refetch:
try:
self._secret = self._fetch_latest_credentials(host_info, props)
if self._secret:
AwsSecretsManagerPlugin._secrets_cache.put(self._secret_key, self._secret, token_expiration_ns)
self._storage_service.put(Secret, self._secret_key, Secret(self._secret), item_expiration_ns=token_expiration_ns)
fetched = True
except (ClientError, AttributeError) as e:
logger.debug("AwsSecretsManagerPlugin.FailedToFetchDbCredentials", e)
Expand Down
29 changes: 28 additions & 1 deletion aws_advanced_python_wrapper/blue_green_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import socket
from datetime import datetime
from datetime import datetime, timedelta
from time import perf_counter_ns
from typing import TYPE_CHECKING, FrozenSet, List, cast

Expand Down Expand Up @@ -44,9 +44,11 @@
from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils import services_container
from aws_advanced_python_wrapper.utils.atomic import AtomicInt
from aws_advanced_python_wrapper.utils.concurrent import (ConcurrentDict,
ConcurrentSet)
from aws_advanced_python_wrapper.utils.events import MonitorResetEvent
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
Expand Down Expand Up @@ -901,6 +903,9 @@ def _run(self):

finally:
self._close_connection()
if self._host_list_provider is not None:
self._host_list_provider.stop_monitor()
self._host_list_provider = None
logger.debug("BlueGreenStatusMonitor.ThreadCompleted", self._bg_role)

def _open_connection(self):
Expand Down Expand Up @@ -1237,6 +1242,8 @@ def __init__(self, plugin_service: PluginService, props: Properties, bg_id: str)
self._green_dns_removed = False
self._green_topology_changed = False
self._all_green_hosts_changed_name = False
self._monitor_reset_on_in_progress_completed = False
self._monitor_reset_on_topology_completed = False
self._post_status_end_time_ns = 0
self._process_status_lock = RLock()
self._status_check_intervals_ms: Dict[BlueGreenIntervalRate, int] = {}
Expand All @@ -1258,6 +1265,8 @@ def __init__(self, plugin_service: PluginService, props: Properties, bg_id: str)
Messages.get_formatted(
"BlueGreenStatusProvider.UnsupportedDialect", self._bg_id, dialect.__class__.__name__))

services_container.get_storage_service().register(BlueGreenStatus, item_expiration_time=timedelta(hours=1))

current_host_info = self._plugin_service.current_host_info
blue_monitor = BlueGreenStatusMonitor(
BlueGreenRole.SOURCE,
Expand Down Expand Up @@ -1476,6 +1485,7 @@ def _update_summary_status(self, bg_role: BlueGreenRole, interim_status: BlueGre
elif self._latest_phase == BlueGreenPhase.IN_PROGRESS:
self._update_dns_flags(bg_role, interim_status)
self._summary_status = self._get_status_of_in_progress()
self._reset_monitors("_monitor_reset_on_in_progress_completed", "- start")

elif self._latest_phase == BlueGreenPhase.POST:
self._update_dns_flags(bg_role, interim_status)
Expand Down Expand Up @@ -1503,6 +1513,7 @@ def _update_dns_flags(self, bg_role: BlueGreenRole, interim_status: BlueGreenInt
logger.debug("BlueGreenStatusProvider.GreenTopologyChanged", self._bg_id)
self._green_topology_changed = True
self._store_event_phase_time("Green topology changed")
self._reset_monitors("_monitor_reset_on_topology_completed", "- green topology")

def _store_event_phase_time(self, key_prefix: str, phase: Optional[BlueGreenPhase] = None):
rollback_str = " (rollback)" if self._rollback else ""
Expand Down Expand Up @@ -1846,6 +1857,20 @@ def _update_status_cache(self):
with latest_status.cv:
latest_status.cv.notify_all()

def _reset_monitors(self, completed_flag_attr: str, event_name: str):
if getattr(self, completed_flag_attr):
return
setattr(self, completed_flag_attr, True)

blue_endpoints = frozenset(
host for host, role in self._roles_by_host.items()
if role == BlueGreenRole.SOURCE)

cluster_id = self._plugin_service.host_list_provider.get_cluster_id()
services_container.get_event_publisher().publish(
MonitorResetEvent(cluster_id=cluster_id, endpoints=blue_endpoints))
self._store_event_phase_time(f"Monitor reset {event_name}")

def _log_current_context(self):
logger.debug(f"[bg_id: '{self._bg_id}'] Summary status: \n{self._summary_status}")
hosts_str = "\n".join(
Expand Down Expand Up @@ -1913,6 +1938,8 @@ def _reset_context_when_completed(self):
self._green_dns_removed = False
self._green_topology_changed = False
self._all_green_hosts_changed_name = False
self._monitor_reset_on_in_progress_completed = False
self._monitor_reset_on_topology_completed = False
self._post_status_end_time_ns = 0
self._interim_status_hashes = [0, 0]
self._latest_context_hash = 0
Expand Down
11 changes: 2 additions & 9 deletions aws_advanced_python_wrapper/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,11 @@
OpenedConnectionTracker
from aws_advanced_python_wrapper.aws_credentials_manager import \
AwsCredentialsManager
from aws_advanced_python_wrapper.host_monitoring_plugin import \
MonitoringThreadContainer
from aws_advanced_python_wrapper.thread_pool_container import \
ThreadPoolContainer
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
SlidingExpirationCacheContainer
from aws_advanced_python_wrapper.utils import services_container


def release_resources() -> None:
"""Release all global resources used by the wrapper."""
MonitoringThreadContainer.clean_up()
ThreadPoolContainer.release_resources()
services_container.release_resources()
AwsCredentialsManager.release_resources()
OpenedConnectionTracker.release_resources()
SlidingExpirationCacheContainer.release_resources()
Loading
Loading