diff --git a/clarifai/__init__.py b/clarifai/__init__.py index 80d60db3..12d7ed4b 100644 --- a/clarifai/__init__.py +++ b/clarifai/__init__.py @@ -1 +1 @@ -__version__ = "12.2.1" +__version__ = "12.2.2" diff --git a/clarifai/cli/artifact.py b/clarifai/cli/artifact.py index 6c34787a..526cecd3 100644 --- a/clarifai/cli/artifact.py +++ b/clarifai/cli/artifact.py @@ -218,7 +218,7 @@ def _download_artifact( context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}, ) def artifact(): - """Manage Artifacts: create, upload, download, list, get, delete""" + """Manage artifacts and files.""" @artifact.command(['list', 'ls']) diff --git a/clarifai/cli/base.py b/clarifai/cli/base.py index 98b04e88..bfe2729f 100644 --- a/clarifai/cli/base.py +++ b/clarifai/cli/base.py @@ -6,6 +6,7 @@ import yaml from clarifai import __version__ +from clarifai.errors import UserError from clarifai.utils.cli import ( LazyAliasedGroup, TableFormatter, @@ -18,11 +19,11 @@ @click.group(cls=LazyAliasedGroup) @click.version_option(version=__version__) -@click.option('--config', default=DEFAULT_CONFIG, help='Path to config file') +@click.option('--config', default=DEFAULT_CONFIG, help='Path to config file.') @click.option('--context', default=None, help='Context to use for this command') @click.pass_context def cli(ctx, config, context): - """Clarifai CLI""" + """Build, deploy, and manage AI models on the Clarifai platform.""" ctx.ensure_object(dict) if os.path.exists(config): cfg = Config.from_yaml(filename=config) @@ -55,121 +56,190 @@ def cli(ctx, config, context): ctx.obj.context_override = context -@cli.command() +@cli.command(short_help='Generate shell completion script.') @click.argument('shell', type=click.Choice(['bash', 'zsh'])) def shell_completion(shell): - """Shell completion script""" + """Generate shell completion script for bash or zsh.""" os.system(f"_CLARIFAI_COMPLETE={shell}_source clarifai") -@cli.group(cls=LazyAliasedGroup) +@cli.group(cls=LazyAliasedGroup, short_help='Manage configuration profiles (contexts).') def config(): - """ - Manage multiple configuration profiles (contexts). + """Manage configuration profiles (contexts). - Authentication Precedence:\n - 1. Environment variables (e.g., `CLARIFAI_PAT`) are used first if set. - 2. The settings from the active context are used if no environment variables are provided.\n + \b + Authentication Precedence: + 1. Environment variables (e.g., CLARIFAI_PAT) are used first if set. + 2. The active context settings are used as fallback. """ -@cli.command() -@click.argument('api_url', default=DEFAULT_BASE) -@click.option('--user_id', required=False, help='User ID') -@click.pass_context -def login(ctx, api_url, user_id): - """Login command to set PAT and other configurations.""" - from clarifai.utils.cli import validate_context_auth +def _get_pat_interactive(): + """Resolve PAT from --pat flag, env var, or interactive prompt. Returns (pat, source).""" + env_pat = os.environ.get('CLARIFAI_PAT') + if env_pat: + click.secho('Using PAT from CLARIFAI_PAT environment variable.', fg='cyan') + return env_pat - # Input user_id if not supplied - if not user_id: - user_id = click.prompt('Enter your Clarifai user ID', type=str) + # Best-effort browser open to PAT page + pat_url = 'https://clarifai.com/me/settings/secrets' + try: + import webbrowser - click.echo() # Blank line for readability + click.secho( + 'Opening browser to create/copy your Personal Access Token (PAT)...', fg='cyan' + ) + click.echo(f' {click.style(pat_url, underline=True)}') + webbrowser.open(pat_url) + except Exception: + click.echo(f'Get your PAT at: {click.style(pat_url, fg="cyan", underline=True)}') + return masked_input('Enter your PAT: ') - # Check for environment variable first - env_pat = os.environ.get('CLARIFAI_PAT') - if env_pat: - use_env = click.confirm('Use CLARIFAI_PAT from environment?', default=True) - if use_env: - pat = env_pat - else: - click.echo(f'> Create a PAT at: https://clarifai.com/{user_id}/settings/secrets') - pat = masked_input('Enter your Personal Access Token (PAT): ') - else: - click.echo('> To authenticate, you\'ll need a Personal Access Token (PAT).') - click.echo(f'> Create one at: https://clarifai.com/{user_id}/settings/secrets') - click.echo('> Tip: Set CLARIFAI_PAT environment variable to skip this prompt.\n') - pat = masked_input('Enter your Personal Access Token (PAT): ') - # Progress indicator - click.echo('\n> Verifying token...') - validate_context_auth(pat, user_id, api_url) +def _verify_and_resolve_user(pat, api_url): + """Validate PAT and return the personal user_id via GET /v2/users/me.""" + try: + from clarifai.client.user import User - # Save context with default name - context_name = 'default' - context = Context( - context_name, - CLARIFAI_API_BASE=api_url, - CLARIFAI_USER_ID=user_id, - CLARIFAI_PAT=pat, - ) + user = User(user_id='me', pat=pat, base_url=api_url) + response = user.get_user_info(user_id='me') + return response.user.id + except Exception as e: + click.secho(f'Authentication failed: {e}', fg='red', err=True) + click.echo( + f'Create a PAT at: {click.style("https://clarifai.com/me/settings/secrets", fg="cyan", underline=True)}', + err=True, + ) + raise click.Abort() - ctx.obj.contexts[context_name] = context - ctx.obj.current_context = context_name - ctx.obj.to_yaml() - click.secho(f'✅ Success! You\'re logged in as {user_id}', fg='green') - click.echo('💡 Tip: Use `clarifai config` to manage multiple accounts or environments') - logger.info(f"Login successful for user '{user_id}' in context '{context_name}'") +def _list_user_orgs(pat, user_id, api_url): + """Return list of (org_id, org_name) tuples for the user. Returns [] on failure.""" + try: + from clarifai.client.user import User + user = User(user_id=user_id, pat=pat, base_url=api_url) + orgs = user.list_organizations() + return [(org['id'], org['name']) for org in orgs if org.get('id')] + except Exception: + return [] -@cli.command() -@click.pass_context -def whoami(ctx): - """Display information about the current user.""" - from clarifai_grpc.grpc.api.status import status_code_pb2 - from clarifai.client.user import User +def _prompt_user_or_org(personal_user_id, orgs): + """Interactive org selection. Returns selected user_id.""" + click.echo() + choices = [(personal_user_id, '(personal)')] + for org_id, org_name in orgs: + label = f'({org_name})' if org_name and org_name != org_id else '(org)' + choices.append((org_id, label)) + + for i, (uid, label) in enumerate(choices, 1): + num = click.style(f'[{i}]', fg='yellow', bold=True) + name = click.style(uid, bold=True) + tag = click.style(label, dim=True) + click.echo(f' {num} {name} {tag}') + click.echo() - # Get the current context - cfg = ctx.obj - current_ctx = cfg.contexts[cfg.current_context] + selection = click.prompt('Select user_id', default='1', type=str).strip() - # Get user_id from context - context_user_id = current_ctx.CLARIFAI_USER_ID - pat = current_ctx.CLARIFAI_PAT - base_url = current_ctx.CLARIFAI_API_BASE + # Accept number or exact user_id string + if selection.isdigit() and 1 <= int(selection) <= len(choices): + return choices[int(selection) - 1][0] + for uid, _ in choices: + if selection == uid: + return uid + return personal_user_id - # Display context user info - click.echo("Context User ID: " + click.style(context_user_id, fg='cyan', bold=True)) - # Call GetUser RPC with "me" to get the actual authenticated user - try: - user_client = User(user_id="me", pat=pat, base_url=base_url) - response = user_client.get_user_info(user_id="me") +def _env_prefix(api_url): + """'https://api-dev.clarifai.com' -> 'dev', 'https://api-staging...' -> 'staging'.""" + from urllib.parse import urlparse - if response.status.code == status_code_pb2.SUCCESS: - actual_user_id = response.user.id - click.echo( - "Authenticated User ID: " + click.style(actual_user_id, fg='green', bold=True) - ) + host = urlparse(api_url).hostname or '' + if 'api-dev' in host: + return 'dev' + elif 'api-staging' in host: + return 'staging' + return 'prod' - # Check if they differ - if context_user_id != actual_user_id: - click.echo() - click.secho( - "��️ Warning: The context user ID differs from the authenticated user ID!", - fg='yellow', - ) - click.echo( - "This means you as the caller will be calling different user or organization." - ) + +@cli.command(short_help='Authenticate and save credentials.') +@click.argument('api_url', default=DEFAULT_BASE) +@click.option('--pat', required=False, help='Personal Access Token (skips interactive prompt).') +@click.option( + '--user-id', required=False, help='User or org ID. Auto-detected from PAT if omitted.' +) +@click.option( + '--context', + 'context_name', + required=False, + help='Context name. Defaults to the selected user_id.', +) +@click.pass_context +def login(ctx, api_url, pat, user_id, context_name): + """Authenticate and save credentials. + + \b + Verifies your PAT, detects your user_id, and saves a named + context to ~/.config/clarifai/config. + + \b + API_URL Clarifai API base URL (default: https://api.clarifai.com). + + \b + Examples: + clarifai login # interactive + clarifai login --pat $MY_PAT # non-interactive + clarifai login --pat $PAT --user-id openai # org user, skip selection + clarifai login https://api-dev.clarifai.com # dev environment + clarifai login --context my-context # custom context name + """ + # 1. Get PAT: --pat flag > CLARIFAI_PAT env var > browser + interactive prompt + if not pat: + pat = _get_pat_interactive() + + # 2. Validate PAT + resolve personal user_id + click.secho('Verifying...', dim=True) + personal_user_id = _verify_and_resolve_user(pat, api_url) + + # 3. Resolve which user_id to use + if user_id: + selected_user_id = user_id + else: + orgs = _list_user_orgs(pat, personal_user_id, api_url) + if orgs: + selected_user_id = _prompt_user_or_org(personal_user_id, orgs) else: - click.secho(f"Error getting user info: {response.status.description}", fg='red') + selected_user_id = personal_user_id - except Exception as e: - click.secho(f"Error: Could not retrieve authenticated user info: {str(e)}", fg='red') + # 4. Derive context name + if not context_name: + if api_url != DEFAULT_BASE: + prefix = _env_prefix(api_url) + context_name = f"{prefix}-{selected_user_id}" + else: + context_name = selected_user_id + + # 5. Save (update if exists, create if new) + action = "Updated" if context_name in ctx.obj.contexts else "Created" + context = Context( + context_name, + CLARIFAI_API_BASE=api_url, + CLARIFAI_USER_ID=selected_user_id, + CLARIFAI_PAT=pat, + ) + ctx.obj.contexts[context_name] = context + ctx.obj.current_context = context_name + ctx.obj.to_yaml() + + # 6. Output + click.echo() + click.echo( + click.style('Logged in as ', fg='green') + + click.style(selected_user_id, fg='green', bold=True) + + click.style(f' ({api_url})', dim=True) + ) + click.echo(f'{action} context {click.style(context_name, bold=True)} and set as active.') def _warn_env_pat(): @@ -245,45 +315,38 @@ def _logout_all_contexts(cfg): @cli.command() -@click.option( - '--current', - 'flag_current', - is_flag=True, - default=False, - help='Clear credentials from the current context (non-interactive).', -) @click.option( '--all', 'flag_all', is_flag=True, default=False, - help='Clear credentials from all contexts (non-interactive).', + help='Clear credentials from all contexts.', ) @click.option( '--context', 'flag_context', default=None, type=str, - help='Clear credentials from a specific named context (non-interactive).', + help='Clear credentials from a specific named context.', ) @click.option( '--delete', 'flag_delete', is_flag=True, default=False, - help='Also delete the context entry (use with --current or --context).', + help='Also delete the context entry (use with --context).', ) @click.pass_context -def logout(ctx, flag_current, flag_all, flag_context, flag_delete): +def logout(ctx, flag_all, flag_context, flag_delete): """Log out by clearing saved credentials. - Without flags, an interactive menu is shown. Use flags for - programmatic / non-interactive usage. + \b + By default, clears credentials from the current context. + Use flags to target a specific or all contexts. \b Examples: - clarifai logout # Interactive - clarifai logout --current # Clear current context PAT + clarifai logout # Log out of current context clarifai logout --context staging # Clear 'staging' PAT clarifai logout --context staging --delete # Remove 'staging' entirely clarifai logout --all # Clear every context PAT @@ -292,97 +355,118 @@ def logout(ctx, flag_current, flag_all, flag_context, flag_delete): if not cfg or not hasattr(cfg, 'contexts'): raise click.ClickException("Not logged in. Run `clarifai login` first.") - # --- Validation for flag combinations --- - if flag_all and (flag_current or flag_context or flag_delete): - raise click.UsageError("--all cannot be combined with --current, --context, or --delete.") - - if flag_delete and not (flag_current or flag_context): - raise click.UsageError("--delete requires --current or --context.") - - if flag_current and flag_context: - raise click.UsageError("Cannot use --current and --context together.") + if flag_all and (flag_context or flag_delete): + raise click.UsageError("--all cannot be combined with --context or --delete.") - # --- Non-interactive paths --- if flag_all: _logout_all_contexts(cfg) - _warn_env_pat() - return - - if flag_current: + elif flag_context: + _logout_one_context(cfg, flag_context, delete=flag_delete) + else: + # Default: log out of current context _logout_one_context(cfg, cfg.current_context, delete=flag_delete) - _warn_env_pat() - return - if flag_context: - _logout_one_context(cfg, flag_context, delete=flag_delete) - _warn_env_pat() - return + _warn_env_pat() + click.echo("Run 'clarifai login' to re-authenticate.") - # --- Interactive flow --- - cur_name = cfg.current_context - cur_ctx = cfg.contexts.get(cur_name) - if not cur_ctx: - raise click.ClickException("No active context found. Run `clarifai login` first.") - user_id = cur_ctx['env'].get('CLARIFAI_USER_ID', 'unknown') - api_base = cur_ctx['env'].get('CLARIFAI_API_BASE', DEFAULT_BASE) - click.echo( - f"\nCurrent context is configured for user '{user_id}' (context: '{cur_name}', api: {api_base})\n" - ) +def pat_display(pat): + return pat[:5] + "****" - # Build menu - choices = [] - choices.append(('switch', 'Switch to another context')) - choices.append(('logout_current', 'Log out of current context (clear credentials)')) - choices.append(('logout_delete', 'Log out and delete current context')) - choices.append(('logout_all', 'Log out of all contexts')) - choices.append(('cancel', 'Cancel')) - for i, (_, label) in enumerate(choices, 1): - click.echo(f" {i}. {label}") +@cli.command(short_help='Show current user and context.') +@click.option('--orgs', is_flag=True, help='Show organizations you belong to.') +@click.option('--all', 'show_all', is_flag=True, help='Show full profile (email, name, orgs).') +@click.option( + '-o', + '--output-format', + default='wide', + type=click.Choice(['wide', 'json']), + help='Output format.', +) +@click.pass_context +def whoami(ctx, orgs, show_all, output_format): + """Show current user and context. - click.echo() - choice_num = click.prompt('Enter choice', type=click.IntRange(1, len(choices))) - action = choices[choice_num - 1][0] - - if action == 'cancel': - click.echo('Cancelled. No changes made.') - return - - if action == 'switch': - other_contexts = [n for n in cfg.contexts if n != cur_name] - if not other_contexts: - click.echo("No other contexts available. Use `clarifai login` to create one.") - return - click.echo('\nAvailable contexts:') - for i, name in enumerate(other_contexts, 1): - uid = cfg.contexts[name]['env'].get('CLARIFAI_USER_ID', 'unknown') - click.echo(f" {i}. {name} (user: {uid})") - click.echo() - idx = click.prompt('Switch to', type=click.IntRange(1, len(other_contexts))) - target_name = other_contexts[idx - 1] - cfg.current_context = target_name - cfg.to_yaml() - click.secho( - f"Switched to context '{target_name}'. No credentials were cleared.", fg='green' - ) - return + \b + Examples: + clarifai whoami # user + context (local only) + clarifai whoami --orgs # include organizations (API call) + clarifai whoami --all # full profile + orgs (API call) + clarifai whoami -o json # JSON output for scripting + """ + cfg = ctx.obj + context = cfg.current - if action == 'logout_current': - _logout_one_context(cfg, cur_name) + # Check if logged in + pat = context.get('pat') + user_id = context.get('user_id') + api_base = context.get('api_base', DEFAULT_BASE) - elif action == 'logout_delete': - _logout_one_context(cfg, cur_name, delete=True) + if not pat or not user_id or user_id == '_empty_': + click.secho("Not logged in. Run 'clarifai login' to authenticate.", fg='red', err=True) + raise SystemExit(1) - elif action == 'logout_all': - _logout_all_contexts(cfg) + data = { + 'user_id': user_id, + 'context': cfg.current_context, + 'api_base': api_base, + } - _warn_env_pat() - click.echo("\nRun 'clarifai login' to re-authenticate.") + # Fetch full profile and/or orgs if requested + org_list = [] + if show_all or orgs: + try: + org_list = _list_user_orgs(pat, user_id, api_base) + data['organizations'] = [{'id': oid, 'name': oname} for oid, oname in org_list] + except Exception: + org_list = [] + if show_all: + try: + from clarifai.client.user import User + + user = User(user_id='me', pat=pat, base_url=api_base) + response = user.get_user_info(user_id=user_id) + u = response.user + if u.full_name: + data['name'] = u.full_name + if u.primary_email: + data['email'] = u.primary_email + if u.company_name: + data['company'] = u.company_name + except Exception: + if output_format == 'wide': + click.secho( + 'Warning: could not fetch full profile from API.', fg='yellow', err=True + ) -def pat_display(pat): - return pat[:5] + "****" + # Output + if output_format == 'json': + click.echo(json.dumps(data)) + else: + click.echo( + click.style('User: ', bold=True) + click.style(user_id, fg='green', bold=True) + ) + if data.get('name'): + click.echo(click.style('Name: ', bold=True) + data['name']) + if data.get('email'): + click.echo(click.style('Email: ', bold=True) + data['email']) + if data.get('company'): + click.echo(click.style('Company: ', bold=True) + data['company']) + click.echo( + click.style('Context: ', bold=True) + + f'{cfg.current_context} @ ' + + click.style(api_base, dim=True) + ) + + if org_list: + click.echo() + click.echo(click.style('Organizations:', bold=True)) + for org_id, org_name in org_list: + click.echo( + f' {click.style(org_id, bold=True)} {click.style(org_name, dim=True)}' + ) def input_or_default(prompt, default): @@ -420,10 +504,14 @@ def get_contexts(ctx, output_format): """List all available contexts.""" if output_format == 'wide': columns = { - '': lambda c: '*' if c.name == ctx.obj.current_context else '', - 'NAME': lambda c: c.name, + '': lambda c: click.style('*', fg='green', bold=True) + if c.name == ctx.obj.current_context + else '', + 'NAME': lambda c: click.style(c.name, bold=True) + if c.name == ctx.obj.current_context + else c.name, 'USER_ID': lambda c: c.user_id, - 'API_BASE': lambda c: c.api_base, + 'API_BASE': lambda c: click.style(c.api_base, dim=True), 'PAT': lambda c: pat_display(c.pat), } additional_columns = set() @@ -461,10 +549,12 @@ def get_contexts(ctx, output_format): def use_context(ctx, name): """Set the current context.""" if name not in ctx.obj.contexts: - raise click.UsageError('Context not found') + raise click.UsageError(f'Context "{name}" not found') ctx.obj.current_context = name ctx.obj.to_yaml() - print(f'Set {name} as the current context') + click.echo( + click.style('Switched to context ', fg='green') + click.style(name, fg='green', bold=True) + ) @config.command(aliases=['current-context', 'current']) @@ -482,46 +572,48 @@ def current_context(ctx, output_format): @config.command(aliases=['create-context', 'create']) @click.argument('name') -@click.option('--user-id', required=False, help='User ID') -@click.option('--base-url', required=False, help='Base URL') -@click.option('--pat', required=False, help='Personal access token') +@click.option( + '--user-id', required=False, help='User or org ID. Auto-detected from PAT if omitted.' +) +@click.option('--base-url', required=False, default=DEFAULT_BASE, help='API base URL.') +@click.option('--pat', required=False, help='Personal Access Token.') @click.pass_context -def create_context( - ctx, - name, - user_id=None, - base_url=None, - pat=None, -): - """Create a new context.""" - from clarifai.utils.cli import validate_context_auth - +def create_context(ctx, name, user_id, base_url, pat): + """Create a new named context.""" if name in ctx.obj.contexts: - click.secho(f'Error: Context "{name}" already exists', fg='red', err=True) - sys.exit(1) - if not user_id: - user_id = input('user id: ') - if not base_url: - base_url = input_or_default( - 'base url (default: https://api.clarifai.com): ', 'https://api.clarifai.com' + click.secho( + f'Context "{name}" already exists. Use "clarifai login" to update it.', + fg='red', + err=True, ) + raise SystemExit(1) + + # Same PAT resolution as login: flag > env > browser + prompt if not pat: - # Check for environment variable first - env_pat = os.environ.get('CLARIFAI_PAT') - if env_pat: - use_env = click.confirm('Found CLARIFAI_PAT in environment. Use it?', default=True) - if use_env: - pat = env_pat - else: - pat = masked_input('Enter your Personal Access Token (PAT): ') + pat = _get_pat_interactive() + + # Same user_id resolution: flag > auto-detect + org selection + if not user_id: + click.secho('Verifying...', dim=True) + personal_user_id = _verify_and_resolve_user(pat, base_url) + orgs = _list_user_orgs(pat, personal_user_id, base_url) + if orgs: + user_id = _prompt_user_or_org(personal_user_id, orgs) else: - click.echo('Tip: Set CLARIFAI_PAT environment variable to skip this step.') - pat = masked_input('Enter your Personal Access Token (PAT): ') - validate_context_auth(pat, user_id, base_url) + user_id = personal_user_id + else: + click.secho('Verifying...', dim=True) + _verify_and_resolve_user(pat, base_url) + context = Context(name, CLARIFAI_USER_ID=user_id, CLARIFAI_API_BASE=base_url, CLARIFAI_PAT=pat) - ctx.obj.contexts[context.name] = context + ctx.obj.contexts[name] = context ctx.obj.to_yaml() - click.secho(f"✅ Context '{name}' created successfully", fg='green') + click.echo( + click.style('Context ', fg='green') + + click.style(name, fg='green', bold=True) + + click.style(' created ', fg='green') + + click.style(f'({user_id} @ {base_url})', dim=True) + ) @config.command(aliases=['e']) @@ -540,11 +632,13 @@ def edit( def delete_context(ctx, name): """Delete a context.""" if name not in ctx.obj.contexts: - print(f'{name} is not a valid context') + click.secho(f'Context "{name}" not found.', fg='red', err=True) sys.exit(1) ctx.obj.contexts.pop(name) ctx.obj.to_yaml() - print(f'{name} deleted') + click.echo( + click.style('Deleted context ', fg='yellow') + click.style(name, fg='yellow', bold=True) + ) @config.command(aliases=['get-env']) @@ -572,12 +666,12 @@ def view(ctx, output_format): print(yaml.safe_dump(config_dict, default_flow_style=False)) -@cli.command() +@cli.command(short_help='Run a script with context env vars.') @click.argument('script', type=str) @click.option('--context', type=str, help='Context to use') @click.pass_context def run(ctx, script, context=None): - """Execute a script with the current context's environment""" + """Run a script with the current context's environment variables injected.""" # Get the effective context - either from --context flag or current context if context: context_obj = validate_and_get_context(ctx.obj, context) @@ -593,6 +687,25 @@ def run(ctx, script, context=None): # Import the CLI commands to register them # load_command_modules() - Now handled lazily by LazyLazyAliasedGroupp +# Define section ordering for `clarifai --help` +cli.command_sections = [ + ('Auth', ['login', 'whoami']), + ('Config', ['config']), + ('Models', ['model']), + ('Pipelines', ['pipeline', 'pipeline-step', 'pipelinerun', 'pipelinetemplate']), + ('Compute', ['list-instances', 'computecluster', 'nodepool', 'deployment']), + ('Other', ['artifact', 'run', 'shell-completion']), +] + def main(): - cli() + try: + cli(standalone_mode=False) + except click.exceptions.Exit: + pass # Normal exit (e.g. --help, --version) + except click.ClickException as e: + e.show() + sys.exit(e.exit_code) + except UserError as e: + click.echo(click.style(f"\nError: {e}", fg="red"), err=True) + sys.exit(1) diff --git a/clarifai/cli/compute_cluster.py b/clarifai/cli/compute_cluster.py index 064a77bb..44261a83 100644 --- a/clarifai/cli/compute_cluster.py +++ b/clarifai/cli/compute_cluster.py @@ -12,7 +12,7 @@ context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}, ) def computecluster(): - """Manage Compute Clusters: create, delete, list""" + """Manage compute clusters.""" @computecluster.command(['c']) diff --git a/clarifai/cli/deployment.py b/clarifai/cli/deployment.py index 92a345b9..6bce7dc2 100644 --- a/clarifai/cli/deployment.py +++ b/clarifai/cli/deployment.py @@ -12,7 +12,7 @@ context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}, ) def deployment(): - """Manage Deployments: create, delete, list""" + """Manage deployments.""" @deployment.command(['c']) @@ -111,19 +111,147 @@ def list(ctx, nodepool_id, page_no, per_page): ) +@deployment.command(['get', 'status']) +@click.argument('deployment_id') +@click.pass_context +def get(ctx, deployment_id): + """Show details for a single deployment. + + \b + Examples: + clarifai deployment get deploy-abc123 + clarifai deployment status deploy-abc123 + """ + validate_context(ctx) + + from clarifai.cli.model import _print_deployment_detail # noqa: E402 + from clarifai.errors import UserError + from clarifai.runners.models.model_deploy import get_deployment + + try: + dep = get_deployment( + deployment_id, + user_id=ctx.obj.current.user_id, + pat=ctx.obj.current.pat, + base_url=ctx.obj.current.api_base, + ) + _print_deployment_detail(dep) + except UserError as e: + click.echo(click.style(f"\nError: {e}", fg="red"), err=True) + raise SystemExit(1) + + +@deployment.command() +@click.argument('deployment_id') +@click.option( + '--follow/--no-follow', + default=True, + help='Continuously tail new logs. Use --no-follow to print and exit.', +) +@click.option( + '--duration', + default=None, + type=int, + help='Stop after N seconds (default: unlimited, Ctrl+C to stop).', +) +@click.option( + '--log-type', + default='model', + type=click.Choice(['model', 'events'], case_sensitive=False), + help='Log type: model (stdout/stderr) or events (k8s scheduling/scaling).', +) +@click.pass_context +def logs(ctx, deployment_id, follow, duration, log_type): + """Stream logs from a deployment's runner. + + \b + Resolves the model, version, and nodepool from the deployment + and streams runner stdout/stderr or k8s events. + + \b + Examples: + clarifai deployment logs deploy-abc123 + clarifai deployment logs deploy-abc123 --log-type events + clarifai deployment logs deploy-abc123 --no-follow + clarifai deployment logs deploy-abc123 --duration 60 + """ + validate_context(ctx) + + from clarifai.errors import UserError + from clarifai.runners.models.model_deploy import get_deployment, stream_model_logs + + user_id = ctx.obj.current.user_id + pat = ctx.obj.current.pat + base_url = ctx.obj.current.api_base + + try: + dep = get_deployment(deployment_id, user_id=user_id, pat=pat, base_url=base_url) + except UserError as e: + click.echo(click.style(f"\nError: {e}", fg="red"), err=True) + raise SystemExit(1) + + # Extract model/version/nodepool from deployment proto + model_id = app_id = model_version_id = None + compute_cluster_id = nodepool_id = None + + w = dep.worker + if w and w.model: + model_id = w.model.id + user_id = w.model.user_id or user_id + app_id = w.model.app_id + if w.model.model_version and w.model.model_version.id: + model_version_id = w.model.model_version.id + if dep.nodepools: + np = dep.nodepools[0] + nodepool_id = np.id + if np.compute_cluster and np.compute_cluster.id: + compute_cluster_id = np.compute_cluster.id + + # Map user-friendly names to API log_type values + api_log_type = {"model": "runner", "events": "runner.events"}[log_type.lower()] + + try: + stream_model_logs( + model_id=model_id, + user_id=user_id, + app_id=app_id, + model_version_id=model_version_id, + compute_cluster_id=compute_cluster_id, + nodepool_id=nodepool_id, + pat=pat, + base_url=base_url, + follow=follow, + duration=duration, + log_type=api_log_type, + ) + except UserError as e: + click.echo(click.style(f"\nError: {e}", fg="red"), err=True) + raise SystemExit(1) + + @deployment.command(['rm']) -@click.argument('nodepool_id') @click.argument('deployment_id') @click.pass_context -def delete(ctx, nodepool_id, deployment_id): - """Deletes a deployment for the nodepool.""" - from clarifai.client.nodepool import Nodepool +def delete(ctx, deployment_id): + """Delete a deployment. + \b + Examples: + clarifai deployment rm deploy-abc123 + """ validate_context(ctx) - nodepool = Nodepool( - nodepool_id=nodepool_id, - user_id=ctx.obj.current.user_id, - pat=ctx.obj.current.pat, - base_url=ctx.obj.current.api_base, - ) - nodepool.delete_deployments([deployment_id]) + + from clarifai.errors import UserError + from clarifai.runners.models.model_deploy import delete_deployment + + try: + delete_deployment( + deployment_id, + user_id=ctx.obj.current.user_id, + pat=ctx.obj.current.pat, + base_url=ctx.obj.current.api_base, + ) + click.echo(click.style(f" Deployment '{deployment_id}' deleted.", fg="green")) + except UserError as e: + click.echo(click.style(f"\nError: {e}", fg="red"), err=True) + raise SystemExit(1) diff --git a/clarifai/cli/list_instances.py b/clarifai/cli/list_instances.py new file mode 100644 index 00000000..fc799b85 --- /dev/null +++ b/clarifai/cli/list_instances.py @@ -0,0 +1,47 @@ +import click + +from clarifai.cli.base import cli +from clarifai.utils.cli import validate_context + + +@cli.command(['list-instances', 'li'], short_help='List available compute instances.') +@click.option('--cloud', default=None, help='Filter by cloud provider (aws, gcp, vultr, azure).') +@click.option('--region', default=None, help='Filter by region (us-east-1, us-central1).') +@click.option('--gpu', default=None, help='Filter by GPU name (A10G, H100, L40S).') +@click.option('--min-gpus', type=int, default=None, help='Minimum GPU count.') +@click.option('--min-gpu-mem', default=None, help='Minimum GPU memory (e.g., 80Gi, 48Gi).') +@click.pass_context +def list_instances(ctx, cloud, region, gpu, min_gpus, min_gpu_mem): + """List available compute instance types with GPU, memory, and cloud info. + + \b + Examples: + clarifai list-instances # all instances + clarifai li --cloud aws # AWS only + clarifai li --gpu H100 # H100 instances + clarifai li --min-gpus 2 # multi-GPU instances + clarifai li --min-gpu-mem 48Gi # 48+ GiB GPU memory + clarifai li --cloud aws --gpu L40S # combined filters + """ + from clarifai.utils.compute_presets import list_gpu_presets + + pat_val = None + base_url_val = None + try: + validate_context(ctx) + pat_val = ctx.obj.current.pat + base_url_val = ctx.obj.current.api_base + except Exception: + pass + + click.echo( + list_gpu_presets( + pat=pat_val, + base_url=base_url_val, + cloud_provider=cloud, + region=region, + gpu_name=gpu, + min_gpus=min_gpus, + min_gpu_mem=min_gpu_mem, + ) + ) diff --git a/clarifai/cli/model.py b/clarifai/cli/model.py index 30cf3e7e..e9fbd0c2 100644 --- a/clarifai/cli/model.py +++ b/clarifai/cli/model.py @@ -1,17 +1,21 @@ import json import os +import platform +import re import shutil import socket import subprocess -import tempfile from contextlib import closing +from pathlib import Path from typing import Any, Dict, Optional import click import yaml -from clarifai.cli.base import cli, pat_display +from clarifai.cli.base import cli +from clarifai.errors import UserError from clarifai.utils.cli import ( + AliasedGroup, check_lmstudio_installed, check_ollama_installed, check_requirements_installed, @@ -31,29 +35,14 @@ from clarifai.utils.constants import ( CLI_LOGIN_DOC_URL, CONFIG_GUIDE_URL, - DEFAULT_HF_MODEL_REPO_BRANCH, - DEFAULT_LMSTUDIO_MODEL_REPO_BRANCH, DEFAULT_LOCAL_RUNNER_APP_ID, DEFAULT_LOCAL_RUNNER_COMPUTE_CLUSTER_CONFIG, DEFAULT_LOCAL_RUNNER_COMPUTE_CLUSTER_ID, - DEFAULT_LOCAL_RUNNER_DEPLOYMENT_ID, - DEFAULT_LOCAL_RUNNER_MODEL_ID, DEFAULT_LOCAL_RUNNER_MODEL_TYPE, DEFAULT_LOCAL_RUNNER_NODEPOOL_CONFIG, DEFAULT_LOCAL_RUNNER_NODEPOOL_ID, - DEFAULT_OLLAMA_MODEL_REPO_BRANCH, - DEFAULT_PYTHON_MODEL_REPO_BRANCH, - DEFAULT_SGLANG_MODEL_REPO_BRANCH, - DEFAULT_TOOLKIT_MODEL_REPO, - DEFAULT_VLLM_MODEL_REPO_BRANCH, ) from clarifai.utils.logging import logger -from clarifai.utils.misc import ( - GitHubDownloader, - clone_github_repo, - format_github_repo_url, - get_list_of_files_to_download, -) def find_available_port(start_port=8080): @@ -257,12 +246,11 @@ def ensure_config_exists_for_upload(ctx, model_path: str) -> None: raise click.Abort() if ctx_config is not None: - selected_context = _select_context(ctx_config) - if selected_context is not None: - current_context = selected_context - elif current_context is None: - contexts_map = getattr(ctx_config, "contexts", {}) or {} - current_context = contexts_map.get(getattr(ctx_config, "current_context", None)) + # Use the active CLI context automatically (no interactive picker). + contexts_map = getattr(ctx_config, "contexts", {}) or {} + current_name = getattr(ctx_config, "current_context", None) + if current_name and current_name in contexts_map: + current_context = contexts_map[current_name] if current_context is None: click.echo( @@ -499,12 +487,103 @@ def ensure_config_exists_for_upload(ctx, model_path: str) -> None: @cli.group( - ['model'], context_settings={'max_content_width': shutil.get_terminal_size().columns - 10} + ['model'], + cls=AliasedGroup, + context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}, ) def model(): - """Manage & Develop Models: init, download-checkpoints, signatures, upload\n - Run & Test Models Locally: local-runner, local-grpc, local-test\n - Model Inference: list, predict""" + """Build, test, and deploy models. + + \b + Workflow: init → serve → deploy + Observe: status, logs, predict + Lifecycle: undeploy + """ + + +def _sanitize_model_id(name): + """Convert a model name to a valid model.id (lowercase, alphanumeric, hyphens only).""" + name = name.split('/')[-1] # "Qwen/Qwen3-0.6B" -> "Qwen3-0.6B" + name = name.lower() + name = name.replace('_', '-') + name = re.sub(r'[^a-z0-9-]', '', name) # strip invalid chars (dots, etc.) + name = re.sub(r'-+', '-', name).strip('-') # collapse/trim hyphens + return name or "my-model" + + +def _copy_embedded_toolkit(toolkit, model_path): + """Copy embedded toolkit template files to model_path.""" + toolkit_dir = Path(__file__).parent / "templates" / "toolkits" / toolkit + if not toolkit_dir.exists(): + raise UserError(f"Toolkit '{toolkit}' template not found at {toolkit_dir}") + for item in toolkit_dir.iterdir(): + if item.name == '__pycache__': + continue + dest = Path(model_path) / item.name + if item.is_dir(): + shutil.copytree(item, dest, dirs_exist_ok=True) + else: + shutil.copy2(item, dest) + + +def _ensure_config_defaults(model_path, model_type_id='any-to-any'): + """Ensure config.yaml has required fields that older clarifai versions assert on. + + When running in env/container mode, the subprocess installs clarifai from PyPI + which may still assert on model_type_id. This patches it into config.yaml if missing. + """ + config_path = os.path.join(model_path, 'config.yaml') + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) or {} + model = config.get('model', {}) + changed = False + if 'model_type_id' not in model: + model['model_type_id'] = model_type_id + config['model'] = model + changed = True + if changed: + with open(config_path, 'w', encoding='utf-8') as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + +def _patch_config(config_path, model_id, checkpoints_repo_id=None): + """Update model.id and optionally checkpoints.repo_id in config.yaml.""" + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) or {} + config.setdefault('model', {})['id'] = model_id + if checkpoints_repo_id: + config.setdefault('checkpoints', {})['repo_id'] = checkpoints_repo_id + with open(config_path, 'w', encoding='utf-8') as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + +def _print_init_success(model_path, toolkit, instance=None): + """Print unified success message after init.""" + from clarifai.runners.models import deploy_output as out + + click.echo() + out.success(f"Model initialized in {model_path}") + click.echo() + if toolkit in ('python', 'mcp', 'openai', None): + click.echo(" 1. Edit 1/model.py with your model logic") + click.echo(" 2. Add dependencies to requirements.txt") + click.echo() + click.echo(" Test locally:") + click.echo(f" clarifai model serve {model_path}") + click.echo( + f" clarifai model serve {model_path} --mode env # auto-create venv and install deps" + ) + click.echo(f" clarifai model serve {model_path} --mode container # run inside Docker") + click.echo() + click.echo(" Deploy to Clarifai:") + if instance: + click.echo(f" clarifai model deploy {model_path}") + else: + click.echo(f" clarifai model deploy {model_path} --instance gpu-nvidia-a10g") + click.echo( + " clarifai list-instances # list available instances" + ) + click.echo() @model.command() @@ -512,318 +591,146 @@ def model(): "model_path", type=click.Path(), required=False, - default=".", -) -@click.option( - '--model-type-id', - type=click.Choice(['mcp', 'openai'], case_sensitive=False), - required=False, - help='Model type: "mcp" for MCPModelClass, "openai" for OpenAIModelClass, or leave empty for default ModelClass.', -) -@click.option( - '--github-pat', - required=False, - help='GitHub Personal Access Token for authentication when cloning private repositories.', -) -@click.option( - '--github-url', - required=False, - help='GitHub repository URL or "user/repo" format to clone a repository from. If provided, the entire repository contents will be copied to the target directory instead of using default templates.', + default=None, ) @click.option( '--toolkit', type=click.Choice( - ['ollama', 'huggingface', 'lmstudio', 'vllm', 'sglang', 'python'], case_sensitive=False + ['vllm', 'sglang', 'huggingface', 'ollama', 'mcp', 'python', 'openai', 'lmstudio'], + case_sensitive=False, ), required=False, - help='Toolkit to use for model initialization. Currently supports "ollama", "huggingface", "lmstudio", "vllm", "sglang" and "python".', + help='Inference toolkit to scaffold. Omit for a blank Python model.', ) @click.option( '--model-name', required=False, - help='Model name to configure when using --toolkit. For ollama toolkit, this sets the Ollama model to use (e.g., "llama3.1", "mistral", etc.). For vllm, sglang & huggingface toolkit, this sets the Hugging Face model repo_id (e.g., "unsloth/Llama-3.2-1B-Instruct").\n For lmstudio toolkit, this sets the LM Studio model name (e.g., "qwen/qwen3-4b-thinking-2507").\n', -) -@click.option( - '--port', - type=str, - help='Port to run the Ollama server on. Defaults to 23333.', - required=False, -) -@click.option( - '--context-length', - type=str, - help='Context length for the Ollama model. Defaults to 8192.', - required=False, + help='Model checkpoint (HF repo_id or ollama tag). Auto-creates directory from name.', ) @click.pass_context def init( ctx, model_path, - model_type_id, - github_pat, - github_url, toolkit, model_name, - port, - context_length, ): - """Initialize a new model directory structure. - - Creates the following structure in the specified directory:\n - ├── 1/\n - │ └── model.py\n - ├── requirements.txt\n - └── config.yaml\n - - If --github-repo is provided, the entire repository contents will be copied to the target - directory instead of using default templates. The --github-pat option can be used for authentication - when cloning private repositories. The --branch option can be used to specify a specific - branch to clone from. - - MODEL_PATH: Path where to create the model directory structure. If not specified, the current directory is used by default.\n - - OPTIONS:\n - MODEL_TYPE_ID: Type of model to create. If not specified, defaults to "text-to-text" for text models.\n - GITHUB_PAT: GitHub Personal Access Token for authentication when cloning private repositories.\n - GITHUB_URL: GitHub repository URL or "repo" format to clone a repository from. If provided, the entire repository contents will be copied to the target directory instead of using default templates.\n - TOOLKIT: Toolkit to use for model initialization. Currently supports "ollama", "huggingface", "lmstudio", "vllm", "sglang" and "python".\n - MODEL_NAME: Model name to configure when using --toolkit. For ollama toolkit, this sets the Ollama model to use (e.g., "llama3.1", "mistral", etc.). For vllm, sglang & huggingface toolkit, this sets the Hugging Face model repo_id (e.g., "Qwen/Qwen3-4B-Instruct-2507"). For lmstudio toolkit, this sets the LM Studio model name (e.g., "qwen/qwen3-4b-thinking-2507").\n - PORT: Port to run the (Ollama/lmstudio) server on. Defaults to 23333.\n - CONTEXT_LENGTH: Context length for the (Ollama/lmstudio) model. Defaults to 8192.\n - """ - validate_context(ctx) - user_id = ctx.obj.current.user_id - # Resolve the absolute path - model_path = os.path.abspath(model_path) + """Scaffold a new model project with a specific toolkit like vLLM, SGLang, HuggingFace, Ollama, etc. - # Create the model directory if it doesn't exist - os.makedirs(model_path, exist_ok=True) + \b + Creates a ready-to-serve model directory with config.yaml, + requirements.txt, and 1/model.py. Pick a toolkit for a specific + inference engine, or omit --toolkit for a blank Python template. - # Validate parameters - if port and not port.isdigit(): - logger.error("Invalid value: --port must be a number") - raise click.Abort() + \b + MODEL_PATH Target directory (default: current dir). + Auto-created from --model-name if omitted + (e.g., --model-name org/Model → ./Model/). - if context_length and not context_length.isdigit(): - logger.error("Invalid value: --context-length must be a number") - raise click.Abort() + \b + Toolkits (GPU): + vllm High-throughput LLM serving with vLLM + sglang Fast LLM serving with SGLang + huggingface HuggingFace Transformers (direct inference) + + \b + Toolkits (local): + ollama Ollama (local LLM server) + lmstudio LM Studio (local LLM server) + + \b + Toolkits (other): + python Blank Python model (default) + mcp MCP tool server (FastMCP) + openai OpenAI-compatible API wrapper + + \b + Examples: + clarifai model init --toolkit vllm --model-name Qwen/Qwen3-0.6B + clarifai model init --toolkit ollama --model-name llama3.1 + clarifai model init --toolkit mcp my-mcp-server + clarifai model init my-model + """ + validate_context(ctx) # Validate option combinations - if model_name and not (toolkit): + if model_name and not toolkit: logger.error("--model-name can only be used with --toolkit") raise click.Abort() - if toolkit and (github_url): - logger.error("Cannot specify both --toolkit and --github-repo") - raise click.Abort() + # Resolve model_path: explicit > derived from model-name > current dir + if model_path is None: + if model_name: + # "Qwen/Qwen3-0.6B" -> "./Qwen3-0.6B" + model_path = model_name.split('/')[-1] + else: + model_path = "." + model_path = os.path.abspath(model_path) + os.makedirs(model_path, exist_ok=True) - # --toolkit option - if toolkit == 'ollama': - if not check_ollama_installed(): - logger.error( - "Ollama is not installed. Please install it from `https://ollama.com/` to use the Ollama toolkit." - ) + # Derive model_id: from --model-name if given, else from directory name + if model_name: + model_id = _sanitize_model_id(model_name) + else: + model_id = _sanitize_model_id(os.path.basename(model_path)) + + # Embedded toolkits: copy from clarifai/cli/templates/toolkits/{name}/ + EMBEDDED_TOOLKITS = ('vllm', 'sglang', 'huggingface', 'ollama', 'lmstudio') + # Template toolkits: generate from string templates + TEMPLATE_TOOLKITS = ('mcp', 'openai', 'python') + + if toolkit in EMBEDDED_TOOLKITS: + # Pre-flight checks for local server toolkits + if toolkit == 'ollama' and not check_ollama_installed(): + logger.error("Ollama is not installed. Please install it from https://ollama.com/") raise click.Abort() - github_url = DEFAULT_TOOLKIT_MODEL_REPO - branch = DEFAULT_OLLAMA_MODEL_REPO_BRANCH - elif toolkit == 'huggingface': - github_url = DEFAULT_TOOLKIT_MODEL_REPO - branch = DEFAULT_HF_MODEL_REPO_BRANCH - elif toolkit == 'lmstudio': - if not check_lmstudio_installed(): + if toolkit == 'lmstudio' and not check_lmstudio_installed(): logger.error( - "LM Studio is not installed. Please install it from `https://lmstudio.com/` to use the LM Studio toolkit." + "LM Studio is not installed. Please install it from https://lmstudio.com/" ) raise click.Abort() - github_url = DEFAULT_TOOLKIT_MODEL_REPO - branch = DEFAULT_LMSTUDIO_MODEL_REPO_BRANCH - elif toolkit == 'vllm': - github_url = DEFAULT_TOOLKIT_MODEL_REPO - branch = DEFAULT_VLLM_MODEL_REPO_BRANCH - elif toolkit == 'sglang': - github_url = DEFAULT_TOOLKIT_MODEL_REPO - branch = DEFAULT_SGLANG_MODEL_REPO_BRANCH - elif toolkit == 'python': - github_url = DEFAULT_TOOLKIT_MODEL_REPO - branch = DEFAULT_PYTHON_MODEL_REPO_BRANCH - - if github_url: - downloader = GitHubDownloader( - max_retries=3, - github_token=github_pat, - ) - if toolkit: - owner, repo, _, folder_path = downloader.parse_github_url(url=github_url) - else: - owner, repo, branch, folder_path = downloader.parse_github_url(url=github_url) - logger.info( - f"Parsed GitHub repository: owner={owner}, repo={repo}, branch={branch}, folder_path={folder_path}" - ) - files_to_download = get_list_of_files_to_download( - downloader, owner, repo, folder_path, branch, [] - ) - for i, file in enumerate(files_to_download): - files_to_download[i] = f"{i + 1}. {file}" - files_to_download = '\n'.join(files_to_download) - logger.info(f"Files to be downloaded are:\n{files_to_download}") - input("Press Enter to continue...") - if not toolkit: - if folder_path != "": - try: - downloader.download_github_folder( - url=github_url, - output_dir=model_path, - github_token=github_pat, - ) - logger.info(f"Successfully downloaded folder contents to {model_path}") - logger.info("Model initialization complete with GitHub folder download") - logger.info("Next steps:") - logger.info("1. Review the model configuration") - logger.info("2. Install any required dependencies manually") - logger.info("3. Test the model locally using 'clarifai model local-test'") - return - except Exception as e: - logger.error(f"Failed to download GitHub folder: {e}") - # Continue with the rest of the initialization process - github_url = None # Fall back to template mode + logger.info(f"Initializing model with {toolkit} toolkit...") + _copy_embedded_toolkit(toolkit, model_path) - elif branch and folder_path == "": - # When we have a branch but no specific folder path - logger.info( - f"Initializing model from GitHub repository: {github_url} (branch: {branch})" - ) + # Toolkit-specific customization (updates toolkit.model, model.py defaults, etc.) + user_id = ctx.obj.current.user_id + if toolkit == 'ollama': + customize_ollama_model(model_path, user_id, model_name) + elif toolkit == 'lmstudio': + customize_lmstudio_model(model_path, user_id, model_name) + elif toolkit in ('huggingface', 'vllm', 'sglang'): + customize_huggingface_model(model_path, user_id, model_name) - # Check if it's a local path or normalize the GitHub repo URL - if os.path.exists(github_url): - repo_url = github_url - else: - repo_url = format_github_repo_url(github_url) - repo_url = f"https://github.com/{owner}/{repo}" + # Patch config LAST to ensure sanitized model_id and checkpoint override + config_path = os.path.join(model_path, "config.yaml") + # Only set checkpoints for HF-based toolkits; ollama/lmstudio use toolkit.model instead + hf_repo = model_name if toolkit in ('huggingface', 'vllm', 'sglang') else None + _patch_config(config_path, model_id=model_id, checkpoints_repo_id=hf_repo) - try: - # Create a temporary directory for cloning - with tempfile.TemporaryDirectory(prefix="clarifai_model_") as clone_dir: - # Clone the repository with explicit branch parameter - if not clone_github_repo(repo_url, clone_dir, github_pat, branch): - logger.error(f"Failed to clone repository from {repo_url}") - github_url = None # Fall back to template mode - - else: - # Copy the entire repository content to target directory (excluding .git) - for item in os.listdir(clone_dir): - if item == '.git': - continue - - source_path = os.path.join(clone_dir, item) - target_path = os.path.join(model_path, item) - - if os.path.isdir(source_path): - shutil.copytree(source_path, target_path, dirs_exist_ok=True) - else: - shutil.copy2(source_path, target_path) - - logger.info(f"Successfully cloned repository to {model_path}") - logger.info( - "Model initialization complete with GitHub repository clone" - ) - logger.info("Next steps:") - logger.info("1. Review the model configuration") - logger.info("2. Install any required dependencies manually") - logger.info( - "3. Test the model locally using 'clarifai model local-test'" - ) - return - - except Exception as e: - logger.error(f"Failed to clone GitHub repository: {e}") - github_url = None # Fall back to template mode - - if toolkit: - logger.info(f"Initializing model from GitHub repository: {github_url}") - - # Check if it's a local path or normalize the GitHub repo URL - if os.path.exists(github_url): - repo_url = github_url - else: - repo_url = format_github_repo_url(github_url) + else: + # Template-based initialization (mcp, openai, python, or no toolkit) + model_type_id = toolkit if toolkit in TEMPLATE_TOOLKITS else None - try: - # Create a temporary directory for cloning - with tempfile.TemporaryDirectory(prefix="clarifai_model_") as clone_dir: - # Clone the repository with explicit branch parameter - if not clone_github_repo(repo_url, clone_dir, github_pat, branch): - logger.error(f"Failed to clone repository from {repo_url}") - github_url = None # Fall back to template mode - - else: - # Copy the entire repository content to target directory (excluding .git) - for item in os.listdir(clone_dir): - if item == '.git': - continue - - source_path = os.path.join(clone_dir, item) - target_path = os.path.join(model_path, item) - - if os.path.isdir(source_path): - shutil.copytree(source_path, target_path, dirs_exist_ok=True) - else: - shutil.copy2(source_path, target_path) + if model_type_id: + logger.info(f"Initializing {model_type_id} model from template...") + else: + logger.info("Initializing model with default template...") - except Exception as e: - logger.error(f"Failed to clone GitHub repository: {e}") - github_url = None - - if (user_id or model_name or port or context_length) and (toolkit == 'ollama'): - customize_ollama_model(model_path, user_id, model_name, port, context_length) - - if (user_id or model_name or port or context_length) and (toolkit == 'lmstudio'): - customize_lmstudio_model(model_path, user_id, model_name, port, context_length) - - if (user_id or model_name) and ( - toolkit == 'huggingface' or toolkit == 'vllm' or toolkit == 'sglang' - ): - # Update the config.yaml file with the provided model name - customize_huggingface_model(model_path, user_id, model_name) - - if github_url: - logger.info("Model initialization complete with GitHub repository") - logger.info("Next steps:") - logger.info("1. Review the model configuration") - logger.info("2. Install any required dependencies manually") - logger.info("3. Test the model locally using 'clarifai model local-test'") - - # Fall back to template-based initialization if no GitHub repo or if GitHub repo failed - if not github_url: - logger.info("Initializing model with default templates...") - input("Press Enter to continue...") - - from clarifai.cli.base import input_or_default from clarifai.cli.templates.model_templates import ( get_config_template, get_model_template, get_requirements_template, ) - # Collect additional parameters for OpenAI template - template_kwargs = {} - if model_type_id == "openai": - logger.info("Configuring OpenAI local runner...") - port = input_or_default("Enter port (default: 8000): ", "8000") - template_kwargs = {"port": port} - - # Create the 1/ subdirectory + # Create 1/model.py model_version_dir = os.path.join(model_path, "1") os.makedirs(model_version_dir, exist_ok=True) - - # Create model.py model_py_path = os.path.join(model_version_dir, "model.py") if os.path.exists(model_py_path): logger.warning(f"File {model_py_path} already exists, skipping...") else: - model_template = get_model_template(model_type_id, **template_kwargs) with open(model_py_path, 'w') as f: - f.write(model_template) + f.write(get_model_template(model_type_id)) logger.info(f"Created {model_py_path}") # Create requirements.txt @@ -831,9 +738,8 @@ def init( if os.path.exists(requirements_path): logger.warning(f"File {requirements_path} already exists, skipping...") else: - requirements_template = get_requirements_template(model_type_id) with open(requirements_path, 'w') as f: - f.write(requirements_template) + f.write(get_requirements_template(model_type_id)) logger.info(f"Created {requirements_path}") # Create config.yaml @@ -841,21 +747,32 @@ def init( if os.path.exists(config_path): logger.warning(f"File {config_path} already exists, skipping...") else: - config_model_type_id = DEFAULT_LOCAL_RUNNER_MODEL_TYPE # default - - config_template = get_config_template( - user_id=user_id, model_type_id=config_model_type_id - ) + config_model_type_id = model_type_id or DEFAULT_LOCAL_RUNNER_MODEL_TYPE with open(config_path, 'w') as f: - f.write(config_template) + f.write(get_config_template(model_type_id=config_model_type_id, model_id=model_id)) logger.info(f"Created {config_path}") - logger.info(f"Model initialization complete in {model_path}") - logger.info("Next steps:") - logger.info("1. Search for '# TODO: please fill in' comments in the generated files") - logger.info("2. Update the model configuration in config.yaml") - logger.info("3. Add your model dependencies to requirements.txt") - logger.info("4. Implement your model logic in 1/model.py") + # Auto-select instance based on model size when --model-name is provided + resolved_instance = None + if model_name and toolkit in ('vllm', 'sglang', 'huggingface'): + config_path = os.path.join(model_path, "config.yaml") + if os.path.exists(config_path): + from clarifai.utils.cli import dump_yaml, from_yaml + from clarifai.utils.compute_presets import recommend_instance + + pat_val = getattr(ctx.obj.current, 'pat', None) + base_url_val = getattr(ctx.obj.current, 'api_base', None) + config = from_yaml(config_path) + recommended, reason = recommend_instance( + config, pat=pat_val, base_url=base_url_val, toolkit=toolkit, model_path=model_path + ) + if recommended: + config.setdefault('compute', {})['instance'] = recommended + dump_yaml(config, config_path) + click.echo(f" Instance: {recommended} ({reason})") + resolved_instance = recommended + + _print_init_success(model_path, toolkit, instance=resolved_instance) def _ensure_hf_token(ctx, model_path): @@ -903,31 +820,29 @@ def _ensure_hf_token(ctx, model_path): logger.warning(f"Unexpected error ensuring HF_TOKEN: {e}") -@model.command(help="Upload a trained model.") +@model.command() @click.argument("model_path", type=click.Path(exists=True), required=False, default=".") @click.option( - '--stage', + '--platform', required=False, - type=click.Choice(['runtime', 'build', 'upload'], case_sensitive=True), - default="upload", - show_default=True, - help='The stage we are calling download checkpoints from. Typically this would "upload" and will download checkpoints if config.yaml checkpoints section has when set to "upload". Other options include "runtime" to be used in load_model or "upload" to be used during model upload. Set this stage to whatever you have in config.yaml to force downloading now.', + help='Docker build platform (e.g., "linux/amd64"). Overrides config.yaml.', ) @click.option( - '--skip_dockerfile', + '-v', + '--verbose', is_flag=True, - help='Flag to skip generating a dockerfile so that you can manually edit an already created dockerfile. If not provided, intelligently handle existing Dockerfiles with user confirmation.', -) -@click.option( - '--platform', - required=False, - help='Target platform(s) for Docker image build (e.g., "linux/amd64" or "linux/amd64,linux/arm64"). This overrides the platform specified in config.yaml.', + help='Show detailed build and upload logs.', ) @click.pass_context -def upload(ctx, model_path, stage, skip_dockerfile, platform): - """Upload a model to Clarifai. +def upload(ctx, model_path, platform, verbose): + """Upload a model to Clarifai (without deploying). + + \b + Builds a Docker image and uploads it to the Clarifai registry. + Use 'clarifai model deploy' to upload and deploy in one step. - MODEL_PATH: Path to the model directory. If not specified, the current directory is used by default. + \b + MODEL_PATH Model directory containing config.yaml (default: "."). """ from clarifai.runners.models.model_builder import upload_model @@ -937,27 +852,21 @@ def upload(ctx, model_path, stage, skip_dockerfile, platform): _ensure_hf_token(ctx, model_path) upload_model( model_path, - stage, - skip_dockerfile, platform=platform, pat=ctx.obj.current.pat, base_url=ctx.obj.current.api_base, + verbose=verbose, ) -@model.command(help="Download model checkpoint files.") -@click.argument( - "model_path", - type=click.Path(exists=True), - required=False, - default=".", -) +@model.command(name="download-checkpoints", hidden=True) +@click.argument("model_path", type=click.Path(exists=True), required=False, default=".") @click.option( '--out_path', type=click.Path(exists=False), required=False, default=None, - help='Option path to write the checkpoints to. This will place them in {out_path}/1/checkpoints If not provided it will default to {model_path}/1/checkpoints where the config.yaml is read.', + help='Path to write the checkpoints to.', ) @click.option( '--stage', @@ -965,636 +874,1031 @@ def upload(ctx, model_path, stage, skip_dockerfile, platform): type=click.Choice(['runtime', 'build', 'upload'], case_sensitive=True), default="build", show_default=True, - help='The stage we are calling download checkpoints from. Typically this would be in the build stage which is the default. Other options include "runtime" to be used in load_model or "upload" to be used during model upload. Set this stage to whatever you have in config.yaml to force downloading now.', + help='The stage to download checkpoints for.', ) @click.pass_context def download_checkpoints(ctx, model_path, out_path, stage): - """Download checkpoints from external source to local model_path - - MODEL_PATH: Path to the model directory. If not specified, the current directory is used by default. - """ - + """Download checkpoints from external source (used internally by Dockerfile build).""" from clarifai.runners.models.model_builder import ModelBuilder - validate_context(ctx) - _ensure_hf_token(ctx, model_path) + model_path = os.path.abspath(model_path) builder = ModelBuilder(model_path, download_validation_only=True) builder.download_checkpoints(stage=stage, checkpoint_path_override=out_path) -@model.command(help="Generate model method signatures.") -@click.argument( - "model_path", - type=click.Path(exists=True), - required=False, - default=".", -) +@model.command() +@click.argument('model_path', type=click.Path(), required=False, default=None) @click.option( - '--out_path', - type=click.Path(exists=False), - required=False, + '--instance', default=None, - help='Path to write the method signature defitions to. If not provided, use stdout.', -) -def signatures(model_path, out_path): - """Generate method signatures for the model. - - MODEL_PATH: Path to the model directory. If not specified, the current directory is used by default. - """ - - from clarifai.runners.models.model_builder import ModelBuilder - - builder = ModelBuilder(model_path, download_validation_only=True) - signatures = builder.method_signatures_yaml() - if out_path: - with open(out_path, 'w') as f: - f.write(signatures) - else: - click.echo(signatures) - - -@model.command(name="local-test", help="Execute all model unit tests locally.") -@click.argument( - "model_path", - type=click.Path(exists=True), - required=False, - default=".", -) -@click.option( - '--mode', - type=click.Choice(['env', 'container'], case_sensitive=False), - default='env', - show_default=True, - help='Specify how to test the model locally: "env" for virtual environment or "container" for Docker container. Defaults to "env".', -) -@click.option( - '--keep_env', - is_flag=True, - help='Keep the virtual environment after testing the model locally (applicable for virtualenv mode). Defaults to False.', + help='Hardware instance type (e.g., gpu-nvidia-a10g). Use "clarifai list-instances" to list options.', ) @click.option( - '--keep_image', - is_flag=True, - help='Keep the Docker image after testing the model locally (applicable for container mode). Defaults to False.', + '--model-url', + default=None, + help='Deploy an already-uploaded model by its Clarifai URL (skips upload).', ) @click.option( - '--skip_dockerfile', - is_flag=True, - help='Flag to skip generating a dockerfile so that you can manually edit an already created dockerfile. If not provided, intelligently handle existing Dockerfiles with user confirmation.', -) -@click.pass_context -def test_locally( - ctx, model_path, keep_env=False, keep_image=False, mode='env', skip_dockerfile=False -): - """Test model locally. - - MODEL_PATH: Path to the model directory. If not specified, the current directory is used by default. - """ - try: - from clarifai.runners.models import model_run_locally - - validate_context(ctx) - _ensure_hf_token(ctx, model_path) - if mode == 'env' and keep_image: - raise ValueError("'keep_image' is applicable only for 'container' mode") - if mode == 'container' and keep_env: - raise ValueError("'keep_env' is applicable only for 'env' mode") - - if mode == "env": - click.echo("Testing model locally in a virtual environment...") - model_run_locally.main(model_path, run_model_server=False, keep_env=keep_env) - elif mode == "container": - click.echo("Testing model locally inside a container...") - model_run_locally.main( - model_path, - inside_container=True, - run_model_server=False, - keep_image=keep_image, - skip_dockerfile=skip_dockerfile, - ) - click.echo("Model tested successfully.") - except Exception as e: - click.echo(f"Failed to test model locally: {e}", err=True) - - -@model.command(name="local-grpc", help="Run the model locally via a gRPC server.") -@click.argument( - "model_path", - type=click.Path(exists=True), - required=False, - default=".", + '--model-version-id', + default=None, + help='Specific model version to deploy (default: latest).', ) @click.option( - '--port', - '-p', + '--min-replicas', + default=1, type=int, - default=8000, - show_default=True, - help="The port to host the gRPC server for running the model locally. Defaults to 8000.", -) -@click.option( - '--mode', - type=click.Choice(['env', 'container'], case_sensitive=False), - default='env', show_default=True, - help='Specifies how to run the model: "env" for virtual environment or "container" for Docker container. Defaults to "env".', -) -@click.option( - '--keep_env', - is_flag=True, - help='Keep the virtual environment after testing the model locally (applicable for virtualenv mode). Defaults to False.', -) -@click.option( - '--keep_image', - is_flag=True, - help='Keep the Docker image after testing the model locally (applicable for container mode). Defaults to False.', -) -@click.option( - '--skip_dockerfile', - is_flag=True, - help='Flag to skip generating a dockerfile so that you can manually edit an already created dockerfile. If not provided, intelligently handle existing Dockerfiles with user confirmation.', -) -@click.pass_context -def run_locally(ctx, model_path, port, mode, keep_env, keep_image, skip_dockerfile=False): - """Run the model locally and start a gRPC server to serve the model. - - MODEL_PATH: Path to the model directory. If not specified, the current directory is used by default. - """ - model_path = os.path.abspath(model_path) - try: - from clarifai.runners.models import model_run_locally - - validate_context(ctx) - _ensure_hf_token(ctx, model_path) - if mode == 'env' and keep_image: - raise ValueError("'keep_image' is applicable only for 'container' mode") - if mode == 'container' and keep_env: - raise ValueError("'keep_env' is applicable only for 'env' mode") - - if mode == "env": - click.echo("Running model locally in a virtual environment...") - model_run_locally.main(model_path, run_model_server=True, keep_env=keep_env, port=port) - elif mode == "container": - click.echo("Running model locally inside a container...") - model_run_locally.main( - model_path, - inside_container=True, - run_model_server=True, - port=port, - keep_image=keep_image, - skip_dockerfile=skip_dockerfile, - ) - click.echo(f"Model server started locally from {model_path} in {mode} mode.") - except Exception as e: - click.echo(f"Failed to starts model server locally: {e}", err=True) - - -@model.command(name="local-runner", help="Run the model locally for dev, debug, or local compute.") -@click.argument( - "model_path", - type=click.Path(exists=True), - required=False, - default=".", + help='Minimum number of running replicas.', ) @click.option( - "--pool_size", + '--max-replicas', + default=5, type=int, - default=32, show_default=True, - help="The number of threads to use. On community plan, the compute time allocation is drained at a rate proportional to the number of threads.", -) # pylint: disable=range-builtin-not-iterating -@click.option( - '--suppress-toolkit-logs', - is_flag=True, - help='Show detailed logs including Ollama server output. By default, Ollama logs are suppressed.', + help='Maximum replicas for autoscaling.', ) @click.option( - "--mode", - type=click.Choice(['env', 'container', 'none'], case_sensitive=False), - default='none', - show_default=True, - help='Specifies how to run the model: "env" for virtual environment, "container" for Docker container, or "none" to skip creating environment and directly run the model. Defaults to "none".', + '--cloud', + default=None, + help='Cloud provider (e.g., aws, gcp). Auto-detected from --instance if omitted.', ) @click.option( - '--keep_image', - is_flag=True, - help='Keep the Docker image after testing the model locally (applicable for container mode). Defaults to False.', + '--region', + default=None, + help='Cloud region (e.g., us-east-1). Auto-detected from --instance if omitted.', ) @click.option( - '--health-check-port', - type=int, - default=8080, - show_default=True, - help='The port to run the health check server on. Defaults to 8080.', + '--compute-cluster-id', + default=None, + help='[Advanced] Existing compute cluster ID (skip auto-creation).', ) @click.option( - '--disable-health-check', - is_flag=True, - help='Disable the health check server.', + '--nodepool-id', + default=None, + help='[Advanced] Existing nodepool ID (skip auto-creation).', ) @click.option( - '--auto-find-health-check-port', + '-v', + '--verbose', is_flag=True, - help='Automatically find an available port starting from --health-check-port.', + help='Show detailed build, upload, and deployment logs.', ) @click.pass_context -def local_runner( +def deploy( ctx, model_path, - pool_size, - suppress_toolkit_logs, - mode, - keep_image, - health_check_port, - disable_health_check, - auto_find_health_check_port, + instance, + model_url, + model_version_id, + min_replicas, + max_replicas, + cloud, + region, + compute_cluster_id, + nodepool_id, + verbose, ): - """Run the model as a local runner to help debug your model connected to the API or to - leverage local compute resources manually. This relies on many variables being present in the env - of the currently selected context. If they are not present then default values will be used to - ease the setup of a local runner and your context yaml will be updated in place. The required - env vars are: + """Deploy a model to Clarifai cloud compute. \b - CLARIFAI_PAT: + Uploads, builds, and deploys in one step. Compute infrastructure + (cluster + nodepool) is auto-created when needed. \b - # for where the model that represents the local runner should be: - \b - CLARIFAI_USER_ID: - CLARIFAI_APP_ID: - CLARIFAI_MODEL_ID: - \b - # for where the local runner should be in a compute cluster - # note the user_id of the compute cluster is the same as the user_id of the model. + MODEL_PATH Local model directory to upload and deploy (default: "."). + Not needed when using --model-url. + \b - CLARIFAI_COMPUTE_CLUSTER_ID: - CLARIFAI_NODEPOOL_ID: + Examples: + clarifai model deploy ./my-model --instance a10g + clarifai model deploy --model-url https://clarifai.com/user/app/models/id --instance a10g - # The following will be created in your context since it's generated by the API + List all GPU/ CPU available instance for model deployment: + clarifai list-instances + clarifai list-instances --cloud gcp + """ + validate_context(ctx) + user_id = ctx.obj.current.user_id + app_id = getattr(ctx.obj.current, 'app_id', None) + + # Resolve model_path to absolute if provided + if model_path: + model_path = os.path.abspath(model_path) + if not os.path.isdir(model_path): + raise click.BadParameter(f"Model path '{model_path}' is not a directory.") + + from clarifai.runners.models.model_deploy import ModelDeployer + + deployer = ModelDeployer( + model_path=model_path, + model_url=model_url, + user_id=user_id, + app_id=app_id, + model_version_id=model_version_id, + instance_type=instance, + cloud_provider=cloud, + region=region, + compute_cluster_id=compute_cluster_id, + nodepool_id=nodepool_id, + min_replicas=min_replicas, + max_replicas=max_replicas, + pat=ctx.obj.current.pat, + base_url=ctx.obj.current.api_base, + verbose=verbose, + ) - CLARIFAI_RUNNER_ID: + result = deployer.deploy() + _print_deploy_result(result) - Additionally using the provided model path, if the config.yaml file does not contain the model - information that matches the above CLARIFAI_USER_ID, CLARIFAI_APP_ID, CLARIFAI_MODEL_ID then the - config.yaml will be updated to include the model information. This is to ensure that the model - that starts up in the local runner is the same as the one you intend to call in the API. +def _print_deploy_result(result): + """Print a formatted deployment result.""" + from clarifai.runners.models import deploy_output as out - MODEL_PATH: Path to the model directory. If not specified, the current directory is used by default. - MODE: Specifies how to run the model: "env" for virtual environment or "container" for Docker container. Defaults to "env". - KEEP_IMAGE: Keep the Docker image after testing the model locally (applicable for container mode). Defaults to False. - """ - from clarifai.client.user import User - from clarifai.runners.models.model_builder import ModelBuilder - from clarifai.runners.models.model_run_locally import ModelRunLocally - from clarifai.runners.server import ModelServer + model_url = result['model_url'] + timed_out = result.get('timed_out', False) - validate_context(ctx) - model_path = os.path.abspath(model_path) - _ensure_hf_token(ctx, model_path) - builder = ModelBuilder(model_path, download_validation_only=True) - manager = ModelRunLocally(model_path, model_builder=builder) + if timed_out: + # Monitoring timed out — deployment exists but pod isn't ready yet. + # The monitor phase already printed detailed diagnostics, so just show + # the deployment info and emphasize status/logs commands. + out.phase_header("Deployed (pod pending)") + click.echo() + out.warning("Deployment created but model pod is not running yet.") + click.echo() + else: + out.phase_header("Ready") + click.echo() + out.success("Model deployed successfully!") + click.echo() + + out.link("Model", model_url) + out.info("Version", result['model_version_id']) + out.info("Deployment", result['deployment_id']) + if result.get('instance_type'): + out.info("Instance", result['instance_type']) + if result.get('cloud_provider') or result.get('region'): + cloud = result.get('cloud_provider', '').upper() + region = result.get('region', '') + out.info("Cloud", f"{cloud} / {region}" if cloud and region else cloud or region) + + if not timed_out: + # Show client script and predict hints only when pod is actually ready + client_script = result.get('client_script') + if client_script: + click.echo("\n" + "=" * 60) + click.echo("# Here is a code snippet to use this model:") + click.echo("=" * 60) + click.echo(client_script) + click.echo("=" * 60) + + from clarifai.runners.utils.code_script import generate_predict_hint + + model_ref = f'{result["user_id"]}/{result["app_id"]}/models/{result["model_id"]}' + predict_cmd = generate_predict_hint( + result.get('method_signatures') or [], + model_ref, + deployment_id=result.get('deployment_id'), + ) - if disable_health_check: - health_check_port = -1 - elif auto_find_health_check_port: - health_check_port = find_available_port(health_check_port) + # Build playground URL from model_url (e.g. https://clarifai.com/user/app/models/id) + from urllib.parse import urlparse - port = 8080 - if mode == "env": - manager.create_temp_venv() - manager.install_requirements() + ui_base = f"{urlparse(model_url).scheme}://{urlparse(model_url).netloc}" + playground_url = ( + f"{ui_base}/playground?model={result['model_id']}__{result['model_version_id']}" + f"&user_id={result['user_id']}&app_id={result['app_id']}" + ) - dependencies = parse_requirements(model_path) - if mode != "container": - logger.info("> Checking local runner requirements...") - # Post check while running `clarifai model local-runner` we check if the toolkit is ollama - if not check_requirements_installed(dependencies=dependencies): - logger.error(f"Requirements not installed for model at {model_path}.") - raise click.Abort() + out.phase_header("Next Steps") + if timed_out: + # Prioritize event logs when pod isn't ready (model logs will be empty) + out.hint( + "Events", + f'clarifai model logs --deployment "{result["deployment_id"]}" --log-type events', + ) + out.hint("Status", f'clarifai model status --deployment "{result["deployment_id"]}"') + out.hint("Predict", predict_cmd) + out.link("Playground", playground_url) + else: + out.hint("Predict", predict_cmd) + out.link("Playground", playground_url) + out.hint("Logs", f'clarifai model logs --deployment "{result["deployment_id"]}"') + out.hint("Status", f'clarifai model status --deployment "{result["deployment_id"]}"') + out.hint("Undeploy", f'clarifai model undeploy --deployment "{result["deployment_id"]}"') - if "ollama" in dependencies or builder.config.get('toolkit', {}).get('provider') == 'ollama': - logger.info("Verifying Ollama installation...") - if not check_ollama_installed(): - logger.error( - "Ollama application is not installed. Please install it from `https://ollama.com/` to use the Ollama toolkit." - ) - raise click.Abort() - elif ( - "lmstudio" in dependencies - or builder.config.get('toolkit', {}).get('provider') == 'lmstudio' - ): - logger.info("Verifying LM Studio installation...") - if not check_lmstudio_installed(): - logger.error( - "LM Studio application is not installed. Please install it from `https://lmstudio.com/` to use the LM Studio toolkit." - ) - raise click.Abort() - # Load model config - config_file = os.path.join(model_path, 'config.yaml') - if not os.path.exists(config_file): - logger.error( - f"config.yaml not found in {model_path}. Please ensure you are passing the correct directory." +@model.command() +@click.option('--model-url', default=None, help='Clarifai model URL.') +@click.option('--model-id', default=None, help='Model ID (alternative to --model-url).') +@click.option( + '--deployment', default=None, help='Deployment ID (resolves model/nodepool automatically).' +) +@click.option('--model-version-id', default=None, help='Specific version (default: latest).') +@click.option( + '--follow/--no-follow', + default=True, + help='Continuously tail new logs. Use --no-follow to print and exit.', +) +@click.option( + '--duration', + default=None, + type=int, + help='Stop after N seconds (default: unlimited, Ctrl+C to stop).', +) +@click.option('--compute-cluster-id', default=None, help='[Advanced] Filter by compute cluster.') +@click.option('--nodepool-id', default=None, help='[Advanced] Filter by nodepool.') +@click.option( + '--log-type', + default='model', + type=click.Choice(['model', 'events'], case_sensitive=False), + help='Log type: model (stdout/stderr) or events (k8s scheduling/scaling).', +) +@click.pass_context +def logs( + ctx, + model_url, + model_id, + deployment, + model_version_id, + follow, + duration, + compute_cluster_id, + nodepool_id, + log_type, +): + """Stream logs from a deployed model's runner. + + \b + Shows stdout/stderr from the runner pod — useful for monitoring + model loading, inference, and debugging errors. + + \b + Examples: + clarifai model logs --deployment deploy-abc123 + clarifai model logs --deployment deploy-abc123 --log-type events + clarifai model logs --model-url https://clarifai.com/user/app/models/id + clarifai model logs --model-url --no-follow + clarifai model logs --model-url --duration 60 + """ + validate_context(ctx) + + from clarifai.runners.models.model_deploy import stream_model_logs + + user_id = ctx.obj.current.user_id + app_id = getattr(ctx.obj.current, 'app_id', None) + + # Resolve deployment ID → model/version/nodepool + if deployment: + from clarifai.runners.models.model_deploy import get_deployment + + dep = get_deployment( + deployment, + user_id=user_id, + pat=ctx.obj.current.pat, + base_url=ctx.obj.current.api_base, ) - raise click.Abort() - config = ModelBuilder._load_config(config_file) - uploaded_model_type_id = config.get('model', {}).get( - 'model_type_id', DEFAULT_LOCAL_RUNNER_MODEL_TYPE + w = dep.worker + if w and w.model: + model_id = model_id or w.model.id + user_id = w.model.user_id or user_id + app_id = w.model.app_id or app_id + if w.model.model_version and w.model.model_version.id: + model_version_id = model_version_id or w.model.model_version.id + if dep.nodepools: + np = dep.nodepools[0] + nodepool_id = nodepool_id or np.id + if np.compute_cluster and np.compute_cluster.id: + compute_cluster_id = compute_cluster_id or np.compute_cluster.id + + # Map user-friendly names to API log_type values + api_log_type = {"model": "runner", "events": "runner.events"}[log_type.lower()] + + stream_model_logs( + model_url=model_url, + model_id=model_id, + user_id=user_id, + app_id=app_id, + model_version_id=model_version_id, + compute_cluster_id=compute_cluster_id, + nodepool_id=nodepool_id, + pat=ctx.obj.current.pat, + base_url=ctx.obj.current.api_base, + follow=follow, + duration=duration, + log_type=api_log_type, + ) + + +def _parse_model_ref(ref): + """Parse a model reference like 'user/app/models/model' into components. + + Returns: + tuple: (user_id, app_id, model_id) or raises UserError. + """ + parts = ref.strip('/').split('/') + if len(parts) == 4 and parts[2] == 'models': + return parts[0], parts[1], parts[3] + # Also accept user/app/model (without 'models' keyword) + if len(parts) == 3: + return parts[0], parts[1], parts[2] + raise UserError( + f"Invalid model reference: '{ref}'\n" + " Expected format: user_id/app_id/models/model_id\n" + " Example: luv_2261/main/models/my-model" + ) + + +def _resolve_deployment_id(deployment, model_ref, model_url, user_id, pat, base_url): + """Resolve to a single deployment ID from --deployment, model ref, or --model-url. + + Returns: + tuple: (deployment_id, user_id) for the resolved deployment. + + Raises: + UserError: If no deployment found, or ambiguous (multiple deployments). + """ + from clarifai.runners.models.model_deploy import ( + list_deployments_for_model, + ) + from clarifai.urls.helper import ClarifaiUrlHelper + + if deployment: + return deployment, user_id + + # Resolve model identity + if model_ref: + ref_user, ref_app, ref_model = _parse_model_ref(model_ref) + elif model_url: + ref_user, ref_app, _, ref_model, _ = ClarifaiUrlHelper.split_clarifai_url(model_url) + else: + raise UserError( + "You must specify one of: MODEL_REF, --model-url, or --deployment.\n" + " Examples:\n" + " clarifai model status user/app/models/model\n" + " clarifai model status --model-url \n" + " clarifai model status --deployment " + ) + + deployments = list_deployments_for_model( + model_id=ref_model, + user_id=ref_user, + app_id=ref_app, + pat=pat, + base_url=base_url, ) - logger.info("> Verifying local runner setup...") - logger.info(f"Current context: {ctx.obj.current.name}") + if len(deployments) == 0: + raise UserError( + f"No deployments found for model '{ref_user}/{ref_app}/models/{ref_model}'.\n" + " Deploy it first:\n" + f" clarifai model deploy --model-url https://clarifai.com/{ref_user}/{ref_app}/models/{ref_model} --instance " + ) + if len(deployments) == 1: + return deployments[0].id, ref_user + # Multiple deployments — list them and ask the user to pick + ids = [d.id for d in deployments] + raise UserError( + f"Multiple deployments found for model '{ref_user}/{ref_app}/models/{ref_model}'.\n" + " Specify one with --deployment:\n" + "\n".join(f" --deployment {did}" for did in ids) + ) + + +def _print_deployment_detail(dep): + """Print formatted details for a deployment proto.""" + from clarifai.runners.models import deploy_output as out + + dep_id = dep.id + bar = "\u2500" * max(1, 56 - len(dep_id) - 4) + click.echo(click.style(f"\n\u2500\u2500 Deployment: {dep_id} {bar}", fg="cyan", bold=True)) + + # Model info + worker = dep.worker + if worker and worker.model: + m = worker.model + model_ref = f"{m.user_id}/{m.app_id}/models/{m.id}" + out.info("Model", model_ref) + if m.model_version and m.model_version.id: + out.info("Version", m.model_version.id[:12]) + + # Autoscale + if dep.autoscale_config: + ac = dep.autoscale_config + if ac.min_replicas or ac.max_replicas: + out.info("Min replicas", str(ac.min_replicas)) + out.info("Max replicas", str(ac.max_replicas)) + + # Nodepool / compute cluster + if dep.nodepools: + np = dep.nodepools[0] + out.info("Nodepool", np.id) + if np.compute_cluster and np.compute_cluster.id: + out.info("Compute cluster", np.compute_cluster.id) + + # Created timestamp + if dep.created_at and dep.created_at.seconds: + from datetime import datetime, timezone + + ts = datetime.fromtimestamp(dep.created_at.seconds, tz=timezone.utc) + out.info("Created", ts.strftime("%Y-%m-%d %H:%M:%S UTC")) + + +@model.command() +@click.argument('model_ref', required=False, default=None) +@click.option('--model-url', default=None, help='Clarifai model URL.') +@click.option('--deployment', default=None, help='Deployment ID to show status for.') +@click.pass_context +def status(ctx, model_ref, model_url, deployment): + """Show deployment status for a model. + + \b + Shows replica count, instance type, cloud, nodepool, and timing + for each deployment. + + \b + MODEL_REF Model reference: user_id/app_id/models/model_id + + \b + Examples: + clarifai model status user/app/models/model + clarifai model status --model-url https://clarifai.com/user/app/models/id + clarifai model status --deployment deploy-abc123 + """ + validate_context(ctx) + + from clarifai.runners.models.model_deploy import ( + get_deployment, + list_deployments_for_model, + ) + from clarifai.urls.helper import ClarifaiUrlHelper + user_id = ctx.obj.current.user_id - logger.info(f"Current user_id: {user_id}") - if not user_id: - logger.error(f"User with ID '{user_id}' not found. Use 'clarifai login' to setup context.") - raise click.Abort() pat = ctx.obj.current.pat - display_pat = pat_display(pat) if pat else "" - logger.info(f"Current PAT: {display_pat}") - if not pat: - logger.error( - "Personal Access Token (PAT) not found. Use 'clarifai login' to setup context." + base_url = ctx.obj.current.api_base + + if deployment: + # Show a single deployment + dep = get_deployment(deployment, user_id=user_id, pat=pat, base_url=base_url) + _print_deployment_detail(dep) + return + + # Resolve model → list all deployments + if model_ref: + ref_user, ref_app, ref_model = _parse_model_ref(model_ref) + elif model_url: + ref_user, ref_app, _, ref_model, _ = ClarifaiUrlHelper.split_clarifai_url(model_url) + else: + raise click.UsageError( + "Provide MODEL_REF, --model-url, or --deployment.\n" + " Example: clarifai model status user/app/models/model" ) - raise click.Abort() - user = User(user_id=user_id, pat=ctx.obj.current.pat, base_url=ctx.obj.current.api_base) - logger.debug("Checking if a local runner compute cluster exists...") - # see if ctx has CLARIFAI_COMPUTE_CLUSTER_ID, if not use default - try: - compute_cluster_id = ctx.obj.current.compute_cluster_id - except AttributeError: - compute_cluster_id = DEFAULT_LOCAL_RUNNER_COMPUTE_CLUSTER_ID - logger.info(f"Current compute_cluster_id: {compute_cluster_id}") + deployments = list_deployments_for_model( + model_id=ref_model, + user_id=ref_user, + app_id=ref_app, + pat=pat, + base_url=base_url, + ) - try: - compute_cluster = user.compute_cluster(compute_cluster_id) - if compute_cluster.cluster_type != 'local-dev': - raise ValueError( - f"Compute cluster {user_id}/{compute_cluster_id} is not a compute cluster of type 'local-dev'. Please use a compute cluster of type 'local-dev'." + if not deployments: + click.echo( + click.style( + f"\nNo deployments found for '{ref_user}/{ref_app}/models/{ref_model}'.", + fg="yellow", ) - try: - compute_cluster_id = ctx.obj.current.compute_cluster_id - except AttributeError: # doesn't exist in context but does in API then update the context. - ctx.obj.current.CLARIFAI_COMPUTE_CLUSTER_ID = compute_cluster.id - ctx.obj.to_yaml() # save to yaml file. - except ValueError: - raise - except Exception as e: - logger.warning(f"Failed to get compute cluster with ID '{compute_cluster_id}':\n{e}") - y = input( - f"Compute cluster not found. Do you want to create a new compute cluster {user_id}/{compute_cluster_id}? (y/n): " ) - if y.lower() != 'y': - raise click.Abort() - # Create a compute cluster with default configuration for local runner. - compute_cluster = user.create_compute_cluster( - compute_cluster_id=compute_cluster_id, - compute_cluster_config=DEFAULT_LOCAL_RUNNER_COMPUTE_CLUSTER_CONFIG, + click.echo( + f" Deploy it: clarifai model deploy --model-url " + f"https://clarifai.com/{ref_user}/{ref_app}/models/{ref_model} --instance " ) - ctx.obj.current.CLARIFAI_COMPUTE_CLUSTER_ID = compute_cluster_id - ctx.obj.to_yaml() # save to yaml file. + return - # Now check if there is a nodepool created in this compute cluser - try: - nodepool_id = ctx.obj.current.nodepool_id - except AttributeError: - nodepool_id = DEFAULT_LOCAL_RUNNER_NODEPOOL_ID - logger.info(f"Current nodepool_id: {nodepool_id}") + for dep in deployments: + _print_deployment_detail(dep) + + +@model.command() +@click.argument('model_ref', required=False, default=None) +@click.option('--model-url', default=None, help='Clarifai model URL.') +@click.option('--deployment', default=None, help='Deployment ID to remove.') +@click.pass_context +def undeploy(ctx, model_ref, model_url, deployment): + """Remove a deployment (permanently stop serving the model). + + \b + Deletes the deployment. The model version and infrastructure remain + intact — you can re-deploy at any time with 'clarifai model deploy'. + + \b + MODEL_REF Model reference: user_id/app_id/models/model_id + + \b + Examples: + clarifai model undeploy --deployment deploy-abc123 + clarifai model undeploy user/app/models/model + clarifai model undeploy --model-url https://clarifai.com/user/app/models/id + """ + validate_context(ctx) + + from clarifai.runners.models import deploy_output as out + from clarifai.runners.models.model_deploy import delete_deployment + + user_id = ctx.obj.current.user_id + pat = ctx.obj.current.pat + base_url = ctx.obj.current.api_base + + dep_id, resolved_user = _resolve_deployment_id( + deployment, model_ref, model_url, user_id, pat, base_url + ) + + delete_deployment(dep_id, user_id=resolved_user, pat=pat, base_url=base_url) + out.success(f"Deployment '{dep_id}' deleted.") + + +def _detect_toolkit(config, dependencies): + """Detect the inference toolkit from config or requirements.txt deps. + + Returns the toolkit name (e.g. 'vllm', 'sglang', 'ollama', 'lmstudio') + or None if no known toolkit is detected. + """ + provider = ( + config.get('toolkit', {}).get('provider', '').lower() if config.get('toolkit') else '' + ) + if provider: + return provider + for name in ('vllm', 'sglang', 'ollama', 'lmstudio'): + if name in dependencies: + return name + return None + + +def _check_platform_for_toolkit(toolkit): + """Raise UserError if toolkit requires Linux and we're not on Linux.""" + if toolkit in ('vllm', 'sglang') and platform.system() != 'Linux': + raise UserError( + f"'{toolkit}' requires a Linux environment with GPU access.\n" + " Options:\n" + " clarifai model serve --mode container # run in Docker locally\n" + " clarifai model deploy . # deploy to cloud GPU\n" + " clarifai model init --toolkit ollama # switch to local-friendly toolkit" + ) + + +def _run_local_grpc(model_path, mode, port, keep_image, verbose): + """Run a model locally via a standalone gRPC server (no PAT, no API).""" + import signal + + from clarifai.runners.models import deploy_output as out + from clarifai.runners.models.model_builder import ModelBuilder + from clarifai.runners.models.model_deploy import _quiet_sdk_logger + from clarifai.runners.models.model_run_locally import ModelRunLocally + from clarifai.runners.server import ModelServer + + model_path = os.path.abspath(model_path) + suppress = not verbose + + # ── Phase 1: Validate ────────────────────────────────────────────── + out.phase_header("Validate") + + with _quiet_sdk_logger(suppress): + builder = ModelBuilder(model_path, download_validation_only=True) + config = builder.config + model_config = config.get('model', {}) + + model_id = model_config.get('id', os.path.basename(model_path)) + model_type_id = model_config.get('model_type_id', DEFAULT_LOCAL_RUNNER_MODEL_TYPE) + + # Validate requirements for none mode only (env creates its own venv, container builds image) + dependencies = parse_requirements(model_path) + toolkit = _detect_toolkit(config, dependencies) + + # vLLM/SGLang need Linux + GPU — block early with clear alternatives + _check_platform_for_toolkit(toolkit) + + if mode not in ("container", "env"): + if not check_requirements_installed(dependencies=dependencies): + raise UserError(f"Requirements not installed for model at {model_path}.") + + if toolkit == 'ollama': + if not check_ollama_installed(): + raise UserError( + "Ollama is not installed. Install from https://ollama.com/ to use the Ollama toolkit." + ) + elif toolkit == 'lmstudio': + if not check_lmstudio_installed(): + raise UserError( + "LM Studio is not installed. Install from https://lmstudio.com/ to use the LM Studio toolkit." + ) + + # Get method signatures to generate test snippet + use_mocking = mode in ("container", "env") + with _quiet_sdk_logger(suppress): + method_signatures = builder.get_method_signatures(mocking=use_mocking) + + out.info("Model", model_id) + out.info("Type", model_type_id) + out.info("Port", str(port)) + + # ── Phase 2: Prepare environment ───────────────────────────────── + container_name = None + image_name = None + manager = None + cleanup_done = False + + def _do_cleanup(): + nonlocal cleanup_done + if cleanup_done: + return + cleanup_done = True + if mode == "container" and manager is not None: + try: + if container_name and manager.container_exists(container_name): + manager.stop_docker_container(container_name) + manager.remove_docker_container(container_name=container_name) + if not keep_image and image_name: + manager.remove_docker_image(image_name=image_name) + except Exception: + pass + out.status("Stopped.") + + original_sigint = signal.getsignal(signal.SIGINT) + + def _sigint_handler(signum, frame): + signal.signal(signal.SIGINT, original_sigint) + _do_cleanup() + raise KeyboardInterrupt + + signal.signal(signal.SIGINT, _sigint_handler) try: - nodepool = compute_cluster.nodepool(nodepool_id) - try: - nodepool_id = ctx.obj.current.nodepool_id - except AttributeError: # doesn't exist in context but does in API then update the context. - ctx.obj.current.CLARIFAI_NODEPOOL_ID = nodepool.id - ctx.obj.to_yaml() # save to yaml file. - except Exception as e: - logger.warning(f"Failed to get nodepool with ID '{nodepool_id}':\n{e}") - y = input( - f"Nodepool not found. Do you want to create a new nodepool {user_id}/{compute_cluster_id}/{nodepool_id}? (y/n): " + # Ensure config.yaml has fields required by PyPI clarifai in subprocess + if mode in ("container", "env"): + _ensure_config_defaults(model_path, model_type_id=model_type_id) + + if mode == "container": + manager = ModelRunLocally(model_path) + if not manager.is_docker_installed(): + raise UserError("Docker is not installed.") + with _quiet_sdk_logger(suppress): + manager.builder.create_dockerfile(generate_dockerfile=True) + image_tag = manager._docker_hash() + container_name = model_id.lower() + image_name = f"{container_name}:{image_tag}" + if not manager.docker_image_exists(image_name): + out.status("Building Docker image... ") + with _quiet_sdk_logger(suppress): + manager.build_docker_image(image_name=image_name) + elif mode == "env": + manager = ModelRunLocally(model_path) + out.status("Creating virtual environment... ") + manager.create_temp_venv() + out.status("Installing requirements... ") + manager.install_requirements() + + # ── Phase 3: Running ──────────────────────────────────────────── + out.phase_header("Running") + out.success(f"gRPC server running at localhost:{port}") + click.echo() + + from clarifai.runners.utils.code_script import generate_client_script + + snippet = generate_client_script( + method_signatures, + user_id=None, + app_id=None, + model_id=model_id, + local_grpc_port=port, + colorize=True, ) - if y.lower() != 'y': - raise click.Abort() - nodepool = compute_cluster.create_nodepool( - nodepool_config=DEFAULT_LOCAL_RUNNER_NODEPOOL_CONFIG, nodepool_id=nodepool_id + out.status("Test with Python:") + click.echo(snippet) + + out.status("Press Ctrl+C to stop.") + click.echo() + + # ── Serve ─────────────────────────────────────────────────────── + if mode == "container": + manager.run_docker_container( + image_name=image_name, + container_name=container_name, + port=port, + ) + elif mode == "env": + manager.run_model_server(port) + else: + # none mode: run in-process + with _quiet_sdk_logger(suppress): + server = ModelServer(model_path=model_path) + server.serve(port=port, grpc=True) + except KeyboardInterrupt: + pass + except Exception as e: + click.echo() + out.warning(f"Model failed: {e}") + finally: + signal.signal(signal.SIGINT, original_sigint) + _do_cleanup() + + +@model.command(name="serve", aliases=["local-runner"]) +@click.argument( + "model_path", + type=click.Path(exists=True), + required=False, + default=".", +) +@click.option( + "--mode", + type=click.Choice(['none', 'env', 'container'], case_sensitive=False), + default='none', + show_default=True, + help='Execution environment. none: use current Python (fastest, deps must be installed). env: auto-create virtualenv and install deps. container: build and run a Docker image.', +) +@click.option( + '--grpc', + is_flag=True, + help='Standalone gRPC server (no login required). Without this flag, the model registers with the Clarifai API for Playground and API access.', +) +@click.option( + '-p', + '--port', + type=int, + default=8000, + show_default=True, + help="Server port (used with --grpc).", +) +@click.option( + "--concurrency", + type=int, + default=32, + show_default=True, + help="Maximum number of concurrent requests.", +) +@click.option( + '--keep-image', + is_flag=True, + help='Keep Docker image after exit (only with --mode container).', +) +@click.option( + '-v', + '--verbose', + is_flag=True, + help='Show detailed SDK and server logs.', +) +@click.pass_context +def serve_cmd(ctx, model_path, grpc, mode, port, concurrency, keep_image, verbose): + """Run a model locally for development and testing. + + \b + Starts the model and registers it with Clarifai so you can send + predictions via the API, SDK, or Playground UI. Use --grpc for a + standalone gRPC server with no API connection. Cleans up on Ctrl+C. + + \b + MODEL_PATH Model directory containing config.yaml (default: "."). + + \b + Modes: + none Run in current Python env (fastest, deps pre-installed) + env Auto-create a virtualenv, install deps, then run + container Build a Docker image with all deps, then run + + \b + Examples: + clarifai model serve ./my-model # current env, API-connected + clarifai model serve --mode env # auto-install deps in venv + clarifai model serve --mode container # run inside Docker + clarifai model serve --grpc # offline gRPC server + clarifai model serve --grpc --port 9000 # custom port + clarifai model serve --mode container --keep-image + """ + if grpc: + # Standalone gRPC server — no PAT, no API + _run_local_grpc(model_path, mode, port, keep_image, verbose) + return + + from clarifai.client.user import User + from clarifai.runners.models import deploy_output as out + from clarifai.runners.models.model_builder import ModelBuilder + from clarifai.runners.models.model_deploy import _quiet_sdk_logger + from clarifai.runners.models.model_run_locally import ModelRunLocally + from clarifai.runners.server import ModelServer + from clarifai.runners.utils import code_script + + validate_context(ctx) + model_path = os.path.abspath(model_path) + suppress = not verbose + + # ── Phase 1: Validate ────────────────────────────────────────────── + out.phase_header("Validate") + + with _quiet_sdk_logger(suppress): + builder = ModelBuilder(model_path, download_validation_only=True) + config = builder.config + model_config = config.get('model', {}) + + model_id = model_config.get('id') + if not model_id: + raise UserError( + "model.id is required in config.yaml.\n" + " Add to your config.yaml:\n" + " model:\n" + " id: my-model" ) - ctx.obj.current.CLARIFAI_NODEPOOL_ID = nodepool_id - ctx.obj.to_yaml() # save to yaml file. - logger.debug("Checking if model is created to call for local development...") - # see if ctx has CLARIFAI_APP_ID, if not use default - try: - app_id = ctx.obj.current.app_id - except AttributeError: - app_id = DEFAULT_LOCAL_RUNNER_APP_ID - logger.info(f"Current app_id: {app_id}") + model_type_id = model_config.get('model_type_id', DEFAULT_LOCAL_RUNNER_MODEL_TYPE) - try: - app = user.app(app_id) + # Resolve user_id: config → context → error + user_id = model_config.get('user_id') + if not user_id: + try: + user_id = ctx.obj.current.user_id + except AttributeError: + pass + if not user_id: + raise UserError( + "user_id not found in config.yaml or CLI context.\n" + " Run 'clarifai login' to set up credentials." + ) + + # Resolve app_id: config → context → default + app_id = model_config.get('app_id') + if not app_id: try: app_id = ctx.obj.current.app_id - except AttributeError: # doesn't exist in context but does in API then update the context. - ctx.obj.current.CLARIFAI_APP_ID = app.id - ctx.obj.to_yaml() # save to yaml file. - except Exception as e: - logger.warning(f"Failed to get app with ID '{app_id}':\n{e}") - y = input(f"App not found. Do you want to create a new app {user_id}/{app_id}? (y/n): ") - if y.lower() != 'y': - raise click.Abort() - app = user.create_app(app_id) - ctx.obj.current.CLARIFAI_APP_ID = app_id - ctx.obj.to_yaml() # save to yaml file. + except AttributeError: + pass + if not app_id: + app_id = DEFAULT_LOCAL_RUNNER_APP_ID - # Within this app we now need a model to call as the local runner. - try: - model_id = ctx.obj.current.model_id - except AttributeError: - model_id = DEFAULT_LOCAL_RUNNER_MODEL_ID - logger.info(f"Current model_id: {model_id}") + pat = ctx.obj.current.pat + if not pat: + raise UserError( + "Personal Access Token (PAT) not found.\n Run 'clarifai login' to set up credentials." + ) + base_url = ctx.obj.current.api_base - try: - model = app.model(model_id) - current_model_type_id = model.model_type_id - try: - model_id = ctx.obj.current.model_id - except AttributeError: # doesn't exist in context but does in API then update the context. - ctx.obj.current.CLARIFAI_MODEL_ID = model.id - ctx.obj.to_yaml() # save to yaml file. - if current_model_type_id != uploaded_model_type_id: - logger.warning( - f"Model type ID mismatch: expected '{uploaded_model_type_id}', found '{current_model_type_id}'. Deleting the model." + # Validate requirements before loading method signatures + # Skip for container (builds image) and env (creates its own venv) + dependencies = parse_requirements(model_path) + toolkit = _detect_toolkit(config, dependencies) + + # vLLM/SGLang need Linux + GPU — block early with clear alternatives + _check_platform_for_toolkit(toolkit) + + if mode not in ("container", "env"): + if not check_requirements_installed(dependencies=dependencies): + raise UserError(f"Requirements not installed for model at {model_path}.") + + if toolkit == 'ollama': + if not check_ollama_installed(): + raise UserError( + "Ollama is not installed. Install from https://ollama.com/ to use the Ollama toolkit." + ) + elif toolkit == 'lmstudio': + if not check_lmstudio_installed(): + raise UserError( + "LM Studio is not installed. Install from https://lmstudio.com/ to use the LM Studio toolkit." ) - app.delete_model(model_id) - raise Exception - except Exception as e: - logger.warning(f"Failed to get model with ID '{model_id}':\n{e}") - y = input( - f"Model not found. Do you want to create a new model {user_id}/{app_id}/models/{model_id}? (y/n): " - ) - if y.lower() != 'y': - raise click.Abort() - model = app.create_model(model_id, model_type_id=uploaded_model_type_id) - ctx.obj.current.CLARIFAI_MODEL_TYPE_ID = uploaded_model_type_id - ctx.obj.current.CLARIFAI_MODEL_ID = model_id - ctx.obj.to_yaml() # save to yaml file. + # Method signatures from ModelBuilder (same as upload/deploy). + # Use mocking=False for "none" mode since requirements are verified installed. + # mocking=True pollutes sys.modules with MagicMock'd third-party packages inside + # clarifai modules (e.g. FastMCP in stdio_mcp_class), which breaks ModelServer.__init__ + # when it later tries to load the model for real. + # For container/env modes, deps may not be in current env, so mocking is needed. + use_mocking = mode in ("container", "env") + with _quiet_sdk_logger(suppress): + method_signatures = builder.get_method_signatures(mocking=use_mocking) - # Now we need to create a version for the model if no version exists. Only need one version that - # mentions it's a local runner. - model_versions = list(model.list_versions()) - method_signatures = manager._get_method_signatures() + out.info("Model", f"{user_id}/{app_id}/models/{model_id}") + out.info("Type", model_type_id) - create_new_version = False - if len(model_versions) == 0: - logger.warning("No model versions found. Creating a new version for local runner.") - create_new_version = True - else: - # Try to patch the latest version, and fallback to creating a new one if that fails. - latest_version = model_versions[0] - logger.warning(f"Attempting to patch latest version: {latest_version.model_version.id}") + _ensure_hf_token(ctx, model_path) + + # ── Phase 2: Setup ───────────────────────────────────────────────── + out.phase_header("Setup") + + # Track what we create for cleanup + created = {} # resource_name → cleanup_info + with _quiet_sdk_logger(suppress): + user = User(user_id=user_id, pat=pat, base_url=base_url) + + # 1. Compute cluster (shared, reusable — never cleaned up) + cc_id = DEFAULT_LOCAL_RUNNER_COMPUTE_CLUSTER_ID try: - patched_model = model.patch_version( - version_id=latest_version.model_version.id, - pretrained_model_config={"local_dev": True}, - method_signatures=method_signatures, + user.compute_cluster(cc_id) + out.status("Compute cluster ready") + except Exception: + out.status("Creating compute cluster... ", nl=False) + user.create_compute_cluster( + compute_cluster_id=cc_id, + compute_cluster_config=DEFAULT_LOCAL_RUNNER_COMPUTE_CLUSTER_CONFIG, ) - patched_model.load_info() - version = patched_model.model_version - logger.info(f"Successfully patched version {version.id}") - ctx.obj.current.CLARIFAI_MODEL_VERSION_ID = version.id - ctx.obj.to_yaml() # save to yaml file. - except Exception as e: - logger.warning(f"Failed to patch model version: {e}. Creating a new version instead.") - create_new_version = True - - if create_new_version: - version = model.create_version( - pretrained_model_config={"local_dev": True}, method_signatures=method_signatures - ).model_version - ctx.obj.current.CLARIFAI_MODEL_VERSION_ID = version.id - ctx.obj.to_yaml() - - logger.info(f"Current model version {version.id}") - - worker = { - "model": { - "id": f"{model.id}", - "model_version": { - "id": f"{version.id}", - }, - "user_id": f"{user_id}", - "app_id": f"{app_id}", - }, - } + click.echo("done") - try: - # if it's already in our context then we'll re-use the same one. - # note these are UUIDs, we cannot provide a runner ID. - runner_id = ctx.obj.current.runner_id + # 2. Nodepool (shared, reusable — never cleaned up) + np_id = DEFAULT_LOCAL_RUNNER_NODEPOOL_ID + try: + nodepool = user.compute_cluster(cc_id).nodepool(np_id) + out.status("Nodepool ready") + except Exception: + out.status("Creating nodepool... ", nl=False) + nodepool = user.compute_cluster(cc_id).create_nodepool( + nodepool_config=DEFAULT_LOCAL_RUNNER_NODEPOOL_CONFIG, + nodepool_id=np_id, + ) + click.echo("done") + # 3. App (shared, reusable — never cleaned up) try: - runner = nodepool.runner(runner_id) - # ensure the deployment is using the latest version. - if runner.worker.model.model_version.id != version.id: - nodepool.delete_runners([runner_id]) - logger.warning("Deleted runner that was for an old model version ID.") - raise AttributeError( - "Runner deleted because it was associated with an outdated model version." - ) - except Exception as e: - logger.warning(f"Failed to get runner with ID '{runner_id}':\n{e}") - raise AttributeError("Runner not found in nodepool.") - except AttributeError: - logger.info( - f"Creating the local runner tying this '{user_id}/{app_id}/models/{model.id}' model (version: {version.id}) to the '{user_id}/{compute_cluster_id}/{nodepool_id}' nodepool." - ) + app = user.app(app_id) + out.status("App ready") + except Exception: + out.status("Creating app... ", nl=False) + app = user.create_app(app_id) + click.echo("done") + + # 4. Model (ephemeral if we create it) + model_existed = False try: - logger.info("Checking for existing runners in the nodepool...") - runners = nodepool.list_runners( - model_version_ids=[version.id], - ) - runner_id = None - for runner in runners: - logger.info( - f"Found existing runner {runner.id} for model version {version.id}. Reusing it." - ) - runner_id = runner.id - break # use the first one we find. - if runner_id is None: - logger.warning("No existing runners found in nodepool. Creating a new one.\n") - runner = nodepool.create_runner( - runner_config={ - "runner": { - "description": "local runner for model testing", - "worker": worker, - "num_replicas": 1, - } - } + model = app.model(model_id) + model_existed = True + if model.model_type_id != model_type_id: + out.warning( + f"Model type mismatch (expected '{model_type_id}', " + f"found '{model.model_type_id}'). Recreating." ) - runner_id = runner.id - except Exception as e: - logger.warning( - f"Failed to list existing runners in nodepool {e}...Creating a new one.\n" - ) - runner = nodepool.create_runner( - runner_config={ - "runner": { - "description": "local runner for model testing", - "worker": worker, - "num_replicas": 1, - } - } - ) - runner_id = runner.id - ctx.obj.current.CLARIFAI_RUNNER_ID = runner.id - ctx.obj.to_yaml() - - logger.info(f"Current runner_id: {runner_id}") + app.delete_model(model_id) + model_existed = False + raise Exception("recreate") + out.status("Model ready") + except Exception: + if not model_existed: + out.status("Creating model... ", nl=False) + model = app.create_model(model_id, model_type_id=model_type_id) + created['model'] = model_id + click.echo("done") + + # 5. Model version (always created fresh — always cleaned up) + out.status("Creating model version... ", nl=False) + version_model = model.create_version( + pretrained_model_config={"local_dev": True}, + method_signatures=method_signatures, + ) + version_model.load_info() + version_id = version_model.model_version.id + created['model_version'] = version_id + click.echo(f"done ({version_id[:8]})") - # To make it easier to call the model without specifying a runner selector - # we will also create a deployment tying the model to the nodepool. - try: - deployment_id = ctx.obj.current.deployment_id - except AttributeError: - deployment_id = DEFAULT_LOCAL_RUNNER_DEPLOYMENT_ID - try: - deployment = nodepool.deployment(deployment_id) - # ensure the deployment is using the latest version. - if deployment.worker.model.model_version.id != version.id: - nodepool.delete_deployments([deployment_id]) - logger.warning("Deleted deployment that was for an old model version ID.") - raise Exception( - "Deployment deleted because it was associated with an outdated model version." - ) + # 6. Stale deployment cleanup (from previous crash) + deployment_id = f"local-{model_id}" try: - deployment_id = ctx.obj.current.deployment_id - except AttributeError: # doesn't exist in context but does in API then update the context. - ctx.obj.current.CLARIFAI_DEPLOYMENT_ID = deployment.id - ctx.obj.to_yaml() # save to yaml file. - except Exception as e: - logger.warning(f"Failed to get deployment with ID {deployment_id}:\n{e}") - y = input( - f"Deployment not found. Do you want to create a new deployment {user_id}/{compute_cluster_id}/{nodepool_id}/{deployment_id}? (y/n): " + nodepool.deployment(deployment_id) + nodepool.delete_deployments([deployment_id]) + except Exception: + pass + + # 7. Runner (always created fresh — always cleaned up) + worker = { + "model": { + "id": model_id, + "model_version": {"id": version_id}, + "user_id": user_id, + "app_id": app_id, + } + } + out.status("Creating runner... ", nl=False) + runner = nodepool.create_runner( + runner_config={ + "runner": { + "description": f"local runner for {model_id}", + "worker": worker, + "num_replicas": 1, + } + } ) - if y.lower() != 'y': - raise click.Abort() + runner_id = runner.id + created['runner'] = runner_id + click.echo("done") + + # 8. Deployment (always created fresh — always cleaned up) + out.status("Creating deployment... ", nl=False) nodepool.create_deployment( deployment_id=deployment_id, deployment_config={ "deployment": { - "scheduling_choice": 3, # 3 means by price + "scheduling_choice": 3, "worker": worker, "nodepools": [ { - "id": f"{nodepool_id}", + "id": np_id, "compute_cluster": { - "id": f"{compute_cluster_id}", - "user_id": f"{user_id}", + "id": cc_id, + "user_id": user_id, }, } ], @@ -1602,129 +1906,205 @@ def local_runner( } }, ) - ctx.obj.current.CLARIFAI_DEPLOYMENT_ID = deployment_id - ctx.obj.to_yaml() # save to yaml file. - - logger.info(f"Current deployment_id: {deployment_id}") - - # Now that we have all the context in ctx.obj, we need to update the config.yaml in - # the model_path directory with the model object containing user_id, app_id, model_id, version_id - # The config.yaml doens't match what we created above. - if 'model' in config and model_id != config['model'].get('id'): - logger.info(f"Current model section of config.yaml: {config.get('model', {})}") - y = input( - "Do you want to backup config.yaml to config.yaml.bk then update the config.yaml with the new model information? (y/n): " - ) - if y.lower() != 'y': - raise click.Abort() - config = ModelBuilder._set_local_runner_model( - config, user_id, app_id, model_id, uploaded_model_type_id - ) - ModelBuilder._backup_config(config_file) - ModelBuilder._save_config(config_file, config) + created['deployment'] = deployment_id + click.echo("done") - # Post check while running `clarifai model local-runner` we check if the toolkit is ollama - if builder.config.get('toolkit', {}).get('provider') == 'ollama': + # Toolkit customization (before serving) + if config.get('toolkit', {}).get('provider') == 'ollama': try: - logger.info("Customizing Ollama model with provided parameters...") - customize_ollama_model( - model_path=model_path, - user_id=user_id, - verbose=False if suppress_toolkit_logs else True, - ) + customize_ollama_model(model_path=model_path, user_id=user_id, verbose=verbose) except Exception as e: - logger.error(f"Failed to customize Ollama model: {e}") - raise click.Abort() - elif builder.config.get('toolkit', {}).get('provider') == 'lmstudio': + raise UserError(f"Failed to customize Ollama model: {e}") + elif config.get('toolkit', {}).get('provider') == 'lmstudio': try: - logger.info("Customizing LM Studio model with provided parameters...") - customize_lmstudio_model( - model_path=model_path, - user_id=user_id, - ) + customize_lmstudio_model(model_path=model_path, user_id=user_id) except Exception as e: - logger.error(f"Failed to customize LM Studio model: {e}") - raise click.Abort() - - logger.info("✅ Starting local runner...") - - def print_code_snippet(): - if ctx.obj.current is None: - logger.debug("Context is None. Skipping code snippet generation.") - else: - from clarifai.runners.utils import code_script - - snippet = code_script.generate_client_script( - method_signatures, - user_id=ctx.obj.current.user_id, - app_id=ctx.obj.current.app_id, - model_id=ctx.obj.current.model_id, - deployment_id=ctx.obj.current.deployment_id, - base_url=ctx.obj.current.api_base, - colorize=True, - ) - logger.info( - "✅ Your model is running locally and is ready for requests from the API...\n" - ) - logger.info( - f"> Code Snippet: To call your model via the API, use this code snippet:\n{snippet}" - ) - logger.info( - f"> Playground: To chat with your model, visit: {ctx.obj.current.ui}/playground?model={ctx.obj.current.model_id}__{ctx.obj.current.model_version_id}&user_id={ctx.obj.current.user_id}&app_id={ctx.obj.current.app_id}\n" - ) - logger.info( - f"> API URL: To call your model via the API, use this model URL: {ctx.obj.current.ui}/{ctx.obj.current.user_id}/{ctx.obj.current.app_id}/models/{ctx.obj.current.model_id}\n" - ) - logger.info("Press CTRL+C to stop the runner.\n") + raise UserError(f"Failed to customize LM Studio model: {e}") serving_args = { - "pool_size": pool_size, - "num_threads": pool_size, + "pool_size": concurrency, + "num_threads": concurrency, "user_id": user_id, - "compute_cluster_id": compute_cluster_id, - "nodepool_id": nodepool_id, + "compute_cluster_id": cc_id, + "nodepool_id": np_id, "runner_id": runner_id, - "base_url": ctx.obj.current.api_base, - "pat": ctx.obj.current.pat, - "context": ctx.obj.current, - "health_check_port": health_check_port, + "base_url": base_url, + "pat": pat, + "health_check_port": 0, # OS-assigned port; avoids collisions between local runners } - if mode == "container": - try: - if not manager.is_docker_installed(): - raise click.abort() + container_name = None + image_name = None + manager = None + cleanup_done = False + + def _cleanup(): + out.phase_header("Stopping") + with _quiet_sdk_logger(suppress): + if 'deployment' in created: + out.status("Deleting deployment... ", nl=False) + try: + nodepool.delete_deployments([created['deployment']]) + click.echo("done") + except Exception: + click.echo("failed") + if 'runner' in created: + out.status("Deleting runner... ", nl=False) + try: + nodepool.delete_runners([created['runner']]) + click.echo("done") + except Exception: + click.echo("failed") + if 'model_version' in created: + out.status("Deleting model version... ", nl=False) + try: + model.delete_version(version_id=created['model_version']) + click.echo("done") + except Exception: + click.echo("failed") + if 'model' in created: + out.status("Deleting model... ", nl=False) + try: + app.delete_model(created['model']) + click.echo("done") + except Exception: + click.echo("failed") + out.status("Stopped.") + + def _do_cleanup(): + nonlocal cleanup_done + if cleanup_done: + return + cleanup_done = True + # Container cleanup (Docker) + if mode == "container" and manager is not None: + try: + if container_name and manager.container_exists(container_name): + manager.stop_docker_container(container_name) + manager.remove_docker_container(container_name=container_name) + if not keep_image and image_name: + manager.remove_docker_image(image_name=image_name) + except Exception: + pass + # API resource cleanup (always) + _cleanup() + + # Register SIGINT handler so cleanup runs before BaseRunner's os._exit(130). + # BaseRunner catches KeyboardInterrupt internally and calls os._exit(130), + # which bypasses try/finally. Our signal handler fires first. + import signal + + original_sigint = signal.getsignal(signal.SIGINT) + + def _sigint_handler(signum, frame): + signal.signal(signal.SIGINT, original_sigint) # Restore so second Ctrl+C force-kills + _do_cleanup() + raise KeyboardInterrupt + + signal.signal(signal.SIGINT, _sigint_handler) + + try: + # ── Phase 3: Prepare environment ──────────────────────────────── + # Ensure config.yaml has fields required by PyPI clarifai in subprocess + if mode in ("container", "env"): + _ensure_config_defaults(model_path, model_type_id=model_type_id) + if mode == "container": + manager = ModelRunLocally(model_path) + if not manager.is_docker_installed(): + raise UserError("Docker is not installed.") + with _quiet_sdk_logger(suppress): + manager.builder.create_dockerfile(generate_dockerfile=True) image_tag = manager._docker_hash() - model_id = manager.config['model']['id'].lower() - # must be in lowercase - image_name = f"{model_id}:{image_tag}" - container_name = model_id + container_name = model_id.lower() + image_name = f"{container_name}:{image_tag}" if not manager.docker_image_exists(image_name): + out.status("Building Docker image... ") manager.build_docker_image(image_name=image_name) + elif mode == "env": + manager = ModelRunLocally(model_path) + out.status("Creating virtual environment... ") + manager.create_temp_venv() + out.status("Installing requirements... ") + manager.install_requirements() + + # ── Phase 4: Running ──────────────────────────────────────────── + out.phase_header("Running") + out.success("Model is ready for API requests!") + click.echo() + + # Code snippet + snippet = code_script.generate_client_script( + method_signatures, + user_id=user_id, + app_id=app_id, + model_id=model_id, + deployment_id=deployment_id, + base_url=base_url, + colorize=True, + ) + click.echo(snippet) + + ui_base = getattr(ctx.obj.current, 'ui', None) or "https://clarifai.com" + playground_url = ( + f"{ui_base}/playground?model={model_id}__{version_id}" + f"&user_id={user_id}&app_id={app_id}" + ) + model_url = f"{ui_base}/{user_id}/{app_id}/models/{model_id}" + model_ref = f"{user_id}/{app_id}/models/{model_id}" + predict_cmd = code_script.generate_predict_hint( + method_signatures or [], model_ref, deployment_id=deployment_id + ) - manager.build_docker_image(image_name=image_name) - print_code_snippet() + out.phase_header("Next Steps") + out.hint("Predict", predict_cmd) + out.link("Playground", playground_url) + out.link("Model URL", model_url) + click.echo() + out.status("Press Ctrl+C to stop.") + click.echo() + + # ── Serve ─────────────────────────────────────────────────────── + if mode == "container": manager.run_docker_container( image_name=image_name, container_name=container_name, - port=port, + port=8080, is_local_runner=True, - env_vars={"CLARIFAI_PAT": ctx.obj.current.pat}, + env_vars={"CLARIFAI_PAT": pat}, **serving_args, ) - - finally: - if manager.container_exists(container_name): - manager.stop_docker_container(container_name) - manager.remove_docker_container(container_name=container_name) - if not keep_image: - manager.remove_docker_image(image_name=image_name) - else: - print_code_snippet() - # This reads the config.yaml from the model_path so we alter it above first. - server = ModelServer(model_path=model_path, model_runner_local=None, model_builder=builder) - server.serve(**serving_args) + elif mode == "env": + # Run via venv subprocess so model code uses venv's packages + # Filter to args accepted by clarifai.runners.server CLI + runner_args = { + k: v + for k, v in serving_args.items() + if k + in ( + 'pool_size', + 'num_threads', + 'user_id', + 'compute_cluster_id', + 'nodepool_id', + 'runner_id', + 'base_url', + 'pat', + ) + } + manager.run_model_server(grpc=False, **runner_args) + else: + server = ModelServer(model_path=model_path, model_runner_local=None) + server.serve(**serving_args) + except KeyboardInterrupt: + pass + except Exception as e: + # Model failed to load or serve — show clean error, then cleanup below + click.echo() + out.warning(f"Model failed: {e}") + finally: + signal.signal(signal.SIGINT, original_sigint) + _do_cleanup() def _parse_json_param(param_value, param_name): @@ -1835,112 +2215,485 @@ def _validate_compute_params(compute_cluster_id, nodepool_id, deployment_id): raise click.Abort() -@model.command(help="Perform a prediction using the model.") +def _resolve_model_ref(model_ref, ui_base=None): + """Resolve a model reference to a full Clarifai URL. + + Accepts: + - Full URL: https://clarifai.com/user/app/models/model → passthrough + - Shorthand: user/app/models/model → prepend UI base + + Args: + model_ref: Model reference string. + ui_base: UI base URL. Defaults to DEFAULT_UI. + + Returns: + str: Full Clarifai model URL. + + Raises: + click.UsageError: If format is invalid. + """ + from clarifai.utils.constants import DEFAULT_UI + + if not model_ref: + return None + + if model_ref.startswith(("http://", "https://")): + return model_ref + + parts = model_ref.split("/") + if len(parts) == 4 and parts[2] == "models": + base = ui_base or DEFAULT_UI + return f"{base.rstrip('/')}/{model_ref}" + + raise click.UsageError( + f"Invalid model reference: '{model_ref}'. " + "Use user_id/app_id/models/model_id or a full URL." + ) + + +def _get_first_str_param(model_client, method_name): + """Find the first string-typed input parameter name for a method. + + Args: + model_client: ModelClient instance with fetched signatures. + method_name: Method name to inspect. + + Returns: + str or None: The parameter name, or None if no string param found. + """ + from clarifai_grpc.grpc.api import resources_pb2 + + if not model_client._defined: + model_client.fetch() + method_sig = model_client._method_signatures.get(method_name) + if not method_sig: + return None + for field in method_sig.input_fields: + if field.type == resources_pb2.ModelTypeField.DataType.STR: + return field.name + return None + + +def _get_first_media_param(model_client, method_name): + """Find the first Image/Video/Audio-typed input parameter name and its type. + + Args: + model_client: ModelClient instance with fetched signatures. + method_name: Method name to inspect. + + Returns: + tuple: (param_name, data_type_enum) or (None, None). + """ + from clarifai_grpc.grpc.api import resources_pb2 + + media_types = { + resources_pb2.ModelTypeField.DataType.IMAGE: 'image', + resources_pb2.ModelTypeField.DataType.VIDEO: 'video', + resources_pb2.ModelTypeField.DataType.AUDIO: 'audio', + } + if not model_client._defined: + model_client.fetch() + method_sig = model_client._method_signatures.get(method_name) + if not method_sig: + return None, None + for field in method_sig.input_fields: + if field.type in media_types: + return field.name, media_types[field.type] + return None, None + + +def _coerce_input_value(value, model_client, method_name, param_name): + """Coerce a string value to the correct type based on the method signature. + + Args: + value: String value to coerce. + model_client: ModelClient instance. + method_name: Method name. + param_name: Parameter name. + + Returns: + Coerced value. + """ + from clarifai_grpc.grpc.api import resources_pb2 + + if not model_client._defined: + model_client.fetch() + method_sig = model_client._method_signatures.get(method_name) + if not method_sig: + return value + type_map = { + resources_pb2.ModelTypeField.DataType.INT: int, + resources_pb2.ModelTypeField.DataType.FLOAT: float, + resources_pb2.ModelTypeField.DataType.BOOL: lambda v: v.lower() in ('true', '1', 'yes'), + } + for field in method_sig.input_fields: + if field.name == param_name and field.type in type_map: + try: + return type_map[field.type](value) + except (ValueError, TypeError): + return value + return value + + +def _parse_kv_inputs(input_params, model_client, method_name): + """Parse key=value input parameters into a dict with type coercion. + + Args: + input_params: Tuple of "key=value" strings. + model_client: ModelClient instance for type coercion. + method_name: Method name for signature lookup. + + Returns: + dict: Parsed parameters. + """ + result = {} + for kv in input_params: + if '=' not in kv: + raise click.UsageError(f"Invalid input format: '{kv}'. Use key=value.") + key, value = kv.split('=', 1) + result[key] = _coerce_input_value(value, model_client, method_name, key) + return result + + +def _detect_media_type_from_ext(path): + """Detect media type from file extension. + + Returns: + str: 'image', 'video', or 'audio'. + """ + ext = os.path.splitext(path)[1].lower() + video_exts = {'.mp4', '.mov', '.avi', '.mkv', '.webm', '.flv', '.wmv'} + audio_exts = {'.wav', '.mp3', '.flac', '.aac', '.ogg', '.m4a', '.wma'} + if ext in video_exts: + return 'video' + elif ext in audio_exts: + return 'audio' + return 'image' + + +def _is_streaming_method(model_client, method_name): + """Check if a method returns a streaming (Iterator) response. + + Args: + model_client: ModelClient instance. + method_name: Method name. + + Returns: + bool: True if the method returns an Iterator type. + """ + sig_str = model_client.method_signature(method_name) + # Signature looks like: "def name(...) -> Iterator[str]:" + return_part = sig_str.split('->')[-1].strip() if '->' in sig_str else '' + return return_part.lower().startswith('iterator') + + +def _select_method(model_methods, model_client, explicit_method, is_chat, has_text_input): + """Select the best method to call based on available methods and flags. + + Priority: + 1. --chat → openai_stream_transport + 2. --method explicit → use that + 3. OpenAI auto-detection (has text input + model has openai_stream_transport) + 4. Streaming method (generate, or any Iterator-returning method) + 5. Fallback to predict + + Returns: + tuple: (method_name, is_openai_chat_path) + """ + methods = list(model_methods) + + # 1. --chat flag + if is_chat: + if 'openai_stream_transport' in methods: + return 'openai_stream_transport', True + elif 'openai_transport' in methods: + return 'openai_transport', True + else: + raise click.UsageError( + "This model does not support OpenAI chat. Available methods: " + ", ".join(methods) + ) + + # 2. Explicit --method + if explicit_method: + if explicit_method in methods: + return explicit_method, False + raise click.UsageError( + f"Method '{explicit_method}' not available. Available methods: " + ", ".join(methods) + ) + + # 3. OpenAI auto-detection for text input + if has_text_input and 'openai_stream_transport' in methods: + return 'openai_stream_transport', True + + # 4. Prefer streaming method + for m in methods: + if m in ('openai_stream_transport', 'openai_transport'): + continue + if _is_streaming_method(model_client, m): + return m, False + + # 5. Fallback to predict or first available + if 'predict' in methods: + return 'predict', False + return methods[0] if methods else 'predict', False + + +def _build_chat_request(message): + """Build an OpenAI-compatible chat request JSON string.""" + return json.dumps( + { + "messages": [{"role": "user", "content": message}], + "stream": True, + } + ) + + +def _display_openai_stream(stream_response, output_format): + """Display streaming OpenAI chat response, handling reasoning_content and content. + + Args: + stream_response: Iterator of streaming chunks. + output_format: 'text' or 'json'. + """ + full_reasoning = [] + full_content = [] + in_reasoning = False + + for chunk in stream_response: + try: + if isinstance(chunk, str): + data = json.loads(chunk) if chunk.strip() else {} + else: + data = chunk + except (json.JSONDecodeError, TypeError): + if output_format == 'text': + click.echo(chunk, nl=False) + full_content.append(str(chunk)) + continue + + if not isinstance(data, dict): + if output_format == 'text': + click.echo(str(data), nl=False) + full_content.append(str(data)) + continue + + # Handle OpenAI SSE format: data contains choices[0].delta + choices = data.get('choices', []) + if not choices: + # Might be raw text content + if output_format == 'text': + click.echo(str(data), nl=False) + full_content.append(str(data)) + continue + + delta = choices[0].get('delta', {}) + reasoning = delta.get('reasoning_content', '') + content = delta.get('content', '') + + if reasoning and output_format == 'text': + if not in_reasoning: + click.echo('', nl=True) + in_reasoning = True + click.echo(reasoning, nl=False) + full_reasoning.append(reasoning) + + if content: + if in_reasoning and output_format == 'text': + click.echo('\n', nl=True) + in_reasoning = False + if output_format == 'text': + click.echo(content, nl=False) + full_content.append(content) + + if in_reasoning and output_format == 'text': + click.echo('\n', nl=True) + + if output_format == 'text': + click.echo() # Final newline + elif output_format == 'json': + result = {} + if full_reasoning: + result['reasoning'] = ''.join(full_reasoning) + result['result'] = ''.join(full_content) + click.echo(json.dumps(result)) + + +@model.command() +@click.argument('model_ref', required=False, default=None) +@click.argument('text_input', required=False, default=None) +@click.option( + '-i', + '--input', + 'input_params', + multiple=True, + help='Named parameter as key=value (repeatable).', +) +@click.option( + '--inputs', + required=False, + help='All parameters as JSON string.', +) @click.option( - '--config', + '--file', + 'input_file', type=click.Path(exists=True), required=False, - help='Path to the model predict config file.', + help='Input file (image, audio, video).', +) +@click.option( + '--url', + 'input_url', + required=False, + help='Input URL (image, audio, video).', +) +@click.option( + '--chat', + 'chat_message', + required=False, + help='OpenAI chat message (auto-uses OpenAI client).', +) +@click.option( + '--method', + 'explicit_method', + required=False, + default=None, + help='Method to call. Overrides auto-selection.', +) +@click.option( + '--info', + is_flag=True, + default=False, + help='Show available methods and their signatures, then exit.', +) +@click.option( + '-o', + '--output', + 'output_format', + type=click.Choice(['text', 'json']), + default='text', + help='Output format (default: text).', +) +@click.option( + '--deployment', + 'deployment_id', + required=False, + help='Route to a specific deployment.', +) +@click.option( + '--model-url', + 'model_url_opt', + required=False, + help='Full model URL (alternative to positional MODEL).', ) -@click.option('--model_id', required=False, help='Model ID of the model used to predict.') -@click.option('--user_id', required=False, help='User ID of the model used to predict.') -@click.option('--app_id', required=False, help='App ID of the model used to predict.') -@click.option('--model_url', required=False, help='Model URL of the model used to predict.') +# Hidden legacy flags — still functional +@click.option('--model_id', required=False, hidden=True, help='Model ID.') +@click.option('--user_id', required=False, hidden=True, help='User ID.') +@click.option('--app_id', required=False, hidden=True, help='App ID.') +@click.option('--model_url', 'model_url_legacy', required=False, hidden=True, help='Model URL.') @click.option( '-cc_id', '--compute_cluster_id', required=False, - help='Compute Cluster ID to use for the model', + hidden=True, + help='Compute Cluster ID.', ) -@click.option('-np_id', '--nodepool_id', required=False, help='Nodepool ID to use for the model') @click.option( - '-dpl_id', '--deployment_id', required=False, help='Deployment ID to use for the model' + '-np_id', + '--nodepool_id', + required=False, + hidden=True, + help='Nodepool ID.', ) @click.option( - '-dpl_usr_id', - '--deployment_user_id', + '-dpl_id', + '--deployment_id', + 'deployment_id_legacy', required=False, - help='User ID to use for runner selector (organization or user). If not provided, defaults to PAT owner user_id.', + hidden=True, + help='Deployment ID.', ) @click.option( - '--inputs', + '-dpl_usr_id', + '--deployment_user_id', required=False, - help='JSON string of input parameters for pythonic models (e.g., \'{"prompt": "Hello", "max_tokens": 100}\')', + hidden=True, + help='Deployment user ID.', ) -@click.option('--method', required=False, default='predict', help='Method to call on the model.') @click.pass_context def predict( ctx, - config, + model_ref, + text_input, + input_params, + inputs, + input_file, + input_url, + chat_message, + explicit_method, + info, + output_format, + deployment_id, + model_url_opt, model_id, user_id, app_id, - model_url, + model_url_legacy, compute_cluster_id, nodepool_id, - deployment_id, + deployment_id_legacy, deployment_user_id, - inputs, - method, ): - """Predict using a Clarifai model. - - \b - Model Identification: - Use either --model_url OR the combination of --model_id, --user_id, and --app_id - - \b - Input Methods: - --inputs: JSON string with parameters (e.g., '{"prompt": "Hello", "max_tokens": 100}') - --method: Method to call on the model (default is 'predict') + """Run a prediction against a Clarifai model. \b - Compute Options: - Use either --deployment_id OR both --compute_cluster_id and --nodepool_id + Arguments: + MODEL Model as user_id/app_id/models/model_id or full URL. + INPUT Text input (for text models). Use --file or --url for media. \b Examples: - Text model: - clarifai model predict --model_url --inputs '{"prompt": "Hello world"}' - - With compute cluster: - clarifai model predict --model_id --user_id --app_id \\ - --compute_cluster_id --nodepool_id \\ - --inputs '{"prompt": "Hello"}' + clarifai model predict openai/chat-completion/models/GPT-4 "Hello world" + clarifai model predict openai/chat-completion/models/GPT-4 --info + echo "Hello" | clarifai model predict openai/chat-completion/models/GPT-4 + clarifai model predict my/app/models/detector --file photo.jpg + clarifai model predict my/app/models/detector --url https://example.com/img.jpg + clarifai model predict my/app/models/m -i prompt="Hello" -i max_tokens=200 + clarifai model predict openai/chat-completion/models/GPT-4 --chat "What is AI?" + clarifai model predict openai/chat-completion/models/GPT-4 "Hello" -o json """ + import sys + from clarifai.client.model import Model from clarifai.urls.helper import ClarifaiUrlHelper - from clarifai.utils.cli import from_yaml, validate_context + from clarifai.utils.cli import validate_context validate_context(ctx) - # Load configuration from file if provided - if config: - config_data = from_yaml(config) - # Override None values with config data - model_id = model_id or config_data.get('model_id') - user_id = user_id or config_data.get('user_id') - app_id = app_id or config_data.get('app_id') - model_url = model_url or config_data.get('model_url') - compute_cluster_id = compute_cluster_id or config_data.get('compute_cluster_id') - nodepool_id = nodepool_id or config_data.get('nodepool_id') - deployment_id = deployment_id or config_data.get('deployment_id') - deployment_user_id = deployment_user_id or config_data.get('deployment_user_id') - inputs = inputs or config_data.get('inputs') - method = method or config_data.get('method', 'predict') - - # Validate parameters - _validate_model_params(model_id, user_id, app_id, model_url) - _validate_compute_params(compute_cluster_id, nodepool_id, deployment_id) + # --- Merge legacy flags --- + model_url = model_url_opt or model_url_legacy + deployment_id = deployment_id or deployment_id_legacy + + # --- Resolve model URL --- + # Priority: positional model_ref > --model-url/--model_url > --model_id triple > config + if model_ref: + model_url = _resolve_model_ref(model_ref, ui_base=ctx.obj.current.ui) + elif not model_url: + # Try legacy triple + if all([model_id, user_id, app_id]): + model_url = ClarifaiUrlHelper.clarifai_url( + user_id=user_id, app_id=app_id, resource_type="models", resource_id=model_id + ) + else: + raise click.UsageError( + "No model specified. Use: clarifai model predict ..." + ) - # Generate model URL if not provided - if not model_url: - model_url = ClarifaiUrlHelper.clarifai_url( - user_id=user_id, app_id=app_id, resource_type="models", resource_id=model_id - ) logger.debug(f"Using model at URL: {model_url}") - # Create model instance + # --- Validate compute params --- + _validate_compute_params(compute_cluster_id, nodepool_id, deployment_id) + + # --- Create model instance --- model = Model( url=model_url, pat=ctx.obj.current.pat, @@ -1951,27 +2704,123 @@ def predict( deployment_user_id=deployment_user_id, ) - model_methods = model.client.available_methods() - stream_method = ( - model.client.method_signature(method).split()[-1][:-1].lower().startswith('iter') + model_methods = list(model.client.available_methods()) + + # --- --info: display methods and exit --- + if info: + click.echo(f"Model: {model.id} ({model.user_id}/{model.app_id})\n") + click.echo("Methods:") + for m in model_methods: + sig = model.client.method_signature(m) + click.echo(f" {sig}") + return + + # --- Determine if we have text input (for OpenAI auto-detection) --- + has_text = bool(text_input or chat_message) + if not has_text and not input_file and not input_url and not inputs and not input_params: + # Check stdin + if not sys.stdin.isatty(): + text_input = sys.stdin.read().strip() + has_text = bool(text_input) + + # --- Select method --- + method_name, is_openai_path = _select_method( + model_methods, model.client, explicit_method, bool(chat_message), has_text ) - - # Determine prediction method and execute - if inputs and (method in model_methods): - # Pythonic model prediction with JSON inputs + is_stream = _is_streaming_method(model.client, method_name) + + # --- Build inputs --- + if is_openai_path: + # OpenAI chat path + chat_text = chat_message or text_input + if not chat_text: + raise click.UsageError("No input provided for chat. Pass a text argument or --chat.") + request_body = _build_chat_request(chat_text) + model_prediction = getattr(model, method_name)(msg=request_body) + _display_openai_stream(model_prediction, output_format) + return + + # Build inputs dict from various sources + inputs_dict = {} + + # --inputs JSON + if inputs: inputs_dict = _parse_json_param(inputs, "inputs") - inputs_dict = _process_multimodal_inputs(inputs_dict) - model_prediction = getattr(model, method)(**inputs_dict) - else: - logger.error( - f"ValueError: The model does not support the '{method}' method. Please check the model's capabilities." + + # -i key=value pairs (override JSON keys) + if input_params: + kv_dict = _parse_kv_inputs(input_params, model.client, method_name) + inputs_dict.update(kv_dict) + + # --file + if input_file: + from clarifai.runners.utils.data_types import Audio, Image, Video + + param_name, media_type = _get_first_media_param(model.client, method_name) + if not param_name: + # Fallback: detect from extension, use generic name + media_type = _detect_media_type_from_ext(input_file) + param_name = media_type + + with open(input_file, 'rb') as f: + file_bytes = f.read() + + type_cls = {'image': Image, 'video': Video, 'audio': Audio}[media_type] + inputs_dict[param_name] = type_cls(bytes=file_bytes) + + # --url + if input_url: + from clarifai.runners.utils.data_types import Audio, Image, Video + + param_name, media_type = _get_first_media_param(model.client, method_name) + if not param_name: + media_type = _detect_media_type_from_ext(input_url) + param_name = media_type + + type_cls = {'image': Image, 'video': Video, 'audio': Audio}[media_type] + inputs_dict[param_name] = type_cls(url=input_url) + + # Positional text input or stdin + if not inputs_dict and text_input: + param_name = _get_first_str_param(model.client, method_name) + if param_name: + inputs_dict[param_name] = text_input + else: + # If no str param found, try passing as first positional + inputs_dict = {'text': text_input} + + if not inputs_dict: + raise click.UsageError( + "No input provided. Pass text, --file, --url, --inputs, or -i key=value.\n" + "Use --info to see available methods and their parameters." ) - raise click.Abort() - if stream_method: - for chunk in model_prediction: - click.echo(chunk, nl=False) - click.echo() # Ensure a newline after the stream ends + # Process multimodal inputs (URL/file strings in dict values) + inputs_dict = _process_multimodal_inputs(inputs_dict) + + # --- Execute prediction --- + if method_name not in model_methods: + raise click.UsageError( + f"Method '{method_name}' not available. Available methods: " + ", ".join(model_methods) + ) + + model_prediction = getattr(model, method_name)(**inputs_dict) + + # --- Display output --- + if is_stream: + if output_format == 'json': + chunks = [] + for chunk in model_prediction: + if isinstance(chunk, str): + chunks.append(chunk) + click.echo(json.dumps({"result": ''.join(chunks)})) + else: + for chunk in model_prediction: + if isinstance(chunk, str): + click.echo(chunk, nl=False) + click.echo() + elif output_format == 'json': + click.echo(json.dumps({"result": str(model_prediction)})) else: click.echo(model_prediction) @@ -1983,18 +2832,19 @@ def predict( default=None, ) @click.option( - '--app_id', '-a', + '--app_id', type=str, default=None, - show_default=True, - help="Get all models of an app", + help='Filter by app ID.', ) @click.pass_context def list_model(ctx, user_id, app_id): - """List models of user/community. + """List models for a user or across the platform. - USER_ID: User id. If not specified, the current user is used by default. Set "all" to get all public models in Clarifai platform. + \b + USER_ID User ID to list models for (default: current user). + Use "all" to list public models across Clarifai. """ from clarifai.client import User diff --git a/clarifai/cli/nodepool.py b/clarifai/cli/nodepool.py index 7cbaf692..7c1d64f7 100644 --- a/clarifai/cli/nodepool.py +++ b/clarifai/cli/nodepool.py @@ -18,7 +18,7 @@ context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}, ) def nodepool(): - """Manage Nodepools: create, delete, list""" + """Manage nodepools.""" @nodepool.command(['c']) diff --git a/clarifai/cli/pipeline.py b/clarifai/cli/pipeline.py index 34a8db95..9b71cb52 100644 --- a/clarifai/cli/pipeline.py +++ b/clarifai/cli/pipeline.py @@ -19,7 +19,7 @@ context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}, ) def pipeline(): - """Manage pipelines: upload, init, list, etc""" + """Create and manage pipelines.""" @pipeline.command() diff --git a/clarifai/cli/pipeline_run.py b/clarifai/cli/pipeline_run.py index efc6ed57..25a86166 100644 --- a/clarifai/cli/pipeline_run.py +++ b/clarifai/cli/pipeline_run.py @@ -86,7 +86,7 @@ def _create_pipeline(ctx, user_id, app_id, pipeline_id, pipeline_version_id): context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}, ) def pipelinerun(): - """Manage Pipeline Version Runs: pause, cancel, resume, monitor""" + """Monitor and control pipeline runs.""" @pipelinerun.command() diff --git a/clarifai/cli/pipeline_step.py b/clarifai/cli/pipeline_step.py index d90465d6..21de508c 100644 --- a/clarifai/cli/pipeline_step.py +++ b/clarifai/cli/pipeline_step.py @@ -19,7 +19,7 @@ context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}, ) def pipeline_step(): - """Manage pipeline steps: upload, test, list, etc""" + """Manage pipeline steps.""" @pipeline_step.command() diff --git a/clarifai/cli/pipeline_template.py b/clarifai/cli/pipeline_template.py index 773b1b12..65e56406 100644 --- a/clarifai/cli/pipeline_template.py +++ b/clarifai/cli/pipeline_template.py @@ -15,7 +15,7 @@ context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}, ) def pipelinetemplate(): - """Manage pipeline templates: list, discover, etc""" + """Browse pipeline templates.""" @pipelinetemplate.command(name='list', aliases=['ls']) diff --git a/clarifai/cli/templates/model_templates.py b/clarifai/cli/templates/model_templates.py index 993f433e..88eb32de 100644 --- a/clarifai/cli/templates/model_templates.py +++ b/clarifai/cli/templates/model_templates.py @@ -10,13 +10,10 @@ def get_model_class_template() -> str: from clarifai.runners.utils.data_utils import Param class MyModel(ModelClass): - """A custom model implementation using ModelClass.""" + """A custom model.""" def load_model(self): - """Load the model here. - # TODO: please fill in - # Add your model loading logic here - """ + """Initialize your model here. Called once when the model starts.""" pass @ModelClass.method @@ -24,28 +21,25 @@ def predict( self, prompt: str = "", chat_history: List[dict] = None, - max_tokens: int = Param(default=256, description="The maximum number of tokens to generate. Shorter token lengths will provide faster performance."), - temperature: float = Param(default=1.0, description="A decimal number that determines the degree of randomness in the response"), - top_p: float = Param(default=1.0, description="An alternative to sampling with temperature, where the model considers the results of the tokens with top_p probability mass."), + max_tokens: int = Param(default=256, description="The maximum number of tokens to generate."), + temperature: float = Param(default=1.0, description="Sampling temperature (higher = more random)."), + top_p: float = Param(default=1.0, description="Nucleus sampling threshold."), ) -> str: - """This is the method that will be called when the runner is run. It takes in an input and returns an output.""" - # TODO: please fill in - # Implement your prediction logic here - return "This is a placeholder response. Please implement your model logic." + """Return a single response.""" + return f"Echo: {prompt}" @ModelClass.method def generate( self, prompt: str = "", chat_history: List[dict] = None, - max_tokens: int = Param(default=256, description="The maximum number of tokens to generate. Shorter token lengths will provide faster performance."), - temperature: float = Param(default=1.0, description="A decimal number that determines the degree of randomness in the response"), - top_p: float = Param(default=1.0, description="An alternative to sampling with temperature, where the model considers the results of the tokens with top_p probability mass."), + max_tokens: int = Param(default=256, description="The maximum number of tokens to generate."), + temperature: float = Param(default=1.0, description="Sampling temperature (higher = more random)."), + top_p: float = Param(default=1.0, description="Nucleus sampling threshold."), ) -> Iterator[str]: - """Example yielding a streamed response.""" - # TODO: please fill in - # Implement your generation logic here - yield "This is a placeholder response. Please implement your model logic." + """Stream a response.""" + for word in f"Echo: {prompt}".split(): + yield word + " " ''' @@ -53,45 +47,28 @@ def get_mcp_model_class_template() -> str: """Return the template for an MCPModelClass-based model.""" return '''from typing import Any -from fastmcp import FastMCP # use fastmcp v2 not the built in mcp +from fastmcp import FastMCP from pydantic import Field from clarifai.runners.models.mcp_class import MCPModelClass -# TODO: please fill in -# Configure your FastMCP server -server = FastMCP("my-mcp-server", instructions="", stateless_http=True) +server = FastMCP("my-mcp-server", instructions="A sample MCP server.", stateless_http=True) -# TODO: please fill in -# Add your tools, resources, and prompts here -@server.tool("example_tool", description="An example tool") -def example_tool(input_param: Any = Field(description="Example input parameter")): - """Example tool implementation.""" - # TODO: please fill in - # Implement your tool logic here - return f"Processed: {input_param}" +@server.tool("hello", description="Say hello to someone") +def hello(name: str = Field(description="Name to greet")) -> str: + """Greet a user by name.""" + return f"Hello, {name}!" -# Static resource example @server.resource("config://version") def get_version(): - """Example static resource.""" - # TODO: please fill in - # Return your resource data + """Return the server version.""" return "1.0.0" -@server.prompt() -def example_prompt(text: str) -> str: - """Example prompt template.""" - # TODO: please fill in - # Define your prompt template - return f"Process this text: {text}" - - class MyModel(MCPModelClass): - """A custom model implementation using MCPModelClass.""" + """MCP model that exposes tools, resources, and prompts.""" def get_server(self) -> FastMCP: """Return the FastMCP server instance.""" @@ -108,22 +85,17 @@ def get_openai_model_class_template(port: str = "8000") -> str: from clarifai.runners.utils.openai_convertor import build_openai_messages class MyModel(OpenAIModelClass): - """A custom model implementation using OpenAIModelClass.""" + """Wraps an OpenAI-compatible API endpoint.""" - # TODO: please fill in - # Configure your OpenAI-compatible client for local model client = OpenAI( - api_key="local-key", # TODO: please fill in - use your local API key - base_url="http://localhost:{port}/v1", # TODO: please fill in - your local model server endpoint + api_key="local-key", + base_url="http://localhost:{port}/v1", ) - # Automatically get the first available model model = client.models.list().data[0].id def load_model(self): - """Optional: Add any additional model loading logic here.""" - # TODO: please fill in (optional) - # Add any initialization logic if needed + """Optional initialization logic.""" pass @OpenAIModelClass.method @@ -131,13 +103,11 @@ def predict( self, prompt: str = "", chat_history: List[dict] = None, - max_tokens: int = Param(default=256, description="The maximum number of tokens to generate. Shorter token lengths will provide faster performance."), - temperature: float = Param(default=1.0, description="A decimal number that determines the degree of randomness in the response"), - top_p: float = Param(default=1.0, description="An alternative to sampling with temperature, where the model considers the results of the tokens with top_p probability mass."), + max_tokens: int = Param(default=256, description="The maximum number of tokens to generate."), + temperature: float = Param(default=1.0, description="Sampling temperature (higher = more random)."), + top_p: float = Param(default=1.0, description="Nucleus sampling threshold."), ) -> str: - """Run a single prompt completion using the OpenAI client.""" - # TODO: please fill in - # Implement your prediction logic here + """Run a single prompt completion.""" messages = build_openai_messages(prompt, chat_history) response = self.client.chat.completions.create( model=self.model, @@ -153,13 +123,11 @@ def generate( self, prompt: str = "", chat_history: List[dict] = None, - max_tokens: int = Param(default=256, description="The maximum number of tokens to generate. Shorter token lengths will provide faster performance."), - temperature: float = Param(default=1.0, description="A decimal number that determines the degree of randomness in the response"), - top_p: float = Param(default=1.0, description="An alternative to sampling with temperature, where the model considers the results of the tokens with top_p probability mass."), + max_tokens: int = Param(default=256, description="The maximum number of tokens to generate."), + temperature: float = Param(default=1.0, description="Sampling temperature (higher = more random)."), + top_p: float = Param(default=1.0, description="Nucleus sampling threshold."), ) -> Iterator[str]: - """Stream a completion response using the OpenAI client.""" - # TODO: please fill in - # Implement your streaming logic here + """Stream a completion response.""" messages = build_openai_messages(prompt, chat_history) stream = self.client.chat.completions.create( model=self.model, @@ -177,60 +145,77 @@ def generate( ''' -def get_config_template(user_id: str = None, model_type_id: str = "any-to-any") -> str: - """Return the template for config.yaml.""" - return f'''# Configuration file for your Clarifai model +def get_config_template( + user_id: str = None, + model_type_id: str = "any-to-any", + model_id: str = "my-model", + simplified: bool = True, +) -> str: + """Return the template for config.yaml. + + Args: + user_id: User ID to include in the config. In simplified mode, this is omitted + (resolved from CLI context at deploy time). + model_type_id: Model type ID. + model_id: Model ID. + simplified: If True, generate simplified config (no TODOs, compute.instance shorthand). + If False, generate verbose config with all fields. + """ + if simplified: + return f'''model: + id: "{model_id}" + model_type_id: "{model_type_id}" + +compute: + instance: g5.xlarge # Run 'clarifai list-instances' to see all options. + # cloud: aws # Cloud provider (aws, gcp, vultr). Auto-detected from instance. + # region: us-east-1 # Cloud region. Auto-detected from instance. + +# Uncomment to auto-download model checkpoints: +# checkpoints: +# repo_id: owner/model-name +''' + else: + return _get_verbose_config_template(user_id, model_type_id, model_id) + -model: - id: "my-model" # TODO: please fill in - replace with your model ID - user_id: "{user_id}" # TODO: please fill in - replace with your user ID - app_id: "app_id" # TODO: please fill in - replace with your app ID - model_type_id: "{model_type_id}" # TODO: please fill in - replace if different model type ID +def _get_verbose_config_template( + user_id: str = None, model_type_id: str = "any-to-any", model_id: str = "my-model" +) -> str: + """Return the verbose template for config.yaml (original format).""" + return f'''model: + id: "{model_id}" + user_id: "{user_id}" + app_id: "app_id" + model_type_id: "{model_type_id}" build_info: python_version: "3.12" - # platform: "linux/amd64,linux/arm64" # Optional: Specify target platform(s) for Docker image build -# TODO: please fill in - adjust compute requirements for your model inference_compute_info: - cpu_limit: "1" # TODO: please fill in - Amount of CPUs to use as a limit - cpu_memory: "1Gi" # TODO: please fill in - Amount of CPU memory to use as a limit - cpu_requests: "0.5" # TODO: please fill in - Amount of CPUs to use as a minimum - cpu_memory_requests: "512Mi" # TODO: please fill in - Amount of CPU memory to use as a minimum - num_accelerators: 1 # TODO: please fill in - Amount of GPU/TPUs to use - accelerator_type: ["NVIDIA-*"] # TODO: please fill in - type of accelerators requested - accelerator_memory: "1Gi" # TODO: please fill in - Amount of accelerator/GPU memory to use as a minimum - -# TODO: please fill in (optional) - add checkpoints section if needed -# checkpoints: -# type: "huggingface" # supported type -# repo_id: "your-model-repo" # for huggingface like openai/gpt-oss-20b -# # hf_token: "your-huggingface-token" # if private repo -# when: "runtime" # or "build", "upload" + cpu_limit: "1" + cpu_memory: "1Gi" + cpu_requests: "0.5" + cpu_memory_requests: "512Mi" + num_accelerators: 1 + accelerator_type: ["NVIDIA-*"] + accelerator_memory: "1Gi" -# Uncomment if model needs to work with streaming video runners (adds additional packages): -# streaming_video_consumer: true +# checkpoints: +# type: "huggingface" +# repo_id: "your-model-repo" +# when: "runtime" ''' def get_requirements_template(model_type_id: str = None) -> str: """Return the template for requirements.txt.""" - requirements = f'''# Clarifai SDK - required -clarifai>={__version__} -''' + req = f'clarifai>={__version__}\n' if model_type_id == "mcp": - requirements += "fastmcp\n" + req += "fastmcp\n" elif model_type_id == "openai": - requirements += "openai\n" - requirements += ''' -# TODO: please fill in - add your model's dependencies here -# Examples: -# torch>=2.0.0 -# transformers>=4.30.0 -# numpy>=1.21.0 -# pillow>=9.0.0 -''' - return requirements + req += "openai\n" + return req # Mapping of model type IDs to their corresponding templates @@ -244,7 +229,6 @@ def get_model_template(model_type_id: str = None, **kwargs) -> str: """Get the appropriate model template based on model_type_id.""" if model_type_id in MODEL_TYPE_TEMPLATES: template_func = MODEL_TYPE_TEMPLATES[model_type_id] - # Check if the template function accepts additional parameters import inspect sig = inspect.signature(template_func) diff --git a/clarifai/cli/templates/toolkits/__init__.py b/clarifai/cli/templates/toolkits/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/clarifai/cli/templates/toolkits/huggingface/1/model.py b/clarifai/cli/templates/toolkits/huggingface/1/model.py new file mode 100644 index 00000000..6494cb30 --- /dev/null +++ b/clarifai/cli/templates/toolkits/huggingface/1/model.py @@ -0,0 +1,133 @@ +import os +from threading import Thread +from typing import Iterator, List + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer + +from clarifai.runners.models.model_builder import ModelBuilder +from clarifai.runners.models.model_class import ModelClass +from clarifai.runners.utils.data_utils import Param +from clarifai.utils.logging import logger + + +class HuggingFaceModel(ModelClass): + def load_model(self): + if torch.backends.mps.is_available(): + self.device = 'mps' + elif torch.cuda.is_available(): + self.device = 'cuda' + else: + self.device = 'cpu' + logger.info(f"Using device: {self.device}") + + model_path = os.path.dirname(os.path.dirname(__file__)) + builder = ModelBuilder(model_path, download_validation_only=True) + config = builder.config + stage = config["checkpoints"]["when"] + checkpoints = config["checkpoints"]["repo_id"] + if stage in ["build", "runtime"]: + checkpoints = builder.download_checkpoints(stage=stage) + + self.tokenizer = AutoTokenizer.from_pretrained(checkpoints) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.hf_model = AutoModelForCausalLM.from_pretrained( + checkpoints, + low_cpu_mem_usage=True, + device_map=self.device, + torch_dtype=torch.bfloat16, + ) + self.streamer = TextIteratorStreamer( + tokenizer=self.tokenizer, + skip_prompt=True, + skip_special_tokens=True, + ) + + @ModelClass.method + def predict( + self, + prompt: str = "", + chat_history: List[dict] = None, + max_tokens: int = Param( + default=512, + description="The maximum number of tokens to generate.", + ), + temperature: float = Param( + default=0.7, + description="Sampling temperature (higher = more random).", + ), + top_p: float = Param( + default=0.8, + description="Nucleus sampling threshold.", + ), + ) -> str: + """Return a single completion.""" + messages = chat_history if chat_history else [] + if prompt: + messages.append({"role": "user", "content": prompt}) + + inputs = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, + ).to(self.hf_model.device) + + output = self.hf_model.generate( + **inputs, + do_sample=True, + max_new_tokens=max_tokens, + temperature=float(temperature), + top_p=float(top_p), + eos_token_id=self.tokenizer.eos_token_id, + ) + generated_tokens = output[0][inputs["input_ids"].shape[-1] :] + return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) + + @ModelClass.method + def generate( + self, + prompt: str = "", + chat_history: List[dict] = None, + max_tokens: int = Param( + default=512, + description="The maximum number of tokens to generate.", + ), + temperature: float = Param( + default=0.7, + description="Sampling temperature (higher = more random).", + ), + top_p: float = Param( + default=0.8, + description="Nucleus sampling threshold.", + ), + ) -> Iterator[str]: + """Stream a completion response.""" + messages = chat_history if chat_history else [] + if prompt: + messages.append({"role": "user", "content": prompt}) + + inputs = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, + ).to(self.hf_model.device) + + generation_kwargs = { + **inputs, + "do_sample": True, + "max_new_tokens": max_tokens, + "temperature": float(temperature), + "top_p": float(top_p), + "eos_token_id": self.tokenizer.eos_token_id, + "streamer": self.streamer, + } + thread = Thread(target=self.hf_model.generate, kwargs=generation_kwargs) + thread.start() + for text in self.streamer: + if text: + yield text + thread.join() diff --git a/clarifai/cli/templates/toolkits/huggingface/config.yaml b/clarifai/cli/templates/toolkits/huggingface/config.yaml new file mode 100644 index 00000000..1e82cd04 --- /dev/null +++ b/clarifai/cli/templates/toolkits/huggingface/config.yaml @@ -0,0 +1,13 @@ +model: + id: "my-model" + +build_info: + python_version: "3.11" + +compute: + instance: g5.xlarge + +checkpoints: + repo_id: unsloth/Llama-3.2-1B-Instruct + type: huggingface + when: runtime diff --git a/clarifai/cli/templates/toolkits/huggingface/requirements.txt b/clarifai/cli/templates/toolkits/huggingface/requirements.txt new file mode 100644 index 00000000..9bda5e91 --- /dev/null +++ b/clarifai/cli/templates/toolkits/huggingface/requirements.txt @@ -0,0 +1,6 @@ +clarifai +torch>=2.6.0 +transformers>=4.47.0 +accelerate>=1.2.0 +optimum>=1.23.3 +scipy>=1.10.1 diff --git a/clarifai/cli/templates/toolkits/lmstudio/1/model.py b/clarifai/cli/templates/toolkits/lmstudio/1/model.py new file mode 100644 index 00000000..871322c1 --- /dev/null +++ b/clarifai/cli/templates/toolkits/lmstudio/1/model.py @@ -0,0 +1,246 @@ +import json +import os +import socket +import subprocess +import sys +import time +from typing import Iterator, List + +from openai import OpenAI + +from clarifai.runners.models.openai_class import OpenAIModelClass +from clarifai.runners.utils.data_types import Image +from clarifai.runners.utils.data_utils import Param +from clarifai.runners.utils.openai_convertor import build_openai_messages +from clarifai.utils.logging import logger + +VERBOSE_LMSTUDIO = True +LMS_MODEL_NAME = os.environ.get("LMS_MODEL_NAME", "google/gemma-3-4b") +LMS_PORT = int(os.environ.get("LMS_PORT", "23333")) +LMS_CONTEXT_LENGTH = int(os.environ.get("LMS_CONTEXT_LENGTH", "4096")) + + +def _stream_command(cmd, verbose=True): + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + process = subprocess.Popen( + cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + env=env, + ) + if verbose and process.stdout: + for line in iter(process.stdout.readline, ""): + if line: + logger.info(f"[lms] {line.rstrip()}") + ret = process.wait() + if ret != 0: + raise RuntimeError(f"Command failed ({ret}): {cmd}") + return True + + +def _wait_for_port(port, timeout=30.0): + start = time.time() + while time.time() - start < timeout: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(1) + try: + if sock.connect_ex(("127.0.0.1", port)) == 0: + return True + except Exception: + pass + time.sleep(0.5) + raise RuntimeError(f"LM Studio server did not start on port {port} within {timeout}s") + + +def _is_model_available(model_name): + """Check if a model is already available locally in LM Studio.""" + try: + result = subprocess.run( + "lms ls --json", shell=True, capture_output=True, text=True, timeout=10, check=False + ) + if result.returncode == 0 and result.stdout.strip(): + models = json.loads(result.stdout) + for m in models: + if m.get("modelKey", "") == model_name: + return True + except Exception: + pass + return False + + +def _is_model_loaded(model_name): + """Check if a model is currently loaded in LM Studio.""" + try: + result = subprocess.run( + "lms ps --json", shell=True, capture_output=True, text=True, timeout=10, check=False + ) + if result.returncode == 0 and result.stdout.strip(): + models = json.loads(result.stdout) + for m in models: + if m.get("modelKey", "") == model_name: + return True + except Exception: + pass + return False + + +def run_lms_server(model_name='google/gemma-3-4b', port=11434, context_length=4096): + """Download model if needed, load it, and start the LM Studio server.""" + try: + if _is_model_available(model_name): + logger.info(f"Model {model_name} is already available locally, skipping download.") + else: + logger.info(f"Model {model_name} not found locally, downloading...") + _stream_command( + f"lms get https://huggingface.co/{model_name} --verbose", + verbose=VERBOSE_LMSTUDIO, + ) + + if _is_model_loaded(model_name): + logger.info(f"Model {model_name} is already loaded.") + else: + _stream_command("lms unload --all", verbose=VERBOSE_LMSTUDIO) + _stream_command( + f"lms load {model_name} --verbose --context-length {context_length}", + verbose=VERBOSE_LMSTUDIO, + ) + + subprocess.Popen( + f"lms server start --port {port}", + shell=True, + stdout=None if not VERBOSE_LMSTUDIO else sys.stdout, + stderr=None if not VERBOSE_LMSTUDIO else sys.stderr, + ) + _wait_for_port(port) + logger.info(f"LM Studio server started on port {port}") + except Exception as e: + raise RuntimeError(f"Failed to start LM Studio server: {e}") + + +def has_image_content(image: Image) -> bool: + return bool(getattr(image, 'url', None) or getattr(image, 'bytes', None)) + + +class LMStudioModel(OpenAIModelClass): + client = True + model = True + + def load_model(self): + self.model = LMS_MODEL_NAME + self.port = LMS_PORT + run_lms_server( + model_name=self.model, + port=self.port, + context_length=LMS_CONTEXT_LENGTH, + ) + self.client = OpenAI(api_key="notset", base_url=f"http://localhost:{self.port}/v1") + + @OpenAIModelClass.method + def predict( + self, + prompt: str = "", + image: Image = None, + images: List[Image] = None, + chat_history: List[dict] = None, + tools: List[dict] = None, + tool_choice: str = None, + max_tokens: int = Param( + default=2048, + description="The maximum number of tokens to generate.", + ), + temperature: float = Param( + default=0.7, + description="Sampling temperature (higher = more random).", + ), + top_p: float = Param( + default=0.95, + description="Nucleus sampling threshold.", + ), + ) -> str: + """Return a single completion.""" + if tools is not None and tool_choice is None: + tool_choice = "auto" + + img_content = image if has_image_content(image) else None + messages = build_openai_messages( + prompt=prompt, image=img_content, images=images, messages=chat_history + ) + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + tools=tools, + tool_choice=tool_choice, + max_completion_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + ) + + if response.usage is not None: + self.set_output_context( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + ) + + if response.choices[0] and response.choices[0].message.tool_calls: + tool_calls = response.choices[0].message.tool_calls + return json.dumps([tc.to_dict() for tc in tool_calls], indent=2) + return response.choices[0].message.content + + @OpenAIModelClass.method + def generate( + self, + prompt: str = "", + image: Image = None, + images: List[Image] = None, + chat_history: List[dict] = None, + tools: List[dict] = None, + tool_choice: str = None, + max_tokens: int = Param( + default=2048, + description="The maximum number of tokens to generate.", + ), + temperature: float = Param( + default=0.7, + description="Sampling temperature (higher = more random).", + ), + top_p: float = Param( + default=0.95, + description="Nucleus sampling threshold.", + ), + ) -> Iterator[str]: + """Stream a completion response.""" + if tools is not None and tool_choice is None: + tool_choice = "auto" + + img_content = image if has_image_content(image) else None + messages = build_openai_messages( + prompt=prompt, image=img_content, images=images, messages=chat_history + ) + for chunk in self.client.chat.completions.create( + model=self.model, + messages=messages, + tools=tools, + tool_choice=tool_choice, + max_completion_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stream=True, + stream_options={"include_usage": True}, + ): + if chunk.usage is not None: + if chunk.usage.prompt_tokens or chunk.usage.completion_tokens: + self.set_output_context( + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + ) + if chunk.choices: + if chunk.choices[0].delta.tool_calls: + tool_calls_json = [tc.to_dict() for tc in chunk.choices[0].delta.tool_calls] + yield json.dumps(tool_calls_json, indent=2) + else: + text = chunk.choices[0].delta.content if chunk.choices[0].delta.content else '' + yield text diff --git a/clarifai/cli/templates/toolkits/lmstudio/config.yaml b/clarifai/cli/templates/toolkits/lmstudio/config.yaml new file mode 100644 index 00000000..c593a1f1 --- /dev/null +++ b/clarifai/cli/templates/toolkits/lmstudio/config.yaml @@ -0,0 +1,8 @@ +model: + id: "my-model" + +build_info: + python_version: "3.12" + +toolkit: + provider: lmstudio diff --git a/clarifai/cli/templates/toolkits/lmstudio/requirements.txt b/clarifai/cli/templates/toolkits/lmstudio/requirements.txt new file mode 100644 index 00000000..0f711a4a --- /dev/null +++ b/clarifai/cli/templates/toolkits/lmstudio/requirements.txt @@ -0,0 +1,2 @@ +clarifai +openai diff --git a/clarifai/cli/templates/toolkits/ollama/1/model.py b/clarifai/cli/templates/toolkits/ollama/1/model.py new file mode 100644 index 00000000..5f269b1e --- /dev/null +++ b/clarifai/cli/templates/toolkits/ollama/1/model.py @@ -0,0 +1,176 @@ +import json +import os +import subprocess +import time +from typing import Iterator, List + +from openai import OpenAI + +from clarifai.runners.models.openai_class import OpenAIModelClass +from clarifai.runners.utils.data_types import Image +from clarifai.runners.utils.data_utils import Param +from clarifai.runners.utils.model_utils import execute_shell_command +from clarifai.runners.utils.openai_convertor import build_openai_messages +from clarifai.utils.logging import logger + +if not os.environ.get('OLLAMA_HOST'): + PORT = '23333' + os.environ["OLLAMA_HOST"] = f'127.0.0.1:{PORT}' +OLLAMA_HOST = os.environ.get('OLLAMA_HOST') + +if not os.environ.get('OLLAMA_CONTEXT_LENGTH'): + os.environ["OLLAMA_CONTEXT_LENGTH"] = '8192' + + +def run_ollama_server(model_name: str = 'llama3.2'): + """Start Ollama server and pull the model.""" + try: + # Start server in the background + execute_shell_command( + "ollama serve", + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + # Wait for server to be ready + start = time.time() + while time.time() - start < 30: + try: + r = subprocess.run(["ollama", "list"], capture_output=True, timeout=5, check=False) + if r.returncode == 0: + break + except (subprocess.TimeoutExpired, FileNotFoundError): + pass + time.sleep(1) + else: + raise RuntimeError("Ollama server did not start within 30s") + + # Pull model (blocking — must finish before we accept requests) + logger.info(f"Pulling ollama model '{model_name}'...") + result = subprocess.run( + ["ollama", "pull", model_name], capture_output=True, text=True, check=False + ) + if result.returncode != 0: + raise RuntimeError(f"ollama pull failed: {result.stderr}") + logger.info(f"Model '{model_name}' ready.") + except Exception as e: + raise RuntimeError(f"Failed to start Ollama server: {e}") + + +def has_image_content(image: Image) -> bool: + return bool(getattr(image, 'url', None) or getattr(image, 'bytes', None)) + + +class OllamaModel(OpenAIModelClass): + client = True + model = True + + def load_model(self): + self.model = os.environ.get("OLLAMA_MODEL_NAME", 'llama3.2') + run_ollama_server(model_name=self.model) + self.client = OpenAI(api_key="notset", base_url=f"http://{OLLAMA_HOST}/v1") + + @OpenAIModelClass.method + def predict( + self, + prompt: str = "", + image: Image = None, + images: List[Image] = None, + chat_history: List[dict] = None, + tools: List[dict] = None, + tool_choice: str = None, + max_tokens: int = Param( + default=2048, + description="The maximum number of tokens to generate.", + ), + temperature: float = Param( + default=0.7, + description="Sampling temperature (higher = more random).", + ), + top_p: float = Param( + default=0.95, + description="Nucleus sampling threshold.", + ), + ) -> str: + """Return a single completion.""" + if tools is not None and tool_choice is None: + tool_choice = "auto" + + img_content = image if has_image_content(image) else None + messages = build_openai_messages( + prompt=prompt, image=img_content, images=images, messages=chat_history + ) + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + tools=tools, + tool_choice=tool_choice, + max_completion_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + ) + + if response.usage is not None: + self.set_output_context( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + ) + + if response.choices[0] and response.choices[0].message.tool_calls: + tool_calls = response.choices[0].message.tool_calls + return json.dumps([tc.to_dict() for tc in tool_calls], indent=2) + return response.choices[0].message.content + + @OpenAIModelClass.method + def generate( + self, + prompt: str = "", + image: Image = None, + images: List[Image] = None, + chat_history: List[dict] = None, + tools: List[dict] = None, + tool_choice: str = None, + max_tokens: int = Param( + default=2048, + description="The maximum number of tokens to generate.", + ), + temperature: float = Param( + default=0.7, + description="Sampling temperature (higher = more random).", + ), + top_p: float = Param( + default=0.95, + description="Nucleus sampling threshold.", + ), + ) -> Iterator[str]: + """Stream a completion response.""" + if tools is not None and tool_choice is None: + tool_choice = "auto" + + img_content = image if has_image_content(image) else None + messages = build_openai_messages( + prompt=prompt, image=img_content, images=images, messages=chat_history + ) + for chunk in self.client.chat.completions.create( + model=self.model, + messages=messages, + tools=tools, + tool_choice=tool_choice, + max_completion_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stream=True, + stream_options={"include_usage": True}, + ): + if chunk.usage is not None: + if chunk.usage.prompt_tokens or chunk.usage.completion_tokens: + self.set_output_context( + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + ) + if chunk.choices: + if chunk.choices[0].delta.tool_calls: + tool_calls_json = [tc.to_dict() for tc in chunk.choices[0].delta.tool_calls] + yield json.dumps(tool_calls_json, indent=2) + else: + text = chunk.choices[0].delta.content if chunk.choices[0].delta.content else '' + yield text diff --git a/clarifai/cli/templates/toolkits/ollama/config.yaml b/clarifai/cli/templates/toolkits/ollama/config.yaml new file mode 100644 index 00000000..79634cc4 --- /dev/null +++ b/clarifai/cli/templates/toolkits/ollama/config.yaml @@ -0,0 +1,9 @@ +model: + id: "my-model" + +build_info: + python_version: "3.12" + image: "ollama/ollama:latest" + +toolkit: + provider: ollama diff --git a/clarifai/cli/templates/toolkits/ollama/requirements.txt b/clarifai/cli/templates/toolkits/ollama/requirements.txt new file mode 100644 index 00000000..0f711a4a --- /dev/null +++ b/clarifai/cli/templates/toolkits/ollama/requirements.txt @@ -0,0 +1,2 @@ +clarifai +openai diff --git a/clarifai/cli/templates/toolkits/sglang/1/model.py b/clarifai/cli/templates/toolkits/sglang/1/model.py new file mode 100644 index 00000000..7c604f88 --- /dev/null +++ b/clarifai/cli/templates/toolkits/sglang/1/model.py @@ -0,0 +1,155 @@ +import os +import sys +from typing import Iterator, List + +from openai import OpenAI + +from clarifai.runners.models.model_builder import ModelBuilder +from clarifai.runners.models.openai_class import OpenAIModelClass +from clarifai.runners.utils.data_utils import Param +from clarifai.runners.utils.openai_convertor import build_openai_messages +from clarifai.utils.logging import logger + +PYTHON_EXEC = sys.executable + + +def sglang_openai_server(checkpoints, **kwargs): + """Start SGLang OpenAI-compatible server.""" + from clarifai.runners.utils.model_utils import ( + execute_shell_command, + terminate_process, + wait_for_server, + ) + + cmds = [ + PYTHON_EXEC, + '-m', + 'sglang.launch_server', + '--model-path', + checkpoints, + ] + for key, value in kwargs.items(): + if value is None: + continue + param_name = key.replace('_', '-') + if isinstance(value, bool): + if value: + cmds.append(f'--{param_name}') + else: + cmds.extend([f'--{param_name}', str(value)]) + + server = type( + 'Server', + (), + { + 'host': kwargs.get('host', '0.0.0.0'), + 'port': kwargs.get('port', 23333), + 'process': None, + }, + )() + + try: + server.process = execute_shell_command(" ".join(cmds)) + url = f"http://{server.host}:{server.port}" + logger.info(f"Waiting for SGLang server at {url}") + wait_for_server(url) + logger.info(f"SGLang server started at {url}") + except Exception as e: + logger.error(f"Failed to start SGLang server: {e}") + if server.process: + terminate_process(server.process) + raise RuntimeError(f"Failed to start SGLang server: {e}") + return server + + +class SGLangModel(OpenAIModelClass): + client = True + model = True + + def load_model(self): + server_args = { + 'dtype': 'auto', + 'kv_cache_dtype': 'auto', + 'tp_size': 1, + 'context_length': None, + 'device': 'cuda', + 'port': 23333, + 'host': '0.0.0.0', + 'mem_fraction_static': 0.9, + } + + model_path = os.path.dirname(os.path.dirname(__file__)) + builder = ModelBuilder(model_path, download_validation_only=True) + config = builder.config + stage = config["checkpoints"]["when"] + checkpoints = config["checkpoints"]["repo_id"] + if stage in ["build", "runtime"]: + checkpoints = builder.download_checkpoints(stage=stage) + + self.server = sglang_openai_server(checkpoints, **server_args) + self.client = OpenAI( + api_key="notset", + base_url=f"http://{self.server.host}:{self.server.port}/v1", + ) + self.model = self.client.models.list().data[0].id + + @OpenAIModelClass.method + def predict( + self, + prompt: str = "", + chat_history: List[dict] = None, + max_tokens: int = Param( + default=512, + description="The maximum number of tokens to generate.", + ), + temperature: float = Param( + default=0.7, + description="Sampling temperature (higher = more random).", + ), + top_p: float = Param( + default=0.8, + description="Nucleus sampling threshold.", + ), + ) -> str: + """Return a single completion.""" + messages = build_openai_messages(prompt=prompt, messages=chat_history) + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + max_completion_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + ) + return response.choices[0].message.content + + @OpenAIModelClass.method + def generate( + self, + prompt: str = "", + chat_history: List[dict] = None, + max_tokens: int = Param( + default=512, + description="The maximum number of tokens to generate.", + ), + temperature: float = Param( + default=0.7, + description="Sampling temperature (higher = more random).", + ), + top_p: float = Param( + default=0.8, + description="Nucleus sampling threshold.", + ), + ) -> Iterator[str]: + """Stream a completion response.""" + messages = build_openai_messages(prompt=prompt, messages=chat_history) + for chunk in self.client.chat.completions.create( + model=self.model, + messages=messages, + max_completion_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stream=True, + ): + if chunk.choices: + text = chunk.choices[0].delta.content if chunk.choices[0].delta.content else '' + yield text diff --git a/clarifai/cli/templates/toolkits/sglang/config.yaml b/clarifai/cli/templates/toolkits/sglang/config.yaml new file mode 100644 index 00000000..d5978f9e --- /dev/null +++ b/clarifai/cli/templates/toolkits/sglang/config.yaml @@ -0,0 +1,13 @@ +model: + id: "my-model" + +build_info: + image: "lmsysorg/sglang:latest" + +compute: + instance: g5.xlarge + +checkpoints: + repo_id: google/gemma-3-1b-it + type: huggingface + when: runtime diff --git a/clarifai/cli/templates/toolkits/sglang/requirements.txt b/clarifai/cli/templates/toolkits/sglang/requirements.txt new file mode 100644 index 00000000..866458fe --- /dev/null +++ b/clarifai/cli/templates/toolkits/sglang/requirements.txt @@ -0,0 +1,3 @@ +clarifai +openai +nvidia-cudnn-cu12>=9.16.0.29 diff --git a/clarifai/cli/templates/toolkits/vllm/1/model.py b/clarifai/cli/templates/toolkits/vllm/1/model.py new file mode 100644 index 00000000..d0245cfb --- /dev/null +++ b/clarifai/cli/templates/toolkits/vllm/1/model.py @@ -0,0 +1,177 @@ +import os +import sys +from typing import Iterator, List + +from openai import OpenAI + +from clarifai.runners.models.model_builder import ModelBuilder +from clarifai.runners.models.openai_class import OpenAIModelClass +from clarifai.runners.utils.data_utils import Param +from clarifai.runners.utils.openai_convertor import build_openai_messages +from clarifai.utils.logging import logger + +PYTHON_EXEC = sys.executable + + +def vllm_openai_server(checkpoints, **kwargs): + """Start vLLM OpenAI-compatible server.""" + from clarifai.runners.utils.model_utils import ( + execute_shell_command, + terminate_process, + wait_for_server, + ) + + cmds = [ + PYTHON_EXEC, + '-m', + 'vllm.entrypoints.openai.api_server', + '--model', + checkpoints, + ] + for key, value in kwargs.items(): + if value is None: + continue + param_name = key.replace('_', '-') + if isinstance(value, bool): + if value: + cmds.append(f'--{param_name}') + else: + cmds.extend([f'--{param_name}', str(value)]) + + server = type( + 'Server', + (), + { + 'host': kwargs.get('host', '0.0.0.0'), + 'port': kwargs.get('port', 23333), + 'process': None, + }, + )() + + try: + server.process = execute_shell_command(" ".join(cmds)) + url = f"http://{server.host}:{server.port}" + logger.info(f"Waiting for vLLM server at {url}") + wait_for_server(url) + logger.info(f"vLLM server started at {url}") + except Exception as e: + logger.error(f"Failed to start vLLM server: {e}") + if server.process: + terminate_process(server.process) + raise RuntimeError(f"Failed to start vLLM server: {e}") + return server + + +class VLLMModel(OpenAIModelClass): + client = True + model = True + + def load_model(self): + server_args = { + 'tensor_parallel_size': 1, + 'port': 23333, + 'host': 'localhost', + } + + model_path = os.path.dirname(os.path.dirname(__file__)) + builder = ModelBuilder(model_path, download_validation_only=True) + config = builder.config + stage = config["checkpoints"]["when"] + checkpoints = config["checkpoints"]["repo_id"] + if stage in ["build", "runtime"]: + checkpoints = builder.download_checkpoints(stage=stage) + + self.server = vllm_openai_server(checkpoints, **server_args) + self.client = OpenAI( + api_key="notset", + base_url=f"http://{self.server.host}:{self.server.port}/v1", + ) + self.model = self.client.models.list().data[0].id + + @OpenAIModelClass.method + def predict( + self, + prompt: str = "", + chat_history: List[dict] = None, + tools: List[dict] = None, + tool_choice: str = None, + max_tokens: int = Param( + default=512, + description="The maximum number of tokens to generate.", + ), + temperature: float = Param( + default=0.7, + description="Sampling temperature (higher = more random).", + ), + top_p: float = Param( + default=0.95, + description="Nucleus sampling threshold.", + ), + ) -> str: + """Return a single completion.""" + if tools is not None and tool_choice is None: + tool_choice = "auto" + + messages = build_openai_messages(prompt=prompt, messages=chat_history) + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + tools=tools, + tool_choice=tool_choice, + max_completion_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + ) + + if response.choices[0] and response.choices[0].message.tool_calls: + import json + + tool_calls = response.choices[0].message.tool_calls + return json.dumps([tc.to_dict() for tc in tool_calls], indent=2) + return response.choices[0].message.content + + @OpenAIModelClass.method + def generate( + self, + prompt: str = "", + chat_history: List[dict] = None, + tools: List[dict] = None, + tool_choice: str = None, + max_tokens: int = Param( + default=512, + description="The maximum number of tokens to generate.", + ), + temperature: float = Param( + default=0.7, + description="Sampling temperature (higher = more random).", + ), + top_p: float = Param( + default=0.95, + description="Nucleus sampling threshold.", + ), + ) -> Iterator[str]: + """Stream a completion response.""" + if tools is not None and tool_choice is None: + tool_choice = "auto" + + messages = build_openai_messages(prompt=prompt, messages=chat_history) + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + tools=tools, + tool_choice=tool_choice, + max_completion_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stream=True, + ) + for chunk in response: + if chunk.choices: + if chunk.choices[0].delta.tool_calls: + import json + + tool_calls_json = [tc.to_dict() for tc in chunk.choices[0].delta.tool_calls] + yield json.dumps(tool_calls_json, indent=2) + else: + text = chunk.choices[0].delta.content if chunk.choices[0].delta.content else '' + yield text diff --git a/clarifai/cli/templates/toolkits/vllm/config.yaml b/clarifai/cli/templates/toolkits/vllm/config.yaml new file mode 100644 index 00000000..b5999dfc --- /dev/null +++ b/clarifai/cli/templates/toolkits/vllm/config.yaml @@ -0,0 +1,13 @@ +model: + id: "my-model" + +build_info: + image: "vllm/vllm-openai:latest" + +compute: + instance: g5.xlarge + +checkpoints: + repo_id: google/gemma-3-1b-it + type: huggingface + when: runtime diff --git a/clarifai/cli/templates/toolkits/vllm/requirements.txt b/clarifai/cli/templates/toolkits/vllm/requirements.txt new file mode 100644 index 00000000..0f711a4a --- /dev/null +++ b/clarifai/cli/templates/toolkits/vllm/requirements.txt @@ -0,0 +1,2 @@ +clarifai +openai diff --git a/clarifai/client/model_client.py b/clarifai/client/model_client.py index 92cc76f1..57417b9d 100644 --- a/clarifai/client/model_client.py +++ b/clarifai/client/model_client.py @@ -37,6 +37,33 @@ def is_async_context(): return False +class _LocalGRPCStub: + """Lightweight gRPC stub for local model servers. + + Only wraps the 3 RPCs needed for model predictions (PostModelOutputs, + GenerateModelOutputs, StreamModelOutputs). Avoids instantiating the full + V2Stub which can fail on certain clarifai_grpc versions due to missing + response deserializers. + """ + + def __init__(self, channel): + self.PostModelOutputs = channel.unary_unary( + "/clarifai.api.V2/PostModelOutputs", + request_serializer=service_pb2.PostModelOutputsRequest.SerializeToString, + response_deserializer=service_pb2.MultiOutputResponse.FromString, + ) + self.GenerateModelOutputs = channel.unary_stream( + "/clarifai.api.V2/GenerateModelOutputs", + request_serializer=service_pb2.PostModelOutputsRequest.SerializeToString, + response_deserializer=service_pb2.MultiOutputResponse.FromString, + ) + self.StreamModelOutputs = channel.stream_stream( + "/clarifai.api.V2/StreamModelOutputs", + request_serializer=service_pb2.PostModelOutputsRequest.SerializeToString, + response_deserializer=service_pb2.MultiOutputResponse.FromString, + ) + + class ModelClient: ''' Client for calling model predict, generate, and stream methods. @@ -62,6 +89,33 @@ def __init__( self._method_signatures = None self._defined = False + @classmethod + def from_local_grpc(cls, port: int = 8000) -> 'ModelClient': + """Create a ModelClient connected to a local gRPC model server. + + Connects to a local gRPC server started with ``clarifai model serve --grpc``. + Method signatures are auto-discovered from the running model. + + Args: + port: Port of the local gRPC server (default 8000). + + Returns: + ModelClient with predict/generate/stream methods ready to use. + + Example:: + + from clarifai.client.model_client import ModelClient + + client = ModelClient.from_local_grpc(port=8000) + response = client.predict(text="What is the future of AI?") + print(response) + """ + import grpc + + channel = grpc.insecure_channel(f"localhost:{port}") + stub = _LocalGRPCStub(channel) + return cls(stub=stub) + def fetch(self): ''' Fetch function signature definitions from the model and define the functions in the client diff --git a/clarifai/client/user.py b/clarifai/client/user.py index 67d05678..60f0edc9 100644 --- a/clarifai/client/user.py +++ b/clarifai/client/user.py @@ -1,6 +1,7 @@ import os from typing import Any, Dict, Generator, List, Optional +import requests import yaml from clarifai_grpc.grpc.api import resources_pb2, service_pb2 from clarifai_grpc.grpc.api.status import status_code_pb2 @@ -524,6 +525,29 @@ def get_user_info(self, user_id: str = None) -> resources_pb2.User: return response + def list_organizations(self) -> List[Dict[str, str]]: + """List organizations the user belongs to via REST API. + + Returns: + list: List of dicts with 'id' and 'name' keys for each organization. + """ + base = self.auth_helper.base + user_id = self.id + url = f"{base}/v2/users/{user_id}/organizations" + headers = {"Authorization": f"Key {self.auth_helper.pat}"} + try: + resp = requests.get(url, headers=headers, timeout=30) + resp.raise_for_status() + data = resp.json() + orgs = [] + for uo in data.get("organizations", []): + org = uo.get("organization", {}) + orgs.append({"id": org.get("id", ""), "name": org.get("name", "")}) + return orgs + except Exception as e: + self.logger.debug(f"Failed to list organizations: {e}") + return [] + def __getattr__(self, name): return getattr(self.user_info, name) @@ -708,6 +732,70 @@ def delete_secrets(self, secret_ids: List[str]) -> None: raise Exception(response.status) self.logger.info("\nSecrets Deleted\n%s", response.status) + def list_cloud_providers(self) -> list: + """List available cloud providers (e.g. aws, gcp, azure, vultr). + + Returns: + list: List of CloudProvider protobuf objects with id, name, special_handling fields. + + Example: + >>> from clarifai.client.user import User + >>> client = User(user_id="user_id") + >>> providers = client.list_cloud_providers() + """ + request = service_pb2.ListCloudProvidersRequest() + response = self._grpc_request(self.STUB.ListCloudProviders, request) + if response.status.code != status_code_pb2.SUCCESS: + raise Exception(response.status) + return list(response.cloud_providers) + + def list_cloud_regions(self, cloud_provider_id: str) -> list: + """List available regions for a cloud provider. + + Args: + cloud_provider_id (str): The cloud provider ID (e.g. 'aws', 'gcp'). + + Returns: + list: List of CloudRegion protobuf objects with id field. + + Example: + >>> from clarifai.client.user import User + >>> client = User(user_id="user_id") + >>> regions = client.list_cloud_regions("aws") + """ + request = service_pb2.ListCloudRegionsRequest( + cloud_provider=resources_pb2.CloudProvider(id=cloud_provider_id) + ) + response = self._grpc_request(self.STUB.ListCloudRegions, request) + if response.status.code != status_code_pb2.SUCCESS: + raise Exception(response.status) + return list(response.cloud_regions) + + def list_instance_types(self, cloud_provider_id: str, region: str) -> list: + """List available GPU/instance types for a cloud provider and region. + + Args: + cloud_provider_id (str): The cloud provider ID (e.g. 'aws', 'gcp'). + region (str): The region ID (e.g. 'us-east-1'). + + Returns: + list: List of InstanceType protobuf objects with id, description, compute_info, + price, cloud_provider, region fields. + + Example: + >>> from clarifai.client.user import User + >>> client = User(user_id="user_id") + >>> instance_types = client.list_instance_types("aws", "us-east-1") + """ + request = service_pb2.ListInstanceTypesRequest( + cloud_provider=resources_pb2.CloudProvider(id=cloud_provider_id), + region=region, + ) + response = self._grpc_request(self.STUB.ListInstanceTypes, request) + if response.status.code != status_code_pb2.SUCCESS: + raise Exception(response.status) + return list(response.instance_types) + def list_models( self, user_id: str = None, diff --git a/clarifai/runners/dockerfile_template/Dockerfile.custom_image.template b/clarifai/runners/dockerfile_template/Dockerfile.custom_image.template new file mode 100644 index 00000000..2524c788 --- /dev/null +++ b/clarifai/runners/dockerfile_template/Dockerfile.custom_image.template @@ -0,0 +1,65 @@ +# syntax=docker/dockerfile:1.13-labs + +FROM --platform=$TARGETPLATFORM ${IMAGE} + +ENV DEBIAN_FRONTEND=noninteractive + +# Install Python if not already available in the base image +RUN command -v python3 >/dev/null 2>&1 || \ + (apt-get update && apt-get install -y --no-install-recommends \ + python3 python3-pip python3-venv && \ + rm -rf /var/lib/apt/lists/*) && \ + # Ensure 'python' points to python3 + (command -v python >/dev/null 2>&1 || ln -sf "$(command -v python3)" /usr/bin/python) && \ + # Ensure 'pip' is available + (command -v pip >/dev/null 2>&1 || command -v pip3 >/dev/null 2>&1 || python3 -m ensurepip 2>/dev/null || true) && \ + (command -v pip >/dev/null 2>&1 || ln -sf "$(command -v pip3)" /usr/bin/pip 2>/dev/null || true) + +COPY --link requirements.txt /home/nonroot/requirements.txt + +# Update clarifai package so we always have latest protocol to the API. +RUN pip install --no-cache-dir --break-system-packages -r /home/nonroot/requirements.txt || \ + pip install --no-cache-dir -r /home/nonroot/requirements.txt +RUN pip show clarifai + +# Ensure the nonroot user (uid 65532) exists in /etc/passwd so getpass.getuser() works +# (required by torch._inductor when the pod runs as uid 65532). +RUN grep -q ':65532:' /etc/passwd 2>/dev/null || echo 'nonroot:x:65532:65532:nonroot:/home/nonroot:/usr/sbin/nologin' >> /etc/passwd + +# Set the NUMBA cache dir to /tmp +# Set the TORCHINDUCTOR cache dir to /tmp +# The CLARIFAI* will be set by the templating system. +ENV NUMBA_CACHE_DIR=/tmp/numba_cache \ + TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_cache \ + HOME=/tmp + +##### +# Download checkpoints if config.yaml has checkpoints.when = "build" +COPY --link=true config.yaml /home/nonroot/main/ +# RUN ["python", "-m", "clarifai.cli", "model", "download-checkpoints", "/home/nonroot/main", "--out_path", "/home/nonroot/main/1/checkpoints", "--stage", "build"] + +##### +# Copy in the actual files like config.yaml, requirements.txt, and most importantly 1/model.py +# for the actual model. +# If checkpoints aren't downloaded since a checkpoints: block is not provided, then they will +# be in the build context and copied here as well. +COPY --link=true 1 /home/nonroot/main/1 + +# At this point we only need these for validation in the SDK. +COPY --link=true requirements.txt config.yaml /home/nonroot/main/ + +# Add the model directory to the python path. +ENV PYTHONPATH=${PYTHONPATH}:/home/nonroot/main \ + CLARIFAI_PAT=${CLARIFAI_PAT} \ + CLARIFAI_USER_ID=${CLARIFAI_USER_ID} \ + CLARIFAI_RUNNER_ID=${CLARIFAI_RUNNER_ID} \ + CLARIFAI_NODEPOOL_ID=${CLARIFAI_NODEPOOL_ID} \ + CLARIFAI_COMPUTE_CLUSTER_ID=${CLARIFAI_COMPUTE_CLUSTER_ID} \ + CLARIFAI_API_BASE=${CLARIFAI_API_BASE:-https://api.clarifai.com} + +WORKDIR /home/nonroot/main + +# Finally run the clarifai entrypoint to start the runner loop and local runner server. +ENTRYPOINT ["python", "-m", "clarifai.runners.server"] +CMD ["--model_path", "/home/nonroot/main"] +############################# diff --git a/clarifai/runners/dockerfile_template/Dockerfile.node.template b/clarifai/runners/dockerfile_template/Dockerfile.node.template index bf3a31df..628a904b 100644 --- a/clarifai/runners/dockerfile_template/Dockerfile.node.template +++ b/clarifai/runners/dockerfile_template/Dockerfile.node.template @@ -15,6 +15,9 @@ COPY --link requirements.txt /home/nonroot/requirements.txt RUN ["pip", "install", "--no-cache-dir", "-r", "/home/nonroot/requirements.txt"] RUN ["pip", "show", "--no-cache-dir", "clarifai"] +# Ensure the nonroot user (uid 65532) exists in /etc/passwd so getpass.getuser() works +RUN grep -q ':65532:' /etc/passwd 2>/dev/null || echo 'nonroot:x:65532:65532:nonroot:/home/nonroot:/usr/sbin/nologin' >> /etc/passwd + # Set the NUMBA cache dir to /tmp # Set the TORCHINDUCTOR cache dir to /tmp # The CLARIFAI* will be set by the templaing system. diff --git a/clarifai/runners/dockerfile_template/Dockerfile.template b/clarifai/runners/dockerfile_template/Dockerfile.template index 3de761d4..caff44af 100644 --- a/clarifai/runners/dockerfile_template/Dockerfile.template +++ b/clarifai/runners/dockerfile_template/Dockerfile.template @@ -7,7 +7,7 @@ FROM --platform=$BUILDPLATFORM ${DOWNLOADER_IMAGE} as model-assets # Install minimal tools needed for download -RUN pip install --no-cache-dir clarifai==${CLARIFAI_VERSION} huggingface_hub +RUN pip install --no-cache-dir clarifai==${CLARIFAI_VERSION} huggingface_hub[hf_transfer] WORKDIR /home/nonroot/main @@ -45,6 +45,7 @@ ENV PATH="/home/nonroot/.local/bin:$VIRTUAL_ENV/bin:$PATH" # Update clarifai package so we always have latest protocol to the API. Everything should land in /venv RUN ["uv", "pip", "install", "--no-cache-dir", "-r", "/home/nonroot/requirements.txt"] +RUN ["uv", "pip", "install", "--no-cache-dir", "huggingface_hub[hf_transfer]"] RUN ["uv", "pip", "show", "--no-cache-dir", "clarifai"] # Set the NUMBA cache dir to /tmp diff --git a/clarifai/runners/models/deploy_output.py b/clarifai/runners/models/deploy_output.py new file mode 100644 index 00000000..fe56ff70 --- /dev/null +++ b/clarifai/runners/models/deploy_output.py @@ -0,0 +1,66 @@ +"""Structured output helpers for `clarifai model deploy`. + +Provides phase headers, status lines, and progress indicators using click.echo +so that deploy output is visually organized into clear phases (Validate, Upload, +Build, Deploy, Monitor, Ready) rather than a wall of undifferentiated log lines. +""" + +import click + +HEADER_WIDTH = 58 + + +def phase_header(title): + """Print a phase separator: ── Title ────────────────────""" + bar = "\u2500" * max(1, HEADER_WIDTH - len(title) - 4) + click.echo(click.style(f"\n\u2500\u2500 {title} {bar}", fg="cyan", bold=True)) + + +def info(label, value): + """Print a labeled info line: ' Label: value'""" + styled_label = click.style(f"{label}:", fg="white", bold=True) + click.echo(f" {styled_label:30s} {value}") + + +def status(message, nl=True): + """Print a status message.""" + click.echo(f" {message}", nl=nl) + + +def inline_progress(message): + """Print inline progress (overwrites current line).""" + click.echo(f"\r {message:<70}", nl=False) + + +def clear_inline(): + """Clear inline progress line.""" + click.echo(f"\r{' ':74}\r", nl=False) + + +def success(message): + """Print a success message.""" + click.echo(click.style(f" {message}", fg="green")) + + +def warning(message): + """Print a warning.""" + click.echo(click.style(f" [warning] {message}", fg="yellow")) + + +def hint(label, command): + """Print a CLI command hint: ' Label: command' with yellow command.""" + styled_label = click.style(f"{label}:", fg="white", bold=True) + styled_cmd = click.style(command, fg="yellow") + click.echo(f" {styled_label:30s} {styled_cmd}") + + +def link(label, url): + """Print a clickable URL: ' Label: url' with OSC 8 hyperlink.""" + styled_label = click.style(f"{label}:", fg="white", bold=True) + styled_url = f"\033]8;;{url}\033\\{click.style(url, fg='cyan', underline=True)}\033]8;;\033\\" + click.echo(f" {styled_label:30s} {styled_url}") + + +def event(message): + """Print a deployment event (dimmed).""" + click.echo(click.style(f" {message}", fg="bright_black")) diff --git a/clarifai/runners/models/mcp_class.py b/clarifai/runners/models/mcp_class.py index 4f72bd5c..f271f390 100644 --- a/clarifai/runners/models/mcp_class.py +++ b/clarifai/runners/models/mcp_class.py @@ -253,6 +253,11 @@ def shutdown(self) -> None: logger.info("MCP bridge shut down") def __del__(self): + import sys + + # Skip cleanup during Python shutdown — modules are already torn down + if sys is None or sys.meta_path is None: + return try: self.shutdown() except Exception: diff --git a/clarifai/runners/models/model_builder.py b/clarifai/runners/models/model_builder.py index 28bab546..16969bae 100644 --- a/clarifai/runners/models/model_builder.py +++ b/clarifai/runners/models/model_builder.py @@ -175,6 +175,8 @@ def __init__( platform: Optional[str] = None, pat: Optional[str] = None, base_url: Optional[str] = None, + user_id: Optional[str] = None, + app_id: Optional[str] = None, compute_info_required: bool = False, ): """ @@ -189,6 +191,8 @@ def __init__( :param platform: Target platform(s) for Docker image build (e.g., "linux/amd64" or "linux/amd64,linux/arm64"). This overrides the platform specified in config.yaml. :param pat: Personal access token for authentication. If None, will use environment variables. :param base_url: Base URL for the API. If None, will use environment variables. + :param user_id: Optional user ID to inject into config if missing (for simplified configs). + :param app_id: Optional app ID to inject into config if missing (for simplified configs). :param compute_info_required: Whether inference compute info is required. This affects certain validation and behavior. """ assert app_not_found_action in ["auto_create", "prompt", "error"], ValueError( @@ -204,6 +208,12 @@ def __init__( self.download_validation_only = download_validation_only self.folder = self._validate_folder(folder) self.config = self._load_config(os.path.join(self.folder, 'config.yaml')) + # Auto-resolve user_id if not provided and not in config + if not user_id and 'user_id' not in self.config.get('model', {}): + from clarifai.utils.config import resolve_user_id + + user_id = resolve_user_id(pat=pat, base_url=base_url) + self.config = self.normalize_config(self.config, user_id=user_id, app_id=app_id) self._validate_config() self._validate_config_secrets() self._validate_stream_options() @@ -360,14 +370,18 @@ def _validate_folder(self, folder): f"Folder {folder} not found, please provide a valid folder path" ) files = os.listdir(folder) - assert "config.yaml" in files, "config.yaml not found in the folder" + if "config.yaml" not in files: + raise UserError(f"config.yaml not found in {folder}") # If just downloading we don't need requirements.txt or the python code, we do need the # 1/ folder to put 1/checkpoints into. - assert "1" in files, "Subfolder '1' not found in the folder" + if "1" not in files: + raise UserError(f"Subfolder '1' not found in {folder}") if not self.download_validation_only: - assert "requirements.txt" in files, "requirements.txt not found in the folder" + if "requirements.txt" not in files: + raise UserError(f"requirements.txt not found in {folder}") subfolder_files = os.listdir(os.path.join(folder, '1')) - assert 'model.py' in subfolder_files, "model.py not found in the folder" + if 'model.py' not in subfolder_files: + raise UserError(f"model.py not found in {folder}/1/") return folder @staticmethod @@ -405,9 +419,8 @@ def _validate_config_checkpoints(self): """ if "checkpoints" not in self.config: return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None - assert "type" in self.config.get("checkpoints"), ( - "No loader type specified in the config file" - ) + if "type" not in self.config.get("checkpoints"): + raise UserError("No loader type specified in checkpoints section of config.yaml") loader_type = self.config.get("checkpoints").get("type") if not loader_type: logger.info("No loader type specified in the config file for checkpoints") @@ -418,18 +431,17 @@ def _validate_config_checkpoints(self): f"No 'when' specified in the config file for checkpoints, defaulting to download at {DEFAULT_DOWNLOAD_CHECKPOINT_WHEN}" ) when = checkpoints.get("when", DEFAULT_DOWNLOAD_CHECKPOINT_WHEN) - assert when in [ - "upload", - "build", - "runtime", - ], ( - "Invalid value for when in the checkpoint loader when, needs to be one of ['upload', 'build', 'runtime']" - ) - assert loader_type == "huggingface", "Only huggingface loader supported for now" - if loader_type == "huggingface": - assert "repo_id" in self.config.get("checkpoints"), ( - "No repo_id specified in the config file" + if when not in ["upload", "build", "runtime"]: + raise UserError( + f"Invalid value '{when}' for checkpoints.when. Must be one of: upload, build, runtime" ) + if loader_type != "huggingface": + raise UserError( + f"Unsupported checkpoint loader type '{loader_type}'. Only 'huggingface' is supported." + ) + if loader_type == "huggingface": + if "repo_id" not in self.config.get("checkpoints"): + raise UserError("No repo_id specified in checkpoints section of config.yaml") repo_id = self.config.get("checkpoints").get("repo_id") # get from config.yaml otherwise fall back to HF_TOKEN env var. @@ -506,23 +518,99 @@ def create_app(): create_app() return True + @staticmethod + def normalize_config(config, user_id=None, app_id=None): + """Expand simplified config format to full format. + + Handles: + 1. Inject user_id/app_id from CLI context if missing + 2. Expand compute.instance (or legacy compute.gpu) -> inference_compute_info + 3. Expand simplified checkpoints (infer type, default when) + 4. Set build_info defaults + + This is a no-op for configs that already have all fields. + """ + config = dict(config) + + # 1. Inject user_id/app_id into model section if missing + model = dict(config.get('model', {})) + if user_id and 'user_id' not in model: + model['user_id'] = user_id + if app_id and 'app_id' not in model: + model['app_id'] = app_id + # Default app_id to "main" if still missing (auto-created on deploy/upload) + if 'app_id' not in model: + model['app_id'] = 'main' + # Default model_type_id to "any-to-any" if not specified + if 'model_type_id' not in model: + model['model_type_id'] = 'any-to-any' + config['model'] = model + + # 2. Expand compute.instance (or legacy compute.gpu) -> inference_compute_info + compute = config.get('compute') + if compute and 'inference_compute_info' not in config: + instance = compute.get('instance') or compute.get('gpu') + if instance: + from clarifai.utils.compute_presets import get_inference_compute_for_gpu + + try: + ici = get_inference_compute_for_gpu(instance) + # Always use wildcard accelerator_type so the model can be scheduled + # on any compatible NVIDIA GPU, not locked to a specific type. + if ici.get('num_accelerators', 0) > 0: + ici['accelerator_type'] = ['NVIDIA-*'] + config['inference_compute_info'] = ici + except ValueError: + logger.debug( + f"Could not resolve compute instance '{instance}'. " + "Skipping inference_compute_info normalization." + ) + # Normalize to compute.instance + compute['instance'] = instance + compute.pop('gpu', None) + + # 3. Expand simplified checkpoints + checkpoints = config.get('checkpoints') + if checkpoints: + checkpoints = dict(checkpoints) + if 'type' not in checkpoints and 'repo_id' in checkpoints: + checkpoints['type'] = 'huggingface' + if 'when' not in checkpoints: + checkpoints['when'] = 'runtime' + config['checkpoints'] = checkpoints + + # 4. Build info defaults + if 'build_info' not in config: + config['build_info'] = {'python_version': '3.12'} + + return config + def _validate_config_model(self): - assert "model" in self.config, "model section not found in the config file" + if "model" not in self.config: + raise UserError("'model' section not found in config.yaml") model = self.config.get('model') - assert "user_id" in model, "user_id not found in the config file" - assert "app_id" in model, "app_id not found in the config file" - assert "model_type_id" in model, "model_type_id not found in the config file" - assert "id" in model, "model_id not found in the config file" - if '.' in model.get('id'): - logger.error( - "Model ID cannot contain '.', please remove it from the model_id in the config file" + if "user_id" not in model: + raise UserError( + "user_id could not be resolved. Either:\n" + " - Add 'user_id' to the model section in config.yaml\n" + " - Run 'clarifai login' to set up your CLI config" ) - sys.exit(1) - - assert model.get('user_id') != "", "user_id cannot be empty in the config file" - assert model.get('app_id') != "", "app_id cannot be empty in the config file" - assert model.get('model_type_id') != "", "model_type_id cannot be empty in the config file" - assert model.get('id') != "", "model_id cannot be empty in the config file" + if "app_id" not in model: + raise UserError("app_id not found in config.yaml") + if "model_type_id" not in model: + model["model_type_id"] = "any-to-any" + if "id" not in model: + raise UserError("model id not found in the model section of config.yaml") + if '.' in model.get('id', ''): + raise UserError( + "Model ID cannot contain '.'. Please remove it from the model id in config.yaml." + ) + if not model.get('user_id'): + raise UserError("user_id cannot be empty in config.yaml") + if not model.get('app_id'): + raise UserError("app_id cannot be empty in config.yaml") + if not model.get('id'): + raise UserError("model id cannot be empty in config.yaml") if not self._check_app_exists(): sys.exit(1) @@ -545,18 +633,19 @@ def _validate_config(self): if not self.download_validation_only: self._validate_config_model() - assert "inference_compute_info" in self.config, ( - "inference_compute_info not found in the config file" - ) + if "inference_compute_info" not in self.config: + logger.warning( + "inference_compute_info not found in config. " + "Set 'compute.instance' or 'inference_compute_info' in config.yaml for deployment." + ) if self.config.get("concepts"): model_type_id = self.config.get('model').get('model_type_id') - assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, ( - f"Model type {model_type_id} not supported for concepts" - ) + if model_type_id not in CONCEPTS_REQUIRED_MODEL_TYPE: + raise UserError(f"Model type '{model_type_id}' not supported for concepts") if self.config.get("checkpoints"): - loader_type, _, hf_token, _, _, _ = self._validate_config_checkpoints() + loader_type, _, hf_token, when, _, _ = self._validate_config_checkpoints() if loader_type == "huggingface": is_valid_token = hf_token and HuggingFaceLoader.validate_hftoken(hf_token) @@ -565,22 +654,53 @@ def _validate_config(self): "Continuing without Hugging Face token for validating config in model builder." ) - has_repo_access = HuggingFaceLoader.validate_hf_repo_access( - repo_id=self.config.get("checkpoints", {}).get("repo_id"), - token=hf_token if is_valid_token else None, + repo_id = self.config.get("checkpoints", {}).get("repo_id") + config_hf_token = self.config.get("checkpoints", {}).get("hf_token") + + # First, check anonymous access (no cached login) to detect gated repos. + has_access, reason = HuggingFaceLoader.validate_hf_repo_access( + repo_id=repo_id, + token=False, # bypass cached huggingface-cli login ) - if not has_repo_access: - logger.error( - f"Invalid Hugging Face repo access for repo {self.config.get('checkpoints').get('repo_id')}. Please check your repo and try again." - ) - sys.exit("Token does not have access to HuggingFace repo , exiting.") + if has_access: + # Public repo — no token needed anywhere. + pass + elif reason == "gated_no_token": + # Repo requires auth. Validate with the available token. + if is_valid_token: + has_access, reason = HuggingFaceLoader.validate_hf_repo_access( + repo_id=repo_id, + token=hf_token, + ) + if not is_valid_token or not has_access: + if not is_valid_token: + reason = "gated_no_token" + self._raise_hf_access_error(repo_id, reason) + + # Token works — for build/runtime, persist it to config so + # the container has it too. + if when in ("build", "runtime") and not config_hf_token: + self.config.setdefault("checkpoints", {})["hf_token"] = hf_token + config_path = os.path.join(self.folder, "config.yaml") + if os.path.exists(config_path): + with open(config_path, 'r') as f: + file_config = yaml.safe_load(f) or {} + file_config.setdefault("checkpoints", {})["hf_token"] = hf_token + with open(config_path, 'w') as f: + yaml.dump(file_config, f, sort_keys=False) + logger.info( + "Wrote HF_TOKEN from environment to config.yaml " + "so the build container can access the gated repo." + ) + else: + # not_found or gated_no_access + self._raise_hf_access_error(repo_id, reason) num_threads = self.config.get("num_threads") if num_threads or num_threads == 0: - assert isinstance(num_threads, int) and num_threads >= 1, ValueError( - f"`num_threads` must be an integer greater than or equal to 1. Received type {type(num_threads)} with value {num_threads}." - ) + if not isinstance(num_threads, int) or num_threads < 1: + raise UserError(f"num_threads must be an integer >= 1. Got: {num_threads!r}") else: num_threads = int(os.environ.get("CLARIFAI_NUM_THREADS", 16)) self.config["num_threads"] = num_threads @@ -589,6 +709,33 @@ def _validate_config(self): if not self.download_validation_only: self._validate_agentic_model_requirements() + @staticmethod + def _raise_hf_access_error(repo_id, reason): + """Raise UserError with actionable guidance for HuggingFace access failures.""" + if reason == "gated_no_token": + raise UserError( + f"HuggingFace repo '{repo_id}' requires authentication.\n" + " Set HF_TOKEN in your environment:\n" + " export HF_TOKEN=hf_...\n" + " Or add to config.yaml:\n" + " checkpoints:\n" + " hf_token: hf_...\n" + f" Request access at: https://huggingface.co/{repo_id}" + ) + elif reason == "gated_no_access": + raise UserError( + f"Your HF token does not have access to gated repo '{repo_id}'.\n" + f" Request access at: https://huggingface.co/{repo_id}\n" + " Then wait for approval and retry." + ) + elif reason == "not_found": + raise UserError( + f"HuggingFace repo '{repo_id}' not found.\n" + " Check the repo_id in your config.yaml checkpoints section." + ) + else: + raise UserError(f"Cannot access HuggingFace repo '{repo_id}'.") + def _validate_agentic_model_requirements(self): """ Validate that AgenticModelClass models have required dependencies (fastmcp and mcp) in requirements.txt. @@ -639,7 +786,7 @@ def _validate_stream_options(self): ) if not self.has_proper_usage_tracking(all_python_content): - logger.error( + logger.warning( "Missing configuration to track usage for OpenAI chat completion calls. " "Go to your model scripts and make sure to set both: " "1) stream_options={'include_usage': True}" @@ -863,10 +1010,13 @@ def get_method_signatures(self, mocking=True): @property def client(self): if self._client is None: - assert "model" in self.config, "model info not found in the config file" + if "model" not in self.config: + raise UserError("'model' section not found in config.yaml") model = self.config.get('model') - assert "user_id" in model, "user_id not found in the config file" - assert "app_id" in model, "app_id not found in the config file" + if "user_id" not in model: + raise UserError("user_id not found in config.yaml") + if "app_id" not in model: + raise UserError("app_id not found in config.yaml") # The owner of the model and the app. user_id = model.get('user_id') app_id = model.get('app_id') @@ -928,14 +1078,19 @@ def model_api_url(self): ) def _get_model_proto(self): - assert "model" in self.config, "model info not found in the config file" + if "model" not in self.config: + raise UserError("'model' section not found in config.yaml") model = self.config.get('model') - assert "model_type_id" in model, "model_type_id not found in the config file" - assert "id" in model, "model_id not found in the config file" + if "model_type_id" not in model: + model["model_type_id"] = "any-to-any" + if "id" not in model: + raise UserError("model id not found in the model section of config.yaml") if not self.download_validation_only: - assert "user_id" in model, "user_id not found in the config file" - assert "app_id" in model, "app_id not found in the config file" + if "user_id" not in model: + raise UserError("user_id not found in config.yaml") + if "app_id" not in model: + raise UserError("app_id not found in config.yaml") model_proto = json_format.ParseDict(model, resources_pb2.Model()) @@ -946,7 +1101,13 @@ def _get_inference_compute_info(self, compute_info_required=False): assert "inference_compute_info" in self.config, ( "inference_compute_info not found in the config file" ) - inference_compute_info = self.config.get('inference_compute_info') or {} + inference_compute_info = self.config.get('inference_compute_info') + if not inference_compute_info: + logger.debug( + "inference_compute_info not found in config. " + "Set 'compute.instance' or 'inference_compute_info' in config.yaml for deployment." + ) + return None # Ensure cpu_limit is a string if it exists and is an int if 'cpu_limit' in inference_compute_info and isinstance( inference_compute_info['cpu_limit'], int @@ -1155,6 +1316,11 @@ def _generate_dockerfile_content(self): # Get the Python version from the config file build_info = self.config.get('build_info', {}) + # Check if custom base image is specified + custom_image = build_info.get('image', '') or '' + if custom_image and str(custom_image).strip(): + return self._generate_custom_image_dockerfile(custom_image.strip()) + # Check if node_version is specified - if so, use the Node.js Dockerfile template node_version = build_info.get('node_version', '') or '' use_node_template = bool(node_version and str(node_version).strip()) @@ -1355,6 +1521,22 @@ def _generate_dockerfile_content(self): return dockerfile_content + def _generate_custom_image_dockerfile(self, image): + """Generate Dockerfile using a custom base image specified in build_info.image.""" + dockerfile_template_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + 'dockerfile_template', + 'Dockerfile.custom_image.template', + ) + with open(dockerfile_template_path, 'r') as template_file: + dockerfile_template = template_file.read() + + dockerfile_template = Template(dockerfile_template) + dockerfile_content = dockerfile_template.safe_substitute(IMAGE=image) + + logger.info(f"Setup: Using custom base image '{image}' from build_info.image") + return dockerfile_content + def create_dockerfile(self, generate_dockerfile=False): """ Create a Dockerfile for the model based on its configuration. @@ -1382,16 +1564,10 @@ def create_dockerfile(self, generate_dockerfile=False): ) should_create_dockerfile = False else: - logger.info("Dockerfile already exists with different content.") - response = input( - "A different Dockerfile already exists. Do you want to overwrite it with the generated one? " - "Type 'y' to overwrite, 'n' to keep your custom Dockerfile: " + logger.warning( + "Custom Dockerfile differs from auto-generated one — keeping yours." ) - if response.lower() != 'y': - logger.info("Keeping existing custom Dockerfile.") - should_create_dockerfile = False - else: - logger.info("Overwriting existing Dockerfile with generated content.") + should_create_dockerfile = False if should_create_dockerfile: # Write Dockerfile @@ -1490,9 +1666,8 @@ def hf_labels_to_config(self, labels, config_file): config = yaml.safe_load(file) model = config.get('model') model_type_id = model.get('model_type_id') - assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, ( - f"Model type {model_type_id} not supported for concepts" - ) + if model_type_id not in CONCEPTS_REQUIRED_MODEL_TYPE: + raise UserError(f"Model type '{model_type_id}' not supported for concepts") concept_protos = self._concepts_protos_from_concepts(labels) config['concepts'] = [ @@ -1564,15 +1739,9 @@ def _get_git_info(self) -> Optional[Dict[str, Any]]: # Not a git repository or git not available return None - def _check_git_status_and_prompt(self) -> bool: - """ - Check for uncommitted changes in git repository within the model path and prompt user. - - Returns: - True if should continue with upload, False if should abort - """ + def _check_git_status(self) -> None: + """Check for uncommitted changes in model path and warn (non-blocking).""" try: - # Check for uncommitted changes within the model path only status_result = subprocess.run( ['git', 'status', '--porcelain', '.'], cwd=self.folder, @@ -1582,21 +1751,16 @@ def _check_git_status_and_prompt(self) -> bool: ) if status_result.stdout.strip(): - logger.warning("Uncommitted changes detected in model path:") - logger.warning(status_result.stdout) - - response = input( - "\nDo you want to continue upload with uncommitted changes? (y/N): " + logger.warning( + "Uncommitted changes detected in model path — uploading working-tree state:" ) - return response.lower() in ['y', 'yes'] + for line in status_result.stdout.strip().splitlines(): + logger.warning(f" {line}") else: logger.info("Model path has no uncommitted changes.") - return True except (subprocess.CalledProcessError, FileNotFoundError): - # Error checking git status, but we already know it's a git repo - logger.warning("Could not check git status, continuing with upload.") - return True + logger.debug("Could not check git status, continuing with upload.") def get_model_version_proto(self, git_info: Optional[Dict[str, Any]] = None): """ @@ -1705,7 +1869,24 @@ def get_model_version_proto(self, git_info: Optional[Dict[str, Any]] = None): ) return model_version_proto - def upload_model_version(self, git_info=None): + def upload_model_version( + self, git_info=None, show_client_script=True, quiet_build=False, post_upload_callback=None + ): + if self.inference_compute_info is None: + raise ValueError( + "inference_compute_info is required for uploading a model.\n" + " Add one of the following to your config.yaml:\n" + " compute:\n" + " instance: gpu-nvidia-a10g # simplified format\n" + " Or:\n" + " inference_compute_info:\n" + " cpu_limit: '4'\n" + " cpu_memory: '16Gi'\n" + " num_accelerators: 1\n" + " accelerator_type: ['NVIDIA-A10G']\n" + " accelerator_memory: '24Gi'\n" + " Run 'clarifai list-instances' to see available options." + ) file_path = f"{self.folder}.tar.gz" logger.debug(f"Will tar it into file: {file_path}") @@ -1726,8 +1907,8 @@ def upload_model_version(self, git_info=None): if when != "upload" and not HuggingFaceLoader.validate_config( self.checkpoint_path ): - input( - "Press Enter to download the HuggingFace model's config.json file to infer the concepts and continue..." + logger.info( + "Downloading HuggingFace model config.json to infer concepts..." ) loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token) loader.download_config(self.checkpoint_path) @@ -1745,10 +1926,25 @@ def filter_func(tarinfo): exclude = [self.tar_file, "*~", "*.pyc", "*.pyo", "__pycache__", ".ruff_cache"] if when != "upload": exclude.append(self.checkpoint_suffix) + # Exclude on-disk config.yaml — we inject the normalized in-memory version below + if name == './config.yaml' or name == 'config.yaml': + return None return None if any(name.endswith(ex) for ex in exclude) else tarinfo + import io + with tarfile.open(self.tar_file, "w:gz") as tar: tar.add(self.folder, arcname=".", filter=filter_func) + # Inject the normalized in-memory config (with user_id, app_id, + # inference_compute_info, etc.) so the packaged image has the full config + # without ever modifying the user's on-disk config.yaml. + config_bytes = yaml.dump( + self.config, default_flow_style=False, sort_keys=False + ).encode('utf-8') + config_info = tarfile.TarInfo(name='./config.yaml') + config_info.size = len(config_bytes) + config_info.mtime = int(time.time()) + tar.addfile(config_info, io.BytesIO(config_bytes)) logger.debug("Tarring complete, about to start upload.") file_size = os.path.getsize(self.tar_file) @@ -1779,22 +1975,37 @@ def filter_func(tarinfo): percent_completed = response.status.percent_completed details = response.status.details - print( - f"Status: {response.status.description}, Progress: {percent_completed}% - {details} ", - f"request_id: {response.status.req_id}", - end='\r', - flush=True, - ) + if quiet_build: + print( + f"\r Uploading... {percent_completed}%", + end='', + flush=True, + ) + else: + print( + f"Status: {response.status.description}, Progress: {percent_completed}% - {details} ", + f"request_id: {response.status.req_id}", + end='\r', + flush=True, + ) + if quiet_build: + # Overwrite "Uploading... X%" with "done" and newline + print(f"\r Uploading... done{' ':50}", flush=True) if response.status.code != status_code_pb2.MODEL_BUILDING: logger.error(f"Failed to upload model version: {response}") return self.model_version_id = response.model_version_id logger.info(f"Created Model Version ID: {self.model_version_id}") logger.info(f"Full url to that version is: {self.model_ui_url}") + + # Callback for deploy orchestrator to emit Version/URL before build starts + if post_upload_callback: + post_upload_callback(self.model_version_id, self.model_ui_url) + is_uploaded = False try: - is_uploaded = self.monitor_model_build() - if is_uploaded: + is_uploaded = self.monitor_model_build(quiet=quiet_build) + if is_uploaded and show_client_script: # python code to run the model. method_signatures = self.get_method_signatures() @@ -1882,10 +2093,21 @@ def get_model_build_logs(self, current_page=1): response = self.client.STUB.ListLogEntries(logs_request) return response - def monitor_model_build(self): + def monitor_model_build(self, quiet=False): + """Monitor model build, optionally suppressing detailed Docker build logs. + + Args: + quiet: If True, suppress Docker build step logs and show only a progress + indicator. Build completion/failure is always shown. + """ st = time.time() seen_logs = set() # To avoid duplicate log messages current_page = 1 # Track current page for log pagination + if quiet: + from clarifai.runners.models import deploy_output as out + + out.phase_header("Build") + while True: resp = self.client.STUB.GetModelVersion( service_pb2.GetModelVersionRequest( @@ -1896,109 +2118,206 @@ def monitor_model_build(self): ) status_code = resp.model_version.status.code - logs = self.get_model_build_logs(current_page) - entries_count = 0 - for log_entry in logs.log_entries: - entries_count += 1 - if log_entry.url not in seen_logs: - seen_logs.add(log_entry.url) - log_entry_msg = re.sub( - r"(\\*)(\[[a-z#/@][^[]*?])", - lambda m: f"{m.group(1)}{m.group(1)}\\{m.group(2)}", - log_entry.message.strip(), - ) - logger.info(log_entry_msg) - # If we got a full page (50 entries), there might be more logs on the next page - # If we got fewer than 50 entries, we've reached the end and should stay on current page - if entries_count == 50: - current_page += 1 - # else: stay on current_page - if status_code == status_code_pb2.MODEL_BUILDING: - print( - f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True - ) + if not quiet: + logs = self.get_model_build_logs(current_page) + entries_count = 0 + for log_entry in logs.log_entries: + entries_count += 1 + if log_entry.url not in seen_logs: + seen_logs.add(log_entry.url) + log_entry_msg = re.sub( + r"(\\*)(\[[a-z#/@][^[]*?])", + lambda m: f"{m.group(1)}{m.group(1)}\\{m.group(2)}", + log_entry.message.strip(), + ) + logger.info(log_entry_msg) + + # If we got a full page (50 entries), there might be more logs on the next page + if entries_count == 50: + current_page += 1 - # Fetch and display the logs + if status_code == status_code_pb2.MODEL_BUILDING: + elapsed = time.time() - st + if quiet: + out.inline_progress(f"Building image... ({elapsed:.0f}s)") + else: + print( + f"Model is building... (elapsed {elapsed:.1f}s)", + end='\r', + flush=True, + ) time.sleep(1) elif status_code == status_code_pb2.MODEL_TRAINED: - logger.info("Model build complete!") - logger.info(f"Build time elapsed {time.time() - st:.1f}s)") - logger.info( - f"Check out the model at {self.model_ui_url} version: {self.model_version_id}" - ) + elapsed = time.time() - st + if quiet: + out.clear_inline() + out.status(f"Building image... done ({elapsed:.1f}s)") + else: + logger.info("Model build complete!") + logger.info(f"Build time elapsed {elapsed:.1f}s)") + logger.info( + f"Check out the model at {self.model_ui_url} version: {self.model_version_id}" + ) return True else: - logger.info( - f"\nModel build failed with status: {resp.model_version.status} and response {resp}" - ) + if quiet: + out.clear_inline() + out.warning( + f"Model build failed with status: {resp.model_version.status.description}" + ) + # Always show build logs on failure so users can diagnose + page = 1 + has_logs = False + while True: + logs = self.get_model_build_logs(page) + if not logs.log_entries: + break + if not has_logs: + out.status("Build logs:") + has_logs = True + for log_entry in logs.log_entries: + msg = log_entry.message.strip() + if msg: + out.status(f" {msg}") + if len(logs.log_entries) < 50: + break + page += 1 + if not has_logs: + out.status( + "No build logs available yet. Check the platform UI for details." + ) + else: + logger.info( + f"\nModel build failed with status: {resp.model_version.status} and response {resp}" + ) return False def upload_model( folder, - stage, - skip_dockerfile, platform: Optional[str] = None, pat: Optional[str] = None, base_url: Optional[str] = None, + verbose: bool = False, ): """ Uploads a model to Clarifai. :param folder: The folder containing the model files. - :param stage: The stage we are calling download checkpoints from. Typically this would "upload" and will download checkpoints if config.yaml checkpoints section has when set to "upload". Other options include "runtime" to be used in load_model or "upload" to be used during model upload. Set this stage to whatever you have in config.yaml to force downloading now. - :param skip_dockerfile: If True, will skip Dockerfile generation entirely. If False or not provided, intelligently handle existing Dockerfiles with user confirmation. :param platform: Target platform(s) for Docker image build (e.g., "linux/amd64" or "linux/amd64,linux/arm64"). This overrides the platform specified in config.yaml. :param pat: Personal access token for authentication. If None, will use environment variables. :param base_url: Base URL for the API. If None, will use environment variables. + :param verbose: If True, show detailed SDK logs and build output. """ - builder = ModelBuilder( - folder, - app_not_found_action="prompt", - platform=platform, - pat=pat, - base_url=base_url, - compute_info_required=True, + import click + + from clarifai.runners.models import deploy_output as out + from clarifai.runners.models.model_deploy import _quiet_sdk_logger + + suppress = not verbose + + # ── Validate ── + out.phase_header("Validate") + with _quiet_sdk_logger(suppress): + builder = ModelBuilder( + folder, + app_not_found_action="prompt", + platform=platform, + pat=pat, + base_url=base_url, + compute_info_required=True, + ) + builder.download_checkpoints(stage="upload") + # Use existing Dockerfile if present, otherwise auto-generate + if not os.path.exists(os.path.join(folder, 'Dockerfile')): + builder.create_dockerfile(generate_dockerfile=True) + + # Validation summary + model_config = builder.config.get('model', {}) + out.info( + "Model", + f"{model_config.get('user_id', '')}/{model_config.get('app_id', '')}/models/{builder.model_id}", ) - builder.download_checkpoints(stage=stage) + out.info("Type", model_config.get('model_type_id', 'unknown')) - if not skip_dockerfile: - builder.create_dockerfile() + compute = builder.config.get('compute', {}) + instance_label = compute.get('instance') or compute.get('gpu') or 'cpu' + out.info("Instance", instance_label) - exists = builder.check_model_exists() - if exists: - logger.info( - f"Model already exists at {builder.model_ui_url}, this upload will create a new version for it." - ) - else: - logger.info( - f"New model will be created at {builder.model_ui_url} with it's first version." - ) + checkpoints = builder.config.get('checkpoints', {}) + if checkpoints and checkpoints.get('repo_id'): + out.info("Checkpoints", checkpoints['repo_id']) + + dockerfile_exists = os.path.exists(os.path.join(folder, 'Dockerfile')) + out.info("Dockerfile", "existing" if dockerfile_exists else "auto-generated") - # Check for git repository information git_info = builder._get_git_info() if git_info: - logger.info(f"Detected git repository: {git_info.get('url', 'local repository')}") - logger.info(f"Current commit: {git_info['commit']}") - logger.info(f"Current branch: {git_info['branch']}") + branch = git_info.get('branch', '') + commit = git_info.get('commit', '')[:8] + out.info("Git", f"{branch} @ {commit}") + builder._check_git_status() + + # ── Upload ── + out.phase_header("Upload") + + def _on_upload_complete(version_id, url): + out.info("Version", version_id) + out.info("URL", url) + + with _quiet_sdk_logger(suppress): + model_version_id = builder.upload_model_version( + git_info, + show_client_script=False, + quiet_build=not verbose, + post_upload_callback=_on_upload_complete, + ) - # Check for uncommitted changes and prompt user - if not builder._check_git_status_and_prompt(): - logger.info("Upload cancelled by user due to uncommitted changes.") - return - input("Press Enter to continue...") + if not model_version_id: + out.warning("Upload failed. Check logs above for details.") + return - model_version = builder.upload_model_version(git_info) + # ── Ready ── + out.phase_header("Ready") + click.echo() + out.success("Model uploaded successfully!") + click.echo() + out.link("Model", builder.model_ui_url) + out.info("Version", model_version_id) - # Ask user if they want to deploy the model - if model_version is not None: # if it comes back None then it failed. - if get_yes_no_input("\n🔶 Do you want to deploy the model?", True): - # Setup deployment for the uploaded model - setup_deployment_for_model(builder) - else: - logger.info("Model uploaded successfully. Skipping deployment setup.") - return + # Client script + method_signatures = None + try: + method_signatures = builder.get_method_signatures() + snippet = code_script.generate_client_script( + method_signatures, + user_id=builder.client.user_app_id.user_id, + app_id=builder.client.user_app_id.app_id, + model_id=builder.model_proto.id, + colorize=True, + ) + click.echo("\n" + "=" * 60) + click.echo("# Here is a code snippet to use this model:") + click.echo("=" * 60) + click.echo(snippet) + click.echo("=" * 60) + except Exception: + pass + + user_id = builder.client.user_app_id.user_id + app_id = builder.client.user_app_id.app_id + model_id = builder.model_proto.id + model_ref = f"{user_id}/{app_id}/models/{model_id}" + predict_cmd = code_script.generate_predict_hint(method_signatures or [], model_ref) + + out.phase_header("Next Steps") + out.hint( + "Deploy", + f'clarifai model deploy --model-url "{builder.model_ui_url}" --instance ', + ) + out.hint("GPU info", "clarifai list-instances") + out.hint("Predict", predict_cmd) def deploy_model( @@ -2015,6 +2334,7 @@ def deploy_model( max_replicas=5, pat=None, base_url=None, + quiet=False, ): """ Deploy a model on Clarifai platform. @@ -2095,12 +2415,14 @@ def deploy_model( deployment_id=deployment_id, deployment_config=deployment_config ) - print( - f"✅ Deployment '{deployment_id}' successfully created for model '{model_id}' with version '{model_version_id}'." - ) + if not quiet: + print( + f"✅ Deployment '{deployment_id}' successfully created for model '{model_id}' with version '{model_version_id}'." + ) return True except Exception as e: - print(f"❌ Failed to create deployment '{deployment_id}': {e}") + if not quiet: + print(f"❌ Failed to create deployment '{deployment_id}': {e}") return False diff --git a/clarifai/runners/models/model_deploy.py b/clarifai/runners/models/model_deploy.py new file mode 100644 index 00000000..051bbd56 --- /dev/null +++ b/clarifai/runners/models/model_deploy.py @@ -0,0 +1,1449 @@ +"""Non-interactive model deployment orchestrator. + +Orchestrates upload + build + deploy in a single command with zero prompts. +All decisions come from parameters or config. Errors are raised as UserError +with actionable fix instructions. This class NEVER calls input(), click.confirm(), +click.prompt(), or any interactive function. +""" + +import contextlib +import json +import logging +import os +import re +import time +import uuid + +from clarifai.errors import UserError +from clarifai.urls.helper import ClarifaiUrlHelper +from clarifai.utils.compute_presets import ( + get_compute_cluster_config, + get_deploy_compute_cluster_id, + get_deploy_nodepool_id, + get_nodepool_config, + resolve_gpu, +) +from clarifai.utils.logging import logger + +# Default deployment monitoring settings +DEFAULT_MONITOR_TIMEOUT = 1200 # 20 minutes +DEFAULT_POLL_INTERVAL = 5 # seconds +DEFAULT_LOG_TAIL_DURATION = 15 # seconds to check for runner logs after pods are ready + +# K8s events to skip in default (non-verbose) mode — transient scheduler noise +_SKIP_EVENTS = {"TaintManagerEviction", "SandboxChanged", "FailedCreatePodSandBox"} + +# Map k8s Reason to human-friendly status for non-verbose mode +_EVENT_PHASE_MAP = { + "FailedScheduling": "Scheduling", + "NotTriggerScaleUp": "Scaling", + "NominatedNode": "Nominated", + "Nominated": "Nominated", + "Scheduled": "Scheduled", + "Pulling": "Pulling image", + "Pulled": "Image pulled", + "Created": "Starting", + "Started": "Running", + "BackOff": "Restarting", + "Unhealthy": "Health check", + "Killing": "Stopping", + "Preempted": "Preempted", + "FailedMount": "Volume", + "FailedAttachVolume": "Volume", + "SuccessfulAttachVolume": "Volume", + "ScalingReplicaSet": "Scaling", +} + + +@contextlib.contextmanager +def _quiet_sdk_logger(suppress=True): + """Temporarily suppress SDK logger INFO output for clean deploy output. + + When suppress=True, raises the logger level to WARNING so that only + WARNING/ERROR messages are visible. INFO-level noise (thread IDs, + microsecond timestamps, protobuf status dumps) is hidden. + + Args: + suppress: If True, suppress INFO. If False, no-op (verbose mode). + """ + if not suppress: + yield + return + old_level = logger.level + logger.setLevel(logging.WARNING) + try: + yield + finally: + logger.setLevel(old_level) + + +class ModelDeployer: + """Non-interactive model deployment orchestrator. + + Two modes: + 1. Local model (upload + deploy): model_path provided + 2. Existing model (deploy only): model_url provided + """ + + def __init__( + self, + model_path=None, + model_url=None, + user_id=None, + app_id=None, + model_version_id=None, + instance_type=None, + cloud_provider=None, + region=None, + compute_cluster_id=None, + nodepool_id=None, + min_replicas=1, + max_replicas=5, + pat=None, + base_url=None, + stage="runtime", + verbose=False, + ): + self.model_path = model_path + self.model_url = model_url + self.model_id = None + self.user_id = user_id + self.app_id = app_id + self.model_version_id = model_version_id + self.instance_type = instance_type + self.cloud_provider = cloud_provider + self.region = region + self.compute_cluster_id = compute_cluster_id + self.nodepool_id = nodepool_id + self.min_replicas = min_replicas + self.max_replicas = max_replicas + self.pat = pat + self.base_url = base_url + self.stage = stage + self.verbose = verbose + + # Resolved during deploy + self._builder = None + self._gpu_preset = None + + def deploy(self): + """Run the full deployment pipeline. Returns a result dict.""" + self._validate() + + if self.model_path: + return self._deploy_local_model() + else: + return self._deploy_existing_model() + + def _validate(self): + """Validate inputs, fail with clear error messages.""" + if not self.model_path and not self.model_url: + raise UserError( + "You must specify either MODEL_PATH (directory) or --model-url.\n" + " Local model: clarifai model deploy ./my-model --instance gpu-nvidia-a10g\n" + " Existing model: clarifai model deploy --model-url --instance gpu-nvidia-a10g" + ) + if self.model_path and self.model_url: + raise UserError("Specify only one of: MODEL_PATH or --model-url.") + + if self.model_url: + user_id, app_id, _, model_id, _ = ClarifaiUrlHelper.split_clarifai_url(self.model_url) + self.user_id = self.user_id or user_id + self.app_id = self.app_id or app_id + self.model_id = model_id + + if not self.instance_type and not self.nodepool_id: + raise UserError( + "You must specify --instance or --nodepool-id when deploying an existing model.\n" + " Example: clarifai model deploy --model-url --instance a10g\n" + " Run 'clarifai list-instances' to see available options." + ) + + # Validate instance type early (before upload/deployment work) + if self.instance_type: + try: + self._resolve_gpu() + except ValueError as e: + raise UserError(str(e)) + + def _resolve_gpu(self): + """Resolve GPU name to preset info if gpu is specified.""" + if self.instance_type and not self._gpu_preset: + self._gpu_preset = resolve_gpu( + self.instance_type, + pat=self.pat, + base_url=self.base_url, + cloud_provider=self.cloud_provider, + region=self.region, + ) + return self._gpu_preset + + def _write_instance_to_config(self, instance_type_id): + """Persist auto-selected instance to config.yaml.""" + config_path = os.path.join(self.model_path, 'config.yaml') + if not os.path.exists(config_path): + return + from clarifai.utils.cli import dump_yaml, from_yaml + + config = from_yaml(config_path) + config.setdefault('compute', {})['instance'] = instance_type_id + dump_yaml(config, config_path) + + def _deploy_local_model(self): + """Upload model from local path, then deploy.""" + from clarifai.runners.models import deploy_output as out + from clarifai.runners.models.model_builder import ModelBuilder + + suppress = not self.verbose + + # ── Validate ── + out.phase_header("Validate") + with _quiet_sdk_logger(suppress): + self._builder = ModelBuilder( + self.model_path, + app_not_found_action="auto_create", + pat=self.pat, + base_url=self.base_url, + user_id=self.user_id, + app_id=self.app_id, + ) + + # Resolve IDs from the builder's config + model_config = self._builder.config.get('model', {}) + self.user_id = self.user_id or model_config.get('user_id') + self.app_id = self.app_id or model_config.get('app_id') + self.model_id = self._builder.model_id + + # Read compute section from config (instance, cloud, region) + compute = self._builder.config.get('compute', {}) + + # Cloud and region from config (CLI flags take priority) + if not self.cloud_provider: + self.cloud_provider = compute.get('cloud') + if not self.region: + self.region = compute.get('region') + + # If instance not specified, try to read from config + if not self.instance_type and not self.nodepool_id: + compute_instance = compute.get('instance') or compute.get('gpu') + if compute_instance: + self.instance_type = compute_instance + else: + # Fallback: try to infer from inference_compute_info, then auto-recommend + from clarifai.utils.compute_presets import ( + infer_gpu_from_config, + recommend_instance, + ) + + inferred = infer_gpu_from_config(self._builder.config) + if inferred: + self.instance_type = inferred + else: + recommended, reason = recommend_instance( + self._builder.config, + pat=self.pat, + base_url=self.base_url, + model_path=self.model_path, + ) + if recommended: + self.instance_type = recommended + out.info("Auto-selected", f"{recommended} ({reason})") + # Persist to config.yaml so future deploys reuse this choice + self._write_instance_to_config(recommended) + else: + raise UserError( + f"Could not auto-detect instance type. {reason or ''}\n" + " Specify --instance or set 'compute.instance' in config.yaml.\n" + " Run 'clarifai list-instances' to see available options." + ) + + # Show clean validation summary + model_type_id = model_config.get('model_type_id', 'unknown') + instance_label = self.instance_type or 'cpu' + checkpoints = self._builder.config.get('checkpoints', {}) + has_checkpoints = bool(checkpoints and checkpoints.get('repo_id')) + dockerfile_path = os.path.join(self.model_path, 'Dockerfile') + has_dockerfile = os.path.exists(dockerfile_path) + + out.info("Model", f"{self.user_id}/{self.app_id}/models/{self.model_id}") + out.info("Type", model_type_id) + out.info("Instance", instance_label) + if has_checkpoints: + out.info("Checkpoints", checkpoints.get('repo_id', '')) + out.info("Dockerfile", "existing" if has_dockerfile else "auto-generated") + + # Only download checkpoints locally when config says when: upload + # (they must be bundled in the tarball). when: runtime or when: build + # means they'll be fetched inside the container in the cloud. + checkpoint_when = checkpoints.get('when', 'runtime') if has_checkpoints else None + with _quiet_sdk_logger(suppress): + if checkpoint_when and checkpoint_when != 'runtime': + self._builder.download_checkpoints(stage=self.stage) + if has_dockerfile: + pass # Use existing + else: + self._builder.create_dockerfile(generate_dockerfile=True) + + # Resolve inference_compute_info from --instance flag. + # Always override when --instance is provided, even if normalize_config + # already set it from config.yaml — the CLI flag takes priority. + if self.instance_type: + from clarifai.utils.compute_presets import get_inference_compute_for_gpu + + ici = get_inference_compute_for_gpu( + self.instance_type, pat=self.pat, base_url=self.base_url + ) + if ici.get('num_accelerators', 0) > 0: + ici.setdefault('accelerator_type', ['NVIDIA-*']) + self._builder.config['inference_compute_info'] = ici + self._builder.inference_compute_info = self._builder._get_inference_compute_info() + + # ── Upload ── + out.phase_header("Upload") + git_info = self._builder._get_git_info() + + # Callback emits Version/URL after upload completes but before build starts, + # so these info lines appear under Upload, not under the Build phase header. + def _on_upload_complete(version_id, url): + out.info("Version", version_id) + out.info("URL", url) + + with _quiet_sdk_logger(suppress): + model_version_id = self._builder.upload_model_version( + git_info, + show_client_script=False, + quiet_build=not self.verbose, + post_upload_callback=_on_upload_complete, + ) + + if not model_version_id: + raise UserError("Model upload failed. Check logs above for details.") + + self.model_version_id = model_version_id + + # Capture client script and method signatures for display after deployment + try: + from clarifai.runners.utils import code_script + + method_signatures = self._builder.get_method_signatures() + self._method_signatures = method_signatures + self._client_script = code_script.generate_client_script( + method_signatures, + user_id=self.user_id, + app_id=self.app_id, + model_id=self.model_id, + colorize=True, + ) + except Exception: + self._method_signatures = None + self._client_script = None + + # ── Deploy ── + out.phase_header("Deploy") + return self._create_deployment() + + def _deploy_existing_model(self): + """Deploy an already-uploaded model.""" + from clarifai.client import Model + from clarifai.runners.models import deploy_output as out + + suppress = not self.verbose + + out.phase_header("Deploy") + + with _quiet_sdk_logger(suppress): + model = Model( + model_id=self.model_id, + app_id=self.app_id, + user_id=self.user_id, + pat=self.pat, + base_url=self.base_url, + ) + + # Get latest version if not specified + if not self.model_version_id: + versions = list(model.list_versions()) + if not versions: + raise UserError(f"No versions found for model '{self.model_id}'.") + self.model_version_id = versions[0].model_version.id + + # Auto-update compute info if the target instance exceeds model version's spec + if self.instance_type: + self._auto_update_compute_if_needed(model) + + # Fetch method signatures from the model version for client script & predict hint + self._fetch_method_signatures(model) + + out.info("Model", f"{self.user_id}/{self.app_id}/models/{self.model_id}") + out.info("Version", self.model_version_id) + out.info("Instance", self.instance_type or "cpu") + + return self._create_deployment() + + def _get_model_version_compute_info(self, model): + """Fetch the model version's current inference_compute_info from the API. + + Returns: + ComputeInfo proto, or None if not set. + """ + from clarifai_grpc.grpc.api import service_pb2 + + try: + resp = model._grpc_request( + model.STUB.GetModelVersion, + service_pb2.GetModelVersionRequest( + user_app_id=model.user_app_id, + model_id=model.id, + version_id=self.model_version_id, + ), + ) + ci = resp.model_version.inference_compute_info + # Check if compute info is actually populated (not just an empty proto) + if ci.ByteSize() > 0: + return ci + return None + except Exception as e: + logger.debug(f"Failed to fetch model version compute info: {e}") + return None + + def _fetch_method_signatures(self, model): + """Fetch method signatures from the model version for client script & predict hint. + + Populates self._method_signatures and self._client_script from the API + so that --model-url deployments show the same output as local deploys. + """ + from clarifai_grpc.grpc.api import service_pb2 + + try: + resp = model._grpc_request( + model.STUB.GetModelVersion, + service_pb2.GetModelVersionRequest( + user_app_id=model.user_app_id, + model_id=model.id, + version_id=self.model_version_id, + ), + ) + sigs = list(resp.model_version.method_signatures) + if sigs: + self._method_signatures = sigs + from clarifai.runners.utils import code_script + + self._client_script = code_script.generate_client_script( + sigs, + user_id=self.user_id, + app_id=self.app_id, + model_id=self.model_id, + colorize=True, + ) + else: + self._method_signatures = None + self._client_script = None + except Exception as e: + logger.debug(f"Failed to fetch method signatures: {e}") + self._method_signatures = None + self._client_script = None + + @staticmethod + def _needs_compute_update(model_compute_info, instance_compute_info): + """Check if the model version's compute info needs to be updated for the target instance. + + The model version's inference_compute_info acts as a ceiling — the scheduler only + places it on instances at or below those specs. If the target instance exceeds + the model version's spec, we need to update the model version. + + Args: + model_compute_info: ComputeInfo proto from the model version (or None). + instance_compute_info: dict with instance's compute info from the preset. + + Returns: + tuple: (needs_update: bool, reasons: list[str]) + """ + from clarifai.utils.compute_presets import parse_k8s_quantity + + reasons = [] + + # No compute info on model version → needs update + if model_compute_info is None: + return True, ["model version has no inference_compute_info"] + + instance_num_acc = instance_compute_info.get("num_accelerators", 0) + model_num_acc = model_compute_info.num_accelerators + + # num_accelerators: if instance has more GPUs than model specifies + if instance_num_acc > model_num_acc: + reasons.append( + f"num_accelerators: instance has {instance_num_acc}, " + f"model version specifies {model_num_acc}" + ) + + # accelerator_memory: if instance has more GPU memory than model specifies + instance_acc_mem = parse_k8s_quantity(instance_compute_info.get("accelerator_memory", "")) + model_acc_mem = parse_k8s_quantity(model_compute_info.accelerator_memory) + + if instance_acc_mem > 0 and model_acc_mem > 0 and instance_acc_mem > model_acc_mem: + reasons.append( + f"accelerator_memory: instance has {instance_compute_info.get('accelerator_memory')}, " + f"model version specifies {model_compute_info.accelerator_memory}" + ) + + return len(reasons) > 0, reasons + + def _auto_update_compute_if_needed(self, model): + """Auto-update model version compute info if the target instance exceeds its spec. + + Fetches the model version's current inference_compute_info, compares against + the target instance, and patches if the instance exceeds the model's spec. + + Only patches num_accelerators and accelerator_memory — NOT accelerator_type, + since the API rejects changing accelerator_type after upload. + """ + from clarifai_grpc.grpc.api import resources_pb2 + + gpu_preset = self._resolve_gpu() + if not gpu_preset: + return + + instance_compute = gpu_preset.get("inference_compute_info", {}) + + # Fetch model version's current compute info + model_compute = self._get_model_version_compute_info(model) + + needs_update, reasons = self._needs_compute_update(model_compute, instance_compute) + + if not needs_update: + logger.debug("Model version compute info is compatible with target instance.") + return + + # Build a ComputeInfo that preserves the existing accelerator_type + # (the API rejects changes to accelerator_type) while updating + # num_accelerators and accelerator_memory. + existing_acc_type = list(model_compute.accelerator_type) if model_compute else ["NVIDIA-*"] + patch_compute = resources_pb2.ComputeInfo( + num_accelerators=instance_compute.get("num_accelerators", 0), + accelerator_memory=instance_compute.get("accelerator_memory", ""), + accelerator_type=existing_acc_type, + ) + + reason_str = "; ".join(reasons) + if self.verbose: + logger.info( + f"Updating model version compute info to match instance " + f"'{self.instance_type}' ({reason_str})" + ) + model.patch_version( + version_id=self.model_version_id, + inference_compute_info=patch_compute, + ) + + def _get_cloud_and_region(self): + """Determine cloud_provider and region for infrastructure creation. + + Priority: + 1. Explicit --cloud/--region flags + 2. Cloud/region from the resolved GPU preset + 3. Default: aws / us-east-1 + """ + gpu_preset = self._resolve_gpu() + + cloud = self.cloud_provider + region = self.region + + if gpu_preset: + cloud = cloud or gpu_preset.get("cloud_provider") + region = region or gpu_preset.get("region") + + cloud = cloud or "aws" + region = region or "us-east-1" + + return cloud, region + + def _ensure_compute_infrastructure(self): + """Auto-create compute cluster and nodepool if needed. + + Returns: + tuple: (compute_cluster_id, nodepool_id, cluster_user_id) + """ + if self.nodepool_id and self.compute_cluster_id: + return self.compute_cluster_id, self.nodepool_id, self.user_id + + from clarifai.client.user import User + from clarifai.runners.models import deploy_output as out + + suppress = not self.verbose + + with _quiet_sdk_logger(suppress): + user = User(user_id=self.user_id, pat=self.pat, base_url=self.base_url) + gpu_preset = self._resolve_gpu() + cloud, region = self._get_cloud_and_region() + + # Determine compute cluster ID (cloud/region-aware) + cc_id = self.compute_cluster_id or get_deploy_compute_cluster_id(cloud, region) + + # Try to get existing compute cluster, create if not found + with _quiet_sdk_logger(suppress): + try: + user.compute_cluster(cc_id) + except Exception: + out.status(f"Creating compute cluster '{cc_id}'...") + cc_config = get_compute_cluster_config(self.user_id, cloud, region) + user.create_compute_cluster(compute_cluster_config=cc_config) + + # Determine nodepool ID + if self.nodepool_id: + np_id = self.nodepool_id + else: + instance_type_id = gpu_preset["instance_type_id"] if gpu_preset else "cpu-t3a-2xlarge" + np_id = get_deploy_nodepool_id(instance_type_id) + + # Try to get existing nodepool, create if not found + from clarifai.client.compute_cluster import ComputeCluster + + with _quiet_sdk_logger(suppress): + cc = ComputeCluster( + compute_cluster_id=cc_id, + user_id=self.user_id, + pat=self.pat, + base_url=self.base_url, + ) + try: + cc.nodepool(np_id) + except Exception: + out.status(f"Creating nodepool '{np_id}'...") + np_config = get_nodepool_config( + instance_type_id=instance_type_id, + compute_cluster_id=cc_id, + user_id=self.user_id, + compute_info=gpu_preset.get("inference_compute_info") + if gpu_preset + else None, + ) + cc.create_nodepool(nodepool_config=np_config) + + return cc_id, np_id, self.user_id + + def _create_deployment(self): + """Create the deployment using existing deploy_model function.""" + from clarifai.runners.models import deploy_output as out + from clarifai.runners.models.model_builder import deploy_model + + cc_id, np_id, cluster_user_id = self._ensure_compute_infrastructure() + + deployment_id = f"deploy-{self.model_id}-{uuid.uuid4().hex[:6]}" + suppress = not self.verbose + + out.status(f"Deploying to nodepool '{np_id}'...") + with _quiet_sdk_logger(suppress): + success = deploy_model( + model_id=self.model_id, + app_id=self.app_id, + user_id=self.user_id, + deployment_id=deployment_id, + model_version_id=self.model_version_id, + nodepool_id=np_id, + compute_cluster_id=cc_id, + cluster_user_id=cluster_user_id, + min_replicas=self.min_replicas, + max_replicas=self.max_replicas, + pat=self.pat, + base_url=self.base_url, + quiet=suppress, + ) + + if not success: + raise UserError( + f"Deployment failed for model '{self.model_id}'. Check logs above for details." + ) + + out.success(f"Deployment '{deployment_id}' created") + + # Skip monitoring when min_replicas is 0 — no pods will be scheduled + timed_out = False + if self.min_replicas > 0: + timed_out = self._monitor_deployment(deployment_id, np_id, cc_id) + + result = self._format_result(deployment_id, np_id, cc_id) + result['timed_out'] = timed_out + return result + + def _monitor_deployment(self, deployment_id, nodepool_id, compute_cluster_id): + """Monitor deployment status until runner pods are ready or timeout. + + Polls runner status and fetches runner logs to show the user what's happening + after the deployment is created (pod scheduling, image pulling, model loading). + + Returns: + True if monitoring timed out, False if pods became ready. + """ + from clarifai_grpc.grpc.api import service_pb2 + + from clarifai.client.auth import create_stub + from clarifai.client.auth.helper import ClarifaiAuthHelper + from clarifai.runners.models import deploy_output as out + + out.phase_header("Monitor") + + # Create a lightweight client for gRPC calls + auth = ClarifaiAuthHelper.from_env( + user_id=self.user_id, pat=self.pat, base=self.base_url, validate=False + ) + stub = create_stub(auth) + user_app_id = auth.get_user_app_id_proto() + + timeout = DEFAULT_MONITOR_TIMEOUT + poll_interval = DEFAULT_POLL_INTERVAL + start_time = time.time() + seen_logs = set() + seen_messages = set() # Dedup simplified event messages across polls + log_page = 1 + has_inline_progress = False # Track if we printed \r progress + + while time.time() - start_time < timeout: + elapsed = int(time.time() - start_time) + + # List runners for our model version in this nodepool + try: + resp = stub.ListRunners( + service_pb2.ListRunnersRequest( + user_app_id=user_app_id, + compute_cluster_id=compute_cluster_id, + nodepool_id=nodepool_id, + model_version_ids=[self.model_version_id], + ) + ) + runners = list(resp.runners) + except Exception as e: + logger.debug(f"Error listing runners: {e}") + runners = [] + + if runners: + runner = runners[0] + metrics = runner.runner_metrics + pods_running = metrics.pods_running if metrics else 0 + pods_total = metrics.pods_total if metrics else 0 + + # Collect new log lines (without printing yet) + log_page, log_lines = self._fetch_runner_logs( + stub, + user_app_id, + compute_cluster_id, + nodepool_id, + runner.id, + seen_logs, + log_page, + verbose=self.verbose, + seen_messages=seen_messages, + ) + + # Print log lines if any (clear inline progress first) + if log_lines: + if has_inline_progress: + out.clear_inline() + has_inline_progress = False + for line in log_lines: + print(line, flush=True) + + # Check if ready + if pods_running >= max(self.min_replicas, 1): + # Brief delay to let late-arriving k8s events propagate to the API, + # then fetch one final time so fast deploys still show events. + time.sleep(3) + _, final_lines = self._fetch_runner_logs( + stub, + user_app_id, + compute_cluster_id, + nodepool_id, + runner.id, + seen_logs, + log_page, + verbose=self.verbose, + seen_messages=seen_messages, + ) + if final_lines: + if has_inline_progress: + out.clear_inline() + has_inline_progress = False + for line in final_lines: + print(line, flush=True) + + if has_inline_progress: + out.clear_inline() + has_inline_progress = False + out.success( + f"Model is running! Pods: {pods_running}/{pods_total} ({elapsed}s)" + ) + # Tail model logs briefly to show startup output + self._tail_runner_logs( + stub, + user_app_id, + compute_cluster_id, + nodepool_id, + runner.id, + ) + return False # Not timed out + + status_msg = f"Pods: {pods_running}/{pods_total} running ({elapsed}s elapsed)" + else: + status_msg = f"Waiting for runner to be scheduled... ({elapsed}s elapsed)" + + # Inline progress update (overwrite same line) + out.inline_progress(status_msg) + has_inline_progress = True + + time.sleep(poll_interval) + + # Timeout reached — provide actionable context + if has_inline_progress: + out.clear_inline() + + elapsed_min = timeout // 60 + out.warning(f"Pod not ready after {elapsed_min} minutes of monitoring.") + out.status("") + + # Determine the last known stage from seen event messages + last_stage = _infer_last_stage(seen_messages) + if last_stage: + out.status(f" Last observed stage: {last_stage}") + out.status("") + + out.status(" The deployment was created successfully but the model pod") + out.status(" hasn't started yet. Common causes:") + out.status(" - GPU nodes are scaling up (can take 5-15 min)") + out.status(" - Large model image is being pulled") + out.status(" - Model is loading checkpoints into GPU memory") + out.status("") + out.status(" Check progress with:") + out.hint("Events", f'clarifai model logs --deployment "{deployment_id}" --log-type events') + out.hint("Status", f'clarifai model status --deployment "{deployment_id}"') + + return True # Timed out + + @staticmethod + def _fetch_runner_logs( + stub, + user_app_id, + compute_cluster_id, + nodepool_id, + runner_id, + seen_logs, + current_page, + verbose=False, + seen_messages=None, + ): + """Fetch k8s event logs during monitoring. + + Only fetches "runner.events" (k8s events like scheduling, pulling, starting). + Model stdout/stderr ("runner" logs) are reserved for the Startup Logs phase + to avoid consuming them here and leaving that section empty. + + Args: + seen_messages: Optional set for deduplicating simplified output lines across + poll cycles. When not verbose, repeated messages like "Scheduling: Waiting + for node..." are suppressed after the first occurrence. + """ + from clarifai_grpc.grpc.api import service_pb2 + + lines = [] + + try: + resp = stub.ListLogEntries( + service_pb2.ListLogEntriesRequest( + log_type="runner.events", + user_app_id=user_app_id, + compute_cluster_id=compute_cluster_id, + nodepool_id=nodepool_id, + runner_id=runner_id, + page=current_page, + per_page=50, + ) + ) + for entry in resp.log_entries: + log_key = ("runner.events", entry.url or entry.message[:100]) + if log_key not in seen_logs: + seen_logs.add(log_key) + event_lines = _format_event_logs(entry.message.strip(), verbose=verbose) + for line in event_lines: + # Deduplicate simplified messages across polls + if seen_messages is not None and not verbose: + if line in seen_messages: + continue + seen_messages.add(line) + lines.append(line) + except Exception: + pass # Log fetching is best-effort + + return current_page, lines + + def _format_result(self, deployment_id, nodepool_id, compute_cluster_id): + """Format deployment result.""" + ui_base = "https://clarifai.com" + model_url = f"{ui_base}/{self.user_id}/{self.app_id}/models/{self.model_id}" + + instance_desc = "" + if self._gpu_preset: + desc = self._gpu_preset['description'] + # Avoid redundant display like "g5.2xlarge (g5.2xlarge)" + if desc and desc.lower() != self.instance_type.lower(): + instance_desc = f"{self.instance_type} ({desc})" + else: + instance_desc = self.instance_type + elif self.instance_type: + instance_desc = self.instance_type + + cloud, region = self._get_cloud_and_region() + + return { + "model_url": model_url, + "model_id": self.model_id, + "model_version_id": self.model_version_id, + "deployment_id": deployment_id, + "nodepool_id": nodepool_id, + "compute_cluster_id": compute_cluster_id, + "instance_type": instance_desc, + "cloud_provider": cloud, + "region": region, + "user_id": self.user_id, + "app_id": self.app_id, + "client_script": getattr(self, '_client_script', None), + "method_signatures": getattr(self, '_method_signatures', None), + } + + def _tail_runner_logs(self, stub, user_app_id, compute_cluster_id, nodepool_id, runner_id): + """Briefly tail runner logs after pods are ready to show model startup output. + + Fetches log_type="runner" (model pod stdout/stderr) for a short period + so the user can see model loading progress, then exits with a hint. + + Uses the same print-raw approach as stream_model_logs() for reliability, + with optional JSON parsing for cleaner output. + """ + from clarifai_grpc.grpc.api import service_pb2 + + from clarifai.runners.models import deploy_output as out + + model_url = f"https://clarifai.com/{self.user_id}/{self.app_id}/models/{self.model_id}" + + seen_logs = set() + log_page = 1 + has_logs = False + tail_start = time.time() + total_api_entries = 0 # Track raw API entry count for diagnostics + + while time.time() - tail_start < DEFAULT_LOG_TAIL_DURATION: + new_entries = 0 + try: + resp = stub.ListLogEntries( + service_pb2.ListLogEntriesRequest( + log_type="runner", + user_app_id=user_app_id, + compute_cluster_id=compute_cluster_id or "", + nodepool_id=nodepool_id or "", + runner_id=runner_id, + page=log_page, + per_page=50, + ) + ) + entries_count = 0 + for entry in resp.log_entries: + entries_count += 1 + total_api_entries += 1 + log_key = entry.url or entry.message[:100] + if log_key in seen_logs: + continue + seen_logs.add(log_key) + msg = entry.message.strip() + if not msg: + continue + + # Try to extract clean message from JSON logs. + parsed = _parse_runner_log(msg, verbose=self.verbose) + display = parsed + if not display and self.verbose: + display = msg[:200] + + if display: + if not has_logs: + out.phase_header("Startup Logs") + has_logs = True + out.status(display) + new_entries += 1 + if entries_count == 50: + log_page += 1 + except Exception as e: + # Make errors visible — logger.debug is not shown at default level + out.event(f"Log fetch error: {e}") + + # If we displayed logs and then an empty poll, we're done + if has_logs and new_entries == 0: + break + + time.sleep(3) + + if not has_logs: + out.phase_header("Startup Logs") + if total_api_entries > 0: + out.status( + f"{total_api_entries} log entries found but all filtered " + f"(use --verbose to see). Logs may appear shortly." + ) + else: + out.status("No startup logs available yet.") + + out.status("") + out.status("Stream model logs:") + out.status(f' clarifai model logs --model-url "{model_url}"') + + +def _parse_runner_log(raw_msg, verbose=False): + """Parse a runner log line, extracting the message from JSON if possible. + + Raw input example: + '{"msg": "Starting MCP bridge...", "@timestamp": "...", "stack_info": null, ...}' + Output: "Starting MCP bridge..." + Args: + raw_msg: Raw log message string. + verbose: If True, pass through all messages unfiltered. + + Returns: + Cleaned message string, or None if the message should be suppressed. + """ + if not raw_msg: + return None + # Try to parse as JSON (runner logs are often JSON-formatted) + try: + data = json.loads(raw_msg) + if isinstance(data, dict) and "msg" in data: + msg = data["msg"] + if msg and isinstance(msg, str): + # Decode unicode escapes if present (e.g. \ud83d\ude80 → emoji) + try: + msg = msg.encode('utf-16', 'surrogatepass').decode('utf-16') + except (UnicodeDecodeError, UnicodeEncodeError): + pass + return msg + return None + except (json.JSONDecodeError, TypeError): + pass + + # In non-verbose mode, filter noisy lines + if not verbose: + if "DeprecationWarning:" in raw_msg: + return None + if raw_msg.startswith("Downloading ") or raw_msg.startswith(" Downloading "): + return None + if raw_msg.startswith("Installing collected packages:"): + return None + + # Return raw message as-is + return raw_msg + + +def _infer_last_stage(seen_messages): + """Infer the last deployment stage from observed k8s event messages. + + Uses the _EVENT_PHASE_MAP ordering to determine the furthest stage reached. + Returns a human-readable description, or None if no events were observed. + """ + if not seen_messages: + return None + + # Ordered from latest to earliest stage — return the first match + stage_keywords = [ + ("Running", "Container started — model may be loading"), + ("Starting", "Container created"), + ("Health check", "Container started but health check failing"), + ("Restarting", "Container is crash-looping"), + ("Image pulled", "Image downloaded, starting container"), + ("Pulling image", "Downloading model image (can be large)"), + ("Scheduled", "Pod assigned to a node"), + ("Nominated", "Node selected, waiting for scheduling"), + ("Scaling", "Cluster is scaling up to add GPU nodes"), + ("Scheduling", "Waiting for a node with available GPU"), + ("Volume", "Waiting for volume attachment"), + ] + for keyword, description in stage_keywords: + for msg in seen_messages: + if keyword in msg: + return description + + return None + + +def _simplify_k8s_message(reason, message): + """Simplify k8s event messages for non-verbose mode. + + Strips internal node IPs, taint specifications, and pod full names. + + Args: + reason: K8s event reason (e.g. "FailedScheduling", "Pulling"). + message: Raw k8s event message. + + Returns: + Simplified, human-friendly message string. + """ + _SIMPLE = { + "FailedScheduling": "Waiting for node to become available...", + "NotTriggerScaleUp": "Waiting for cluster to scale up...", + "NominatedNode": "Node selected for scheduling", + "Nominated": "Node selected for scheduling", + "Scheduled": "Pod scheduled on node", + "Pulling": "Pulling model image...", + "Pulled": "Model image pulled", + "Created": "Container created", + "Started": "Container started", + "BackOff": "Container restarting (back-off)", + "Unhealthy": "Health check failed, waiting...", + "Killing": "Stopping container...", + "Preempted": "Pod preempted, rescheduling...", + "FailedMount": "Volume mount failed", + "FailedAttachVolume": "Volume attach failed", + "SuccessfulAttachVolume": "Volume attached", + "ScalingReplicaSet": "Scaling replicas...", + } + simplified = _SIMPLE.get(reason) + if simplified: + return simplified + # Truncate anything beyond 80 chars + if len(message) > 80: + return message[:77] + "..." + return message + + +def _format_event_logs(raw_message, verbose=False): + """Parse Kubernetes-style event log entries into formatted lines. + + Raw format: "Name: pod-xyz, Type: Warning, Source: {karpenter }, Reason: FailedScheduling, + FirstTimestamp: ..., LastTimestamp: ..., Message: ..." + Multiple events may be concatenated with newlines. + + Args: + raw_message: Raw k8s event log string. + verbose: If True, show all events with full detail. If False, simplify and filter. + + Returns: + list of formatted strings. + """ + if not raw_message: + return [] + + lines = [] + # Split concatenated events (each starts with "Name:") + events = re.split(r'\n(?=Name:\s)', raw_message) + + for event_str in events: + event_str = event_str.strip() + if not event_str: + continue + + # Extract key fields + type_match = re.search(r'Type:\s*(\w+)', event_str) + reason_match = re.search(r'Reason:\s*(\w+)', event_str) + message_match = re.search(r'Message:\s*(.+?)(?:\s*$)', event_str, re.DOTALL) + + event_type = type_match.group(1) if type_match else "" + reason = reason_match.group(1) if reason_match else "" + message = message_match.group(1).strip() if message_match else event_str + + # In non-verbose mode, skip transient noise events + if not verbose and reason in _SKIP_EVENTS: + continue + + # In non-verbose mode, simplify messages + if not verbose: + message = _simplify_k8s_message(reason, message) + phase = _EVENT_PHASE_MAP.get(reason, reason) + else: + phase = reason + # Truncate very long messages in verbose mode too + if len(message) > 200: + message = message[:197] + "..." + + # Format with type indicator (consistent width) + tag = "warning" if event_type == "Warning" else "event" + + if phase: + lines.append(f" [{tag:7s}] {phase}: {message}") + else: + lines.append(f" [{tag:7s}] {message}") + + return lines + + +def stream_model_logs( + model_url=None, + model_id=None, + user_id=None, + app_id=None, + model_version_id=None, + compute_cluster_id=None, + nodepool_id=None, + pat=None, + base_url=None, + follow=True, + duration=None, + log_type="runner", +): + """Stream model runner logs to stdout. + + Looks up the runner for the given model, then continuously fetches and prints + log entries (model pod stdout/stderr or k8s events). + + Args: + model_url: Clarifai model URL. Used to extract user_id, app_id, model_id. + model_id: Model ID (alternative to model_url). + user_id: User ID. + app_id: App ID. + model_version_id: Specific version (default: latest). + compute_cluster_id: Filter by compute cluster. + nodepool_id: Filter by nodepool. + pat: PAT for auth. + base_url: API base URL. + follow: If True, continuously tail logs. If False, print existing and exit. + duration: Max seconds to tail (None = until Ctrl+C). + log_type: Log type to fetch — "runner" (model stdout/stderr) or + "runner.events" (k8s scheduling/scaling events). + """ + from clarifai_grpc.grpc.api import service_pb2 + + from clarifai.client.auth import create_stub + from clarifai.client.auth.helper import ClarifaiAuthHelper + + # Parse model URL if provided + if model_url: + url_user, url_app, _, url_model, _ = ClarifaiUrlHelper.split_clarifai_url(model_url) + user_id = user_id or url_user + app_id = app_id or url_app + model_id = model_id or url_model + + if not model_id or not user_id: + raise UserError( + "You must specify --model-url or --model-id with --user-id.\n" + " Example: clarifai model logs --model-url https://clarifai.com/user/app/models/id" + ) + + # Get latest version if not specified + if not model_version_id: + from clarifai.client import Model + + model = Model( + model_id=model_id, + app_id=app_id, + user_id=user_id, + pat=pat, + base_url=base_url, + ) + versions = list(model.list_versions()) + if not versions: + raise UserError(f"No versions found for model '{model_id}'.") + model_version_id = versions[0].model_version.id + + # Create gRPC client + auth = ClarifaiAuthHelper.from_env(user_id=user_id, pat=pat, base=base_url, validate=False) + stub = create_stub(auth) + user_app_id = auth.get_user_app_id_proto() + + # Find the runner + runner_id = None + cc_id = compute_cluster_id + np_id = nodepool_id + + try: + resp = stub.ListRunners( + service_pb2.ListRunnersRequest( + user_app_id=user_app_id, + compute_cluster_id=cc_id or "", + nodepool_id=np_id or "", + model_version_ids=[model_version_id], + ) + ) + runners = list(resp.runners) + if runners: + runner = runners[0] + runner_id = runner.id + # Extract cc/np from runner if not provided + if not cc_id and runner.nodepool: + cc_id = ( + runner.nodepool.compute_cluster.id if runner.nodepool.compute_cluster else "" + ) + np_id = runner.nodepool.id + except Exception as e: + logger.debug(f"Error listing runners: {e}") + + if not runner_id: + url_hint = model_url or f"https://clarifai.com/{user_id}/{app_id}/models/{model_id}" + raise UserError( + f"No active runner found for model '{model_id}' (version: {model_version_id}).\n" + " The model is not currently deployed. To deploy it, run:\n" + f" clarifai model deploy --model-url \"{url_hint}\" --instance \n" + " Run 'clarifai list-instances' to see available instance types." + ) + + print(f"Streaming logs for model '{model_id}' (runner: {runner_id})...", flush=True) + print("Press Ctrl+C to stop.\n", flush=True) + + seen_logs = set() + log_page = 1 + start_time = time.time() + poll_interval = 3 # Slightly faster polling for log streaming + + try: + while True: + if duration and (time.time() - start_time) > duration: + break + + try: + resp = stub.ListLogEntries( + service_pb2.ListLogEntriesRequest( + log_type=log_type, + user_app_id=user_app_id, + compute_cluster_id=cc_id or "", + nodepool_id=np_id or "", + runner_id=runner_id, + page=log_page, + per_page=50, + ) + ) + entries_count = 0 + for entry in resp.log_entries: + entries_count += 1 + log_key = entry.url or entry.message[:100] + if log_key not in seen_logs: + seen_logs.add(log_key) + msg = entry.message.strip() + if msg: + print(msg, flush=True) + if entries_count == 50: + log_page += 1 + except Exception as e: + logger.debug(f"Error fetching logs: {e}") + + if not follow: + break + + time.sleep(poll_interval) + except KeyboardInterrupt: + print("\nStopped log streaming.", flush=True) + + +def get_deployment(deployment_id, user_id, pat=None, base_url=None): + """Fetch a single deployment by ID. + + Args: + deployment_id: The deployment ID. + user_id: User ID that owns the deployment. + pat: PAT for auth. + base_url: API base URL. + + Returns: + Deployment proto object. + + Raises: + UserError: If the deployment is not found. + """ + from clarifai_grpc.grpc.api import service_pb2 + from clarifai_grpc.grpc.api.status import status_code_pb2 + + from clarifai.client.auth import create_stub + from clarifai.client.auth.helper import ClarifaiAuthHelper + + auth = ClarifaiAuthHelper.from_env(user_id=user_id, pat=pat, base=base_url, validate=False) + stub = create_stub(auth) + user_app_id = auth.get_user_app_id_proto(user_id=user_id, app_id="") + + request = service_pb2.GetDeploymentRequest( + user_app_id=user_app_id, deployment_id=deployment_id + ) + response = stub.GetDeployment(request) + + if response.status.code != status_code_pb2.SUCCESS: + raise UserError( + f"Deployment '{deployment_id}' not found.\n" + f" Status: {response.status.description}\n" + " Check the deployment ID and try again." + ) + return response.deployment + + +def list_deployments_for_model( + model_id, user_id, app_id, model_version_id=None, pat=None, base_url=None +): + """List deployments for a specific model. + + Uses the ListDeployments API with model_version_ids filter to find + deployments without walking compute clusters/nodepools. + + Args: + model_id: Model ID. + user_id: User ID. + app_id: App ID. + model_version_id: Specific version to filter (default: latest). + pat: PAT for auth. + base_url: API base URL. + + Returns: + List of deployment proto objects. + """ + from clarifai_grpc.grpc.api import service_pb2 + + from clarifai.client.auth import create_stub + from clarifai.client.auth.helper import ClarifaiAuthHelper + + # Get latest version if not specified + if not model_version_id: + from clarifai.client import Model + + model = Model( + model_id=model_id, + app_id=app_id, + user_id=user_id, + pat=pat, + base_url=base_url, + ) + versions = list(model.list_versions()) + if not versions: + raise UserError(f"No versions found for model '{model_id}'.") + model_version_id = versions[0].model_version.id + + auth = ClarifaiAuthHelper.from_env(user_id=user_id, pat=pat, base=base_url, validate=False) + stub = create_stub(auth) + user_app_id = auth.get_user_app_id_proto(user_id=user_id, app_id="") + + response = stub.ListDeployments( + service_pb2.ListDeploymentsRequest( + user_app_id=user_app_id, + model_version_ids=[model_version_id], + per_page=100, + ) + ) + return list(response.deployments) + + +def delete_deployment(deployment_id, user_id, pat=None, base_url=None): + """Delete a deployment by ID. No nodepool needed. + + Args: + deployment_id: The deployment ID to delete. + user_id: User ID that owns the deployment. + pat: PAT for auth. + base_url: API base URL. + + Raises: + UserError: If the deletion fails. + """ + from clarifai_grpc.grpc.api import service_pb2 + from clarifai_grpc.grpc.api.status import status_code_pb2 + + from clarifai.client.auth import create_stub + from clarifai.client.auth.helper import ClarifaiAuthHelper + + auth = ClarifaiAuthHelper.from_env(user_id=user_id, pat=pat, base=base_url, validate=False) + stub = create_stub(auth) + user_app_id = auth.get_user_app_id_proto(user_id=user_id, app_id="") + + request = service_pb2.DeleteDeploymentsRequest(user_app_id=user_app_id, ids=[deployment_id]) + response = stub.DeleteDeployments(request) + + if response.status.code != status_code_pb2.SUCCESS: + raise UserError( + f"Failed to delete deployment '{deployment_id}'.\n" + f" Status: {response.status.description}" + ) diff --git a/clarifai/runners/models/model_run_locally.py b/clarifai/runners/models/model_run_locally.py index 63e615d1..9436b114 100644 --- a/clarifai/runners/models/model_run_locally.py +++ b/clarifai/runners/models/model_run_locally.py @@ -191,8 +191,18 @@ def test_model(self): process.kill() # run the model server - def run_model_server(self, port=8080): - """Run the Clarifai Runners's model server.""" + def run_model_server(self, port=8080, grpc=True, **kwargs): + """Run the Clarifai Runner's model server. + + Args: + port: Port to run the server on. + grpc: If True, start a standalone gRPC server. If False, start a + runner that connects to the Clarifai API (requires kwargs like + user_id, compute_cluster_id, nodepool_id, runner_id, base_url, pat). + **kwargs: Additional arguments passed to clarifai.runners.server + (e.g. user_id, compute_cluster_id, nodepool_id, runner_id, + base_url, pat, num_threads, pool_size). + """ command = [ self.python_executable, @@ -200,10 +210,17 @@ def run_model_server(self, port=8080): "clarifai.runners.server", "--model_path", self.model_path, - "--grpc", "--port", str(port), ] + if grpc: + command.append("--grpc") + # Pass additional runner args + for key, value in kwargs.items(): + if value is None: + continue + command.extend([f"--{key}", str(value)]) + try: logger.info( f"Starting model server at localhost:{port} with the model at {self.model_path}..." @@ -252,11 +269,16 @@ def build_docker_image( with open(dockerfile_path, 'r') as file: lines = file.readlines() - # Comment out the COPY instruction that copies the current folder + # Comment out lines not needed for local container builds: + # - download-checkpoints: checkpoints are mounted at runtime, not baked in + # - downloader/unused.yaml: only exists in cloud build context modified_lines = [] for line in lines: + stripped = line.strip() if 'download-checkpoints' in line and '/home/nonroot/main' in line: modified_lines.append(f'# {line}') + elif stripped.startswith('COPY') and 'downloader/' in line: + modified_lines.append(f'# {line}') else: modified_lines.append(line) @@ -307,7 +329,7 @@ def _validate_test_environment(self): Validate that the current environment supports model testing. Provides immediate feedback for unsupported configurations. This function runs only during CLI commands: - 1. clarifai model local-grpc + 1. clarifai model serve 2. clarifai model local-test """ warnings = [] @@ -395,8 +417,13 @@ def run_docker_container( cmd = ["docker", "run", "--name", container_name, '--rm', "--network", "host"] if self._gpu_is_available(): cmd.extend(["--gpus", "all"]) - # Add volume mappings - cmd.extend(["-v", f"{self.model_path}:/home/nonroot/main"]) + # Add volume mappings (use --mount to handle paths with colons, e.g. "gemma3:1b") + cmd.extend( + [ + "--mount", + f"type=bind,source={self.model_path},target=/home/nonroot/main", + ] + ) # Add environment variables if env_vars: for key, value in env_vars.items(): @@ -405,13 +432,18 @@ def run_docker_container( cmd.extend(["-e", "PYTHONDONTWRITEBYTECODE=1"]) # Add the image name cmd.append(image_name) - # update the CMD to run the server + # Override CMD to run the server (ENTRYPOINT is tini, so we need the full command) + server_cmd = [ + "python", + "-m", + "clarifai.runners.server", + "--model_path", + "/home/nonroot/main", + ] if is_local_runner: kwargs.pop("pool_size", None) # remove pool_size if exists - cmd.extend( + server_cmd.extend( [ - "--model_path", - "/home/nonroot/main", "--compute_cluster_id", str(kwargs.get("compute_cluster_id", None)), "--user_id", @@ -426,12 +458,16 @@ def run_docker_container( str(kwargs.get("pat", None)), "--num_threads", str(kwargs.get("num_threads", 0)), - "--health_check_port", - str(kwargs.get("health_check_port", 8080)), ] ) + # Only pass health_check_port if non-zero (avoids compat issues with + # older clarifai versions inside the container that lack the flag) + hcp = kwargs.get("health_check_port", 8080) + if hcp and hcp > 0: + server_cmd.extend(["--health_check_port", str(hcp)]) else: - cmd.extend(["--model_path", "/home/nonroot/main", "--grpc", "--port", str(port)]) + server_cmd.extend(["--grpc", "--port", str(port)]) + cmd.extend(server_cmd) # Run the container logger.info(f"Running docker commands: {cmd}") process = subprocess.Popen( @@ -473,8 +509,13 @@ def test_model_container( cmd.extend(["--gpus", "all"]) # update the entrypoint for testing the model cmd.extend(["--entrypoint", "python"]) - # Add volume mappings - cmd.extend(["-v", f"{self.model_path}:/home/nonroot/main"]) + # Add volume mappings (use --mount to handle paths with colons, e.g. "gemma3:1b") + cmd.extend( + [ + "--mount", + f"type=bind,source={self.model_path},target=/home/nonroot/main", + ] + ) # Add environment variables if env_vars: for key, value in env_vars.items(): diff --git a/clarifai/runners/server.py b/clarifai/runners/server.py index 65f5961a..a60ee065 100644 --- a/clarifai/runners/server.py +++ b/clarifai/runners/server.py @@ -312,7 +312,7 @@ def serve( base_url, pat, num_threads, - health_check_port, + health_check_port=health_check_port, ) def start_servicer(self, port, pool_size, max_queue_size, max_msg_length, enable_tls): @@ -344,7 +344,7 @@ def start_runner( base_url, pat, num_threads, - health_check_port, + health_check_port=8080, ): # initialize the Runner class. This is what the user implements. assert compute_cluster_id is not None, "compute_cluster_id must be set for the runner." diff --git a/clarifai/runners/utils/code_script.py b/clarifai/runners/utils/code_script.py index 81c2fb43..a2595382 100644 --- a/clarifai/runners/utils/code_script.py +++ b/clarifai/runners/utils/code_script.py @@ -7,6 +7,55 @@ from clarifai.urls.helper import ClarifaiUrlHelper from clarifai.utils.constants import MCP_TRANSPORT_NAME, OPENAI_TRANSPORT_NAME +_SAMPLE_URLS = { + resources_pb2.ModelTypeField.DataType.IMAGE: "https://s3.amazonaws.com/samples.clarifai.com/featured-models/image-captioning-statue-of-liberty.jpeg", + resources_pb2.ModelTypeField.DataType.AUDIO: "https://s3.amazonaws.com/samples.clarifai.com/GoodMorning.wav", + resources_pb2.ModelTypeField.DataType.VIDEO: "https://s3.amazonaws.com/samples.clarifai.com/beer.mp4", +} + + +def generate_predict_hint( + method_signatures: List[resources_pb2.MethodSignature], + model_ref: str, + deployment_id: str = None, +) -> str: + """Build a concrete ``clarifai model predict`` command from method signatures. + + Returns a ready-to-copy CLI string like:: + + clarifai model predict user/app/models/m "Hello world" + clarifai model predict user/app/models/m --url https://...jpg + """ + deployment_flag = f" --deployment {deployment_id}" if deployment_id else "" + base = f"clarifai model predict {model_ref}{deployment_flag}" + + if not method_signatures: + return f"{base} --info" + + # OpenAI-style model — positional text auto-routes to OpenAI client + if has_signature_method(OPENAI_TRANSPORT_NAME, method_signatures): + return f'{base} "Hello world"' + + # MCP model → --info (no direct predict) + if has_signature_method(MCP_TRANSPORT_NAME, method_signatures): + return f"{base} --info" + + # Pick the first user-facing method + sig = method_signatures[0] + if not sig.input_fields: + return f"{base} --info" + + first_field = sig.input_fields[0] + dt = first_field.type + DT = resources_pb2.ModelTypeField.DataType + + if dt in (DT.STR, DT.TEXT): + return f'{base} "Hello world"' + elif dt in _SAMPLE_URLS: + return f"{base} --url {_SAMPLE_URLS[dt]}" + else: + return f"{base} --info" + def has_signature_method( name: str, method_signatures: List[resources_pb2.MethodSignature] @@ -23,6 +72,78 @@ def has_signature_method( ) +def _colorize_script(script: str) -> str: + """Apply Python syntax highlighting to a script string using pygments.""" + try: + from pygments import highlight # type: ignore + from pygments.formatters import TerminalFormatter # type: ignore + from pygments.lexers import PythonLexer # type: ignore + + return highlight(script, PythonLexer(), TerminalFormatter()) + except Exception: + return script + + +def _generate_local_grpc_script( + method_signatures: List[resources_pb2.MethodSignature], + port: int, + colorize: bool = False, +) -> str: + """Generate a Python snippet for calling a model on a local gRPC server. + + Uses ModelClient.from_local_grpc() which auto-discovers method signatures + and provides the same SDK interface as cloud models. + """ + lines = [ + "from clarifai.client.model_client import ModelClient", + "", + f"client = ModelClient.from_local_grpc(port={port})", + "", + ] + + for method_signature in method_signatures: + if method_signature is None: + continue + method_name = method_signature.name + # Skip bidirectional streaming — too complex for a quick snippet + if method_signature.method_type == resources_pb2.RunnerMethodType.STREAMING_STREAMING: + continue + + annotations = _get_annotations_source(method_signature) + # Build sample kwargs from input params + param_parts = [] + for param_name, (param_type, default_value, _required) in annotations.items(): + if param_name == "return": + continue + if default_value is None: + default_value = _set_default_value(param_type) + if default_value is not None: + if param_type == "str": + param_parts.append(f'{param_name}={json.dumps(str(default_value))}') + else: + param_parts.append(f"{param_name}={default_value}") + break # Just show the first param for brevity + + call_args = ", ".join(param_parts) if param_parts else "" + + is_streaming = ( + method_signature.method_type == resources_pb2.RunnerMethodType.UNARY_STREAMING + ) + + lines.append(f"# Method: {method_name}") + if is_streaming: + lines.append(f"for chunk in client.{method_name}({call_args}):") + lines.append(" print(chunk, end='')") + lines.append("print()") + else: + lines.append(f"response = client.{method_name}({call_args})") + lines.append("print(response)") + lines.append("") + + script = "\n".join(lines) + return _colorize_script(script) if colorize else script + + def generate_client_script( method_signatures: List[resources_pb2.MethodSignature], user_id, @@ -35,7 +156,12 @@ def generate_client_script( deployment_user_id: str = None, use_ctx: bool = False, colorize: bool = False, + local_grpc_port: int = None, ) -> str: + # ── Local gRPC mode ──────────────────────────────────────────────── + if local_grpc_port is not None: + return _generate_local_grpc_script(method_signatures, local_grpc_port, colorize=colorize) + url_helper = ClarifaiUrlHelper() # Provide an mcp client config if there is a method named "mcp_transport" @@ -69,7 +195,7 @@ async def main(): if __name__ == "__main__": asyncio.run(main()) """ - return _CLIENT_TEMPLATE + return _colorize_script(_CLIENT_TEMPLATE) if colorize else _CLIENT_TEMPLATE if has_signature_method(OPENAI_TRANSPORT_NAME, method_signatures): openai_api_base = url_helper.openai_api_url() @@ -87,18 +213,16 @@ async def main(): response = client.chat.completions.create( model="{model_ui_url}", messages=[ - {{"role": "system", "content": "Talk like a pirate."}}, {{ "role": "user", "content": "How do I check if a Python object is an instance of a class?", }}, ], - temperature=1.0, stream=False, # stream=True also works, just iterator over the response ) print(response) """ - return _CLIENT_TEMPLATE + return _colorize_script(_CLIENT_TEMPLATE) if colorize else _CLIENT_TEMPLATE # Generate client template _CLIENT_TEMPLATE = ( "import os\n\n" @@ -224,17 +348,7 @@ async def main(): script_lines.append(method_signatures_str) script_lines.append("") script = "\n".join(script_lines) - if colorize: - try: - from pygments import highlight # type: ignore - from pygments.formatters import TerminalFormatter # type: ignore - from pygments.lexers import PythonLexer # type: ignore - - return highlight(script, PythonLexer(), TerminalFormatter()) - except Exception: - # Fallback to plain text if pygments is unavailable - return script - return script + return _colorize_script(script) if colorize else script # get annotations source with default values diff --git a/clarifai/runners/utils/loader.py b/clarifai/runners/utils/loader.py index 0d3be0f1..da7e7d02 100644 --- a/clarifai/runners/utils/loader.py +++ b/clarifai/runners/utils/loader.py @@ -264,8 +264,16 @@ def _get_ignore_patterns(self): return self.ignore_patterns @classmethod - def validate_hf_repo_access(cls, repo_id: str, token: str = None) -> bool: - # check if model exists on HF + def validate_hf_repo_access(cls, repo_id: str, token: str = None) -> tuple: + """Validate access to a HuggingFace repo. + + Returns: + (bool, str): (has_access, reason) where reason is one of: + "" - success + "gated_no_token" - gated repo, no token provided + "gated_no_access" - gated repo, token lacks access + "not_found" - repo does not exist + """ try: from huggingface_hub import auth_check from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError @@ -275,15 +283,13 @@ def validate_hf_repo_access(cls, repo_id: str, token: str = None) -> bool: try: auth_check(repo_id, token=token) logger.info("Hugging Face repo access validated") - return True + return True, "" except GatedRepoError: - logger.error( - "Hugging Face repo is gated. Please make sure you have access to the repo." - ) - return False + if token: + return False, "gated_no_access" + return False, "gated_no_token" except RepositoryNotFoundError: - logger.error("Hugging Face repo not found. Please make sure the repo exists.") - return False + return False, "not_found" @staticmethod def validate_config(checkpoint_path: str): diff --git a/clarifai/utils/cli.py b/clarifai/utils/cli.py index 650ba625..d6e22531 100644 --- a/clarifai/utils/cli.py +++ b/clarifai/utils/cli.py @@ -309,19 +309,42 @@ def get_command(self, ctx: click.Context, cmd_name: str) -> t.Optional[click.Com def format_commands(self, ctx, formatter): sub_commands = self.list_commands(ctx) + limit = formatter.width - 6 - max(len(c) for c in sub_commands) if sub_commands else 80 - rows = [] + # Build command -> (display_name, short_help) map + cmd_info = {} for sub_command in sub_commands: cmd = self.get_command(ctx, sub_command) if cmd is None or getattr(cmd, 'hidden', False): continue + name = sub_command if cmd in self.command_to_aliases: aliases = ', '.join(self.command_to_aliases[cmd]) - sub_command = f'{sub_command} ({aliases})' - cmd_help = cmd.help - rows.append((sub_command, cmd_help)) - - if rows: + name = f'{sub_command} ({aliases})' + help_text = cmd.get_short_help_str(limit=limit) + cmd_info[sub_command] = (name, help_text) + + if not cmd_info: + return + + sections = getattr(self, 'command_sections', None) + if sections: + listed = set() + for section_name, cmd_names in sections: + rows = [] + for cmd_name in cmd_names: + if cmd_name in cmd_info: + rows.append(cmd_info[cmd_name]) + listed.add(cmd_name) + if rows: + with formatter.section(section_name): + formatter.write_dl(rows) + remaining = [(n, h) for cn, (n, h) in cmd_info.items() if cn not in listed] + if remaining: + with formatter.section("Other"): + formatter.write_dl(remaining) + else: + rows = list(cmd_info.values()) with formatter.section("Commands"): formatter.write_dl(rows) @@ -349,6 +372,8 @@ def __init__(self, *args, **kwargs): 'ps': 'pipeline_step', 'pipelinetemplate': 'pipeline_template', 'pt': 'pipeline_template', + 'list-instances': 'list_instances', + 'li': 'list_instances', } def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: @@ -428,8 +453,6 @@ def customize_ollama_model( with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) - # Update the user_id in the model section - config['model']['user_id'] = user_id if 'toolkit' not in config or config['toolkit'] is None: config['toolkit'] = {} if model_name is not None: @@ -441,6 +464,9 @@ def customize_ollama_model( with open(config_path, 'w', encoding='utf-8') as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) + # Simplify the cloned config (remove placeholder user_id/app_id, convert to compute.instance) + simplify_cloned_config(config_path, model_name=model_name) + model_py_path = os.path.join(model_path, "1", "model.py") if not os.path.exists(model_py_path): @@ -561,17 +587,14 @@ def check_requirements_installed(model_path: str = None, dependencies: dict = No True if all dependencies are installed, False otherwise """ - if model_path and dependencies: - logger.warning( - "model_path and dependencies cannot be provided together, using dependencies instead" - ) - dependencies = parse_requirements(model_path) - try: if not dependencies: + if not model_path: + logger.error("No model_path or dependencies provided to check requirements.") + return False dependencies = parse_requirements(model_path) missing = [ - full_req + f"{package_name}{full_req}" if full_req else package_name for package_name, full_req in dependencies.items() if not _is_package_installed(package_name) ] @@ -585,8 +608,16 @@ def check_requirements_installed(model_path: str = None, dependencies: dict = No f"❌ {len(missing)} of {len(dependencies)} required packages are missing in the current environment" ) logger.error("\n".join(f" - {pkg}" for pkg in missing)) - requirements_path = Path(model_path) / "requirements.txt" - logger.warning(f"To install: pip install -r {requirements_path}") + if model_path: + try: + requirements_path = (Path(model_path) / "requirements.txt").relative_to(Path.cwd()) + except ValueError: + requirements_path = Path(model_path) / "requirements.txt" + logger.warning(f"To install: pip install -r {requirements_path}") + logger.warning( + "Tip: use '--mode env' to auto-install deps in a virtualenv, " + "or '--mode container' to run in Docker." + ) return False except Exception as e: @@ -617,9 +648,6 @@ def customize_huggingface_model(model_path, user_id, model_name): with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) - # Update the user_id in the model section - config['model']['user_id'] = user_id - if model_name: # Update the repo_id in checkpoints section if 'checkpoints' not in config: @@ -629,11 +657,66 @@ def customize_huggingface_model(model_path, user_id, model_name): with open(config_path, 'w', encoding='utf-8') as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) + # Simplify the cloned config to use compute.instance shorthand + simplify_cloned_config(config_path, model_name=model_name) + logger.info(f"Updated Hugging Face model repo_id to: {model_name}") else: logger.warning(f"config.yaml not found at {config_path}, skipping model configuration") +def simplify_cloned_config(config_path, user_id=None, model_name=None): + """Post-process a cloned config.yaml to simplified format. + + - Removes user_id/app_id placeholders (will be injected from CLI context at deploy time) + - Converts inference_compute_info to compute.instance shorthand (if it matches a preset) + - Keeps model.id, model_type_id, checkpoints, build_info as-is + """ + if not os.path.exists(config_path): + return + + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + if not config: + return + + # Remove placeholder user_id/app_id from model section + model = config.get('model', {}) + placeholder_values = {'user_id', 'YOUR_USER_ID', 'app_id', 'YOUR_APP_ID', ''} + if model.get('user_id') in placeholder_values: + model.pop('user_id', None) + if model.get('app_id') in placeholder_values: + model.pop('app_id', None) + + # Convert inference_compute_info to compute.instance shorthand + if 'inference_compute_info' in config and 'compute' not in config: + from clarifai.utils.compute_presets import infer_gpu_from_config + + gpu_name = infer_gpu_from_config(config) + if gpu_name: + config['compute'] = {'instance': gpu_name} + del config['inference_compute_info'] + + # Remove placeholder hf_token values from checkpoints + checkpoints = config.get('checkpoints', {}) + if checkpoints: + hf_token = checkpoints.get('hf_token', '') + if hf_token in {'your_hf_token', 'hf_token', 'your-huggingface-token', ''}: + checkpoints.pop('hf_token', None) + + # Update model_id from directory name if it's a placeholder + if model_name: + # Use the last part of the model_name (e.g. 'Llama-3-8B' from 'meta-llama/Llama-3-8B') + simple_name = model_name.split('/')[-1] if '/' in model_name else model_name + model['id'] = simple_name + + config['model'] = model + + with open(config_path, 'w', encoding='utf-8') as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + def customize_lmstudio_model(model_path, user_id, model_name=None, port=None, context_length=None): """Customize the LM Studio model name in the cloned template files. Args: @@ -648,8 +731,6 @@ def customize_lmstudio_model(model_path, user_id, model_name=None, port=None, co if os.path.exists(config_path): with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) - # Update the user_id in the model section - config['model']['user_id'] = user_id if 'toolkit' not in config or config['toolkit'] is None: config['toolkit'] = {} if model_name is not None: @@ -660,10 +741,41 @@ def customize_lmstudio_model(model_path, user_id, model_name=None, port=None, co config['toolkit']['context_length'] = context_length with open(config_path, 'w', encoding='utf-8') as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + # Simplify the cloned config (remove placeholder user_id/app_id, convert to compute.instance) + simplify_cloned_config(config_path, model_name=model_name) + logger.info(f"Updated LM Studio model configuration in: {config_path}") else: logger.warning(f"config.yaml not found at {config_path}, skipping model configuration") + # Patch model.py defaults to match the configured model + model_py_path = os.path.join(model_path, "1", "model.py") + if os.path.exists(model_py_path): + try: + with open(model_py_path, 'r', encoding='utf-8') as f: + content = f.read() + if model_name: + content = content.replace( + 'LMS_MODEL_NAME = os.environ.get("LMS_MODEL_NAME", "google/gemma-3-4b")', + f'LMS_MODEL_NAME = os.environ.get("LMS_MODEL_NAME", "{model_name}")', + ) + if port: + content = content.replace( + 'LMS_PORT = int(os.environ.get("LMS_PORT", "23333"))', + f'LMS_PORT = int(os.environ.get("LMS_PORT", "{port}"))', + ) + if context_length: + content = content.replace( + 'LMS_CONTEXT_LENGTH = int(os.environ.get("LMS_CONTEXT_LENGTH", "4096"))', + f'LMS_CONTEXT_LENGTH = int(os.environ.get("LMS_CONTEXT_LENGTH", "{context_length}"))', + ) + with open(model_py_path, 'w', encoding='utf-8') as f: + f.write(content) + except Exception as e: + logger.error(f"Failed to customize LM Studio model name in {model_py_path}: {e}") + raise + def prompt_required_field(message: str, default: Optional[str] = None) -> str: """Prompt the user for a required field, optionally with a default. diff --git a/clarifai/utils/compute_presets.py b/clarifai/utils/compute_presets.py new file mode 100644 index 00000000..419ff4c3 --- /dev/null +++ b/clarifai/utils/compute_presets.py @@ -0,0 +1,1083 @@ +"""GPU/compute resource discovery via Clarifai API with hardcoded fallbacks. + +This module provides: +1. Dynamic GPU/instance type lookup via ListInstanceTypes API (across all cloud providers) +2. Hardcoded fallback presets for offline / CI usage +3. Auto-create compute cluster & nodepool configs for model deployment +""" + +import os +import re + +import requests + +from clarifai.utils.logging import logger + +# Kubernetes quantity suffixes and their multipliers (in bytes or millicores) +_K8S_SUFFIXES = { + '': 1, + 'm': 0.001, # millicores (CPU) + 'k': 1e3, + 'K': 1e3, + 'Ki': 1024, + 'M': 1e6, + 'Mi': 1024**2, + 'G': 1e9, + 'Gi': 1024**3, + 'T': 1e12, + 'Ti': 1024**4, +} + + +def parse_k8s_quantity(value): + """Parse a Kubernetes quantity string to a numeric value. + + Handles formats like: "24Gi", "16Mi", "4", "100m", "4.5", "1500Mi", "3Gi" + Returns a float (bytes for memory, cores for CPU). + + Args: + value: K8s quantity string or numeric value. + + Returns: + float: Parsed numeric value, or 0 if parsing fails. + """ + if value is None: + return 0 + if isinstance(value, (int, float)): + return float(value) + + value = str(value).strip() + if not value: + return 0 + + match = re.match(r'^([0-9]*\.?[0-9]+)\s*([A-Za-z]*)$', value) + if not match: + return 0 + + number = float(match.group(1)) + suffix = match.group(2) + + multiplier = _K8S_SUFFIXES.get(suffix) + if multiplier is None: + return 0 + + return number * multiplier + + +# Hardcoded fallback presets (used when API is unavailable) +FALLBACK_GPU_PRESETS = { + "CPU": { + "description": "CPU only (no GPU)", + "instance_type_id": "t3a.2xlarge", + "cloud_provider": "aws", + "region": "us-east-1", + "inference_compute_info": { + "cpu_limit": "4", + "cpu_memory": "16Gi", + "num_accelerators": 0, + "accelerator_type": [], + "accelerator_memory": "", + }, + }, + "A10G": { + "description": "NVIDIA A10G 24GB", + "instance_type_id": "gpu-nvidia-a10g", + "cloud_provider": "aws", + "region": "us-east-1", + "inference_compute_info": { + "cpu_limit": "4", + "cpu_memory": "16Gi", + "num_accelerators": 1, + "accelerator_type": ["NVIDIA-A10G"], + "accelerator_memory": "24Gi", + }, + }, + "L40S": { + "description": "NVIDIA L40S 48GB", + "instance_type_id": "gpu-nvidia-l40s", + "cloud_provider": "aws", + "region": "us-east-1", + "inference_compute_info": { + "cpu_limit": "8", + "cpu_memory": "32Gi", + "num_accelerators": 1, + "accelerator_type": ["NVIDIA-L40S"], + "accelerator_memory": "48Gi", + }, + }, + "G6E": { + "description": "NVIDIA L40S 2x48GB", + "instance_type_id": "gpu-nvidia-g6e-2x-large", + "cloud_provider": "aws", + "region": "us-east-1", + "inference_compute_info": { + "cpu_limit": "16", + "cpu_memory": "64Gi", + "num_accelerators": 2, + "accelerator_type": ["NVIDIA-L40S"], + "accelerator_memory": "96Gi", + }, + }, +} + + +def get_deploy_compute_cluster_id(cloud_provider="aws", region="us-east-1"): + """Return a deterministic compute cluster ID for the given cloud/region.""" + return f"deploy-cc-{cloud_provider}-{region}" + + +def get_deploy_nodepool_id(instance_type_id): + """Return a deterministic nodepool ID for the given instance type.""" + import re + + # Sanitize: only alphanumeric, hyphens, underscores allowed in IDs + sanitized = re.sub(r'[^a-zA-Z0-9_-]', '-', instance_type_id) + # Collapse consecutive hyphens + sanitized = re.sub(r'-{2,}', '-', sanitized) + return f"deploy-np-{sanitized}" + + +# Module-level cache for instance types (avoids repeated API calls in one session) +_instance_types_cache = None + + +def _try_list_all_instance_types(pat=None, base_url=None): + """Fetch instance types across all cloud providers and regions. + + Queries the API for all available cloud providers, their regions, + and the instance types in each. Results are cached for the session. + + Returns: + list of InstanceType protos (with cloud_provider and region set), or None on failure. + """ + global _instance_types_cache + if _instance_types_cache is not None: + return _instance_types_cache + + try: + from clarifai.client.user import User + + user = User(pat=pat, base_url=base_url) if (pat or base_url) else User() + + all_instance_types = [] + providers = user.list_cloud_providers() + + for provider in providers: + try: + regions = user.list_cloud_regions(provider.id) + for region in regions: + region_id = getattr(region, 'id', None) or str(region) + try: + instance_types = user.list_instance_types(provider.id, region_id) + all_instance_types.extend(instance_types) + except Exception as e: + logger.debug( + f"Failed to list instance types for {provider.id}/{region_id}: {e}" + ) + except Exception as e: + logger.debug(f"Failed to list regions for {provider.id}: {e}") + + if all_instance_types: + _instance_types_cache = all_instance_types + return all_instance_types + return None + except Exception as e: + logger.debug(f"Failed to fetch instance types from API: {e}") + return None + + +def _try_list_instance_types(cloud_provider="aws", region="us-east-1", pat=None, base_url=None): + """Fetch instance types for a specific cloud provider and region. + + Returns list of InstanceType protos, or None on failure. + """ + try: + from clarifai.client.user import User + + user = User(pat=pat, base_url=base_url) if (pat or base_url) else User() + return user.list_instance_types(cloud_provider, region) + except Exception as e: + logger.debug(f"Failed to fetch instance types from API: {e}") + return None + + +def _sort_instance_types(instance_types): + """Sort instance types by cloud provider, GPU count desc, then ID. + + This ensures consistent priority: aws before gcp before vultr, + matching the display order of list-instances. + """ + return sorted( + instance_types, + key=lambda it: ( + it.cloud_provider.id if it.cloud_provider else "", + -(it.compute_info.num_accelerators if it.compute_info else 0), + it.id, + ), + ) + + +def _match_gpu_name_to_instance_type(gpu_name, instance_types): + """Match a GPU shorthand name (e.g. 'A10G') to an API InstanceType.""" + gpu_upper = gpu_name.upper() + for it in instance_types: + it_id_upper = it.id.upper() + # Match by GPU name in instance type ID + if gpu_upper in it_id_upper: + return it + # Match by accelerator type + if it.compute_info and it.compute_info.accelerator_type: + for acc_type in it.compute_info.accelerator_type: + if gpu_upper in acc_type.upper(): + return it + return None + + +def _instance_type_to_preset(instance_type): + """Convert an API InstanceType proto to a preset dict.""" + ci = instance_type.compute_info + return { + "description": instance_type.description or instance_type.id, + "instance_type_id": instance_type.id, + "cloud_provider": instance_type.cloud_provider.id if instance_type.cloud_provider else "", + "region": instance_type.region or "", + "inference_compute_info": { + "cpu_limit": ci.cpu_limit or "4", + "cpu_memory": ci.cpu_memory or "16Gi", + "num_accelerators": ci.num_accelerators, + "accelerator_type": list(ci.accelerator_type) if ci.accelerator_type else [], + "accelerator_memory": ci.accelerator_memory or "", + }, + } + + +def _normalize_gpu_name(gpu_name): + """Extract GPU shorthand from various formats. + + Handles formats like: + - 'gpu-nvidia-a10g' → 'A10G' + - 'gpu-nvidia-g6e-2x-large' → 'G6E-2X-LARGE' + - 'A10G' → 'A10G' (already short) + - 'g5.xlarge' → 'G5.XLARGE' (no prefix to strip) + """ + name = gpu_name.strip() + # Strip 'gpu-nvidia-' prefix if present + lower = name.lower() + if lower.startswith("gpu-nvidia-"): + name = name[len("gpu-nvidia-") :] + elif lower.startswith("gpu-"): + name = name[len("gpu-") :] + return name.upper() + + +def resolve_gpu(gpu_name, pat=None, base_url=None, cloud_provider=None, region=None): + """Resolve a GPU/instance type name to its full preset info. + + Accepts either: + - Instance type IDs from the API (e.g. 'g5.xlarge', 'g6e.2xlarge', 't3a.2xlarge') + - GPU shorthand names (e.g. 'A10G', 'L40S', 'CPU') as fallback aliases + - Legacy nodepool-style names (e.g. 'gpu-nvidia-a10g') — normalized to GPU shorthand + + Queries all cloud providers/regions unless cloud_provider/region are specified. + If --cloud/--region are given, only that provider+region is queried. + + Args: + gpu_name: Instance type ID or GPU shorthand name. + pat: Optional PAT for API auth. + base_url: Optional API base URL. + cloud_provider: Optional cloud provider filter (e.g. 'aws', 'gcp'). + region: Optional region filter (e.g. 'us-east-1'). + + Returns: + dict with keys: description, instance_type_id, cloud_provider, region, inference_compute_info + + Raises: + ValueError: If GPU name is not found. + """ + # If user specified cloud/region, query only that combination + if cloud_provider and region: + instance_types = _try_list_instance_types( + cloud_provider, region, pat=pat, base_url=base_url + ) + else: + # Query all providers/regions + instance_types = _try_list_all_instance_types(pat=pat, base_url=base_url) + + if instance_types: + # Sort for consistent priority (aws first, matching list-instances order) + instance_types = _sort_instance_types(instance_types) + + # Optionally filter by cloud_provider (even when querying all) + filtered = instance_types + if cloud_provider: + filtered = [ + it + for it in instance_types + if (it.cloud_provider and it.cloud_provider.id == cloud_provider) + ] + if region: + filtered = [it for it in filtered if it.region == region] + + # 1. Exact match by instance type ID (e.g. 'g5.xlarge') + for it in filtered: + if it.id.lower() == gpu_name.lower(): + return _instance_type_to_preset(it) + + # 2. Fuzzy match by GPU shorthand in instance type ID or accelerator type + # Normalize 'gpu-nvidia-a10g' → 'A10G' for better matching + normalized = _normalize_gpu_name(gpu_name) + matched = _match_gpu_name_to_instance_type(normalized, filtered) + if matched: + return _instance_type_to_preset(matched) + + # If filtering narrowed results too much, try unfiltered + if filtered != instance_types: + for it in instance_types: + if it.id.lower() == gpu_name.lower(): + return _instance_type_to_preset(it) + matched = _match_gpu_name_to_instance_type(normalized, instance_types) + if matched: + return _instance_type_to_preset(matched) + + # Fallback to hardcoded presets (by shorthand name, then by instance_type_id) + gpu_upper = gpu_name.upper() + if gpu_upper in FALLBACK_GPU_PRESETS: + return dict(FALLBACK_GPU_PRESETS[gpu_upper]) + gpu_lower = gpu_name.lower() + for preset in FALLBACK_GPU_PRESETS.values(): + if preset["instance_type_id"].lower() == gpu_lower: + return dict(preset) + + available = "Run 'clarifai list-instances' to see available options." + raise ValueError(f"Unknown instance type '{gpu_name}'. {available}") + + +def get_inference_compute_for_gpu(gpu_name, pat=None, base_url=None): + """Get inference_compute_info dict for a GPU name. + + Args: + gpu_name: GPU shorthand name (e.g. 'A10G'). + pat: Optional PAT for API auth. + base_url: Optional API base URL. + + Returns: + dict: inference_compute_info suitable for config.yaml. + """ + preset = resolve_gpu(gpu_name, pat=pat, base_url=base_url) + return dict(preset["inference_compute_info"]) + + +def infer_gpu_from_config(config): + """Infer GPU shorthand name from an existing inference_compute_info config. + + Args: + config: dict with inference_compute_info section. + + Returns: + str or None: GPU name like 'A10G', or None if not recognized. + """ + ici = config.get("inference_compute_info") + if not ici: + return None + + acc_types = ici.get("accelerator_type", []) + if not acc_types or ici.get("num_accelerators", 0) == 0: + return "CPU" + + # Try to match accelerator types against known presets + for gpu_name, preset in FALLBACK_GPU_PRESETS.items(): + if gpu_name == "CPU": + continue + preset_acc = preset["inference_compute_info"].get("accelerator_type", []) + preset_num = preset["inference_compute_info"].get("num_accelerators", 0) + if ( + preset_acc + and set(preset_acc) == set(acc_types) + and preset_num == ici.get("num_accelerators", 0) + ): + return gpu_name + + return None + + +def list_gpu_presets( + pat=None, + base_url=None, + cloud_provider=None, + region=None, + gpu_name=None, + min_gpus=None, + min_gpu_mem=None, +): + """Return a formatted table of available GPU presets. + + Queries all cloud providers/regions via the API, falls back to hardcoded presets. + + Args: + pat: Optional PAT for API auth. + base_url: Optional API base URL. + cloud_provider: Optional filter by cloud provider (e.g. 'aws', 'gcp'). + region: Optional filter by region (e.g. 'us-east-1'). + gpu_name: Optional filter by GPU name substring (case-insensitive, e.g. 'H100'). + min_gpus: Optional minimum GPU count filter. + min_gpu_mem: Optional minimum GPU memory filter (K8s quantity string, e.g. '48Gi'). + + Returns: + str: Formatted table string. + """ + rows = [] + header = "Available instance types (use the ID with --instance flag):\n" + + # Try API first (all providers/regions) + instance_types = _try_list_all_instance_types(pat=pat, base_url=base_url) + if instance_types: + # Filter by cloud/region if specified + filtered = instance_types + if cloud_provider: + filtered = [ + it + for it in filtered + if (it.cloud_provider and it.cloud_provider.id == cloud_provider) + ] + header = f"Available instance types for {cloud_provider}" + if region: + header += f" / {region}" + header += ":\n" + if region: + filtered = [it for it in filtered if it.region == region] + + # Deduplicate by (instance_type_id, cloud_provider) - keep first per combo + seen = set() + deduped = [] + for it in filtered: + cp = it.cloud_provider.id if it.cloud_provider else "" + key = (it.id, cp) + if key not in seen: + seen.add(key) + deduped.append(it) + + # Apply gpu_name filter (case-insensitive substring on accelerator_type entries) + if gpu_name: + gpu_upper = gpu_name.upper() + deduped = [ + it + for it in deduped + if it.compute_info + and it.compute_info.accelerator_type + and any(gpu_upper in acc.upper() for acc in it.compute_info.accelerator_type) + ] + + # Apply min_gpus filter + if min_gpus is not None: + deduped = [ + it + for it in deduped + if it.compute_info and (it.compute_info.num_accelerators or 0) >= min_gpus + ] + + # Apply min_gpu_mem filter + if min_gpu_mem is not None: + threshold = parse_k8s_quantity(min_gpu_mem) + deduped = [ + it + for it in deduped + if it.compute_info + and it.compute_info.accelerator_memory + and parse_k8s_quantity(it.compute_info.accelerator_memory) >= threshold + ] + + # Sort: cloud first, then GPU count desc, then ID + sorted_types = _sort_instance_types(deduped) + for it in sorted_types: + ci = it.compute_info + acc_type = ", ".join(ci.accelerator_type) if ci.accelerator_type else "-" + gpu_mem = ci.accelerator_memory if ci.accelerator_memory else "-" + cloud = it.cloud_provider.id if it.cloud_provider else "-" + rgn = it.region or "-" + rows.append( + { + "--instance value": it.id, + "Cloud": cloud, + "Region": rgn, + "GPUs": ci.num_accelerators or 0, + "Accelerator": acc_type, + "GPU Memory": gpu_mem, + "CPU": ci.cpu_limit, + "CPU Memory": ci.cpu_memory, + } + ) + else: + return ( + "Could not fetch instance types from API.\nMake sure you are logged in: clarifai login" + ) + + if not rows: + return "No instance types match the given filters." + + from tabulate import tabulate + + table = tabulate(rows, headers="keys", tablefmt="simple") + example = "\nExample: clarifai model deploy ./my-model --instance a10g" + return header + table + example + + +def _get_hf_model_info(repo_id): + """Fetch model metadata from HuggingFace API. + + Returns dict with: num_params, quant_method, quant_bits, dtype_breakdown, pipeline_tag. + Returns None on failure. + """ + try: + url = f"https://huggingface.co/api/models/{repo_id}" + resp = requests.get(url, timeout=5) + resp.raise_for_status() + data = resp.json() + except Exception: + logger.debug(f"Failed to fetch HF model info for {repo_id}") + return None + + result = { + "num_params": None, + "quant_method": None, + "quant_bits": None, + "dtype_breakdown": None, + "pipeline_tag": data.get("pipeline_tag"), + } + + # Extract parameter count and dtype breakdown from safetensors metadata + safetensors = data.get("safetensors") + if safetensors: + result["num_params"] = safetensors.get("total") + params_by_dtype = safetensors.get("parameters") + if params_by_dtype: + result["dtype_breakdown"] = dict(params_by_dtype) + + # Extract quantization config + config = data.get("config") or {} + quant_config = config.get("quantization_config") or {} + if quant_config: + result["quant_method"] = quant_config.get("quant_method") + result["quant_bits"] = quant_config.get("bits") + + return result + + +def _detect_quant_from_repo_name(repo_id): + """Detect quantization from repo name. Returns (quant_method, bits) or (None, None).""" + name = repo_id.lower() + patterns = [ + ("-awq", "awq", 4), + ("-gptq", "gptq", 4), + ("-bnb-4bit", "bnb", 4), + ("-int8", None, 8), + ("-int4", None, 4), + ("-4bit", None, 4), + ("-fp16", "fp16", 16), + ] + for suffix, method, bits in patterns: + if suffix in name: + return (method, bits) + return (None, None) + + +def _get_hf_token(config=None): + """Get HuggingFace token from config, environment, or cached token file. + + Checks in order: + 1. config['checkpoints']['hf_token'] + 2. HF_TOKEN environment variable + 3. ~/.cache/huggingface/token (standard HF CLI cache) + + Returns token string or None. + """ + # From config + if config: + token = (config.get('checkpoints') or {}).get('hf_token') + if token: + return token + + # From environment + token = os.environ.get('HF_TOKEN') + if token: + return token + + # From HF CLI cache + token_path = os.path.expanduser('~/.cache/huggingface/token') + try: + with open(token_path) as f: + token = f.read().strip() + if token: + return token + except (OSError, IOError): + pass + + return None + + +def _get_hf_model_config(repo_id, hf_token=None): + """Fetch model config.json from HuggingFace for KV cache calculation. + + Extracts architecture details needed for accurate KV cache sizing: + - num_hidden_layers + - num_key_value_heads (for GQA/MQA models) + - head_dim + - max_position_embeddings (context window) + + Returns dict with these keys, or None if config unavailable or missing required fields. + """ + try: + url = f"https://huggingface.co/{repo_id}/raw/main/config.json" + headers = {} + if hf_token: + headers['Authorization'] = f'Bearer {hf_token}' + resp = requests.get(url, headers=headers, timeout=5) + resp.raise_for_status() + data = resp.json() + except Exception: + logger.debug(f"Failed to fetch HF config.json for {repo_id}") + return None + + # Extract num_hidden_layers (required) + num_layers = ( + data.get('num_hidden_layers') + or data.get('n_layer') + or data.get('n_layers') + or data.get('num_layers') + ) + if not num_layers: + return None + + # Extract num_attention_heads (needed for head_dim fallback and MHA) + num_attention_heads = data.get('num_attention_heads') or data.get('n_head') + + # Extract num_key_value_heads (for GQA/MQA; falls back to num_attention_heads for MHA) + num_kv_heads = data.get('num_key_value_heads') + if num_kv_heads is None: + num_kv_heads = num_attention_heads + if not num_kv_heads: + return None + + # Extract head_dim (explicit field or computed from hidden_size / num_attention_heads) + head_dim = data.get('head_dim') + if not head_dim: + hidden_size = data.get('hidden_size') or data.get('n_embd') + if hidden_size and num_attention_heads: + head_dim = hidden_size // num_attention_heads + else: + return None + + # Extract max_position_embeddings (context window - required for KV cache sizing) + max_seq_len = ( + data.get('max_position_embeddings') + or data.get('max_seq_len') + or data.get('seq_length') + or data.get('n_positions') + ) + if not max_seq_len: + return None + + return { + 'num_hidden_layers': int(num_layers), + 'num_key_value_heads': int(num_kv_heads), + 'head_dim': int(head_dim), + 'max_position_embeddings': int(max_seq_len), + } + + +def _estimate_kv_cache_bytes(model_config, dtype_bytes=2): + """Estimate KV cache memory for full context window. + + Formula: 2 (K+V) × num_layers × num_kv_heads × head_dim × dtype_bytes × max_seq_len + + Args: + model_config: Dict from _get_hf_model_config() with architecture details. + dtype_bytes: Bytes per element (default 2 for FP16/BF16). + + Returns: + int: KV cache size in bytes. + """ + return ( + 2 + * model_config['num_hidden_layers'] + * model_config['num_key_value_heads'] + * model_config['head_dim'] + * dtype_bytes + * model_config['max_position_embeddings'] + ) + + +# Bytes per parameter by dtype/quantization +_BYTES_PER_PARAM = { + "BF16": 2.0, + "F16": 2.0, + "FP16": 2.0, + "F32": 4.0, + "FP32": 4.0, + "I8": 1.0, + "I32": 4.0, + "U8": 1.0, +} + +# Fixed framework overhead (2 GiB) for CUDA context, PyTorch runtime, etc. +_FRAMEWORK_OVERHEAD_FIXED = 2 * 1024**3 +# Variable overhead as fraction of weight bytes (activations, internal buffers) +_FRAMEWORK_OVERHEAD_FRACTION = 0.10 +# Fallback KV cache overhead as fraction of model weights (used when config.json unavailable) +_KV_CACHE_FRACTION = 0.50 + + +def _compute_overhead(weight_bytes): + """Compute framework overhead: 2 GiB fixed + 10% of weight bytes.""" + return int(_FRAMEWORK_OVERHEAD_FIXED + weight_bytes * _FRAMEWORK_OVERHEAD_FRACTION) + + +def _estimate_weight_bytes(num_params, quant_method=None, quant_bits=None, dtype_breakdown=None): + """Estimate model weight bytes. Returns int.""" + if dtype_breakdown: + weight_bytes = 0 + for dtype, count in dtype_breakdown.items(): + bpp = _BYTES_PER_PARAM.get(dtype.upper(), 2.0) + weight_bytes += count * bpp + elif quant_method in ("awq", "gptq"): + bpp = 0.5 if (quant_bits is None or quant_bits == 4) else 1.0 + weight_bytes = num_params * bpp + elif quant_bits: + bpp = quant_bits / 8.0 + weight_bytes = num_params * bpp + else: + # Default: BF16 + weight_bytes = num_params * 2.0 + return int(weight_bytes) + + +def _estimate_vram_bytes(num_params, quant_method=None, quant_bits=None, dtype_breakdown=None): + """Estimate VRAM bytes needed for inference (heuristic fallback). Returns int.""" + weight_bytes = _estimate_weight_bytes(num_params, quant_method, quant_bits, dtype_breakdown) + # Total: weights + KV cache overhead + framework overhead + return int( + weight_bytes + (weight_bytes * _KV_CACHE_FRACTION) + _compute_overhead(weight_bytes) + ) + + +# Pre-Ampere GPU indicators (compute capability < 8.0). +# SGLang requires Ampere+ for CUDA graph capture (RMSNorm kernels). +# Checked case-insensitively against instance IDs across all clouds: +# AWS: "g4dn.xlarge", "p3.2xlarge" +# Azure: "Standard_NC4as_T4_v3" +# GCP: "n1-standard-4-nvidia-tesla-t4" +_PRE_AMPERE_INDICATORS = ("t4", "v100", "k80", "p100", "p40", "m60", "g4dn", "g4ad", "p3.", "p2.") + + +# Minimum GPU utilization headroom: require at least 10% free VRAM. +# vLLM/SGLang default to gpu_memory_utilization=0.9, and the remaining 10% covers +# CUDA block allocator overhead, page tables, and memory fragmentation. +# Without this, a 15.1 GiB model on a 16 GiB GPU leaves only ~0.9 GiB headroom +# which gets eaten by vLLM internals, causing OOM. +_GPU_UTILIZATION_FACTOR = 0.90 + +# Cloud providers supported for auto-recommendation. +# Other providers (e.g. CoreWeave) may have instance types with non-standard VRAM +# configurations that don't map well to our estimation heuristics. +_SUPPORTED_CLOUDS = {"aws", "gcp", "vultr"} + + +def _select_instance_by_vram( + vram_bytes, pat=None, base_url=None, exclude_pre_ampere=False, reason_detail="" +): + """Select smallest instance whose usable VRAM >= vram_bytes. + + Applies a 10% headroom factor (matching vLLM/SGLang gpu_memory_utilization=0.9) + so the selected GPU isn't filled to the brim. + + Args: + vram_bytes: Minimum required VRAM in bytes. + pat: Clarifai PAT for API lookups. + base_url: Clarifai API base URL. + exclude_pre_ampere: If True, skip pre-Ampere instances (T4, V100, etc.). + Required by SGLang which needs compute capability >= 8.0. + reason_detail: Optional detail string for the reason (e.g. weight/KV breakdown). + + Returns (instance_type_id, reason) or (None, reason). + """ + vram_gib = vram_bytes / (1024**3) + estimate_prefix = f"Estimated {vram_gib:.1f} GiB VRAM" + if reason_detail: + estimate_prefix += f" ({reason_detail})" + + def _is_excluded(inst_id): + if not exclude_pre_ampere: + return False + inst_lower = inst_id.lower() + return any(indicator in inst_lower for indicator in _PRE_AMPERE_INDICATORS) + + # Try API first for real available instances + instance_types = _try_list_all_instance_types(pat=pat, base_url=base_url) + if instance_types: + # Build list of (instance_id, vram_bytes) for GPU instances, sorted by VRAM ascending + # Only include instances from supported clouds (aws, gcp, vultr) + gpu_instances = [] + for it in instance_types: + cloud = it.cloud_provider.id if it.cloud_provider else "" + if cloud not in _SUPPORTED_CLOUDS: + continue + ci = it.compute_info if it.compute_info else None + if not ci or not ci.num_accelerators or ci.num_accelerators == 0: + continue + acc_mem = ci.accelerator_memory + if not acc_mem: + continue + mem_bytes = parse_k8s_quantity(acc_mem) + if mem_bytes > 0: + gpu_instances.append((it.id, mem_bytes)) + + # Deduplicate by instance ID, keeping largest VRAM for each + seen = {} + for inst_id, mem in gpu_instances: + if inst_id not in seen or mem > seen[inst_id]: + seen[inst_id] = mem + sorted_instances = sorted(seen.items(), key=lambda x: x[1]) + + for inst_id, mem in sorted_instances: + if _is_excluded(inst_id): + continue + usable = mem * _GPU_UTILIZATION_FACTOR + if usable >= vram_bytes: + mem_gib = mem / (1024**3) + return ( + inst_id, + f"{estimate_prefix}, fits {inst_id} ({mem_gib:.0f} GiB)", + ) + + if sorted_instances: + max_gib = sorted_instances[-1][1] / (1024**3) + return ( + None, + f"{estimate_prefix}, exceeds max available {max_gib:.0f} GiB", + ) + + # Fallback to hardcoded GPU tiers + fallback_tiers = [ + ("gpu-nvidia-a10g", 24 * 1024**3), # A10G: 24 GiB + ("gpu-nvidia-l40s", 48 * 1024**3), # L40S: 48 GiB + ("gpu-nvidia-g6e-2x-large", 96 * 1024**3), # G6E 2x: 96 GiB + ] + for inst_id, mem in fallback_tiers: + if _is_excluded(inst_id): + continue + usable = mem * _GPU_UTILIZATION_FACTOR + if usable >= vram_bytes: + mem_gib = mem / (1024**3) + return ( + inst_id, + f"{estimate_prefix}, fits {inst_id} ({mem_gib:.0f} GiB)", + ) + + return (None, f"{estimate_prefix}, exceeds max 96 GiB") + + +def _detect_toolkit_from_config(config, model_path=None): + """Detect inference toolkit from build_info.image or requirements.txt. + + Checks build_info.image first (e.g. "lmsysorg/sglang:latest"), then + falls back to scanning requirements.txt for known toolkit packages. + + Returns toolkit name ('vllm', 'sglang') or empty string. + """ + # Check build_info.image + build_image = (config.get('build_info', {}).get('image') or '').lower() + if 'sglang' in build_image: + return 'sglang' + elif 'vllm' in build_image: + return 'vllm' + + # Check requirements.txt + if model_path: + try: + from clarifai.utils.cli import parse_requirements + + deps = parse_requirements(model_path) + for name in ('vllm', 'sglang'): + if name in deps: + return name + except Exception: + pass + + return '' + + +def recommend_instance(config, pat=None, base_url=None, toolkit=None, model_path=None): + """Recommend instance type based on model config. + + Args: + config: Parsed config.yaml dict. + pat: Clarifai PAT for API lookups. + base_url: Clarifai API base URL. + toolkit: Explicit toolkit name (e.g. 'vllm', 'sglang'). If not provided, + detected from build_info.image or requirements.txt. + model_path: Path to model directory (for requirements.txt-based toolkit detection). + + Returns (instance_type_id, reason) or (None, reason). + """ + model_config = config.get('model', {}) + model_type_id = model_config.get('model_type_id', '') + + # MCP models run on CPU + if model_type_id in ("mcp", "mcp-stdio"): + return ("t3a.2xlarge", "MCP models run on CPU") + + checkpoints = config.get('checkpoints', {}) + repo_id = checkpoints.get('repo_id') if checkpoints else None + + if not toolkit: + toolkit = _detect_toolkit_from_config(config, model_path=model_path) + + if not repo_id: + # Check if this is a GPU toolkit (vllm/sglang) that needs a repo_id + if toolkit in ('vllm', 'sglang'): + return (None, "Cannot estimate without checkpoints.repo_id") + # No checkpoints, no GPU toolkit → default to CPU + return ("t3a.2xlarge", "No model checkpoints, defaulting to CPU") + + # SGLang requires Ampere+ GPUs (compute capability >= 8.0). + # Skip pre-Ampere instances like T4 (g4dn) which fail with CUDA graph errors. + exclude_pre_ampere = toolkit == 'sglang' + + # For vLLM/SGLang, try to get HF token for gated model access + hf_token = _get_hf_token(config) if toolkit in ('vllm', 'sglang') else None + + # Try HF metadata API for parameter count + quantization + hf_info = _get_hf_model_info(repo_id) + num_params = hf_info.get("num_params") if hf_info else None + + if num_params: + quant_method = hf_info.get("quant_method") + quant_bits = hf_info.get("quant_bits") + dtype_breakdown = hf_info.get("dtype_breakdown") + + # Also check repo name for quantization hints if API didn't report any + if not quant_method: + name_method, name_bits = _detect_quant_from_repo_name(repo_id) + if name_method: + quant_method = name_method + quant_bits = name_bits + + # For vLLM/SGLang: try accurate KV cache estimation from config.json + if toolkit in ('vllm', 'sglang'): + hf_config = _get_hf_model_config(repo_id, hf_token=hf_token) + if hf_config: + weight_bytes = _estimate_weight_bytes( + num_params, quant_method, quant_bits, dtype_breakdown + ) + kv_bytes = _estimate_kv_cache_bytes(hf_config) + vram = int(weight_bytes + kv_bytes + _compute_overhead(weight_bytes)) + weight_gib = weight_bytes / (1024**3) + kv_gib = kv_bytes / (1024**3) + ctx_len = hf_config['max_position_embeddings'] + reason_detail = ( + f"{weight_gib:.1f} GiB weights + {kv_gib:.1f} GiB KV cache for {ctx_len} ctx" + ) + return _select_instance_by_vram( + vram, + pat=pat, + base_url=base_url, + exclude_pre_ampere=exclude_pre_ampere, + reason_detail=reason_detail, + ) + + # Fallback: heuristic KV cache (fraction of weights) + vram = _estimate_vram_bytes(num_params, quant_method, quant_bits, dtype_breakdown) + return _select_instance_by_vram( + vram, pat=pat, base_url=base_url, exclude_pre_ampere=exclude_pre_ampere + ) + + # Fallback: file-size-based estimate via HuggingFaceLoader + try: + from clarifai.runners.utils.loader import HuggingFaceLoader + + file_size = HuggingFaceLoader.get_huggingface_checkpoint_total_size(repo_id) + if file_size and file_size > 0: + # For vLLM/SGLang: try accurate KV cache with file-size weights + if toolkit in ('vllm', 'sglang'): + hf_config = _get_hf_model_config(repo_id, hf_token=hf_token) + if hf_config: + kv_bytes = _estimate_kv_cache_bytes(hf_config) + vram = int(file_size + kv_bytes + _compute_overhead(file_size)) + file_gib = file_size / (1024**3) + kv_gib = kv_bytes / (1024**3) + ctx_len = hf_config['max_position_embeddings'] + reason_detail = ( + f"{file_gib:.1f} GiB weights + {kv_gib:.1f} GiB KV cache for {ctx_len} ctx" + ) + return _select_instance_by_vram( + vram, + pat=pat, + base_url=base_url, + exclude_pre_ampere=exclude_pre_ampere, + reason_detail=reason_detail, + ) + # Heuristic: file size + 30% overhead for runtime buffers + KV cache + vram = int(file_size * 1.3) + _compute_overhead(file_size) + return _select_instance_by_vram( + vram, pat=pat, base_url=base_url, exclude_pre_ampere=exclude_pre_ampere + ) + except Exception: + logger.debug(f"Failed to get checkpoint size for {repo_id}") + + return (None, "Could not determine model size for " + repo_id) + + +def get_compute_cluster_config(user_id, cloud_provider="aws", region="us-east-1"): + """Get auto-create config for a compute cluster. + + Args: + user_id: The user ID for the compute cluster. + cloud_provider: Cloud provider ID (e.g. 'aws', 'gcp', 'vultr'). + region: Region ID (e.g. 'us-east-1', 'us-central1'). + + Returns: + dict: Compute cluster config suitable for User.create_compute_cluster(). + """ + return { + "compute_cluster": { + "id": get_deploy_compute_cluster_id(cloud_provider, region), + "description": f"Auto-created compute cluster for {cloud_provider}/{region}", + "cloud_provider": {"id": cloud_provider}, + "region": region, + "managed_by": "clarifai", + "cluster_type": "dedicated", + } + } + + +def get_nodepool_config(instance_type_id, compute_cluster_id, user_id, compute_info=None): + """Build nodepool config from instance type info. + + Args: + instance_type_id: The instance type ID (e.g. 'gpu-nvidia-a10g'). + compute_cluster_id: The compute cluster ID. + user_id: The user ID that owns the compute cluster. + compute_info: Optional dict of compute info. If None, minimal config is used. + + Returns: + dict: Nodepool config suitable for ComputeCluster.create_nodepool(). + """ + instance_type = {"id": instance_type_id} + if compute_info: + instance_type["compute_info"] = compute_info + + return { + "nodepool": { + "id": get_deploy_nodepool_id(instance_type_id), + "description": f"Auto-created nodepool for {instance_type_id}", + "compute_cluster": { + "id": compute_cluster_id, + "user_id": user_id, + }, + "instance_types": [instance_type], + "node_capacity_type": { + "capacity_types": [1], + }, + "min_instances": 0, + "max_instances": 5, + } + } diff --git a/clarifai/utils/config.py b/clarifai/utils/config.py index fa25630a..ad290c72 100644 --- a/clarifai/utils/config.py +++ b/clarifai/utils/config.py @@ -211,3 +211,47 @@ def current(self) -> Context: ) return Context("_empty_") return self.contexts[context_name] + + +def resolve_user_id(pat=None, base_url=None): + """Resolve user_id from CLI config or API. + + Resolution order: + 1. CLI config file (~/.config/clarifai/config) current context's CLARIFAI_USER_ID + 2. API call using PAT: GET /v2/users/me + + Args: + pat: Optional PAT for API auth (used for API fallback). + base_url: Optional API base URL (used for API fallback). + + Returns: + str or None: The resolved user_id, or None if resolution fails. + """ + # 1. Try CLI config file + try: + config = Config.from_yaml() + user_id = config.current.get('user_id') + if user_id and user_id != '_empty_': + logger.debug(f"Resolved user_id from CLI config: {user_id}") + return user_id + except Exception: + pass + + # 2. Try API call using PAT + try: + from clarifai.client.user import User + + kwargs = {} + if pat: + kwargs['pat'] = pat + if base_url: + kwargs['base_url'] = base_url + user = User(**kwargs) + user_id = user.get_user_info(user_id='me').user.id + if user_id: + logger.debug(f"Resolved user_id from API: {user_id}") + return user_id + except Exception as e: + logger.debug(f"Failed to resolve user_id from API: {e}") + + return None diff --git a/clarifai/utils/constants.py b/clarifai/utils/constants.py index 6d28c957..cc25d37a 100644 --- a/clarifai/utils/constants.py +++ b/clarifai/utils/constants.py @@ -61,14 +61,6 @@ "max_instances": 1, } } -DEFAULT_TOOLKIT_MODEL_REPO = "https://github.com/Clarifai/runners-examples" -DEFAULT_OLLAMA_MODEL_REPO_BRANCH = "ollama" -DEFAULT_HF_MODEL_REPO_BRANCH = "huggingface" -DEFAULT_LMSTUDIO_MODEL_REPO_BRANCH = "lmstudio" -DEFAULT_VLLM_MODEL_REPO_BRANCH = "vllm" -DEFAULT_SGLANG_MODEL_REPO_BRANCH = "sglang" -DEFAULT_PYTHON_MODEL_REPO_BRANCH = "python" - STATUS_OK = "200 OK" STATUS_MIXED = "207 MIXED" STATUS_FAIL = "500 FAIL" diff --git a/clarifai/utils/logging.py b/clarifai/utils/logging.py index d1beb658..63c028ad 100644 --- a/clarifai/utils/logging.py +++ b/clarifai/utils/logging.py @@ -394,7 +394,9 @@ def format(self, record): logr.update( { JSON_LOG_KEY: msg, - '@timestamp': datetime.datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.%fZ'), + '@timestamp': datetime.datetime.now(datetime.timezone.utc).strftime( + '%Y-%m-%dT%H:%M:%S.%fZ' + ), } ) @@ -447,7 +449,11 @@ def format(self, record): def formatTime(self, record, datefmt=None): # Note we didn't go with UTC here as it's easier to understand time in your time zone. # The json logger leverages UTC though. - return datetime.datetime.fromtimestamp(record.created).strftime('%H:%M:%S.%f') + try: + return datetime.datetime.fromtimestamp(record.created).strftime('%H:%M:%S.%f') + except Exception: + # During Python shutdown, datetime may already be torn down. + return "" # the default logger for the SDK. diff --git a/requirements.txt b/requirements.txt index 839dd4bf..ca1fd3e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,5 @@ packaging>=25.0 tenacity>=8.2.3 httpx>=0.27.0 openai>=1.0.0 +huggingface_hub>=0.16.4 +hf-transfer>=0.1.9 diff --git a/setup.py b/setup.py index 765bffbd..7e180d72 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ }, entry_points={ "console_scripts": [ - "clarifai = clarifai.cli.base:cli", + "clarifai = clarifai.cli.base:main", ], }, include_package_data=True, diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/cli/test_artifact_cli.py b/tests/cli/test_artifact_cli.py index aafb7958..d035e74d 100644 --- a/tests/cli/test_artifact_cli.py +++ b/tests/cli/test_artifact_cli.py @@ -452,7 +452,7 @@ def test_artifact_alias_commands(self): # Test that 'af' alias works result = self.runner.invoke(artifact, ['--help']) assert result.exit_code == 0 - assert "Manage Artifacts" in result.output + assert "Manage artifacts" in result.output # Test that 'ls' alias works for list command with patch('clarifai.cli.artifact.validate_context'): diff --git a/tests/cli/test_compute_orchestration.py b/tests/cli/test_compute_orchestration.py index 40b06f88..d5ad9ee4 100644 --- a/tests/cli/test_compute_orchestration.py +++ b/tests/cli/test_compute_orchestration.py @@ -189,9 +189,7 @@ def test_list_deployments(self, cli_runner): @pytest.mark.coverage_only def test_delete_deployment(self, cli_runner): cli_runner.invoke(cli, ["login", "--env", CLARIFAI_ENV]) - result = cli_runner.invoke( - cli, ["deployment", "delete", CREATE_NODEPOOL_ID, CREATE_DEPLOYMENT_ID] - ) + result = cli_runner.invoke(cli, ["deployment", "delete", CREATE_DEPLOYMENT_ID]) assert result.exit_code == 0, logger.exception(result) def test_delete_nodepool(self, cli_runner): diff --git a/tests/cli/test_local_runner_cli.py b/tests/cli/test_local_runner_cli.py index dc482105..72f76cbb 100644 --- a/tests/cli/test_local_runner_cli.py +++ b/tests/cli/test_local_runner_cli.py @@ -1,12 +1,12 @@ -"""Tests for the local-runner CLI command. +"""Tests for the serve CLI command. -These tests verify the basic functionality of the `clarifai model local-runner` command +These tests verify the basic functionality of the `clarifai model serve` command by mocking external dependencies and testing key behaviors. """ import os from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import pytest import yaml @@ -15,8 +15,65 @@ from clarifai.cli.base import cli +def _make_mock_context(): + """Create a mock click context with valid credentials.""" + mock_ctx = Mock() + mock_ctx.obj = Mock() + mock_ctx.obj.current = Mock() + mock_ctx.obj.current.user_id = "test-user" + mock_ctx.obj.current.pat = "test-pat" + mock_ctx.obj.current.api_base = "https://api.clarifai.com" + mock_ctx.obj.current.name = "default" + mock_ctx.obj.to_yaml = Mock() + return mock_ctx + + +def _make_mock_user_with_existing_resources(): + """Create a fully mocked User with existing compute cluster, nodepool, app, model, etc.""" + mock_user = MagicMock() + + mock_compute_cluster = MagicMock() + mock_compute_cluster.id = "local-dev-cluster" + mock_compute_cluster.cluster_type = "local-dev" + mock_user.compute_cluster.return_value = mock_compute_cluster + + mock_nodepool = MagicMock() + mock_nodepool.id = "local-dev-nodepool" + mock_compute_cluster.nodepool.return_value = mock_nodepool + + mock_app = MagicMock() + mock_app.id = "local-dev-app" + mock_user.app.return_value = mock_app + + mock_model = MagicMock() + mock_model.id = "dummy-runner-model" + mock_model.model_type_id = "multimodal-to-text" + mock_app.model.return_value = mock_model + + # Model version (created fresh every serve) + mock_version_response = MagicMock() + mock_version_response.model_version.id = "version-123" + mock_version_response.load_info = MagicMock() + mock_model.create_version.return_value = mock_version_response + + # Runner + mock_runner = MagicMock() + mock_runner.id = "runner-123" + mock_nodepool.create_runner.return_value = mock_runner + + # Deployment + mock_deployment = MagicMock() + mock_deployment.id = "deployment-123" + mock_nodepool.create_deployment.return_value = mock_deployment + + # Stale deployment cleanup — no existing deployment to clean up + mock_nodepool.deployment.side_effect = Exception("Not found") + + return mock_user + + class TestLocalRunnerCLI: - """Test cases for the local-runner CLI command.""" + """Test cases for the serve CLI command.""" @pytest.fixture def dummy_model_dir(self): @@ -36,10 +93,8 @@ def dummy_model_dir(self): @patch("clarifai.cli.model.check_requirements_installed") @patch("clarifai.cli.model.parse_requirements") @patch("clarifai.cli.model.validate_context") - @patch("builtins.input") def test_local_runner_requires_installed_dependencies( self, - mock_input, mock_validate_context, mock_parse_requirements, mock_check_requirements, @@ -48,14 +103,15 @@ def test_local_runner_requires_installed_dependencies( mock_server_class, dummy_model_dir, ): - """Test that local-runner checks for installed requirements.""" - # Setup: Requirements not installed + """Test that serve checks for installed requirements (mode=none).""" mock_check_requirements.return_value = False mock_parse_requirements.return_value = [] - # Mock ModelBuilder to return a basic config mock_builder = MagicMock() - mock_builder.config = {"model": {"model_type_id": "multimodal-to-text"}, "toolkit": {}} + mock_builder.config = { + "model": {"id": "dummy-runner-model", "model_type_id": "multimodal-to-text"}, + "toolkit": {}, + } mock_builder_class.return_value = mock_builder runner = CliRunner() @@ -63,10 +119,9 @@ def test_local_runner_requires_installed_dependencies( result = runner.invoke( cli, - ["model", "local-runner", str(dummy_model_dir)], + ["model", "serve", str(dummy_model_dir)], ) - # Should abort because requirements are not installed assert result.exit_code == 1 mock_check_requirements.assert_called() @@ -76,58 +131,8 @@ def test_local_runner_requires_installed_dependencies( @patch("clarifai.cli.model.check_requirements_installed") @patch("clarifai.cli.model.parse_requirements") @patch("clarifai.cli.model.validate_context") - @patch("builtins.input") - def test_local_runner_user_declines_resource_creation( - self, - mock_input, - mock_validate_context, - mock_parse_requirements, - mock_check_requirements, - mock_builder_class, - mock_user_class, - mock_server_class, - dummy_model_dir, - ): - """Test that local-runner aborts when user declines resource creation.""" - # Setup: user declines resource creation - mock_input.return_value = "n" # User says no - mock_check_requirements.return_value = True - mock_parse_requirements.return_value = [] - - # Mock ModelBuilder - mock_builder = MagicMock() - mock_builder.config = {"model": {"model_type_id": "multimodal-to-text"}, "toolkit": {}} - mock_builder_class.return_value = mock_builder - mock_builder_class._load_config.return_value = mock_builder.config - - # Mock User that throws exception for missing compute cluster - mock_user = MagicMock() - mock_user.compute_cluster.side_effect = Exception("Cluster not found") - mock_user_class.return_value = mock_user - - runner = CliRunner() - runner.invoke(cli, ["login", "--user_id", "test-user", "--pat", "test-pat"]) - - result = runner.invoke( - cli, - ["model", "local-runner", str(dummy_model_dir)], - ) - - # Should abort when user declines - assert result.exit_code == 1 - # Verify that create_compute_cluster was NOT called - mock_user.create_compute_cluster.assert_not_called() - - @patch("clarifai.runners.server.ModelServer") - @patch("clarifai.client.user.User") - @patch("clarifai.runners.models.model_builder.ModelBuilder") - @patch("clarifai.cli.model.check_requirements_installed") - @patch("clarifai.cli.model.parse_requirements") - @patch("clarifai.cli.model.validate_context") - @patch("builtins.input") - def test_local_runner_creates_resources_when_missing( + def test_local_runner_auto_creates_resources( self, - mock_input, mock_validate_context, mock_parse_requirements, mock_check_requirements, @@ -136,32 +141,37 @@ def test_local_runner_creates_resources_when_missing( mock_server_class, dummy_model_dir, ): - """Test that local-runner creates missing resources when user accepts.""" - # Setup: user accepts resource creation - mock_input.return_value = "y" # User says yes + """Test that serve auto-creates missing resources without prompting.""" mock_check_requirements.return_value = True mock_parse_requirements.return_value = [] - # Mock ModelBuilder mock_builder = MagicMock() - mock_builder.config = {"model": {"model_type_id": "multimodal-to-text"}, "toolkit": {}} + mock_builder.config = { + "model": {"id": "dummy-runner-model", "model_type_id": "multimodal-to-text"}, + "toolkit": {}, + } mock_method_sig = MagicMock() mock_method_sig.name = "predict" mock_builder.get_method_signatures.return_value = [mock_method_sig] mock_builder_class.return_value = mock_builder - mock_builder_class._load_config.return_value = mock_builder.config - # Mock User and resources mock_user = MagicMock() - # Compute cluster doesn't exist - mock_user.compute_cluster.side_effect = Exception("Cluster not found") + # Compute cluster doesn't exist first time, then succeeds after creation mock_compute_cluster = MagicMock() mock_compute_cluster.id = "local-dev-cluster" - mock_compute_cluster.cluster_type = "local-dev" + call_count = {"compute_cluster": 0} + + def compute_cluster_side_effect(cc_id): + call_count["compute_cluster"] += 1 + if call_count["compute_cluster"] == 1: + raise Exception("Cluster not found") + return mock_compute_cluster + + mock_user.compute_cluster.side_effect = compute_cluster_side_effect mock_user.create_compute_cluster.return_value = mock_compute_cluster - # Nodepool doesn't exist + # Nodepool doesn't exist first time, then succeeds after creation mock_nodepool = MagicMock() mock_nodepool.id = "local-dev-nodepool" mock_compute_cluster.nodepool.side_effect = Exception("Nodepool not found") @@ -175,17 +185,15 @@ def test_local_runner_creates_resources_when_missing( # Model doesn't exist mock_model = MagicMock() - mock_model.id = "local-dev-model" + mock_model.id = "dummy-runner-model" mock_model.model_type_id = "multimodal-to-text" mock_app.model.side_effect = Exception("Model not found") mock_app.create_model.return_value = mock_model - # Model version - mock_version = MagicMock() - mock_version.id = "version-123" - mock_model.list_versions.return_value = [] + # Version (always created) mock_version_response = MagicMock() - mock_version_response.model_version = mock_version + mock_version_response.model_version.id = "version-123" + mock_version_response.load_info = MagicMock() mock_model.create_version.return_value = mock_version_response # Runner @@ -193,29 +201,18 @@ def test_local_runner_creates_resources_when_missing( mock_runner.id = "runner-123" mock_nodepool.create_runner.return_value = mock_runner - # Deployment doesn't exist + # Stale deployment cleanup + new deployment creation + mock_nodepool.deployment.side_effect = Exception("Not found") mock_deployment = MagicMock() mock_deployment.id = "deployment-123" - mock_nodepool.deployment.side_effect = Exception("Deployment not found") mock_nodepool.create_deployment.return_value = mock_deployment mock_user_class.return_value = mock_user - # Mock ModelServer mock_server = MagicMock() mock_server_class.return_value = mock_server - # Mock a proper context - from unittest.mock import Mock - - mock_ctx = Mock() - mock_ctx.obj = Mock() - mock_ctx.obj.current = Mock() - mock_ctx.obj.current.user_id = "test-user" - mock_ctx.obj.current.pat = "test-pat" - mock_ctx.obj.current.api_base = "https://api.clarifai.com" - mock_ctx.obj.current.name = "default" - mock_ctx.obj.to_yaml = Mock() + mock_ctx = _make_mock_context() def validate_ctx_mock(ctx): ctx.obj = mock_ctx.obj @@ -225,20 +222,15 @@ def validate_ctx_mock(ctx): runner = CliRunner() result = runner.invoke( cli, - ["model", "local-runner", str(dummy_model_dir)], + ["model", "serve", str(dummy_model_dir)], catch_exceptions=False, ) - # Should succeed after creating resources assert result.exit_code == 0, f"Command failed with: {result.output}" - # Verify resources were created mock_user.create_compute_cluster.assert_called_once() - mock_compute_cluster.create_nodepool.assert_called_once() mock_user.create_app.assert_called_once() mock_app.create_model.assert_called_once() mock_model.create_version.assert_called_once() - # TODO: Create runner is failing in CI, so commenting out for now - # mock_nodepool.create_runner.assert_called_once() mock_nodepool.create_deployment.assert_called_once() @patch("clarifai.runners.server.ModelServer") @@ -247,10 +239,8 @@ def validate_ctx_mock(ctx): @patch("clarifai.cli.model.check_requirements_installed") @patch("clarifai.cli.model.parse_requirements") @patch("clarifai.cli.model.validate_context") - @patch("builtins.input") def test_local_runner_uses_existing_resources( self, - mock_input, mock_validate_context, mock_parse_requirements, mock_check_requirements, @@ -259,98 +249,27 @@ def test_local_runner_uses_existing_resources( mock_server_class, dummy_model_dir, ): - """Test that local-runner uses existing resources without creating new ones.""" - # Setup - mock_input.return_value = "y" + """Test that serve reuses existing resources but always creates a fresh version.""" mock_check_requirements.return_value = True mock_parse_requirements.return_value = [] - # Mock ModelBuilder mock_builder = MagicMock() - mock_builder.config = {"model": {"model_type_id": "multimodal-to-text"}, "toolkit": {}} + mock_builder.config = { + "model": {"id": "dummy-runner-model", "model_type_id": "multimodal-to-text"}, + "toolkit": {}, + } mock_method_sig = MagicMock() mock_method_sig.name = "predict" mock_builder.get_method_signatures.return_value = [mock_method_sig] mock_builder_class.return_value = mock_builder - mock_builder_class._load_config.return_value = mock_builder.config - - # Mock User with all existing resources - mock_user = MagicMock() - - # Existing compute cluster - mock_compute_cluster = MagicMock() - mock_compute_cluster.id = "local-dev-cluster" - mock_compute_cluster.cluster_type = "local-dev" - mock_user.compute_cluster.return_value = mock_compute_cluster - - # Existing nodepool - mock_nodepool = MagicMock() - mock_nodepool.id = "local-dev-nodepool" - mock_compute_cluster.nodepool.return_value = mock_nodepool - - # Existing app - mock_app = MagicMock() - mock_app.id = "local-dev-app" - mock_user.app.return_value = mock_app - - # Existing model - mock_model = MagicMock() - mock_model.id = "local-dev-model" - mock_model.model_type_id = "multimodal-to-text" - mock_app.model.return_value = mock_model - - # Existing model version - mock_version = MagicMock() - mock_version.id = "version-123" - mock_model_version_obj = MagicMock() - mock_model_version_obj.model_version = mock_version - mock_model.list_versions.return_value = [mock_model_version_obj] - mock_patched_model = MagicMock() - mock_patched_model.model_version = mock_version - mock_patched_model.load_info = MagicMock() - mock_model.patch_version.return_value = mock_patched_model - - # Existing runner - mock_runner = MagicMock() - mock_runner.id = "runner-123" - mock_runner.worker = MagicMock() - mock_runner.worker.model = MagicMock() - mock_runner.worker.model.model_version = MagicMock() - mock_runner.worker.model.model_version.id = "version-123" - mock_nodepool.runner.return_value = mock_runner - - # Existing deployment - mock_deployment = MagicMock() - mock_deployment.id = "deployment-123" - mock_deployment.worker = MagicMock() - mock_deployment.worker.model = MagicMock() - mock_deployment.worker.model.model_version = MagicMock() - mock_deployment.worker.model.model_version.id = "version-123" - mock_nodepool.deployment.return_value = mock_deployment + mock_user = _make_mock_user_with_existing_resources() mock_user_class.return_value = mock_user - # Mock ModelServer mock_server = MagicMock() mock_server_class.return_value = mock_server - # Mock a proper context - from unittest.mock import Mock - - mock_ctx = Mock() - mock_ctx.obj = Mock() - mock_ctx.obj.current = Mock() - mock_ctx.obj.current.user_id = "test-user" - mock_ctx.obj.current.pat = "test-pat" - mock_ctx.obj.current.api_base = "https://api.clarifai.com" - mock_ctx.obj.current.name = "default" - mock_ctx.obj.current.compute_cluster_id = "local-dev-cluster" - mock_ctx.obj.current.nodepool_id = "local-dev-nodepool" - mock_ctx.obj.current.app_id = "local-dev-app" - mock_ctx.obj.current.model_id = "local-dev-model" - mock_ctx.obj.current.runner_id = "runner-123" - mock_ctx.obj.current.deployment_id = "deployment-123" - mock_ctx.obj.to_yaml = Mock() + mock_ctx = _make_mock_context() def validate_ctx_mock(ctx): ctx.obj = mock_ctx.obj @@ -360,20 +279,16 @@ def validate_ctx_mock(ctx): runner = CliRunner() result = runner.invoke( cli, - ["model", "local-runner", str(dummy_model_dir)], + ["model", "serve", str(dummy_model_dir)], catch_exceptions=False, ) - # Should succeed using existing resources assert result.exit_code == 0, f"Command failed with: {result.output}" - # Verify no new resources were created + # Existing resources not re-created mock_user.create_compute_cluster.assert_not_called() - mock_compute_cluster.create_nodepool.assert_not_called() mock_user.create_app.assert_not_called() - mock_app.create_model.assert_not_called() - mock_model.create_version.assert_not_called() - mock_nodepool.create_runner.assert_not_called() - mock_nodepool.create_deployment.assert_not_called() + # But version IS always created fresh + mock_user.app().model().create_version.assert_called_once() @patch("clarifai.runners.server.ModelServer") @patch("clarifai.client.user.User") @@ -381,10 +296,8 @@ def validate_ctx_mock(ctx): @patch("clarifai.cli.model.check_requirements_installed") @patch("clarifai.cli.model.parse_requirements") @patch("clarifai.cli.model.validate_context") - @patch("builtins.input") - def test_local_runner_with_pool_size_parameter( + def test_local_runner_with_concurrency_parameter( self, - mock_input, mock_validate_context, mock_parse_requirements, mock_check_requirements, @@ -393,90 +306,27 @@ def test_local_runner_with_pool_size_parameter( mock_server_class, dummy_model_dir, ): - """Test that local-runner accepts and uses the pool_size parameter.""" - # Setup - mock_input.return_value = "y" + """Test that serve accepts and uses the --concurrency parameter.""" mock_check_requirements.return_value = True mock_parse_requirements.return_value = [] - # Mock ModelBuilder mock_builder = MagicMock() - mock_builder.config = {"model": {"model_type_id": "multimodal-to-text"}, "toolkit": {}} + mock_builder.config = { + "model": {"id": "dummy-runner-model", "model_type_id": "multimodal-to-text"}, + "toolkit": {}, + } mock_method_sig = MagicMock() mock_method_sig.name = "predict" mock_builder.get_method_signatures.return_value = [mock_method_sig] mock_builder_class.return_value = mock_builder - mock_builder_class._load_config.return_value = mock_builder.config - - # Mock User with all existing resources (simplified) - mock_user = MagicMock() - mock_compute_cluster = MagicMock() - mock_compute_cluster.id = "local-dev-cluster" - mock_compute_cluster.cluster_type = "local-dev" - mock_user.compute_cluster.return_value = mock_compute_cluster - - mock_nodepool = MagicMock() - mock_nodepool.id = "local-dev-nodepool" - mock_compute_cluster.nodepool.return_value = mock_nodepool - - mock_app = MagicMock() - mock_app.id = "local-dev-app" - mock_user.app.return_value = mock_app - - mock_model = MagicMock() - mock_model.id = "local-dev-model" - mock_model.model_type_id = "multimodal-to-text" - mock_app.model.return_value = mock_model - - mock_version = MagicMock() - mock_version.id = "version-123" - mock_model_version_obj = MagicMock() - mock_model_version_obj.model_version = mock_version - mock_model.list_versions.return_value = [mock_model_version_obj] - mock_patched_model = MagicMock() - mock_patched_model.model_version = mock_version - mock_patched_model.load_info = MagicMock() - mock_model.patch_version.return_value = mock_patched_model - - mock_runner = MagicMock() - mock_runner.id = "runner-123" - mock_runner.worker = MagicMock() - mock_runner.worker.model = MagicMock() - mock_runner.worker.model.model_version = MagicMock() - mock_runner.worker.model.model_version.id = "version-123" - mock_nodepool.runner.return_value = mock_runner - - mock_deployment = MagicMock() - mock_deployment.id = "deployment-123" - mock_deployment.worker = MagicMock() - mock_deployment.worker.model = MagicMock() - mock_deployment.worker.model.model_version = MagicMock() - mock_deployment.worker.model.model_version.id = "version-123" - mock_nodepool.deployment.return_value = mock_deployment + mock_user = _make_mock_user_with_existing_resources() mock_user_class.return_value = mock_user - # Mock ModelServer mock_server = MagicMock() mock_server_class.return_value = mock_server - # Mock a proper context - from unittest.mock import Mock - - mock_ctx = Mock() - mock_ctx.obj = Mock() - mock_ctx.obj.current = Mock() - mock_ctx.obj.current.user_id = "test-user" - mock_ctx.obj.current.pat = "test-pat" - mock_ctx.obj.current.api_base = "https://api.clarifai.com" - mock_ctx.obj.current.name = "default" - mock_ctx.obj.current.compute_cluster_id = "local-dev-cluster" - mock_ctx.obj.current.nodepool_id = "local-dev-nodepool" - mock_ctx.obj.current.app_id = "local-dev-app" - mock_ctx.obj.current.model_id = "local-dev-model" - mock_ctx.obj.current.runner_id = "runner-123" - mock_ctx.obj.current.deployment_id = "deployment-123" - mock_ctx.obj.to_yaml = Mock() + mock_ctx = _make_mock_context() def validate_ctx_mock(ctx): ctx.obj = mock_ctx.obj @@ -484,16 +334,13 @@ def validate_ctx_mock(ctx): mock_validate_context.side_effect = validate_ctx_mock runner = CliRunner() - # Test with custom pool_size result = runner.invoke( cli, - ["model", "local-runner", str(dummy_model_dir), "--pool_size", "24"], + ["model", "serve", str(dummy_model_dir), "--concurrency", "24"], catch_exceptions=False, ) - # Should succeed assert result.exit_code == 0, f"Command failed with: {result.output}" - # Verify pool_size was passed to serve mock_server.serve.assert_called_once() serve_kwargs = mock_server.serve.call_args[1] assert serve_kwargs["pool_size"] == 24 @@ -519,10 +366,8 @@ def test_local_runner_has_config_yaml_in_model_dir(self, dummy_model_dir): @patch("clarifai.cli.model.check_requirements_installed") @patch("clarifai.cli.model.parse_requirements") @patch("clarifai.cli.model.validate_context") - @patch("builtins.input") def test_local_runner_model_serving( self, - mock_input, mock_validate_context, mock_parse_requirements, mock_check_requirements, @@ -531,90 +376,27 @@ def test_local_runner_model_serving( mock_server_class, dummy_model_dir, ): - """Test that local-runner properly initializes and serves the model.""" - # Setup - mock_input.return_value = "y" + """Test that serve properly initializes and serves the model.""" mock_check_requirements.return_value = True mock_parse_requirements.return_value = [] - # Mock ModelBuilder mock_builder = MagicMock() - mock_builder.config = {"model": {"model_type_id": "multimodal-to-text"}, "toolkit": {}} + mock_builder.config = { + "model": {"id": "dummy-runner-model", "model_type_id": "multimodal-to-text"}, + "toolkit": {}, + } mock_method_sig = MagicMock() mock_method_sig.name = "predict" mock_builder.get_method_signatures.return_value = [mock_method_sig] mock_builder_class.return_value = mock_builder - mock_builder_class._load_config.return_value = mock_builder.config - - # Mock User with all existing resources - mock_user = MagicMock() - mock_compute_cluster = MagicMock() - mock_compute_cluster.id = "local-dev-cluster" - mock_compute_cluster.cluster_type = "local-dev" - mock_user.compute_cluster.return_value = mock_compute_cluster - - mock_nodepool = MagicMock() - mock_nodepool.id = "local-dev-nodepool" - mock_compute_cluster.nodepool.return_value = mock_nodepool - - mock_app = MagicMock() - mock_app.id = "local-dev-app" - mock_user.app.return_value = mock_app - - mock_model = MagicMock() - mock_model.id = "local-dev-model" - mock_model.model_type_id = "multimodal-to-text" - mock_app.model.return_value = mock_model - - mock_version = MagicMock() - mock_version.id = "version-123" - mock_model_version_obj = MagicMock() - mock_model_version_obj.model_version = mock_version - mock_model.list_versions.return_value = [mock_model_version_obj] - mock_patched_model = MagicMock() - mock_patched_model.model_version = mock_version - mock_patched_model.load_info = MagicMock() - mock_model.patch_version.return_value = mock_patched_model - - mock_runner = MagicMock() - mock_runner.id = "runner-123" - mock_runner.worker = MagicMock() - mock_runner.worker.model = MagicMock() - mock_runner.worker.model.model_version = MagicMock() - mock_runner.worker.model.model_version.id = "version-123" - mock_nodepool.runner.return_value = mock_runner - - mock_deployment = MagicMock() - mock_deployment.id = "deployment-123" - mock_deployment.worker = MagicMock() - mock_deployment.worker.model = MagicMock() - mock_deployment.worker.model.model_version = MagicMock() - mock_deployment.worker.model.model_version.id = "version-123" - mock_nodepool.deployment.return_value = mock_deployment + mock_user = _make_mock_user_with_existing_resources() mock_user_class.return_value = mock_user - # Mock ModelServer - this is the key part of this test mock_server = MagicMock() mock_server_class.return_value = mock_server - # Mock a proper context - from unittest.mock import Mock - - mock_ctx = Mock() - mock_ctx.obj = Mock() - mock_ctx.obj.current = Mock() - mock_ctx.obj.current.user_id = "test-user" - mock_ctx.obj.current.pat = "test-pat" - mock_ctx.obj.current.api_base = "https://api.clarifai.com" - mock_ctx.obj.current.name = "default" - mock_ctx.obj.current.compute_cluster_id = "local-dev-cluster" - mock_ctx.obj.current.nodepool_id = "local-dev-nodepool" - mock_ctx.obj.current.app_id = "local-dev-app" - mock_ctx.obj.current.model_id = "local-dev-model" - mock_ctx.obj.current.runner_id = "runner-123" - mock_ctx.obj.current.deployment_id = "deployment-123" - mock_ctx.obj.to_yaml = Mock() + mock_ctx = _make_mock_context() def validate_ctx_mock(ctx): ctx.obj = mock_ctx.obj @@ -624,30 +406,22 @@ def validate_ctx_mock(ctx): runner = CliRunner() result = runner.invoke( cli, - ["model", "local-runner", str(dummy_model_dir)], + ["model", "serve", str(dummy_model_dir)], catch_exceptions=False, ) - # Should succeed assert result.exit_code == 0, f"Command failed with: {result.output}" - # Verify ModelServer was instantiated with the correct model path + # Verify ModelServer was instantiated (without model_builder arg) mock_server_class.assert_called_once_with( - model_path=str(dummy_model_dir), model_runner_local=None, model_builder=mock_builder + model_path=str(dummy_model_dir), model_runner_local=None ) - # Verify serve method was called with correct parameters for local runner + # Verify serve method was called with correct parameters mock_server.serve.assert_called_once() serve_kwargs = mock_server.serve.call_args[1] - - # Check that all critical parameters are passed correctly assert serve_kwargs["user_id"] == "test-user" - assert serve_kwargs["runner_id"] == "runner-123" - assert serve_kwargs["compute_cluster_id"] == "local-dev-cluster" - assert serve_kwargs["nodepool_id"] == "local-dev-nodepool" assert serve_kwargs["base_url"] == "https://api.clarifai.com" assert serve_kwargs["pat"] == "test-pat" assert "pool_size" in serve_kwargs assert "num_threads" in serve_kwargs - # grpc defaults to False for local runner (not always passed as kwarg) - assert serve_kwargs.get("grpc", False) is False diff --git a/tests/cli/test_login.py b/tests/cli/test_login.py index e6a277a2..a29b22a5 100644 --- a/tests/cli/test_login.py +++ b/tests/cli/test_login.py @@ -3,141 +3,465 @@ import os from unittest import mock +import pytest from click.testing import CliRunner from clarifai.cli.base import cli +@pytest.fixture(autouse=True) +def _clean_env(): + """Ensure CLARIFAI_* variables are not leaked from the host environment into tests.""" + with mock.patch.dict(os.environ, {}, clear=False): + for key in list(os.environ.keys()): + if key.startswith('CLARIFAI_'): + os.environ.pop(key, None) + yield + + +def _mock_verify(pat, api_url): + """Mock _verify_and_resolve_user that returns a fake user_id.""" + return 'testuser' + + +def _mock_list_orgs_empty(pat, user_id, api_url): + return [] + + +def _mock_list_orgs_with_orgs(pat, user_id, api_url): + return [('clarifai', 'Clarifai'), ('openai', 'OpenAI')] + + class TestLoginCommand: """Test cases for the login command.""" def setup_method(self): self.runner = CliRunner() - self.validate_patch = mock.patch('clarifai.utils.cli.validate_context_auth') - self.mock_validate = self.validate_patch.start() + self.verify_patch = mock.patch( + 'clarifai.cli.base._verify_and_resolve_user', side_effect=_mock_verify + ) + self.orgs_patch = mock.patch( + 'clarifai.cli.base._list_user_orgs', side_effect=_mock_list_orgs_empty + ) + self.mock_verify = self.verify_patch.start() + self.mock_orgs = self.orgs_patch.start() def teardown_method(self): - self.validate_patch.stop() + self.verify_patch.stop() + self.orgs_patch.stop() - def test_login_with_env_var_accepted(self): - """Test login when CLARIFAI_PAT env var exists and user accepts it.""" + def test_login_with_pat_flag(self): + """Non-interactive login with --pat flag, no prompts.""" with self.runner.isolated_filesystem(): - with mock.patch.dict(os.environ, {'CLARIFAI_PAT': 'test_pat_123'}): - result = self.runner.invoke( - cli, ['login'], input='testuser\ny\n', catch_exceptions=False - ) + result = self.runner.invoke( + cli, ['login', '--pat', 'test_pat_123'], catch_exceptions=False + ) assert result.exit_code == 0 - assert 'Use CLARIFAI_PAT from environment?' in result.output - assert "Success! You're logged in as testuser" in result.output - # Should not prompt for context name (auto-creates 'default') - assert 'Enter a name for this context' not in result.output - # Should not show verbose context explanation from old flow - assert "Let's save these credentials to a new context" not in result.output - # Validation debug logs should not leak into output - assert 'Validating the Context Credentials' not in result.output + assert 'Verifying...' in result.output + assert 'Logged in as testuser' in result.output + assert 'context' in result.output and 'testuser' in result.output - def test_login_with_env_var_declined(self): - """Test login when user declines to use CLARIFAI_PAT env var.""" + def test_login_with_env_var(self): + """Login auto-uses CLARIFAI_PAT from environment (no confirm prompt).""" with self.runner.isolated_filesystem(): - with mock.patch.dict(os.environ, {'CLARIFAI_PAT': 'test_pat_123'}): - with mock.patch('clarifai.cli.base.masked_input', return_value='manual_pat'): - result = self.runner.invoke( - cli, ['login'], input='testuser\nn\n', catch_exceptions=False - ) + with mock.patch.dict(os.environ, {'CLARIFAI_PAT': 'env_pat_456'}): + result = self.runner.invoke(cli, ['login'], catch_exceptions=False) assert result.exit_code == 0 - assert 'Use CLARIFAI_PAT from environment?' in result.output - assert 'Create a PAT at:' in result.output - assert "Success! You're logged in as testuser" in result.output + assert 'Using PAT from CLARIFAI_PAT environment variable.' in result.output + assert 'Logged in as testuser' in result.output + # No confirm prompt — it just uses the env var + assert 'Use CLARIFAI_PAT from environment?' not in result.output - def test_login_without_env_var(self): - """Test login when CLARIFAI_PAT env var is not set.""" + def test_login_interactive_prompt(self): + """Login prompts for PAT when no flag and no env var.""" with self.runner.isolated_filesystem(): env = os.environ.copy() env.pop('CLARIFAI_PAT', None) with mock.patch.dict(os.environ, env, clear=True): - with mock.patch( - 'clarifai.cli.base.masked_input', return_value='typed_pat' - ) as mock_masked: - result = self.runner.invoke( - cli, ['login'], input='testuser\n', catch_exceptions=False - ) + with mock.patch('clarifai.cli.base.masked_input', return_value='typed_pat'): + with mock.patch('clarifai.cli.base.webbrowser', create=True): + result = self.runner.invoke(cli, ['login'], catch_exceptions=False) assert result.exit_code == 0 - assert "you'll need a Personal Access Token (PAT)" in result.output - assert 'Create one at:' in result.output - assert 'Set CLARIFAI_PAT environment variable to skip this prompt' in result.output - assert "Success! You're logged in as testuser" in result.output - mock_masked.assert_called_once() + assert 'Logged in as testuser' in result.output - def test_login_with_user_id_option(self): - """Test login with --user_id skips the user ID prompt.""" + def test_login_with_user_id_flag(self): + """--user-id skips org selection and uses the given user_id.""" with self.runner.isolated_filesystem(): - with mock.patch.dict(os.environ, {'CLARIFAI_PAT': 'test_pat'}): - result = self.runner.invoke( - cli, ['login', '--user_id', 'presetuser'], input='y\n', catch_exceptions=False - ) + result = self.runner.invoke( + cli, ['login', '--pat', 'test_pat', '--user-id', 'openai'], catch_exceptions=False + ) assert result.exit_code == 0 - assert 'Enter your Clarifai user ID' not in result.output - assert "Success! You're logged in as presetuser" in result.output + assert 'Logged in as openai' in result.output + assert 'context' in result.output and 'openai' in result.output + + def test_login_with_context_flag(self): + """--context sets a custom context name.""" + with self.runner.isolated_filesystem(): + result = self.runner.invoke( + cli, + ['login', '--pat', 'test_pat', '--context', 'my-custom'], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + assert 'my-custom' in result.output + assert 'set as active' in result.output + + def test_login_dev_env_context_naming(self): + """Dev environment URL produces 'dev-{user_id}' context name.""" + with self.runner.isolated_filesystem(): + result = self.runner.invoke( + cli, + ['login', 'https://api-dev.clarifai.com', '--pat', 'test_pat'], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + assert 'dev-testuser' in result.output + assert 'set as active' in result.output + + def test_login_staging_env_context_naming(self): + """Staging environment URL produces 'staging-{user_id}' context name.""" + with self.runner.isolated_filesystem(): + result = self.runner.invoke( + cli, + ['login', 'https://api-staging.clarifai.com', '--pat', 'test_pat'], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + assert 'staging-testuser' in result.output + assert 'set as active' in result.output + + def test_login_relogin_updates_existing(self): + """Re-login updates existing context instead of erroring.""" + with self.runner.isolated_filesystem(): + # First login + self.runner.invoke(cli, ['login', '--pat', 'old_pat'], catch_exceptions=False) + # Second login (same user_id resolves to same context name) + result = self.runner.invoke(cli, ['login', '--pat', 'new_pat'], catch_exceptions=False) + + assert result.exit_code == 0 + assert 'Updated' in result.output + assert 'testuser' in result.output + + def test_login_creates_new_context(self): + """First login creates a new context.""" + with self.runner.isolated_filesystem(): + result = self.runner.invoke( + cli, + ['--config', './test_config.yaml', 'login', '--pat', 'test_pat'], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + assert 'Created' in result.output + assert 'testuser' in result.output + + +class TestLoginOrgSelection: + """Test cases for org selection during login.""" + + def setup_method(self): + self.runner = CliRunner() + self.verify_patch = mock.patch( + 'clarifai.cli.base._verify_and_resolve_user', side_effect=_mock_verify + ) + self.orgs_patch = mock.patch( + 'clarifai.cli.base._list_user_orgs', side_effect=_mock_list_orgs_with_orgs + ) + self.mock_verify = self.verify_patch.start() + self.mock_orgs = self.orgs_patch.start() + + def teardown_method(self): + self.verify_patch.stop() + self.orgs_patch.stop() + + def test_login_shows_org_list(self): + """When user has orgs, login shows numbered list.""" + with self.runner.isolated_filesystem(): + result = self.runner.invoke( + cli, ['login', '--pat', 'test_pat'], input='1\n', catch_exceptions=False + ) + + assert result.exit_code == 0 + # Check key parts are present (colors stripped in test runner) + assert '[1]' in result.output and 'testuser' in result.output + assert '[2]' in result.output and 'clarifai' in result.output + assert '[3]' in result.output and 'openai' in result.output + assert '(personal)' in result.output + assert '(Clarifai)' in result.output + assert '(OpenAI)' in result.output + + def test_login_org_selection_default(self): + """Pressing enter selects personal user (default=1).""" + with self.runner.isolated_filesystem(): + result = self.runner.invoke( + cli, ['login', '--pat', 'test_pat'], input='\n', catch_exceptions=False + ) + + assert result.exit_code == 0 + assert 'Logged in as testuser' in result.output + + def test_login_org_selection_by_number(self): + """Selecting an org by number works.""" + with self.runner.isolated_filesystem(): + result = self.runner.invoke( + cli, ['login', '--pat', 'test_pat'], input='2\n', catch_exceptions=False + ) + + assert result.exit_code == 0 + assert 'Logged in as clarifai' in result.output + assert 'clarifai' in result.output and 'set as active' in result.output + + def test_login_user_id_flag_skips_org_prompt(self): + """--user-id bypasses org selection even when orgs exist.""" + with self.runner.isolated_filesystem(): + result = self.runner.invoke( + cli, + ['login', '--pat', 'test_pat', '--user-id', 'openai'], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + assert 'Select user_id' not in result.output + assert 'Logged in as openai' in result.output class TestCreateContextCommand: - """Test cases for the create-context command improvements.""" + """Test cases for the create-context command.""" def setup_method(self): self.runner = CliRunner() - self.validate_patch = mock.patch('clarifai.utils.cli.validate_context_auth') - self.mock_validate = self.validate_patch.start() + self.verify_patch = mock.patch( + 'clarifai.cli.base._verify_and_resolve_user', side_effect=_mock_verify + ) + self.orgs_patch = mock.patch( + 'clarifai.cli.base._list_user_orgs', side_effect=_mock_list_orgs_empty + ) + self.mock_verify = self.verify_patch.start() + self.mock_orgs = self.orgs_patch.start() def teardown_method(self): - self.validate_patch.stop() + self.verify_patch.stop() + self.orgs_patch.stop() def _login_first(self, config_path='./config.yaml'): """Helper to create initial config via login.""" - with mock.patch.dict(os.environ, {'CLARIFAI_PAT': 'test_pat'}): - self.runner.invoke(cli, ['--config', config_path, 'login'], input='testuser\ny\n') + self.runner.invoke( + cli, ['--config', config_path, 'login', '--pat', 'test_pat'], catch_exceptions=False + ) + + def test_create_context_with_pat_flag(self): + """Create context with all flags — no prompts.""" + with self.runner.isolated_filesystem(): + with mock.patch('clarifai.cli.base.DEFAULT_CONFIG', './config.yaml'): + self._login_first() + result = self.runner.invoke( + cli, + [ + '--config', + './config.yaml', + 'config', + 'create-context', + 'dev', + '--pat', + 'dev_pat', + '--user-id', + 'devuser', + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + assert 'Context' in result.output and 'dev' in result.output + assert 'created' in result.output def test_create_context_with_env_var(self): - """Test create-context detects and offers CLARIFAI_PAT from environment.""" + """Create context auto-uses CLARIFAI_PAT from env.""" + with self.runner.isolated_filesystem(): + with mock.patch('clarifai.cli.base.DEFAULT_CONFIG', './config.yaml'): + self._login_first() + with mock.patch.dict(os.environ, {'CLARIFAI_PAT': 'env_pat'}): + result = self.runner.invoke( + cli, + ['--config', './config.yaml', 'config', 'create-context', 'myctx'], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + assert 'Using PAT from CLARIFAI_PAT environment variable.' in result.output + assert 'Context' in result.output and 'myctx' in result.output + assert 'created' in result.output + + def test_create_context_duplicate_name_fails(self): + """Creating context with existing name fails.""" + with self.runner.isolated_filesystem(): + with mock.patch('clarifai.cli.base.DEFAULT_CONFIG', './config.yaml'): + self._login_first() + result = self.runner.invoke( + cli, + [ + '--config', + './config.yaml', + 'config', + 'create-context', + 'testuser', + '--pat', + 'x', + ], + catch_exceptions=False, + ) + + assert result.exit_code == 1 + assert 'already exists' in result.output + + +def _mock_get_user_info(user_id=None): + """Return a mock response with user profile fields.""" + from unittest.mock import MagicMock + + user = MagicMock() + user.id = 'testuser' + user.full_name = 'Test User' + user.primary_email = 'test@clarifai.com' + user.company_name = 'Clarifai' + resp = MagicMock() + resp.user = user + return resp + + +class TestWhoamiCommand: + """Test cases for the whoami command.""" + + def setup_method(self): + self.runner = CliRunner() + + def _login_first(self, config_path='./config.yaml', pat='test_pat_123'): + """Helper to login and create a config.""" + with ( + mock.patch('clarifai.cli.base._verify_and_resolve_user', side_effect=_mock_verify), + mock.patch('clarifai.cli.base._list_user_orgs', side_effect=_mock_list_orgs_empty), + ): + self.runner.invoke( + cli, + ['--config', config_path, 'login', '--pat', pat], + catch_exceptions=False, + ) + + def test_whoami_default(self): + """Default whoami shows user_id and context from local config (no API call).""" with self.runner.isolated_filesystem(): with mock.patch('clarifai.cli.base.DEFAULT_CONFIG', './config.yaml'): self._login_first() + result = self.runner.invoke( + cli, ['--config', './config.yaml', 'whoami'], catch_exceptions=False + ) + + assert result.exit_code == 0 + assert 'testuser' in result.output + assert 'Context:' in result.output - with mock.patch.dict(os.environ, {'CLARIFAI_PAT': 'another_pat'}): + def test_whoami_with_orgs(self): + """--orgs shows organization list.""" + with self.runner.isolated_filesystem(): + with mock.patch('clarifai.cli.base.DEFAULT_CONFIG', './config.yaml'): + self._login_first() + with mock.patch( + 'clarifai.cli.base._list_user_orgs', side_effect=_mock_list_orgs_with_orgs + ): result = self.runner.invoke( cli, - ['--config', './config.yaml', 'config', 'create-context', 'dev'], - input='devuser\nhttps://api-dev.clarifai.com\ny\n', + ['--config', './config.yaml', 'whoami', '--orgs'], catch_exceptions=False, ) assert result.exit_code == 0 - assert 'Found CLARIFAI_PAT in environment. Use it?' in result.output + assert 'testuser' in result.output + assert 'Organizations:' in result.output + assert 'clarifai' in result.output + assert 'openai' in result.output - def test_create_context_without_env_var(self): - """Test create-context uses masked_input when no env var is set.""" + def test_whoami_with_all(self): + """--all shows full profile including name, email, company, and orgs.""" with self.runner.isolated_filesystem(): with mock.patch('clarifai.cli.base.DEFAULT_CONFIG', './config.yaml'): self._login_first() + with ( + mock.patch( + 'clarifai.cli.base._list_user_orgs', side_effect=_mock_list_orgs_with_orgs + ), + mock.patch( + 'clarifai.client.user.User.get_user_info', side_effect=_mock_get_user_info + ), + ): + result = self.runner.invoke( + cli, + ['--config', './config.yaml', 'whoami', '--all'], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + assert 'testuser' in result.output + assert 'Test User' in result.output + assert 'test@clarifai.com' in result.output + assert 'Clarifai' in result.output + assert 'Organizations:' in result.output - env = os.environ.copy() - env.pop('CLARIFAI_PAT', None) + def test_whoami_json_output(self): + """JSON output contains expected keys.""" + import json - with mock.patch.dict(os.environ, env, clear=True): - with mock.patch( - 'clarifai.cli.base.masked_input', return_value='new_pat' - ) as mock_masked: - result = self.runner.invoke( - cli, - ['--config', './config.yaml', 'config', 'create-context', 'prod'], - input='produser\nhttps://api.clarifai.com\n', - catch_exceptions=False, - ) + with self.runner.isolated_filesystem(): + with mock.patch('clarifai.cli.base.DEFAULT_CONFIG', './config.yaml'): + self._login_first() + result = self.runner.invoke( + cli, + ['--config', './config.yaml', 'whoami', '-o', 'json'], + catch_exceptions=False, + ) assert result.exit_code == 0 - assert 'Set CLARIFAI_PAT environment variable' in result.output - mock_masked.assert_called_once() + data = json.loads(result.output.strip()) + assert data['user_id'] == 'testuser' + assert 'context' in data + assert 'api_base' in data + + def test_whoami_not_logged_in(self): + """Error when no PAT is configured.""" + with self.runner.isolated_filesystem(): + env = os.environ.copy() + env.pop('CLARIFAI_PAT', None) + with mock.patch.dict(os.environ, env, clear=True): + result = self.runner.invoke( + cli, + ['--config', './nonexistent.yaml', 'whoami'], + catch_exceptions=True, + ) + + assert result.exit_code == 1 + assert 'Not logged in' in result.output or 'Not logged in' in (result.stderr or '') + + +class TestEnvPrefix: + """Test cases for _env_prefix helper.""" + + def test_dev_url(self): + from clarifai.cli.base import _env_prefix + + assert _env_prefix('https://api-dev.clarifai.com') == 'dev' + + def test_staging_url(self): + from clarifai.cli.base import _env_prefix + + assert _env_prefix('https://api-staging.clarifai.com') == 'staging' + + def test_prod_url(self): + from clarifai.cli.base import _env_prefix + + assert _env_prefix('https://api.clarifai.com') == 'prod' diff --git a/tests/cli/test_model_cli.py b/tests/cli/test_model_cli.py index e9011a1c..21f23bf3 100644 --- a/tests/cli/test_model_cli.py +++ b/tests/cli/test_model_cli.py @@ -20,7 +20,6 @@ def test_ensure_config_exists_for_upload_creates_file(monkeypatch, tmp_path): responses = iter( [ "n", # Do you want to create config.yaml yourself? No, create interactively - "", # context selection (keep current) "custom-model", # model id "", # user id (default) "", # app id (default) @@ -97,7 +96,7 @@ class TestModelCliOllama: """Test CLI model commands with Ollama toolkit integration.""" def test_customize_ollama_model_function_call(self): - """Test that customize_ollama_model is called with correct parameters.""" + """Test that customize_ollama_model initializes the toolkit section.""" with tempfile.TemporaryDirectory() as tmp_dir: # Create a mock config.yaml file config_file = os.path.join(tmp_dir, 'config.yaml') @@ -117,8 +116,9 @@ def test_customize_ollama_model_function_call(self): with open(config_file, 'r', encoding='utf-8') as f: updated_config = yaml.safe_load(f) - assert updated_config['model']['user_id'] == test_user_id + # customize_ollama_model doesn't update model.user_id — that's injected from CLI context assert 'toolkit' in updated_config + assert isinstance(updated_config['toolkit'], dict) def test_customize_ollama_model_missing_user_id_raises_error(self): """Test that customize_ollama_model raises TypeError when user_id is missing.""" @@ -174,7 +174,7 @@ def __init__(self): with open(config_file, 'r', encoding='utf-8') as f: updated_config = yaml.safe_load(f) - assert updated_config['model']['user_id'] == 'new-user-id' + # customize_ollama_model doesn't update model.user_id — that's injected from CLI context assert updated_config['toolkit']['model'] == 'mistral' assert updated_config['toolkit']['port'] == '8080' assert updated_config['toolkit']['context_length'] == '8192' diff --git a/tests/cli/test_predict.py b/tests/cli/test_predict.py new file mode 100644 index 00000000..f7339fa8 --- /dev/null +++ b/tests/cli/test_predict.py @@ -0,0 +1,497 @@ +"""Tests for the improved `clarifai model predict` CLI command.""" + +import json +from unittest.mock import MagicMock + +import click +import pytest + +from clarifai.cli.model import ( + _build_chat_request, + _coerce_input_value, + _detect_media_type_from_ext, + _get_first_media_param, + _get_first_str_param, + _is_streaming_method, + _parse_kv_inputs, + _resolve_model_ref, + _select_method, +) + + +# --------------------------------------------------------------------------- +# _resolve_model_ref +# --------------------------------------------------------------------------- +class TestResolveModelRef: + def test_full_url_passthrough(self): + url = "https://clarifai.com/openai/chat-completion/models/GPT-4" + assert _resolve_model_ref(url) == url + + def test_http_url_passthrough(self): + url = "http://clarifai.com/openai/chat-completion/models/GPT-4" + assert _resolve_model_ref(url) == url + + def test_shorthand_default_base(self): + result = _resolve_model_ref("openai/chat-completion/models/GPT-4") + assert result == "https://clarifai.com/openai/chat-completion/models/GPT-4" + + def test_shorthand_custom_base(self): + result = _resolve_model_ref( + "openai/chat-completion/models/GPT-4", + ui_base="https://web-dev.clarifai.com", + ) + assert result == "https://web-dev.clarifai.com/openai/chat-completion/models/GPT-4" + + def test_shorthand_trailing_slash_base(self): + result = _resolve_model_ref( + "openai/chat-completion/models/GPT-4", + ui_base="https://clarifai.com/", + ) + assert result == "https://clarifai.com/openai/chat-completion/models/GPT-4" + + def test_invalid_no_models_keyword(self): + with pytest.raises(click.UsageError, match="Invalid model reference"): + _resolve_model_ref("openai/chat-completion/GPT-4") + + def test_invalid_too_few_parts(self): + with pytest.raises(click.UsageError, match="Invalid model reference"): + _resolve_model_ref("openai/models/GPT-4") + + def test_invalid_too_many_parts(self): + with pytest.raises(click.UsageError, match="Invalid model reference"): + _resolve_model_ref("openai/app/models/GPT-4/extra") + + def test_none_returns_none(self): + assert _resolve_model_ref(None) is None + + def test_empty_returns_none(self): + assert _resolve_model_ref("") is None + + +# --------------------------------------------------------------------------- +# Mock helpers for method signature-based tests +# --------------------------------------------------------------------------- +def _make_field(name, data_type, iterator=False): + """Create a mock ModelTypeField.""" + from clarifai_grpc.grpc.api import resources_pb2 + + field = resources_pb2.ModelTypeField() + field.name = name + field.type = data_type + field.iterator = iterator + return field + + +def _make_method_sig(name, input_fields, output_fields): + """Create a mock MethodSignature.""" + from clarifai_grpc.grpc.api import resources_pb2 + + sig = resources_pb2.MethodSignature() + sig.name = name + sig.input_fields.extend(input_fields) + sig.output_fields.extend(output_fields) + return sig + + +def _make_model_client(method_sigs_dict, sig_strings=None): + """Create a mock model client with pre-set method signatures. + + Args: + method_sigs_dict: Dict of method_name -> MethodSignature proto. + sig_strings: Dict of method_name -> signature display string. + """ + client = MagicMock() + client._defined = True + client._method_signatures = method_sigs_dict + client.available_methods.return_value = list(method_sigs_dict.keys()) + if sig_strings: + client.method_signature.side_effect = lambda m: sig_strings[m] + return client + + +# --------------------------------------------------------------------------- +# _get_first_str_param +# --------------------------------------------------------------------------- +class TestGetFirstStrParam: + def test_finds_str_param(self): + from clarifai_grpc.grpc.api import resources_pb2 + + sig = _make_method_sig( + "predict", + [_make_field("prompt", resources_pb2.ModelTypeField.DataType.STR)], + [_make_field("return", resources_pb2.ModelTypeField.DataType.STR)], + ) + client = _make_model_client({"predict": sig}) + assert _get_first_str_param(client, "predict") == "prompt" + + def test_skips_non_str(self): + from clarifai_grpc.grpc.api import resources_pb2 + + sig = _make_method_sig( + "predict", + [ + _make_field("count", resources_pb2.ModelTypeField.DataType.INT), + _make_field("text", resources_pb2.ModelTypeField.DataType.STR), + ], + [_make_field("return", resources_pb2.ModelTypeField.DataType.STR)], + ) + client = _make_model_client({"predict": sig}) + assert _get_first_str_param(client, "predict") == "text" + + def test_no_str_param(self): + from clarifai_grpc.grpc.api import resources_pb2 + + sig = _make_method_sig( + "predict", + [_make_field("image", resources_pb2.ModelTypeField.DataType.IMAGE)], + [_make_field("return", resources_pb2.ModelTypeField.DataType.STR)], + ) + client = _make_model_client({"predict": sig}) + assert _get_first_str_param(client, "predict") is None + + def test_missing_method(self): + client = _make_model_client({}) + assert _get_first_str_param(client, "nonexistent") is None + + +# --------------------------------------------------------------------------- +# _get_first_media_param +# --------------------------------------------------------------------------- +class TestGetFirstMediaParam: + def test_finds_image_param(self): + from clarifai_grpc.grpc.api import resources_pb2 + + sig = _make_method_sig( + "predict", + [_make_field("image", resources_pb2.ModelTypeField.DataType.IMAGE)], + [_make_field("return", resources_pb2.ModelTypeField.DataType.STR)], + ) + client = _make_model_client({"predict": sig}) + name, media_type = _get_first_media_param(client, "predict") + assert name == "image" + assert media_type == "image" + + def test_finds_video_param(self): + from clarifai_grpc.grpc.api import resources_pb2 + + sig = _make_method_sig( + "predict", + [_make_field("clip", resources_pb2.ModelTypeField.DataType.VIDEO)], + [_make_field("return", resources_pb2.ModelTypeField.DataType.STR)], + ) + client = _make_model_client({"predict": sig}) + name, media_type = _get_first_media_param(client, "predict") + assert name == "clip" + assert media_type == "video" + + def test_finds_audio_param(self): + from clarifai_grpc.grpc.api import resources_pb2 + + sig = _make_method_sig( + "predict", + [_make_field("audio", resources_pb2.ModelTypeField.DataType.AUDIO)], + [_make_field("return", resources_pb2.ModelTypeField.DataType.STR)], + ) + client = _make_model_client({"predict": sig}) + name, media_type = _get_first_media_param(client, "predict") + assert name == "audio" + assert media_type == "audio" + + def test_no_media_param(self): + from clarifai_grpc.grpc.api import resources_pb2 + + sig = _make_method_sig( + "predict", + [_make_field("prompt", resources_pb2.ModelTypeField.DataType.STR)], + [_make_field("return", resources_pb2.ModelTypeField.DataType.STR)], + ) + client = _make_model_client({"predict": sig}) + name, media_type = _get_first_media_param(client, "predict") + assert name is None + assert media_type is None + + +# --------------------------------------------------------------------------- +# _coerce_input_value +# --------------------------------------------------------------------------- +class TestCoerceInputValue: + def test_coerce_int(self): + from clarifai_grpc.grpc.api import resources_pb2 + + sig = _make_method_sig( + "predict", + [_make_field("max_tokens", resources_pb2.ModelTypeField.DataType.INT)], + [_make_field("return", resources_pb2.ModelTypeField.DataType.STR)], + ) + client = _make_model_client({"predict": sig}) + assert _coerce_input_value("200", client, "predict", "max_tokens") == 200 + + def test_coerce_float(self): + from clarifai_grpc.grpc.api import resources_pb2 + + sig = _make_method_sig( + "predict", + [_make_field("temperature", resources_pb2.ModelTypeField.DataType.FLOAT)], + [_make_field("return", resources_pb2.ModelTypeField.DataType.STR)], + ) + client = _make_model_client({"predict": sig}) + assert _coerce_input_value("0.7", client, "predict", "temperature") == 0.7 + + def test_coerce_bool_true(self): + from clarifai_grpc.grpc.api import resources_pb2 + + sig = _make_method_sig( + "predict", + [_make_field("stream", resources_pb2.ModelTypeField.DataType.BOOL)], + [_make_field("return", resources_pb2.ModelTypeField.DataType.STR)], + ) + client = _make_model_client({"predict": sig}) + assert _coerce_input_value("true", client, "predict", "stream") is True + assert _coerce_input_value("yes", client, "predict", "stream") is True + assert _coerce_input_value("1", client, "predict", "stream") is True + + def test_coerce_bool_false(self): + from clarifai_grpc.grpc.api import resources_pb2 + + sig = _make_method_sig( + "predict", + [_make_field("stream", resources_pb2.ModelTypeField.DataType.BOOL)], + [_make_field("return", resources_pb2.ModelTypeField.DataType.STR)], + ) + client = _make_model_client({"predict": sig}) + assert _coerce_input_value("false", client, "predict", "stream") is False + + def test_str_passthrough(self): + from clarifai_grpc.grpc.api import resources_pb2 + + sig = _make_method_sig( + "predict", + [_make_field("prompt", resources_pb2.ModelTypeField.DataType.STR)], + [_make_field("return", resources_pb2.ModelTypeField.DataType.STR)], + ) + client = _make_model_client({"predict": sig}) + assert _coerce_input_value("hello", client, "predict", "prompt") == "hello" + + def test_unknown_param_passthrough(self): + client = _make_model_client({}) + assert _coerce_input_value("hello", client, "predict", "unknown") == "hello" + + +# --------------------------------------------------------------------------- +# _parse_kv_inputs +# --------------------------------------------------------------------------- +class TestParseKvInputs: + def test_simple_kv(self): + from clarifai_grpc.grpc.api import resources_pb2 + + sig = _make_method_sig( + "predict", + [ + _make_field("prompt", resources_pb2.ModelTypeField.DataType.STR), + _make_field("max_tokens", resources_pb2.ModelTypeField.DataType.INT), + ], + [_make_field("return", resources_pb2.ModelTypeField.DataType.STR)], + ) + client = _make_model_client({"predict": sig}) + result = _parse_kv_inputs(("prompt=Hello", "max_tokens=200"), client, "predict") + assert result == {"prompt": "Hello", "max_tokens": 200} + + def test_value_with_equals(self): + """Values can contain = signs.""" + client = _make_model_client({}) + result = _parse_kv_inputs(("prompt=a=b=c",), client, "predict") + assert result == {"prompt": "a=b=c"} + + def test_invalid_no_equals(self): + client = _make_model_client({}) + with pytest.raises(click.UsageError, match="Invalid input format"): + _parse_kv_inputs(("no-equals-here",), client, "predict") + + +# --------------------------------------------------------------------------- +# _detect_media_type_from_ext +# --------------------------------------------------------------------------- +class TestDetectMediaTypeFromExt: + def test_image_extensions(self): + for ext in [".jpg", ".png", ".gif", ".bmp", ".tiff"]: + assert _detect_media_type_from_ext(f"photo{ext}") == "image" + + def test_video_extensions(self): + for ext in [".mp4", ".mov", ".avi", ".mkv", ".webm"]: + assert _detect_media_type_from_ext(f"clip{ext}") == "video" + + def test_audio_extensions(self): + for ext in [".wav", ".mp3", ".flac", ".aac", ".ogg"]: + assert _detect_media_type_from_ext(f"sound{ext}") == "audio" + + def test_unknown_defaults_to_image(self): + assert _detect_media_type_from_ext("file.xyz") == "image" + + def test_url_path(self): + assert _detect_media_type_from_ext("https://example.com/photo.jpg") == "image" + assert _detect_media_type_from_ext("https://example.com/clip.mp4") == "video" + + +# --------------------------------------------------------------------------- +# _is_streaming_method +# --------------------------------------------------------------------------- +class TestIsStreamingMethod: + def test_streaming_signature(self): + client = MagicMock() + client.method_signature.return_value = "def generate(prompt: str) -> Iterator[str]:" + assert _is_streaming_method(client, "generate") is True + + def test_unary_signature(self): + client = MagicMock() + client.method_signature.return_value = "def predict(prompt: str) -> str:" + assert _is_streaming_method(client, "predict") is False + + def test_no_return_type(self): + client = MagicMock() + client.method_signature.return_value = "def predict(prompt: str):" + assert _is_streaming_method(client, "predict") is False + + +# --------------------------------------------------------------------------- +# _select_method +# --------------------------------------------------------------------------- +class TestSelectMethod: + def _client_with_sigs(self, sig_strings): + client = MagicMock() + client.method_signature.side_effect = lambda m: sig_strings[m] + return client + + def test_chat_selects_openai_stream(self): + methods = ['predict', 'openai_stream_transport', 'openai_transport'] + client = self._client_with_sigs( + { + 'predict': 'def predict(prompt: str) -> str:', + 'openai_stream_transport': 'def openai_stream_transport(msg: str) -> Iterator[str]:', + 'openai_transport': 'def openai_transport(msg: str) -> str:', + } + ) + method, is_openai = _select_method( + methods, client, None, is_chat=True, has_text_input=True + ) + assert method == 'openai_stream_transport' + assert is_openai is True + + def test_chat_fallback_to_non_stream(self): + methods = ['predict', 'openai_transport'] + client = self._client_with_sigs( + { + 'predict': 'def predict(prompt: str) -> str:', + 'openai_transport': 'def openai_transport(msg: str) -> str:', + } + ) + method, is_openai = _select_method( + methods, client, None, is_chat=True, has_text_input=True + ) + assert method == 'openai_transport' + assert is_openai is True + + def test_chat_no_openai_methods_errors(self): + methods = ['predict'] + client = self._client_with_sigs( + { + 'predict': 'def predict(prompt: str) -> str:', + } + ) + with pytest.raises(click.UsageError, match="does not support OpenAI chat"): + _select_method(methods, client, None, is_chat=True, has_text_input=True) + + def test_explicit_method(self): + methods = ['predict', 'generate'] + client = self._client_with_sigs( + { + 'predict': 'def predict(prompt: str) -> str:', + 'generate': 'def generate(prompt: str) -> Iterator[str]:', + } + ) + method, is_openai = _select_method( + methods, client, 'predict', is_chat=False, has_text_input=True + ) + assert method == 'predict' + assert is_openai is False + + def test_explicit_method_not_available(self): + methods = ['predict'] + client = self._client_with_sigs( + { + 'predict': 'def predict(prompt: str) -> str:', + } + ) + with pytest.raises(click.UsageError, match="not available"): + _select_method(methods, client, 'nonexistent', is_chat=False, has_text_input=True) + + def test_auto_openai_for_text(self): + methods = ['predict', 'openai_stream_transport'] + client = self._client_with_sigs( + { + 'predict': 'def predict(prompt: str) -> str:', + 'openai_stream_transport': 'def openai_stream_transport(msg: str) -> Iterator[str]:', + } + ) + method, is_openai = _select_method( + methods, client, None, is_chat=False, has_text_input=True + ) + assert method == 'openai_stream_transport' + assert is_openai is True + + def test_prefer_streaming_no_openai(self): + methods = ['predict', 'generate'] + client = self._client_with_sigs( + { + 'predict': 'def predict(prompt: str) -> str:', + 'generate': 'def generate(prompt: str) -> Iterator[str]:', + } + ) + method, is_openai = _select_method( + methods, client, None, is_chat=False, has_text_input=True + ) + assert method == 'generate' + assert is_openai is False + + def test_fallback_to_predict(self): + methods = ['predict'] + client = self._client_with_sigs( + { + 'predict': 'def predict(prompt: str) -> str:', + } + ) + method, is_openai = _select_method( + methods, client, None, is_chat=False, has_text_input=True + ) + assert method == 'predict' + assert is_openai is False + + def test_no_text_skips_openai_auto(self): + """Without text input, don't auto-select OpenAI path.""" + methods = ['predict', 'openai_stream_transport'] + client = self._client_with_sigs( + { + 'predict': 'def predict(prompt: str) -> str:', + 'openai_stream_transport': 'def openai_stream_transport(msg: str) -> Iterator[str]:', + } + ) + method, is_openai = _select_method( + methods, client, None, is_chat=False, has_text_input=False + ) + assert method == 'predict' + assert is_openai is False + + +# --------------------------------------------------------------------------- +# _build_chat_request +# --------------------------------------------------------------------------- +class TestBuildChatRequest: + def test_builds_valid_json(self): + result = _build_chat_request("What is AI?") + data = json.loads(result) + assert data["messages"] == [{"role": "user", "content": "What is AI?"}] + assert data["stream"] is True + + def test_preserves_special_chars(self): + result = _build_chat_request('Say "hello" & ') + data = json.loads(result) + assert data["messages"][0]["content"] == 'Say "hello" & ' diff --git a/tests/client/__init__.py b/tests/client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/compute_orchestration/__init__.py b/tests/compute_orchestration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/requirements.txt b/tests/requirements.txt index 35b842f4..98af6db6 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -4,7 +4,7 @@ pytest-xdist==2.5.0 pytest-asyncio py llama-index-core==0.13.2 -huggingface_hub[hf_transfer]==0.27.1 +huggingface_hub[hf_transfer]>=0.27.1 pypdf==3.17.4 seaborn==0.13.2 pycocotools>=2.0.7 diff --git a/tests/runners/__init__.py b/tests/runners/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/runners/hf_mbart_model/requirements.txt b/tests/runners/hf_mbart_model/requirements.txt index d9a4d1c1..51efb6d4 100644 --- a/tests/runners/hf_mbart_model/requirements.txt +++ b/tests/runners/hf_mbart_model/requirements.txt @@ -1,9 +1,9 @@ -accelerate==1.6.0 +accelerate>=1.6.0 blobfile clarifai requests -sentencepiece==0.2.0 -tiktoken==0.9.0 -tokenizers==0.21.1 +sentencepiece>=0.2.0 +tiktoken>=0.9.0 +tokenizers>=0.21.1 torch==2.6.0 -transformers==4.51.3 +transformers>=4.51.3 diff --git a/tests/runners/test_git_info.py b/tests/runners/test_git_info.py index c9b0c79d..d746b510 100644 --- a/tests/runners/test_git_info.py +++ b/tests/runners/test_git_info.py @@ -4,7 +4,6 @@ import subprocess import tempfile import unittest -from unittest.mock import patch from clarifai.runners.models.model_builder import ModelBuilder @@ -133,7 +132,7 @@ def test_get_git_info_git_folder(self): self.assertEqual(len(git_info['commit']), 40) # Git commit hash length def test_check_git_status_clean_repo(self): - """Test _check_git_status_and_prompt with clean repository""" + """Test _check_git_status with clean repository (no warnings).""" # Initialize git repo with clean state subprocess.run(['git', 'init'], cwd=self.model_dir, capture_output=True, check=False) subprocess.run( @@ -157,12 +156,11 @@ def test_check_git_status_clean_repo(self): ) builder = ModelBuilder(self.model_dir, download_validation_only=True) - result = builder._check_git_status_and_prompt() - self.assertTrue(result) + # _check_git_status is non-blocking (returns None, just logs warnings) + builder._check_git_status() - @patch('builtins.input', return_value='y') - def test_check_git_status_dirty_repo_accept(self, mock_input): - """Test _check_git_status_and_prompt with uncommitted changes - user accepts""" + def test_check_git_status_dirty_repo(self): + """Test _check_git_status with uncommitted changes (warns but does not block).""" # Initialize git repo with uncommitted changes subprocess.run(['git', 'init'], cwd=self.model_dir, capture_output=True, check=False) subprocess.run( @@ -190,43 +188,8 @@ def test_check_git_status_dirty_repo_accept(self, mock_input): f.write("uncommitted content") builder = ModelBuilder(self.model_dir, download_validation_only=True) - result = builder._check_git_status_and_prompt() - self.assertTrue(result) - mock_input.assert_called_once() - - @patch('builtins.input', return_value='n') - def test_check_git_status_dirty_repo_decline(self, mock_input): - """Test _check_git_status_and_prompt with uncommitted changes - user declines""" - # Initialize git repo with uncommitted changes - subprocess.run(['git', 'init'], cwd=self.model_dir, capture_output=True, check=False) - subprocess.run( - ['git', 'config', 'user.email', 'test@test.com'], - cwd=self.model_dir, - capture_output=True, - check=False, - ) - subprocess.run( - ['git', 'config', 'user.name', 'Test User'], - cwd=self.model_dir, - capture_output=True, - check=False, - ) - subprocess.run(['git', 'add', '.'], cwd=self.model_dir, capture_output=True, check=False) - subprocess.run( - ['git', 'commit', '-m', 'Initial commit'], - cwd=self.model_dir, - capture_output=True, - check=False, - ) - - # Add uncommitted changes - with open(os.path.join(self.model_dir, "uncommitted.txt"), "w") as f: - f.write("uncommitted content") - - builder = ModelBuilder(self.model_dir, download_validation_only=True) - result = builder._check_git_status_and_prompt() - self.assertFalse(result) - mock_input.assert_called_once() + # _check_git_status is non-blocking (returns None, just logs warnings) + builder._check_git_status() def test_get_model_version_proto_with_git_info(self): """Test that git info is properly added to model version proto""" @@ -261,7 +224,7 @@ def test_get_model_version_proto_without_git_info(self): self.assertNotIn('git_registry', metadata_dict) def test_git_status_limited_to_model_path(self): - """Test that git status checking is limited to model path only""" + """Test that _check_git_status uses model path scope (git status --porcelain .)""" # Create a parent directory for the git repo parent_dir = os.path.join(self.test_dir, "git_repo") os.makedirs(parent_dir) @@ -300,27 +263,11 @@ def test_git_status_limited_to_model_path(self): with open(outside_file, "w") as f: f.write("This file is outside the model directory") - # Test that ModelBuilder ignores files outside model path + # _check_git_status is non-blocking — just verify it runs without error. + # It uses `git status --porcelain .` scoped to model path, so outside + # changes should not appear in its output. builder = ModelBuilder(self.model_dir, download_validation_only=True) - result = builder._check_git_status_and_prompt() - - # Should return True because no uncommitted changes within model path - self.assertTrue(result) - - # Now add uncommitted file inside model directory - inside_file = os.path.join(self.model_dir, "inside_model.txt") - with open(inside_file, "w") as f: - f.write("This file is inside the model directory") - - # Mock user declining the prompt - with patch('builtins.input', return_value='n'): - result = builder._check_git_status_and_prompt() - self.assertFalse(result) - - # Mock user accepting the prompt - with patch('builtins.input', return_value='y'): - result = builder._check_git_status_and_prompt() - self.assertTrue(result) + builder._check_git_status() if __name__ == '__main__': diff --git a/tests/runners/test_model_deploy.py b/tests/runners/test_model_deploy.py new file mode 100644 index 00000000..e9c4fc15 --- /dev/null +++ b/tests/runners/test_model_deploy.py @@ -0,0 +1,2633 @@ +"""Tests for model deploy, config normalization, and GPU presets.""" + +import os +import tempfile + +import pytest +import requests +import yaml + +from clarifai.runners.models.model_builder import ModelBuilder +from clarifai.utils.compute_presets import ( + FALLBACK_GPU_PRESETS, + _detect_quant_from_repo_name, + _estimate_kv_cache_bytes, + _estimate_vram_bytes, + _estimate_weight_bytes, + _get_hf_model_config, + _get_hf_model_info, + _get_hf_token, + _select_instance_by_vram, + get_compute_cluster_config, + get_deploy_compute_cluster_id, + get_deploy_nodepool_id, + get_inference_compute_for_gpu, + get_nodepool_config, + infer_gpu_from_config, + list_gpu_presets, + parse_k8s_quantity, + recommend_instance, + resolve_gpu, +) + + +class TestGPUPresets: + """Test GPU preset resolution and lookup.""" + + def test_resolve_known_gpu_fallback(self): + """Known GPU names resolve from fallback presets.""" + for name in ["A10G", "L40S", "G6E", "CPU"]: + preset = resolve_gpu(name) + assert "description" in preset + assert "instance_type_id" in preset + assert "inference_compute_info" in preset + + def test_resolve_instance_type_id_case_insensitive(self): + """GPU names should be case-insensitive.""" + preset_lower = resolve_gpu("a10g") + preset_upper = resolve_gpu("A10G") + assert preset_lower["instance_type_id"] == preset_upper["instance_type_id"] + + def test_resolve_unknown_gpu_raises(self): + """Unknown GPU name should raise ValueError.""" + with pytest.raises(ValueError, match="Unknown instance type"): + resolve_gpu("NONEXISTENT_GPU") + + def test_get_inference_compute_for_gpu(self): + """Returns a dict with expected compute info keys.""" + info = get_inference_compute_for_gpu("A10G") + assert "cpu_limit" in info + assert "cpu_memory" in info + assert "num_accelerators" in info + assert "accelerator_type" in info + assert info["num_accelerators"] == 1 + assert "NVIDIA-A10G" in info["accelerator_type"] + + def test_get_inference_compute_for_cpu(self): + """CPU preset has no accelerators.""" + info = get_inference_compute_for_gpu("CPU") + assert info["num_accelerators"] == 0 + + def test_infer_gpu_from_config_a10g(self): + """Infer A10G from inference_compute_info.""" + config = { + "inference_compute_info": { + "cpu_limit": "4", + "cpu_memory": "16Gi", + "num_accelerators": 1, + "accelerator_type": ["NVIDIA-A10G"], + "accelerator_memory": "24Gi", + } + } + assert infer_gpu_from_config(config) == "A10G" + + def test_infer_gpu_from_config_cpu(self): + """Infer CPU when no accelerators.""" + config = { + "inference_compute_info": { + "cpu_limit": "4", + "cpu_memory": "16Gi", + "num_accelerators": 0, + "accelerator_type": [], + } + } + assert infer_gpu_from_config(config) == "CPU" + + def test_infer_gpu_from_config_missing(self): + """Returns None when no inference_compute_info.""" + assert infer_gpu_from_config({}) is None + + def test_list_gpu_presets_returns_string(self): + """list_gpu_presets returns a string (API data or login message).""" + result = list_gpu_presets() + assert isinstance(result, str) + # Either shows instance types from API or a login prompt + assert "instance type" in result.lower() or "logged in" in result.lower() + + def test_resolve_gpu_nvidia_prefix_via_api(self): + """'gpu-nvidia-a10g' should resolve to real API instance type, not fallback.""" + from unittest.mock import MagicMock, patch + + mock_g5 = MagicMock() + mock_g5.id = "g5.xlarge" + mock_g5.description = "NVIDIA A10G 24GB" + mock_g5.cloud_provider.id = "aws" + mock_g5.region = "us-east-1" + mock_g5.compute_info.cpu_limit = "4" + mock_g5.compute_info.cpu_memory = "16Gi" + mock_g5.compute_info.num_accelerators = 1 + mock_g5.compute_info.accelerator_type = ["NVIDIA-A10G"] + mock_g5.compute_info.accelerator_memory = "24Gi" + + with patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", + return_value=[mock_g5], + ): + # 'gpu-nvidia-a10g' should normalize to 'A10G' and match via accelerator_type + preset = resolve_gpu("gpu-nvidia-a10g") + assert preset["instance_type_id"] == "g5.xlarge" + assert preset["cloud_provider"] == "aws" + + def test_fallback_presets_complete(self): + """All fallback presets have required keys.""" + for name, preset in FALLBACK_GPU_PRESETS.items(): + assert "description" in preset, f"Missing description for {name}" + assert "instance_type_id" in preset, f"Missing instance_type_id for {name}" + assert "inference_compute_info" in preset, f"Missing inference_compute_info for {name}" + ici = preset["inference_compute_info"] + assert "cpu_limit" in ici + assert "num_accelerators" in ici + + +class TestComputeConfigs: + """Test compute cluster and nodepool config generation.""" + + def test_compute_cluster_config(self): + """Cluster config has required fields.""" + config = get_compute_cluster_config("test-user") + cc = config["compute_cluster"] + assert cc["id"] == "deploy-cc-aws-us-east-1" + assert cc["cloud_provider"]["id"] == "aws" + assert cc["region"] == "us-east-1" + assert cc["managed_by"] == "clarifai" + + def test_compute_cluster_config_custom_cloud(self): + """Cluster config uses specified cloud and region.""" + config = get_compute_cluster_config( + "test-user", cloud_provider="gcp", region="us-central1" + ) + cc = config["compute_cluster"] + assert cc["id"] == "deploy-cc-gcp-us-central1" + assert cc["cloud_provider"]["id"] == "gcp" + assert cc["region"] == "us-central1" + + def test_deploy_ids(self): + """Compute cluster and nodepool IDs are deterministic.""" + assert get_deploy_compute_cluster_id("aws", "us-east-1") == "deploy-cc-aws-us-east-1" + assert get_deploy_compute_cluster_id("gcp", "us-central1") == "deploy-cc-gcp-us-central1" + assert get_deploy_nodepool_id("g5.xlarge") == "deploy-np-g5-xlarge" + assert get_deploy_nodepool_id("gpu-nvidia-a10g") == "deploy-np-gpu-nvidia-a10g" + assert get_deploy_nodepool_id("g5.2xlarge") == "deploy-np-g5-2xlarge" + + def test_nodepool_config(self): + """Nodepool config has required fields.""" + config = get_nodepool_config( + instance_type_id="gpu-nvidia-a10g", + compute_cluster_id="test-cc", + user_id="test-user", + ) + np = config["nodepool"] + assert np["id"] == "deploy-np-gpu-nvidia-a10g" + assert np["compute_cluster"]["id"] == "test-cc" + assert np["compute_cluster"]["user_id"] == "test-user" + assert len(np["instance_types"]) == 1 + assert np["instance_types"][0]["id"] == "gpu-nvidia-a10g" + + def test_nodepool_config_with_compute_info(self): + """Nodepool config includes compute_info when provided.""" + ci = {"cpu_limit": "4", "num_accelerators": 1} + config = get_nodepool_config( + instance_type_id="gpu-nvidia-a10g", + compute_cluster_id="test-cc", + user_id="test-user", + compute_info=ci, + ) + np = config["nodepool"] + assert np["instance_types"][0]["compute_info"] == ci + + +class TestNormalizeConfig: + """Test ModelBuilder.normalize_config().""" + + def test_inject_user_id_and_app_id(self): + """user_id and app_id injected when missing.""" + config = {"model": {"id": "test", "model_type_id": "text-to-text"}} + result = ModelBuilder.normalize_config(config, user_id="user1", app_id="app1") + assert result["model"]["user_id"] == "user1" + assert result["model"]["app_id"] == "app1" + + def test_existing_user_id_not_overwritten(self): + """Existing user_id/app_id are preserved.""" + config = { + "model": { + "id": "test", + "model_type_id": "text-to-text", + "user_id": "existing", + "app_id": "existing-app", + } + } + result = ModelBuilder.normalize_config(config, user_id="new-user", app_id="new-app") + assert result["model"]["user_id"] == "existing" + assert result["model"]["app_id"] == "existing-app" + + def test_default_app_id_when_missing(self): + """app_id defaults to 'main' when not in config and not provided.""" + config = {"model": {"id": "test", "user_id": "user1", "model_type_id": "text-to-text"}} + result = ModelBuilder.normalize_config(config) + assert result["model"]["app_id"] == "main" + + def test_default_app_id_not_applied_when_provided(self): + """Default app_id is not used when explicitly provided via parameter.""" + config = {"model": {"id": "test", "model_type_id": "text-to-text"}} + result = ModelBuilder.normalize_config(config, app_id="my-custom-app") + assert result["model"]["app_id"] == "my-custom-app" + + def test_default_app_id_not_applied_when_in_config(self): + """Default app_id is not used when already in config.""" + config = { + "model": {"id": "test", "model_type_id": "text-to-text", "app_id": "existing-app"} + } + result = ModelBuilder.normalize_config(config) + assert result["model"]["app_id"] == "existing-app" + + def test_default_model_type_id_when_missing(self): + """model_type_id defaults to 'any-to-any' when not in config.""" + config = {"model": {"id": "test"}} + result = ModelBuilder.normalize_config(config) + assert result["model"]["model_type_id"] == "any-to-any" + + def test_default_model_type_id_not_applied_when_in_config(self): + """Existing model_type_id is preserved.""" + config = {"model": {"id": "test", "model_type_id": "text-to-text"}} + result = ModelBuilder.normalize_config(config) + assert result["model"]["model_type_id"] == "text-to-text" + + def test_expand_compute_instance(self): + """compute.instance expands to inference_compute_info with wildcard accelerator_type.""" + config = { + "model": {"id": "test", "model_type_id": "text-to-text"}, + "compute": {"instance": "A10G"}, + } + result = ModelBuilder.normalize_config(config) + assert "inference_compute_info" in result + assert result["inference_compute_info"]["num_accelerators"] == 1 + # Should use wildcard so model can be scheduled on any NVIDIA GPU + assert result["inference_compute_info"]["accelerator_type"] == ["NVIDIA-*"] + assert result["compute"]["instance"] == "A10G" # compute key preserved for reference + + def test_expand_compute_gpu_legacy(self): + """Legacy compute.gpu still works and gets normalized to compute.instance.""" + config = { + "model": {"id": "test", "model_type_id": "text-to-text"}, + "compute": {"gpu": "A10G"}, + } + result = ModelBuilder.normalize_config(config) + assert "inference_compute_info" in result + assert result["inference_compute_info"]["num_accelerators"] == 1 + assert result["compute"]["instance"] == "A10G" + assert "gpu" not in result["compute"] # legacy key removed after normalization + + def test_inference_compute_info_wins_over_compute_instance(self): + """If both compute.instance and inference_compute_info exist, inference_compute_info wins.""" + config = { + "model": {"id": "test", "model_type_id": "text-to-text"}, + "compute": {"instance": "A10G"}, + "inference_compute_info": { + "cpu_limit": "8", + "num_accelerators": 2, + "accelerator_type": ["NVIDIA-L40S"], + }, + } + result = ModelBuilder.normalize_config(config) + # inference_compute_info should be unchanged (it existed already) + assert result["inference_compute_info"]["num_accelerators"] == 2 + assert "NVIDIA-L40S" in result["inference_compute_info"]["accelerator_type"] + + def test_expand_simplified_checkpoints(self): + """Simplified checkpoints get type and when defaults.""" + config = { + "model": {"id": "test", "model_type_id": "text-to-text"}, + "checkpoints": {"repo_id": "meta-llama/Llama-3-8B"}, + } + result = ModelBuilder.normalize_config(config) + assert result["checkpoints"]["type"] == "huggingface" + assert result["checkpoints"]["when"] == "runtime" + + def test_existing_checkpoints_preserved(self): + """Existing checkpoints type and when are preserved.""" + config = { + "model": {"id": "test", "model_type_id": "text-to-text"}, + "checkpoints": { + "type": "custom", + "repo_id": "some/model", + "when": "build", + }, + } + result = ModelBuilder.normalize_config(config) + assert result["checkpoints"]["type"] == "custom" + assert result["checkpoints"]["when"] == "build" + + def test_build_info_defaults(self): + """build_info defaults are added when missing.""" + config = {"model": {"id": "test", "model_type_id": "text-to-text"}} + result = ModelBuilder.normalize_config(config) + assert result["build_info"] == {"python_version": "3.12"} + + def test_existing_build_info_preserved(self): + """Existing build_info is preserved.""" + config = { + "model": {"id": "test", "model_type_id": "text-to-text"}, + "build_info": {"python_version": "3.11"}, + } + result = ModelBuilder.normalize_config(config) + assert result["build_info"]["python_version"] == "3.11" + + def test_verbose_config_passthrough(self): + """Verbose config (already has all fields) passes through unchanged.""" + config = { + "model": { + "id": "test", + "user_id": "user1", + "app_id": "app1", + "model_type_id": "text-to-text", + }, + "inference_compute_info": { + "cpu_limit": "4", + "cpu_memory": "16Gi", + "num_accelerators": 1, + "accelerator_type": ["NVIDIA-A10G"], + "accelerator_memory": "24Gi", + }, + "build_info": {"python_version": "3.12"}, + } + result = ModelBuilder.normalize_config(config) + assert result["model"]["user_id"] == "user1" + assert result["inference_compute_info"]["num_accelerators"] == 1 + assert result["build_info"]["python_version"] == "3.12" + + +class TestSimplifyClonedConfig: + """Test simplify_cloned_config utility.""" + + def test_removes_placeholder_user_id(self): + """Placeholder user_id/app_id values are removed.""" + from clarifai.utils.cli import simplify_cloned_config + + config = { + "model": { + "id": "test", + "user_id": "user_id", + "app_id": "app_id", + "model_type_id": "text-to-text", + }, + "inference_compute_info": { + "cpu_limit": "4", + "cpu_memory": "16Gi", + "num_accelerators": 1, + "accelerator_type": ["NVIDIA-A10G"], + "accelerator_memory": "24Gi", + }, + } + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config, f) + tmp_path = f.name + + try: + simplify_cloned_config(tmp_path) + with open(tmp_path) as f: + result = yaml.safe_load(f) + assert "user_id" not in result["model"] + assert "app_id" not in result["model"] + finally: + os.unlink(tmp_path) + + def test_converts_compute_info_to_instance_shorthand(self): + """inference_compute_info matching A10G becomes compute.instance.""" + from clarifai.utils.cli import simplify_cloned_config + + config = { + "model": { + "id": "test", + "model_type_id": "text-to-text", + }, + "inference_compute_info": { + "cpu_limit": "4", + "cpu_memory": "16Gi", + "num_accelerators": 1, + "accelerator_type": ["NVIDIA-A10G"], + "accelerator_memory": "24Gi", + }, + } + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config, f) + tmp_path = f.name + + try: + simplify_cloned_config(tmp_path) + with open(tmp_path) as f: + result = yaml.safe_load(f) + assert "inference_compute_info" not in result + assert result["compute"]["instance"] == "A10G" + finally: + os.unlink(tmp_path) + + def test_updates_model_id_from_model_name(self): + """model_name is used to set model.id.""" + from clarifai.utils.cli import simplify_cloned_config + + config = { + "model": { + "id": "old-name", + "model_type_id": "text-to-text", + }, + } + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config, f) + tmp_path = f.name + + try: + simplify_cloned_config(tmp_path, model_name="meta-llama/Llama-3-8B") + with open(tmp_path) as f: + result = yaml.safe_load(f) + assert result["model"]["id"] == "Llama-3-8B" + finally: + os.unlink(tmp_path) + + +class TestResolveUserId: + """Test resolve_user_id() from config and API.""" + + def test_resolve_from_config(self): + """Resolves user_id from CLI config file.""" + from unittest.mock import MagicMock, patch + + from clarifai.utils.config import resolve_user_id + + mock_config = MagicMock() + mock_config.current.get.return_value = "config-user" + + with patch('clarifai.utils.config.Config.from_yaml', return_value=mock_config): + user_id = resolve_user_id() + assert user_id == "config-user" + + def test_resolve_falls_back_to_api(self): + """Falls back to API when config has no user_id.""" + from unittest.mock import MagicMock, patch + + from clarifai.utils.config import resolve_user_id + + # Mock config to have no user_id + mock_config = MagicMock() + mock_config.current.get.return_value = None + + # Mock User API call + mock_user = MagicMock() + mock_user.get_user_info.return_value.user.id = "api-user" + + with ( + patch('clarifai.utils.config.Config.from_yaml', return_value=mock_config), + patch('clarifai.client.user.User', return_value=mock_user), + ): + user_id = resolve_user_id(pat="test-pat") + assert user_id == "api-user" + + def test_resolve_returns_none_on_failure(self): + """Returns None when both config and API fail.""" + from unittest.mock import patch + + from clarifai.utils.config import resolve_user_id + + with ( + patch('clarifai.utils.config.Config.from_yaml', side_effect=Exception("no config")), + patch('clarifai.client.user.User', side_effect=Exception("no api")), + ): + user_id = resolve_user_id() + assert user_id is None + + def test_config_user_id_takes_priority_over_api(self): + """Config file user_id is used without making API call.""" + from unittest.mock import MagicMock, patch + + from clarifai.utils.config import resolve_user_id + + mock_config = MagicMock() + mock_config.current.get.return_value = "config-user" + + mock_user_cls = MagicMock() + + with ( + patch('clarifai.utils.config.Config.from_yaml', return_value=mock_config), + patch('clarifai.client.user.User', mock_user_cls), + ): + user_id = resolve_user_id(pat="test-pat") + assert user_id == "config-user" + # User class should NOT have been called + mock_user_cls.assert_not_called() + + +class TestModelDeployerValidation: + """Test ModelDeployer input validation.""" + + def test_no_source_raises(self): + """No model source raises UserError.""" + from clarifai.runners.models.model_deploy import ModelDeployer + + deployer = ModelDeployer() + with pytest.raises(Exception, match="You must specify either MODEL_PATH"): + deployer.deploy() + + def test_multiple_sources_raises(self): + """Multiple model sources raises UserError.""" + from clarifai.runners.models.model_deploy import ModelDeployer + + deployer = ModelDeployer( + model_path="/tmp/model", model_url="https://clarifai.com/u/a/models/m" + ) + with pytest.raises(Exception, match="Specify only one of"): + deployer.deploy() + + def test_existing_model_without_gpu_raises(self): + """Deploying existing model without GPU raises UserError.""" + from clarifai.runners.models.model_deploy import ModelDeployer + + deployer = ModelDeployer(model_url="https://clarifai.com/user1/app1/models/my-model") + with pytest.raises(Exception, match="You must specify --instance"): + deployer.deploy() + + +class TestInstanceOverride: + """Test that --instance flag properly overrides inference_compute_info.""" + + def test_instance_flag_overrides_config(self): + """--instance l40s should override inference_compute_info even if config had a10g.""" + from unittest.mock import MagicMock, patch + + from clarifai.runners.models.model_deploy import ModelDeployer + + deployer = ModelDeployer.__new__(ModelDeployer) + deployer.instance_type = "gpu-nvidia-l40s" + deployer.pat = None + deployer.base_url = None + + # Simulate builder with A10G inference_compute_info (set by normalize_config) + mock_builder = MagicMock() + mock_builder.config = { + "inference_compute_info": { + "cpu_limit": "4", + "cpu_memory": "16Gi", + "num_accelerators": 1, + "accelerator_type": ["NVIDIA-A10G"], + "accelerator_memory": "24Gi", + } + } + mock_builder.inference_compute_info = MagicMock() # non-None (already set) + deployer._builder = mock_builder + + # Mock get_inference_compute_for_gpu to return L40S info + l40s_ici = { + "cpu_limit": "8", + "cpu_memory": "32Gi", + "num_accelerators": 1, + "accelerator_type": ["NVIDIA-L40S"], + "accelerator_memory": "48Gi", + } + with patch( + "clarifai.utils.compute_presets.get_inference_compute_for_gpu", + return_value=l40s_ici, + ): + from clarifai.utils.compute_presets import get_inference_compute_for_gpu + + if deployer.instance_type: + ici = get_inference_compute_for_gpu( + deployer.instance_type, pat=deployer.pat, base_url=deployer.base_url + ) + if ici.get('num_accelerators', 0) > 0: + ici.setdefault('accelerator_type', ['NVIDIA-*']) + deployer._builder.config['inference_compute_info'] = ici + deployer._builder.inference_compute_info = ( + deployer._builder._get_inference_compute_info() + ) + + # Verify inference_compute_info was updated to L40S + updated_ici = deployer._builder.config['inference_compute_info'] + assert updated_ici['accelerator_memory'] == '48Gi' + assert updated_ici['cpu_limit'] == '8' + + def test_no_instance_flag_keeps_config(self): + """Without --instance, inference_compute_info from config is preserved.""" + from clarifai.runners.models.model_deploy import ModelDeployer + + deployer = ModelDeployer.__new__(ModelDeployer) + deployer.instance_type = None # No --instance flag + deployer.pat = None + deployer.base_url = None + + # Simulate builder with A10G inference_compute_info + from unittest.mock import MagicMock + + mock_builder = MagicMock() + a10g_ici = { + "cpu_limit": "4", + "cpu_memory": "16Gi", + "num_accelerators": 1, + "accelerator_type": ["NVIDIA-A10G"], + "accelerator_memory": "24Gi", + } + mock_builder.config = {"inference_compute_info": dict(a10g_ici)} + deployer._builder = mock_builder + + # The override block should NOT execute + if deployer.instance_type: + assert False, "Should not reach here" + + # inference_compute_info unchanged + assert deployer._builder.config['inference_compute_info'] == a10g_ici + + +class TestDeploymentMonitoring: + """Test deployment monitoring logic.""" + + def test_fetch_runner_logs_deduplicates(self): + """Runner log fetching deduplicates by (log_type, url/message). + + _fetch_runner_logs only fetches "runner.events" (k8s events). + Model stdout/stderr ("runner" logs) are reserved for _tail_runner_logs. + """ + from unittest.mock import MagicMock + + from clarifai_grpc.grpc.api import resources_pb2 + + from clarifai.runners.models.model_deploy import ModelDeployer + + # Mock stub and response with real-ish log entries + mock_stub = MagicMock() + mock_entry1 = resources_pb2.LogEntry( + url="http://log1", message="Pod scheduled on node abc" + ) + mock_entry2 = resources_pb2.LogEntry( + url="http://log2", message="Pulling image clarifai/runner:latest" + ) + + mock_response = MagicMock() + mock_response.log_entries = [mock_entry1, mock_entry2] + mock_stub.ListLogEntries.return_value = mock_response + + user_app_id = resources_pb2.UserAppIDSet(user_id="test-user") + seen_logs = set() + + # First call - should return log lines + page, lines = ModelDeployer._fetch_runner_logs( + mock_stub, + user_app_id, + "cc-id", + "np-id", + "runner-1", + seen_logs, + 1, + ) + + # Only "runner.events" is fetched (not "runner" — that's for Startup Logs) + assert mock_stub.ListLogEntries.call_count == 1 + # 2 entries from runner.events + assert len(seen_logs) == 2 + assert len(lines) > 0 + + # Second call with same logs - should not add new entries + prev_seen = len(seen_logs) + mock_stub.ListLogEntries.reset_mock() + page, lines = ModelDeployer._fetch_runner_logs( + mock_stub, + user_app_id, + "cc-id", + "np-id", + "runner-1", + seen_logs, + 1, + ) + # No new logs should be added + assert len(seen_logs) == prev_seen + assert len(lines) == 0 + + def test_fetch_runner_logs_handles_errors(self): + """Log fetching is best-effort and doesn't raise on errors.""" + from unittest.mock import MagicMock + + from clarifai_grpc.grpc.api import resources_pb2 + + from clarifai.runners.models.model_deploy import ModelDeployer + + mock_stub = MagicMock() + mock_stub.ListLogEntries.side_effect = Exception("API unavailable") + user_app_id = resources_pb2.UserAppIDSet(user_id="test-user") + + # Should not raise + page, lines = ModelDeployer._fetch_runner_logs( + mock_stub, + user_app_id, + "cc-id", + "np-id", + "runner-1", + set(), + 1, + ) + assert page == 1 # Page unchanged on error + assert len(lines) == 0 + + def test_format_event_logs_parses_events(self): + """Event log parser extracts reason and message from raw events.""" + from clarifai.runners.models.model_deploy import _format_event_logs + + raw = ( + "Name: runner-pod-xyz.abc123, Type: Warning, Source: {karpenter }, " + "Reason: FailedScheduling, FirstTimestamp: 2026-02-16 15:49:06 +0000 UTC, " + "LastTimestamp: 2026-02-16 15:49:06 +0000 UTC, " + 'Message: Failed to schedule pod, incompatible requirements' + ) + # verbose=True preserves the original reason + lines = _format_event_logs(raw, verbose=True) + assert len(lines) == 1 + assert "FailedScheduling" in lines[0] + assert "Failed to schedule pod" in lines[0] + # Should NOT contain raw pod name or timestamps + assert "runner-pod-xyz" not in lines[0] + assert "FirstTimestamp" not in lines[0] + + def test_format_event_logs_multi_events(self): + """Multiple events separated by newlines are returned as separate lines.""" + from clarifai.runners.models.model_deploy import _format_event_logs + + raw = ( + "Name: pod-1.abc, Type: Warning, Source: {}, Reason: FailedScheduling, " + "FirstTimestamp: 2026-01-01 00:00:00, LastTimestamp: 2026-01-01 00:00:00, " + "Message: No nodes available\n" + "Name: pod-1.def, Type: Normal, Source: {autoscaler}, Reason: ScaleUp, " + "FirstTimestamp: 2026-01-01 00:01:00, LastTimestamp: 2026-01-01 00:01:00, " + "Message: Scaling up node group" + ) + lines = _format_event_logs(raw, verbose=True) + assert len(lines) == 2 + assert "FailedScheduling" in lines[0] + assert "ScaleUp" in lines[1] + + def test_format_event_logs_non_verbose_simplifies(self): + """Non-verbose mode simplifies FailedScheduling messages.""" + from clarifai.runners.models.model_deploy import _format_event_logs + + raw = ( + "Name: pod-1.abc, Type: Warning, Source: {karpenter }, " + "Reason: FailedScheduling, FirstTimestamp: 2026-02-16 15:49:06, " + "LastTimestamp: 2026-02-16 15:49:06, " + "Message: 0/5 nodes are available: 3 had untolerated taint " + "{infra.clarifai.com/karpenter: }, 2 didn't match Pod topology" + ) + lines = _format_event_logs(raw, verbose=False) + assert len(lines) == 1 + assert "Scheduling" in lines[0] + assert "Waiting for node" in lines[0] + # Should NOT contain taint details + assert "untolerated taint" not in lines[0] + + def test_format_event_logs_non_verbose_skips_noise(self): + """Non-verbose mode skips TaintManagerEviction and other noise events.""" + from clarifai.runners.models.model_deploy import _format_event_logs + + raw = ( + "Name: pod-1.abc, Type: Normal, Source: {scheduler}, " + "Reason: TaintManagerEviction, FirstTimestamp: 2026-02-16 15:49:06, " + "LastTimestamp: 2026-02-16 15:49:06, " + "Message: Taint manager evicted the pod" + ) + lines = _format_event_logs(raw, verbose=False) + assert len(lines) == 0 + + # Same event in verbose mode should appear + lines_verbose = _format_event_logs(raw, verbose=True) + assert len(lines_verbose) == 1 + + def test_monitor_constants(self): + """Monitoring constants are set to reasonable values.""" + from clarifai.runners.models.model_deploy import ( + DEFAULT_LOG_TAIL_DURATION, + DEFAULT_MONITOR_TIMEOUT, + DEFAULT_POLL_INTERVAL, + ) + + assert DEFAULT_MONITOR_TIMEOUT == 1200 # 20 minutes + assert DEFAULT_POLL_INTERVAL == 5 # 5 seconds + assert DEFAULT_LOG_TAIL_DURATION == 15 # 15 seconds quick check after ready + + +class TestParseK8sQuantity: + """Test parse_k8s_quantity helper.""" + + def test_gibibytes(self): + assert parse_k8s_quantity("24Gi") == 24 * 1024**3 + assert parse_k8s_quantity("48Gi") == 48 * 1024**3 + + def test_mebibytes(self): + assert parse_k8s_quantity("1500Mi") == 1500 * 1024**2 + + def test_gigabytes(self): + assert parse_k8s_quantity("16G") == 16e9 + + def test_plain_number(self): + assert parse_k8s_quantity("4") == 4.0 + assert parse_k8s_quantity("4.5") == 4.5 + + def test_millicores(self): + assert parse_k8s_quantity("100m") == 0.1 + assert parse_k8s_quantity("500m") == 0.5 + + def test_none_and_empty(self): + assert parse_k8s_quantity(None) == 0 + assert parse_k8s_quantity("") == 0 + + def test_numeric_input(self): + assert parse_k8s_quantity(24) == 24.0 + assert parse_k8s_quantity(4.5) == 4.5 + + def test_tebibytes(self): + assert parse_k8s_quantity("1Ti") == 1024**4 + + def test_kibibytes(self): + assert parse_k8s_quantity("512Ki") == 512 * 1024 + + +class TestAutoComputeUpdate: + """Test automatic compute info update logic.""" + + def _make_compute_info_proto( + self, num_accelerators=1, accelerator_memory="24Gi", accelerator_type=None + ): + """Create a mock ComputeInfo proto.""" + from unittest.mock import MagicMock + + ci = MagicMock() + ci.num_accelerators = num_accelerators + ci.accelerator_memory = accelerator_memory + ci.accelerator_type = accelerator_type or ["NVIDIA-*"] + ci.ByteSize.return_value = 1 # Non-empty + return ci + + def test_a10g_to_l40s_needs_update(self): + """A10G model → L40S instance: instance exceeds spec, needs update.""" + from clarifai.runners.models.model_deploy import ModelDeployer + + model_ci = self._make_compute_info_proto(num_accelerators=1, accelerator_memory="24Gi") + instance_ci = FALLBACK_GPU_PRESETS["L40S"]["inference_compute_info"] + + needs_update, reasons = ModelDeployer._needs_compute_update(model_ci, instance_ci) + assert needs_update is True + assert any("accelerator_memory" in r for r in reasons) + + def test_l40s_to_a10g_no_update(self): + """L40S model → A10G instance: instance is below spec, no update needed.""" + from clarifai.runners.models.model_deploy import ModelDeployer + + model_ci = self._make_compute_info_proto(num_accelerators=1, accelerator_memory="48Gi") + instance_ci = FALLBACK_GPU_PRESETS["A10G"]["inference_compute_info"] + + needs_update, reasons = ModelDeployer._needs_compute_update(model_ci, instance_ci) + assert needs_update is False + assert len(reasons) == 0 + + def test_same_instance_no_update(self): + """A10G model → A10G instance: same spec, no update needed.""" + from clarifai.runners.models.model_deploy import ModelDeployer + + model_ci = self._make_compute_info_proto(num_accelerators=1, accelerator_memory="24Gi") + instance_ci = FALLBACK_GPU_PRESETS["A10G"]["inference_compute_info"] + + needs_update, reasons = ModelDeployer._needs_compute_update(model_ci, instance_ci) + assert needs_update is False + + def test_no_compute_info_needs_update(self): + """Model with no compute info → any instance needs update.""" + from clarifai.runners.models.model_deploy import ModelDeployer + + instance_ci = FALLBACK_GPU_PRESETS["A10G"]["inference_compute_info"] + + needs_update, reasons = ModelDeployer._needs_compute_update(None, instance_ci) + assert needs_update is True + assert any("no inference_compute_info" in r for r in reasons) + + def test_num_accelerators_triggers_update(self): + """1-GPU model → 2-GPU instance: needs update.""" + from clarifai.runners.models.model_deploy import ModelDeployer + + model_ci = self._make_compute_info_proto(num_accelerators=1, accelerator_memory="48Gi") + instance_ci = FALLBACK_GPU_PRESETS["G6E"]["inference_compute_info"] + + needs_update, reasons = ModelDeployer._needs_compute_update(model_ci, instance_ci) + assert needs_update is True + assert any("num_accelerators" in r for r in reasons) + + def test_auto_update_patches_when_needed(self): + """_auto_update_compute_if_needed patches model when instance exceeds spec.""" + from unittest.mock import MagicMock, patch + + from clarifai.runners.models.model_deploy import ModelDeployer + + deployer = ModelDeployer( + model_url="https://clarifai.com/user1/app1/models/my-model", + instance_type="L40S", + ) + deployer.model_version_id = "version-123" + + mock_model = MagicMock() + + # Model has A10G compute info (24Gi) with specific accelerator_type + model_ci = self._make_compute_info_proto( + num_accelerators=1, accelerator_memory="24Gi", accelerator_type=["NVIDIA-*"] + ) + + with ( + patch.object(deployer, '_resolve_gpu') as mock_resolve, + patch.object(deployer, '_get_model_version_compute_info', return_value=model_ci), + ): + mock_resolve.return_value = dict(FALLBACK_GPU_PRESETS["L40S"]) + deployer._auto_update_compute_if_needed(mock_model) + + # Should have patched the model version + mock_model.patch_version.assert_called_once() + call_kwargs = mock_model.patch_version.call_args + assert call_kwargs.kwargs["version_id"] == "version-123" + # num_accelerators and accelerator_memory updated, accelerator_type preserved + patched_ci = call_kwargs.kwargs["inference_compute_info"] + assert patched_ci.accelerator_memory == "48Gi" + assert patched_ci.num_accelerators == 1 + # accelerator_type should be preserved from the model version (not changed) + assert list(patched_ci.accelerator_type) == ["NVIDIA-*"] + + def test_auto_update_skips_when_compatible(self): + """_auto_update_compute_if_needed skips patch when instance is within spec.""" + from unittest.mock import MagicMock, patch + + from clarifai.runners.models.model_deploy import ModelDeployer + + deployer = ModelDeployer( + model_url="https://clarifai.com/user1/app1/models/my-model", + instance_type="A10G", + ) + deployer.model_version_id = "version-123" + + mock_model = MagicMock() + + # Model has L40S compute info (48Gi), deploying to A10G (24Gi) → compatible + model_ci = self._make_compute_info_proto(num_accelerators=1, accelerator_memory="48Gi") + + with ( + patch.object(deployer, '_resolve_gpu') as mock_resolve, + patch.object(deployer, '_get_model_version_compute_info', return_value=model_ci), + ): + mock_resolve.return_value = dict(FALLBACK_GPU_PRESETS["A10G"]) + deployer._auto_update_compute_if_needed(mock_model) + + # Should NOT have patched + mock_model.patch_version.assert_not_called() + + def test_auto_update_skips_without_gpu_preset(self): + """_auto_update_compute_if_needed is a no-op when GPU preset can't be resolved.""" + from unittest.mock import MagicMock, patch + + from clarifai.runners.models.model_deploy import ModelDeployer + + deployer = ModelDeployer( + model_url="https://clarifai.com/user1/app1/models/my-model", + instance_type="A10G", + ) + deployer.model_version_id = "version-123" + + mock_model = MagicMock() + + with patch.object(deployer, '_resolve_gpu', return_value=None): + deployer._auto_update_compute_if_needed(mock_model) + mock_model.patch_version.assert_not_called() + + +class TestStreamModelLogs: + """Test standalone log streaming function.""" + + def test_stream_logs_requires_model_info(self): + """stream_model_logs raises UserError without model info.""" + from clarifai.runners.models.model_deploy import stream_model_logs + + with pytest.raises(Exception, match="You must specify --model-url"): + stream_model_logs() + + def test_stream_logs_parses_model_url(self): + """stream_model_logs extracts user/app/model from URL.""" + from unittest.mock import MagicMock, patch + + from clarifai.runners.models.model_deploy import stream_model_logs + + # Mock the Model client and gRPC stub + mock_version = MagicMock() + mock_version.model_version.id = "ver-123" + + mock_stub = MagicMock() + # ListRunners returns no runners → should raise UserError + mock_resp = MagicMock() + mock_resp.runners = [] + mock_stub.ListRunners.return_value = mock_resp + + with ( + patch('clarifai.client.auth.create_stub', return_value=mock_stub), + patch('clarifai.client.model.Model.__init__', return_value=None), + patch('clarifai.client.model.Model.list_versions', return_value=[mock_version]), + ): + with pytest.raises(Exception, match="No active runner found"): + stream_model_logs( + model_url="https://clarifai.com/user1/app1/models/my-model", + pat="test-pat", + ) + + def test_stream_logs_no_follow(self, capsys): + """stream_model_logs with follow=False prints existing logs and exits.""" + from unittest.mock import MagicMock, patch + + from clarifai_grpc.grpc.api import resources_pb2 + + from clarifai.runners.models.model_deploy import stream_model_logs + + # Mock version lookup + mock_version = MagicMock() + mock_version.model_version.id = "ver-123" + + # Mock runner + mock_runner = MagicMock() + mock_runner.id = "runner-1" + mock_runner.nodepool.compute_cluster.id = "cc-1" + mock_runner.nodepool.id = "np-1" + mock_runners_resp = MagicMock() + mock_runners_resp.runners = [mock_runner] + + # Mock log entries + mock_log_entry = resources_pb2.LogEntry(message="Model loaded successfully!") + mock_log_resp = MagicMock() + mock_log_resp.log_entries = [mock_log_entry] + + mock_stub = MagicMock() + mock_stub.ListRunners.return_value = mock_runners_resp + mock_stub.ListLogEntries.return_value = mock_log_resp + + with ( + patch('clarifai.client.auth.create_stub', return_value=mock_stub), + patch('clarifai.client.model.Model.__init__', return_value=None), + patch('clarifai.client.model.Model.list_versions', return_value=[mock_version]), + ): + stream_model_logs( + model_url="https://clarifai.com/user1/app1/models/my-model", + pat="test-pat", + follow=False, + ) + + captured = capsys.readouterr() + assert "Model loaded successfully!" in captured.out + assert "runner-1" in captured.out + + +class TestConfigTemplate: + """Test config template generation.""" + + def test_simplified_template(self): + """Simplified template has no TODOs.""" + from clarifai.cli.templates.model_templates import get_config_template + + template = get_config_template(simplified=True, model_id="test-model") + assert "TODO" not in template + assert "compute:" in template + assert "instance:" in template + assert "test-model" in template + # Should NOT have user_id/app_id + assert "user_id" not in template + assert "app_id" not in template + + def test_verbose_template(self): + """Verbose template has full config fields.""" + from clarifai.cli.templates.model_templates import get_config_template + + template = get_config_template(simplified=False, user_id="test-user") + assert "test-user" in template + assert "inference_compute_info" in template + + +class TestCustomImageDockerfile: + """Test build_info.image custom base Docker image support.""" + + def test_custom_image_dockerfile_generated(self): + """build_info.image triggers custom image Dockerfile template.""" + import shutil + from pathlib import Path + + tests_dir = Path(__file__).parent.resolve() + original_dummy_path = tests_dir / "dummy_runner_models" + + with tempfile.TemporaryDirectory() as tmp_dir: + target = Path(tmp_dir) / "model" + shutil.copytree(original_dummy_path, target) + + config_path = target / "config.yaml" + with config_path.open("r") as f: + config = yaml.safe_load(f) + + config["build_info"] = {"image": "nvcr.io/nvidia/pytorch:24.01-py3"} + + with config_path.open("w") as f: + yaml.dump(config, f, sort_keys=False) + + builder = ModelBuilder(str(target), validate_api_ids=False) + content = builder._generate_dockerfile_content() + + assert "nvcr.io/nvidia/pytorch:24.01-py3" in content + assert "FROM --platform=$TARGETPLATFORM nvcr.io/nvidia/pytorch:24.01-py3" in content + assert 'pip' in content + # Should NOT contain multi-stage build FROM (no second FROM) + from_lines = [l for l in content.splitlines() if l.strip().startswith("FROM")] + assert len(from_lines) == 1 + + def test_custom_image_dockerfile_has_required_sections(self): + """Custom image Dockerfile includes requirements install, config copy, entrypoint.""" + import shutil + from pathlib import Path + + tests_dir = Path(__file__).parent.resolve() + original_dummy_path = tests_dir / "dummy_runner_models" + + with tempfile.TemporaryDirectory() as tmp_dir: + target = Path(tmp_dir) / "model" + shutil.copytree(original_dummy_path, target) + + config_path = target / "config.yaml" + with config_path.open("r") as f: + config = yaml.safe_load(f) + + config["build_info"] = {"image": "python:3.12-slim"} + + with config_path.open("w") as f: + yaml.dump(config, f, sort_keys=False) + + builder = ModelBuilder(str(target), validate_api_ids=False) + content = builder._generate_dockerfile_content() + + assert "requirements.txt" in content + assert "config.yaml" in content + assert "ENTRYPOINT" in content + assert "clarifai.runners.server" in content + assert "WORKDIR /home/nonroot/main" in content + + def test_no_custom_image_uses_standard_dockerfile(self): + """Without build_info.image, standard Dockerfile is generated.""" + import shutil + from pathlib import Path + + tests_dir = Path(__file__).parent.resolve() + original_dummy_path = tests_dir / "dummy_runner_models" + + with tempfile.TemporaryDirectory() as tmp_dir: + target = Path(tmp_dir) / "model" + shutil.copytree(original_dummy_path, target) + + builder = ModelBuilder(str(target), validate_api_ids=False) + content = builder._generate_dockerfile_content() + + # Standard Dockerfile uses uv and multi-stage build + assert "uv" in content.lower() or "pip" in content.lower() + # Should NOT reference a custom image like nvcr.io + assert "nvcr.io" not in content + + def test_empty_image_uses_standard_dockerfile(self): + """Empty string build_info.image falls back to standard Dockerfile.""" + import shutil + from pathlib import Path + + tests_dir = Path(__file__).parent.resolve() + original_dummy_path = tests_dir / "dummy_runner_models" + + with tempfile.TemporaryDirectory() as tmp_dir: + target = Path(tmp_dir) / "model" + shutil.copytree(original_dummy_path, target) + + config_path = target / "config.yaml" + with config_path.open("r") as f: + config = yaml.safe_load(f) + + config["build_info"] = {"image": ""} + + with config_path.open("w") as f: + yaml.dump(config, f, sort_keys=False) + + builder = ModelBuilder(str(target), validate_api_ids=False) + content = builder._generate_dockerfile_content() + + # Should NOT trigger custom image path + assert "nvcr.io" not in content + + +class TestParseRunnerLog: + """Test _parse_runner_log() JSON log parsing and filtering.""" + + def test_json_log_extracts_msg(self): + """JSON runner log extracts the 'msg' field.""" + from clarifai.runners.models.model_deploy import _parse_runner_log + + raw = '{"msg": "Starting MCP bridge...", "@timestamp": "2026-02-18T13:15:15Z", "stack_info": null}' + assert _parse_runner_log(raw) == "Starting MCP bridge..." + + def test_json_log_empty_msg_returns_none(self): + """JSON runner log with empty msg returns None.""" + from clarifai.runners.models.model_deploy import _parse_runner_log + + raw = '{"msg": "", "@timestamp": "2026-02-18T13:15:15Z"}' + assert _parse_runner_log(raw) is None + + def test_json_log_no_msg_field_passthrough(self): + """JSON object without 'msg' field passes through as raw string.""" + from clarifai.runners.models.model_deploy import _parse_runner_log + + raw = '{"level": "info", "@timestamp": "2026-02-18T13:15:15Z"}' + assert _parse_runner_log(raw) == raw + + def test_plain_text_passthrough(self): + """Non-JSON text passes through unchanged.""" + from clarifai.runners.models.model_deploy import _parse_runner_log + + raw = "[02/18/26 13:15:22] INFO Starting server on port 8080" + assert _parse_runner_log(raw) == raw + + def test_deprecation_warning_filtered(self): + """DeprecationWarning lines are filtered in non-verbose mode.""" + from clarifai.runners.models.model_deploy import _parse_runner_log + + raw = "/usr/local/lib/python3.12/site-packages/foo.py:42: DeprecationWarning: datetime.utcnow() is deprecated" + assert _parse_runner_log(raw, verbose=False) is None + # Verbose mode keeps it + assert _parse_runner_log(raw, verbose=True) == raw + + def test_pip_download_filtered(self): + """pip download lines are filtered in non-verbose mode.""" + from clarifai.runners.models.model_deploy import _parse_runner_log + + raw = "Downloading pygments (1.2MiB)" + assert _parse_runner_log(raw, verbose=False) is None + assert _parse_runner_log(raw, verbose=True) == raw + + def test_empty_and_none(self): + """Empty string and None return None.""" + from clarifai.runners.models.model_deploy import _parse_runner_log + + assert _parse_runner_log("") is None + assert _parse_runner_log(None) is None + + def test_installing_packages_filtered(self): + """'Installing collected packages:' lines are filtered in non-verbose mode.""" + from clarifai.runners.models.model_deploy import _parse_runner_log + + raw = "Installing collected packages: numpy, pandas, torch" + assert _parse_runner_log(raw, verbose=False) is None + + +class TestSimplifyK8sMessage: + """Test _simplify_k8s_message() human-friendly event mapping.""" + + def test_failed_scheduling_simplified(self): + """FailedScheduling becomes a simple 'waiting' message.""" + from clarifai.runners.models.model_deploy import _simplify_k8s_message + + msg = _simplify_k8s_message( + "FailedScheduling", + "0/5 nodes are available: 3 had untolerated taint {infra.clarifai.com/karpenter: }", + ) + assert msg == "Waiting for node to become available..." + + def test_scheduled_simplified(self): + from clarifai.runners.models.model_deploy import _simplify_k8s_message + + msg = _simplify_k8s_message( + "Scheduled", "Successfully assigned to ip-10-7-1-42.ec2.internal" + ) + assert msg == "Pod scheduled on node" + + def test_pulling_simplified(self): + from clarifai.runners.models.model_deploy import _simplify_k8s_message + + msg = _simplify_k8s_message( + "Pulling", "Pulling image public.ecr.aws/clarifai/runner:sha-abc123" + ) + assert msg == "Pulling model image..." + + def test_long_message_truncated(self): + from clarifai.runners.models.model_deploy import _simplify_k8s_message + + long_msg = "x" * 100 + result = _simplify_k8s_message("UnknownReason", long_msg) + assert len(result) == 80 + assert result.endswith("...") + + def test_short_message_passthrough(self): + from clarifai.runners.models.model_deploy import _simplify_k8s_message + + msg = _simplify_k8s_message("UnknownReason", "Short message") + assert msg == "Short message" + + def test_nominated_simplified(self): + """Nominated/NominatedNode hides internal node IPs.""" + from clarifai.runners.models.model_deploy import _simplify_k8s_message + + msg = _simplify_k8s_message( + "Nominated", "Pod should schedule on: node/ip-10-7-158-85.ec2.internal" + ) + assert msg == "Node selected for scheduling" + assert "ip-10" not in msg + + msg2 = _simplify_k8s_message( + "NominatedNode", "Pod should schedule on: node/ip-10-7-158-85.ec2.internal" + ) + assert msg2 == "Node selected for scheduling" + + +class TestEventDedup: + """Test that simplified event messages are deduplicated across polls.""" + + def test_seen_messages_deduplicates_events(self): + """Repeated simplified events are suppressed when seen_messages is used.""" + from clarifai.runners.models.model_deploy import _format_event_logs + + raw = ( + "Name: pod-1.abc, Type: Warning, Source: {karpenter }, " + "Reason: FailedScheduling, FirstTimestamp: 2026-02-16 15:49:06, " + "LastTimestamp: 2026-02-16 15:49:06, " + "Message: 0/5 nodes are available" + ) + raw2 = ( + "Name: pod-1.def, Type: Warning, Source: {karpenter }, " + "Reason: FailedScheduling, FirstTimestamp: 2026-02-16 15:49:11, " + "LastTimestamp: 2026-02-16 15:49:11, " + "Message: 0/5 nodes are available (different timestamp)" + ) + + # Both simplify to the same message + lines1 = _format_event_logs(raw, verbose=False) + lines2 = _format_event_logs(raw2, verbose=False) + assert len(lines1) == 1 + assert len(lines2) == 1 + # Both have the same simplified text + assert lines1[0] == lines2[0] + + def test_event_prefix_alignment(self): + """[warning] and [event ] prefixes have consistent width.""" + from clarifai.runners.models.model_deploy import _format_event_logs + + warning_raw = ( + "Name: pod-1, Type: Warning, Source: {}, Reason: FailedScheduling, " + "FirstTimestamp: 2026-01-01, LastTimestamp: 2026-01-01, " + "Message: test" + ) + normal_raw = ( + "Name: pod-1, Type: Normal, Source: {}, Reason: Scheduled, " + "FirstTimestamp: 2026-01-01, LastTimestamp: 2026-01-01, " + "Message: test" + ) + warning_lines = _format_event_logs(warning_raw, verbose=True) + normal_lines = _format_event_logs(normal_raw, verbose=True) + assert len(warning_lines) == 1 + assert len(normal_lines) == 1 + # Both prefixes should align — same character position for the reason + assert "[warning]" in warning_lines[0] + assert "[event ]" in normal_lines[0] + + +class TestDeployOutput: + """Test deploy_output.py helper functions.""" + + def test_phase_header_outputs(self, capsys): + """phase_header prints a formatted header.""" + from clarifai.runners.models.deploy_output import phase_header + + phase_header("Validate") + captured = capsys.readouterr() + assert "Validate" in captured.out + assert "\u2500" in captured.out # em dash character + + def test_info_outputs(self, capsys): + """info prints a labeled line.""" + from clarifai.runners.models.deploy_output import info + + info("Model", "my-model-id") + captured = capsys.readouterr() + assert "Model:" in captured.out + assert "my-model-id" in captured.out + + def test_status_outputs(self, capsys): + """status prints a status message.""" + from clarifai.runners.models.deploy_output import status + + status("Building image...") + captured = capsys.readouterr() + assert "Building image..." in captured.out + + def test_success_outputs(self, capsys): + """success prints a green message.""" + from clarifai.runners.models.deploy_output import success + + success("Model deployed!") + captured = capsys.readouterr() + assert "Model deployed!" in captured.out + + def test_warning_outputs(self, capsys): + """warning prints a yellow [warning] message.""" + from clarifai.runners.models.deploy_output import warning + + warning("Timeout reached") + captured = capsys.readouterr() + assert "[warning]" in captured.out + assert "Timeout reached" in captured.out + + +class TestVerboseFlag: + """Test that verbose flag is properly plumbed through ModelDeployer.""" + + def test_deployer_accepts_verbose(self): + """ModelDeployer accepts verbose parameter.""" + from clarifai.runners.models.model_deploy import ModelDeployer + + deployer = ModelDeployer(verbose=True) + assert deployer.verbose is True + + deployer2 = ModelDeployer() + assert deployer2.verbose is False + + def test_fetch_runner_logs_passes_verbose(self): + """_fetch_runner_logs accepts and passes verbose to formatters.""" + from unittest.mock import MagicMock + + from clarifai_grpc.grpc.api import resources_pb2 + + from clarifai.runners.models.model_deploy import ModelDeployer + + mock_stub = MagicMock() + mock_entry = resources_pb2.LogEntry( + url="http://log1", message='{"msg": "Hello", "@timestamp": "2026-01-01"}' + ) + mock_response = MagicMock() + mock_response.log_entries = [mock_entry] + mock_stub.ListLogEntries.return_value = mock_response + + user_app_id = resources_pb2.UserAppIDSet(user_id="test-user") + + # Should not raise with verbose=True + page, lines = ModelDeployer._fetch_runner_logs( + mock_stub, user_app_id, "cc", "np", "runner-1", set(), 1, verbose=True + ) + assert page >= 1 + + +class TestQuietSdkLogger: + """Test the _quiet_sdk_logger context manager.""" + + def test_suppresses_info_when_enabled(self): + """Logger level is raised to WARNING inside the context.""" + import logging + + from clarifai.runners.models.model_deploy import _quiet_sdk_logger + from clarifai.utils.logging import logger + + original_level = logger.level + with _quiet_sdk_logger(suppress=True): + assert logger.level >= logging.WARNING + # Restored after exiting + assert logger.level == original_level + + def test_noop_when_disabled(self): + """Logger level is unchanged when suppress=False.""" + + from clarifai.runners.models.model_deploy import _quiet_sdk_logger + from clarifai.utils.logging import logger + + original_level = logger.level + with _quiet_sdk_logger(suppress=False): + assert logger.level == original_level + + def test_restores_on_exception(self): + """Logger level is restored even if an exception is raised.""" + import logging + + from clarifai.runners.models.model_deploy import _quiet_sdk_logger + from clarifai.utils.logging import logger + + original_level = logger.level + with pytest.raises(ValueError): + with _quiet_sdk_logger(suppress=True): + assert logger.level >= logging.WARNING + raise ValueError("test error") + assert logger.level == original_level + + +class TestDeployModelQuiet: + """Test the quiet parameter on deploy_model.""" + + def test_deploy_model_quiet_suppresses_print(self, capsys): + """deploy_model with quiet=True should not print success/failure messages.""" + from unittest.mock import MagicMock, patch + + from clarifai.runners.models.model_builder import deploy_model + + mock_nodepool = MagicMock() + mock_deployment = MagicMock() + mock_nodepool.create_deployment.return_value = mock_deployment + + with patch('clarifai.runners.models.model_builder.Nodepool', return_value=mock_nodepool): + result = deploy_model( + model_id="test-model", + app_id="test-app", + user_id="test-user", + deployment_id="deploy-test", + model_version_id="v1", + nodepool_id="np-1", + compute_cluster_id="cc-1", + cluster_user_id="test-user", + quiet=True, + ) + + assert result is True + captured = capsys.readouterr() + assert "✅" not in captured.out + assert "Deployment" not in captured.out + + def test_deploy_model_not_quiet_prints(self, capsys): + """deploy_model with quiet=False should print success message.""" + from unittest.mock import MagicMock, patch + + from clarifai.runners.models.model_builder import deploy_model + + mock_nodepool = MagicMock() + mock_deployment = MagicMock() + mock_nodepool.create_deployment.return_value = mock_deployment + + with patch('clarifai.runners.models.model_builder.Nodepool', return_value=mock_nodepool): + result = deploy_model( + model_id="test-model", + app_id="test-app", + user_id="test-user", + deployment_id="deploy-test", + model_version_id="v1", + nodepool_id="np-1", + compute_cluster_id="cc-1", + cluster_user_id="test-user", + quiet=False, + ) + + assert result is True + captured = capsys.readouterr() + assert "Deployment" in captured.out + + +class TestRecommendInstance: + """Test auto-selection of GPU instance based on model size.""" + + def test_get_hf_model_info_success(self): + """Parses safetensors.total and config from mocked HF API.""" + from unittest.mock import MagicMock, patch + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = { + "safetensors": { + "total": 7_000_000_000, + "parameters": {"BF16": 7_000_000_000}, + }, + "config": {}, + "pipeline_tag": "text-generation", + } + with patch("clarifai.utils.compute_presets.requests.get", return_value=mock_resp): + info = _get_hf_model_info("meta-llama/Llama-3-8B") + assert info is not None + assert info["num_params"] == 7_000_000_000 + assert info["dtype_breakdown"] == {"BF16": 7_000_000_000} + assert info["pipeline_tag"] == "text-generation" + assert info["quant_method"] is None + + def test_get_hf_model_info_with_quantization(self): + """Detects AWQ quantization from API response.""" + from unittest.mock import MagicMock, patch + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = { + "safetensors": {"total": 7_000_000_000, "parameters": {"I32": 7_000_000_000}}, + "config": {"quantization_config": {"quant_method": "awq", "bits": 4}}, + } + with patch("clarifai.utils.compute_presets.requests.get", return_value=mock_resp): + info = _get_hf_model_info("some/model-awq") + assert info["quant_method"] == "awq" + assert info["quant_bits"] == 4 + + def test_get_hf_model_info_api_failure(self): + """Returns None when API fails.""" + from unittest.mock import patch + + with patch( + "clarifai.utils.compute_presets.requests.get", side_effect=Exception("timeout") + ): + info = _get_hf_model_info("nonexistent/model") + assert info is None + + def test_get_hf_model_info_no_safetensors(self): + """Returns num_params=None when safetensors field missing.""" + from unittest.mock import MagicMock, patch + + mock_resp = MagicMock() + mock_resp.json.return_value = {"config": {}, "pipeline_tag": "text-generation"} + with patch("clarifai.utils.compute_presets.requests.get", return_value=mock_resp): + info = _get_hf_model_info("some/old-model") + assert info is not None + assert info["num_params"] is None + + def test_detect_quant_awq(self): + """-awq in repo name → ("awq", 4)""" + method, bits = _detect_quant_from_repo_name("TheBloke/Llama-7B-AWQ") + assert method == "awq" + assert bits == 4 + + def test_detect_quant_gptq(self): + """-GPTQ in repo name → ("gptq", 4)""" + method, bits = _detect_quant_from_repo_name("TheBloke/Llama-7B-GPTQ") + assert method == "gptq" + assert bits == 4 + + def test_detect_quant_none(self): + """Clean name → (None, None)""" + method, bits = _detect_quant_from_repo_name("meta-llama/Llama-3-8B") + assert method is None + assert bits is None + + def test_estimate_vram_7b_bf16(self): + """~7B * 2 + overhead (50% KV + hybrid overhead)""" + vram = _estimate_vram_bytes(7_248_023_552) # 7B params, BF16 default + weight_bytes = 7_248_023_552 * 2.0 + # weights + 50% KV + overhead (2 GiB fixed + 10% of weights) + expected_approx = weight_bytes * 1.50 + (2 * 1024**3 + weight_bytes * 0.10) + assert abs(vram - expected_approx) < 1024 # within 1 KB + + def test_estimate_vram_7b_awq_4bit(self): + """~7B * 0.5 + overhead (50% KV + hybrid overhead)""" + vram = _estimate_vram_bytes(7_248_023_552, quant_method="awq", quant_bits=4) + weight_bytes = 7_248_023_552 * 0.5 + expected_approx = weight_bytes * 1.50 + (2 * 1024**3 + weight_bytes * 0.10) + assert abs(vram - expected_approx) < 1024 + + def test_estimate_vram_70b_bf16(self): + """~70B BF16 should be very large.""" + vram = _estimate_vram_bytes(70_000_000_000) + vram_gib = vram / (1024**3) + assert vram_gib > 150 # ~170 GiB, exceeds all instances + + def test_select_instance_small(self): + """10 GiB → A10G (24 GiB) via fallback.""" + from unittest.mock import patch + + with patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", return_value=None + ): + inst_id, reason = _select_instance_by_vram(10 * 1024**3) + assert inst_id == "gpu-nvidia-a10g" + assert "10.0 GiB" in reason + + def test_select_instance_medium(self): + """30 GiB → L40S (48 GiB) via fallback.""" + from unittest.mock import patch + + with patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", return_value=None + ): + inst_id, reason = _select_instance_by_vram(30 * 1024**3) + assert inst_id == "gpu-nvidia-l40s" + + def test_select_instance_large(self): + """60 GiB → G6E (96 GiB) via fallback.""" + from unittest.mock import patch + + with patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", return_value=None + ): + inst_id, reason = _select_instance_by_vram(60 * 1024**3) + assert inst_id == "gpu-nvidia-g6e-2x-large" + + def test_select_instance_too_large(self): + """120 GiB → None via fallback.""" + from unittest.mock import patch + + with patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", return_value=None + ): + inst_id, reason = _select_instance_by_vram(120 * 1024**3) + assert inst_id is None + assert "exceeds" in reason + + def test_recommend_mcp_model(self): + """MCP → CPU instance.""" + config = {"model": {"model_type_id": "mcp"}} + inst_id, reason = recommend_instance(config) + assert inst_id == "t3a.2xlarge" + assert "CPU" in reason + + def test_recommend_no_checkpoints(self): + """No repo_id, no GPU toolkit → CPU.""" + config = {"model": {"model_type_id": "any-to-any"}} + inst_id, reason = recommend_instance(config) + assert inst_id == "t3a.2xlarge" + assert "CPU" in reason or "cpu" in reason.lower() + + def test_recommend_vllm_no_repo(self): + """vLLM without repo_id → None.""" + config = { + "model": {"model_type_id": "any-to-any"}, + "build_info": {"image": "vllm/vllm-openai:latest"}, + } + inst_id, reason = recommend_instance(config) + assert inst_id is None + assert "checkpoints.repo_id" in reason + + def test_recommend_7b_model(self): + """Mock 7B BF16 → L40S (heuristic path with 90% utilization headroom).""" + from unittest.mock import MagicMock, patch + + config = { + "model": {"model_type_id": "any-to-any"}, + "checkpoints": {"repo_id": "meta-llama/Llama-3-8B"}, + } + mock_resp = MagicMock() + mock_resp.json.return_value = { + "safetensors": {"total": 7_248_023_552, "parameters": {"BF16": 7_248_023_552}}, + "config": {}, + } + with patch("clarifai.utils.compute_presets.requests.get", return_value=mock_resp): + with patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", return_value=None + ): + inst_id, reason = recommend_instance(config) + # 7.2B * 2 * 1.5 + overhead = ~23.6 GiB > A10G usable (21.6 GiB) → L40S + assert inst_id == "gpu-nvidia-l40s" + + def test_recommend_13b_model(self): + """Mock 13B BF16 → L40S.""" + from unittest.mock import MagicMock, patch + + config = { + "model": {"model_type_id": "any-to-any"}, + "checkpoints": {"repo_id": "meta-llama/Llama-13B"}, + } + mock_resp = MagicMock() + mock_resp.json.return_value = { + "safetensors": {"total": 13_000_000_000, "parameters": {"BF16": 13_000_000_000}}, + "config": {}, + } + with patch("clarifai.utils.compute_presets.requests.get", return_value=mock_resp): + with patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", return_value=None + ): + inst_id, reason = recommend_instance(config) + # 13B * 2 * 1.5 + (2 GiB + 10%) ≈ 40.7 GiB → L40S (48 GiB) + assert inst_id == "gpu-nvidia-l40s" + + def test_recommend_fallback_file_size(self): + """HF metadata fails, file size works → selects by file size.""" + from unittest.mock import patch + + config = { + "model": {"model_type_id": "any-to-any"}, + "checkpoints": {"repo_id": "some/model"}, + } + with patch( + "clarifai.utils.compute_presets._get_hf_model_info", + return_value={ + "num_params": None, + "quant_method": None, + "quant_bits": None, + "dtype_breakdown": None, + "pipeline_tag": None, + }, + ): + with patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.get_huggingface_checkpoint_total_size", + return_value=10 * 1024**3, # 10 GiB files + ): + with patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", + return_value=None, + ): + inst_id, reason = recommend_instance(config) + # 10 GiB * 1.3 + 2 GiB ≈ 15 GiB → A10G (24 GiB) + assert inst_id == "gpu-nvidia-a10g" + + def test_recommend_both_fail(self): + """Both APIs fail → (None, reason).""" + from unittest.mock import patch + + config = { + "model": {"model_type_id": "any-to-any"}, + "checkpoints": {"repo_id": "nonexistent/model"}, + } + with patch("clarifai.utils.compute_presets._get_hf_model_info", return_value=None): + with patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.get_huggingface_checkpoint_total_size", + return_value=0, + ): + inst_id, reason = recommend_instance(config) + assert inst_id is None + assert "Could not determine" in reason + + def test_recommend_sglang_skips_pre_ampere(self): + """SGLang toolkit should skip pre-Ampere instances across all clouds.""" + from unittest.mock import MagicMock, patch + + # AWS T4 instance + mock_g4dn = MagicMock() + mock_g4dn.id = "g4dn.xlarge" + mock_g4dn.cloud_provider.id = "aws" + mock_g4dn.compute_info.num_accelerators = 1 + mock_g4dn.compute_info.accelerator_memory = "16Gi" + + # Azure T4 instance (not in supported clouds — filtered out by recommendation) + mock_azure_t4 = MagicMock() + mock_azure_t4.id = "Standard_NC4as_T4_v3" + mock_azure_t4.cloud_provider.id = "azure" + mock_azure_t4.compute_info.num_accelerators = 1 + mock_azure_t4.compute_info.accelerator_memory = "16Gi" + + # AWS A10G instance (Ampere) + mock_g5 = MagicMock() + mock_g5.id = "g5.xlarge" + mock_g5.cloud_provider.id = "aws" + mock_g5.compute_info.num_accelerators = 1 + mock_g5.compute_info.accelerator_memory = "24Gi" + + config_sglang = { + "model": {"model_type_id": "any-to-any"}, + "build_info": {"image": "lmsysorg/sglang:latest"}, + "checkpoints": {"repo_id": "Qwen/Qwen3-0.6B"}, + } + config_vllm = { + "model": {"model_type_id": "any-to-any"}, + "build_info": {"image": "vllm/vllm-openai:latest"}, + "checkpoints": {"repo_id": "Qwen/Qwen3-0.6B"}, + } + + small_vram = {"num_params": 600_000_000} # ~3 GiB, fits T4 + + with ( + patch("clarifai.utils.compute_presets._get_hf_model_info", return_value=small_vram), + patch("clarifai.utils.compute_presets._get_hf_model_config", return_value=None), + patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", + return_value=[mock_g4dn, mock_azure_t4, mock_g5], + ), + ): + # SGLang should skip g4dn (pre-Ampere), Azure T4 filtered (unsupported cloud) → g5 + inst_id, _ = recommend_instance(config_sglang) + assert inst_id == "g5.xlarge" + + # vLLM picks g4dn (AWS, cheapest); Azure T4 filtered out + inst_id, _ = recommend_instance(config_vllm) + assert inst_id == "g4dn.xlarge" + + def test_recommend_sglang_from_requirements_txt(self): + """SGLang detected via requirements.txt should also skip pre-Ampere.""" + from unittest.mock import MagicMock, patch + + mock_g4dn = MagicMock() + mock_g4dn.id = "g4dn.xlarge" + mock_g4dn.cloud_provider.id = "aws" + mock_g4dn.compute_info.num_accelerators = 1 + mock_g4dn.compute_info.accelerator_memory = "16Gi" + + mock_g5 = MagicMock() + mock_g5.id = "g5.xlarge" + mock_g5.cloud_provider.id = "aws" + mock_g5.compute_info.num_accelerators = 1 + mock_g5.compute_info.accelerator_memory = "24Gi" + + # Config has no build_info.image — toolkit should be detected from requirements.txt + config = { + "model": {"model_type_id": "any-to-any"}, + "checkpoints": {"repo_id": "Qwen/Qwen3-0.6B"}, + } + + small_vram = {"num_params": 600_000_000} + + with tempfile.TemporaryDirectory() as tmpdir: + req_path = os.path.join(tmpdir, "requirements.txt") + with open(req_path, "w") as f: + f.write("sglang\nclarifai\n") + + with ( + patch( + "clarifai.utils.compute_presets._get_hf_model_info", return_value=small_vram + ), + patch("clarifai.utils.compute_presets._get_hf_model_config", return_value=None), + patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", + return_value=[mock_g4dn, mock_g5], + ), + ): + # Should detect sglang from requirements.txt and skip g4dn + inst_id, _ = recommend_instance(config, model_path=tmpdir) + assert inst_id == "g5.xlarge" + + def test_write_instance_to_config(self): + """Verify config.yaml is updated with selected instance.""" + from clarifai.runners.models.model_deploy import ModelDeployer + + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, "config.yaml") + with open(config_path, "w") as f: + yaml.dump({"model": {"id": "test"}}, f) + + deployer = ModelDeployer.__new__(ModelDeployer) + deployer.model_path = tmpdir + deployer._write_instance_to_config("gpu-nvidia-l40s") + + with open(config_path) as f: + config = yaml.safe_load(f) + assert config["compute"]["instance"] == "gpu-nvidia-l40s" + + +class TestHFTokenValidation: + """Tests for HuggingFace token validation and error reporting.""" + + def test_hf_gated_no_token(self): + """Gated repo with no token returns (False, 'gated_no_token').""" + from unittest.mock import patch + + with patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.validate_hf_repo_access" + ) as mock: + # Simulate real behavior + mock.return_value = (False, "gated_no_token") + has_access, reason = mock("meta-llama/Llama-3.1-8B-Instruct", token=None) + assert has_access is False + assert reason == "gated_no_token" + + def _mock_hf_response(self, status_code=403): + """Create a mock HTTP response for HF exceptions.""" + from unittest.mock import MagicMock + + response = MagicMock() + response.status_code = status_code + response.headers = {} + return response + + def test_hf_gated_no_access(self): + """Gated repo with token that lacks access returns (False, 'gated_no_access').""" + from unittest.mock import patch + + from huggingface_hub.utils import GatedRepoError + + err = GatedRepoError("gated", response=self._mock_hf_response(403)) + with patch("huggingface_hub.auth_check", side_effect=err): + from clarifai.runners.utils.loader import HuggingFaceLoader + + has_access, reason = HuggingFaceLoader.validate_hf_repo_access( + "meta-llama/Llama-3.1-8B-Instruct", token="hf_fake_token" + ) + assert has_access is False + assert reason == "gated_no_access" + + def test_hf_gated_no_token_real(self): + """Gated repo with no token triggers gated_no_token reason.""" + from unittest.mock import patch + + from huggingface_hub.utils import GatedRepoError + + err = GatedRepoError("gated", response=self._mock_hf_response(403)) + with patch("huggingface_hub.auth_check", side_effect=err): + from clarifai.runners.utils.loader import HuggingFaceLoader + + has_access, reason = HuggingFaceLoader.validate_hf_repo_access( + "meta-llama/Llama-3.1-8B-Instruct", token=None + ) + assert has_access is False + assert reason == "gated_no_token" + + def test_hf_not_found(self): + """Non-existent repo returns (False, 'not_found').""" + from unittest.mock import patch + + from huggingface_hub.utils import RepositoryNotFoundError + + err = RepositoryNotFoundError("not found", response=self._mock_hf_response(404)) + with patch("huggingface_hub.auth_check", side_effect=err): + from clarifai.runners.utils.loader import HuggingFaceLoader + + has_access, reason = HuggingFaceLoader.validate_hf_repo_access( + "fake-org/nonexistent-model", token=None + ) + assert has_access is False + assert reason == "not_found" + + def test_hf_success(self): + """Valid repo returns (True, '').""" + from unittest.mock import patch + + with patch("huggingface_hub.auth_check", return_value=None): + from clarifai.runners.utils.loader import HuggingFaceLoader + + has_access, reason = HuggingFaceLoader.validate_hf_repo_access( + "bert-base-uncased", token=None + ) + assert has_access is True + assert reason == "" + + def test_validate_config_gated_no_token_raises(self): + """ModelBuilder raises UserError with 'Set HF_TOKEN' for gated repo without token.""" + import shutil + from pathlib import Path + from unittest.mock import patch + + from clarifai.errors import UserError + + tests_dir = Path(__file__).parent.resolve() + original_dummy_path = tests_dir / "dummy_runner_models" + + with tempfile.TemporaryDirectory() as tmp_dir: + target = Path(tmp_dir) / "test_model" + shutil.copytree(original_dummy_path, target) + + config_path = target / "config.yaml" + with config_path.open("r") as f: + config = yaml.safe_load(f) + + config["checkpoints"] = { + "type": "huggingface", + "repo_id": "meta-llama/Llama-3.1-8B-Instruct", + "when": "runtime", + } + with config_path.open("w") as f: + yaml.dump(config, f, sort_keys=False) + + # Anonymous check returns gated, no env token available either + with ( + patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.validate_hf_repo_access", + return_value=(False, "gated_no_token"), + ), + patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.validate_hftoken", + return_value=False, + ), + ): + with pytest.raises(UserError, match="requires authentication"): + ModelBuilder(str(target), validate_api_ids=False) + + def test_validate_config_gated_no_access_raises(self): + """ModelBuilder raises UserError with 'Request access' for gated repo with bad token.""" + import shutil + from pathlib import Path + from unittest.mock import patch + + from clarifai.errors import UserError + + tests_dir = Path(__file__).parent.resolve() + original_dummy_path = tests_dir / "dummy_runner_models" + + with tempfile.TemporaryDirectory() as tmp_dir: + target = Path(tmp_dir) / "test_model" + shutil.copytree(original_dummy_path, target) + + config_path = target / "config.yaml" + with config_path.open("r") as f: + config = yaml.safe_load(f) + + config["checkpoints"] = { + "type": "huggingface", + "repo_id": "meta-llama/Llama-3.1-8B-Instruct", + "hf_token": "hf_bad_token", + "when": "runtime", + } + with config_path.open("w") as f: + yaml.dump(config, f, sort_keys=False) + + # 1st call (anonymous) → gated, 2nd call (with token) → no access + with ( + patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.validate_hf_repo_access", + side_effect=[(False, "gated_no_token"), (False, "gated_no_access")], + ), + patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.validate_hftoken", + return_value=True, + ), + ): + with pytest.raises(UserError, match="does not have access"): + ModelBuilder(str(target), validate_api_ids=False) + + def test_validate_config_not_found_raises(self): + """ModelBuilder raises UserError with 'not found' for missing repo.""" + import shutil + from pathlib import Path + from unittest.mock import patch + + from clarifai.errors import UserError + + tests_dir = Path(__file__).parent.resolve() + original_dummy_path = tests_dir / "dummy_runner_models" + + with tempfile.TemporaryDirectory() as tmp_dir: + target = Path(tmp_dir) / "test_model" + shutil.copytree(original_dummy_path, target) + + config_path = target / "config.yaml" + with config_path.open("r") as f: + config = yaml.safe_load(f) + + config["checkpoints"] = { + "type": "huggingface", + "repo_id": "fake-org/nonexistent-model", + "when": "runtime", + } + with config_path.open("w") as f: + yaml.dump(config, f, sort_keys=False) + + with ( + patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.validate_hf_repo_access", + return_value=(False, "not_found"), + ), + patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.validate_hftoken", + return_value=False, + ), + ): + with pytest.raises(UserError, match="not found"): + ModelBuilder(str(target), validate_api_ids=False) + + def test_validate_config_env_token_persisted_for_runtime(self): + """When when=runtime and HF_TOKEN only in env, token is validated and written to config.""" + import shutil + from pathlib import Path + from unittest.mock import patch + + tests_dir = Path(__file__).parent.resolve() + original_dummy_path = tests_dir / "dummy_runner_models" + + with tempfile.TemporaryDirectory() as tmp_dir: + target = Path(tmp_dir) / "test_model" + shutil.copytree(original_dummy_path, target) + + config_path = target / "config.yaml" + with config_path.open("r") as f: + config = yaml.safe_load(f) + + config["checkpoints"] = { + "type": "huggingface", + "repo_id": "meta-llama/Llama-3.1-8B-Instruct", + "when": "runtime", + } + with config_path.open("w") as f: + yaml.dump(config, f, sort_keys=False) + + # 1st call (anonymous/False) → gated, 2nd call (with env token) → success + mock_validate = patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.validate_hf_repo_access", + side_effect=[(False, "gated_no_token"), (True, "")], + ) + mock_env = patch.dict(os.environ, {"HF_TOKEN": "hf_env_only_token"}) + mock_hftoken = patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.validate_hftoken", + return_value=True, + ) + + with mock_validate as mv, mock_env, mock_hftoken: + ModelBuilder(str(target), validate_api_ids=False) + # First call anonymous (False), second with env token + assert mv.call_count == 2 + assert mv.call_args_list[0].kwargs["token"] is False + assert mv.call_args_list[1].kwargs["token"] == "hf_env_only_token" + + # Token should have been persisted to config.yaml + with config_path.open("r") as f: + saved = yaml.safe_load(f) + assert saved["checkpoints"]["hf_token"] == "hf_env_only_token" + + def test_validate_config_env_token_no_access_raises(self): + """When when=runtime, env token set but lacks access, raises UserError.""" + import shutil + from pathlib import Path + from unittest.mock import patch + + from clarifai.errors import UserError + + tests_dir = Path(__file__).parent.resolve() + original_dummy_path = tests_dir / "dummy_runner_models" + + with tempfile.TemporaryDirectory() as tmp_dir: + target = Path(tmp_dir) / "test_model" + shutil.copytree(original_dummy_path, target) + + config_path = target / "config.yaml" + with config_path.open("r") as f: + config = yaml.safe_load(f) + + config["checkpoints"] = { + "type": "huggingface", + "repo_id": "meta-llama/Llama-3.1-8B-Instruct", + "when": "runtime", + } + with config_path.open("w") as f: + yaml.dump(config, f, sort_keys=False) + + # 1st call (anonymous) → gated, 2nd call (with env token) → no access + with ( + patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.validate_hf_repo_access", + side_effect=[(False, "gated_no_token"), (False, "gated_no_access")], + ), + patch.dict(os.environ, {"HF_TOKEN": "hf_bad_env_token"}), + patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.validate_hftoken", + return_value=True, + ), + ): + with pytest.raises(UserError, match="does not have access"): + ModelBuilder(str(target), validate_api_ids=False) + + def test_validate_config_config_token_used_for_build_runtime(self): + """When when=runtime and hf_token IS in config, validate with that token.""" + import shutil + from pathlib import Path + from unittest.mock import patch + + tests_dir = Path(__file__).parent.resolve() + original_dummy_path = tests_dir / "dummy_runner_models" + + with tempfile.TemporaryDirectory() as tmp_dir: + target = Path(tmp_dir) / "test_model" + shutil.copytree(original_dummy_path, target) + + config_path = target / "config.yaml" + with config_path.open("r") as f: + config = yaml.safe_load(f) + + config["checkpoints"] = { + "type": "huggingface", + "repo_id": "meta-llama/Llama-3.1-8B-Instruct", + "hf_token": "hf_config_token", + "when": "runtime", + } + with config_path.open("w") as f: + yaml.dump(config, f, sort_keys=False) + + # 1st call (anonymous) → gated, 2nd call (with config token) → success + mock_validate = patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.validate_hf_repo_access", + side_effect=[(False, "gated_no_token"), (True, "")], + ) + mock_hftoken = patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.validate_hftoken", + return_value=True, + ) + + with mock_validate as mv, mock_hftoken: + ModelBuilder(str(target), validate_api_ids=False) + # 1st anonymous, 2nd with config token + assert mv.call_count == 2 + assert mv.call_args_list[1].kwargs["token"] == "hf_config_token" + + +class TestKVCacheEstimation: + """Tests for accurate KV cache estimation from HF config.json.""" + + # Qwen3-4B architecture params (from actual config.json) + QWEN3_4B_CONFIG = { + 'num_hidden_layers': 36, + 'num_key_value_heads': 8, + 'head_dim': 128, + 'max_position_embeddings': 40960, + } + + # Llama-3.1-8B architecture params + LLAMA_8B_CONFIG = { + 'num_hidden_layers': 32, + 'num_key_value_heads': 8, + 'head_dim': 128, + 'max_position_embeddings': 131072, + } + + # Phi-3-mini-4k (small context) + PHI3_MINI_CONFIG = { + 'num_hidden_layers': 32, + 'num_key_value_heads': 32, # MHA (num_kv_heads == num_attention_heads) + 'head_dim': 96, + 'max_position_embeddings': 4096, + } + + def test_kv_cache_qwen3_4b(self): + """Qwen3-4B KV cache should be ~5.62 GiB (matches vLLM error).""" + kv_bytes = _estimate_kv_cache_bytes(self.QWEN3_4B_CONFIG) + kv_gib = kv_bytes / (1024**3) + # 2 * 36 * 8 * 128 * 2 * 40960 = 6,039,797,760 bytes = ~5.625 GiB + assert abs(kv_gib - 5.625) < 0.01 + + def test_kv_cache_llama_8b(self): + """Llama-3.1-8B with 128k context should have ~16 GiB KV cache.""" + kv_bytes = _estimate_kv_cache_bytes(self.LLAMA_8B_CONFIG) + kv_gib = kv_bytes / (1024**3) + # 2 * 32 * 8 * 128 * 2 * 131072 = 17,179,869,184 bytes = 16 GiB + assert abs(kv_gib - 16.0) < 0.01 + + def test_kv_cache_phi3_small_context(self): + """Phi-3-mini with 4k context should have small KV cache.""" + kv_bytes = _estimate_kv_cache_bytes(self.PHI3_MINI_CONFIG) + kv_gib = kv_bytes / (1024**3) + # 2 * 32 * 32 * 96 * 2 * 4096 = 1,610,612,736 bytes = 1.5 GiB + assert abs(kv_gib - 1.5) < 0.01 + + def test_estimate_weight_bytes_bf16(self): + """_estimate_weight_bytes returns just weights, no KV or overhead.""" + weight_bytes = _estimate_weight_bytes(4_000_000_000) # 4B params, BF16 + # 4B * 2 = 8 GB + assert weight_bytes == 4_000_000_000 * 2 + + def test_estimate_weight_bytes_awq(self): + """AWQ 4-bit quantization: 0.5 bytes per param.""" + weight_bytes = _estimate_weight_bytes(7_000_000_000, quant_method="awq", quant_bits=4) + assert weight_bytes == 7_000_000_000 * 0.5 + + def test_get_hf_model_config_qwen3(self): + """Mock HF config.json for Qwen3-4B returns correct architecture.""" + from unittest.mock import MagicMock, patch + + mock_resp = MagicMock() + mock_resp.json.return_value = { + 'num_hidden_layers': 36, + 'num_attention_heads': 32, + 'num_key_value_heads': 8, + 'head_dim': 128, + 'hidden_size': 2560, + 'max_position_embeddings': 40960, + } + with patch("clarifai.utils.compute_presets.requests.get", return_value=mock_resp): + config = _get_hf_model_config("Qwen/Qwen3-4B") + assert config == self.QWEN3_4B_CONFIG + + def test_get_hf_model_config_compute_head_dim(self): + """When head_dim is not explicit, compute from hidden_size / num_attention_heads.""" + from unittest.mock import MagicMock, patch + + mock_resp = MagicMock() + mock_resp.json.return_value = { + 'num_hidden_layers': 32, + 'num_attention_heads': 32, + 'num_key_value_heads': 8, + 'hidden_size': 4096, + 'max_position_embeddings': 131072, + } + with patch("clarifai.utils.compute_presets.requests.get", return_value=mock_resp): + config = _get_hf_model_config("meta-llama/Llama-3.1-8B") + assert config is not None + assert config['head_dim'] == 128 # 4096 / 32 + assert config['num_key_value_heads'] == 8 + + def test_get_hf_model_config_mha_fallback(self): + """MHA model (no num_key_value_heads) falls back to num_attention_heads.""" + from unittest.mock import MagicMock, patch + + mock_resp = MagicMock() + mock_resp.json.return_value = { + 'num_hidden_layers': 32, + 'num_attention_heads': 32, + 'hidden_size': 3072, + 'max_position_embeddings': 4096, + } + with patch("clarifai.utils.compute_presets.requests.get", return_value=mock_resp): + config = _get_hf_model_config("microsoft/phi-3-mini-4k-instruct") + assert config is not None + assert config['num_key_value_heads'] == 32 # falls back to num_attention_heads + + def test_get_hf_model_config_missing_max_pos(self): + """Models without max_position_embeddings return None (safe fallback).""" + from unittest.mock import MagicMock, patch + + mock_resp = MagicMock() + mock_resp.json.return_value = { + 'num_hidden_layers': 32, + 'num_attention_heads': 71, + 'hidden_size': 4544, + # No max_position_embeddings + } + with patch("clarifai.utils.compute_presets.requests.get", return_value=mock_resp): + config = _get_hf_model_config("tiiuae/falcon-7b") + assert config is None + + def test_get_hf_model_config_network_failure(self): + """Network failure returns None gracefully.""" + from unittest.mock import patch + + with patch( + "clarifai.utils.compute_presets.requests.get", + side_effect=requests.exceptions.ConnectionError("offline"), + ): + config = _get_hf_model_config("any/model") + assert config is None + + def test_get_hf_model_config_with_token(self): + """HF token is passed as Bearer header for gated models.""" + from unittest.mock import MagicMock, patch + + mock_resp = MagicMock() + mock_resp.json.return_value = { + 'num_hidden_layers': 32, + 'num_attention_heads': 8, + 'hidden_size': 4096, + 'max_position_embeddings': 131072, + } + with patch("clarifai.utils.compute_presets.requests.get", return_value=mock_resp) as mock: + _get_hf_model_config("meta-llama/Llama-3.1-8B", hf_token="hf_test_token") + # Verify token was passed in headers + call_kwargs = mock.call_args + assert call_kwargs.kwargs['headers']['Authorization'] == 'Bearer hf_test_token' + + def test_get_hf_token_from_config(self): + """Token from config takes priority.""" + config = {'checkpoints': {'hf_token': 'hf_from_config'}} + assert _get_hf_token(config) == 'hf_from_config' + + def test_get_hf_token_from_env(self): + """Falls back to HF_TOKEN environment variable.""" + from unittest.mock import patch + + with patch.dict(os.environ, {'HF_TOKEN': 'hf_from_env'}): + assert _get_hf_token({}) == 'hf_from_env' + + def test_get_hf_token_none(self): + """Returns None when no token available.""" + from unittest.mock import patch + + with patch.dict(os.environ, {}, clear=True): + # Also ensure no cached token file + with patch("builtins.open", side_effect=FileNotFoundError): + assert _get_hf_token({}) is None + + def test_recommend_vllm_qwen3_4b_accurate_kv(self): + """vLLM + Qwen3-4B: accurate KV cache → A10G instead of g4dn.""" + from unittest.mock import MagicMock, patch + + config = { + "model": {"model_type_id": "any-to-any"}, + "build_info": {"image": "vllm/vllm-openai:latest"}, + "checkpoints": {"repo_id": "Qwen/Qwen3-4B"}, + } + # Mock HF model info (parameter count) + mock_hf_info = MagicMock() + mock_hf_info.json.return_value = { + "safetensors": {"total": 4_020_000_000, "parameters": {"BF16": 4_020_000_000}}, + "config": {}, + } + with ( + patch("clarifai.utils.compute_presets.requests.get") as mock_get, + patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", return_value=None + ), + ): + # First call: _get_hf_model_info (HF API) + # Second call: _get_hf_model_config (config.json) + mock_config_resp = MagicMock() + mock_config_resp.json.return_value = { + 'num_hidden_layers': 36, + 'num_attention_heads': 32, + 'num_key_value_heads': 8, + 'head_dim': 128, + 'hidden_size': 2560, + 'max_position_embeddings': 40960, + } + mock_get.side_effect = [mock_hf_info, mock_config_resp] + + inst_id, reason = recommend_instance(config) + + # 4.02B * 2 = 8.04 GiB weights + 5.625 GiB KV + 2 GiB overhead = ~15.7 GiB + # g4dn (16 GiB) would be too tight → should pick A10G (24 GiB) + assert inst_id == "gpu-nvidia-a10g" + assert "KV cache" in reason + assert "40960 ctx" in reason + + def test_recommend_vllm_short_context_unchanged(self): + """vLLM + short context model: accurate KV cache is small, same result as heuristic.""" + from unittest.mock import MagicMock, patch + + config = { + "model": {"model_type_id": "any-to-any"}, + "build_info": {"image": "vllm/vllm-openai:latest"}, + "checkpoints": {"repo_id": "microsoft/phi-3-mini-4k-instruct"}, + } + mock_hf_info = MagicMock() + mock_hf_info.json.return_value = { + "safetensors": {"total": 3_800_000_000, "parameters": {"BF16": 3_800_000_000}}, + "config": {}, + } + mock_config_resp = MagicMock() + mock_config_resp.json.return_value = { + 'num_hidden_layers': 32, + 'num_attention_heads': 32, + 'hidden_size': 3072, + 'max_position_embeddings': 4096, + } + with ( + patch("clarifai.utils.compute_presets.requests.get") as mock_get, + patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", return_value=None + ), + ): + mock_get.side_effect = [mock_hf_info, mock_config_resp] + inst_id, reason = recommend_instance(config) + + # 3.8B * 2 = 7.6 GiB weights + ~1.5 GiB KV (small ctx) + 2 GiB = ~11 GiB + # Fits g4dn (16 GiB) - but g4dn is not in fallback tiers, so → A10G (24 GiB) + assert inst_id == "gpu-nvidia-a10g" + assert "KV cache" in reason + + def test_recommend_non_vllm_uses_heuristic(self): + """Non-vLLM toolkit (huggingface) uses heuristic, not accurate KV cache.""" + from unittest.mock import MagicMock, patch + + config = { + "model": {"model_type_id": "any-to-any"}, + "checkpoints": {"repo_id": "some/model"}, + } + mock_resp = MagicMock() + mock_resp.json.return_value = { + "safetensors": {"total": 7_000_000_000, "parameters": {"BF16": 7_000_000_000}}, + "config": {}, + } + with ( + patch("clarifai.utils.compute_presets.requests.get", return_value=mock_resp), + patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", return_value=None + ), + ): + inst_id, reason = recommend_instance(config) + + # Should NOT call _get_hf_model_config (no toolkit detected) + # 7B * 2 * 1.5 + overhead ≈ 22.9 GiB > A10G usable (21.6 GiB) → L40S + assert inst_id == "gpu-nvidia-l40s" + assert "KV cache" not in reason # heuristic path, no KV detail + + def test_recommend_vllm_config_unavailable_falls_back(self): + """vLLM with unavailable config.json falls back to heuristic.""" + from unittest.mock import MagicMock, patch + + config = { + "model": {"model_type_id": "any-to-any"}, + "build_info": {"image": "vllm/vllm-openai:latest"}, + "checkpoints": {"repo_id": "private/gated-model"}, + } + mock_hf_info = MagicMock() + mock_hf_info.json.return_value = { + "safetensors": {"total": 7_000_000_000, "parameters": {"BF16": 7_000_000_000}}, + "config": {}, + } + with ( + patch("clarifai.utils.compute_presets.requests.get") as mock_get, + patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", return_value=None + ), + ): + # First call: _get_hf_model_info succeeds + # Second call: _get_hf_model_config fails (gated) + mock_config_fail = MagicMock() + mock_config_fail.raise_for_status.side_effect = requests.exceptions.HTTPError("403") + mock_get.side_effect = [mock_hf_info, mock_config_fail] + + inst_id, reason = recommend_instance(config) + + # Falls back to heuristic: 7B * 2 * 1.5 + overhead ≈ 22.9 GiB > A10G usable → L40S + assert inst_id == "gpu-nvidia-l40s" + assert "KV cache" not in reason # heuristic, no KV detail + + def test_recommend_vllm_file_size_with_kv(self): + """vLLM file-size fallback also uses accurate KV cache when available.""" + from unittest.mock import patch + + config = { + "model": {"model_type_id": "any-to-any"}, + "build_info": {"image": "vllm/vllm-openai:latest"}, + "checkpoints": {"repo_id": "Qwen/Qwen3-4B"}, + } + with ( + patch( + "clarifai.utils.compute_presets._get_hf_model_info", + return_value={ + "num_params": None, + "quant_method": None, + "quant_bits": None, + "dtype_breakdown": None, + "pipeline_tag": None, + }, + ), + patch( + "clarifai.utils.compute_presets._get_hf_model_config", + return_value=self.QWEN3_4B_CONFIG, + ), + patch( + "clarifai.runners.utils.loader.HuggingFaceLoader.get_huggingface_checkpoint_total_size", + return_value=int(7.5 * 1024**3), # 7.5 GiB file size + ), + patch( + "clarifai.utils.compute_presets._try_list_all_instance_types", return_value=None + ), + ): + inst_id, reason = recommend_instance(config) + + # 7.5 GiB files + 5.625 GiB KV + 2 GiB overhead = ~15.1 GiB → A10G + assert inst_id == "gpu-nvidia-a10g" + assert "KV cache" in reason + assert "40960 ctx" in reason diff --git a/tests/runners/test_model_init_huggingface_toolkit.py b/tests/runners/test_model_init_huggingface_toolkit.py index f9ce07c1..9bcb272c 100644 --- a/tests/runners/test_model_init_huggingface_toolkit.py +++ b/tests/runners/test_model_init_huggingface_toolkit.py @@ -1,58 +1,15 @@ -import os +"""Tests for model init with huggingface toolkit (embedded templates, no GitHub clone).""" -"""Tests for model init with huggingface toolkit. - -These tests run fully offline by setting CLARIFAI_SKIP_GITHUB_LISTING so the -command won't attempt to hit the GitHub contents API (which could rate-limit in CI). -""" import yaml from click.testing import CliRunner -import clarifai.cli.model as model_module from clarifai.cli.base import cli def test_model_init_huggingface_toolkit(monkeypatch, tmp_path): - """Happy path: model-name provided -> checkpoints.repo_id created and set.""" + """Happy path: --model-name provided -> checkpoints.repo_id set, model.id sanitized.""" runner = CliRunner() runner.invoke(cli, ["login", "--user_id", "test_user"]) - called = {'clone': False, 'repo_url': None, 'branch': None} - - def fake_clone(repo_url, clone_dir, github_pat, branch): - called['clone'] = True - called['repo_url'] = repo_url - called['branch'] = branch - version_dir = os.path.join(clone_dir, '1') - os.makedirs(version_dir, exist_ok=True) - # minimal model file (content should remain unchanged by huggingface customization) - with open(os.path.join(version_dir, 'model.py'), 'w') as f: - f.write('pass') - # config WITHOUT checkpoints so code path adds it - with open(os.path.join(clone_dir, 'config.yaml'), 'w') as f: - f.write('model:\n id: dummy\n') - with open(os.path.join(clone_dir, 'requirements.txt'), 'w') as f: - f.write('# none') - return True - - # Stub remote folder listing instead of relying on env flags - monkeypatch.setattr( - model_module.GitHubDownloader, - 'get_folder_contents', - lambda self, owner, repo, path, branch: [ - {'name': '1', 'type': 'dir', 'path': '1'}, - {'name': 'config.yaml', 'type': 'file', 'path': 'config.yaml'}, - {'name': 'requirements.txt', 'type': 'file', 'path': 'requirements.txt'}, - ], - raising=True, - ) - - # Patches - monkeypatch.setattr(model_module, 'clone_github_repo', fake_clone) - monkeypatch.setattr( - model_module, 'check_requirements_installed', lambda path: True, raising=False - ) - # Simulate pressing Enter for interactive confirmation - monkeypatch.setattr('builtins.input', lambda *a, **k: '\n') model_dir = tmp_path / 'hf_model' result = runner.invoke( @@ -70,8 +27,6 @@ def fake_clone(repo_url, clone_dir, github_pat, branch): ) assert result.exit_code == 0, result.output - assert called['clone'] is True - assert called['repo_url'] is not None # sanity that our fake saw a value cfg_path = model_dir / 'config.yaml' assert cfg_path.exists(), 'config.yaml not created' @@ -80,57 +35,33 @@ def fake_clone(repo_url, clone_dir, github_pat, branch): 'checkpoints section missing' ) assert data['checkpoints']['repo_id'] == 'UnsLOTH/Llama-1B' + assert data['model']['id'] == 'llama-1b' model_py = model_dir / '1' / 'model.py' assert model_py.exists(), 'model.py missing' - assert model_py.read_text() == 'pass', 'model.py unexpectedly modified' + assert 'HuggingFaceModel' in model_py.read_text(), 'embedded hf model class missing' + + requirements = model_dir / 'requirements.txt' + assert requirements.exists(), 'requirements.txt missing' + assert 'transformers' in requirements.read_text() def test_model_init_hf_no_model_name(monkeypatch, tmp_path): - """No --model-name: checkpoints section should NOT be added (mirrors current logic).""" + """No --model-name: default checkpoint from embedded template remains.""" runner = CliRunner() runner.invoke(cli, ["login", "--user_id", "test_user"]) - called = {'clone': False} - - def fake_clone(repo_url, clone_dir, github_pat, branch): - called['clone'] = True - version_dir = os.path.join(clone_dir, '1') - os.makedirs(version_dir, exist_ok=True) - with open(os.path.join(version_dir, 'model.py'), 'w') as f: - f.write('pass') - with open(os.path.join(clone_dir, 'config.yaml'), 'w') as f: - f.write('model:\n id: dummy\n') - with open(os.path.join(clone_dir, 'requirements.txt'), 'w') as f: - f.write('# none') - return True - - monkeypatch.setattr(model_module, 'clone_github_repo', fake_clone) - monkeypatch.setattr( - model_module.GitHubDownloader, - 'get_folder_contents', - lambda self, owner, repo, path, branch: [ - {'name': '1', 'type': 'dir', 'path': '1'}, - {'name': 'config.yaml', 'type': 'file', 'path': 'config.yaml'}, - {'name': 'requirements.txt', 'type': 'file', 'path': 'requirements.txt'}, - ], - raising=True, - ) - monkeypatch.setattr( - model_module, 'check_requirements_installed', lambda path: True, raising=False - ) - monkeypatch.setattr('builtins.input', lambda *a, **k: '\n') model_dir = tmp_path / 'hf_model2' result = runner.invoke( cli, - ['model', 'init', str(model_dir), '--toolkit', 'huggingface'], # no --model-name + ['model', 'init', str(model_dir), '--toolkit', 'huggingface'], standalone_mode=False, ) assert result.exit_code == 0, result.output - assert called['clone'] is True cfg_path = model_dir / 'config.yaml' data = yaml.safe_load(cfg_path.read_text()) - assert 'checkpoints' not in data, 'checkpoints unexpectedly added without model-name' + assert 'checkpoints' in data + assert data['checkpoints']['repo_id'] == 'unsloth/Llama-3.2-1B-Instruct' assert (model_dir / '1' / 'model.py').exists() diff --git a/tests/runners/test_model_init_lmstudio_toolkit.py b/tests/runners/test_model_init_lmstudio_toolkit.py index 7c15f376..fed49617 100644 --- a/tests/runners/test_model_init_lmstudio_toolkit.py +++ b/tests/runners/test_model_init_lmstudio_toolkit.py @@ -1,4 +1,4 @@ -import os +"""Tests for model init with lmstudio toolkit (embedded templates, no GitHub clone).""" import yaml from click.testing import CliRunner @@ -7,50 +7,13 @@ from clarifai.cli.base import cli -def test_model_init_lmstudio_toolkit(monkeypatch, tmp_path): - """Happy path: all customization flags provided; placeholders replaced.""" +def test_model_init_lmstudio_with_model_name(monkeypatch, tmp_path): + """Happy path: --model-name provided -> toolkit.model set in config.""" runner = CliRunner() runner.invoke(cli, ["login", "--user_id", "test_user"]) - called = {'clone': False, 'repo_url': None, 'branch': None} - - def fake_clone(repo_url, clone_dir, github_pat, branch): - called['clone'] = True - called['repo_url'] = repo_url - called['branch'] = branch - version_dir = os.path.join(clone_dir, '1') - os.makedirs(version_dir, exist_ok=True) - model_py = os.path.join(version_dir, 'model.py') - with open(model_py, 'w') as f: - f.write('pass') - with open(os.path.join(clone_dir, 'config.yaml'), 'w') as f: - f.write('model:\n id: dummy\n') - with open(os.path.join(clone_dir, 'requirements.txt'), 'w') as f: - f.write('# none') - return True - - # Stub remote GitHub folder listing - monkeypatch.setattr( - model_module.GitHubDownloader, - 'get_folder_contents', - lambda self, owner, repo, path, branch: [ - {'name': '1', 'type': 'dir', 'path': '1'}, - {'name': 'config.yaml', 'type': 'file', 'path': 'config.yaml'}, - {'name': 'requirements.txt', 'type': 'file', 'path': 'requirements.txt'}, - ], - raising=True, - ) - - # Patches - monkeypatch.setattr(model_module, 'clone_github_repo', fake_clone) monkeypatch.setattr(model_module, 'check_lmstudio_installed', lambda: True) - monkeypatch.setattr( - model_module, 'check_requirements_installed', lambda path: True, raising=False - ) - # Simulate Enter key for interactive confirmation - monkeypatch.setattr('builtins.input', lambda *a, **k: '\n') model_dir = tmp_path / 'lmstudio_model' - result = runner.invoke( cli, [ @@ -61,87 +24,42 @@ def fake_clone(repo_url, clone_dir, github_pat, branch): 'lmstudio', '--model-name', 'qwen/qwen3-4b', - '--port', - '11888', - '--context-length', - '16000', ], standalone_mode=False, ) assert result.exit_code == 0, result.output - assert called['clone'] is True - assert called['repo_url'] is not None cfg_path = model_dir / 'config.yaml' assert cfg_path.exists(), 'config.yaml not created' data = yaml.safe_load(cfg_path.read_text()) assert 'toolkit' in data and isinstance(data['toolkit'], dict), 'toolkit section missing' - - # New values assert data['toolkit']['model'] == 'qwen/qwen3-4b' - assert data['toolkit']['port'] == '11888' - assert data['toolkit']['context_length'] == '16000' + assert data['model']['id'] == 'qwen3-4b' - # Originals removed - assert data['toolkit']['model'] != 'LiquidAI/LFM2-1.2B' - assert data['toolkit']['port'] != '11434' - assert data['toolkit']['context_length'] != '2048' + model_py = model_dir / '1' / 'model.py' + assert model_py.exists(), 'model.py missing' + assert 'LMStudioModel' in model_py.read_text(), 'embedded lmstudio model class missing' def test_model_init_lmstudio_defaults(monkeypatch, tmp_path): - """No customization flags: placeholders remain unchanged.""" + """No --model-name: defaults from embedded template remain.""" runner = CliRunner() runner.invoke(cli, ["login", "--user_id", "test_user"]) - called = {'clone': True} - - def fake_clone(repo_url, clone_dir, github_pat, branch): - version_dir = os.path.join(clone_dir, '1') - os.makedirs(version_dir, exist_ok=True) - with open(os.path.join(version_dir, 'model.py'), 'w') as f: - f.write( - "LMS_MODEL_NAME = 'LiquidAI/LFM2-1.2B'\n" - "LMS_PORT = 11434\n" - "LMS_CONTEXT_LENGTH = 4096\n" - ) - with open(os.path.join(clone_dir, 'config.yaml'), 'w') as f: - f.write('model:\n id: dummy\n') - with open(os.path.join(clone_dir, 'requirements.txt'), 'w') as f: - f.write('# none') - return True - - monkeypatch.setattr(model_module, 'clone_github_repo', fake_clone) - monkeypatch.setattr( - model_module.GitHubDownloader, - 'get_folder_contents', - lambda self, owner, repo, path, branch: [ - {'name': '1', 'type': 'dir', 'path': '1'}, - {'name': 'config.yaml', 'type': 'file', 'path': 'config.yaml'}, - {'name': 'requirements.txt', 'type': 'file', 'path': 'requirements.txt'}, - ], - raising=True, - ) monkeypatch.setattr(model_module, 'check_lmstudio_installed', lambda: True) - monkeypatch.setattr( - model_module, 'check_requirements_installed', lambda path: True, raising=False - ) - monkeypatch.setattr('builtins.input', lambda *a, **k: '\n') model_dir = tmp_path / 'lmstudio_model_default' result = runner.invoke( cli, - [ - 'model', - 'init', - str(model_dir), - '--toolkit', - 'lmstudio', - ], # no customization args + ['model', 'init', str(model_dir), '--toolkit', 'lmstudio'], standalone_mode=False, ) + assert result.exit_code == 0, result.output content = (model_dir / '1' / 'model.py').read_text() - assert "LMS_MODEL_NAME = 'LiquidAI/LFM2-1.2B'" in content - assert "LMS_PORT = 11434" in content - assert "LMS_CONTEXT_LENGTH = 4096" in content + assert 'LMS_MODEL_NAME = os.environ.get("LMS_MODEL_NAME", "google/gemma-3-4b")' in content + assert 'LMS_PORT = int(os.environ.get("LMS_PORT", "23333"))' in content + + files = {p.name for p in model_dir.iterdir()} + assert {'1', 'config.yaml', 'requirements.txt'}.issubset(files) diff --git a/tests/runners/test_model_init_ollama_toolkit.py b/tests/runners/test_model_init_ollama_toolkit.py index a869f820..5eb4f239 100644 --- a/tests/runners/test_model_init_ollama_toolkit.py +++ b/tests/runners/test_model_init_ollama_toolkit.py @@ -1,98 +1,71 @@ -import os +"""Tests for model init with ollama toolkit (embedded templates, no GitHub clone).""" -import pytest +import yaml from click.testing import CliRunner import clarifai.cli.model as model_module from clarifai.cli.base import cli -@pytest.mark.parametrize( - "custom,model_name,port,context_length", - [ - (True, "my-ollama", "4567", "9999"), - (False, None, None, None), - ], -) -def test_model_init_ollama(monkeypatch, tmp_path, custom, model_name, port, context_length): - """Test ollama toolkit init with and without customization flags.""" +def test_model_init_ollama_with_model_name(monkeypatch, tmp_path): + """Happy path: --model-name provided -> toolkit.model set in config.""" runner = CliRunner() runner.invoke(cli, ["login", "--user_id", "test_user"]) - called = {'clone': False} - - def fake_clone(repo_url, clone_dir, github_pat, branch): - called['clone'] = True - version_dir = os.path.join(clone_dir, '1') - os.makedirs(version_dir, exist_ok=True) - # template model file - with open(os.path.join(version_dir, 'model.py'), 'w') as f: - f.write( - "# placeholder template\nimport os\n\nclass Dummy:\n" - " def __init__(self):\n" - " self.model = os.environ.get(\"OLLAMA_MODEL_NAME\", 'llama3.2')\n\n" - "PORT = '23333'\ncontext_length = '8192'\n" - ) - with open(os.path.join(clone_dir, 'config.yaml'), 'w') as f: - f.write("model:\n id: dummy\n") - with open(os.path.join(clone_dir, 'requirements.txt'), 'w') as f: - f.write("# none\n") - return True - - monkeypatch.setattr(model_module, 'clone_github_repo', fake_clone) monkeypatch.setattr(model_module, 'check_ollama_installed', lambda: True) - monkeypatch.setattr( - model_module, 'check_requirements_installed', lambda path: True, raising=False - ) - # Avoid real GitHub listing by stubbing folder contents. - monkeypatch.setattr( - model_module.GitHubDownloader, - 'get_folder_contents', - lambda self, owner, repo, path, branch: [ - {'name': '1', 'type': 'dir', 'path': '1'}, - {'name': 'config.yaml', 'type': 'file', 'path': 'config.yaml'}, - {'name': 'requirements.txt', 'type': 'file', 'path': 'requirements.txt'}, + + model_dir = tmp_path / 'ollama_custom' + result = runner.invoke( + cli, + [ + 'model', + 'init', + str(model_dir), + '--toolkit', + 'ollama', + '--model-name', + 'llama3.1', ], - raising=True, + standalone_mode=False, ) - # Simulate user pressing Enter at interactive prompts - monkeypatch.setattr('builtins.input', lambda *a, **k: '\n') + assert result.exit_code == 0, result.output + + cfg_path = model_dir / 'config.yaml' + assert cfg_path.exists(), 'config.yaml not created' + data = yaml.safe_load(cfg_path.read_text()) + assert data['toolkit']['model'] == 'llama3.1' + assert data['model']['id'] == 'llama31' # sanitized: dots stripped + + model_py = model_dir / '1' / 'model.py' + assert model_py.exists(), 'model.py missing' + content = model_py.read_text() + assert 'OllamaModel' in content, 'embedded ollama model class missing' + # model name should be replaced in the model.py + assert 'llama3.1' in content + + files = {p.name for p in model_dir.iterdir()} + assert {'1', 'config.yaml', 'requirements.txt'}.issubset(files) + - args = [ - 'model', - 'init', - str(tmp_path / ('ollama_custom' if custom else 'ollama_default')), - '--toolkit', - 'ollama', - ] - if custom: - args.extend( - ['--model-name', model_name, '--port', port, '--context-length', context_length] - ) +def test_model_init_ollama_defaults(monkeypatch, tmp_path): + """No --model-name: defaults from embedded template remain.""" + runner = CliRunner() + runner.invoke(cli, ["login", "--user_id", "test_user"]) + monkeypatch.setattr(model_module, 'check_ollama_installed', lambda: True) - result = runner.invoke(cli, args, standalone_mode=False) + model_dir = tmp_path / 'ollama_default' + result = runner.invoke( + cli, + ['model', 'init', str(model_dir), '--toolkit', 'ollama'], + standalone_mode=False, + ) assert result.exit_code == 0, result.output - assert called['clone'] is True - # We allow prompts; just ensure command succeeds. - model_py = tmp_path / ('ollama_custom' if custom else 'ollama_default') / '1' / 'model.py' + model_py = model_dir / '1' / 'model.py' content = model_py.read_text() + # defaults remain + assert 'llama3.2' in content - if custom: - assert model_name in content - assert f"PORT = '{port}'" in content - assert f"context_length = '{context_length}'" in content - # old defaults replaced - assert "PORT = '23333'" not in content - assert "context_length = '8192'" not in content - else: - # defaults remain - assert "llama3.2" in content - assert "PORT = '23333'" in content - assert "context_length = '8192'" in content - - # baseline file set - root = tmp_path / ('ollama_custom' if custom else 'ollama_default') - files = {p.name for p in root.iterdir()} + files = {p.name for p in model_dir.iterdir()} assert {'1', 'config.yaml', 'requirements.txt'}.issubset(files) diff --git a/tests/runners/test_model_init_sglang_toolkit.py b/tests/runners/test_model_init_sglang_toolkit.py index b3c89c75..81e6a437 100644 --- a/tests/runners/test_model_init_sglang_toolkit.py +++ b/tests/runners/test_model_init_sglang_toolkit.py @@ -1,58 +1,15 @@ -import os +"""Tests for model init with sglang toolkit (embedded templates, no GitHub clone).""" -"""Tests for model init with sglang toolkit. - -These tests run fully offline by setting CLARIFAI_SKIP_GITHUB_LISTING so the -command won't attempt to hit the GitHub contents API (which could rate-limit in CI). -""" import yaml from click.testing import CliRunner -import clarifai.cli.model as model_module from clarifai.cli.base import cli def test_model_init_sglang_toolkit(monkeypatch, tmp_path): - """Happy path: model-name provided -> checkpoints.repo_id created and set.""" + """Happy path: --model-name provided -> checkpoints.repo_id set, model.id sanitized.""" runner = CliRunner() runner.invoke(cli, ["login", "--user_id", "test_user"]) - called = {'clone': False, 'repo_url': None, 'branch': None} - - def fake_clone(repo_url, clone_dir, github_pat, branch): - called['clone'] = True - called['repo_url'] = repo_url - called['branch'] = branch - version_dir = os.path.join(clone_dir, '1') - os.makedirs(version_dir, exist_ok=True) - # minimal model file (content should remain unchanged by huggingface customization) - with open(os.path.join(version_dir, 'model.py'), 'w') as f: - f.write('pass') - # config WITHOUT checkpoints so code path adds it - with open(os.path.join(clone_dir, 'config.yaml'), 'w') as f: - f.write('model:\n id: dummy\n') - with open(os.path.join(clone_dir, 'requirements.txt'), 'w') as f: - f.write('# none') - return True - - # Stub remote folder listing instead of relying on env flags - monkeypatch.setattr( - model_module.GitHubDownloader, - 'get_folder_contents', - lambda self, owner, repo, path, branch: [ - {'name': '1', 'type': 'dir', 'path': '1'}, - {'name': 'config.yaml', 'type': 'file', 'path': 'config.yaml'}, - {'name': 'requirements.txt', 'type': 'file', 'path': 'requirements.txt'}, - ], - raising=True, - ) - - # Patches - monkeypatch.setattr(model_module, 'clone_github_repo', fake_clone) - monkeypatch.setattr( - model_module, 'check_requirements_installed', lambda path: True, raising=False - ) - # Simulate pressing Enter for interactive confirmation - monkeypatch.setattr('builtins.input', lambda *a, **k: '\n') model_dir = tmp_path / 'sglang_model' result = runner.invoke( @@ -70,8 +27,6 @@ def fake_clone(repo_url, clone_dir, github_pat, branch): ) assert result.exit_code == 0, result.output - assert called['clone'] is True - assert called['repo_url'] is not None # sanity that our fake saw a value cfg_path = model_dir / 'config.yaml' assert cfg_path.exists(), 'config.yaml not created' @@ -80,57 +35,29 @@ def fake_clone(repo_url, clone_dir, github_pat, branch): 'checkpoints section missing' ) assert data['checkpoints']['repo_id'] == 'microsoft/phi-1_5' + assert data['model']['id'] == 'phi-1-5' model_py = model_dir / '1' / 'model.py' assert model_py.exists(), 'model.py missing' - assert model_py.read_text() == 'pass', 'model.py unexpectedly modified' + assert 'SGLangModel' in model_py.read_text(), 'embedded sglang model class missing' def test_model_init_sglang_no_model_name(monkeypatch, tmp_path): - """No --model-name: checkpoints section should NOT be added (mirrors current logic).""" + """No --model-name: default checkpoint from embedded template remains.""" runner = CliRunner() runner.invoke(cli, ["login", "--user_id", "test_user"]) - called = {'clone': False} - - def fake_clone(repo_url, clone_dir, github_pat, branch): - called['clone'] = True - version_dir = os.path.join(clone_dir, '1') - os.makedirs(version_dir, exist_ok=True) - with open(os.path.join(version_dir, 'model.py'), 'w') as f: - f.write('pass') - with open(os.path.join(clone_dir, 'config.yaml'), 'w') as f: - f.write('model:\n id: dummy\n') - with open(os.path.join(clone_dir, 'requirements.txt'), 'w') as f: - f.write('# none') - return True - - monkeypatch.setattr(model_module, 'clone_github_repo', fake_clone) - monkeypatch.setattr( - model_module.GitHubDownloader, - 'get_folder_contents', - lambda self, owner, repo, path, branch: [ - {'name': '1', 'type': 'dir', 'path': '1'}, - {'name': 'config.yaml', 'type': 'file', 'path': 'config.yaml'}, - {'name': 'requirements.txt', 'type': 'file', 'path': 'requirements.txt'}, - ], - raising=True, - ) - monkeypatch.setattr( - model_module, 'check_requirements_installed', lambda path: True, raising=False - ) - monkeypatch.setattr('builtins.input', lambda *a, **k: '\n') model_dir = tmp_path / 'sglang_model2' result = runner.invoke( cli, - ['model', 'init', str(model_dir), '--toolkit', 'sglang'], # no --model-name + ['model', 'init', str(model_dir), '--toolkit', 'sglang'], standalone_mode=False, ) assert result.exit_code == 0, result.output - assert called['clone'] is True cfg_path = model_dir / 'config.yaml' data = yaml.safe_load(cfg_path.read_text()) - assert 'checkpoints' not in data, 'checkpoints unexpectedly added without model-name' + assert 'checkpoints' in data + assert data['checkpoints']['repo_id'] == 'google/gemma-3-1b-it' assert (model_dir / '1' / 'model.py').exists() diff --git a/tests/runners/test_model_init_vllm_toolkit.py b/tests/runners/test_model_init_vllm_toolkit.py index f9e7948e..b5f8a7e8 100644 --- a/tests/runners/test_model_init_vllm_toolkit.py +++ b/tests/runners/test_model_init_vllm_toolkit.py @@ -1,58 +1,15 @@ -import os +"""Tests for model init with vllm toolkit (embedded templates, no GitHub clone).""" -"""Tests for model init with vllm toolkit. - -These tests run fully offline by setting CLARIFAI_SKIP_GITHUB_LISTING so the -command won't attempt to hit the GitHub contents API (which could rate-limit in CI). -""" import yaml from click.testing import CliRunner -import clarifai.cli.model as model_module from clarifai.cli.base import cli def test_model_init_vllm_toolkit(monkeypatch, tmp_path): - """Happy path: model-name provided -> checkpoints.repo_id created and set.""" + """Happy path: --model-name provided -> checkpoints.repo_id set, model.id sanitized.""" runner = CliRunner() runner.invoke(cli, ["login", "--user_id", "test_user"]) - called = {'clone': False, 'repo_url': None, 'branch': None} - - def fake_clone(repo_url, clone_dir, github_pat, branch): - called['clone'] = True - called['repo_url'] = repo_url - called['branch'] = branch - version_dir = os.path.join(clone_dir, '1') - os.makedirs(version_dir, exist_ok=True) - # minimal model file (content should remain unchanged by huggingface customization) - with open(os.path.join(version_dir, 'model.py'), 'w') as f: - f.write('pass') - # config WITHOUT checkpoints so code path adds it - with open(os.path.join(clone_dir, 'config.yaml'), 'w') as f: - f.write('model:\n id: dummy\n') - with open(os.path.join(clone_dir, 'requirements.txt'), 'w') as f: - f.write('# none') - return True - - # Stub remote folder listing instead of relying on env flags - monkeypatch.setattr( - model_module.GitHubDownloader, - 'get_folder_contents', - lambda self, owner, repo, path, branch: [ - {'name': '1', 'type': 'dir', 'path': '1'}, - {'name': 'config.yaml', 'type': 'file', 'path': 'config.yaml'}, - {'name': 'requirements.txt', 'type': 'file', 'path': 'requirements.txt'}, - ], - raising=True, - ) - - # Patches - monkeypatch.setattr(model_module, 'clone_github_repo', fake_clone) - monkeypatch.setattr( - model_module, 'check_requirements_installed', lambda path: True, raising=False - ) - # Simulate pressing Enter for interactive confirmation - monkeypatch.setattr('builtins.input', lambda *a, **k: '\n') model_dir = tmp_path / 'vllm_model' result = runner.invoke( @@ -70,8 +27,6 @@ def fake_clone(repo_url, clone_dir, github_pat, branch): ) assert result.exit_code == 0, result.output - assert called['clone'] is True - assert called['repo_url'] is not None # sanity that our fake saw a value cfg_path = model_dir / 'config.yaml' assert cfg_path.exists(), 'config.yaml not created' @@ -80,57 +35,35 @@ def fake_clone(repo_url, clone_dir, github_pat, branch): 'checkpoints section missing' ) assert data['checkpoints']['repo_id'] == 'microsoft/phi-1_5' + assert data['model']['id'] == 'phi-1-5' # sanitized: dots and underscores handled model_py = model_dir / '1' / 'model.py' assert model_py.exists(), 'model.py missing' - assert model_py.read_text() == 'pass', 'model.py unexpectedly modified' + assert 'VLLMModel' in model_py.read_text(), 'embedded vllm model class missing' + requirements = model_dir / 'requirements.txt' + assert requirements.exists(), 'requirements.txt missing' + # vllm is provided by the Docker base image, not in requirements.txt + assert 'clarifai' in requirements.read_text() -def test_model_init_hf_no_model_name(monkeypatch, tmp_path): - """No --model-name: checkpoints section should NOT be added (mirrors current logic).""" + +def test_model_init_vllm_no_model_name(monkeypatch, tmp_path): + """No --model-name: default checkpoint from embedded template remains.""" runner = CliRunner() runner.invoke(cli, ["login", "--user_id", "test_user"]) - called = {'clone': False} - - def fake_clone(repo_url, clone_dir, github_pat, branch): - called['clone'] = True - version_dir = os.path.join(clone_dir, '1') - os.makedirs(version_dir, exist_ok=True) - with open(os.path.join(version_dir, 'model.py'), 'w') as f: - f.write('pass') - with open(os.path.join(clone_dir, 'config.yaml'), 'w') as f: - f.write('model:\n id: dummy\n') - with open(os.path.join(clone_dir, 'requirements.txt'), 'w') as f: - f.write('# none') - return True - - monkeypatch.setattr(model_module, 'clone_github_repo', fake_clone) - monkeypatch.setattr( - model_module.GitHubDownloader, - 'get_folder_contents', - lambda self, owner, repo, path, branch: [ - {'name': '1', 'type': 'dir', 'path': '1'}, - {'name': 'config.yaml', 'type': 'file', 'path': 'config.yaml'}, - {'name': 'requirements.txt', 'type': 'file', 'path': 'requirements.txt'}, - ], - raising=True, - ) - monkeypatch.setattr( - model_module, 'check_requirements_installed', lambda path: True, raising=False - ) - monkeypatch.setattr('builtins.input', lambda *a, **k: '\n') model_dir = tmp_path / 'vllm_model2' result = runner.invoke( cli, - ['model', 'init', str(model_dir), '--toolkit', 'vllm'], # no --model-name + ['model', 'init', str(model_dir), '--toolkit', 'vllm'], standalone_mode=False, ) assert result.exit_code == 0, result.output - assert called['clone'] is True cfg_path = model_dir / 'config.yaml' data = yaml.safe_load(cfg_path.read_text()) - assert 'checkpoints' not in data, 'checkpoints unexpectedly added without model-name' + # Default checkpoint from embedded template + assert 'checkpoints' in data + assert data['checkpoints']['repo_id'] == 'google/gemma-3-1b-it' assert (model_dir / '1' / 'model.py').exists() diff --git a/tests/runners/test_num_threads_config.py b/tests/runners/test_num_threads_config.py index 7b5d5e57..fe3448b5 100644 --- a/tests/runners/test_num_threads_config.py +++ b/tests/runners/test_num_threads_config.py @@ -63,7 +63,9 @@ def test_num_threads(my_tmp_path, num_threads, monkeypatch): assert builder.config.get("num_threads") == num_threads elif num_threads in [-1, 0, "a", 1.5]: - with pytest.raises(AssertionError): + from clarifai.errors import UserError + + with pytest.raises(UserError): builder = ModelBuilder(target_folder, validate_api_ids=False) diff --git a/tests/test_cli_logout.py b/tests/test_cli_logout.py index 7d42df17..f1017624 100644 --- a/tests/test_cli_logout.py +++ b/tests/test_cli_logout.py @@ -70,41 +70,45 @@ def _load_config(config_path): return Config.from_yaml(filename=config_path) -class TestLogoutNonInteractive: - """Tests for flag-based (non-interactive) logout.""" +class TestLogoutDefault: + """Tests for default logout (no flags = log out of current context).""" - def test_logout_current_clears_pat(self, tmp_path): - """--current should clear PAT from the active context.""" + def test_bare_logout_clears_pat(self, tmp_path): + """'clarifai logout' should clear PAT from the active context.""" config_path = _make_config(tmp_path) runner = CliRunner() - result = runner.invoke(cli, ['--config', config_path, 'logout', '--current']) + result = runner.invoke(cli, ['--config', config_path, 'logout']) assert result.exit_code == 0 assert "Logged out of context 'default'" in result.output cfg = _load_config(config_path) assert cfg.contexts['default']['env']['CLARIFAI_PAT'] == '' - def test_logout_current_delete_single_context(self, tmp_path): - """--current --delete with only one context should clear PAT but keep context.""" + def test_bare_logout_delete_single_context(self, tmp_path): + """'clarifai logout --delete' with only one context should clear PAT but keep context.""" config_path = _make_config(tmp_path) runner = CliRunner() - result = runner.invoke(cli, ['--config', config_path, 'logout', '--current', '--delete']) + result = runner.invoke(cli, ['--config', config_path, 'logout', '--delete']) assert result.exit_code == 0 cfg = _load_config(config_path) assert 'default' in cfg.contexts # context kept assert cfg.contexts['default']['env']['CLARIFAI_PAT'] == '' assert "only context" in result.output - def test_logout_current_delete_multi_context(self, tmp_path): - """--current --delete with multiple contexts should delete and switch.""" + def test_bare_logout_delete_multi_context(self, tmp_path): + """'clarifai logout --delete' with multiple contexts should delete and switch.""" config_path = _multi_context_config(tmp_path=tmp_path) runner = CliRunner() - result = runner.invoke(cli, ['--config', config_path, 'logout', '--current', '--delete']) + result = runner.invoke(cli, ['--config', config_path, 'logout', '--delete']) assert result.exit_code == 0 cfg = _load_config(config_path) assert 'default' not in cfg.contexts assert cfg.current_context == 'staging' assert "deleted" in result.output.lower() + +class TestLogoutNamedContext: + """Tests for --context flag.""" + def test_logout_named_context(self, tmp_path): """--context should clear PAT from the named context only.""" config_path = _multi_context_config(tmp_path=tmp_path) @@ -125,6 +129,10 @@ def test_logout_named_context_not_found(self, tmp_path): assert result.exit_code != 0 assert "not found" in result.output.lower() + +class TestLogoutAll: + """Tests for --all flag.""" + def test_logout_all(self, tmp_path): """--all should clear PATs from every context.""" config_path = _multi_context_config(tmp_path=tmp_path) @@ -140,31 +148,14 @@ def test_logout_all(self, tmp_path): class TestLogoutFlagValidation: """Tests for invalid flag combinations.""" - def test_delete_without_current_or_context(self, tmp_path): - """--delete alone should error.""" - config_path = _make_config(tmp_path) - runner = CliRunner() - result = runner.invoke(cli, ['--config', config_path, 'logout', '--delete']) - assert result.exit_code != 0 - assert "--delete requires" in result.output - - def test_current_and_context_together(self, tmp_path): - """--current and --context together should error.""" + def test_all_with_other_flags(self, tmp_path): + """--all combined with --context or --delete should error.""" config_path = _make_config(tmp_path) runner = CliRunner() result = runner.invoke( - cli, - ['--config', config_path, 'logout', '--current', '--context', 'default'], + cli, ['--config', config_path, 'logout', '--all', '--context', 'default'] ) assert result.exit_code != 0 - assert "Cannot use --current and --context together" in result.output - - def test_all_with_other_flags(self, tmp_path): - """--all combined with --current, --context, or --delete should error.""" - config_path = _make_config(tmp_path) - runner = CliRunner() - result = runner.invoke(cli, ['--config', config_path, 'logout', '--all', '--current']) - assert result.exit_code != 0 assert "--all cannot be combined" in result.output @@ -176,50 +167,6 @@ def test_warns_when_env_pat_set(self, tmp_path): config_path = _make_config(tmp_path) runner = CliRunner() with mock.patch.dict(os.environ, {'CLARIFAI_PAT': 'env_pat_value'}): - result = runner.invoke(cli, ['--config', config_path, 'logout', '--current']) + result = runner.invoke(cli, ['--config', config_path, 'logout']) assert result.exit_code == 0 assert "CLARIFAI_PAT environment variable is still set" in result.output - - -class TestLogoutInteractive: - """Tests for the interactive menu flow.""" - - def test_interactive_cancel(self, tmp_path): - """Choosing cancel should make no changes.""" - config_path = _make_config(tmp_path) - runner = CliRunner() - result = runner.invoke(cli, ['--config', config_path, 'logout'], input='5\n') - assert result.exit_code == 0 - assert "Cancelled" in result.output - cfg = _load_config(config_path) - assert cfg.contexts['default']['env']['CLARIFAI_PAT'] == 'test_pat_12345' - - def test_interactive_logout_current(self, tmp_path): - """Choosing option 2 should clear current context PAT.""" - config_path = _make_config(tmp_path) - runner = CliRunner() - result = runner.invoke(cli, ['--config', config_path, 'logout'], input='2\n') - assert result.exit_code == 0 - assert "Logged out of context 'default'" in result.output - cfg = _load_config(config_path) - assert cfg.contexts['default']['env']['CLARIFAI_PAT'] == '' - - def test_interactive_switch_context(self, tmp_path): - """Choosing option 1 should switch to another context.""" - config_path = _multi_context_config(tmp_path=tmp_path) - runner = CliRunner() - result = runner.invoke(cli, ['--config', config_path, 'logout'], input='1\n1\n') - assert result.exit_code == 0 - cfg = _load_config(config_path) - assert cfg.current_context == 'staging' - assert "No credentials were cleared" in result.output - - def test_interactive_logout_delete_multi(self, tmp_path): - """Choosing option 3 with multiple contexts should delete and switch.""" - config_path = _multi_context_config(tmp_path=tmp_path) - runner = CliRunner() - result = runner.invoke(cli, ['--config', config_path, 'logout'], input='3\n') - assert result.exit_code == 0 - cfg = _load_config(config_path) - assert 'default' not in cfg.contexts - assert cfg.current_context == 'staging' diff --git a/tests/test_cli_whoami_app.py b/tests/test_cli_whoami_app.py index 6268d2b8..edb472af 100644 --- a/tests/test_cli_whoami_app.py +++ b/tests/test_cli_whoami_app.py @@ -54,7 +54,7 @@ def test_whoami_displays_context_user_id(self, tmp_path): result = runner.invoke(cli, ['--config', config_path, 'whoami']) assert result.exit_code == 0 - assert 'Context User ID: test_user' in result.output + assert 'test_user' in result.output def test_whoami_handles_api_error(self, tmp_path): """should handle API errors gracefully.""" @@ -70,7 +70,7 @@ def test_whoami_handles_api_error(self, tmp_path): result = runner.invoke(cli, ['--config', config_path, 'whoami']) assert result.exit_code == 0 - assert 'Context User ID: test_user' in result.output + assert 'test_user' in result.output class TestAppList: diff --git a/tests/workflow/__init__.py b/tests/workflow/__init__.py new file mode 100644 index 00000000..e69de29b