Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Bug Fixes
---------
* Improve query cancellation on control-c.
* Improve refresh of some format strings in the toolbar.
* Improve keyring storage, requiring re-entering most keyring passwords.


Internal
Expand Down
51 changes: 34 additions & 17 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,17 +689,19 @@ def connect(
# 5. cnf (.my.cnf / etc)
# 6. keyring

keychain_identifier = f'{user}@{host}:{int_port}:{socket}'
keychain_domain = 'mycli.net'
keychain_retrieved = False
keyring_identifier = f'{user}@{host}:{"" if socket else int_port}:{socket or ""}'
keyring_domain = 'mycli.net'
keyring_retrieved_cleanly = False

if passwd is None and use_keyring and not reset_keyring:
passwd = keyring.get_password(keychain_domain, keychain_identifier)
keychain_retrieved = True
passwd = keyring.get_password(keyring_domain, keyring_identifier)
if passwd is not None:
keyring_retrieved_cleanly = True

# prompt for password if requested by user
if passwd == "MYCLI_ASK_PASSWORD":
passwd = click.prompt(f"Enter password for {user}", hide_input=True, show_default=False, default='', type=str, err=True)
keyring_retrieved_cleanly = False

connection_info: dict[Any, Any] = {
"database": database,
Expand All @@ -720,21 +722,27 @@ def connect(
"unbuffered": unbuffered,
}

def _update_keyring(password: str | None):
def _update_keyring(password: str | None, keyring_retrieved_cleanly: bool):
if not password:
return
if reset_keyring or (use_keyring and not keychain_retrieved):
if reset_keyring or (use_keyring and not keyring_retrieved_cleanly):
try:
saved_pw = keyring.get_password(keychain_domain, keychain_identifier)
saved_pw = keyring.get_password(keyring_domain, keyring_identifier)
if password != saved_pw or reset_keyring:
keyring.set_password(keychain_domain, keychain_identifier, password)
click.secho('Password saved to the system keyring', err=True)
keyring.set_password(keyring_domain, keyring_identifier, password)
click.secho(f'Password saved to the system keyring at {keyring_domain}/{keyring_identifier}', err=True)
except Exception as e:
click.secho(f'Password not saved to the system keyring: {e}', err=True, fg='red')

def _connect(retry_ssl: bool = False, retry_password: bool = False) -> None:
def _connect(
retry_ssl: bool = False,
retry_password: bool = False,
keyring_save_eligible: bool = True,
keyring_retrieved_cleanly: bool = False,
) -> None:
try:
_update_keyring(connection_info["password"])
if keyring_save_eligible:
_update_keyring(connection_info["password"], keyring_retrieved_cleanly=keyring_retrieved_cleanly)
self.sqlexecute = SQLExecute(**connection_info)
except pymysql.OperationalError as e1:
if e1.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto":
Expand All @@ -743,7 +751,9 @@ def _connect(retry_ssl: bool = False, retry_password: bool = False) -> None:
raise e1
# disable SSL and try to connect again
connection_info["ssl"] = None
_connect(retry_ssl=True)
_connect(
retry_ssl=True, keyring_retrieved_cleanly=keyring_retrieved_cleanly, keyring_save_eligible=keyring_save_eligible
)
elif e1.args[0] == ACCESS_DENIED_ERROR and connection_info["password"] is None:
# if we already tried and failed to connect with a new password, raise the error
if retry_password:
Expand All @@ -753,7 +763,12 @@ def _connect(retry_ssl: bool = False, retry_password: bool = False) -> None:
f"Enter password for {user}", hide_input=True, show_default=False, default='', type=str, err=True
)
connection_info["password"] = new_password
_connect(retry_password=True)
keyring_retrieved_cleanly = False
_connect(
retry_password=True,
keyring_retrieved_cleanly=keyring_retrieved_cleanly,
keyring_save_eligible=keyring_save_eligible,
)
elif e1.args[0] == CR_SERVER_LOST:
self.echo(
(
Expand All @@ -775,7 +790,7 @@ def _connect(retry_ssl: bool = False, retry_password: bool = False) -> None:
socket_owner = '<unknown>'
self.echo(f"Connecting to socket {socket}, owned by user {socket_owner}", err=True)
try:
_connect()
_connect(keyring_retrieved_cleanly=keyring_retrieved_cleanly)
except pymysql.OperationalError as e:
# These are "Can't open socket" and 2x "Can't connect"
if [code for code in (2001, 2002, 2003) if code == e.args[0]]:
Expand All @@ -790,12 +805,14 @@ def _connect(retry_ssl: bool = False, retry_password: bool = False) -> None:
socket = ""
host = "localhost"
port = 3306
_connect()
# todo should reload the keyring identifier here instead of invalidating
_connect(keyring_save_eligible=False)
else:
raise e
else:
host = host or "localhost"
port = port or 3306
# could try loading the keyring again here instead of assuming nothing important changed

# Bad ports give particularly daft error messages
try:
Expand All @@ -804,7 +821,7 @@ def _connect(retry_ssl: bool = False, retry_password: bool = False) -> None:
self.echo(f"Error: Invalid port number: '{port}'.", err=True, fg="red")
sys.exit(1)

_connect()
_connect(keyring_retrieved_cleanly=keyring_retrieved_cleanly)
except Exception as e: # Connecting to a database could fail.
self.logger.debug("Database connection failed: %r.", e)
self.logger.error("traceback: %r", traceback.format_exc())
Expand Down