"""Tests for `ChatAPI._resolve_model` + the Client-level catalog cache. Resolution rules (mirrored from PROGRESS.md "Phase 6 / 6a"): 1. `ModelInfo` argument → use `.model` and `.provider`, ignore the `provider=` kwarg. 2. `str` + `provider=` → pass through verbatim, no catalog lookup. 3. `str` only: - match catalog id (`info.id`), - else match wire id (`info.model`), - else match display name (`info.name`), - else raise `ValueError`. The catalog is fetched at most once per Client and cached. We assert this by mocking `/ai/models` and counting requests across two `chat.complete` calls. """ from __future__ import annotations from typing import Any import pytest from aioresponses import aioresponses from raycast_api.ai import Message from raycast_api.ai.chat import ChatAPI from raycast_api.ai.models import ModelInfo, ModelsResponse from raycast_api.client import Client from raycast_api.config import Config from raycast_api.signing_spec import RotRange, SigningSpec def _config() -> Config: return Config( signature_secret="0" * 64, signing_spec=SigningSpec( rot_fn_name="Sur", signing_fn_name="Nkt", rot_ranges=[ RotRange(start=65, end=90, shift=13), RotRange(start=97, end=122, shift=13), RotRange(start=48, end=57, shift=5), ], ), app_version="0.60.1.0", user_agent="Raycast/0.60.1.0 (x-macOS Version 26.3.1)", bundle_hash="0" * 64, launcher_hash="0" * 64, ) def _client(**kwargs: Any) -> Client: return Client( config=_config(), bearer_token="rca_test", device_id="a" * 64, clock=lambda: 1700000000, **kwargs, ) CATALOG_PAYLOAD = { "models": [ { "id": "google-gemini-3.1-pro-preview", "name": "Gemini 3.1 Pro Preview", "model": "gemini-3.1-pro-preview", "provider": "google", "provider_name": "Google", }, { "id": "anthropic-claude-sonnet-4-6", "name": "Claude Sonnet 4.6", "model": "claude-sonnet-4-6", "provider": "anthropic", "provider_name": "Anthropic", }, ], "default_models": {"chat": "google-gemini-3.1-pro-preview"}, "free_models": ["anthropic-claude-sonnet-4-6"], } def _catalog() -> ModelsResponse: return ModelsResponse.from_wire(CATALOG_PAYLOAD) class TestResolveModel: async def test_model_info_argument_wins(self) -> None: """A `ModelInfo` short-circuits any catalog lookup; `provider=` is ignored.""" async with _client(models=_catalog()) as client: chat = ChatAPI(client) info = _catalog().by_id("google-gemini-3.1-pro-preview") assert info is not None wire, provider = await chat._resolve_model(info, provider="ignored") assert wire == "gemini-3.1-pro-preview" assert provider == "google" async def test_string_plus_provider_passes_through(self) -> None: """The escape hatch: explicit `provider=` skips the catalog entirely.""" async with _client(models=_catalog()) as client: chat = ChatAPI(client) wire, provider = await chat._resolve_model( "some-future-model", provider="custom-provider" ) assert wire == "some-future-model" assert provider == "custom-provider" async def test_string_matches_catalog_id(self) -> None: async with _client(models=_catalog()) as client: chat = ChatAPI(client) wire, provider = await chat._resolve_model( "google-gemini-3.1-pro-preview", provider=None ) assert wire == "gemini-3.1-pro-preview" assert provider == "google" async def test_string_matches_wire_id(self) -> None: async with _client(models=_catalog()) as client: chat = ChatAPI(client) wire, provider = await chat._resolve_model( "gemini-3.1-pro-preview", provider=None ) assert wire == "gemini-3.1-pro-preview" assert provider == "google" async def test_string_matches_display_name(self) -> None: async with _client(models=_catalog()) as client: chat = ChatAPI(client) wire, provider = await chat._resolve_model( "Claude Sonnet 4.6", provider=None ) assert wire == "claude-sonnet-4-6" assert provider == "anthropic" async def test_unknown_string_raises_value_error(self) -> None: async with _client(models=_catalog()) as client: chat = ChatAPI(client) with pytest.raises(ValueError, match="not found in catalog"): await chat._resolve_model("totally-made-up", provider=None) class TestCatalogCache: async def test_catalog_fetched_once_and_reused(self) -> None: """Two `_resolve_model` calls with no `provider=` should hit `/ai/models` exactly once (catalog cached on the Client).""" call_count = {"n": 0} def _cb(url: Any, **kwargs: Any) -> Any: call_count["n"] += 1 from aioresponses import CallbackResult return CallbackResult(status=200, payload=CATALOG_PAYLOAD) with aioresponses() as mocked: mocked.get( "https://backend.raycast.com/api/v1/ai/models", callback=_cb, repeat=True, ) async with _client() as client: chat = ChatAPI(client) a = await chat._resolve_model( "google-gemini-3.1-pro-preview", provider=None ) b = await chat._resolve_model("claude-sonnet-4-6", provider=None) assert a == ("gemini-3.1-pro-preview", "google") assert b == ("claude-sonnet-4-6", "anthropic") assert call_count["n"] == 1 async def test_models_constructor_kwarg_skips_fetch(self) -> None: """Passing `models=` to Client should mean zero `/ai/models` hits.""" with aioresponses() as mocked: mocked.get( "https://backend.raycast.com/api/v1/ai/models", status=500, repeat=True, ) async with _client(models=_catalog()) as client: chat = ChatAPI(client) wire, provider = await chat._resolve_model( "google-gemini-3.1-pro-preview", provider=None ) assert (wire, provider) == ("gemini-3.1-pro-preview", "google") async def test_invalidate_models_cache_refetches(self) -> None: call_count = {"n": 0} def _cb(url: Any, **kwargs: Any) -> Any: call_count["n"] += 1 from aioresponses import CallbackResult return CallbackResult(status=200, payload=CATALOG_PAYLOAD) with aioresponses() as mocked: mocked.get( "https://backend.raycast.com/api/v1/ai/models", callback=_cb, repeat=True, ) async with _client() as client: chat = ChatAPI(client) await chat._resolve_model( "google-gemini-3.1-pro-preview", provider=None ) client.invalidate_models_cache() await chat._resolve_model( "google-gemini-3.1-pro-preview", provider=None ) assert call_count["n"] == 2