Files
raycast-api/tests/test_chat_model_resolution.py
T

216 lines
7.5 KiB
Python

"""Tests for `ChatAPI._resolve_model` + the Client-level catalog cache.
Resolution rules (mirrored from PROGRESS.md "Phase 6 / 6a"):
1. `ModelInfo` argument → use `.model` and `.provider`, ignore the
`provider=` kwarg.
2. `str` + `provider=` → pass through verbatim, no catalog lookup.
3. `str` only:
- match catalog id (`info.id`),
- else match wire id (`info.model`),
- else match display name (`info.name`),
- else raise `ValueError`.
The catalog is fetched at most once per Client and cached. We assert this
by mocking `/ai/models` and counting requests across two `chat.complete`
calls.
"""
from __future__ import annotations
from typing import Any
import pytest
from aioresponses import aioresponses
from raycast_api.ai import Message
from raycast_api.ai.chat import ChatAPI
from raycast_api.ai.models import ModelInfo, ModelsResponse
from raycast_api.client import Client
from raycast_api.config import Config
from raycast_api.signing_spec import RotRange, SigningSpec
def _config() -> Config:
return Config(
signature_secret="0" * 64,
signing_spec=SigningSpec(
rot_fn_name="Sur",
signing_fn_name="Nkt",
rot_ranges=[
RotRange(start=65, end=90, shift=13),
RotRange(start=97, end=122, shift=13),
RotRange(start=48, end=57, shift=5),
],
),
app_version="0.60.1.0",
user_agent="Raycast/0.60.1.0 (x-macOS Version 26.3.1)",
bundle_hash="0" * 64,
launcher_hash="0" * 64,
)
def _client(**kwargs: Any) -> Client:
return Client(
config=_config(),
bearer_token="rca_test",
device_id="a" * 64,
clock=lambda: 1700000000,
**kwargs,
)
CATALOG_PAYLOAD = {
"models": [
{
"id": "google-gemini-3.1-pro-preview",
"name": "Gemini 3.1 Pro Preview",
"model": "gemini-3.1-pro-preview",
"provider": "google",
"provider_name": "Google",
},
{
"id": "anthropic-claude-sonnet-4-6",
"name": "Claude Sonnet 4.6",
"model": "claude-sonnet-4-6",
"provider": "anthropic",
"provider_name": "Anthropic",
},
],
"default_models": {"chat": "google-gemini-3.1-pro-preview"},
"free_models": ["anthropic-claude-sonnet-4-6"],
}
def _catalog() -> ModelsResponse:
return ModelsResponse.from_wire(CATALOG_PAYLOAD)
class TestResolveModel:
async def test_model_info_argument_wins(self) -> None:
"""A `ModelInfo` short-circuits any catalog lookup; `provider=` is ignored."""
async with _client(models=_catalog()) as client:
chat = ChatAPI(client)
info = _catalog().by_id("google-gemini-3.1-pro-preview")
assert info is not None
wire, provider = await chat._resolve_model(info, provider="ignored")
assert wire == "gemini-3.1-pro-preview"
assert provider == "google"
async def test_string_plus_provider_passes_through(self) -> None:
"""The escape hatch: explicit `provider=` skips the catalog entirely."""
async with _client(models=_catalog()) as client:
chat = ChatAPI(client)
wire, provider = await chat._resolve_model(
"some-future-model", provider="custom-provider"
)
assert wire == "some-future-model"
assert provider == "custom-provider"
async def test_string_matches_catalog_id(self) -> None:
async with _client(models=_catalog()) as client:
chat = ChatAPI(client)
wire, provider = await chat._resolve_model(
"google-gemini-3.1-pro-preview", provider=None
)
assert wire == "gemini-3.1-pro-preview"
assert provider == "google"
async def test_string_matches_wire_id(self) -> None:
async with _client(models=_catalog()) as client:
chat = ChatAPI(client)
wire, provider = await chat._resolve_model(
"gemini-3.1-pro-preview", provider=None
)
assert wire == "gemini-3.1-pro-preview"
assert provider == "google"
async def test_string_matches_display_name(self) -> None:
async with _client(models=_catalog()) as client:
chat = ChatAPI(client)
wire, provider = await chat._resolve_model(
"Claude Sonnet 4.6", provider=None
)
assert wire == "claude-sonnet-4-6"
assert provider == "anthropic"
async def test_unknown_string_raises_value_error(self) -> None:
async with _client(models=_catalog()) as client:
chat = ChatAPI(client)
with pytest.raises(ValueError, match="not found in catalog"):
await chat._resolve_model("totally-made-up", provider=None)
class TestCatalogCache:
async def test_catalog_fetched_once_and_reused(self) -> None:
"""Two `_resolve_model` calls with no `provider=` should hit
`/ai/models` exactly once (catalog cached on the Client)."""
call_count = {"n": 0}
def _cb(url: Any, **kwargs: Any) -> Any:
call_count["n"] += 1
from aioresponses import CallbackResult
return CallbackResult(status=200, payload=CATALOG_PAYLOAD)
with aioresponses() as mocked:
mocked.get(
"https://backend.raycast.com/api/v1/ai/models",
callback=_cb,
repeat=True,
)
async with _client() as client:
chat = ChatAPI(client)
a = await chat._resolve_model(
"google-gemini-3.1-pro-preview", provider=None
)
b = await chat._resolve_model("claude-sonnet-4-6", provider=None)
assert a == ("gemini-3.1-pro-preview", "google")
assert b == ("claude-sonnet-4-6", "anthropic")
assert call_count["n"] == 1
async def test_models_constructor_kwarg_skips_fetch(self) -> None:
"""Passing `models=` to Client should mean zero `/ai/models` hits."""
with aioresponses() as mocked:
mocked.get(
"https://backend.raycast.com/api/v1/ai/models",
status=500,
repeat=True,
)
async with _client(models=_catalog()) as client:
chat = ChatAPI(client)
wire, provider = await chat._resolve_model(
"google-gemini-3.1-pro-preview", provider=None
)
assert (wire, provider) == ("gemini-3.1-pro-preview", "google")
async def test_invalidate_models_cache_refetches(self) -> None:
call_count = {"n": 0}
def _cb(url: Any, **kwargs: Any) -> Any:
call_count["n"] += 1
from aioresponses import CallbackResult
return CallbackResult(status=200, payload=CATALOG_PAYLOAD)
with aioresponses() as mocked:
mocked.get(
"https://backend.raycast.com/api/v1/ai/models",
callback=_cb,
repeat=True,
)
async with _client() as client:
chat = ChatAPI(client)
await chat._resolve_model(
"google-gemini-3.1-pro-preview", provider=None
)
client.invalidate_models_cache()
await chat._resolve_model(
"google-gemini-3.1-pro-preview", provider=None
)
assert call_count["n"] == 2