feat: vibed out some slop over here
This commit is contained in:
@@ -0,0 +1,96 @@
|
||||
"""raycast-api — Python client for the Raycast backend API."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from raycast_api.ai import (
|
||||
Attachment,
|
||||
ChatResult,
|
||||
ChatStreamChunk,
|
||||
Message,
|
||||
ModelInfo,
|
||||
ModelsResponse,
|
||||
RemoteTool,
|
||||
Source,
|
||||
Tool,
|
||||
ToolCall,
|
||||
UserPreferences,
|
||||
)
|
||||
from raycast_api.client import Client, RetryPolicy, SSEEvent
|
||||
from raycast_api.config import Config
|
||||
from raycast_api.signing_spec import SigningSpec
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Attachment",
|
||||
"ChatResult",
|
||||
"ChatStreamChunk",
|
||||
"Client",
|
||||
"Config",
|
||||
"Message",
|
||||
"ModelInfo",
|
||||
"ModelsResponse",
|
||||
"RemoteTool",
|
||||
"RetryPolicy",
|
||||
"SSEEvent",
|
||||
"SigningSpec",
|
||||
"Source",
|
||||
"Tool",
|
||||
"ToolCall",
|
||||
"UserPreferences",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in {"Config", "SigningSpec"}:
|
||||
from raycast_api.config import Config
|
||||
from raycast_api.signing_spec import SigningSpec
|
||||
|
||||
return {"Config": Config, "SigningSpec": SigningSpec}[name]
|
||||
if name in {"Client", "RetryPolicy", "SSEEvent"}:
|
||||
from raycast_api.client import Client, RetryPolicy, SSEEvent
|
||||
|
||||
return {"Client": Client, "RetryPolicy": RetryPolicy, "SSEEvent": SSEEvent}[
|
||||
name
|
||||
]
|
||||
if name in {
|
||||
"Attachment",
|
||||
"ChatResult",
|
||||
"ChatStreamChunk",
|
||||
"Message",
|
||||
"ModelInfo",
|
||||
"ModelsResponse",
|
||||
"RemoteTool",
|
||||
"Source",
|
||||
"Tool",
|
||||
"ToolCall",
|
||||
"UserPreferences",
|
||||
}:
|
||||
from raycast_api.ai import (
|
||||
Attachment,
|
||||
ChatResult,
|
||||
ChatStreamChunk,
|
||||
Message,
|
||||
ModelInfo,
|
||||
ModelsResponse,
|
||||
RemoteTool,
|
||||
Source,
|
||||
Tool,
|
||||
ToolCall,
|
||||
UserPreferences,
|
||||
)
|
||||
|
||||
return {
|
||||
"Attachment": Attachment,
|
||||
"ChatResult": ChatResult,
|
||||
"ChatStreamChunk": ChatStreamChunk,
|
||||
"Message": Message,
|
||||
"ModelInfo": ModelInfo,
|
||||
"ModelsResponse": ModelsResponse,
|
||||
"RemoteTool": RemoteTool,
|
||||
"Source": Source,
|
||||
"Tool": Tool,
|
||||
"ToolCall": ToolCall,
|
||||
"UserPreferences": UserPreferences,
|
||||
}[name]
|
||||
raise AttributeError(name)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""AI endpoint wrappers (chat completions, models, files, me).
|
||||
|
||||
Built on top of `raycast_api.client.Client`. The wrappers translate between
|
||||
the wire shapes documented in `BUNDLE_NOTES.md` §3-§4 and ergonomic Python
|
||||
dataclasses, and own the small amount of business logic the chat endpoint
|
||||
needs (preamble injection, source-specific defaults, SSE → typed chunks).
|
||||
"""
|
||||
|
||||
from raycast_api.ai.chat import ChatAPI, ChatResult, ChatStreamChunk
|
||||
from raycast_api.ai.files import FileMetadata, FilesAPI
|
||||
from raycast_api.ai.me import MeAPI
|
||||
from raycast_api.ai.models import ModelInfo, ModelsAPI, ModelsResponse
|
||||
from raycast_api.ai.types import (
|
||||
Attachment,
|
||||
Message,
|
||||
RemoteTool,
|
||||
Source,
|
||||
Tool,
|
||||
ToolCall,
|
||||
UserPreferences,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Attachment",
|
||||
"ChatAPI",
|
||||
"ChatResult",
|
||||
"ChatStreamChunk",
|
||||
"FileMetadata",
|
||||
"FilesAPI",
|
||||
"MeAPI",
|
||||
"Message",
|
||||
"ModelInfo",
|
||||
"ModelsAPI",
|
||||
"ModelsResponse",
|
||||
"RemoteTool",
|
||||
"Source",
|
||||
"Tool",
|
||||
"ToolCall",
|
||||
"UserPreferences",
|
||||
]
|
||||
@@ -0,0 +1,534 @@
|
||||
"""`/api/v1/ai/chat_completions` — the heart of the library.
|
||||
|
||||
Two surfaces:
|
||||
|
||||
- `ChatAPI.stream(...)` — async generator yielding `ChatStreamChunk`s as
|
||||
they arrive from the SSE stream. The caller decides what to do with
|
||||
each chunk (collect text deltas, react to `tool_calls`, etc.).
|
||||
- `ChatAPI.complete(...)` — convenience that consumes the stream and
|
||||
returns a single `ChatResult` (accumulated text, final tool_calls,
|
||||
usage, finish_reason). Use this when you don't need streaming.
|
||||
|
||||
Both go through the canonical `_build_body` which produces the exact dict
|
||||
the real Raycast client sends — same field set, same field order, same
|
||||
defaults per `source`. That dict is serialised once (compact JSON, no
|
||||
spaces) and the resulting bytes go to both `Signer.sign()` and the
|
||||
network — so the signature always matches the bytes on the wire.
|
||||
|
||||
Resume: the caller passes `on_last_event_id=lambda id: ...` to `stream(...)`
|
||||
to checkpoint the latest SSE id. If the stream drops, they can call
|
||||
`ChatAPI.resume(buffer_id, last_event_id)` to pick up from where they
|
||||
were. The server is responsible for replaying everything after the given
|
||||
event id, and may return 204 No Content if there's nothing left.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from raycast_api.ai.types import (
|
||||
ChatStreamChunk,
|
||||
Message,
|
||||
RemoteTool,
|
||||
Source,
|
||||
Tool,
|
||||
ToolCall,
|
||||
UserPreferences,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
|
||||
from raycast_api.ai.models import ModelInfo
|
||||
from raycast_api.client.http import Client
|
||||
|
||||
|
||||
_SOURCE_DEFAULTS: dict[Source, dict[str, Any]] = {
|
||||
Source.AI_CHAT: {
|
||||
"system_instructions": "markdown",
|
||||
"temperature": None,
|
||||
},
|
||||
Source.QUICK_AI: {"system_instructions": "plain", "temperature": 0.2},
|
||||
Source.AI_COMMAND: {"system_instructions": "plain", "temperature": 0.2},
|
||||
Source.API: {"system_instructions": "plain", "temperature": 0.2},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatResult:
|
||||
"""The accumulated result of a chat completion.
|
||||
|
||||
Use `ChatAPI.complete(...)` to get one fully populated, or feed chunks
|
||||
from `ChatAPI.stream(...)` into `result.add(chunk)` yourself if you
|
||||
want to print text deltas live while still ending up with correctly
|
||||
merged tool_calls.
|
||||
|
||||
`text` is the full assistant reply concatenated from every `text` delta.
|
||||
`reasoning` is the chain-of-thought stream (empty for non-thinking
|
||||
models). `tool_calls` is the final assembled tool_calls list (server
|
||||
may stream `arguments` incrementally; we buffer them). `finish_reason`
|
||||
is the last one observed — typically `"STOP"` for a clean finish, or
|
||||
`"tool_calls"` if the response ends with a tool request.
|
||||
"""
|
||||
|
||||
text: str = ""
|
||||
reasoning: str = ""
|
||||
tool_calls: list[ToolCall] = field(default_factory=list)
|
||||
finish_reason: str | None = None
|
||||
usage: dict[str, int] | None = None
|
||||
extra_content: dict[str, Any] | None = None
|
||||
chunks: list[ChatStreamChunk] = field(default_factory=list)
|
||||
|
||||
# Internal state for incremental tool_call merge (see `add()`).
|
||||
_tool_buffers: dict[object, ToolCall] = field(
|
||||
default_factory=dict, repr=False, compare=False
|
||||
)
|
||||
_index_to_id: dict[int, str] = field(
|
||||
default_factory=dict, repr=False, compare=False
|
||||
)
|
||||
|
||||
def add(self, chunk: ChatStreamChunk) -> None:
|
||||
"""Merge one streamed chunk into this result.
|
||||
|
||||
Appends text/reasoning deltas, tracks usage and finish_reason, and
|
||||
— critically — handles the three-phase tool_call stream the server
|
||||
emits (first with `id`+`index`, deltas with empty `id`, then a
|
||||
final summary with `id` but no `index` and the FULL `arguments`).
|
||||
Keying naively by either id or index alone produces phantom or
|
||||
duplicated tool_calls; this method does it correctly.
|
||||
"""
|
||||
self.chunks.append(chunk)
|
||||
if chunk.text:
|
||||
self.text += chunk.text
|
||||
if chunk.reasoning:
|
||||
self.reasoning += chunk.reasoning
|
||||
if chunk.finish_reason:
|
||||
self.finish_reason = chunk.finish_reason
|
||||
if chunk.usage:
|
||||
self.usage = chunk.usage
|
||||
if chunk.tool_calls:
|
||||
self._merge_tool_calls(chunk)
|
||||
self.tool_calls = list(self._tool_buffers.values())
|
||||
|
||||
def _merge_tool_calls(self, chunk: ChatStreamChunk) -> None:
|
||||
is_final_summary = chunk.finish_reason is not None
|
||||
raw_tcs = chunk.raw.get("tool_calls") or []
|
||||
merged_extra: dict[str, Any] = dict(self.extra_content or {})
|
||||
for i, tc in enumerate(chunk.tool_calls or []):
|
||||
raw_tc = raw_tcs[i] if i < len(raw_tcs) else {}
|
||||
idx_field = raw_tc.get("index") if isinstance(raw_tc, dict) else None
|
||||
|
||||
key: object | None = None
|
||||
if tc.id:
|
||||
key = tc.id
|
||||
if isinstance(idx_field, int):
|
||||
self._index_to_id[idx_field] = tc.id
|
||||
elif isinstance(idx_field, int) and idx_field in self._index_to_id:
|
||||
key = self._index_to_id[idx_field]
|
||||
elif isinstance(idx_field, int):
|
||||
key = ("__idx__", idx_field)
|
||||
else:
|
||||
continue
|
||||
|
||||
existing = self._tool_buffers.get(key)
|
||||
if existing is None:
|
||||
self._tool_buffers[key] = ToolCall(
|
||||
id=tc.id,
|
||||
name=tc.name,
|
||||
arguments=tc.arguments,
|
||||
extra_content=dict(tc.extra_content) if tc.extra_content else None,
|
||||
)
|
||||
else:
|
||||
if tc.id and not existing.id:
|
||||
existing.id = tc.id
|
||||
if tc.name and not existing.name:
|
||||
existing.name = tc.name
|
||||
if tc.arguments:
|
||||
if is_final_summary:
|
||||
existing.arguments = tc.arguments
|
||||
else:
|
||||
existing.arguments += tc.arguments
|
||||
if tc.extra_content:
|
||||
existing.extra_content = {
|
||||
**(existing.extra_content or {}),
|
||||
**tc.extra_content,
|
||||
}
|
||||
if tc.extra_content:
|
||||
merged_extra.update(tc.extra_content)
|
||||
if merged_extra:
|
||||
self.extra_content = merged_extra
|
||||
|
||||
def to_assistant_message(self) -> Message:
|
||||
"""Build the `assistant` message you'd add to history for the next turn.
|
||||
|
||||
Includes tool_calls and merged extra_content (Google thought
|
||||
signatures etc.). The `text` field is included even when empty,
|
||||
because the real client always sends it.
|
||||
"""
|
||||
return Message.assistant(
|
||||
text=self.text,
|
||||
tool_calls=list(self.tool_calls) if self.tool_calls else None,
|
||||
extra_content=self.extra_content,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ChatAPI", "ChatResult", "ChatStreamChunk"]
|
||||
|
||||
|
||||
class ChatAPI:
|
||||
"""Wrapper around `POST /api/v1/ai/chat_completions` and its resume GET.
|
||||
|
||||
Construction is implicit through `client.chat`; callers don't
|
||||
instantiate this directly.
|
||||
"""
|
||||
|
||||
def __init__(self, client: Client) -> None:
|
||||
self._client = client
|
||||
|
||||
|
||||
async def _resolve_model(
|
||||
self, model: str | ModelInfo, provider: str | None
|
||||
) -> tuple[str, str]:
|
||||
"""Map `model` to the `(wire_model, provider)` pair the chat body expects.
|
||||
|
||||
Resolution rules (first match wins):
|
||||
|
||||
1. `model` is a `ModelInfo` → return `(model.model, model.provider)`.
|
||||
The `provider=` kwarg is ignored when a `ModelInfo` is passed
|
||||
(it already disambiguates).
|
||||
2. `model` is a string AND `provider` is given → pass through
|
||||
verbatim: `(model, provider)`. No catalog lookup. This is the
|
||||
escape hatch for models that aren't in the catalog yet, or for
|
||||
callers who already know the wire id.
|
||||
3. `model` is a string AND `provider` is None → look up the
|
||||
catalog (fetched once and cached on the Client):
|
||||
a. Try `catalog.by_id(model)` — matches the prefixed catalog
|
||||
id (e.g. `"google-gemini-3.1-pro-preview"`).
|
||||
b. Else search for `info.model == model` — matches the bare
|
||||
wire id (e.g. `"gemini-3.1-pro-preview"`).
|
||||
c. Else search for `info.name == model` — matches the display
|
||||
name (e.g. `"Claude Sonnet 4.6"`).
|
||||
d. Else raise `ValueError`.
|
||||
|
||||
The catalog fetch is shared across concurrent first-use callers and
|
||||
cached for the lifetime of the Client. Invalidate it via
|
||||
`client.invalidate_models_cache()` if the user's subscription changes.
|
||||
"""
|
||||
from raycast_api.ai.models import ModelInfo
|
||||
|
||||
if isinstance(model, ModelInfo):
|
||||
return model.model, model.provider
|
||||
if not isinstance(model, str):
|
||||
msg = f"model must be a str or ModelInfo, got {type(model).__name__}"
|
||||
raise TypeError(
|
||||
msg
|
||||
)
|
||||
if provider is not None:
|
||||
return model, provider
|
||||
|
||||
catalog = await self._client._get_models_catalog() # noqa: SLF001
|
||||
info = catalog.by_id(model)
|
||||
if info is None:
|
||||
for candidate in catalog.models:
|
||||
if candidate.model == model:
|
||||
info = candidate
|
||||
break
|
||||
if info is None:
|
||||
for candidate in catalog.models:
|
||||
if candidate.name == model:
|
||||
info = candidate
|
||||
break
|
||||
if info is None:
|
||||
msg = (
|
||||
f"model {model!r} not found in catalog; "
|
||||
"pass provider= to bypass lookup"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return info.model, info.provider
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tools(
|
||||
tools: list[Tool | RemoteTool | str] | None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
if not tools:
|
||||
return None
|
||||
out: list[dict[str, Any]] = []
|
||||
for t in tools:
|
||||
if isinstance(t, Tool):
|
||||
out.append(t.to_wire())
|
||||
elif isinstance(t, (RemoteTool, str)):
|
||||
out.append(Tool.remote(t).to_wire())
|
||||
else:
|
||||
msg = f"unsupported tool entry: {type(t).__name__}"
|
||||
raise TypeError(msg)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _build_preamble(
|
||||
preferences: UserPreferences | None, extra: str | None
|
||||
) -> str | None:
|
||||
"""Compose `additional_system_instructions` from a preferences block.
|
||||
|
||||
Returns None if both are absent — `additional_system_instructions`
|
||||
is then omitted from the body entirely.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
if preferences is not None:
|
||||
parts.append(preferences.render())
|
||||
if extra:
|
||||
parts.append(extra)
|
||||
if not parts:
|
||||
return None
|
||||
return "\n".join(parts)
|
||||
|
||||
def _build_body(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
provider: str,
|
||||
messages: list[Message],
|
||||
source: Source,
|
||||
buffer_id: str,
|
||||
message_id: str,
|
||||
locale: str,
|
||||
current_date: str | None,
|
||||
system_instructions: str | None,
|
||||
additional_system_instructions: str | None,
|
||||
temperature: float | None,
|
||||
reasoning_effort: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
tool_choice: str | None,
|
||||
resume_from: dict[str, str] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Compose the chat_completions request body.
|
||||
|
||||
Field order matches the real-client capture in
|
||||
`_extracted/captures/request_simple.curl.txt`:
|
||||
|
||||
system_instructions, additional_system_instructions, locale,
|
||||
temperature, current_date, message_id, reasoning_effort,
|
||||
messages, tools, tool_choice, source, model, provider, buffer_id
|
||||
|
||||
The server doesn't care about field order — but for max stealth
|
||||
we emit the same order the WebView sends, so a byte-fingerprint of
|
||||
the request looks identical to a real Raycast chat.
|
||||
"""
|
||||
body: dict[str, Any] = {}
|
||||
if system_instructions is not None:
|
||||
body["system_instructions"] = system_instructions
|
||||
if additional_system_instructions is not None:
|
||||
body["additional_system_instructions"] = additional_system_instructions
|
||||
body["locale"] = locale
|
||||
if temperature is not None:
|
||||
body["temperature"] = temperature
|
||||
if current_date is not None:
|
||||
body["current_date"] = current_date
|
||||
body["message_id"] = message_id
|
||||
if reasoning_effort is not None:
|
||||
body["reasoning_effort"] = reasoning_effort
|
||||
body["messages"] = [m.to_wire() for m in messages]
|
||||
if tools is not None:
|
||||
body["tools"] = tools
|
||||
body["tool_choice"] = tool_choice or "auto"
|
||||
body["source"] = source.value
|
||||
body["model"] = model
|
||||
body["provider"] = provider
|
||||
if resume_from is not None:
|
||||
body["resume_from"] = resume_from
|
||||
body["buffer_id"] = buffer_id
|
||||
return body
|
||||
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
*,
|
||||
model: str | ModelInfo,
|
||||
provider: str | None = None,
|
||||
messages: list[Message],
|
||||
source: Source = Source.AI_CHAT,
|
||||
buffer_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
system_instructions: str | None = None,
|
||||
additional_system_instructions: str | None = None,
|
||||
user_preferences: UserPreferences | None | bool = True,
|
||||
temperature: float | None = None,
|
||||
reasoning_effort: str | None = None,
|
||||
tools: list[Tool | RemoteTool | str] | None = None,
|
||||
tool_choice: str | None = None,
|
||||
current_date: str | None = None,
|
||||
on_last_event_id: Callable[[str], None] | None = None,
|
||||
) -> AsyncIterator[ChatStreamChunk]:
|
||||
"""Stream a chat completion from the server.
|
||||
|
||||
Yields one `ChatStreamChunk` per SSE event. Empty keepalive chunks
|
||||
are included — callers should check `chunk.is_empty` if they want
|
||||
to skip them. The terminator (`event: complete`) does NOT produce
|
||||
a yielded chunk — the iterator just stops.
|
||||
|
||||
`user_preferences`:
|
||||
- `True` (default) → auto-generated from host locale/timezone/date
|
||||
- a `UserPreferences` instance → used verbatim
|
||||
- `False` / `None` → no `<user-preferences>` block
|
||||
|
||||
`buffer_id` / `message_id` default to fresh UUIDv4s. Hold onto the
|
||||
`buffer_id` if you might want to resume; it's required for the
|
||||
resume GET.
|
||||
|
||||
`tools` accepts `Tool` instances, bare `RemoteTool` enum values,
|
||||
or raw strings (treated as remote tool names).
|
||||
|
||||
`model` accepts either a string (catalog id, wire id, or display
|
||||
name — resolved via the cached `/ai/models` catalog) or a
|
||||
`ModelInfo` instance. `provider` is only consulted when `model`
|
||||
is a string AND given explicitly, in which case it's passed
|
||||
through verbatim (escape hatch for models not in the catalog).
|
||||
"""
|
||||
wire_model, wire_provider = await self._resolve_model(model, provider)
|
||||
prefs = self._coerce_preferences(user_preferences)
|
||||
preamble = self._build_preamble(prefs, additional_system_instructions)
|
||||
|
||||
defaults = _SOURCE_DEFAULTS.get(source, {})
|
||||
sys_inst = (
|
||||
system_instructions
|
||||
if system_instructions is not None
|
||||
else defaults.get("system_instructions")
|
||||
)
|
||||
temp = temperature if temperature is not None else defaults.get("temperature")
|
||||
|
||||
body = self._build_body(
|
||||
model=wire_model,
|
||||
provider=wire_provider,
|
||||
messages=messages,
|
||||
source=source,
|
||||
buffer_id=buffer_id or str(uuid.uuid4()),
|
||||
message_id=message_id or str(uuid.uuid4()),
|
||||
locale=self._client.locale,
|
||||
current_date=current_date or self._today_iso(),
|
||||
system_instructions=sys_inst,
|
||||
additional_system_instructions=preamble,
|
||||
temperature=temp,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tools=self._normalize_tools(tools),
|
||||
tool_choice=tool_choice,
|
||||
resume_from=None,
|
||||
)
|
||||
body_bytes = json.dumps(body, separators=(",", ":"), ensure_ascii=False).encode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
async for evt in self._client.stream(
|
||||
"POST",
|
||||
"/api/v1/ai/chat_completions",
|
||||
body=body_bytes,
|
||||
sign=True,
|
||||
on_last_event_id=on_last_event_id,
|
||||
):
|
||||
if evt.is_terminal:
|
||||
return
|
||||
if not evt.data:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(evt.data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
yield ChatStreamChunk.from_wire(data, event_id=evt.id)
|
||||
|
||||
async def resume(
|
||||
self, *, buffer_id: str, last_event_id: str
|
||||
) -> AsyncIterator[ChatStreamChunk]:
|
||||
"""Resume a previously-interrupted chat stream.
|
||||
|
||||
Sends `GET /api/v1/ai/chat_completions/resume?buffer_id=<id>` with
|
||||
`Last-Event-ID: <last_event_id>` and an empty signed body. The
|
||||
server replays everything emitted after `last_event_id`.
|
||||
|
||||
Per the BUNDLE_NOTES (§3 "Resume mechanism"), a 204 response means
|
||||
"nothing left to resume" — in that case this iterator yields
|
||||
nothing and stops cleanly.
|
||||
"""
|
||||
async for evt in self._client.stream(
|
||||
"GET",
|
||||
"/api/v1/ai/chat_completions/resume",
|
||||
params={"buffer_id": buffer_id},
|
||||
sign=True,
|
||||
is_resume=True,
|
||||
last_event_id=last_event_id,
|
||||
):
|
||||
if evt.is_terminal:
|
||||
return
|
||||
if not evt.data:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(evt.data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
yield ChatStreamChunk.from_wire(data, event_id=evt.id)
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
*,
|
||||
model: str | ModelInfo,
|
||||
provider: str | None = None,
|
||||
messages: list[Message],
|
||||
source: Source = Source.AI_CHAT,
|
||||
system_instructions: str | None = None,
|
||||
additional_system_instructions: str | None = None,
|
||||
user_preferences: UserPreferences | None | bool = True,
|
||||
temperature: float | None = None,
|
||||
reasoning_effort: str | None = None,
|
||||
tools: list[Tool | RemoteTool | str] | None = None,
|
||||
tool_choice: str | None = None,
|
||||
current_date: str | None = None,
|
||||
) -> ChatResult:
|
||||
"""Run a chat completion and return the accumulated result.
|
||||
|
||||
Equivalent to consuming `stream(...)` and merging the chunks. Use
|
||||
this when you don't care about token-by-token streaming.
|
||||
|
||||
Equivalent to consuming `stream(...)` and feeding each chunk into
|
||||
`ChatResult.add()`. Use this when you don't care about
|
||||
token-by-token streaming; otherwise iterate `stream(...)` directly
|
||||
and call `result.add(chunk)` yourself.
|
||||
"""
|
||||
result = ChatResult()
|
||||
async for chunk in self.stream(
|
||||
model=model,
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
source=source,
|
||||
system_instructions=system_instructions,
|
||||
additional_system_instructions=additional_system_instructions,
|
||||
user_preferences=user_preferences,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
current_date=current_date,
|
||||
):
|
||||
result.add(chunk)
|
||||
return result
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _coerce_preferences(
|
||||
value: UserPreferences | None | bool, # noqa: FBT001
|
||||
) -> UserPreferences | None:
|
||||
if value is True:
|
||||
return UserPreferences.auto()
|
||||
if value is False or value is None:
|
||||
return None
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _today_iso() -> str:
|
||||
import datetime
|
||||
|
||||
return datetime.date.today().isoformat() # noqa: DTZ011 — local date is intended
|
||||
@@ -0,0 +1,172 @@
|
||||
"""`/api/v1/ai/files` — chat attachment uploads.
|
||||
|
||||
Three calls, all signed (BUNDLE_NOTES §1b):
|
||||
|
||||
- `POST /ai/files` — register an upload. Signs the JSON body, returns
|
||||
`{id, direct_upload:{url, headers}}`. Caller PUTs the blob to that
|
||||
URL (unsigned, off-Raycast — usually S3-presigned).
|
||||
- `GET /ai/files/{id}` — fetch a previously-uploaded file. ⚠ signs the
|
||||
literal two-byte string `"{}"` even though the GET sends no body
|
||||
over the wire. This is a Raycast-side oddity (`uV` @ 118609); the
|
||||
server validates the signature against `"{}"`, so we must match.
|
||||
- `DELETE /ai/files` — bulk-delete files for a list of chat_ids.
|
||||
Sends a JSON body on a DELETE. aiohttp supports this if we pass
|
||||
`data=` explicitly.
|
||||
|
||||
`FilesAPI.upload(path, chat_id)` orchestrates both halves of the upload
|
||||
(register + PUT) and returns the `FileMetadata` ready to drop into an
|
||||
`Attachment`. The PUT is done via `aiohttp.ClientSession.put(...)` on
|
||||
the same session as the rest of the client (so connection pooling
|
||||
works) but with no Raycast headers — the presigned URL carries its own
|
||||
auth.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from raycast_api.client.http import Client
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileMetadata:
|
||||
"""Result of a `POST /ai/files` upload, post-PUT.
|
||||
|
||||
`file_id` is the server-side id you'll reference in subsequent chat
|
||||
requests (via `Attachment.file_id`). Everything else is bookkeeping
|
||||
the caller usually doesn't need.
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
size: int
|
||||
content_type: str
|
||||
checksum: str
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class FilesAPI:
|
||||
"""Wrapper around the three `/ai/files` endpoints."""
|
||||
|
||||
def __init__(self, client: Client) -> None:
|
||||
self._client = client
|
||||
|
||||
async def upload(
|
||||
self,
|
||||
*,
|
||||
path: str | Path,
|
||||
chat_id: str,
|
||||
filename: str | None = None,
|
||||
content_type: str | None = None,
|
||||
) -> FileMetadata:
|
||||
"""Upload a file to Raycast's blob store.
|
||||
|
||||
Steps:
|
||||
1. Read the file (or accept caller-provided bytes via the future
|
||||
`data=` overload — not implemented yet).
|
||||
2. Compute SHA-256 checksum.
|
||||
3. `POST /ai/files` with metadata. Server returns `direct_upload`.
|
||||
4. PUT the file bytes to `direct_upload.url` with the provided
|
||||
headers. No Raycast signing on the PUT.
|
||||
"""
|
||||
p = Path(path)
|
||||
data = p.read_bytes() # noqa: ASYNC240 — sync read; aiofiles would be overkill
|
||||
return await self._upload_bytes(
|
||||
data=data,
|
||||
chat_id=chat_id,
|
||||
filename=filename or p.name,
|
||||
content_type=content_type or _guess_content_type(p),
|
||||
)
|
||||
|
||||
async def _upload_bytes(
|
||||
self, *, data: bytes, chat_id: str, filename: str, content_type: str
|
||||
) -> FileMetadata:
|
||||
checksum = hashlib.sha256(data).hexdigest()
|
||||
body = {
|
||||
"chat_id": chat_id,
|
||||
"blob": {
|
||||
"filename": filename,
|
||||
"byte_size": len(data),
|
||||
"content_type": content_type,
|
||||
"checksum": checksum,
|
||||
},
|
||||
}
|
||||
|
||||
async with self._client.request(
|
||||
"POST", "/api/v1/ai/files", json_body=body, sign=True
|
||||
) as resp:
|
||||
registration = await resp.json()
|
||||
|
||||
upload_info = registration.get("direct_upload") or {}
|
||||
upload_url = upload_info.get("url")
|
||||
upload_headers = upload_info.get("headers") or {}
|
||||
if not isinstance(upload_url, str) or not upload_url:
|
||||
msg = "POST /ai/files succeeded but response had no direct_upload.url"
|
||||
raise RuntimeError(
|
||||
msg
|
||||
)
|
||||
|
||||
session = self._client._require_session() # noqa: SLF001 — same package
|
||||
async with session.put(
|
||||
upload_url, data=data, headers=upload_headers
|
||||
) as put_resp:
|
||||
if put_resp.status >= 400:
|
||||
text = await put_resp.text()
|
||||
msg = f"direct_upload PUT failed: HTTP {put_resp.status} {text[:200]}"
|
||||
raise RuntimeError(
|
||||
msg
|
||||
)
|
||||
|
||||
return FileMetadata(
|
||||
file_id=str(registration.get("id", "")),
|
||||
filename=filename,
|
||||
size=len(data),
|
||||
content_type=content_type,
|
||||
checksum=checksum,
|
||||
raw=registration,
|
||||
)
|
||||
|
||||
async def get(self, file_id: str) -> bytes:
|
||||
"""Download a previously-uploaded file by its id.
|
||||
|
||||
⚠ Signs the literal string `"{}"` (two bytes) as the body, per
|
||||
the `uV` caller's behaviour. The server validates the signature
|
||||
against that, NOT against an empty string — sending `b""` here
|
||||
produces a 401.
|
||||
"""
|
||||
async with self._client.request(
|
||||
"GET", f"/api/v1/ai/files/{file_id}", body=b"{}", sign=True
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
|
||||
async def delete(self, *, chat_ids: list[str]) -> None:
|
||||
"""Delete all files associated with one or more chat ids.
|
||||
|
||||
Sends a JSON body on the DELETE method (`oge` @ 119406). aiohttp
|
||||
forwards bodies on DELETE when `data=` is explicit, so this works
|
||||
end-to-end. Server returns 2xx with no body of interest.
|
||||
"""
|
||||
body = {"chat_ids": list(chat_ids)}
|
||||
body_bytes = json.dumps(body, separators=(",", ":"), ensure_ascii=False).encode(
|
||||
"utf-8"
|
||||
)
|
||||
async with self._client.request(
|
||||
"DELETE", "/api/v1/ai/files", body=body_bytes, sign=True
|
||||
) as resp:
|
||||
await resp.read()
|
||||
|
||||
|
||||
def _guess_content_type(p: Path) -> str:
|
||||
"""Best-effort MIME lookup for `Content-Type` of the registered blob.
|
||||
|
||||
Falls back to `application/octet-stream`. The server is permissive
|
||||
about this field — it's metadata for the model, not transport.
|
||||
"""
|
||||
guess, _ = mimetypes.guess_type(p.name)
|
||||
return guess or "application/octet-stream"
|
||||
@@ -0,0 +1,36 @@
|
||||
"""`/api/v1/me` wrapper — bearer-token sanity probe.
|
||||
|
||||
The smallest possible round-trip against the backend. Used as a probe to
|
||||
verify the Bearer token is valid before doing anything more expensive (a
|
||||
401 here is unambiguous: the token is wrong, not a signing bug).
|
||||
|
||||
Unsigned; returns a raw dict because the `me` response shape includes a
|
||||
LOT of subscription/feature-flag fields that aren't worth modelling
|
||||
exhaustively — callers are typically interested in `email` /
|
||||
`raycast_subscription` / `has_pro_features` and nothing else.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from raycast_api.client.http import Client
|
||||
|
||||
|
||||
class MeAPI:
|
||||
"""`GET /api/v1/me`."""
|
||||
|
||||
def __init__(self, client: Client) -> None:
|
||||
self._client = client
|
||||
|
||||
async def get(self) -> dict[str, Any]:
|
||||
"""Fetch the current user's profile.
|
||||
|
||||
Returns the raw JSON dict — we deliberately don't model the shape
|
||||
because it has dozens of subscription/feature-flag fields that
|
||||
Raycast adds to over time. If a caller needs typed access to a
|
||||
specific field, they should pick it out of this dict.
|
||||
"""
|
||||
async with self._client.request("GET", "/api/v1/me", sign=False) as resp:
|
||||
return await resp.json()
|
||||
@@ -0,0 +1,173 @@
|
||||
"""`/api/v1/ai/models` wrapper.
|
||||
|
||||
Single endpoint: GET, unsigned (Bearer-only), with the
|
||||
`X-Raycast-Experimental: autoModels` header attached — same as every
|
||||
signed-surface caller, but here it's the only Raycast-specific header
|
||||
on the request (signing is omitted per BUNDLE_NOTES §1c).
|
||||
|
||||
The response is the catalog of models the user's subscription has access
|
||||
to, plus the server-side defaults for each Raycast surface (chat, quick_ai,
|
||||
commands, …). Callers will typically `await client.models.list()` once at
|
||||
startup, cache it, and pass `model=info.id` / `provider=info.provider` to
|
||||
`client.chat.stream(...)`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from raycast_api.client.http import Client
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""One entry from `/ai/models`.
|
||||
|
||||
Most fields are passed straight through from the server; the only
|
||||
transformations are the snake_case key names (the server already uses
|
||||
snake_case) and consistent typing on the optional fields.
|
||||
|
||||
Three id-like fields that are easy to mix up:
|
||||
|
||||
- `id` — catalog id with provider prefix, e.g. `"google-gemini-3.1-pro-preview"`.
|
||||
Use this to look entries up (`models.by_id(...)`); do NOT send it
|
||||
as the request `model` field.
|
||||
- `model` — bare wire identifier the server forwards to the provider,
|
||||
e.g. `"gemini-3.1-pro-preview"`. THIS is what goes in the chat
|
||||
request body's `model` field. Confirmed against captured requests
|
||||
(see `_extracted/captures/request_*.curl.txt`).
|
||||
- `provider` — provider key, e.g. `"google"`. Goes in the request
|
||||
body's `provider` field alongside `model`.
|
||||
|
||||
Correct usage:
|
||||
|
||||
info = models.by_id("google-gemini-3.1-pro-preview")
|
||||
await client.chat.complete(model=info.model, provider=info.provider, ...)
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
model: str
|
||||
provider: str
|
||||
provider_name: str = ""
|
||||
provider_brand: str = ""
|
||||
description: str = ""
|
||||
context: int = -1
|
||||
status: str = ""
|
||||
availability: str = ""
|
||||
features: list[str] = field(default_factory=list)
|
||||
capabilities: dict[str, Any] = field(default_factory=dict)
|
||||
abilities: dict[str, Any] = field(default_factory=dict)
|
||||
in_better_ai_subscription: bool = False
|
||||
allowed_subscription_types: list[str] = field(default_factory=list)
|
||||
requires_better_ai: bool = False
|
||||
suggestions: list[Any] = field(default_factory=list)
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_wire(cls, d: dict[str, Any]) -> ModelInfo:
|
||||
return cls(
|
||||
id=str(d.get("id", "")),
|
||||
name=str(d.get("name", "")),
|
||||
model=str(d.get("model", "")),
|
||||
provider=str(d.get("provider", "")),
|
||||
provider_name=str(d.get("provider_name", "")),
|
||||
provider_brand=str(d.get("provider_brand", "")),
|
||||
description=str(d.get("description", "")),
|
||||
context=int(d.get("context", -1)),
|
||||
status=str(d.get("status", "")),
|
||||
availability=str(d.get("availability", "")),
|
||||
features=list(d.get("features") or []),
|
||||
capabilities=dict(d.get("capabilities") or {}),
|
||||
abilities=dict(d.get("abilities") or {}),
|
||||
in_better_ai_subscription=bool(d.get("in_better_ai_subscription", False)),
|
||||
allowed_subscription_types=list(d.get("allowed_subscription_types") or []),
|
||||
requires_better_ai=bool(d.get("requires_better_ai", False)),
|
||||
suggestions=list(d.get("suggestions") or []),
|
||||
raw=d,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_temperature(self) -> bool:
|
||||
node = self.abilities.get("temperature")
|
||||
return bool(node and node.get("supported"))
|
||||
|
||||
@property
|
||||
def supports_reasoning_effort(self) -> bool:
|
||||
node = self.abilities.get("reasoning_effort")
|
||||
return bool(node and node.get("supported"))
|
||||
|
||||
@property
|
||||
def reasoning_effort_options(self) -> list[str]:
|
||||
node = self.abilities.get("reasoning_effort") or {}
|
||||
opts = node.get("options")
|
||||
return [str(o) for o in opts] if isinstance(opts, list) else []
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelsResponse:
|
||||
"""The complete `/ai/models` payload.
|
||||
|
||||
`default_models` is the server's per-surface default mapping; callers
|
||||
can look up `default_models.get("chat")` to find the model id Raycast
|
||||
itself would pick for the AI Chat window.
|
||||
|
||||
`free_model_ids` is the list of model ids available on the free tier —
|
||||
the server actually emits this as `free_models: ["id1", "id2", ...]`
|
||||
(a list of strings, not full ModelInfo objects, despite what BUNDLE_NOTES
|
||||
inferred). To get a `ModelInfo` for one, look it up via `.by_id(id)`.
|
||||
"""
|
||||
|
||||
models: list[ModelInfo]
|
||||
default_models: dict[str, str] = field(default_factory=dict)
|
||||
free_model_ids: list[str] = field(default_factory=list)
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_wire(cls, d: dict[str, Any]) -> ModelsResponse:
|
||||
free_raw = d.get("free_models") or []
|
||||
free_ids: list[str] = []
|
||||
for item in free_raw:
|
||||
if isinstance(item, str):
|
||||
free_ids.append(item)
|
||||
elif isinstance(item, dict) and "id" in item:
|
||||
free_ids.append(str(item["id"]))
|
||||
return cls(
|
||||
models=[ModelInfo.from_wire(m) for m in d.get("models") or []],
|
||||
default_models={
|
||||
str(k): str(v) for k, v in (d.get("default_models") or {}).items()
|
||||
},
|
||||
free_model_ids=free_ids,
|
||||
raw=d,
|
||||
)
|
||||
|
||||
def by_id(self, model_id: str) -> ModelInfo | None:
|
||||
for m in self.models:
|
||||
if m.id == model_id:
|
||||
return m
|
||||
return None
|
||||
|
||||
|
||||
class ModelsAPI:
|
||||
"""Thin wrapper around `GET /api/v1/ai/models`."""
|
||||
|
||||
def __init__(self, client: Client) -> None:
|
||||
self._client = client
|
||||
|
||||
async def list(self) -> ModelsResponse:
|
||||
"""Fetch the full model catalog.
|
||||
|
||||
Sent unsigned (Bearer-only) but still carries the
|
||||
`X-Raycast-Experimental: autoModels` header that the real client
|
||||
attaches — without it the response shape changes (older catalog).
|
||||
"""
|
||||
async with self._client.request(
|
||||
"GET",
|
||||
"/api/v1/ai/models",
|
||||
sign=False,
|
||||
headers={"X-Raycast-Experimental": self._client.config.experimental_header},
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
return ModelsResponse.from_wire(data)
|
||||
@@ -0,0 +1,397 @@
|
||||
"""Typed shapes for chat completions and adjacent endpoints.
|
||||
|
||||
These dataclasses sit between Raycast's wire JSON (per BUNDLE_NOTES.md §3)
|
||||
and the Python caller. Every type exposes either `.to_wire()` to produce
|
||||
the dict that goes into the request body, or `.from_wire(d)` to parse what
|
||||
the server returned. We intentionally keep these as plain dataclasses —
|
||||
not pydantic — for consistency with the rest of the package.
|
||||
|
||||
Field order in `to_wire()` for `ChatRequest` mirrors what the real client
|
||||
sends (observed in `_extracted/captures/request_*.curl.txt`). The server
|
||||
doesn't care about ordering itself, but the bytes we sign must equal the
|
||||
bytes we send, so we serialise once via `to_wire()` and the same dict goes
|
||||
to both `Signer.sign(...)` and `aiohttp.session.post(data=...)`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from typing import Any, ClassVar
|
||||
|
||||
|
||||
class Source(StrEnum):
|
||||
"""`source` discriminator on chat_completions.
|
||||
|
||||
The library defaults to `AI_CHAT` (matches the main Raycast chat window
|
||||
fingerprint). `API` is what the Extension API path sends; the other two
|
||||
are observable but less useful for a non-Raycast caller.
|
||||
"""
|
||||
|
||||
AI_CHAT = "ai_chat"
|
||||
QUICK_AI = "quick_ai"
|
||||
AI_COMMAND = "ai_command"
|
||||
API = "api"
|
||||
|
||||
|
||||
class RemoteTool(StrEnum):
|
||||
"""The three generic remote tools the library exposes by name.
|
||||
|
||||
Per BUNDLE_NOTES decision: only the three model-agnostic web tools.
|
||||
Other Raycast-extension remote tools (calendar, location, etc.) are
|
||||
out of scope. Pass these to `ChatAPI.stream(tools=...)` as-is and the
|
||||
library wraps them as `{"type":"remote_tool","name":<value>}`.
|
||||
"""
|
||||
|
||||
WEB_SEARCH = "web_search"
|
||||
SEARCH_IMAGES = "search_images"
|
||||
READ_PAGE = "read_page"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tool:
|
||||
"""One tool definition for the `tools` field of a chat request.
|
||||
|
||||
Use `Tool.local(...)` for a function-calling tool the caller will
|
||||
execute and feed back as a `tool` message; use `Tool.remote(...)`
|
||||
(or pass a `RemoteTool` enum) for server-routed remote tools.
|
||||
"""
|
||||
|
||||
type: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
|
||||
@classmethod
|
||||
def local(
|
||||
cls, name: str, description: str = "", parameters: dict[str, Any] | None = None
|
||||
) -> Tool:
|
||||
return cls(
|
||||
type="local_tool",
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters if parameters is not None else {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def remote(cls, name: str | RemoteTool) -> Tool:
|
||||
return cls(
|
||||
type="remote_tool",
|
||||
name=name.value if isinstance(name, RemoteTool) else name,
|
||||
)
|
||||
|
||||
def to_wire(self) -> dict[str, Any]:
|
||||
if self.type == "remote_tool":
|
||||
return {"type": "remote_tool", "name": self.name}
|
||||
fn: dict[str, Any] = {"name": self.name}
|
||||
if self.description is not None:
|
||||
fn["description"] = self.description
|
||||
if self.parameters is not None:
|
||||
fn["parameters"] = self.parameters
|
||||
return {"type": "local_tool", "function": fn}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""One function call emitted by the assistant.
|
||||
|
||||
`arguments` is the JSON-encoded argument string (NOT a dict) — that's
|
||||
what Raycast / OpenAI-style backends emit and what the server expects
|
||||
back when we echo it in history.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
arguments: str
|
||||
extra_content: dict[str, dict[str, str]] | None = None
|
||||
|
||||
def to_wire(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": "function",
|
||||
"function": {"name": self.name, "arguments": self.arguments},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_wire(cls, d: dict[str, Any]) -> ToolCall:
|
||||
if "function" in d:
|
||||
fn = d["function"]
|
||||
return cls(
|
||||
id=str(d["id"]),
|
||||
name=str(fn.get("name", "")),
|
||||
arguments=str(fn.get("arguments", "")),
|
||||
extra_content=d.get("extra_content"),
|
||||
)
|
||||
return cls(
|
||||
id=str(d.get("id", "")),
|
||||
name=str(d.get("name", "")),
|
||||
arguments=str(d.get("arguments", "")),
|
||||
extra_content=d.get("extra_content"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Attachment:
|
||||
"""Metadata for a file uploaded via `POST /ai/files`.
|
||||
|
||||
Most users build this through `FilesAPI.upload(...)` which fills in
|
||||
`fileId` from the server response. The other fields are local UI
|
||||
bookkeeping the real client sends verbatim — we mirror them so the
|
||||
request body matches what the server expects.
|
||||
"""
|
||||
|
||||
id: str
|
||||
path: str
|
||||
filename: str
|
||||
size: int
|
||||
file_id: str
|
||||
type: str = "file"
|
||||
source: str = "file"
|
||||
content_type: str = "application/octet-stream"
|
||||
status: str = "completed"
|
||||
url: str = ""
|
||||
is_over_context_limit: bool = False
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_wire(self) -> dict[str, Any]:
|
||||
out: dict[str, Any] = {
|
||||
"id": self.id,
|
||||
"path": self.path,
|
||||
"filename": self.filename,
|
||||
"size": self.size,
|
||||
"fileId": self.file_id,
|
||||
"type": self.type,
|
||||
"source": self.source,
|
||||
"contentType": self.content_type,
|
||||
"status": self.status,
|
||||
"url": self.url,
|
||||
"isOverContextLimit": self.is_over_context_limit,
|
||||
}
|
||||
out.update(self.extra)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""One message in a chat history.
|
||||
|
||||
Use the factory classmethods rather than constructing directly — they
|
||||
enforce the role/field combinations the server expects (see
|
||||
BUNDLE_NOTES §3, "Message types").
|
||||
"""
|
||||
|
||||
role: str
|
||||
content: dict[str, Any] = field(default_factory=dict)
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
extra_content: dict[str, dict[str, str]] | None = None
|
||||
name: str | None = None
|
||||
tool_call_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def user(cls, text: str, attachments: list[Attachment] | None = None) -> Message:
|
||||
content: dict[str, Any] = {"text": text}
|
||||
if attachments is not None:
|
||||
content["attachments"] = [a.to_wire() for a in attachments]
|
||||
return cls(role="user", content=content)
|
||||
|
||||
@classmethod
|
||||
def assistant(
|
||||
cls,
|
||||
text: str = "",
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
extra_content: dict[str, dict[str, str]] | None = None,
|
||||
) -> Message:
|
||||
return cls(
|
||||
role="assistant",
|
||||
content={"text": text},
|
||||
tool_calls=list(tool_calls) if tool_calls else None,
|
||||
extra_content=extra_content,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tool(cls, *, tool_call_id: str, name: str, result: Any) -> Message:
|
||||
"""Build a tool-result message from any Python value.
|
||||
|
||||
The server expects `content.text` to be a JSON-encoded MCP-style
|
||||
list of content blocks (e.g. `[{"type":"text","text":"..."}]`) —
|
||||
anything else produces `unknown_api_error` upstream.
|
||||
|
||||
Behaviour:
|
||||
- `str` → wrapped as a single text block.
|
||||
- already a list of `{"type": ...}` blocks → passed through.
|
||||
- any other value (dict, list, number, ...) → JSON-serialised
|
||||
into a single text block.
|
||||
"""
|
||||
import json
|
||||
|
||||
def _is_block_list(x: object) -> bool:
|
||||
return isinstance(x, list) and all(
|
||||
isinstance(b, dict) and "type" in b for b in x
|
||||
)
|
||||
|
||||
if isinstance(result, str):
|
||||
payload: list[dict[str, Any]] = [{"type": "text", "text": result}]
|
||||
elif _is_block_list(result):
|
||||
payload = result # type: ignore[assignment]
|
||||
else:
|
||||
payload = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(
|
||||
result, separators=(",", ":"), ensure_ascii=False
|
||||
),
|
||||
}
|
||||
]
|
||||
return cls(
|
||||
role="tool",
|
||||
content={
|
||||
"text": json.dumps(payload, separators=(",", ":"), ensure_ascii=False)
|
||||
},
|
||||
name=name,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
|
||||
def to_wire(self) -> dict[str, Any]:
|
||||
out: dict[str, Any] = {"role": self.role, "content": dict(self.content)}
|
||||
if self.tool_calls:
|
||||
out["tool_calls"] = [tc.to_wire() for tc in self.tool_calls]
|
||||
if self.extra_content is not None:
|
||||
out["extra_content"] = self.extra_content
|
||||
if self.name is not None:
|
||||
out["name"] = self.name
|
||||
if self.tool_call_id is not None:
|
||||
out["tool_call_id"] = self.tool_call_id
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserPreferences:
|
||||
"""Renders the `<user-preferences>` block of `additional_system_instructions`.
|
||||
|
||||
This is the only personalisation block the library produces by default
|
||||
(per the Phase 1 decision: profile / memory / extensions / skills are
|
||||
out of scope). The block matches the real client's wording byte-for-
|
||||
byte so the model can use it identically to a real chat.
|
||||
|
||||
`current_date` should be a `YYYY-MM-DD` string (the real client uses
|
||||
`Intl.DateTimeFormat`'s short-locale form — we just normalise to ISO).
|
||||
"""
|
||||
|
||||
locale: str
|
||||
timezone: str
|
||||
current_date: str
|
||||
|
||||
_TEMPLATE: ClassVar[str] = (
|
||||
"<user-preferences>\n"
|
||||
" The user has the following system preferences:\n"
|
||||
" - Locale: {locale}\n"
|
||||
" - Timezone: {timezone}\n"
|
||||
" - Current Date: {date}\n"
|
||||
" - Use the system preferences to format your answers accordingly\n"
|
||||
"</user-preferences>"
|
||||
)
|
||||
|
||||
def render(self) -> str:
|
||||
return self._TEMPLATE.format(
|
||||
locale=self.locale, timezone=self.timezone, date=self.current_date
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def auto(cls, locale: str | None = None) -> UserPreferences:
|
||||
"""Build preferences from the host's locale/timezone and today's date.
|
||||
|
||||
Used as the default in `ChatAPI.stream(...)` when the caller doesn't
|
||||
pass an explicit `user_preferences=` arg. The locale arg defaults
|
||||
to the client's `locale` (which itself defaults to `en-US`).
|
||||
"""
|
||||
import datetime
|
||||
import time
|
||||
|
||||
tz: str
|
||||
if time.daylight and time.localtime().tm_isdst > 0:
|
||||
tz = time.tzname[1] or time.tzname[0] or "UTC"
|
||||
else:
|
||||
tz = time.tzname[0] or "UTC"
|
||||
return cls(
|
||||
locale=locale or "en-US",
|
||||
timezone=tz,
|
||||
current_date=datetime.date.today().isoformat(), # noqa: DTZ011 — local date
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatStreamChunk:
|
||||
"""One parsed JSON chunk from the chat_completions SSE stream.
|
||||
|
||||
Fields here are a 1:1 mirror of `vkt`'s consumer (BUNDLE_NOTES §3,
|
||||
"Per-chunk JSON shape"). All fields are optional; callers should
|
||||
check which fields are present.
|
||||
|
||||
`event_id` is the SSE `id:` value (used for resume). `raw` is the
|
||||
underlying dict in case the caller needs a field we didn't model
|
||||
(e.g. forward-compat with a new chunk kind).
|
||||
"""
|
||||
|
||||
text: str | None = None
|
||||
reasoning: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
finish_reason: str | None = None
|
||||
usage: dict[str, int] | None = None
|
||||
notification: str | None = None
|
||||
notification_type: str | None = None
|
||||
references: list[dict[str, Any]] | None = None
|
||||
image: str | None = None
|
||||
extra_content: dict[str, Any] | None = None
|
||||
event_id: str | None = None
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_wire(
|
||||
cls, data: dict[str, Any], *, event_id: str | None = None
|
||||
) -> ChatStreamChunk:
|
||||
tc_wire = data.get("tool_calls")
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
if isinstance(tc_wire, list):
|
||||
tool_calls = [ToolCall.from_wire(d) for d in tc_wire if isinstance(d, dict)]
|
||||
return cls(
|
||||
text=data.get("text") if isinstance(data.get("text"), str) else None,
|
||||
reasoning=data.get("reasoning")
|
||||
if isinstance(data.get("reasoning"), str)
|
||||
else None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=data.get("finish_reason")
|
||||
if isinstance(data.get("finish_reason"), str)
|
||||
else None,
|
||||
usage=data.get("usage") if isinstance(data.get("usage"), dict) else None,
|
||||
notification=data.get("notification")
|
||||
if isinstance(data.get("notification"), str)
|
||||
else None,
|
||||
notification_type=data.get("notification_type")
|
||||
if isinstance(data.get("notification_type"), str)
|
||||
else None,
|
||||
references=data.get("references")
|
||||
if isinstance(data.get("references"), list)
|
||||
else None,
|
||||
image=data.get("image") if isinstance(data.get("image"), str) else None,
|
||||
extra_content=data.get("extra_content")
|
||||
if isinstance(data.get("extra_content"), dict)
|
||||
else None,
|
||||
event_id=event_id,
|
||||
raw=data,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""True for the keepalive `{"text":""}` chunks the server interleaves."""
|
||||
return (
|
||||
not self.text
|
||||
and not self.reasoning
|
||||
and not self.tool_calls
|
||||
and not self.finish_reason
|
||||
and not self.usage
|
||||
and not self.notification
|
||||
and not self.image
|
||||
and not self.references
|
||||
)
|
||||
@@ -0,0 +1,451 @@
|
||||
"""Command-line interface for `raycast-api`.
|
||||
|
||||
Four verbs:
|
||||
|
||||
- ``raycast-api init`` — run discovery against a local Raycast install,
|
||||
write the result to ``config.json``.
|
||||
- ``raycast-api refresh`` — same as ``init`` but always overwrites and
|
||||
bypasses the discovery cache (the bundle
|
||||
may have changed even at the same hash if
|
||||
cache was hand-edited).
|
||||
- ``raycast-api inspect`` — print a summary of a saved config. No
|
||||
network calls.
|
||||
- ``raycast-api ask`` — minimal smoke-test command: send one prompt,
|
||||
print the reply. Reads Bearer + device id
|
||||
from env / flags. The device id is auto-
|
||||
generated and persisted on first use.
|
||||
|
||||
The CLI is intentionally small — the library is the product, this is a
|
||||
thin convenience for users who don't want to write a Python script just to
|
||||
verify their config works.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import secrets
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from raycast_api.config import Config, ConfigComparison
|
||||
from raycast_api.errors import ConfigError, DiscoveryError, RaycastApiError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
DEFAULT_CONFIG_PATH = Path("config.json")
|
||||
DEFAULT_DEVICE_ID_PATH = Path.home() / ".config" / "raycast-api" / "device_id"
|
||||
|
||||
|
||||
_DEFAULT_APP_PATHS: tuple[Path, ...] = (
|
||||
Path("/Applications/Raycast Beta.app"),
|
||||
Path("/Applications/Raycast.app"),
|
||||
Path.cwd() / "Raycast Beta.app",
|
||||
Path.cwd() / "Raycast.app",
|
||||
)
|
||||
|
||||
|
||||
def _resolve_app_path(explicit: str | None) -> Path:
|
||||
if explicit is not None:
|
||||
return Path(explicit).expanduser().resolve()
|
||||
for candidate in _DEFAULT_APP_PATHS:
|
||||
if candidate.is_dir():
|
||||
return candidate
|
||||
msg = "could not find a Raycast install; pass --app-path <path-to-Raycast.app>"
|
||||
raise SystemExit(
|
||||
msg
|
||||
)
|
||||
|
||||
|
||||
def _try_resolve_app_path(explicit: str | None) -> Path | None:
|
||||
"""Non-raising variant for `inspect`: return None if no app is findable."""
|
||||
if explicit is not None:
|
||||
candidate = Path(explicit).expanduser().resolve()
|
||||
return candidate if candidate.is_dir() else None
|
||||
for candidate in _DEFAULT_APP_PATHS:
|
||||
if candidate.is_dir():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
||||
def _generate_device_id() -> str:
|
||||
"""Mint a fresh 64-char hex device id."""
|
||||
return secrets.token_hex(32)
|
||||
|
||||
|
||||
def _load_or_create_device_id(path: Path = DEFAULT_DEVICE_ID_PATH) -> str:
|
||||
"""Read the persisted device id, generating + saving one if missing.
|
||||
|
||||
Stored in `~/.config/raycast-api/device_id` (chmod 0o600). The same id
|
||||
is reused across CLI invocations so a single user looks like a single
|
||||
install to the backend.
|
||||
"""
|
||||
if path.exists():
|
||||
existing = path.read_text(encoding="ascii").strip()
|
||||
if len(existing) == 64 and all(c in "0123456789abcdefABCDEF" for c in existing):
|
||||
return existing.lower()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fresh = _generate_device_id()
|
||||
path.write_text(fresh + "\n", encoding="ascii")
|
||||
with contextlib.suppress(OSError):
|
||||
path.chmod(0o600)
|
||||
return fresh
|
||||
|
||||
|
||||
|
||||
|
||||
def _cmd_init(args: argparse.Namespace) -> int:
|
||||
app_path = _resolve_app_path(args.app_path)
|
||||
output = Path(args.output).expanduser()
|
||||
if output.exists() and not args.force:
|
||||
print(f"!! {output} already exists; pass --force to overwrite", file=sys.stderr)
|
||||
return 1
|
||||
try:
|
||||
config = Config.discover_from_app(app_path, use_cache=not args.no_cache)
|
||||
except DiscoveryError as e:
|
||||
print(f"!! discovery failed: {e}", file=sys.stderr)
|
||||
return 1
|
||||
config.save(output)
|
||||
print(f"·· wrote {output}")
|
||||
print(f" app version : {config.app_version}")
|
||||
print(f" secret : {config.redacted_secret()}")
|
||||
print(f" bundle hash : {config.bundle_hash[:12]}…")
|
||||
print(f" cache key : {config.cache_key()[:12]}…")
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_refresh(args: argparse.Namespace) -> int:
|
||||
"""Re-discover and overwrite an existing config.
|
||||
|
||||
Bypasses the discovery cache so a launcher rebuild that changed the
|
||||
secret without changing the bundle still gets picked up.
|
||||
"""
|
||||
app_path = _resolve_app_path(args.app_path)
|
||||
output = Path(args.config).expanduser()
|
||||
try:
|
||||
config = Config.discover_from_app(app_path, use_cache=False)
|
||||
except DiscoveryError as e:
|
||||
print(f"!! discovery failed: {e}", file=sys.stderr)
|
||||
return 1
|
||||
config.save(output)
|
||||
print(f"·· refreshed {output}")
|
||||
print(f" app version : {config.app_version}")
|
||||
print(f" secret : {config.redacted_secret()}")
|
||||
print(f" bundle hash : {config.bundle_hash[:12]}…")
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_inspect(args: argparse.Namespace) -> int:
|
||||
"""Print a saved config; optionally verify freshness against a local app.
|
||||
|
||||
Verification is opt-in via `--verify`, `--app-path`, or `--quiet` (which
|
||||
implies verification). Without any of those, this is a pure offline dump
|
||||
of the config file — same behavior as the original `inspect`.
|
||||
|
||||
Exit codes:
|
||||
0 — config loaded; verified current OR no verification requested.
|
||||
1 — config missing / invalid (legacy), OR verification reports stale.
|
||||
2 — verification was requested but the local app is unreachable
|
||||
(explicit `--app-path` missing, or `--quiet` without an
|
||||
autodetectable app).
|
||||
|
||||
`--quiet` suppresses output and is meant for shell scripts:
|
||||
`raycast-api inspect --quiet || raycast-api refresh`.
|
||||
"""
|
||||
verify_requested = args.verify or args.quiet or args.app_path is not None
|
||||
path = Path(args.config).expanduser()
|
||||
if not path.exists():
|
||||
if not args.quiet:
|
||||
print(f"!! no config at {path}", file=sys.stderr)
|
||||
return 1
|
||||
try:
|
||||
config = Config.load(path)
|
||||
except ConfigError as e:
|
||||
if not args.quiet:
|
||||
print(f"!! config invalid: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
app_path: Path | None = None
|
||||
if verify_requested:
|
||||
if args.app_path is not None:
|
||||
candidate = Path(args.app_path).expanduser().resolve()
|
||||
if not candidate.is_dir():
|
||||
if not args.quiet:
|
||||
print(f"!! app path not found: {candidate}", file=sys.stderr)
|
||||
return 2
|
||||
app_path = candidate
|
||||
else:
|
||||
app_path = _try_resolve_app_path(None)
|
||||
if app_path is None and args.quiet:
|
||||
return 2
|
||||
|
||||
comparison: ConfigComparison | None = None
|
||||
compare_error: str | None = None
|
||||
if app_path is not None:
|
||||
try:
|
||||
comparison = config.compare_with_app(app_path)
|
||||
except DiscoveryError as e:
|
||||
compare_error = str(e)
|
||||
if args.quiet:
|
||||
return 2
|
||||
|
||||
if args.quiet:
|
||||
return 0 if (comparison is not None and comparison.is_current) else 1
|
||||
|
||||
spec = config.signing_spec
|
||||
print(f"config_path : {path}")
|
||||
print(f"app_version : {config.app_version}")
|
||||
print(f"user_agent : {config.user_agent}")
|
||||
print(f"backend_url : {config.backend_url}")
|
||||
print(f"secret : {config.redacted_secret()}")
|
||||
print(f"bundle_hash : {config.bundle_hash}")
|
||||
print(f"launcher_hash : {config.launcher_hash}")
|
||||
print(f"cache_key : {config.cache_key()}")
|
||||
print(f"experimental : {config.experimental_header}")
|
||||
print(f"signing_spec : rot={spec.rot_fn_name} sign={spec.signing_fn_name}")
|
||||
print(f" ranges : {[(r.start, r.end, r.shift) for r in spec.rot_ranges]}")
|
||||
print(f" hmac : {spec.hmac_algorithm} body={spec.body_hash_algorithm}")
|
||||
print(f" join : {spec.join_char!r}")
|
||||
|
||||
if verify_requested:
|
||||
_print_status_block(app_path, comparison, compare_error)
|
||||
if comparison is not None and not comparison.is_current:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def _print_status_block(
|
||||
app_path: Path | None, comparison: ConfigComparison | None, error: str | None
|
||||
) -> None:
|
||||
"""Render the freshness section appended to `inspect` output."""
|
||||
print()
|
||||
if app_path is None:
|
||||
print(
|
||||
"status : unknown (no local Raycast.app found; "
|
||||
"pass --app-path to verify)"
|
||||
)
|
||||
return
|
||||
if comparison is None:
|
||||
print(f"status : unknown ({error or 'could not hash app'})")
|
||||
print(f" app path : {app_path}")
|
||||
return
|
||||
label = "CURRENT" if comparison.is_current else "STALE — run `raycast-api refresh`"
|
||||
print(f"status : {label}")
|
||||
print(f" app path : {app_path}")
|
||||
print(
|
||||
" bundle : "
|
||||
+ (
|
||||
"✓ matches"
|
||||
if comparison.bundle_matches
|
||||
else f"✗ saved {comparison.saved_bundle_hash[:12]}… → "
|
||||
f"current {comparison.current_bundle_hash[:12]}…"
|
||||
)
|
||||
)
|
||||
print(
|
||||
" launcher : "
|
||||
+ (
|
||||
"✓ matches"
|
||||
if comparison.launcher_matches
|
||||
else f"✗ saved {comparison.saved_launcher_hash[:12]}… → "
|
||||
f"current {comparison.current_launcher_hash[:12]}…"
|
||||
)
|
||||
)
|
||||
if comparison.app_version_matches:
|
||||
print(f" app version : ✓ {comparison.current_app_version}")
|
||||
else:
|
||||
print(
|
||||
f" app version : ≠ saved {comparison.saved_app_version!r} → "
|
||||
f"current {comparison.current_app_version!r} "
|
||||
"(informational; not a staleness signal on its own)"
|
||||
)
|
||||
|
||||
|
||||
async def _run_ask(args: argparse.Namespace) -> int:
|
||||
from raycast_api.ai import Message
|
||||
from raycast_api.client import Client
|
||||
|
||||
path = Path(args.config).expanduser() # noqa: ASYNC240 — sync I/O at CLI boundary
|
||||
if not path.exists():
|
||||
print(f"!! no config at {path}; run `raycast-api init` first", file=sys.stderr)
|
||||
return 1
|
||||
try:
|
||||
config = Config.load(path)
|
||||
except ConfigError as e:
|
||||
print(f"!! config invalid: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
bearer = args.bearer or os.environ.get("RAYCAST_BEARER")
|
||||
if not bearer:
|
||||
print(
|
||||
"!! missing bearer token; pass --bearer or set RAYCAST_BEARER",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 2
|
||||
|
||||
device_id = args.device_id or os.environ.get("RAYCAST_DEVICE_ID")
|
||||
if not device_id:
|
||||
device_id = _load_or_create_device_id()
|
||||
|
||||
model = args.model or os.environ.get("RAYCAST_MODEL")
|
||||
if not model:
|
||||
print(
|
||||
"!! no model specified; pass --model or set RAYCAST_MODEL", file=sys.stderr
|
||||
)
|
||||
return 2
|
||||
|
||||
prompt = args.prompt
|
||||
provider = args.provider or os.environ.get("RAYCAST_PROVIDER")
|
||||
|
||||
async with Client(
|
||||
config=config, bearer_token=bearer, device_id=device_id
|
||||
) as client:
|
||||
try:
|
||||
if args.stream:
|
||||
final_finish: str | None = None
|
||||
final_usage: dict | None = None
|
||||
async for chunk in client.chat.stream(
|
||||
model=model, provider=provider, messages=[Message.user(prompt)]
|
||||
):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
if chunk.finish_reason:
|
||||
final_finish = chunk.finish_reason
|
||||
if chunk.usage:
|
||||
final_usage = chunk.usage
|
||||
print()
|
||||
print(
|
||||
f"·· finish_reason={final_finish} usage={final_usage}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
result = await client.chat.complete(
|
||||
model=model, provider=provider, messages=[Message.user(prompt)]
|
||||
)
|
||||
print(result.text)
|
||||
print(
|
||||
f"·· finish_reason={result.finish_reason} usage={result.usage}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
except RaycastApiError as e:
|
||||
print(f"!! ask failed: {type(e).__name__}: {e}", file=sys.stderr)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_ask(args: argparse.Namespace) -> int:
|
||||
return asyncio.run(_run_ask(args))
|
||||
|
||||
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="raycast-api",
|
||||
description="Bring-your-own-credentials client for the Raycast backend.",
|
||||
)
|
||||
sub = parser.add_subparsers(dest="cmd", required=True)
|
||||
|
||||
init_p = sub.add_parser("init", help="discover and save a config.json")
|
||||
init_p.add_argument(
|
||||
"--app-path", help="path to Raycast.app (autodetected if omitted)"
|
||||
)
|
||||
init_p.add_argument(
|
||||
"--output",
|
||||
default=str(DEFAULT_CONFIG_PATH),
|
||||
help=f"output file (default: {DEFAULT_CONFIG_PATH})",
|
||||
)
|
||||
init_p.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="overwrite the output file if it already exists",
|
||||
)
|
||||
init_p.add_argument(
|
||||
"--no-cache", action="store_true", help="bypass the discovery cache"
|
||||
)
|
||||
init_p.set_defaults(func=_cmd_init)
|
||||
|
||||
refresh_p = sub.add_parser(
|
||||
"refresh", help="re-run discovery (overwrites the config; bypasses the cache)"
|
||||
)
|
||||
refresh_p.add_argument(
|
||||
"--app-path", help="path to Raycast.app (autodetected if omitted)"
|
||||
)
|
||||
refresh_p.add_argument(
|
||||
"--config",
|
||||
default=str(DEFAULT_CONFIG_PATH),
|
||||
help=f"config file to overwrite (default: {DEFAULT_CONFIG_PATH})",
|
||||
)
|
||||
refresh_p.set_defaults(func=_cmd_refresh)
|
||||
|
||||
inspect_p = sub.add_parser(
|
||||
"inspect",
|
||||
help="print a saved config and verify freshness against the local app",
|
||||
)
|
||||
inspect_p.add_argument(
|
||||
"--config",
|
||||
default=str(DEFAULT_CONFIG_PATH),
|
||||
help=f"config file to read (default: {DEFAULT_CONFIG_PATH})",
|
||||
)
|
||||
inspect_p.add_argument(
|
||||
"--verify",
|
||||
action="store_true",
|
||||
help="check freshness against an autodetected local Raycast install",
|
||||
)
|
||||
inspect_p.add_argument(
|
||||
"--app-path",
|
||||
help=(
|
||||
"path to Raycast.app for freshness verification "
|
||||
"(implies --verify; failure if the path doesn't exist)"
|
||||
),
|
||||
)
|
||||
inspect_p.add_argument(
|
||||
"--quiet",
|
||||
action="store_true",
|
||||
help=(
|
||||
"no output; implies --verify. Exit 0=current, 1=stale, "
|
||||
"2=unverifiable (no app available)"
|
||||
),
|
||||
)
|
||||
inspect_p.set_defaults(func=_cmd_inspect)
|
||||
|
||||
ask_p = sub.add_parser("ask", help="run a one-shot chat completion")
|
||||
ask_p.add_argument("prompt", help="the user prompt")
|
||||
ask_p.add_argument(
|
||||
"--config",
|
||||
default=str(DEFAULT_CONFIG_PATH),
|
||||
help=f"config file to read (default: {DEFAULT_CONFIG_PATH})",
|
||||
)
|
||||
ask_p.add_argument(
|
||||
"--bearer", help="Raycast OAuth bearer token (or set RAYCAST_BEARER)"
|
||||
)
|
||||
ask_p.add_argument(
|
||||
"--device-id",
|
||||
help="64-hex device id (or set RAYCAST_DEVICE_ID; auto-persisted by default)",
|
||||
)
|
||||
ask_p.add_argument("--model", help="model id or catalog id (or set RAYCAST_MODEL)")
|
||||
ask_p.add_argument(
|
||||
"--provider",
|
||||
help="explicit provider (skips catalog lookup; or set RAYCAST_PROVIDER)",
|
||||
)
|
||||
ask_p.add_argument(
|
||||
"--stream", action="store_true", help="stream tokens to stdout as they arrive"
|
||||
)
|
||||
ask_p.set_defaults(func=_cmd_ask)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv: Sequence[str] | None = None) -> int:
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
return int(args.func(args) or 0)
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
raise SystemExit(main())
|
||||
@@ -0,0 +1,21 @@
|
||||
"""HTTP client and SSE streaming for the Raycast backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from raycast_api.client.http import Client
|
||||
from raycast_api.client.retry import (
|
||||
DEFAULT_RETRY_STATUSES,
|
||||
RetryPolicy,
|
||||
parse_retry_after,
|
||||
)
|
||||
from raycast_api.client.streaming import SSEEvent, SSEParser, iter_sse
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_RETRY_STATUSES",
|
||||
"Client",
|
||||
"RetryPolicy",
|
||||
"SSEEvent",
|
||||
"SSEParser",
|
||||
"iter_sse",
|
||||
"parse_retry_after",
|
||||
]
|
||||
@@ -0,0 +1,436 @@
|
||||
"""Async HTTP client for the Raycast backend.
|
||||
|
||||
`Client` is the single entry point used by the higher-level endpoint
|
||||
wrappers (Phase 5). It:
|
||||
|
||||
- owns (or borrows) an `aiohttp.ClientSession`,
|
||||
- builds the full Raycast header set on every request — Bearer +
|
||||
User-Agent + the four `X-Raycast-*` signing headers + optional
|
||||
`Last-Event-ID` for resume + WebView fluff,
|
||||
- delegates signature computation to a `Signer` constructed from the
|
||||
`Config.signing_spec`,
|
||||
- retries 429/5xx with exponential backoff (one retry per attempt
|
||||
re-signs with a fresh timestamp),
|
||||
- maps server errors to typed exceptions,
|
||||
- exposes `stream(...)` as an async generator over parsed SSE events.
|
||||
|
||||
Streaming responses are NOT auto-retried. Once the first chunk has reached
|
||||
the caller, replaying the request would duplicate output; resume via
|
||||
`Last-Event-ID` (see `is_resume=True`) is the supported recovery path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Self
|
||||
|
||||
import aiohttp
|
||||
|
||||
from raycast_api.client.retry import RetryPolicy, parse_retry_after
|
||||
from raycast_api.client.streaming import SSEEvent, SSEParser
|
||||
from raycast_api.errors import (
|
||||
AuthError,
|
||||
HTTPStatusError,
|
||||
RateLimitError,
|
||||
StreamError,
|
||||
TransportError,
|
||||
)
|
||||
from raycast_api.signing import Signer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
|
||||
|
||||
from raycast_api.ai.chat import ChatAPI
|
||||
from raycast_api.ai.files import FilesAPI
|
||||
from raycast_api.ai.me import MeAPI
|
||||
from raycast_api.ai.models import ModelsAPI, ModelsResponse
|
||||
from raycast_api.config import Config
|
||||
|
||||
_BROWSER_FLUFF: dict[str, str] = {
|
||||
"Accept": "*/*",
|
||||
"Origin": "file://",
|
||||
"Sec-Fetch-Site": "cross-site",
|
||||
"Sec-Fetch-Mode": "cors",
|
||||
"Sec-Fetch-Dest": "empty",
|
||||
}
|
||||
|
||||
|
||||
class Client:
|
||||
"""Signed HTTP client over `aiohttp`.
|
||||
|
||||
Construction:
|
||||
|
||||
async with Client(
|
||||
config=cfg,
|
||||
bearer_token="rca_...",
|
||||
device_id="<64 hex>",
|
||||
) as client:
|
||||
async with client.request("GET", "/api/v1/me", sign=False) as resp:
|
||||
me = await resp.json()
|
||||
|
||||
`device_id` is opaque to the server but must be a stable 64-char hex
|
||||
string per install (Phase 6 CLI will generate one); for tests anything
|
||||
that's 64 hex chars works.
|
||||
|
||||
`bearer_token` is the user's OAuth access token. Pass an empty string
|
||||
to send requests without `Authorization` — useful for endpoints that
|
||||
accept anonymous calls (none observed so far, but kept open).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: Config,
|
||||
bearer_token: str,
|
||||
device_id: str,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
signer: Signer | None = None,
|
||||
retry: RetryPolicy | None = None,
|
||||
locale: str = "en-US",
|
||||
browser_headers: bool = True,
|
||||
models: ModelsResponse | None = None,
|
||||
clock: Callable[[], int] | None = None,
|
||||
sleep: Callable[[float], Awaitable[None]] | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.bearer_token = bearer_token
|
||||
self.device_id = device_id
|
||||
self.locale = locale
|
||||
self.browser_headers = browser_headers
|
||||
self._retry = retry or RetryPolicy()
|
||||
self._clock = clock or (lambda: int(time.time()))
|
||||
self._sleep = sleep or asyncio.sleep
|
||||
self._signer = signer or Signer(
|
||||
spec=config.signing_spec, secret=config.signature_secret
|
||||
)
|
||||
self._session = session
|
||||
self._owned_session = session is None
|
||||
self._chat: ChatAPI | None = None
|
||||
self._models: ModelsAPI | None = None
|
||||
self._me: MeAPI | None = None
|
||||
self._files: FilesAPI | None = None
|
||||
self._models_catalog: ModelsResponse | None = models
|
||||
self._models_catalog_lock: asyncio.Lock | None = None
|
||||
|
||||
@property
|
||||
def chat(self) -> ChatAPI:
|
||||
"""Chat completions API."""
|
||||
if self._chat is None:
|
||||
from raycast_api.ai.chat import ChatAPI
|
||||
|
||||
self._chat = ChatAPI(self)
|
||||
return self._chat
|
||||
|
||||
@property
|
||||
def models(self) -> ModelsAPI:
|
||||
"""Models catalog API."""
|
||||
if self._models is None:
|
||||
from raycast_api.ai.models import ModelsAPI
|
||||
|
||||
self._models = ModelsAPI(self)
|
||||
return self._models
|
||||
|
||||
@property
|
||||
def me(self) -> MeAPI:
|
||||
"""Account info API."""
|
||||
if self._me is None:
|
||||
from raycast_api.ai.me import MeAPI
|
||||
|
||||
self._me = MeAPI(self)
|
||||
return self._me
|
||||
|
||||
@property
|
||||
def files(self) -> FilesAPI:
|
||||
"""File upload API."""
|
||||
if self._files is None:
|
||||
from raycast_api.ai.files import FilesAPI
|
||||
|
||||
self._files = FilesAPI(self)
|
||||
return self._files
|
||||
|
||||
async def _get_models_catalog(self) -> ModelsResponse:
|
||||
"""Return the cached `/ai/models` response, fetching it once on first use.
|
||||
|
||||
Used by `ChatAPI._resolve_model` to look up the wire `model` + `provider`
|
||||
for a caller-supplied string id when no `provider=` was passed. Single
|
||||
round-trip per Client lifetime; subsequent calls reuse the cache.
|
||||
"""
|
||||
if self._models_catalog is not None:
|
||||
return self._models_catalog
|
||||
if self._models_catalog_lock is None:
|
||||
self._models_catalog_lock = asyncio.Lock()
|
||||
async with self._models_catalog_lock:
|
||||
if self._models_catalog is None:
|
||||
self._models_catalog = await self.models.list()
|
||||
return self._models_catalog
|
||||
|
||||
def invalidate_models_cache(self) -> None:
|
||||
"""Drop the cached models catalog so the next resolution re-fetches it."""
|
||||
self._models_catalog = None
|
||||
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
if self._session is None:
|
||||
self._session = aiohttp.ClientSession()
|
||||
self._owned_session = True
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc_info: object) -> None:
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the owned session, if any. Idempotent."""
|
||||
if self._owned_session and self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
|
||||
def _url(self, path: str) -> str:
|
||||
if path.startswith(("http://", "https://")):
|
||||
return path
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
return self.config.backend_url + path
|
||||
|
||||
def build_headers(
|
||||
self,
|
||||
*,
|
||||
sign: bool,
|
||||
body: bytes,
|
||||
is_resume: bool = False,
|
||||
last_event_id: str | None = None,
|
||||
content_type: str | None = None,
|
||||
timestamp: str | None = None,
|
||||
extra: dict[str, str] | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Compose the full header set for one outgoing request.
|
||||
|
||||
Mirrors the captured curl in `_extracted/captures/request_simple.curl.txt`
|
||||
when called with the defaults the Phase 5 chat API will pass.
|
||||
|
||||
Exposed (rather than tucked inside `request()`) so tests can pin the
|
||||
exact set without hitting the wire, and so a caller doing something
|
||||
unusual — e.g. signing a multipart upload — can build their own
|
||||
request from these primitives.
|
||||
"""
|
||||
headers: dict[str, str] = {"User-Agent": self.config.user_agent}
|
||||
if self.bearer_token:
|
||||
headers["Authorization"] = f"Bearer {self.bearer_token}"
|
||||
if self.browser_headers:
|
||||
headers.update(_BROWSER_FLUFF)
|
||||
headers["Accept-Language"] = self.locale
|
||||
if content_type is not None and not is_resume:
|
||||
headers["Content-Type"] = content_type
|
||||
if sign:
|
||||
ts = timestamp if timestamp is not None else str(self._clock())
|
||||
sig = self._signer.sign(timestamp=ts, device_id=self.device_id, body=body)
|
||||
headers["X-Raycast-Timestamp"] = ts
|
||||
headers["X-Raycast-DeviceId"] = self.device_id
|
||||
headers["X-Raycast-Signature-v2"] = sig
|
||||
headers["X-Raycast-Experimental"] = self.config.experimental_header
|
||||
if last_event_id is not None:
|
||||
headers["Last-Event-ID"] = last_event_id
|
||||
if extra:
|
||||
headers.update(extra)
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _coerce_body(
|
||||
body: bytes | str | None, json_body: Any
|
||||
) -> tuple[bytes, str | None]:
|
||||
"""Return (body_bytes, content_type_or_None)."""
|
||||
if json_body is not None:
|
||||
if body is not None:
|
||||
msg = "pass either `body` or `json_body`, not both"
|
||||
raise ValueError(msg)
|
||||
return (
|
||||
json.dumps(json_body, separators=(",", ":"), ensure_ascii=False).encode(
|
||||
"utf-8"
|
||||
),
|
||||
"application/json",
|
||||
)
|
||||
if body is None:
|
||||
return b"", None
|
||||
if isinstance(body, str):
|
||||
return body.encode("utf-8"), "application/json"
|
||||
return bytes(body), "application/json"
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
body: bytes | str | None = None,
|
||||
json_body: Any = None,
|
||||
sign: bool = True,
|
||||
is_resume: bool = False,
|
||||
last_event_id: str | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
retry: RetryPolicy | None = None,
|
||||
) -> AsyncGenerator[aiohttp.ClientResponse, None]:
|
||||
"""Send one signed request and yield the open response.
|
||||
|
||||
Use as an async context manager so the response body is released
|
||||
even if the caller raises:
|
||||
|
||||
async with client.request("POST", "/api/v1/ai/files",
|
||||
json_body=blob) as resp:
|
||||
data = await resp.json()
|
||||
|
||||
Retries 429 and 5xx according to `retry` (or `self._retry` by
|
||||
default), re-signing each attempt with a fresh timestamp. The
|
||||
retry loop is unaware of streaming bodies — the streaming code
|
||||
path goes through `request()` too but only retries on the
|
||||
connection-establishment failures, not mid-stream.
|
||||
"""
|
||||
body_bytes, content_type = self._coerce_body(body, json_body)
|
||||
policy = retry or self._retry
|
||||
url = self._url(path)
|
||||
|
||||
attempt = 0
|
||||
while True:
|
||||
attempt += 1
|
||||
hdrs = self.build_headers(
|
||||
sign=sign,
|
||||
body=body_bytes,
|
||||
is_resume=is_resume,
|
||||
last_event_id=last_event_id,
|
||||
content_type=content_type,
|
||||
extra=headers,
|
||||
)
|
||||
session = self._require_session()
|
||||
|
||||
delay: float
|
||||
try:
|
||||
async with session.request(
|
||||
method, url, data=body_bytes or None, headers=hdrs, params=params
|
||||
) as resp:
|
||||
if resp.status < 400:
|
||||
yield resp
|
||||
return
|
||||
err = await self._build_status_error(resp)
|
||||
if not policy.should_retry(attempt, resp.status):
|
||||
raise err
|
||||
delay = policy.delay_for_attempt(attempt, err.retry_after)
|
||||
except aiohttp.ClientError as e:
|
||||
if attempt >= policy.max_attempts:
|
||||
raise TransportError(str(e)) from e
|
||||
delay = policy.delay_for_attempt(attempt, None)
|
||||
|
||||
await self._sleep(delay)
|
||||
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
body: bytes | str | None = None,
|
||||
json_body: Any = None,
|
||||
sign: bool = True,
|
||||
is_resume: bool = False,
|
||||
last_event_id: str | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
on_last_event_id: Callable[[str], None] | None = None,
|
||||
raise_on_error_event: bool = True,
|
||||
) -> AsyncIterator[SSEEvent]:
|
||||
"""Open an SSE request and yield each parsed event.
|
||||
|
||||
Termination semantics:
|
||||
|
||||
- The async generator naturally stops when the response body
|
||||
closes. Raycast's terminators (`event: complete` and the legacy
|
||||
`data: [DONE]`) are yielded to the caller but don't break the
|
||||
loop here — letting the caller see them lets it distinguish
|
||||
"stream ended cleanly" from "connection died mid-way".
|
||||
- `event: error` chunks are yielded as `SSEEvent` and, if
|
||||
`raise_on_error_event` is True (default), also raise
|
||||
`StreamError` immediately after — but the consumer's `async for`
|
||||
will have already seen the error event in the previous yield.
|
||||
- `on_last_event_id` is called with the latest `id:` value each
|
||||
time it advances. Useful for callers that want to checkpoint
|
||||
for resume without keeping a reference to the parser.
|
||||
"""
|
||||
async with self.request(
|
||||
method,
|
||||
path,
|
||||
body=body,
|
||||
json_body=json_body,
|
||||
sign=sign,
|
||||
is_resume=is_resume,
|
||||
last_event_id=last_event_id,
|
||||
headers=headers,
|
||||
params=params,
|
||||
retry=self._retry,
|
||||
) as resp:
|
||||
parser = SSEParser()
|
||||
try:
|
||||
async for chunk in resp.content.iter_any():
|
||||
for evt in parser.feed(chunk):
|
||||
if evt.id is not None and on_last_event_id is not None:
|
||||
on_last_event_id(evt.id)
|
||||
yield evt
|
||||
if evt.is_error and raise_on_error_event:
|
||||
payload: Any
|
||||
try:
|
||||
payload = evt.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
payload = None
|
||||
raise StreamError(payload, raw=evt.data)
|
||||
for evt in parser.flush():
|
||||
if evt.id is not None and on_last_event_id is not None:
|
||||
on_last_event_id(evt.id)
|
||||
yield evt
|
||||
if evt.is_error and raise_on_error_event:
|
||||
try:
|
||||
payload = evt.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
payload = None
|
||||
raise StreamError(payload, raw=evt.data)
|
||||
except aiohttp.ClientError as e:
|
||||
raise TransportError(str(e)) from e
|
||||
|
||||
|
||||
def _require_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None:
|
||||
msg = (
|
||||
"Client.session not initialised; use `async with Client(...)` "
|
||||
"or pass an explicit aiohttp.ClientSession."
|
||||
)
|
||||
raise RuntimeError(
|
||||
msg
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def _build_status_error(
|
||||
self, resp: aiohttp.ClientResponse
|
||||
) -> HTTPStatusError:
|
||||
try:
|
||||
body_bytes = await resp.read()
|
||||
except aiohttp.ClientError: # pragma: no cover — defensive
|
||||
body_bytes = b""
|
||||
body_text = body_bytes.decode("utf-8", errors="replace")
|
||||
retry_after = parse_retry_after(resp.headers.get("Retry-After"))
|
||||
headers_dict = dict(resp.headers.items())
|
||||
message = resp.reason or ""
|
||||
cls: type[HTTPStatusError]
|
||||
if resp.status == 401:
|
||||
cls = AuthError
|
||||
elif resp.status == 429:
|
||||
cls = RateLimitError
|
||||
else:
|
||||
cls = HTTPStatusError
|
||||
return cls(
|
||||
resp.status,
|
||||
message,
|
||||
body=body_text,
|
||||
retry_after=retry_after,
|
||||
headers=headers_dict,
|
||||
)
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Retry policy for the HTTP client.
|
||||
|
||||
The Raycast backend doesn't document its rate limits, but the standard
|
||||
HTTP conventions apply: 429 means slow down (and may carry a `Retry-After`
|
||||
header), 5xx means transient server error worth retrying once or twice.
|
||||
|
||||
Kept as a pure helper so it's trivial to unit-test without aiohttp in the
|
||||
loop — the `Client` consults `RetryPolicy.should_retry(...)` and
|
||||
`RetryPolicy.delay_for_attempt(...)` and does the actual sleeping itself.
|
||||
Streaming responses are NOT retried automatically: once we've started
|
||||
yielding SSE events to the caller, replaying the request would duplicate
|
||||
output. The caller is expected to use the resume mechanism (Last-Event-ID)
|
||||
for that case, which is a different code path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import email.utils
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
DEFAULT_RETRY_STATUSES: frozenset[int] = frozenset({408, 425, 429, 500, 502, 503, 504})
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RetryPolicy:
|
||||
"""Exponential backoff schedule.
|
||||
|
||||
Total wall-clock budget for a request is roughly
|
||||
`initial * (multiplier^max_attempts - 1) / (multiplier - 1)`; with the
|
||||
defaults that's about 7 seconds across 4 attempts, which is the right
|
||||
order of magnitude for an interactive call.
|
||||
"""
|
||||
|
||||
max_attempts: int = 4
|
||||
initial_delay: float = 0.5
|
||||
max_delay: float = 30.0
|
||||
multiplier: float = 2.0
|
||||
retry_statuses: frozenset[int] = DEFAULT_RETRY_STATUSES
|
||||
respect_retry_after: bool = True
|
||||
|
||||
def should_retry(self, attempt: int, status: int) -> bool:
|
||||
"""`attempt` is 1-indexed (the attempt that just failed)."""
|
||||
if attempt >= self.max_attempts:
|
||||
return False
|
||||
return status in self.retry_statuses
|
||||
|
||||
def delay_for_attempt(
|
||||
self, attempt: int, retry_after: float | None = None
|
||||
) -> float:
|
||||
"""Delay before attempt N+1 (where `attempt` is the attempt that failed).
|
||||
|
||||
`retry_after`, when not None and `respect_retry_after` is True, takes
|
||||
precedence over the backoff schedule — but we still clamp to
|
||||
`max_delay` so a hostile / buggy server can't park the client.
|
||||
"""
|
||||
if retry_after is not None and self.respect_retry_after:
|
||||
return max(0.0, min(retry_after, self.max_delay))
|
||||
delay = self.initial_delay * (self.multiplier ** (attempt - 1))
|
||||
return min(delay, self.max_delay)
|
||||
|
||||
|
||||
def parse_retry_after(value: str | None, *, now: float | None = None) -> float | None:
|
||||
"""Parse a `Retry-After` header value into seconds.
|
||||
|
||||
The spec allows either an integer "delay seconds" form or an HTTP-date.
|
||||
Returns None if the header is missing or unparseable. Negative results
|
||||
(clock skew, past dates) are clamped to 0.
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
value = value.strip()
|
||||
try:
|
||||
return max(0.0, float(value))
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
parsed = email.utils.parsedate_to_datetime(value)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
if parsed is None:
|
||||
return None
|
||||
target = parsed.timestamp()
|
||||
current = now if now is not None else time.time()
|
||||
return max(0.0, target - current)
|
||||
@@ -0,0 +1,171 @@
|
||||
r"""SSE parser for the Raycast streaming endpoint.
|
||||
|
||||
The wire format is the standard W3C `text/event-stream`, reduced to what
|
||||
the Raycast backend actually emits (per `BUNDLE_NOTES.md` §SSE):
|
||||
|
||||
- Line endings: `\n` or `\r\n` (both stripped).
|
||||
- `:` prefix → comment, ignored.
|
||||
- Field separator: first `:` on the line; one optional space after the colon
|
||||
is stripped from the value.
|
||||
- Recognised fields: `id`, `event`, `data`. Multiple `data` lines per event
|
||||
are joined with `\n`. Unknown fields are dropped.
|
||||
- Event boundary: an empty line. EOF flushes any pending event.
|
||||
|
||||
The parser is byte-stream-driven (`SSEParser.feed(chunk_bytes)`) so it can
|
||||
sit directly behind `aiohttp.StreamResponse.content` and not care about
|
||||
chunk boundaries — a line can be split across two TCP chunks and the parser
|
||||
still emits the right event.
|
||||
|
||||
It tracks `last_event_id` for caller-driven resume bookkeeping: every event
|
||||
that carries an `id:` field updates the attribute, and that's the value the
|
||||
caller should pass back in `Last-Event-ID` on the resume GET.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SSEEvent:
|
||||
r"""One parsed event from the stream.
|
||||
|
||||
- `id`: the `id:` field if present on any line of the event, else None.
|
||||
- `event`: the `event:` field. None means the default event (most chunks).
|
||||
Raycast uses `complete` for the terminator and `error` for failures.
|
||||
- `data`: the `data:` payload as a string. If the server sent multiple
|
||||
`data:` lines, they're joined with `\n`. Empty string when no `data:`
|
||||
line was sent (rare — typically heartbeat comments use the `:` form
|
||||
instead and don't produce an event at all).
|
||||
"""
|
||||
|
||||
id: str | None
|
||||
event: str | None
|
||||
data: str
|
||||
|
||||
def json(self) -> Any:
|
||||
"""Parse `data` as JSON. Raises `json.JSONDecodeError` on bad payloads."""
|
||||
return json.loads(self.data)
|
||||
|
||||
@property
|
||||
def is_terminal(self) -> bool:
|
||||
"""True for the two end-of-stream markers Raycast emits.
|
||||
|
||||
- `event: complete` with `data: {"complete":true}` — production form
|
||||
- `data: [DONE]` (no event) — legacy form, still supported by client
|
||||
"""
|
||||
if self.event == "complete":
|
||||
return True
|
||||
return self.event is None and self.data.strip() == "[DONE]"
|
||||
|
||||
@property
|
||||
def is_error(self) -> bool:
|
||||
return self.event == "error"
|
||||
|
||||
|
||||
class SSEParser:
|
||||
"""Stateful byte-stream SSE parser.
|
||||
|
||||
Feed it bytes as they arrive; it yields `SSEEvent`s when complete events
|
||||
are buffered. Holds line + event state across `.feed()` calls so chunk
|
||||
boundaries are transparent.
|
||||
"""
|
||||
|
||||
__slots__ = ("_buf", "_data", "_event", "_id", "last_event_id")
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buf = bytearray()
|
||||
self._id: str | None = None
|
||||
self._event: str | None = None
|
||||
self._data: list[str] = []
|
||||
self.last_event_id: str | None = None
|
||||
|
||||
def feed(self, chunk: bytes) -> Iterable[SSEEvent]:
|
||||
"""Push bytes into the parser; yield any events that completed."""
|
||||
if not chunk:
|
||||
return
|
||||
self._buf.extend(chunk)
|
||||
while True:
|
||||
nl = self._buf.find(b"\n")
|
||||
if nl == -1:
|
||||
return
|
||||
raw = bytes(self._buf[:nl])
|
||||
del self._buf[: nl + 1]
|
||||
if raw.endswith(b"\r"):
|
||||
raw = raw[:-1]
|
||||
event = self._consume_line(raw)
|
||||
if event is not None:
|
||||
yield event
|
||||
|
||||
def flush(self) -> Iterable[SSEEvent]:
|
||||
"""Yield any pending event at EOF.
|
||||
|
||||
The W3C spec says EOF without a trailing blank line does NOT dispatch
|
||||
the pending event, but Raycast's parser (`Rkt` in the bundle) does
|
||||
flush — matching that behavior keeps us robust to abruptly-closed
|
||||
connections that hold the last chunk.
|
||||
"""
|
||||
if self._buf:
|
||||
raw = bytes(self._buf)
|
||||
self._buf.clear()
|
||||
if raw.endswith(b"\r"):
|
||||
raw = raw[:-1]
|
||||
self._consume_line(raw)
|
||||
if self._data or self._event is not None or self._id is not None:
|
||||
yield self._dispatch()
|
||||
|
||||
|
||||
def _consume_line(self, raw: bytes) -> SSEEvent | None:
|
||||
if not raw:
|
||||
if self._data or self._event is not None or self._id is not None:
|
||||
return self._dispatch()
|
||||
return None
|
||||
|
||||
line = raw.decode("utf-8", errors="replace")
|
||||
|
||||
if line.startswith(":"):
|
||||
return None
|
||||
|
||||
if ":" in line:
|
||||
field_name, _, value = line.partition(":")
|
||||
value = value.removeprefix(" ")
|
||||
else:
|
||||
field_name, value = line, ""
|
||||
|
||||
if field_name == "data":
|
||||
self._data.append(value)
|
||||
elif field_name == "event":
|
||||
self._event = value
|
||||
elif field_name == "id":
|
||||
self._id = value
|
||||
self.last_event_id = value
|
||||
return None
|
||||
|
||||
def _dispatch(self) -> SSEEvent:
|
||||
data = "\n".join(self._data)
|
||||
event = SSEEvent(id=self._id, event=self._event, data=data)
|
||||
self._id = None
|
||||
self._event = None
|
||||
self._data = []
|
||||
return event
|
||||
|
||||
|
||||
async def iter_sse(byte_stream: AsyncIterable[bytes]) -> AsyncIterator[SSEEvent]:
|
||||
"""Async generator: bytes → `SSEEvent`s.
|
||||
|
||||
Useful when you have an aiohttp response and just want events out:
|
||||
|
||||
async for evt in iter_sse(resp.content.iter_any()):
|
||||
...
|
||||
"""
|
||||
parser = SSEParser()
|
||||
async for chunk in byte_stream:
|
||||
for evt in parser.feed(chunk):
|
||||
yield evt
|
||||
for evt in parser.flush():
|
||||
yield evt
|
||||
@@ -0,0 +1,294 @@
|
||||
"""Top-level Config dataclass.
|
||||
|
||||
Carries everything the runtime client needs that isn't a user credential:
|
||||
the signing spec, the signing secret, the per-build constants, and a few
|
||||
metadata fields. Built by `Config.discover_from_app(...)` or hand-loaded from
|
||||
JSON via `Config.load(...)`.
|
||||
|
||||
The split between "discovered" and "user-supplied" is:
|
||||
|
||||
- Discovered (in this Config): backend URL, signing secret, signing spec,
|
||||
user-agent template, app version.
|
||||
- User-supplied (NOT in this Config): Bearer token, device id. These are
|
||||
per-user credentials that should never be persisted alongside discovery
|
||||
output. The HTTP client takes them as separate constructor args.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from raycast_api.discovery.cache import read_json, write_json
|
||||
from raycast_api.errors import ConfigError, DiscoveryError
|
||||
from raycast_api.signing_spec import SigningSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
DEFAULT_BACKEND_URL = "https://backend.raycast.com"
|
||||
DEFAULT_API_PREFIX = "/api/v1"
|
||||
DEFAULT_DEVICE_TAG_SALT = "xK7mQ2vLpN8wY4jR6tBfHsAeDc"
|
||||
DEFAULT_OAUTH_CLIENT_ID = "FRsHICIAlyPB_v2m4tfVqHtVUS40Ieco_da0Y6zBwgA"
|
||||
EXPERIMENTAL_HEADER_VALUE = "autoModels"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConfigComparison:
|
||||
"""Result of comparing a saved Config against a local Raycast install.
|
||||
|
||||
All three booleans need to be True for the config to be considered fully
|
||||
current. In practice `launcher_matches` is the load-bearing one — it
|
||||
catches secret rotation. `bundle_matches` catches JS-bundle rebuilds
|
||||
(the signing spec might have moved). `app_version_matches` is
|
||||
informational: it can drift independently of the hashes after a hot
|
||||
patch, and a mismatch alone doesn't necessarily mean re-discovery is
|
||||
required.
|
||||
"""
|
||||
|
||||
bundle_matches: bool
|
||||
launcher_matches: bool
|
||||
app_version_matches: bool
|
||||
saved_bundle_hash: str
|
||||
current_bundle_hash: str
|
||||
saved_launcher_hash: str
|
||||
current_launcher_hash: str
|
||||
saved_app_version: str
|
||||
current_app_version: str
|
||||
|
||||
@property
|
||||
def is_current(self) -> bool:
|
||||
"""True iff both hashes match. Version drift alone doesn't disqualify."""
|
||||
return self.bundle_matches and self.launcher_matches
|
||||
|
||||
def reasons(self) -> list[str]:
|
||||
"""Human-readable explanations for any drift, in display order."""
|
||||
out: list[str] = []
|
||||
if not self.bundle_matches:
|
||||
out.append("bundle rebuilt (signing spec may have moved)")
|
||||
if not self.launcher_matches:
|
||||
out.append("launcher rebuilt (secret may have rotated)")
|
||||
if not self.app_version_matches:
|
||||
out.append(
|
||||
f"app version {self.saved_app_version!r} → {self.current_app_version!r}"
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Runtime config for the Raycast client.
|
||||
|
||||
`signature_secret` is the only sensitive field — keep this file readable
|
||||
only by the user (`Config.save` writes with chmod 600).
|
||||
"""
|
||||
|
||||
signature_secret: str
|
||||
signing_spec: SigningSpec
|
||||
app_version: str
|
||||
user_agent: str
|
||||
bundle_hash: str
|
||||
launcher_hash: str
|
||||
backend_url: str = DEFAULT_BACKEND_URL
|
||||
api_prefix: str = DEFAULT_API_PREFIX
|
||||
oauth_client_id: str = DEFAULT_OAUTH_CLIENT_ID
|
||||
device_tag_salt: str = DEFAULT_DEVICE_TAG_SALT
|
||||
experimental_header: str = EXPERIMENTAL_HEADER_VALUE
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def cache_key(self) -> str:
|
||||
"""Identifier used by `DiscoveryCache` to find/replace this config.
|
||||
|
||||
SHA-256 of `bundle_hash || launcher_hash`. Either side changing —
|
||||
a JS bundle rebuild or a launcher rebuild that rotates the secret —
|
||||
produces a fresh key and invalidates the previous cache entry.
|
||||
"""
|
||||
return hashlib.sha256(
|
||||
(self.bundle_hash + self.launcher_hash).encode("ascii")
|
||||
).hexdigest()
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"schema_version": 1,
|
||||
"signature_secret": self.signature_secret,
|
||||
"signing_spec": self.signing_spec.to_dict(),
|
||||
"app_version": self.app_version,
|
||||
"user_agent": self.user_agent,
|
||||
"bundle_hash": self.bundle_hash,
|
||||
"launcher_hash": self.launcher_hash,
|
||||
"backend_url": self.backend_url,
|
||||
"api_prefix": self.api_prefix,
|
||||
"oauth_client_id": self.oauth_client_id,
|
||||
"device_tag_salt": self.device_tag_salt,
|
||||
"experimental_header": self.experimental_header,
|
||||
"extra": dict(self.extra),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Config:
|
||||
try:
|
||||
spec_raw = data["signing_spec"]
|
||||
except KeyError as e:
|
||||
msg = f"Missing required field: {e}"
|
||||
raise ConfigError(msg) from e
|
||||
if not isinstance(spec_raw, dict):
|
||||
msg = "signing_spec must be an object"
|
||||
raise ConfigError(msg)
|
||||
spec = SigningSpec.from_dict(spec_raw)
|
||||
try:
|
||||
return cls(
|
||||
signature_secret=str(data["signature_secret"]),
|
||||
signing_spec=spec,
|
||||
app_version=str(data["app_version"]),
|
||||
user_agent=str(data["user_agent"]),
|
||||
bundle_hash=str(data["bundle_hash"]),
|
||||
launcher_hash=str(data["launcher_hash"]),
|
||||
backend_url=str(data.get("backend_url", DEFAULT_BACKEND_URL)),
|
||||
api_prefix=str(data.get("api_prefix", DEFAULT_API_PREFIX)),
|
||||
oauth_client_id=str(
|
||||
data.get("oauth_client_id", DEFAULT_OAUTH_CLIENT_ID)
|
||||
),
|
||||
device_tag_salt=str(
|
||||
data.get("device_tag_salt", DEFAULT_DEVICE_TAG_SALT)
|
||||
),
|
||||
experimental_header=str(
|
||||
data.get("experimental_header", EXPERIMENTAL_HEADER_VALUE)
|
||||
),
|
||||
extra=dict(data.get("extra", {}) or {}),
|
||||
)
|
||||
except KeyError as e:
|
||||
msg = f"Missing required field: {e}"
|
||||
raise ConfigError(msg) from e
|
||||
|
||||
def save(self, path: Path) -> None:
|
||||
write_json(path, self.to_dict())
|
||||
with contextlib.suppress(OSError):
|
||||
path.chmod(0o600)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path) -> Config:
|
||||
return cls.from_dict(read_json(path))
|
||||
|
||||
|
||||
@classmethod
|
||||
def discover_from_app(
|
||||
cls,
|
||||
app_path: Path,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
platform_version: str | None = None,
|
||||
) -> Config:
|
||||
"""Run the full discovery pipeline against a local Raycast install.
|
||||
|
||||
Steps (all in `raycast_api.discovery`):
|
||||
|
||||
1. Locate `Contents/Resources/.../backend/index.mjs` and hash it.
|
||||
2. If `use_cache`, look up the cached Config by hash and return it.
|
||||
3. Extract `signature_secret` from the launcher binary.
|
||||
4. Parse the bundle and derive the `SigningSpec` structurally.
|
||||
5. Read `CFBundleShortVersionString` from Info.plist for the
|
||||
User-Agent template.
|
||||
6. Persist into the cache.
|
||||
|
||||
Returns the Config. Raises `DiscoveryError` if any step can't find
|
||||
what it expects (likely meaning Raycast changed its layout).
|
||||
"""
|
||||
from raycast_api.discovery.binary import find_signature_secret, launcher_hash
|
||||
from raycast_api.discovery.bundle import (
|
||||
bundle_hash,
|
||||
find_index_mjs,
|
||||
read_bundle_source,
|
||||
)
|
||||
from raycast_api.discovery.cache import DiscoveryCache
|
||||
from raycast_api.discovery.extractors import (
|
||||
extract_signing_spec,
|
||||
extract_user_agent_template,
|
||||
read_app_version,
|
||||
)
|
||||
|
||||
if not app_path.is_dir():
|
||||
msg = f"app_path is not a directory: {app_path}"
|
||||
raise DiscoveryError(msg)
|
||||
|
||||
index_mjs = find_index_mjs(app_path)
|
||||
bhash = bundle_hash(index_mjs)
|
||||
lhash = launcher_hash(app_path)
|
||||
combined_key = hashlib.sha256((bhash + lhash).encode("ascii")).hexdigest()
|
||||
|
||||
cache = DiscoveryCache() if use_cache else None
|
||||
if cache is not None:
|
||||
cached = cache.get(combined_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
secret = find_signature_secret(app_path)
|
||||
spec = extract_signing_spec(read_bundle_source(index_mjs))
|
||||
version = read_app_version(app_path)
|
||||
ua = extract_user_agent_template(app_path, platform_version=platform_version)
|
||||
|
||||
config = cls(
|
||||
signature_secret=secret,
|
||||
signing_spec=spec,
|
||||
app_version=version,
|
||||
user_agent=ua,
|
||||
bundle_hash=bhash,
|
||||
launcher_hash=lhash,
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
cache.set(combined_key, config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def compare_with_app(self, app_path: Path) -> ConfigComparison:
|
||||
"""Re-hash the local Raycast install and compare to the saved config.
|
||||
|
||||
Returns a `ConfigComparison` with per-field booleans plus the current
|
||||
and saved hashes side-by-side, suitable for printing. Raises
|
||||
`DiscoveryError` if the local app can't be hashed at all (missing
|
||||
bundle / unreadable launcher) — callers that want to report "unknown"
|
||||
instead of an error should catch it themselves.
|
||||
|
||||
This does NOT re-run the AST extractor; it only re-hashes. A bundle
|
||||
whose hash matches the saved value is assumed to have the same
|
||||
signing spec — that's the whole premise of the discovery cache.
|
||||
"""
|
||||
from raycast_api.discovery.binary import launcher_hash
|
||||
from raycast_api.discovery.bundle import bundle_hash, find_index_mjs
|
||||
from raycast_api.discovery.extractors import read_app_version
|
||||
|
||||
if not app_path.is_dir():
|
||||
msg = f"app_path is not a directory: {app_path}"
|
||||
raise DiscoveryError(msg)
|
||||
current_bundle = bundle_hash(find_index_mjs(app_path))
|
||||
current_launcher = launcher_hash(app_path)
|
||||
try:
|
||||
current_version = read_app_version(app_path)
|
||||
except DiscoveryError:
|
||||
current_version = ""
|
||||
return ConfigComparison(
|
||||
bundle_matches=current_bundle == self.bundle_hash,
|
||||
launcher_matches=current_launcher == self.launcher_hash,
|
||||
app_version_matches=current_version == self.app_version,
|
||||
saved_bundle_hash=self.bundle_hash,
|
||||
current_bundle_hash=current_bundle,
|
||||
saved_launcher_hash=self.launcher_hash,
|
||||
current_launcher_hash=current_launcher,
|
||||
saved_app_version=self.app_version,
|
||||
current_app_version=current_version,
|
||||
)
|
||||
|
||||
def is_current_for(self, app_path: Path) -> bool:
|
||||
"""Convenience wrapper: True iff both hashes match the local install."""
|
||||
return self.compare_with_app(app_path).is_current
|
||||
|
||||
def redacted_secret(self) -> str:
|
||||
"""Return a masked form of the secret (last 4 chars) for display."""
|
||||
if len(self.signature_secret) <= 4:
|
||||
return "*" * len(self.signature_secret)
|
||||
return "…" + self.signature_secret[-4:]
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Discovery — extract signing config from a local Raycast install at runtime.
|
||||
|
||||
Submodules are imported lazily so partial-import scenarios (during package
|
||||
construction or partial-install tests) don't blow up.
|
||||
"""
|
||||
@@ -0,0 +1,380 @@
|
||||
"""JavaScript AST utilities for structural function matching.
|
||||
|
||||
The Raycast Node bundle is 4 MB of minified, modern ES (import.meta, class
|
||||
fields, etc.) and the Python `esprima` port — being a 2018-era ES2017 parser —
|
||||
cannot parse it whole. We work around that by:
|
||||
|
||||
1. Scanning the source byte-by-byte (string/comment/regex-aware) for
|
||||
`function NAME(...)` and `async function NAME(...)` *declarations* and
|
||||
extracting each function's complete source via brace matching.
|
||||
2. Parsing each candidate's source — small, syntactically vanilla — with
|
||||
esprima individually.
|
||||
3. Applying structural matchers on the resulting AST.
|
||||
|
||||
This is enough for our needs (locate signing fn + rot fn), avoids depending on
|
||||
a separate Node.js parse step, and survives the bundle being re-minified or
|
||||
rearranged between releases. The structural patterns we match against are
|
||||
documented in `BUNDLE_NOTES.md` §2.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import esprima
|
||||
|
||||
from raycast_api.errors import DiscoveryError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Iterator
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionInfo:
|
||||
"""Located function declaration with its AST and source snippet.
|
||||
|
||||
`ast` is the dict form of esprima's FunctionDeclaration node — we use the
|
||||
plain-dict form (via `toDict()`) throughout so downstream matchers don't
|
||||
need to know esprima's object model.
|
||||
"""
|
||||
|
||||
name: str
|
||||
is_async: bool
|
||||
params: list[str]
|
||||
source: str
|
||||
"""Full function source, from `function`/`async function` keyword through
|
||||
the closing `}`."""
|
||||
|
||||
body_source: str
|
||||
"""Body source only, including the surrounding `{}`."""
|
||||
|
||||
ast: dict[str, Any]
|
||||
"""esprima dict for the FunctionDeclaration node."""
|
||||
|
||||
start: int
|
||||
"""Byte offset of the keyword in the original source."""
|
||||
|
||||
|
||||
|
||||
_DECL_RE = re.compile(
|
||||
r"(?:^|[^A-Za-z0-9_$])(async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\("
|
||||
)
|
||||
|
||||
|
||||
def iter_function_declarations(source: str) -> Iterator[FunctionInfo]:
|
||||
"""Yield every `function NAME(...){...}` declaration found in `source`.
|
||||
|
||||
Each FunctionInfo is parsed individually with esprima; declarations whose
|
||||
bodies fail to parse are skipped (useful: the bundle contains some odd
|
||||
minifier output that's syntactically borderline — we don't care about those
|
||||
anyway because they're not our signing fn).
|
||||
"""
|
||||
for keyword_start, body_start in _iter_decl_positions(source):
|
||||
end = _find_matching_brace(source, body_start)
|
||||
if end == -1:
|
||||
continue
|
||||
snippet = source[keyword_start : end + 1]
|
||||
info = _try_parse_function(snippet, keyword_start)
|
||||
if info is not None:
|
||||
yield info
|
||||
|
||||
|
||||
def _iter_decl_positions(source: str) -> Iterator[tuple[int, int]]:
|
||||
"""Yield (keyword_start, body_open_brace) offsets for each function declaration."""
|
||||
for match in _DECL_RE.finditer(source):
|
||||
if match.group(1) is not None:
|
||||
kw_start = match.start(1)
|
||||
else:
|
||||
kw_start = source.rindex("function", match.start(), match.end())
|
||||
paren_end = _find_matching_paren(source, match.end() - 1)
|
||||
if paren_end == -1:
|
||||
continue
|
||||
i = paren_end + 1
|
||||
while i < len(source) and source[i] in " \t\r\n":
|
||||
i += 1
|
||||
if i >= len(source) or source[i] != "{":
|
||||
continue
|
||||
yield kw_start, i
|
||||
|
||||
|
||||
def _try_parse_function(snippet: str, abs_start: int) -> FunctionInfo | None:
|
||||
try:
|
||||
tree = esprima.parseScript(snippet, {"tolerant": False})
|
||||
except esprima.Error:
|
||||
return None
|
||||
if not tree.body or tree.body[0].type != "FunctionDeclaration":
|
||||
return None
|
||||
fn_node = tree.body[0]
|
||||
ast_dict = fn_node.toDict()
|
||||
params: list[str] = []
|
||||
for p in ast_dict.get("params", []):
|
||||
if p.get("type") == "Identifier":
|
||||
params.append(p["name"])
|
||||
else:
|
||||
params.append("")
|
||||
body_open = snippet.index("{")
|
||||
body_close = _find_matching_brace(snippet, body_open)
|
||||
body_source = snippet[body_open : body_close + 1] if body_close != -1 else ""
|
||||
return FunctionInfo(
|
||||
name=ast_dict.get("id", {}).get("name", ""),
|
||||
is_async=bool(ast_dict.get("async", False)),
|
||||
params=params,
|
||||
source=snippet,
|
||||
body_source=body_source,
|
||||
ast=ast_dict,
|
||||
start=abs_start,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def _find_matching_brace(source: str, open_pos: int) -> int:
|
||||
"""Return the index of the `}` that matches the `{` at `open_pos`, or -1.
|
||||
|
||||
Tracks string literals (`"`, `'`, template literals), block/line comments,
|
||||
and regex literals so curly braces inside them don't count.
|
||||
"""
|
||||
return _find_matching(source, open_pos, "{", "}")
|
||||
|
||||
|
||||
def _find_matching_paren(source: str, open_pos: int) -> int:
|
||||
return _find_matching(source, open_pos, "(", ")")
|
||||
|
||||
|
||||
def _find_matching(source: str, open_pos: int, open_ch: str, close_ch: str) -> int:
|
||||
if source[open_pos] != open_ch:
|
||||
msg = f"expected {open_ch!r} at position {open_pos}, got {source[open_pos]!r}"
|
||||
raise ValueError(msg)
|
||||
n = len(source)
|
||||
depth = 0
|
||||
i = open_pos
|
||||
prev_significant = ""
|
||||
while i < n:
|
||||
ch = source[i]
|
||||
|
||||
if ch in ("'", '"'):
|
||||
i = _skip_string(source, i, ch)
|
||||
prev_significant = ch
|
||||
continue
|
||||
|
||||
if ch == "`":
|
||||
i = _skip_template(source, i)
|
||||
prev_significant = "`"
|
||||
continue
|
||||
|
||||
if ch == "/" and i + 1 < n:
|
||||
nxt = source[i + 1]
|
||||
if nxt == "/":
|
||||
i = source.find("\n", i + 2)
|
||||
if i == -1:
|
||||
return -1
|
||||
i += 1
|
||||
continue
|
||||
if nxt == "*":
|
||||
end = source.find("*/", i + 2)
|
||||
if end == -1:
|
||||
return -1
|
||||
i = end + 2
|
||||
continue
|
||||
if _can_start_regex(prev_significant):
|
||||
i = _skip_regex(source, i)
|
||||
prev_significant = "/"
|
||||
continue
|
||||
|
||||
if ch == open_ch:
|
||||
depth += 1
|
||||
elif ch == close_ch:
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return i
|
||||
|
||||
if not ch.isspace():
|
||||
prev_significant = ch
|
||||
i += 1
|
||||
return -1
|
||||
|
||||
|
||||
def _skip_string(source: str, i: int, quote: str) -> int:
|
||||
n = len(source)
|
||||
j = i + 1
|
||||
while j < n:
|
||||
c = source[j]
|
||||
if c == "\\":
|
||||
j += 2
|
||||
continue
|
||||
if c == quote:
|
||||
return j + 1
|
||||
if c == "\n":
|
||||
return j + 1
|
||||
j += 1
|
||||
return n
|
||||
|
||||
|
||||
def _skip_template(source: str, i: int) -> int:
|
||||
"""Skip a template literal `` `...${expr}...` `` starting at backtick `i`."""
|
||||
n = len(source)
|
||||
j = i + 1
|
||||
while j < n:
|
||||
c = source[j]
|
||||
if c == "\\":
|
||||
j += 2
|
||||
continue
|
||||
if c == "`":
|
||||
return j + 1
|
||||
if c == "$" and j + 1 < n and source[j + 1] == "{":
|
||||
close = _find_matching_brace(source, j + 1)
|
||||
if close == -1:
|
||||
return n
|
||||
j = close + 1
|
||||
continue
|
||||
j += 1
|
||||
return n
|
||||
|
||||
|
||||
def _skip_regex(source: str, i: int) -> int:
|
||||
"""Skip a `/regex/flags` literal starting at the `/` at index `i`."""
|
||||
n = len(source)
|
||||
j = i + 1
|
||||
in_class = False
|
||||
while j < n:
|
||||
c = source[j]
|
||||
if c == "\\":
|
||||
j += 2
|
||||
continue
|
||||
if c == "[":
|
||||
in_class = True
|
||||
elif c == "]":
|
||||
in_class = False
|
||||
elif c == "/" and not in_class:
|
||||
j += 1
|
||||
while j < n and source[j].isalpha():
|
||||
j += 1
|
||||
return j
|
||||
elif c == "\n":
|
||||
return j
|
||||
j += 1
|
||||
return n
|
||||
|
||||
|
||||
def _can_start_regex(prev: str) -> bool:
|
||||
if prev == "":
|
||||
return True
|
||||
return not (prev.isalnum() or prev in "_$)]")
|
||||
|
||||
|
||||
|
||||
|
||||
def find_function_by_shape(
|
||||
functions: list[FunctionInfo],
|
||||
*,
|
||||
is_async: bool | None = None,
|
||||
param_count: int | None = None,
|
||||
body_contains_all: list[str] | None = None,
|
||||
body_contains_any: list[str] | None = None,
|
||||
name_equals: str | None = None,
|
||||
custom: list[Callable[[FunctionInfo], bool]] | None = None,
|
||||
) -> list[FunctionInfo]:
|
||||
"""Filter `functions` to those matching all supplied predicates.
|
||||
|
||||
Substring matching on body source is intentional: minifiers rename
|
||||
identifiers but they preserve string literals, numeric literals, and
|
||||
standard library identifiers like `crypto.subtle`. The structural
|
||||
fingerprints in `BUNDLE_NOTES.md` are expressed in terms of these stable
|
||||
substrings.
|
||||
|
||||
For more specific structural checks (e.g. "this function calls X"),
|
||||
pass a `custom` predicate that walks `fn.ast`.
|
||||
"""
|
||||
results: list[FunctionInfo] = []
|
||||
for fn in functions:
|
||||
if is_async is not None and fn.is_async != is_async:
|
||||
continue
|
||||
if param_count is not None and len(fn.params) != param_count:
|
||||
continue
|
||||
if name_equals is not None and fn.name != name_equals:
|
||||
continue
|
||||
if body_contains_all and not all(
|
||||
s in fn.body_source for s in body_contains_all
|
||||
):
|
||||
continue
|
||||
if body_contains_any and not any(
|
||||
s in fn.body_source for s in body_contains_any
|
||||
):
|
||||
continue
|
||||
if custom and not all(p(fn) for p in custom):
|
||||
continue
|
||||
results.append(fn)
|
||||
return results
|
||||
|
||||
|
||||
def walk_ast(node: Any) -> Iterator[dict[str, Any]]:
|
||||
"""Depth-first iterator over every dict-shaped AST node."""
|
||||
if isinstance(node, dict):
|
||||
yield node
|
||||
for v in node.values():
|
||||
yield from walk_ast(v)
|
||||
elif isinstance(node, list):
|
||||
for v in node:
|
||||
yield from walk_ast(v)
|
||||
|
||||
|
||||
def find_calls(fn: FunctionInfo, callee_name: str) -> list[dict[str, Any]]:
|
||||
"""Return CallExpression nodes whose callee is the identifier `callee_name`.
|
||||
|
||||
Also matches member-expression callees like `obj.callee_name` — uses the
|
||||
`property.name` in that case. This is liberal on purpose: minifier-output
|
||||
sometimes has the callee referenced via short helper variables, and we'd
|
||||
rather over-match here and let the outer predicate verify.
|
||||
"""
|
||||
out: list[dict[str, Any]] = []
|
||||
for node in walk_ast(fn.ast):
|
||||
if not (isinstance(node, dict) and node.get("type") == "CallExpression"):
|
||||
continue
|
||||
callee = node.get("callee", {})
|
||||
if callee.get("type") == "Identifier" and callee.get("name") == callee_name:
|
||||
out.append(node)
|
||||
elif callee.get("type") == "MemberExpression":
|
||||
prop = callee.get("property", {})
|
||||
if prop.get("type") == "Identifier" and prop.get("name") == callee_name:
|
||||
out.append(node)
|
||||
return out
|
||||
|
||||
|
||||
def has_string_literal(fn: FunctionInfo, value: str) -> bool:
|
||||
"""True if a string literal with the given value appears in the function body."""
|
||||
for node in walk_ast(fn.ast):
|
||||
if (
|
||||
isinstance(node, dict)
|
||||
and node.get("type") == "Literal"
|
||||
and node.get("value") == value
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def collect_numeric_literals(fn: FunctionInfo) -> set[int]:
|
||||
"""Return the set of integer literal values used inside the function."""
|
||||
out: set[int] = set()
|
||||
for node in walk_ast(fn.ast):
|
||||
if not (isinstance(node, dict) and node.get("type") == "Literal"):
|
||||
continue
|
||||
v = node.get("value")
|
||||
if isinstance(v, (int, float)) and float(v).is_integer():
|
||||
out.add(int(v))
|
||||
return out
|
||||
|
||||
|
||||
def assert_one(matches: list[FunctionInfo], what: str) -> FunctionInfo:
|
||||
"""Helper that turns ambiguous match results into clear DiscoveryError messages."""
|
||||
if not matches:
|
||||
msg = f"No function matched: {what}"
|
||||
raise DiscoveryError(msg)
|
||||
if len(matches) > 1:
|
||||
names = ", ".join(f.name for f in matches[:5])
|
||||
raise DiscoveryError(
|
||||
f"Multiple functions matched {what}: {names}"
|
||||
+ (" ..." if len(matches) > 5 else "")
|
||||
)
|
||||
return matches[0]
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Extract `window.signatureSecret = '<64 hex>'` from the Raycast launcher binary.
|
||||
|
||||
The Swift launcher hardcodes the per-build signing secret as a string literal so it
|
||||
can inject it into the WebView via `window.signatureSecret = '...'`. The same value
|
||||
is also forwarded across UniFFI into the Rust dylib as the `Secrets` record and
|
||||
appears verbatim in the binary's `__cstring` section. Scanning the raw bytes for
|
||||
the literal pattern is sufficient — no Mach-O parsing required.
|
||||
|
||||
Pattern observed in Raycast Beta 0.60.1.0:
|
||||
|
||||
window.signatureSecret = '6bc4...1408'
|
||||
|
||||
The secret is always 64 lowercase hex characters (32 bytes encoded as hex). Per
|
||||
HANDOFF.md the key is used AS-IS (utf-8 of the 64 ASCII chars), not hex-decoded.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from raycast_api.errors import DiscoveryError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
_SECRET_RE = re.compile(rb"window\.signatureSecret\s*=\s*(['\"])([0-9a-f]{64})\1")
|
||||
|
||||
|
||||
def find_signature_secret(app_path: Path) -> str:
|
||||
"""Locate and return the signature secret from the app's launcher binary.
|
||||
|
||||
`app_path` is the path to the .app bundle (e.g. "Raycast Beta.app"). The
|
||||
launcher binary lives at `Contents/MacOS/<basename minus .app>`.
|
||||
|
||||
Raises `DiscoveryError` if the binary isn't found or the pattern doesn't
|
||||
match — both indicate that Raycast changed the injection mechanism.
|
||||
"""
|
||||
binary = _resolve_launcher_binary(app_path)
|
||||
data = binary.read_bytes()
|
||||
match = _SECRET_RE.search(data)
|
||||
if not match:
|
||||
msg = (
|
||||
f"Could not find `window.signatureSecret = '<hex>'` in {binary}. "
|
||||
"Raycast may have changed how the secret is embedded."
|
||||
)
|
||||
raise DiscoveryError(
|
||||
msg
|
||||
)
|
||||
return match.group(2).decode("ascii")
|
||||
|
||||
|
||||
def launcher_hash(app_path: Path) -> str:
|
||||
"""SHA-256 (hex) of the launcher binary.
|
||||
|
||||
The launcher carries the signing secret, which can rotate per release
|
||||
without necessarily forcing a JS bundle rebuild. We include this in the
|
||||
discovery cache key so a secret-only rotation invalidates the cache too.
|
||||
"""
|
||||
binary = _resolve_launcher_binary(app_path)
|
||||
h = hashlib.sha256()
|
||||
with binary.open("rb") as f:
|
||||
for chunk in iter(lambda: f.read(1 << 20), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _resolve_launcher_binary(app_path: Path) -> Path:
|
||||
"""Return the path to the Mach-O launcher inside an .app bundle."""
|
||||
macos_dir = app_path / "Contents" / "MacOS"
|
||||
if not macos_dir.is_dir():
|
||||
msg = f"Not an app bundle (missing {macos_dir}): {app_path}"
|
||||
raise DiscoveryError(msg)
|
||||
|
||||
expected_name = app_path.name.removesuffix(".app")
|
||||
candidate = macos_dir / expected_name
|
||||
if candidate.is_file():
|
||||
return candidate
|
||||
|
||||
children = [p for p in macos_dir.iterdir() if p.is_file()]
|
||||
if len(children) == 1:
|
||||
return children[0]
|
||||
|
||||
msg = (
|
||||
f"Could not identify launcher binary in {macos_dir}; "
|
||||
f"expected {expected_name!r}, found {[p.name for p in children]}"
|
||||
)
|
||||
raise DiscoveryError(
|
||||
msg
|
||||
)
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Locate the Node backend bundle inside a Raycast .app and read its source.
|
||||
|
||||
Bundle layout observed in Raycast Beta 0.60.1.0:
|
||||
|
||||
Raycast Beta.app/Contents/Resources/macos-app_RaycastDesktopApp.bundle/
|
||||
Contents/Resources/backend/index.mjs
|
||||
|
||||
The minified bundle is ~4 MB and contains the signing function we want to AST-match
|
||||
against. We hash it for cache invalidation and read its source straight from disk —
|
||||
beautification is not required for esprima (it doesn't care about whitespace), so
|
||||
we skip the prettier/js-beautify step that Phase 1 used for human reading.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
from raycast_api.errors import DiscoveryError
|
||||
|
||||
_DESKTOP_BUNDLE_NAME = "macos-app_RaycastDesktopApp.bundle"
|
||||
_BACKEND_INDEX_RELATIVE = Path("Contents/Resources/backend/index.mjs")
|
||||
|
||||
|
||||
def locate_node_bundle(app_path: Path) -> Path:
|
||||
"""Return the path to the embedded RaycastDesktopApp sub-bundle."""
|
||||
resources = app_path / "Contents" / "Resources"
|
||||
direct = resources / _DESKTOP_BUNDLE_NAME
|
||||
if direct.is_dir():
|
||||
return direct
|
||||
|
||||
matches = list(resources.glob("*RaycastDesktopApp*.bundle"))
|
||||
if len(matches) == 1:
|
||||
return matches[0]
|
||||
if not matches:
|
||||
msg = f"No RaycastDesktopApp bundle under {resources}"
|
||||
raise DiscoveryError(msg)
|
||||
names = [m.name for m in matches]
|
||||
msg = f"Multiple RaycastDesktopApp bundles under {resources}: {names}"
|
||||
raise DiscoveryError(
|
||||
msg
|
||||
)
|
||||
|
||||
|
||||
def find_index_mjs(bundle_or_app_path: Path) -> Path:
|
||||
"""Return `backend/index.mjs` given either an app bundle or a sub-bundle.
|
||||
|
||||
Accepting both lets callers pass `Raycast Beta.app` directly or a
|
||||
pre-located desktop bundle (e.g. in tests).
|
||||
"""
|
||||
candidates = [
|
||||
bundle_or_app_path / _BACKEND_INDEX_RELATIVE,
|
||||
locate_node_bundle(bundle_or_app_path) / _BACKEND_INDEX_RELATIVE
|
||||
if (bundle_or_app_path / "Contents" / "Resources").is_dir()
|
||||
else None,
|
||||
]
|
||||
for c in candidates:
|
||||
if c and c.is_file():
|
||||
return c
|
||||
msg = f"Could not find backend/index.mjs starting from {bundle_or_app_path}"
|
||||
raise DiscoveryError(
|
||||
msg
|
||||
)
|
||||
|
||||
|
||||
def read_bundle_source(index_mjs_path: Path) -> str:
|
||||
"""Read index.mjs and return it as a UTF-8 string."""
|
||||
return index_mjs_path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def bundle_hash(index_mjs_path: Path) -> str:
|
||||
"""SHA-256 (hex) of the bundle file, used as a cache key.
|
||||
|
||||
Any change to the bundle — new Raycast version, hotfix, secret rotation
|
||||
(the secret itself isn't in this file, but the bundle is rebuilt around it)
|
||||
— produces a different hash and invalidates the cached config.
|
||||
"""
|
||||
h = hashlib.sha256()
|
||||
with index_mjs_path.open("rb") as f:
|
||||
for chunk in iter(lambda: f.read(1 << 20), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
@@ -0,0 +1,84 @@
|
||||
"""On-disk cache of discovered configs, keyed by the bundle's SHA-256.
|
||||
|
||||
Discovery is fast (~2 s on a 6 MB bundle) but not free, and we'd rather not
|
||||
repeat the work on every CLI invocation. We key the cache on the bundle hash
|
||||
so a Raycast update — which always rebuilds the bundle — invalidates
|
||||
automatically.
|
||||
|
||||
Cache files live under `$XDG_CACHE_HOME/raycast-api/` (defaulting to
|
||||
`~/.cache/raycast-api/`), each named `<bundle_sha256>.json`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from raycast_api.config import Config
|
||||
|
||||
|
||||
def default_cache_dir() -> Path:
|
||||
"""Return the cache directory, honoring XDG_CACHE_HOME if set."""
|
||||
xdg = os.environ.get("XDG_CACHE_HOME")
|
||||
base = Path(xdg) if xdg else Path.home() / ".cache"
|
||||
return base / "raycast-api"
|
||||
|
||||
|
||||
class DiscoveryCache:
|
||||
"""File-hash-keyed cache for `Config` blobs."""
|
||||
|
||||
def __init__(self, root: Path | None = None) -> None:
|
||||
self.root = root or default_cache_dir()
|
||||
|
||||
def path_for(self, bundle_hash: str) -> Path:
|
||||
return self.root / f"{bundle_hash}.json"
|
||||
|
||||
def get(self, bundle_hash: str) -> Config | None:
|
||||
from raycast_api.config import Config
|
||||
|
||||
path = self.path_for(bundle_hash)
|
||||
if not path.is_file():
|
||||
return None
|
||||
try:
|
||||
return Config.load(path)
|
||||
except (OSError, ValueError, KeyError, TypeError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
def set(self, bundle_hash: str, config: Config) -> None:
|
||||
self.root.mkdir(parents=True, exist_ok=True)
|
||||
config.save(self.path_for(bundle_hash))
|
||||
|
||||
def clear(self) -> None:
|
||||
if not self.root.is_dir():
|
||||
return
|
||||
for p in self.root.glob("*.json"):
|
||||
p.unlink()
|
||||
|
||||
def _all(self) -> list[Path]:
|
||||
if not self.root.is_dir():
|
||||
return []
|
||||
return sorted(self.root.glob("*.json"))
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover — debug aid
|
||||
return f"DiscoveryCache(root={self.root!r})"
|
||||
|
||||
|
||||
__all__ = ["DiscoveryCache", "default_cache_dir"]
|
||||
|
||||
|
||||
def write_json(path: Path, payload: dict[str, object]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
tmp.write_text(json.dumps(payload, indent=2, sort_keys=False), encoding="utf-8")
|
||||
tmp.replace(path)
|
||||
|
||||
|
||||
def read_json(path: Path) -> dict[str, object]:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
msg = f"Expected JSON object at {path}, got {type(data).__name__}"
|
||||
raise TypeError(msg)
|
||||
return data
|
||||
@@ -0,0 +1,302 @@
|
||||
"""High-level extractors that turn a parsed bundle into a `SigningSpec`.
|
||||
|
||||
The structural matchers in `ast_parse` give us candidate functions; the
|
||||
extractors here verify those candidates against the documented shape of
|
||||
`Sur`/`Nkt` (rot + HMAC signer) and tease out the few parameters that aren't
|
||||
visually identical between minified builds:
|
||||
|
||||
- rot ranges and shifts (might rotate keys differently in future builds)
|
||||
- canonical-string join character (currently ".")
|
||||
- body-hash algorithm (currently "SHA-256")
|
||||
- HMAC hash (currently "SHA-256")
|
||||
- key encoding (currently utf-8 of the hex string AS-IS, per HANDOFF.md)
|
||||
|
||||
This produces a `SigningSpec` that the runtime signer consumes — `sign.py`'s
|
||||
constants become `SigningSpec` fields, so the whole signing pipeline is
|
||||
data-driven and re-derives on every Raycast update.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import plistlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from raycast_api.discovery.ast_parse import (
|
||||
FunctionInfo,
|
||||
find_calls,
|
||||
find_function_by_shape,
|
||||
has_string_literal,
|
||||
iter_function_declarations,
|
||||
walk_ast,
|
||||
)
|
||||
from raycast_api.errors import DiscoveryError
|
||||
from raycast_api.signing_spec import RotRange, SigningSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
__all__ = ["extract_signing_spec", "extract_user_agent_template", "read_app_version"]
|
||||
|
||||
|
||||
def extract_signing_spec(bundle_source: str) -> SigningSpec:
|
||||
"""Find the signing primitives in the bundle and return a `SigningSpec`.
|
||||
|
||||
Strategy (in order):
|
||||
|
||||
1. Enumerate all top-level function declarations.
|
||||
2. Find the rot function by the literal triplet `(65, 90, 13)`,
|
||||
`(97, 122, 13)`, `(48, 57, 5)` co-located inside one 1-param fn.
|
||||
Minifier renames don't touch numeric literals; these five constants
|
||||
(plus 26 and 10) uniquely identify rot13+rot5.
|
||||
3. Find the signing function: an async 4-param fn that imports an HMAC key
|
||||
and calls .map(<rotFnName>) on a 3-element array. The verifier here is
|
||||
strict: we require it to mention the rot fn we just found by name, so
|
||||
we can't accidentally pick up an unrelated HMAC routine.
|
||||
4. Read the join character from the signing fn's `.join(...)` call and
|
||||
the digest/HMAC algorithm strings from the crypto.subtle calls.
|
||||
|
||||
`bundle_source` is the raw JS text — we don't need a pre-beautified copy.
|
||||
"""
|
||||
fns = list(iter_function_declarations(bundle_source))
|
||||
if not fns:
|
||||
msg = "No function declarations found in bundle source"
|
||||
raise DiscoveryError(msg)
|
||||
|
||||
rot, signing = _find_rot_and_signing(fns)
|
||||
join_char = _extract_join_char(signing, rot_fn_name=rot.name)
|
||||
digest_algo = _extract_digest_algo(signing)
|
||||
hmac_algo = _extract_hmac_algo(signing)
|
||||
|
||||
rot_ranges = _extract_rot_ranges(rot)
|
||||
|
||||
return SigningSpec(
|
||||
rot_fn_name=rot.name,
|
||||
signing_fn_name=signing.name,
|
||||
rot_ranges=rot_ranges,
|
||||
join_char=join_char,
|
||||
body_hash_algorithm=digest_algo,
|
||||
hmac_algorithm=hmac_algo,
|
||||
key_encoding="utf-8",
|
||||
output_encoding="hex-lower",
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def _find_rot_and_signing(fns: list[FunctionInfo]) -> tuple[FunctionInfo, FunctionInfo]:
|
||||
"""Find a (rot, signing) pair where signing calls .map(rot.name).
|
||||
|
||||
Several rot candidates may exist (the bundle has two byte-identical copies,
|
||||
`Sur` and `Tur`). Several signing candidates similarly (`Nkt` is unique by
|
||||
the 4-param shape but if a future build splits the canonical path, we want
|
||||
to handle the ambiguity). We resolve by requiring the rot fn referenced by
|
||||
the signing fn's `.map(...)` call to be among the rot candidates.
|
||||
"""
|
||||
rot_candidates = find_function_by_shape(
|
||||
fns, param_count=1, custom=[_has_required_rot_triplets]
|
||||
)
|
||||
if not rot_candidates:
|
||||
msg = (
|
||||
"No rot13+rot5 candidate "
|
||||
"(1 param, all of (65,90,13)/(97,122,13)/(48,57,5))"
|
||||
)
|
||||
raise DiscoveryError(msg)
|
||||
rot_by_name = {f.name: f for f in rot_candidates}
|
||||
|
||||
sign_candidates = find_function_by_shape(
|
||||
fns,
|
||||
is_async=True,
|
||||
param_count=4,
|
||||
body_contains_all=["HMAC", "SHA-256", "importKey"],
|
||||
custom=[
|
||||
lambda f: has_string_literal(f, "HMAC"),
|
||||
lambda f: has_string_literal(f, "SHA-256"),
|
||||
],
|
||||
)
|
||||
if not sign_candidates:
|
||||
msg = "No signing candidate (async, 4 params, HMAC+SHA-256+importKey)"
|
||||
raise DiscoveryError(
|
||||
msg
|
||||
)
|
||||
|
||||
pairs: list[tuple[FunctionInfo, FunctionInfo]] = []
|
||||
for sign in sign_candidates:
|
||||
for name in _map_argument_identifiers(sign):
|
||||
if name in rot_by_name:
|
||||
pairs.append((rot_by_name[name], sign))
|
||||
break
|
||||
if not pairs:
|
||||
msg = (
|
||||
f"Found rot candidates {list(rot_by_name)} and signing candidates "
|
||||
f"{[s.name for s in sign_candidates]} but none of the signers calls "
|
||||
f".map(<rotName>)"
|
||||
)
|
||||
raise DiscoveryError(
|
||||
msg
|
||||
)
|
||||
return pairs[0]
|
||||
|
||||
|
||||
def _map_argument_identifiers(fn: FunctionInfo) -> list[str]:
|
||||
"""Return identifier names passed to any `.map(...)` call inside fn."""
|
||||
names: list[str] = []
|
||||
for node in find_calls(fn, "map"):
|
||||
args = node.get("arguments", [])
|
||||
if (
|
||||
len(args) == 1
|
||||
and isinstance(args[0], dict)
|
||||
and args[0].get("type") == "Identifier"
|
||||
):
|
||||
names.append(args[0].get("name", ""))
|
||||
return names
|
||||
|
||||
|
||||
def _has_required_rot_triplets(fn: FunctionInfo) -> bool:
|
||||
"""True iff the fn contains all three (start, end, shift) triplets as numerics.
|
||||
|
||||
We don't try to parse the *structure* of the conditional chain — too many
|
||||
valid shapes (if/else vs ternary vs switch). The numeric fingerprint is
|
||||
enough on its own; the 1-param shape filter prevents false positives from
|
||||
unrelated maths.
|
||||
"""
|
||||
nums = _collect_numeric_literals(fn)
|
||||
needed = {65, 90, 13, 26, 97, 122, 48, 57, 5, 10}
|
||||
return needed.issubset(nums)
|
||||
|
||||
|
||||
def _collect_numeric_literals(fn: FunctionInfo) -> set[int]:
|
||||
out: set[int] = set()
|
||||
for node in walk_ast(fn.ast):
|
||||
if not (isinstance(node, dict) and node.get("type") == "Literal"):
|
||||
continue
|
||||
v = node.get("value")
|
||||
if isinstance(v, (int, float)) and float(v).is_integer():
|
||||
out.add(int(v))
|
||||
return out
|
||||
|
||||
|
||||
def _extract_rot_ranges(rot: FunctionInfo) -> list[RotRange]: # noqa: ARG001 — kept for future structural derivation
|
||||
"""Return the rot transform parameters as a list of (start, end, shift) ranges.
|
||||
|
||||
For now we hardcode the three triplets we matched against — the structural
|
||||
matcher already confirmed they're present. If future builds add/remove a
|
||||
range, this is the place to teach the extractor to walk the conditional chain
|
||||
and discover them dynamically.
|
||||
"""
|
||||
return [
|
||||
RotRange(start=65, end=90, shift=13),
|
||||
RotRange(start=97, end=122, shift=13),
|
||||
RotRange(start=48, end=57, shift=5),
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
def _extract_join_char(fn: FunctionInfo, rot_fn_name: str) -> str:
|
||||
"""Find the `.join("X")` whose receiver is `<arr>.map(<rotName>)`.
|
||||
|
||||
The signing fn body has several `.map(...).join(...)` chains — the hex
|
||||
encoder uses `.join("")`, the canonical-string builder uses `.join(".")`.
|
||||
We pick the one whose `.map`'s sole argument is the rot fn identifier.
|
||||
"""
|
||||
for call in find_calls(fn, "join"):
|
||||
callee = call.get("callee", {})
|
||||
if callee.get("type") != "MemberExpression":
|
||||
continue
|
||||
receiver = callee.get("object", {})
|
||||
if receiver.get("type") != "CallExpression":
|
||||
continue
|
||||
r_callee = receiver.get("callee", {})
|
||||
if not (
|
||||
r_callee.get("type") == "MemberExpression"
|
||||
and r_callee.get("property", {}).get("name") == "map"
|
||||
):
|
||||
continue
|
||||
r_args = receiver.get("arguments", [])
|
||||
if not (
|
||||
len(r_args) == 1
|
||||
and r_args[0].get("type") == "Identifier"
|
||||
and r_args[0].get("name") == rot_fn_name
|
||||
):
|
||||
continue
|
||||
args = call.get("arguments", [])
|
||||
if args and args[0].get("type") == "Literal":
|
||||
val = args[0].get("value")
|
||||
if isinstance(val, str):
|
||||
return val
|
||||
msg = f"Could not find `.map({rot_fn_name}).join(<str>)` in `{fn.name}`"
|
||||
raise DiscoveryError(
|
||||
msg
|
||||
)
|
||||
|
||||
|
||||
def _extract_digest_algo(fn: FunctionInfo) -> str:
|
||||
"""Read the algorithm name from `crypto.subtle.digest("SHA-256", ...)`."""
|
||||
for call in find_calls(fn, "digest"):
|
||||
args = call.get("arguments", [])
|
||||
if args and args[0].get("type") == "Literal":
|
||||
v = args[0].get("value")
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
msg = f"No crypto.subtle.digest(...) call in `{fn.name}`"
|
||||
raise DiscoveryError(msg)
|
||||
|
||||
|
||||
def _extract_hmac_algo(fn: FunctionInfo) -> str:
|
||||
"""Read the `hash:"SHA-256"` from the HMAC importKey options.
|
||||
|
||||
Looks for an object literal containing { name: "HMAC", hash: "<algo>" }
|
||||
inside the signing fn. That's the importKey args[2] but we don't rely on
|
||||
position — we walk all ObjectExpressions and find the matching shape.
|
||||
"""
|
||||
for node in walk_ast(fn.ast):
|
||||
if not (isinstance(node, dict) and node.get("type") == "ObjectExpression"):
|
||||
continue
|
||||
props: dict[str, Any] = {}
|
||||
for prop in node.get("properties", []):
|
||||
key = prop.get("key", {})
|
||||
value = prop.get("value", {})
|
||||
if key.get("type") == "Identifier" and value.get("type") == "Literal":
|
||||
props[key["name"]] = value.get("value")
|
||||
if props.get("name") == "HMAC":
|
||||
hash_val = props.get("hash")
|
||||
if isinstance(hash_val, str):
|
||||
return hash_val
|
||||
msg = f"No {{name:'HMAC', hash:'...'}} object in `{fn.name}`"
|
||||
raise DiscoveryError(msg)
|
||||
|
||||
|
||||
|
||||
|
||||
def read_app_version(app_path: Path) -> str:
|
||||
"""Return `CFBundleShortVersionString` from the app's Info.plist."""
|
||||
plist_path = app_path / "Contents" / "Info.plist"
|
||||
if not plist_path.is_file():
|
||||
msg = f"Missing Info.plist at {plist_path}"
|
||||
raise DiscoveryError(msg)
|
||||
with plist_path.open("rb") as f:
|
||||
plist = plistlib.load(f)
|
||||
version = plist.get("CFBundleShortVersionString")
|
||||
if not isinstance(version, str):
|
||||
msg = f"No CFBundleShortVersionString in {plist_path}"
|
||||
raise DiscoveryError(msg)
|
||||
return version
|
||||
|
||||
|
||||
def extract_user_agent_template(
|
||||
app_path: Path, *, platform: str = "macOS", platform_version: str | None = None
|
||||
) -> str:
|
||||
"""Build the `User-Agent` header Raycast sends.
|
||||
|
||||
Template (BUNDLE_NOTES §6): `Raycast/<version> (x-<platform> Version <ver>)`.
|
||||
We default platform to "macOS" because the bundle is macOS-only; future
|
||||
Windows builds would need this hooked up to a platform argument.
|
||||
`platform_version` defaults to the host's macOS version, looked up at call
|
||||
time so a config written on one machine still serializes the host string.
|
||||
"""
|
||||
import platform as platform_mod
|
||||
|
||||
version = read_app_version(app_path)
|
||||
if platform_version is None:
|
||||
platform_version = platform_mod.mac_ver()[0] or "26.0"
|
||||
return f"Raycast/{version} (x-{platform} Version {platform_version})"
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Exception types for raycast_api."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RaycastApiError(Exception):
|
||||
"""Base for everything raycast_api raises."""
|
||||
|
||||
|
||||
class DiscoveryError(RaycastApiError):
|
||||
"""Failed to derive config from a local Raycast install.
|
||||
|
||||
Raised when binary parsing, bundle location, or AST extraction can't find
|
||||
what they expect — typically meaning Raycast changed its layout or signing
|
||||
shape and the library needs updating.
|
||||
"""
|
||||
|
||||
|
||||
class ConfigError(RaycastApiError):
|
||||
"""Loaded config is missing required fields or is internally inconsistent."""
|
||||
|
||||
|
||||
class TransportError(RaycastApiError):
|
||||
"""Network-level failure (DNS, connect, read timeout, dropped socket).
|
||||
|
||||
Wraps the underlying aiohttp / OSError so callers don't need to import
|
||||
aiohttp's exception hierarchy to catch them. The original exception is
|
||||
chained via `__cause__`.
|
||||
"""
|
||||
|
||||
|
||||
class HTTPStatusError(RaycastApiError):
|
||||
"""Server responded with a non-2xx status.
|
||||
|
||||
`body` is the response payload decoded as utf-8 with replacement;
|
||||
`retry_after` is the parsed `Retry-After` header in seconds if the server
|
||||
sent one (either as an integer or HTTP-date), otherwise None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status: int,
|
||||
message: str,
|
||||
*,
|
||||
body: str = "",
|
||||
retry_after: float | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
super().__init__(f"HTTP {status}: {message}" if message else f"HTTP {status}")
|
||||
self.status = status
|
||||
self.message = message
|
||||
self.body = body
|
||||
self.retry_after = retry_after
|
||||
self.headers = headers or {}
|
||||
|
||||
|
||||
class AuthError(HTTPStatusError):
|
||||
"""401 — bearer token rejected or signature invalid."""
|
||||
|
||||
|
||||
class RateLimitError(HTTPStatusError):
|
||||
"""429 — rate limited. `retry_after` is set when the server provides it."""
|
||||
|
||||
|
||||
class StreamError(RaycastApiError):
|
||||
"""SSE stream emitted an `event: error` chunk.
|
||||
|
||||
`payload` is the parsed JSON body of the error event (or None if it
|
||||
failed to parse); `raw` is the raw `data:` string for debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, payload: Any, *, raw: str = "") -> None:
|
||||
msg = ""
|
||||
if isinstance(payload, dict):
|
||||
msg = str(payload.get("message") or payload.get("error") or payload)
|
||||
elif payload is not None:
|
||||
msg = str(payload)
|
||||
super().__init__(msg or "stream error")
|
||||
self.payload = payload
|
||||
self.raw = raw
|
||||
@@ -0,0 +1,79 @@
|
||||
"""Request signer composing the discovered spec with a secret.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> from raycast_api.signing import Signer
|
||||
>>> signer = Signer(spec=config.signing_spec, secret=config.signature_secret)
|
||||
>>> sig = signer.sign(timestamp="1778858809", device_id="20ec…", body=b"{...}")
|
||||
|
||||
The `Signer` is cheap to instantiate (it caches the encoded key + hash factory
|
||||
on construction) and is safe to reuse for many requests against the same
|
||||
secret; create a new one if the secret rotates.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from raycast_api.signing.canonical import build_canonical
|
||||
from raycast_api.signing.hmac import HMACSigner, encode_key, encode_output, hash_body
|
||||
from raycast_api.signing.transforms import apply_rot
|
||||
from raycast_api.signing_spec import RotRange, SigningSpec
|
||||
|
||||
__all__ = [
|
||||
"HMACSigner",
|
||||
"RotRange",
|
||||
"Signer",
|
||||
"SigningSpec",
|
||||
"apply_rot",
|
||||
"build_canonical",
|
||||
"encode_key",
|
||||
"encode_output",
|
||||
"hash_body",
|
||||
]
|
||||
|
||||
|
||||
class Signer:
|
||||
"""Produce `X-Raycast-Signature-v2` values for outgoing requests.
|
||||
|
||||
Constructed from a `SigningSpec` (discovered once per bundle) and a secret
|
||||
(discovered once per launcher build). Holds no per-request state; the same
|
||||
instance can sign any number of requests in parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, *, spec: SigningSpec, secret: str) -> None:
|
||||
if not spec.rot_ranges:
|
||||
msg = "SigningSpec has no rot_ranges; discovery did not populate them"
|
||||
raise ValueError(
|
||||
msg
|
||||
)
|
||||
self._spec = spec
|
||||
self._hmac = HMACSigner(
|
||||
secret,
|
||||
algorithm=spec.hmac_algorithm,
|
||||
key_encoding=spec.key_encoding,
|
||||
output_encoding=spec.output_encoding,
|
||||
)
|
||||
|
||||
@property
|
||||
def spec(self) -> SigningSpec:
|
||||
return self._spec
|
||||
|
||||
def canonical_string(self, timestamp: str, device_id: str, body: bytes) -> str:
|
||||
"""Return the rot-transformed, joined canonical string (debug helper)."""
|
||||
body_hex = hash_body(body, self._spec.body_hash_algorithm)
|
||||
return build_canonical(
|
||||
(timestamp, device_id, body_hex),
|
||||
self._spec.rot_ranges,
|
||||
self._spec.join_char,
|
||||
)
|
||||
|
||||
def sign(self, *, timestamp: str, device_id: str, body: bytes) -> str:
|
||||
"""Compute the signature header value.
|
||||
|
||||
- `timestamp`: decimal seconds since epoch as a string. Use the SAME
|
||||
string in `X-Raycast-Timestamp` (not a re-stringified int) — the
|
||||
server hashes the byte representation, not the value.
|
||||
- `device_id`: lowercase 64-char hex string.
|
||||
- `body`: exact request body bytes. For GET resume requests pass `b""`.
|
||||
"""
|
||||
canonical = self.canonical_string(timestamp, device_id, body)
|
||||
return self._hmac.sign(canonical.encode("utf-8"))
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Canonical-string assembly.
|
||||
|
||||
Raycast's signing canonical string is, per `BUNDLE_NOTES.md`:
|
||||
|
||||
rot(timestamp) + "." + rot(device_id) + "." + rot(sha256_hex_lower(body))
|
||||
|
||||
The components are passed in as strings (already body-hashed), each is rot-
|
||||
transformed, and they're joined with `join_char`. Both the rot ranges and the
|
||||
join character come from `SigningSpec` — discovery, not hardcoded.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from raycast_api.signing.transforms import apply_rot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
from raycast_api.signing_spec import RotRange
|
||||
|
||||
|
||||
def build_canonical(
|
||||
components: Sequence[str], rot_ranges: Iterable[RotRange], join_char: str = "."
|
||||
) -> str:
|
||||
"""Apply `rot_ranges` to each component then join with `join_char`."""
|
||||
snapshot = tuple(rot_ranges)
|
||||
return join_char.join(apply_rot(c, snapshot) for c in components)
|
||||
@@ -0,0 +1,89 @@
|
||||
"""HMAC signing primitive driven by `SigningSpec`.
|
||||
|
||||
The JS implementation calls `crypto.subtle.importKey` with `{name:"HMAC",
|
||||
hash:"SHA-256"}` and `crypto.subtle.sign("HMAC", key, msg)`, then hex-encodes
|
||||
the bytes lowercase via `.padStart(2, "0")`. We mirror that here, parametrised
|
||||
on algorithm name, key encoding, and output encoding — all populated from the
|
||||
discovered spec so a future Raycast change only needs a discovery-side update.
|
||||
|
||||
Algorithm names follow the WebCrypto convention ("SHA-256", "SHA-1", …) since
|
||||
that's what the JS bundle exposes; we normalise to hashlib's casing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac as _hmac
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def _hash_factory(name: str) -> Callable[[], hashlib._Hash]:
|
||||
"""Resolve a WebCrypto-style hash name (`SHA-256`) to a hashlib factory.
|
||||
|
||||
Strips dashes and lowercases ("SHA-256" → "sha256"). Raises `ValueError`
|
||||
if the name isn't supported by the local OpenSSL build.
|
||||
"""
|
||||
normalised = name.replace("-", "").lower()
|
||||
if normalised not in hashlib.algorithms_available:
|
||||
msg = f"Unsupported hash algorithm: {name!r}"
|
||||
raise ValueError(msg)
|
||||
return lambda: hashlib.new(normalised)
|
||||
|
||||
|
||||
def encode_key(secret: str, encoding: str) -> bytes:
|
||||
"""Turn the secret string into key bytes per `SigningSpec.key_encoding`.
|
||||
|
||||
Real Raycast: `"utf-8"` — the 64-char ASCII hex secret is the key bytes
|
||||
AS-IS, NOT hex-decoded. Hex-decoding silently breaks signing (HANDOFF.md
|
||||
"Things that will bite the next person", #1).
|
||||
"""
|
||||
if encoding == "utf-8":
|
||||
return secret.encode("utf-8")
|
||||
if encoding == "ascii":
|
||||
return secret.encode("ascii")
|
||||
if encoding == "hex":
|
||||
return bytes.fromhex(secret)
|
||||
if encoding == "base64":
|
||||
return base64.b64decode(secret)
|
||||
msg = f"Unsupported key encoding: {encoding!r}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def encode_output(digest: bytes, encoding: str) -> str:
|
||||
"""Encode the raw HMAC bytes per `SigningSpec.output_encoding`."""
|
||||
if encoding == "hex-lower":
|
||||
return digest.hex()
|
||||
if encoding == "hex-upper":
|
||||
return digest.hex().upper()
|
||||
if encoding == "base64":
|
||||
return base64.b64encode(digest).decode("ascii")
|
||||
if encoding == "base64url":
|
||||
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||
msg = f"Unsupported output encoding: {encoding!r}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class HMACSigner:
|
||||
"""Computes `HMAC(key, msg)` per a `SigningSpec`-derived configuration."""
|
||||
|
||||
def __init__(
|
||||
self, secret: str, *, algorithm: str, key_encoding: str, output_encoding: str
|
||||
) -> None:
|
||||
self._key = encode_key(secret, key_encoding)
|
||||
self._factory = _hash_factory(algorithm)
|
||||
self._output_encoding = output_encoding
|
||||
|
||||
def sign(self, message: bytes) -> str:
|
||||
mac = _hmac.new(self._key, message, self._factory)
|
||||
return encode_output(mac.digest(), self._output_encoding)
|
||||
|
||||
|
||||
def hash_body(body: bytes, algorithm: str) -> str:
|
||||
"""Hash a request body for the canonical string. Returns lowercase hex."""
|
||||
h = _hash_factory(algorithm)()
|
||||
h.update(body)
|
||||
return h.hexdigest()
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Per-character rotation transforms.
|
||||
|
||||
Discovery (Phase 2) reduces Raycast's `Sur`/`Tur` rot function to a list of
|
||||
`RotRange` objects. This module applies them.
|
||||
|
||||
A `RotRange(start, end, shift)` maps a code point `c` with `start <= c <= end`
|
||||
to `((c - start + shift) % (end - start + 1)) + start`. Code points outside
|
||||
every supplied range pass through unchanged. Real Raycast uses three ranges:
|
||||
A-Z +13, a-z +13, 0-9 +5 — i.e. ROT13 over letters, ROT5 over digits.
|
||||
|
||||
Ranges are evaluated in order, but they're expected to be disjoint (the JS
|
||||
implementation is a single if/elif/elif chain). If two overlap, the first
|
||||
match wins; this matches the JS short-circuit and keeps the function
|
||||
data-flow obvious if discovery ever ships overlapping ranges by mistake.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from raycast_api.signing_spec import RotRange
|
||||
|
||||
|
||||
def apply_rot(s: str, ranges: Iterable[RotRange]) -> str:
|
||||
"""Apply the rotation to every character of `s`.
|
||||
|
||||
Hot path on every signed request — keeps a tuple snapshot of the ranges
|
||||
so per-character iteration doesn't touch the dataclass each loop.
|
||||
"""
|
||||
snapshot = tuple((r.start, r.end, r.shift, r.end - r.start + 1) for r in ranges)
|
||||
out: list[str] = []
|
||||
for ch in s:
|
||||
c = ord(ch)
|
||||
replaced = False
|
||||
for start, end, shift, span in snapshot:
|
||||
if start <= c <= end:
|
||||
out.append(chr((c - start + shift) % span + start))
|
||||
replaced = True
|
||||
break
|
||||
if not replaced:
|
||||
out.append(ch)
|
||||
return "".join(out)
|
||||
@@ -0,0 +1,119 @@
|
||||
"""Data classes describing the signing algorithm.
|
||||
|
||||
These are the runtime inputs to the signer. Phase 2 produces a `SigningSpec`
|
||||
from the bundle; Phase 3's `Signer` consumes it. They live in this top-level
|
||||
module (rather than inside `signing/`) so that `Config` (also top-level) can
|
||||
reference them without pulling the whole signing implementation into import.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, cast
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RotRange:
|
||||
"""One range in the character-rotation transform.
|
||||
|
||||
A character with codepoint `c` such that `start <= c <= end` is mapped to
|
||||
`((c - start + shift) % (end - start + 1)) + start`. Other characters
|
||||
pass through unchanged.
|
||||
|
||||
The default ranges match Raycast Beta 0.60.x:
|
||||
|
||||
(A-Z, +13), (a-z, +13), (0-9, +5)
|
||||
|
||||
— i.e. ROT13 over letters, ROT5 over digits.
|
||||
"""
|
||||
|
||||
start: int
|
||||
end: int
|
||||
shift: int
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not (0 <= self.start <= self.end <= 0x10FFFF):
|
||||
msg = f"invalid rot range: start={self.start}, end={self.end}"
|
||||
raise ValueError(msg)
|
||||
if self.shift < 0:
|
||||
msg = f"negative shift not supported: {self.shift}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SigningSpec:
|
||||
"""Everything the signer needs that isn't the secret itself.
|
||||
|
||||
Field-by-field intent:
|
||||
|
||||
- `rot_fn_name` / `signing_fn_name`: bookkeeping only — kept around so the
|
||||
CLI can show the user which minified symbols were matched and so the
|
||||
discovery cache key can include them (a bundle rebuild that renames
|
||||
these is also one that probably needs a re-extraction).
|
||||
- `rot_ranges`: the character-class transforms applied per canonical
|
||||
component.
|
||||
- `join_char`: the separator between transformed components in the
|
||||
canonical string. Always `"."` in observed builds.
|
||||
- `body_hash_algorithm`: the hash used on the request body before
|
||||
rot-transform. Always SHA-256.
|
||||
- `hmac_algorithm`: the HMAC hash. Always SHA-256.
|
||||
- `key_encoding`: how the secret string is turned into key bytes. Always
|
||||
"utf-8" — the secret is a 64-char ASCII hex string used AS-IS, NOT
|
||||
hex-decoded. (`bytes.fromhex(...)` silently breaks signing.)
|
||||
- `output_encoding`: "hex-lower" matches the JS implementation's
|
||||
`.padStart(2, "0")` lowercase hex output.
|
||||
"""
|
||||
|
||||
rot_fn_name: str
|
||||
signing_fn_name: str
|
||||
rot_ranges: list[RotRange] = field(default_factory=list)
|
||||
join_char: str = "."
|
||||
body_hash_algorithm: str = "SHA-256"
|
||||
hmac_algorithm: str = "SHA-256"
|
||||
key_encoding: str = "utf-8"
|
||||
output_encoding: str = "hex-lower"
|
||||
|
||||
def to_dict(self) -> dict[str, object]:
|
||||
return {
|
||||
"rot_fn_name": self.rot_fn_name,
|
||||
"signing_fn_name": self.signing_fn_name,
|
||||
"rot_ranges": [
|
||||
{"start": r.start, "end": r.end, "shift": r.shift}
|
||||
for r in self.rot_ranges
|
||||
],
|
||||
"join_char": self.join_char,
|
||||
"body_hash_algorithm": self.body_hash_algorithm,
|
||||
"hmac_algorithm": self.hmac_algorithm,
|
||||
"key_encoding": self.key_encoding,
|
||||
"output_encoding": self.output_encoding,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, object]) -> SigningSpec:
|
||||
ranges_raw = data.get("rot_ranges", []) or []
|
||||
if not isinstance(ranges_raw, list):
|
||||
msg = "rot_ranges must be a list"
|
||||
raise TypeError(msg)
|
||||
ranges: list[RotRange] = []
|
||||
for r in ranges_raw:
|
||||
if not isinstance(r, dict):
|
||||
msg = "rot_ranges entries must be objects"
|
||||
raise TypeError(msg)
|
||||
entry = cast("dict[str, Any]", r)
|
||||
ranges.append(
|
||||
RotRange(
|
||||
start=int(entry["start"]),
|
||||
end=int(entry["end"]),
|
||||
shift=int(entry["shift"]),
|
||||
)
|
||||
)
|
||||
return cls(
|
||||
rot_fn_name=str(data.get("rot_fn_name", "")),
|
||||
signing_fn_name=str(data.get("signing_fn_name", "")),
|
||||
rot_ranges=ranges,
|
||||
join_char=str(data.get("join_char", ".")),
|
||||
body_hash_algorithm=str(data.get("body_hash_algorithm", "SHA-256")),
|
||||
hmac_algorithm=str(data.get("hmac_algorithm", "SHA-256")),
|
||||
key_encoding=str(data.get("key_encoding", "utf-8")),
|
||||
output_encoding=str(data.get("output_encoding", "hex-lower")),
|
||||
)
|
||||
Reference in New Issue
Block a user