diff --git a/taskbadger/cli_main.py b/taskbadger/cli_main.py index 88d1e0b..1244a87 100644 --- a/taskbadger/cli_main.py +++ b/taskbadger/cli_main.py @@ -6,6 +6,7 @@ from taskbadger import __version__ from taskbadger.cli import create, get, list_tasks_command, run, update from taskbadger.config import get_config, write_config +from taskbadger.sdk import _parse_token app = typer.Typer( rich_markup_mode="rich", @@ -30,9 +31,18 @@ def version_callback(value: bool): def configure(ctx: typer.Context): """Update CLI configuration.""" config = ctx.meta["tb_config"] - config.organization_slug = typer.prompt("Organization slug", default=config.organization_slug) - config.project_slug = typer.prompt("Project slug", default=config.project_slug) - config.token = typer.prompt("API Key", default=config.token) + token = typer.prompt("API Key", default=config.token) + parsed = _parse_token(token) + if parsed: + org_slug, project_slug, api_key = parsed + print(f"Project key detected — organization: [green]{org_slug}[/green], project: [green]{project_slug}[/green]") + config.organization_slug = org_slug + config.project_slug = project_slug + config.token = token + else: + config.organization_slug = typer.prompt("Organization slug", default=config.organization_slug) + config.project_slug = typer.prompt("Project slug", default=config.project_slug) + config.token = token path = write_config(config) print(f"Config written to [green]{path}[/green]") diff --git a/taskbadger/config.py b/taskbadger/config.py index 2c13305..9102646 100644 --- a/taskbadger/config.py +++ b/taskbadger/config.py @@ -7,7 +7,7 @@ import typer from tomlkit import document, table -from taskbadger.sdk import _TB_HOST, _init +from taskbadger.sdk import _TB_HOST, _init, _parse_token APP_NAME = "taskbadger" @@ -47,10 +47,20 @@ def from_dict(config_dict, **overrides) -> "Config": """ defaults = config_dict.get("defaults", {}) auth = config_dict.get("auth", {}) + token = overrides.get("token") or _from_env("API_KEY", auth.get("token")) + organization_slug = overrides.get("org") or _from_env("ORG", defaults.get("org")) + project_slug = overrides.get("project") or _from_env("PROJECT", defaults.get("project")) + + if token: + parsed = _parse_token(token) + if parsed: + organization_slug = parsed[0] + project_slug = parsed[1] + return Config( - token=overrides.get("token") or _from_env("API_KEY", auth.get("token")), - organization_slug=overrides.get("org") or _from_env("ORG", defaults.get("org")), - project_slug=overrides.get("project") or _from_env("PROJECT", defaults.get("project")), + token=token, + organization_slug=organization_slug, + project_slug=project_slug, host=overrides.get("host") or auth.get("host"), tags=config_dict.get("tags", {}), ) diff --git a/taskbadger/sdk.py b/taskbadger/sdk.py index 1f52273..28ee3df 100644 --- a/taskbadger/sdk.py +++ b/taskbadger/sdk.py @@ -1,3 +1,4 @@ +import base64 import datetime import logging import os @@ -35,6 +36,26 @@ _TB_HOST = "https://taskbadger.net" +def _parse_token(token): + """Try to decode a project API key. + + Project keys are base64-encoded strings in the format ``org/project/key``. + + Returns: + A tuple of ``(organization_slug, project_slug, api_key)`` if *token* + is a valid project key, otherwise ``None``. + """ + try: + decoded = base64.b64decode(token, validate=True).decode("utf-8") + except Exception: + return None + + parts = decoded.split("/") + if len(parts) == 3 and all(parts): + return tuple(parts) + return None + + def init( organization_slug: str = None, project_slug: str = None, @@ -43,9 +64,16 @@ def init( tags: dict[str, str] = None, before_create: Callback = None, ): - """Initialize Task Badger client + """Initialize Task Badger client. + + If *token* is a project API key (base64-encoded ``org/project/key``), + the organization and project slugs are extracted automatically and + *organization_slug* / *project_slug* are ignored. - Call this function once per thread + For legacy API keys, *organization_slug* and *project_slug* are + required and a deprecation warning is emitted. + + Call this function once per thread. """ _init(_TB_HOST, organization_slug, project_slug, token, systems, tags, before_create) @@ -64,6 +92,17 @@ def _init( project_slug = project_slug or os.environ.get("TASKBADGER_PROJECT") token = token or os.environ.get("TASKBADGER_API_KEY") + if token: + parsed = _parse_token(token) + if parsed: + organization_slug, project_slug, token = parsed + else: + warnings.warn( + "Legacy API keys are deprecated. Please switch to a project API key.", + DeprecationWarning, + stacklevel=3, + ) + if before_create and isinstance(before_create, str): try: before_create = import_string(before_create) diff --git a/tests/test_celery_system_integration.py b/tests/test_celery_system_integration.py index 24a0fe4..d7d7c7d 100644 --- a/tests/test_celery_system_integration.py +++ b/tests/test_celery_system_integration.py @@ -10,6 +10,7 @@ import logging import sys +import time import weakref from http import HTTPStatus from unittest import mock @@ -26,6 +27,18 @@ from tests.utils import task_for_test +def _wait_for_mock_calls(mock_obj, expected_count, timeout=5): + """Wait for a mock to reach the expected call count. + + Celery stores the task result before firing task_success, so + ``result.get()`` can return before the success signal handler runs. + Without this wait the mock context may exit before the handler fires. + """ + deadline = time.monotonic() + timeout + while mock_obj.call_count < expected_count and time.monotonic() < deadline: + time.sleep(0.05) + + @pytest.fixture() def _bind_settings_with_system(): systems = [CelerySystemIntegration()] @@ -71,6 +84,7 @@ def add_normal(self, a, b): result = add_normal.delay(2, 2) assert result.info.get("taskbadger_task_id") == tb_task.id assert result.get(timeout=10, propagate=True) == 4 + _wait_for_mock_calls(update, 2) create.assert_called_once() assert get_task.call_count == 1 @@ -102,6 +116,7 @@ def add_normal(self, a, b): result = add_normal.delay(2, 2) assert result.info.get("taskbadger_task_id") == tb_task.id assert result.get(timeout=10, propagate=True) == 4 + _wait_for_mock_calls(update, 2) create.assert_called_once_with( "tests.test_celery_system_integration.add_normal", diff --git a/tests/test_cli_config.py b/tests/test_cli_config.py index 051aab7..180af10 100644 --- a/tests/test_cli_config.py +++ b/tests/test_cli_config.py @@ -112,7 +112,7 @@ def test_info_config_env_args(): def test_configure(mock_config_location): - result = runner.invoke(app, ["configure"], input="an-org\na-project\na-token") + result = runner.invoke(app, ["configure"], input="a-token\nan-org\na-project") assert result.exit_code == 0 assert mock_config_location.is_file() with mock_config_location.open("rt", encoding="utf-8") as fp: diff --git a/tests/test_init.py b/tests/test_init.py index 07acd31..81d57ea 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,3 +1,5 @@ +import warnings + import pytest from taskbadger import Badger, init @@ -14,16 +16,22 @@ def _reset(): def test_init(): - init("org", "project", "token", before_create=lambda x: x) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + init("org", "project", "token", before_create=lambda x: x) def test_init_import_before_create(): - init("org", "project", "token", before_create="tests.test_init._before_create") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + init("org", "project", "token", before_create="tests.test_init._before_create") def test_init_import_before_create_fail(): - with pytest.raises(ConfigurationError): - init("org", "project", "token", before_create="missing") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + with pytest.raises(ConfigurationError): + init("org", "project", "token", before_create="missing") def _before_create(_): diff --git a/tests/test_project_key.py b/tests/test_project_key.py new file mode 100644 index 0000000..aed27ce --- /dev/null +++ b/tests/test_project_key.py @@ -0,0 +1,190 @@ +import base64 +import os +from pathlib import Path +from unittest import mock + +import pytest +import tomlkit +from typer.testing import CliRunner + +from taskbadger.cli_main import app +from taskbadger.config import Config +from taskbadger.mug import Badger, _local +from taskbadger.sdk import _parse_token, init + + +def _make_project_key(org="myorg", project="myproject", key="secret123"): + return base64.b64encode(f"{org}/{project}/{key}".encode()).decode() + + +# --- _parse_token tests --- + + +class TestParseToken: + def test_valid_project_key(self): + token = _make_project_key("org1", "proj1", "apikey") + result = _parse_token(token) + assert result == ("org1", "proj1", "apikey") + + def test_legacy_key(self): + result = _parse_token("some-legacy-api-key") + assert result is None + + def test_invalid_base64(self): + result = _parse_token("!!!not-base64!!!") + assert result is None + + def test_base64_but_wrong_format_two_parts(self): + token = base64.b64encode(b"only/two").decode() + result = _parse_token(token) + assert result is None + + def test_base64_but_wrong_format_four_parts(self): + token = base64.b64encode(b"a/b/c/d").decode() + result = _parse_token(token) + assert result is None + + def test_base64_with_empty_parts(self): + token = base64.b64encode(b"org//key").decode() + result = _parse_token(token) + assert result is None + + def test_empty_string(self): + result = _parse_token("") + assert result is None + + +# --- init() tests --- + + +@pytest.fixture(autouse=True) +def _reset_badger(): + b_global = Badger.current + _local.set(Badger()) + yield + _local.set(b_global) + + +class TestInitWithProjectKey: + def test_init_with_project_key(self): + token = _make_project_key("org1", "proj1", "apikey") + init(token=token) + settings = Badger.current.settings + assert settings.organization_slug == "org1" + assert settings.project_slug == "proj1" + assert settings.token == "apikey" + + def test_init_project_key_overrides_org_project(self): + token = _make_project_key("org1", "proj1", "apikey") + init(organization_slug="ignored", project_slug="ignored", token=token) + settings = Badger.current.settings + assert settings.organization_slug == "org1" + assert settings.project_slug == "proj1" + assert settings.token == "apikey" + + def test_init_project_key_via_env(self): + token = _make_project_key("org1", "proj1", "apikey") + with mock.patch.dict(os.environ, {"TASKBADGER_API_KEY": token}): + init() + settings = Badger.current.settings + assert settings.organization_slug == "org1" + assert settings.project_slug == "proj1" + assert settings.token == "apikey" + + def test_init_project_key_no_deprecation_warning(self, recwarn): + token = _make_project_key() + init(token=token) + deprecation_warnings = [w for w in recwarn if issubclass(w.category, DeprecationWarning)] + assert len(deprecation_warnings) == 0 + + def test_init_legacy_key_emits_deprecation_warning(self): + with pytest.warns(DeprecationWarning, match="Legacy API keys are deprecated"): + init("org", "project", "legacy-token") + + def test_init_legacy_key_still_works(self): + with pytest.warns(DeprecationWarning): + init("org", "project", "legacy-token") + settings = Badger.current.settings + assert settings.organization_slug == "org" + assert settings.project_slug == "project" + assert settings.token == "legacy-token" + + +# --- Config.from_dict tests --- + + +class TestConfigFromDictWithProjectKey: + def test_project_key_in_config(self): + token = _make_project_key("org1", "proj1", "apikey") + config = Config.from_dict({"auth": {"token": token}}) + assert config.organization_slug == "org1" + assert config.project_slug == "proj1" + # Token remains as original base64 string (decoded by _init at init time) + assert config.token == token + + def test_project_key_overrides_config_org_project(self): + token = _make_project_key("org1", "proj1", "apikey") + config = Config.from_dict( + { + "auth": {"token": token}, + "defaults": {"org": "old-org", "project": "old-project"}, + } + ) + assert config.organization_slug == "org1" + assert config.project_slug == "proj1" + + def test_project_key_via_env(self): + token = _make_project_key("org1", "proj1", "apikey") + with mock.patch.dict(os.environ, {"TASKBADGER_API_KEY": token}): + config = Config.from_dict({}) + assert config.organization_slug == "org1" + assert config.project_slug == "proj1" + assert config.token == token + + def test_project_key_is_valid(self): + token = _make_project_key("org1", "proj1", "apikey") + config = Config.from_dict({"auth": {"token": token}}) + assert config.is_valid() + + +# --- CLI configure tests --- + +runner = CliRunner() + + +@pytest.fixture() +def mock_config_location(): + config_path = Path(__file__).parent / "_mock_config_project_key" + with mock.patch("taskbadger.config._get_config_path", return_value=config_path): + yield config_path + if config_path.exists(): + os.remove(config_path) + + +class TestCLIConfigureProjectKey: + def test_configure_with_project_key(self, mock_config_location): + token = _make_project_key("myorg", "myproj", "mykey") + result = runner.invoke(app, ["configure"], input=f"{token}\n") + assert result.exit_code == 0 + assert "Project key detected" in result.stdout + assert "myorg" in result.stdout + assert "myproj" in result.stdout + + with mock_config_location.open("rt", encoding="utf-8") as fp: + raw_config = tomlkit.load(fp) + config_dict = raw_config.unwrap() + assert config_dict["defaults"]["org"] == "myorg" + assert config_dict["defaults"]["project"] == "myproj" + assert config_dict["auth"]["token"] == token + + def test_configure_with_legacy_key(self, mock_config_location): + result = runner.invoke(app, ["configure"], input="a-token\nan-org\na-project\n") + assert result.exit_code == 0 + assert "Project key detected" not in result.stdout + + with mock_config_location.open("rt", encoding="utf-8") as fp: + raw_config = tomlkit.load(fp) + config_dict = raw_config.unwrap() + assert config_dict["defaults"]["org"] == "an-org" + assert config_dict["defaults"]["project"] == "a-project" + assert config_dict["auth"]["token"] == "a-token" diff --git a/tests/test_scope.py b/tests/test_scope.py index 0dc93c5..ba0b0be 100644 --- a/tests/test_scope.py +++ b/tests/test_scope.py @@ -1,3 +1,5 @@ +import warnings + import pytest from taskbadger import create_task, init @@ -46,7 +48,9 @@ def test_scope_context(): @pytest.fixture(autouse=True) def _init_skd(): - init("org", "project", "token") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + init("org", "project", "token") def test_create_task_with_scope(httpx_mock): diff --git a/tests/test_sdk.py b/tests/test_sdk.py index 225f836..89fb22c 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -1,4 +1,5 @@ import datetime +import warnings from http import HTTPStatus from unittest import mock @@ -18,7 +19,9 @@ @pytest.fixture(autouse=True) def _init_skd(): - init("org", "project", "token") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + init("org", "project", "token") @pytest.fixture() diff --git a/tests/test_sdk_primatives.py b/tests/test_sdk_primatives.py index 1de0a99..88e5b91 100644 --- a/tests/test_sdk_primatives.py +++ b/tests/test_sdk_primatives.py @@ -1,3 +1,5 @@ +import warnings + import pytest from taskbadger import Action, EmailIntegration, StatusEnum, update_task @@ -7,7 +9,9 @@ @pytest.fixture(autouse=True) def _init_skd(): - init("org", "project", "token") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + init("org", "project", "token") def test_get_task(httpx_mock):