379 lines
13 KiB
Python
379 lines
13 KiB
Python
"""Tests for `raycast_api.ai.chat.ChatAPI`.
|
|
|
|
Covers, in roughly increasing scope:
|
|
|
|
- `_build_body` produces fields in the right order with the right defaults
|
|
per `source`.
|
|
- Tool/message serialisation matches the wire shape from BUNDLE_NOTES §3.
|
|
- `UserPreferences.render()` matches the byte-exact preamble from the
|
|
real Raycast `Ya()` function.
|
|
- `complete(...)` accumulates deltas into a single `ChatResult` and
|
|
handles the streamed-arguments case (concatenating tool-call argument
|
|
fragments across chunks).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from aioresponses import aioresponses
|
|
|
|
from raycast_api.ai import (
|
|
ChatAPI,
|
|
ChatStreamChunk,
|
|
Message,
|
|
RemoteTool,
|
|
Source,
|
|
Tool,
|
|
ToolCall,
|
|
UserPreferences,
|
|
)
|
|
from raycast_api.client import Client
|
|
from raycast_api.config import Config
|
|
from raycast_api.signing_spec import RotRange, SigningSpec
|
|
|
|
|
|
REFERENCE_SECRET = "6bc455473576ce2cd6f70426caff867aabbe3f7291c1a79681af5e8ce0ca1408"
|
|
DEVICE_ID = "20eca913cada74f879e6535304f9d44da380c28eb855065c0d71017a3d7c3099"
|
|
FIXED_TIMESTAMP = 1778858809
|
|
|
|
|
|
def _config() -> Config:
|
|
return Config(
|
|
signature_secret=REFERENCE_SECRET,
|
|
signing_spec=SigningSpec(
|
|
rot_fn_name="Sur",
|
|
signing_fn_name="Nkt",
|
|
rot_ranges=[
|
|
RotRange(start=65, end=90, shift=13),
|
|
RotRange(start=97, end=122, shift=13),
|
|
RotRange(start=48, end=57, shift=5),
|
|
],
|
|
),
|
|
app_version="0.60.1.0",
|
|
user_agent="Raycast/0.60.1.0 (x-macOS Version 26.3.1)",
|
|
bundle_hash="0" * 64,
|
|
launcher_hash="0" * 64,
|
|
)
|
|
|
|
|
|
def _client(**kwargs: Any) -> Client:
|
|
return Client(
|
|
config=_config(),
|
|
bearer_token="rca_test_token",
|
|
device_id=DEVICE_ID,
|
|
clock=lambda: FIXED_TIMESTAMP,
|
|
locale="en-GB",
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
|
|
|
|
class TestUserPreferences:
|
|
def test_render_matches_real_client_wording(self) -> None:
|
|
"""The block must be byte-identical to what `Ya()` emits.
|
|
|
|
The captured request body in `request_simple.curl.txt` contains:
|
|
|
|
<user-preferences>\\n
|
|
The user has the following system preferences:\\n
|
|
- Locale: en-GB\\n
|
|
- Timezone: Europe/Warsaw\\n
|
|
- Current Date: 2026-05-15\\n
|
|
- Use the system preferences to format your answers accordingly\\n
|
|
</user-preferences>
|
|
|
|
Any deviation (spacing, line breaks, punctuation) breaks the
|
|
fingerprint match.
|
|
"""
|
|
prefs = UserPreferences(
|
|
locale="en-GB", timezone="Europe/Warsaw", current_date="2026-05-15"
|
|
)
|
|
rendered = prefs.render()
|
|
assert rendered == (
|
|
"<user-preferences>\n"
|
|
" The user has the following system preferences:\n"
|
|
" - Locale: en-GB\n"
|
|
" - Timezone: Europe/Warsaw\n"
|
|
" - Current Date: 2026-05-15\n"
|
|
" - Use the system preferences to format your answers accordingly\n"
|
|
"</user-preferences>"
|
|
)
|
|
|
|
def test_auto_picks_today_and_locale_argument(self) -> None:
|
|
import datetime
|
|
|
|
prefs = UserPreferences.auto(locale="ru-RU")
|
|
assert prefs.locale == "ru-RU"
|
|
assert prefs.current_date == datetime.date.today().isoformat()
|
|
assert prefs.timezone
|
|
|
|
|
|
|
|
|
|
class TestSerialisation:
|
|
def test_remote_tool_shape(self) -> None:
|
|
assert Tool.remote("web_search").to_wire() == {
|
|
"type": "remote_tool",
|
|
"name": "web_search",
|
|
}
|
|
assert Tool.remote(RemoteTool.SEARCH_IMAGES).to_wire() == {
|
|
"type": "remote_tool",
|
|
"name": "search_images",
|
|
}
|
|
|
|
def test_local_tool_shape(self) -> None:
|
|
t = Tool.local(
|
|
name="weather__get",
|
|
description="get weather",
|
|
parameters={"type": "object", "properties": {"city": {"type": "string"}}},
|
|
)
|
|
assert t.to_wire() == {
|
|
"type": "local_tool",
|
|
"function": {
|
|
"name": "weather__get",
|
|
"description": "get weather",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"city": {"type": "string"}},
|
|
},
|
|
},
|
|
}
|
|
|
|
def test_user_message(self) -> None:
|
|
assert Message.user("hello").to_wire() == {
|
|
"role": "user",
|
|
"content": {"text": "hello"},
|
|
}
|
|
|
|
def test_assistant_with_tool_calls(self) -> None:
|
|
msg = Message.assistant(
|
|
text="",
|
|
tool_calls=[
|
|
ToolCall(
|
|
id="abc", name="coffee__caffeinate-for", arguments='{"minutes":5}'
|
|
)
|
|
],
|
|
extra_content={"google": {"thought_signature": "xyz"}},
|
|
)
|
|
assert msg.to_wire() == {
|
|
"role": "assistant",
|
|
"content": {"text": ""},
|
|
"tool_calls": [
|
|
{
|
|
"id": "abc",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "coffee__caffeinate-for",
|
|
"arguments": '{"minutes":5}',
|
|
},
|
|
}
|
|
],
|
|
"extra_content": {"google": {"thought_signature": "xyz"}},
|
|
}
|
|
|
|
def test_tool_message_wraps_string_as_mcp_text(self) -> None:
|
|
msg = Message.tool(
|
|
tool_call_id="abc",
|
|
name="coffee__caffeinate-for",
|
|
result="Mac will stay awake for 5m",
|
|
)
|
|
assert msg.to_wire() == {
|
|
"role": "tool",
|
|
"content": {
|
|
"text": '[{"type":"text","text":"Mac will stay awake for 5m"}]'
|
|
},
|
|
"name": "coffee__caffeinate-for",
|
|
"tool_call_id": "abc",
|
|
}
|
|
|
|
|
|
|
|
|
|
class TestBuildBody:
|
|
def test_minimal_body_field_order(self) -> None:
|
|
"""First-turn body should serialise fields in the captured order."""
|
|
chat = ChatAPI(_client())
|
|
body = chat._build_body(
|
|
model="gemini-3.1-pro-preview",
|
|
provider="google",
|
|
messages=[Message.user("привет")],
|
|
source=Source.AI_CHAT,
|
|
buffer_id="8480fbbb-4592-4257-812d-f24a67da3c07",
|
|
message_id="2f138e1c-edcf-495b-915c-db5cbb154674",
|
|
locale="en-GB",
|
|
current_date="2026-05-15",
|
|
system_instructions="markdown",
|
|
additional_system_instructions="<user-preferences>\n The user has the following system preferences:\n - Locale: en-GB\n - Timezone: Europe/Warsaw\n - Current Date: 2026-05-15\n - Use the system preferences to format your answers accordingly\n</user-preferences>",
|
|
temperature=0,
|
|
reasoning_effort="high",
|
|
tools=[
|
|
Tool.remote(RemoteTool.WEB_SEARCH).to_wire(),
|
|
Tool.remote(RemoteTool.SEARCH_IMAGES).to_wire(),
|
|
Tool.remote(RemoteTool.READ_PAGE).to_wire(),
|
|
],
|
|
tool_choice="auto",
|
|
resume_from=None,
|
|
)
|
|
keys = list(body.keys())
|
|
assert keys == [
|
|
"system_instructions",
|
|
"additional_system_instructions",
|
|
"locale",
|
|
"temperature",
|
|
"current_date",
|
|
"message_id",
|
|
"reasoning_effort",
|
|
"messages",
|
|
"tools",
|
|
"tool_choice",
|
|
"source",
|
|
"model",
|
|
"provider",
|
|
"buffer_id",
|
|
]
|
|
|
|
def test_omits_optional_fields_when_none(self) -> None:
|
|
"""No tools → no tools / tool_choice in the body at all."""
|
|
chat = ChatAPI(_client())
|
|
body = chat._build_body(
|
|
model="m",
|
|
provider="p",
|
|
messages=[Message.user("hi")],
|
|
source=Source.AI_CHAT,
|
|
buffer_id="b",
|
|
message_id="m",
|
|
locale="en-US",
|
|
current_date=None,
|
|
system_instructions=None,
|
|
additional_system_instructions=None,
|
|
temperature=None,
|
|
reasoning_effort=None,
|
|
tools=None,
|
|
tool_choice=None,
|
|
resume_from=None,
|
|
)
|
|
assert "tools" not in body
|
|
assert "tool_choice" not in body
|
|
assert "temperature" not in body
|
|
assert "system_instructions" not in body
|
|
assert "additional_system_instructions" not in body
|
|
assert "reasoning_effort" not in body
|
|
assert "current_date" not in body
|
|
assert body["model"] == "m"
|
|
assert body["provider"] == "p"
|
|
assert body["buffer_id"] == "b"
|
|
assert body["source"] == "ai_chat"
|
|
|
|
def test_source_default_temperature_only_applies_when_unspecified(self) -> None:
|
|
"""Quick AI defaults to 0.2; passing temperature=0 overrides."""
|
|
chat = ChatAPI(_client())
|
|
from raycast_api.ai.chat import _SOURCE_DEFAULTS
|
|
|
|
defaults = _SOURCE_DEFAULTS[Source.QUICK_AI]
|
|
assert defaults["temperature"] == 0.2
|
|
assert defaults["system_instructions"] == "plain"
|
|
|
|
|
|
|
|
|
|
class TestComplete:
|
|
@pytest.mark.asyncio
|
|
async def test_complete_concatenates_streamed_tool_arguments(self) -> None:
|
|
"""If `arguments` arrives in multiple chunks, they're concatenated.
|
|
|
|
Constructs a synthetic SSE stream where the same tool_call id
|
|
appears across two chunks with partial `arguments` payloads.
|
|
"""
|
|
sse = (
|
|
b"id: 0\n"
|
|
b'data: {"text":"","tool_calls":[{"id":"tc1","name":"f","arguments":"{\\"a\\":"}]}\n\n'
|
|
b"id: 1\n"
|
|
b'data: {"text":"","tool_calls":[{"id":"tc1","arguments":"1}"}]}\n\n'
|
|
b"id: 2\n"
|
|
b'data: {"text":"","finish_reason":"STOP","usage":{"input_tokens":1,"output_tokens":1}}\n\n'
|
|
b'event: complete\ndata: {"complete":true}\n\n'
|
|
)
|
|
with aioresponses() as mocked:
|
|
mocked.post(
|
|
"https://backend.raycast.com/api/v1/ai/chat_completions",
|
|
status=200,
|
|
body=sse,
|
|
headers={"Content-Type": "text/event-stream"},
|
|
)
|
|
async with _client() as client:
|
|
result = await client.chat.complete(
|
|
model="m",
|
|
provider="p",
|
|
messages=[Message.user("x")],
|
|
user_preferences=False,
|
|
)
|
|
assert len(result.tool_calls) == 1
|
|
assert result.tool_calls[0].id == "tc1"
|
|
assert result.tool_calls[0].name == "f"
|
|
assert result.tool_calls[0].arguments == '{"a":1}'
|
|
|
|
|
|
|
|
|
|
class TestSignedBytesMatch:
|
|
"""When we call `client.chat.stream`, the body bytes the request carries
|
|
must equal the bytes the Signer signed. `aioresponses` lets us capture
|
|
the outgoing body via a callback.
|
|
"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_post_body_matches_signed_bytes(self) -> None:
|
|
captured: dict[str, Any] = {}
|
|
|
|
def _cb(url: Any, **kwargs: Any) -> Any:
|
|
captured["data"] = kwargs.get("data")
|
|
captured["headers"] = kwargs.get("headers")
|
|
from aioresponses import CallbackResult
|
|
|
|
return CallbackResult(
|
|
status=200,
|
|
body=b'event: complete\ndata: {"complete":true}\n\n',
|
|
headers={"Content-Type": "text/event-stream"},
|
|
)
|
|
|
|
with aioresponses() as mocked:
|
|
mocked.post(
|
|
"https://backend.raycast.com/api/v1/ai/chat_completions", callback=_cb
|
|
)
|
|
async with _client() as client:
|
|
async for _ in client.chat.stream(
|
|
model="m",
|
|
provider="p",
|
|
messages=[Message.user("hi")],
|
|
user_preferences=False,
|
|
buffer_id="b",
|
|
message_id="mid",
|
|
current_date="2026-05-15",
|
|
):
|
|
pass
|
|
|
|
body_bytes = captured["data"]
|
|
if hasattr(body_bytes, "_value"):
|
|
body_bytes = body_bytes._value
|
|
assert isinstance(body_bytes, (bytes, bytearray))
|
|
|
|
from raycast_api.signing import Signer
|
|
|
|
signer = Signer(spec=_config().signing_spec, secret=REFERENCE_SECRET)
|
|
expected_sig = signer.sign(
|
|
timestamp=str(FIXED_TIMESTAMP), device_id=DEVICE_ID, body=bytes(body_bytes)
|
|
)
|
|
assert captured["headers"]["X-Raycast-Signature-v2"] == expected_sig
|
|
|
|
parsed = json.loads(bytes(body_bytes))
|
|
assert parsed["model"] == "m"
|
|
assert parsed["provider"] == "p"
|
|
assert parsed["buffer_id"] == "b"
|
|
assert parsed["message_id"] == "mid"
|
|
assert parsed["source"] == "ai_chat"
|
|
assert parsed["system_instructions"] == "markdown"
|