diff --git a/changelog.md b/changelog.md index b9cf72ee..8a79c2d9 100644 --- a/changelog.md +++ b/changelog.md @@ -13,6 +13,7 @@ Bug Fixes --------- * Improve query cancellation on control-c. * Improve refresh of some format strings in the toolbar. +* Improve keyring storage, requiring re-entering most keyring passwords. Internal diff --git a/mycli/main.py b/mycli/main.py index db03bf25..9bf6ecac 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -689,17 +689,19 @@ def connect( # 5. cnf (.my.cnf / etc) # 6. keyring - keychain_identifier = f'{user}@{host}:{int_port}:{socket}' - keychain_domain = 'mycli.net' - keychain_retrieved = False + keyring_identifier = f'{user}@{host}:{"" if socket else int_port}:{socket or ""}' + keyring_domain = 'mycli.net' + keyring_retrieved_cleanly = False if passwd is None and use_keyring and not reset_keyring: - passwd = keyring.get_password(keychain_domain, keychain_identifier) - keychain_retrieved = True + passwd = keyring.get_password(keyring_domain, keyring_identifier) + if passwd is not None: + keyring_retrieved_cleanly = True # prompt for password if requested by user if passwd == "MYCLI_ASK_PASSWORD": passwd = click.prompt(f"Enter password for {user}", hide_input=True, show_default=False, default='', type=str, err=True) + keyring_retrieved_cleanly = False connection_info: dict[Any, Any] = { "database": database, @@ -720,21 +722,27 @@ def connect( "unbuffered": unbuffered, } - def _update_keyring(password: str | None): + def _update_keyring(password: str | None, keyring_retrieved_cleanly: bool): if not password: return - if reset_keyring or (use_keyring and not keychain_retrieved): + if reset_keyring or (use_keyring and not keyring_retrieved_cleanly): try: - saved_pw = keyring.get_password(keychain_domain, keychain_identifier) + saved_pw = keyring.get_password(keyring_domain, keyring_identifier) if password != saved_pw or reset_keyring: - keyring.set_password(keychain_domain, keychain_identifier, password) - click.secho('Password saved to the system keyring', err=True) + keyring.set_password(keyring_domain, keyring_identifier, password) + click.secho(f'Password saved to the system keyring at {keyring_domain}/{keyring_identifier}', err=True) except Exception as e: click.secho(f'Password not saved to the system keyring: {e}', err=True, fg='red') - def _connect(retry_ssl: bool = False, retry_password: bool = False) -> None: + def _connect( + retry_ssl: bool = False, + retry_password: bool = False, + keyring_save_eligible: bool = True, + keyring_retrieved_cleanly: bool = False, + ) -> None: try: - _update_keyring(connection_info["password"]) + if keyring_save_eligible: + _update_keyring(connection_info["password"], keyring_retrieved_cleanly=keyring_retrieved_cleanly) self.sqlexecute = SQLExecute(**connection_info) except pymysql.OperationalError as e1: if e1.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto": @@ -743,7 +751,9 @@ def _connect(retry_ssl: bool = False, retry_password: bool = False) -> None: raise e1 # disable SSL and try to connect again connection_info["ssl"] = None - _connect(retry_ssl=True) + _connect( + retry_ssl=True, keyring_retrieved_cleanly=keyring_retrieved_cleanly, keyring_save_eligible=keyring_save_eligible + ) elif e1.args[0] == ACCESS_DENIED_ERROR and connection_info["password"] is None: # if we already tried and failed to connect with a new password, raise the error if retry_password: @@ -753,7 +763,12 @@ def _connect(retry_ssl: bool = False, retry_password: bool = False) -> None: f"Enter password for {user}", hide_input=True, show_default=False, default='', type=str, err=True ) connection_info["password"] = new_password - _connect(retry_password=True) + keyring_retrieved_cleanly = False + _connect( + retry_password=True, + keyring_retrieved_cleanly=keyring_retrieved_cleanly, + keyring_save_eligible=keyring_save_eligible, + ) elif e1.args[0] == CR_SERVER_LOST: self.echo( ( @@ -775,7 +790,7 @@ def _connect(retry_ssl: bool = False, retry_password: bool = False) -> None: socket_owner = '' self.echo(f"Connecting to socket {socket}, owned by user {socket_owner}", err=True) try: - _connect() + _connect(keyring_retrieved_cleanly=keyring_retrieved_cleanly) except pymysql.OperationalError as e: # These are "Can't open socket" and 2x "Can't connect" if [code for code in (2001, 2002, 2003) if code == e.args[0]]: @@ -790,12 +805,14 @@ def _connect(retry_ssl: bool = False, retry_password: bool = False) -> None: socket = "" host = "localhost" port = 3306 - _connect() + # todo should reload the keyring identifier here instead of invalidating + _connect(keyring_save_eligible=False) else: raise e else: host = host or "localhost" port = port or 3306 + # could try loading the keyring again here instead of assuming nothing important changed # Bad ports give particularly daft error messages try: @@ -804,7 +821,7 @@ def _connect(retry_ssl: bool = False, retry_password: bool = False) -> None: self.echo(f"Error: Invalid port number: '{port}'.", err=True, fg="red") sys.exit(1) - _connect() + _connect(keyring_retrieved_cleanly=keyring_retrieved_cleanly) except Exception as e: # Connecting to a database could fail. self.logger.debug("Database connection failed: %r.", e) self.logger.error("traceback: %r", traceback.format_exc())