216 lines
7.5 KiB
Python
216 lines
7.5 KiB
Python
"""Tests for `ChatAPI._resolve_model` + the Client-level catalog cache.
|
|
|
|
Resolution rules (mirrored from PROGRESS.md "Phase 6 / 6a"):
|
|
|
|
1. `ModelInfo` argument → use `.model` and `.provider`, ignore the
|
|
`provider=` kwarg.
|
|
2. `str` + `provider=` → pass through verbatim, no catalog lookup.
|
|
3. `str` only:
|
|
- match catalog id (`info.id`),
|
|
- else match wire id (`info.model`),
|
|
- else match display name (`info.name`),
|
|
- else raise `ValueError`.
|
|
|
|
The catalog is fetched at most once per Client and cached. We assert this
|
|
by mocking `/ai/models` and counting requests across two `chat.complete`
|
|
calls.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from aioresponses import aioresponses
|
|
|
|
from raycast_api.ai import Message
|
|
from raycast_api.ai.chat import ChatAPI
|
|
from raycast_api.ai.models import ModelInfo, ModelsResponse
|
|
from raycast_api.client import Client
|
|
from raycast_api.config import Config
|
|
from raycast_api.signing_spec import RotRange, SigningSpec
|
|
|
|
|
|
def _config() -> Config:
|
|
return Config(
|
|
signature_secret="0" * 64,
|
|
signing_spec=SigningSpec(
|
|
rot_fn_name="Sur",
|
|
signing_fn_name="Nkt",
|
|
rot_ranges=[
|
|
RotRange(start=65, end=90, shift=13),
|
|
RotRange(start=97, end=122, shift=13),
|
|
RotRange(start=48, end=57, shift=5),
|
|
],
|
|
),
|
|
app_version="0.60.1.0",
|
|
user_agent="Raycast/0.60.1.0 (x-macOS Version 26.3.1)",
|
|
bundle_hash="0" * 64,
|
|
launcher_hash="0" * 64,
|
|
)
|
|
|
|
|
|
def _client(**kwargs: Any) -> Client:
|
|
return Client(
|
|
config=_config(),
|
|
bearer_token="rca_test",
|
|
device_id="a" * 64,
|
|
clock=lambda: 1700000000,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
CATALOG_PAYLOAD = {
|
|
"models": [
|
|
{
|
|
"id": "google-gemini-3.1-pro-preview",
|
|
"name": "Gemini 3.1 Pro Preview",
|
|
"model": "gemini-3.1-pro-preview",
|
|
"provider": "google",
|
|
"provider_name": "Google",
|
|
},
|
|
{
|
|
"id": "anthropic-claude-sonnet-4-6",
|
|
"name": "Claude Sonnet 4.6",
|
|
"model": "claude-sonnet-4-6",
|
|
"provider": "anthropic",
|
|
"provider_name": "Anthropic",
|
|
},
|
|
],
|
|
"default_models": {"chat": "google-gemini-3.1-pro-preview"},
|
|
"free_models": ["anthropic-claude-sonnet-4-6"],
|
|
}
|
|
|
|
|
|
def _catalog() -> ModelsResponse:
|
|
return ModelsResponse.from_wire(CATALOG_PAYLOAD)
|
|
|
|
|
|
|
|
|
|
class TestResolveModel:
|
|
async def test_model_info_argument_wins(self) -> None:
|
|
"""A `ModelInfo` short-circuits any catalog lookup; `provider=` is ignored."""
|
|
async with _client(models=_catalog()) as client:
|
|
chat = ChatAPI(client)
|
|
info = _catalog().by_id("google-gemini-3.1-pro-preview")
|
|
assert info is not None
|
|
wire, provider = await chat._resolve_model(info, provider="ignored")
|
|
assert wire == "gemini-3.1-pro-preview"
|
|
assert provider == "google"
|
|
|
|
async def test_string_plus_provider_passes_through(self) -> None:
|
|
"""The escape hatch: explicit `provider=` skips the catalog entirely."""
|
|
async with _client(models=_catalog()) as client:
|
|
chat = ChatAPI(client)
|
|
wire, provider = await chat._resolve_model(
|
|
"some-future-model", provider="custom-provider"
|
|
)
|
|
assert wire == "some-future-model"
|
|
assert provider == "custom-provider"
|
|
|
|
async def test_string_matches_catalog_id(self) -> None:
|
|
async with _client(models=_catalog()) as client:
|
|
chat = ChatAPI(client)
|
|
wire, provider = await chat._resolve_model(
|
|
"google-gemini-3.1-pro-preview", provider=None
|
|
)
|
|
assert wire == "gemini-3.1-pro-preview"
|
|
assert provider == "google"
|
|
|
|
async def test_string_matches_wire_id(self) -> None:
|
|
async with _client(models=_catalog()) as client:
|
|
chat = ChatAPI(client)
|
|
wire, provider = await chat._resolve_model(
|
|
"gemini-3.1-pro-preview", provider=None
|
|
)
|
|
assert wire == "gemini-3.1-pro-preview"
|
|
assert provider == "google"
|
|
|
|
async def test_string_matches_display_name(self) -> None:
|
|
async with _client(models=_catalog()) as client:
|
|
chat = ChatAPI(client)
|
|
wire, provider = await chat._resolve_model(
|
|
"Claude Sonnet 4.6", provider=None
|
|
)
|
|
assert wire == "claude-sonnet-4-6"
|
|
assert provider == "anthropic"
|
|
|
|
async def test_unknown_string_raises_value_error(self) -> None:
|
|
async with _client(models=_catalog()) as client:
|
|
chat = ChatAPI(client)
|
|
with pytest.raises(ValueError, match="not found in catalog"):
|
|
await chat._resolve_model("totally-made-up", provider=None)
|
|
|
|
|
|
|
|
|
|
class TestCatalogCache:
|
|
async def test_catalog_fetched_once_and_reused(self) -> None:
|
|
"""Two `_resolve_model` calls with no `provider=` should hit
|
|
`/ai/models` exactly once (catalog cached on the Client)."""
|
|
call_count = {"n": 0}
|
|
|
|
def _cb(url: Any, **kwargs: Any) -> Any:
|
|
call_count["n"] += 1
|
|
from aioresponses import CallbackResult
|
|
|
|
return CallbackResult(status=200, payload=CATALOG_PAYLOAD)
|
|
|
|
with aioresponses() as mocked:
|
|
mocked.get(
|
|
"https://backend.raycast.com/api/v1/ai/models",
|
|
callback=_cb,
|
|
repeat=True,
|
|
)
|
|
async with _client() as client:
|
|
chat = ChatAPI(client)
|
|
a = await chat._resolve_model(
|
|
"google-gemini-3.1-pro-preview", provider=None
|
|
)
|
|
b = await chat._resolve_model("claude-sonnet-4-6", provider=None)
|
|
assert a == ("gemini-3.1-pro-preview", "google")
|
|
assert b == ("claude-sonnet-4-6", "anthropic")
|
|
assert call_count["n"] == 1
|
|
|
|
async def test_models_constructor_kwarg_skips_fetch(self) -> None:
|
|
"""Passing `models=` to Client should mean zero `/ai/models` hits."""
|
|
with aioresponses() as mocked:
|
|
mocked.get(
|
|
"https://backend.raycast.com/api/v1/ai/models",
|
|
status=500,
|
|
repeat=True,
|
|
)
|
|
async with _client(models=_catalog()) as client:
|
|
chat = ChatAPI(client)
|
|
wire, provider = await chat._resolve_model(
|
|
"google-gemini-3.1-pro-preview", provider=None
|
|
)
|
|
assert (wire, provider) == ("gemini-3.1-pro-preview", "google")
|
|
|
|
async def test_invalidate_models_cache_refetches(self) -> None:
|
|
call_count = {"n": 0}
|
|
|
|
def _cb(url: Any, **kwargs: Any) -> Any:
|
|
call_count["n"] += 1
|
|
from aioresponses import CallbackResult
|
|
|
|
return CallbackResult(status=200, payload=CATALOG_PAYLOAD)
|
|
|
|
with aioresponses() as mocked:
|
|
mocked.get(
|
|
"https://backend.raycast.com/api/v1/ai/models",
|
|
callback=_cb,
|
|
repeat=True,
|
|
)
|
|
async with _client() as client:
|
|
chat = ChatAPI(client)
|
|
await chat._resolve_model(
|
|
"google-gemini-3.1-pro-preview", provider=None
|
|
)
|
|
client.invalidate_models_cache()
|
|
await chat._resolve_model(
|
|
"google-gemini-3.1-pro-preview", provider=None
|
|
)
|
|
assert call_count["n"] == 2
|