Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions src/datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
port: int | None = None,
use_tls: bool | dict | None = None,
*,
dbname: str | None = None,
backend: str | None = None,
config_override: "Config | None" = None,
) -> None:
Expand All @@ -180,7 +181,9 @@ def __init__(
port = int(port)
elif port is None:
port = self._config["database.port"]
self.conn_info = dict(host=host, port=port, user=user, passwd=password)
if dbname is None:
dbname = self._config.get("database.dbname")
self.conn_info = dict(host=host, port=port, user=user, passwd=password, dbname=dbname)
if use_tls is not False:
# use_tls can be: None (auto-detect), True (enable), False (disable), or dict (custom config)
if isinstance(use_tls, dict):
Expand Down Expand Up @@ -218,20 +221,26 @@ def __repr__(self):
connected = "connected" if self.is_connected else "disconnected"
return "DataJoint connection ({connected}) {user}@{host}:{port}".format(connected=connected, **self.conn_info)

def _build_connect_kwargs(self, use_tls=None):
"""Build kwargs dict for adapter.connect()."""
kwargs = dict(
host=self.conn_info["host"],
port=self.conn_info["port"],
user=self.conn_info["user"],
password=self.conn_info["passwd"],
charset=self._config["connection.charset"],
use_tls=use_tls if use_tls is not None else self.conn_info.get("ssl"),
)
if self.conn_info.get("dbname"):
kwargs["dbname"] = self.conn_info["dbname"]
return kwargs

def connect(self) -> None:
"""Establish or re-establish connection to the database server."""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*deprecated.*")
try:
# Use adapter to create connection
self._conn = self.adapter.connect(
host=self.conn_info["host"],
port=self.conn_info["port"],
user=self.conn_info["user"],
password=self.conn_info["passwd"],
charset=self._config["connection.charset"],
use_tls=self.conn_info.get("ssl"),
)
self._conn = self.adapter.connect(**self._build_connect_kwargs())
except Exception as ssl_error:
# If SSL fails, retry without SSL (if it was auto-detected)
if self.conn_info.get("ssl_input") is None:
Expand All @@ -240,14 +249,7 @@ def connect(self) -> None:
"To require SSL, set use_tls=True explicitly.",
ssl_error,
)
self._conn = self.adapter.connect(
host=self.conn_info["host"],
port=self.conn_info["port"],
user=self.conn_info["user"],
password=self.conn_info["passwd"],
charset=self._config["connection.charset"],
use_tls=False, # Explicitly disable SSL for fallback
)
self._conn = self.adapter.connect(**self._build_connect_kwargs(use_tls=False))
else:
raise
self._is_closed = False # Mark as connected after successful connection
Expand Down
5 changes: 5 additions & 0 deletions src/datajoint/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ class DatabaseSettings(BaseSettings):
description="Database backend: 'mysql' or 'postgresql'",
)
port: int | None = Field(default=None, validation_alias="DJ_PORT")
dbname: str | None = Field(
default=None,
validation_alias="DJ_DBNAME",
description="Database name for PostgreSQL connections. Defaults to 'postgres' if not set.",
)
reconnect: bool = True
use_tls: bool | None = Field(default=None, validation_alias="DJ_USE_TLS")
database_prefix: str = Field(
Expand Down
2 changes: 1 addition & 1 deletion src/datajoint/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# version bump auto managed by Github Actions:
# label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit)
# manually set this version will be eventually overwritten by the above actions
__version__ = "2.2.0"
__version__ = "2.2.1"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done automatically as part of the PyPI publishing action.

51 changes: 51 additions & 0 deletions tests/unit/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,57 @@ def test_similar_prefix_names_allowed(self):
dj.config.stores.update(original_stores)


class TestDbnameConfiguration:
"""Test database.dbname configuration."""

def test_dbname_default_is_none(self):
"""Dbname defaults to None when not configured."""
from datajoint.settings import DatabaseSettings

s = DatabaseSettings()
assert s.dbname is None

def test_dbname_env_var(self, monkeypatch):
"""DJ_DBNAME environment variable sets dbname."""
from datajoint.settings import DatabaseSettings

monkeypatch.setenv("DJ_DBNAME", "my_database")
s = DatabaseSettings()
assert s.dbname == "my_database"

def test_dbname_from_config_file(self, tmp_path, monkeypatch):
"""Load dbname from config file."""
import json

from datajoint.settings import Config

config_file = tmp_path / "test_config.json"
config_file.write_text(json.dumps({"database": {"dbname": "custom_db", "host": "localhost"}}))

monkeypatch.delenv("DJ_DBNAME", raising=False)
monkeypatch.delenv("DJ_HOST", raising=False)

cfg = Config()
cfg.load(config_file)
assert cfg.database.dbname == "custom_db"

def test_dbname_dict_access(self):
"""Dict-style access reads and writes dbname."""
original = dj.config.database.dbname
try:
dj.config.database.dbname = "test_db"
assert dj.config["database.dbname"] == "test_db"
finally:
dj.config.database.dbname = original

def test_dbname_override_context_manager(self):
"""Override context manager temporarily sets dbname."""
original = dj.config.database.dbname
with dj.config.override(database__dbname="override_db"):
assert dj.config.database.dbname == "override_db"
assert dj.config.database.dbname == original


class TestBackendConfiguration:
"""Test database backend configuration and port auto-detection."""

Expand Down
Loading