Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions taskbadger/cli_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from taskbadger import __version__
from taskbadger.cli import create, get, list_tasks_command, run, update
from taskbadger.config import get_config, write_config
from taskbadger.sdk import _parse_token

app = typer.Typer(
rich_markup_mode="rich",
Expand All @@ -30,9 +31,18 @@ def version_callback(value: bool):
def configure(ctx: typer.Context):
"""Update CLI configuration."""
config = ctx.meta["tb_config"]
config.organization_slug = typer.prompt("Organization slug", default=config.organization_slug)
config.project_slug = typer.prompt("Project slug", default=config.project_slug)
config.token = typer.prompt("API Key", default=config.token)
token = typer.prompt("API Key", default=config.token)
parsed = _parse_token(token)
if parsed:
org_slug, project_slug, api_key = parsed
print(f"Project key detected — organization: [green]{org_slug}[/green], project: [green]{project_slug}[/green]")
config.organization_slug = org_slug
config.project_slug = project_slug
config.token = token
else:
config.organization_slug = typer.prompt("Organization slug", default=config.organization_slug)
config.project_slug = typer.prompt("Project slug", default=config.project_slug)
config.token = token
path = write_config(config)
print(f"Config written to [green]{path}[/green]")

Expand Down
18 changes: 14 additions & 4 deletions taskbadger/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import typer
from tomlkit import document, table

from taskbadger.sdk import _TB_HOST, _init
from taskbadger.sdk import _TB_HOST, _init, _parse_token

APP_NAME = "taskbadger"

Expand Down Expand Up @@ -47,10 +47,20 @@ def from_dict(config_dict, **overrides) -> "Config":
"""
defaults = config_dict.get("defaults", {})
auth = config_dict.get("auth", {})
token = overrides.get("token") or _from_env("API_KEY", auth.get("token"))
organization_slug = overrides.get("org") or _from_env("ORG", defaults.get("org"))
project_slug = overrides.get("project") or _from_env("PROJECT", defaults.get("project"))

if token:
parsed = _parse_token(token)
if parsed:
organization_slug = parsed[0]
project_slug = parsed[1]

return Config(
token=overrides.get("token") or _from_env("API_KEY", auth.get("token")),
organization_slug=overrides.get("org") or _from_env("ORG", defaults.get("org")),
project_slug=overrides.get("project") or _from_env("PROJECT", defaults.get("project")),
token=token,
organization_slug=organization_slug,
project_slug=project_slug,
host=overrides.get("host") or auth.get("host"),
tags=config_dict.get("tags", {}),
)
Expand Down
43 changes: 41 additions & 2 deletions taskbadger/sdk.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import datetime
import logging
import os
Expand Down Expand Up @@ -35,6 +36,26 @@
_TB_HOST = "https://taskbadger.net"


def _parse_token(token):
"""Try to decode a project API key.

Project keys are base64-encoded strings in the format ``org/project/key``.

Returns:
A tuple of ``(organization_slug, project_slug, api_key)`` if *token*
is a valid project key, otherwise ``None``.
"""
try:
decoded = base64.b64decode(token, validate=True).decode("utf-8")
except Exception:
return None

parts = decoded.split("/")
if len(parts) == 3 and all(parts):
return tuple(parts)
return None


def init(
organization_slug: str = None,
project_slug: str = None,
Expand All @@ -43,9 +64,16 @@ def init(
tags: dict[str, str] = None,
before_create: Callback = None,
):
"""Initialize Task Badger client
"""Initialize Task Badger client.

If *token* is a project API key (base64-encoded ``org/project/key``),
the organization and project slugs are extracted automatically and
*organization_slug* / *project_slug* are ignored.

Call this function once per thread
For legacy API keys, *organization_slug* and *project_slug* are
required and a deprecation warning is emitted.

Call this function once per thread.
"""
_init(_TB_HOST, organization_slug, project_slug, token, systems, tags, before_create)

Expand All @@ -64,6 +92,17 @@ def _init(
project_slug = project_slug or os.environ.get("TASKBADGER_PROJECT")
token = token or os.environ.get("TASKBADGER_API_KEY")

if token:
parsed = _parse_token(token)
if parsed:
organization_slug, project_slug, token = parsed
else:
warnings.warn(
"Legacy API keys are deprecated. Please switch to a project API key.",
DeprecationWarning,
stacklevel=3,
)

if before_create and isinstance(before_create, str):
try:
before_create = import_string(before_create)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_celery_system_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import logging
import sys
import time
import weakref
from http import HTTPStatus
from unittest import mock
Expand All @@ -26,6 +27,18 @@
from tests.utils import task_for_test


def _wait_for_mock_calls(mock_obj, expected_count, timeout=5):
"""Wait for a mock to reach the expected call count.

Celery stores the task result before firing task_success, so
``result.get()`` can return before the success signal handler runs.
Without this wait the mock context may exit before the handler fires.
"""
deadline = time.monotonic() + timeout
while mock_obj.call_count < expected_count and time.monotonic() < deadline:
time.sleep(0.05)


@pytest.fixture()
def _bind_settings_with_system():
systems = [CelerySystemIntegration()]
Expand Down Expand Up @@ -71,6 +84,7 @@ def add_normal(self, a, b):
result = add_normal.delay(2, 2)
assert result.info.get("taskbadger_task_id") == tb_task.id
assert result.get(timeout=10, propagate=True) == 4
_wait_for_mock_calls(update, 2)

create.assert_called_once()
assert get_task.call_count == 1
Expand Down Expand Up @@ -102,6 +116,7 @@ def add_normal(self, a, b):
result = add_normal.delay(2, 2)
assert result.info.get("taskbadger_task_id") == tb_task.id
assert result.get(timeout=10, propagate=True) == 4
_wait_for_mock_calls(update, 2)

create.assert_called_once_with(
"tests.test_celery_system_integration.add_normal",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_info_config_env_args():


def test_configure(mock_config_location):
result = runner.invoke(app, ["configure"], input="an-org\na-project\na-token")
result = runner.invoke(app, ["configure"], input="a-token\nan-org\na-project")
assert result.exit_code == 0
assert mock_config_location.is_file()
with mock_config_location.open("rt", encoding="utf-8") as fp:
Expand Down
16 changes: 12 additions & 4 deletions tests/test_init.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import pytest

from taskbadger import Badger, init
Expand All @@ -14,16 +16,22 @@ def _reset():


def test_init():
init("org", "project", "token", before_create=lambda x: x)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
init("org", "project", "token", before_create=lambda x: x)


def test_init_import_before_create():
init("org", "project", "token", before_create="tests.test_init._before_create")
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
init("org", "project", "token", before_create="tests.test_init._before_create")


def test_init_import_before_create_fail():
with pytest.raises(ConfigurationError):
init("org", "project", "token", before_create="missing")
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
with pytest.raises(ConfigurationError):
init("org", "project", "token", before_create="missing")


def _before_create(_):
Expand Down
Loading