feat: add streaming to markdown, fix raycast mcps exposing
This commit is contained in:
@@ -17,12 +17,24 @@ Two halves live here:
|
||||
``arguments``) — Anthropic wants one ``content_block_start`` → deltas
|
||||
→ ``content_block_stop`` per block, so we de-duplicate the final
|
||||
summary against the per-delta increments already emitted.
|
||||
|
||||
MCP wiring: Raycast has no native MCP concept, so when an agent declares
|
||||
``expose_mcps`` we splice each MCP's tools into the wire request as
|
||||
``Tool.local(name=f"{mcp}__{tool}", ...)`` and run a gateway-internal
|
||||
loop. Every time the model emits a tool_call for one of those local
|
||||
tools we route it back to the underlying MCP in-process, append the
|
||||
result as a ``tool`` message, and re-issue the stream — all inside one
|
||||
Anthropic envelope (one ``message_start`` … one ``message_stop``). The
|
||||
tool_use blocks DO surface to the caller (mirrors what
|
||||
``ClaudeCodeBackendAdapter`` does), but tool_results stay internal.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from raycast_api import Message as RaycastMessage
|
||||
@@ -47,6 +59,8 @@ if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
|
||||
|
||||
from anthropic.types import MessageParam
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.tools.base import Tool as FastMCPTool
|
||||
from raycast_api import ChatStreamChunk, Client
|
||||
|
||||
from beaver_gateway.agents.base import BaseAgent
|
||||
@@ -58,6 +72,16 @@ else:
|
||||
__all__ = ["RaycastBackend"]
|
||||
|
||||
|
||||
_log = logging.getLogger("beaver_gateway.backends.raycast")
|
||||
|
||||
|
||||
# Cap on consecutive tool-call turns inside one Anthropic envelope.
|
||||
# Real conversations rarely chain more than a handful; the limit only
|
||||
# fires on a model that loops, and surfaces as a clean ``end_turn``
|
||||
# with an error tool_result rather than a hang.
|
||||
_MAX_TOOL_TURNS = 20
|
||||
|
||||
|
||||
_RAYCAST_TO_ANTHROPIC_STOP: dict[str, StopReason] = {
|
||||
"stop": "end_turn",
|
||||
"STOP": "end_turn",
|
||||
@@ -132,9 +156,7 @@ def _extract_tool_result_text(content: object) -> str:
|
||||
return json.dumps(content, separators=(",", ":"), ensure_ascii=False)
|
||||
|
||||
|
||||
def _to_raycast_messages(
|
||||
messages: Iterable[MessageParam],
|
||||
) -> list[RaycastMessage]:
|
||||
def _to_raycast_messages(messages: Iterable[MessageParam]) -> list[RaycastMessage]:
|
||||
"""Convert an Anthropic message history into a Raycast one.
|
||||
|
||||
``tool_use_id → name`` is tracked across the iteration so that
|
||||
@@ -206,27 +228,143 @@ def _to_raycast_messages(
|
||||
)
|
||||
out.append(
|
||||
RaycastMessage.assistant(
|
||||
text="\n".join(text_parts),
|
||||
tool_calls=tool_calls or None,
|
||||
text="\n".join(text_parts), tool_calls=tool_calls or None
|
||||
)
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _build_tool_list(
|
||||
agent: RaycastAgent,
|
||||
) -> list[Tool | RemoteTool | str] | None:
|
||||
"""Wrap each native-tool name as a Raycast remote tool.
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _AgentToolCatalog:
|
||||
"""Per-agent MCP tool catalog + routing map.
|
||||
|
||||
Phase 1.2 scope: only the three model-agnostic remote tools
|
||||
(`web_search`, `search_images`, `read_page`). Client-defined tools
|
||||
coming through the Anthropic body's ``tools`` field stay out — that's
|
||||
Phase 1.3+ alongside ``accept_client_tools``.
|
||||
Cached on :class:`RaycastBackend` keyed by agent name. ``tools`` is
|
||||
the fully-rendered ``tools`` argument for ``chat.stream`` (local MCP
|
||||
tools first, native remote tools last). ``mcp_routing`` maps each
|
||||
wire-name we registered back to the underlying ``(mcp, original)``
|
||||
pair so we can dispatch a tool_call without re-parsing the prefix.
|
||||
"""
|
||||
if not agent.available_native_tools:
|
||||
return None
|
||||
return [Tool.remote(name) for name in agent.available_native_tools]
|
||||
|
||||
tools: tuple[Tool, ...]
|
||||
mcp_routing: Mapping[str, tuple[str, str]]
|
||||
|
||||
|
||||
_EMPTY_SCHEMA: dict[str, Any] = {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
def _sanitize_schema(schema: Mapping[str, Any] | None) -> dict[str, Any]:
|
||||
"""Strip JSON-Schema meta keys that Raycast/Gemini's tool API rejects.
|
||||
|
||||
MCP servers ship draft-07 schemas including ``$schema``/``$id``/etc.
|
||||
OpenAI-style function-calling expects a leaner subset — leaving the
|
||||
meta keys in causes the upstream provider to reject the entire
|
||||
request with an opaque ``unknown_api_error``. We drop the
|
||||
document-level meta keys and keep the structural keys
|
||||
(``type``/``properties``/``required``/...).
|
||||
"""
|
||||
if not schema:
|
||||
return dict(_EMPTY_SCHEMA)
|
||||
out: dict[str, Any] = {}
|
||||
for key, value in schema.items():
|
||||
if key.startswith("$"):
|
||||
continue
|
||||
out[key] = value
|
||||
if "type" not in out:
|
||||
out["type"] = "object"
|
||||
if out.get("type") == "object" and "properties" not in out:
|
||||
out["properties"] = {}
|
||||
return out
|
||||
|
||||
|
||||
def _build_agent_catalog(
|
||||
agent: RaycastAgent, mcp_tools: Mapping[str, list[FastMCPTool]]
|
||||
) -> _AgentToolCatalog:
|
||||
"""Resolve ``agent.expose_mcps`` against the prefetched MCP tool lists.
|
||||
|
||||
Each exposed MCP contributes one ``Tool.local`` per tool, named
|
||||
``{mcp}__{tool}`` (mirrors claude-code's wire convention, so a model
|
||||
that has seen one style sees a familiar one here). ``ExposedMcp.tools``
|
||||
optionally filters to a subset by original (unprefixed) name.
|
||||
|
||||
A missing MCP entry (broken at startup, see
|
||||
``cli._prefetch_mcp_tools``) silently contributes nothing — surfaces
|
||||
in logs at start, doesn't crash request-time.
|
||||
"""
|
||||
routing: dict[str, tuple[str, str]] = {}
|
||||
local_tools: list[Tool] = []
|
||||
for em in agent.expose_mcps:
|
||||
for mt in mcp_tools.get(em.name, []):
|
||||
if em.tools is not None and mt.name not in em.tools:
|
||||
continue
|
||||
wire_name = f"{em.name}__{mt.name}"
|
||||
routing[wire_name] = (em.name, mt.name)
|
||||
local_tools.append(
|
||||
Tool.local(
|
||||
name=wire_name,
|
||||
description=mt.description or "",
|
||||
parameters=_sanitize_schema(mt.parameters),
|
||||
)
|
||||
)
|
||||
remote_tools = [Tool.remote(n) for n in agent.available_native_tools]
|
||||
return _AgentToolCatalog(
|
||||
tools=tuple(local_tools + remote_tools), mcp_routing=routing
|
||||
)
|
||||
|
||||
|
||||
def _render_mcp_tool_result(result: Any) -> str:
|
||||
"""Flatten a ``fastmcp.ToolResult`` into the text Raycast carries.
|
||||
|
||||
Raycast ``tool`` messages only hold a string. MCP results can carry
|
||||
text, images, structured payloads, etc. — we keep text blocks
|
||||
verbatim, JSON-encode anything else, and fall back to
|
||||
``structured_content`` if the content list is empty.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for block in getattr(result, "content", None) or []:
|
||||
text = getattr(block, "text", None)
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
continue
|
||||
dump = getattr(block, "model_dump", None)
|
||||
if callable(dump):
|
||||
parts.append(json.dumps(dump(), ensure_ascii=False))
|
||||
else:
|
||||
parts.append(str(block))
|
||||
if not parts:
|
||||
structured = getattr(result, "structured_content", None)
|
||||
if structured is not None:
|
||||
parts.append(json.dumps(structured, ensure_ascii=False))
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _TurnAccumulator:
|
||||
"""Pieces collected from one Raycast stream needed to feed the next.
|
||||
|
||||
``text_parts`` is the assistant's text portion (joined and replayed
|
||||
on the next ``chat.stream`` call so the model sees its own prior
|
||||
reply). ``tool_*`` mirrors the same data the wire-state already
|
||||
emitted, but kept in raw form so we can build ``ToolCall`` objects
|
||||
and dispatch them through the gateway.
|
||||
"""
|
||||
|
||||
text_parts: list[str] = field(default_factory=list)
|
||||
tool_order: list[str] = field(default_factory=list)
|
||||
tool_names: dict[str, str] = field(default_factory=dict)
|
||||
tool_args: dict[str, list[str]] = field(default_factory=dict)
|
||||
finish_reason: str | None = None
|
||||
usage: dict[str, int] | None = None
|
||||
|
||||
def add_tool_phase1(self, tool_id: str, name: str) -> None:
|
||||
if tool_id in self.tool_names:
|
||||
return
|
||||
self.tool_order.append(tool_id)
|
||||
self.tool_names[tool_id] = name
|
||||
self.tool_args.setdefault(tool_id, [])
|
||||
|
||||
def add_tool_args(self, tool_id: str, fragment: str) -> None:
|
||||
self.tool_args.setdefault(tool_id, []).append(fragment)
|
||||
|
||||
|
||||
class _BlockState:
|
||||
@@ -241,6 +379,10 @@ class _BlockState:
|
||||
in phase 1 only and ``index`` in every chunk: ``tool_id_to_block``
|
||||
keys by Raycast tool-call id, ``tool_idx_to_id`` resolves chunk
|
||||
indices back to that id for the no-id delta chunks.
|
||||
|
||||
State spans the entire Anthropic envelope (all Raycast turns) so
|
||||
block indices grow monotonically across the tool-call loop. The
|
||||
per-turn ``_TurnAccumulator`` carries the volatile bits.
|
||||
"""
|
||||
|
||||
__slots__ = ("index", "kind", "tool_id_to_block", "tool_idx_to_id")
|
||||
@@ -251,18 +393,49 @@ class _BlockState:
|
||||
self.tool_id_to_block: dict[str, int] = {}
|
||||
self.tool_idx_to_id: dict[int, str] = {}
|
||||
|
||||
def reset_turn_indexing(self) -> None:
|
||||
"""Drop chunk-index → id routing between turns.
|
||||
|
||||
Raycast chunk ``index`` fields restart at 0 inside each fresh
|
||||
``chat.stream``; reusing the previous turn's table would alias
|
||||
new tool_calls onto old ids. ``tool_id_to_block`` we keep — ids
|
||||
are stream-unique and new ones won't collide.
|
||||
"""
|
||||
self.tool_idx_to_id = {}
|
||||
|
||||
|
||||
class RaycastBackend:
|
||||
"""Adapter from ``raycast-api`` chat streams to Anthropic stream events.
|
||||
|
||||
Construction takes a long-lived :class:`raycast_api.Client` (one per
|
||||
gateway — bearer + device_id are process-wide). Each
|
||||
:meth:`complete` call opens one ``chat.stream`` and yields a fully
|
||||
Anthropic-shaped event sequence.
|
||||
gateway — bearer + device_id are process-wide) plus the in-process
|
||||
MCP server map and the prefetched per-MCP tool list. Each
|
||||
:meth:`complete` call opens one or more ``chat.stream`` calls — one
|
||||
per agent "turn" inside the gateway-internal tool-call loop — and
|
||||
yields a single Anthropic stream envelope spanning all of them.
|
||||
"""
|
||||
|
||||
def __init__(self, client: Client) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
client: Client,
|
||||
*,
|
||||
mcp_servers: Mapping[str, FastMCP] | None = None,
|
||||
mcp_tools: Mapping[str, list[FastMCPTool]] | None = None,
|
||||
) -> None:
|
||||
self._client = client
|
||||
self._mcp_servers: Mapping[str, FastMCP] = mcp_servers or {}
|
||||
self._mcp_tools: Mapping[str, list[FastMCPTool]] = mcp_tools or {}
|
||||
# Cached per-agent catalog — agents are immutable, so a single
|
||||
# render at first use covers the gateway's lifetime.
|
||||
self._agent_catalog: dict[str, _AgentToolCatalog] = {}
|
||||
|
||||
def _catalog_for(self, agent: RaycastAgent) -> _AgentToolCatalog:
|
||||
cached = self._agent_catalog.get(agent.name)
|
||||
if cached is not None:
|
||||
return cached
|
||||
cat = _build_agent_catalog(agent, self._mcp_tools)
|
||||
self._agent_catalog[agent.name] = cat
|
||||
return cat
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
@@ -277,7 +450,13 @@ class RaycastBackend:
|
||||
raise TypeError(msg)
|
||||
|
||||
raycast_messages = _to_raycast_messages(messages)
|
||||
tools = _build_tool_list(agent)
|
||||
catalog = self._catalog_for(agent)
|
||||
# Native remote tools + spliced MCP locals. ``None`` keeps the
|
||||
# SDK from sending a ``tools`` field at all when neither is
|
||||
# declared.
|
||||
tools_arg: list[Tool | RemoteTool | str] | None = (
|
||||
list(catalog.tools) if catalog.tools else None
|
||||
)
|
||||
|
||||
# On the wire Raycast uses ``system_instructions`` as a format
|
||||
# marker (``"markdown"`` for AI_CHAT, ``"plain"`` otherwise —
|
||||
@@ -294,7 +473,8 @@ class RaycastBackend:
|
||||
async for event in self._stream(
|
||||
agent=agent,
|
||||
raycast_messages=raycast_messages,
|
||||
tools=tools,
|
||||
tools=tools_arg,
|
||||
mcp_routing=catalog.mcp_routing,
|
||||
prompt_content=prompt_content,
|
||||
temperature=_first_set(options.get("temperature"), agent.temperature),
|
||||
reasoning_effort=_first_set(
|
||||
@@ -310,6 +490,7 @@ class RaycastBackend:
|
||||
agent: RaycastAgent,
|
||||
raycast_messages: Sequence[RaycastMessage],
|
||||
tools: list[Tool | RemoteTool | str] | None,
|
||||
mcp_routing: Mapping[str, tuple[str, str]],
|
||||
prompt_content: str | None,
|
||||
temperature: float | None,
|
||||
reasoning_effort: str | None,
|
||||
@@ -319,45 +500,153 @@ class RaycastBackend:
|
||||
yield build_message_start(message_id=message_id, model=agent.model)
|
||||
|
||||
state = _BlockState()
|
||||
final_finish: str | None = None
|
||||
final_usage: dict[str, int] | None = None
|
||||
# ``working_messages`` is the rolling history fed back to
|
||||
# Raycast as we resolve MCP tool_calls turn by turn.
|
||||
working_messages = list(raycast_messages)
|
||||
last_usage: dict[str, int] | None = None
|
||||
last_finish: str | None = None
|
||||
|
||||
stream = self._client.chat.stream(
|
||||
model=agent.model,
|
||||
messages=list(raycast_messages),
|
||||
source=agent.source,
|
||||
# ``system_instructions=None`` → SDK substitutes the source
|
||||
# default (``"markdown"`` / ``"plain"``). Real prompt goes
|
||||
# into ``additional_system_instructions``.
|
||||
additional_system_instructions=prompt_content,
|
||||
user_preferences=agent.user_preferences,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
for _turn in range(_MAX_TOOL_TURNS):
|
||||
acc = _TurnAccumulator()
|
||||
state.reset_turn_indexing()
|
||||
|
||||
async for chunk in stream:
|
||||
for event in self._handle_chunk(chunk, state):
|
||||
yield event
|
||||
if chunk.finish_reason:
|
||||
final_finish = chunk.finish_reason
|
||||
if chunk.usage:
|
||||
final_usage = chunk.usage
|
||||
stream = self._client.chat.stream(
|
||||
model=agent.model,
|
||||
messages=working_messages,
|
||||
source=agent.source,
|
||||
# ``system_instructions=None`` → SDK substitutes the
|
||||
# source default (``"markdown"`` / ``"plain"``). Real
|
||||
# prompt goes into ``additional_system_instructions``.
|
||||
additional_system_instructions=prompt_content,
|
||||
user_preferences=agent.user_preferences,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
|
||||
# Close whatever block is still open before the message delta.
|
||||
async for chunk in stream:
|
||||
for event in self._handle_chunk(chunk, state, acc):
|
||||
yield event
|
||||
if chunk.finish_reason:
|
||||
acc.finish_reason = chunk.finish_reason
|
||||
if chunk.usage:
|
||||
acc.usage = chunk.usage
|
||||
|
||||
last_finish = acc.finish_reason
|
||||
if acc.usage:
|
||||
last_usage = acc.usage
|
||||
|
||||
# Figure out which (if any) tool_calls land on our MCP
|
||||
# routing table. Anything not in the table is left to bubble
|
||||
# out of the envelope as a regular tool_use block — the
|
||||
# Anthropic caller can then respond with a tool_result the
|
||||
# usual way, and the next ``complete`` invocation will
|
||||
# carry it back in.
|
||||
pending_mcp_calls = [
|
||||
tid for tid in acc.tool_order if acc.tool_names.get(tid) in mcp_routing
|
||||
]
|
||||
if not pending_mcp_calls:
|
||||
break
|
||||
|
||||
# Close whatever block is still open before we step into
|
||||
# tool execution — the next turn's chunks start a fresh
|
||||
# block sequence.
|
||||
if state.kind is not None:
|
||||
yield build_content_block_stop(state.index)
|
||||
state.kind = None
|
||||
|
||||
# Echo the assistant turn so Raycast sees its own reply +
|
||||
# tool_calls in subsequent context.
|
||||
assistant_text = "".join(acc.text_parts)
|
||||
assistant_tool_calls = [
|
||||
ToolCall(
|
||||
id=tid,
|
||||
name=acc.tool_names.get(tid, ""),
|
||||
arguments="".join(acc.tool_args.get(tid, [])) or "{}",
|
||||
)
|
||||
for tid in acc.tool_order
|
||||
]
|
||||
working_messages.append(
|
||||
RaycastMessage.assistant(
|
||||
text=assistant_text, tool_calls=assistant_tool_calls or None
|
||||
)
|
||||
)
|
||||
|
||||
# Dispatch each MCP-routed call and append a ``tool``
|
||||
# message. Calls not in the routing table get a placeholder
|
||||
# error so the model can correct itself rather than the
|
||||
# gateway hanging the conversation.
|
||||
for tid in acc.tool_order:
|
||||
tool_name = acc.tool_names.get(tid, "")
|
||||
args_str = "".join(acc.tool_args.get(tid, [])) or "{}"
|
||||
if tool_name in mcp_routing:
|
||||
result_text = await self._dispatch_mcp_call(
|
||||
tool_name=tool_name, args_json=args_str, mcp_routing=mcp_routing
|
||||
)
|
||||
else:
|
||||
# Non-MCP tool — shouldn't really happen because we
|
||||
# haven't surfaced any other locals, but defend.
|
||||
result_text = f"Tool {tool_name!r} is not handled by the gateway."
|
||||
working_messages.append(
|
||||
RaycastMessage.tool(
|
||||
tool_call_id=tid, name=tool_name, result=result_text
|
||||
)
|
||||
)
|
||||
|
||||
# Loop: re-stream with the new history.
|
||||
else:
|
||||
_log.warning(
|
||||
"raycast tool-call loop hit %d-turn cap for agent %r; "
|
||||
"closing the envelope",
|
||||
_MAX_TOOL_TURNS,
|
||||
agent.name,
|
||||
)
|
||||
|
||||
# Close whatever block is still open before the final delta.
|
||||
if state.kind is not None:
|
||||
yield build_content_block_stop(state.index)
|
||||
state.kind = None
|
||||
|
||||
yield build_message_delta(
|
||||
stop_reason=_map_stop_reason(final_finish),
|
||||
usage=final_usage,
|
||||
stop_reason=_map_stop_reason(last_finish), usage=last_usage
|
||||
)
|
||||
yield build_message_stop()
|
||||
|
||||
async def _dispatch_mcp_call(
|
||||
self,
|
||||
*,
|
||||
tool_name: str,
|
||||
args_json: str,
|
||||
mcp_routing: Mapping[str, tuple[str, str]],
|
||||
) -> str:
|
||||
"""Route one tool_call to the matching MCP and render its result.
|
||||
|
||||
Errors (missing server, JSON parse, tool exception) surface as
|
||||
text in the ``tool`` message rather than crashing the stream —
|
||||
the model gets a chance to recover or apologize.
|
||||
"""
|
||||
mcp_name, original_name = mcp_routing[tool_name]
|
||||
server = self._mcp_servers.get(mcp_name)
|
||||
if server is None:
|
||||
return f"Error: MCP {mcp_name!r} is not registered."
|
||||
try:
|
||||
args = json.loads(args_json) if args_json else {}
|
||||
except json.JSONDecodeError as exc:
|
||||
return f"Error: tool arguments are not valid JSON ({exc})."
|
||||
if not isinstance(args, dict):
|
||||
return "Error: tool arguments must be a JSON object."
|
||||
try:
|
||||
result = await server.call_tool(original_name, args)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
_log.exception(
|
||||
"MCP call failed: %s.%s args=%r", mcp_name, original_name, args
|
||||
)
|
||||
return f"Error calling {tool_name}: {exc}"
|
||||
return _render_mcp_tool_result(result)
|
||||
|
||||
def _handle_chunk(
|
||||
self, chunk: ChatStreamChunk, state: _BlockState
|
||||
self, chunk: ChatStreamChunk, state: _BlockState, acc: _TurnAccumulator
|
||||
) -> Iterable[MessageStreamEvent]:
|
||||
"""Translate one Raycast chunk into zero or more Anthropic events.
|
||||
|
||||
@@ -373,6 +662,7 @@ class RaycastBackend:
|
||||
if chunk.text:
|
||||
events.extend(self._ensure_kind(state, "text"))
|
||||
events.append(build_text_delta(state.index, chunk.text))
|
||||
acc.text_parts.append(chunk.text)
|
||||
|
||||
if chunk.reasoning:
|
||||
events.extend(self._ensure_kind(state, "thinking"))
|
||||
@@ -380,7 +670,9 @@ class RaycastBackend:
|
||||
|
||||
if chunk.tool_calls:
|
||||
events.extend(
|
||||
self._handle_tool_calls(chunk, state, is_final_summary=is_final_summary)
|
||||
self._handle_tool_calls(
|
||||
chunk, state, acc, is_final_summary=is_final_summary
|
||||
)
|
||||
)
|
||||
|
||||
return events
|
||||
@@ -409,6 +701,7 @@ class RaycastBackend:
|
||||
self,
|
||||
chunk: ChatStreamChunk,
|
||||
state: _BlockState,
|
||||
acc: _TurnAccumulator,
|
||||
*,
|
||||
is_final_summary: bool,
|
||||
) -> Iterable[MessageStreamEvent]:
|
||||
@@ -419,6 +712,11 @@ class RaycastBackend:
|
||||
by id when present, otherwise by their wire ``index`` field, and
|
||||
the final-summary chunk's arguments string is dropped because the
|
||||
deltas already streamed it.
|
||||
|
||||
Side-effects on ``acc`` mirror what gets emitted to the wire so
|
||||
the gateway can rebuild a full ``ToolCall`` for the next Raycast
|
||||
turn (it needs the joined ``arguments`` JSON string, which the
|
||||
wire never gives us as a whole — only in fragments).
|
||||
"""
|
||||
events: list[MessageStreamEvent] = []
|
||||
raw_tcs = chunk.raw.get("tool_calls") or []
|
||||
@@ -449,20 +747,35 @@ class RaycastBackend:
|
||||
state.kind = "tool_use"
|
||||
block_idx = state.index
|
||||
state.tool_id_to_block[tool_id] = block_idx
|
||||
acc.add_tool_phase1(tool_id, tc.name or "")
|
||||
events.append(
|
||||
build_tool_use_block_start(
|
||||
block_idx, tool_use_id=tool_id, name=tc.name or ""
|
||||
)
|
||||
)
|
||||
if tc.arguments and not is_final_summary:
|
||||
# Streaming providers (GPT) deliver arguments as deltas
|
||||
# across chunks and re-state the full string in a final
|
||||
# summary chunk; non-streaming-args providers (Gemini)
|
||||
# only send args in the final summary. Emit args in
|
||||
# both cases when this is the first appearance — final
|
||||
# summary then IS the full string.
|
||||
if tc.arguments:
|
||||
events.append(build_input_json_delta(block_idx, tc.arguments))
|
||||
acc.add_tool_args(tool_id, tc.arguments)
|
||||
continue
|
||||
|
||||
# Existing block. Skip args on the final summary — they're a
|
||||
# full restatement of what's already been delta'd.
|
||||
# Existing block. Final summary chunks restate the full args
|
||||
# string; if we already streamed deltas, that restatement is
|
||||
# a duplicate (skip). If we streamed nothing (the streaming
|
||||
# provider didn't send mid-arg deltas — Gemini path again),
|
||||
# the summary IS the args — emit it once.
|
||||
if is_final_summary:
|
||||
if tc.arguments and not acc.tool_args.get(tool_id):
|
||||
events.append(build_input_json_delta(block_idx, tc.arguments))
|
||||
acc.add_tool_args(tool_id, tc.arguments)
|
||||
continue
|
||||
if tc.arguments:
|
||||
events.append(build_input_json_delta(block_idx, tc.arguments))
|
||||
acc.add_tool_args(tool_id, tc.arguments)
|
||||
|
||||
return events
|
||||
|
||||
+46
-11
@@ -44,6 +44,8 @@ from beaver_gateway.settings import Settings
|
||||
from beaver_gateway.storage import Database
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.tools.base import Tool as FastMCPTool
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from beaver_gateway.backends.base import Backend
|
||||
@@ -95,16 +97,26 @@ async def _async_main() -> None:
|
||||
stack.push_async_callback(token_store.stop)
|
||||
# Internal MCP URLs must exist before we construct any
|
||||
# ClaudeCodeBackendAdapter — adapters bake the URLs into their
|
||||
# ``BackendOptions.mcp_servers`` at construction time.
|
||||
internal_app, internal_urls = _build_internal_mcp(
|
||||
# ``BackendOptions.mcp_servers`` at construction time. The
|
||||
# ``mcp_servers`` map is used by the Raycast backend, which
|
||||
# needs in-process ``list_tools`` / ``call_tool`` access (the
|
||||
# Raycast wire has no native MCP concept).
|
||||
internal_app, internal_urls, mcp_servers = _build_internal_mcp(
|
||||
gateway.mcps, settings=settings
|
||||
)
|
||||
|
||||
# Prefetch tool catalogs for every MCP so RaycastAgent requests
|
||||
# don't pay a per-turn list_tools roundtrip and so a broken MCP
|
||||
# surfaces at startup instead of mid-conversation.
|
||||
mcp_tools = await _prefetch_mcp_tools(mcp_servers)
|
||||
|
||||
backends: dict[str, Backend] = await _build_backends(
|
||||
settings=settings,
|
||||
agents=agents,
|
||||
stack=stack,
|
||||
mcp_internal_urls=internal_urls,
|
||||
mcp_servers=mcp_servers,
|
||||
mcp_tools=mcp_tools,
|
||||
)
|
||||
|
||||
runtime = GatewayRuntime(
|
||||
@@ -152,20 +164,39 @@ async def _async_main() -> None:
|
||||
|
||||
def _build_internal_mcp(
|
||||
mcps: list[McpServerT], *, settings: Settings
|
||||
) -> tuple[Starlette | None, dict[str, str]]:
|
||||
"""Build the aggregator app + URL map, or return ``(None, {})``.
|
||||
) -> tuple[Starlette | None, dict[str, str], dict[str, FastMCP]]:
|
||||
"""Build the aggregator app + URL map + server map, or empty equivalents.
|
||||
|
||||
The URL map is always handed out (frontends may still introspect
|
||||
``runtime.mcp_internal_urls`` even if nothing is configured); the
|
||||
app is ``None`` when there are no MCPs to mount, so the caller
|
||||
skips the uvicorn task entirely.
|
||||
skips the uvicorn task entirely. The server map is the in-process
|
||||
handle the Raycast backend needs to splice MCP tools into its
|
||||
requests — empty when no MCPs are configured.
|
||||
"""
|
||||
if not mcps:
|
||||
return None, {}
|
||||
app, urls = build_internal_app(
|
||||
mcps, host="127.0.0.1", port=settings.internal_mcp_port
|
||||
)
|
||||
return app, urls
|
||||
return None, {}, {}
|
||||
return build_internal_app(mcps, host="127.0.0.1", port=settings.internal_mcp_port)
|
||||
|
||||
|
||||
async def _prefetch_mcp_tools(
|
||||
servers: dict[str, FastMCP],
|
||||
) -> dict[str, list[FastMCPTool]]:
|
||||
"""Eagerly enumerate tools per MCP so the Raycast loop has a static catalog.
|
||||
|
||||
Each underlying proxy is allowed to fail independently — a broken
|
||||
MCP shouldn't take down the whole gateway. The result has one entry
|
||||
per MCP that responded; agents that ``expose_mcps`` a missing entry
|
||||
will simply expose no tools from it (logged once per request).
|
||||
"""
|
||||
out: dict[str, list[FastMCPTool]] = {}
|
||||
for name, server in servers.items():
|
||||
try:
|
||||
out[name] = list(await server.list_tools())
|
||||
except Exception: # noqa: BLE001 — proxy can raise any transport error; we degrade per-MCP rather than fail the whole gateway
|
||||
_log.exception("failed to list tools for MCP %r — skipping", name)
|
||||
out[name] = []
|
||||
return out
|
||||
|
||||
|
||||
async def _serve_internal_mcp(app: Starlette, *, settings: Settings) -> None:
|
||||
@@ -196,6 +227,8 @@ async def _build_backends(
|
||||
agents: AgentRegistry,
|
||||
stack: AsyncExitStack,
|
||||
mcp_internal_urls: dict[str, str],
|
||||
mcp_servers: dict[str, FastMCP],
|
||||
mcp_tools: dict[str, list[FastMCPTool]],
|
||||
) -> dict[str, Backend]:
|
||||
"""Construct one backend per agent name.
|
||||
|
||||
@@ -216,7 +249,9 @@ async def _build_backends(
|
||||
if raycast_agents:
|
||||
client = await _try_open_raycast_client(settings, stack)
|
||||
if client is not None:
|
||||
raycast_backend = RaycastBackend(client)
|
||||
raycast_backend = RaycastBackend(
|
||||
client, mcp_servers=mcp_servers, mcp_tools=mcp_tools
|
||||
)
|
||||
for a in raycast_agents:
|
||||
backends[a.name] = raycast_backend
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlmodel import select
|
||||
|
||||
from beaver_gateway.storage.models import Conversation, ConversationMessage
|
||||
@@ -136,7 +137,34 @@ async def load_messages(
|
||||
)
|
||||
result = await session.exec(stmt)
|
||||
rows = result.all()
|
||||
return [{"role": r.role, "content": json.loads(r.content_json)} for r in rows]
|
||||
return [
|
||||
{"role": r.role, "content": _sanitize_content(json.loads(r.content_json))}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
def _sanitize_content(content: Any) -> Any:
|
||||
"""Strip wire-illegal fields from stored Anthropic content blocks.
|
||||
|
||||
Older capture code emitted ``"is_error": null`` on ``tool_result``
|
||||
blocks; the Anthropic API rejects null there (the field is optional
|
||||
but, when present, must be boolean). We omit the key on read so
|
||||
historical rows don't break continuation.
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return content
|
||||
cleaned: list[Any] = []
|
||||
for blk in content:
|
||||
out_blk = blk
|
||||
if (
|
||||
isinstance(blk, dict)
|
||||
and blk.get("type") == "tool_result"
|
||||
and blk.get("is_error") is None
|
||||
and "is_error" in blk
|
||||
):
|
||||
out_blk = {k: v for k, v in blk.items() if k != "is_error"}
|
||||
cleaned.append(out_blk)
|
||||
return cleaned
|
||||
|
||||
|
||||
async def rewrite_messages(
|
||||
@@ -148,13 +176,20 @@ async def rewrite_messages(
|
||||
our volume; if it ever matters we can switch to soft-delete +
|
||||
branch pointers.
|
||||
"""
|
||||
# Delete existing rows for this conversation.
|
||||
existing_stmt = select(ConversationMessage).where(
|
||||
ConversationMessage.conversation_id == conversation_id
|
||||
# Bulk-delete and flush before inserting the new sequence: SQLAlchemy's
|
||||
# unit-of-work flushes INSERTs before DELETEs by default, which would
|
||||
# trip ``uq_msg_conv_seq`` when the new rows reuse the same seq numbers
|
||||
# as the soon-to-be-deleted ones.
|
||||
# SQLModel descriptors resolve to ColumnElement at runtime but to bare
|
||||
# ``int`` in ty's stubs; the select-path at line 135 lives behind sqlmodel's
|
||||
# own ``select`` overloads that hide it, but ``sqlalchemy.delete().where``
|
||||
# uses the raw stubs.
|
||||
await session.execute( # ty: ignore[deprecated]
|
||||
delete(ConversationMessage).where(
|
||||
ConversationMessage.conversation_id == conversation_id # ty: ignore[invalid-argument-type]
|
||||
)
|
||||
)
|
||||
result = await session.exec(existing_stmt)
|
||||
for row in result.all():
|
||||
await session.delete(row)
|
||||
await session.flush()
|
||||
# Insert the new sequence.
|
||||
for seq, m in enumerate(messages):
|
||||
session.add(
|
||||
|
||||
@@ -29,10 +29,12 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import aiofile
|
||||
from anthropic.types import RawContentBlockStopEvent
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
@@ -46,7 +48,7 @@ from beaver_gateway.core.conversation_store import (
|
||||
rewrite_messages,
|
||||
)
|
||||
from beaver_gateway.core.turn_record import TurnRecord
|
||||
from beaver_gateway.frontends._accumulate import accumulate
|
||||
from beaver_gateway.frontends._accumulate import StreamAccumulator
|
||||
from beaver_gateway.frontends._auth import require_token
|
||||
from beaver_gateway.frontends.base import Frontend
|
||||
from beaver_gateway.frontends.markdown import parser, renderer
|
||||
@@ -69,6 +71,14 @@ _log = logging.getLogger("beaver_gateway.frontends.markdown")
|
||||
__all__ = ["MarkdownFrontend"]
|
||||
|
||||
|
||||
# How often we re-render the assistant turn into the .md file while the
|
||||
# backend stream is still open. Trades responsiveness (faster updates to
|
||||
# Obsidian sync / Raycast tailers) against write amplification. Each
|
||||
# ``RawContentBlockStopEvent`` also forces a flush regardless of the
|
||||
# timer, so block boundaries always land in the file.
|
||||
_STREAM_FLUSH_DEBOUNCE = 0.4
|
||||
|
||||
|
||||
class MarkdownFrontend(Frontend):
|
||||
"""FastAPI app behind ``POST /chat`` driven by Obsidian-vault files."""
|
||||
|
||||
@@ -285,21 +295,23 @@ class MarkdownFrontend(Frontend):
|
||||
TurnCapture() if isinstance(backend, ClaudeCodeBackendAdapter) else None
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {}
|
||||
if capture is not None:
|
||||
kwargs["capture"] = capture
|
||||
events = backend.complete(
|
||||
agent=agent, messages=outcome.messages, system=None, **kwargs
|
||||
)
|
||||
try:
|
||||
kwargs: dict[str, Any] = {}
|
||||
if capture is not None:
|
||||
kwargs["capture"] = capture
|
||||
events = backend.complete(
|
||||
agent=agent, messages=outcome.messages, system=None, **kwargs
|
||||
message = await self._stream_to_file(
|
||||
events=events,
|
||||
file_path=file_path,
|
||||
parsed=parsed,
|
||||
model=agent.model or agent.name,
|
||||
filename=filename,
|
||||
)
|
||||
message = await accumulate(events, model=agent.model or agent.name)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
_log.exception("backend failed for %s", filename)
|
||||
error_block = _render_error_block(exc)
|
||||
new_body = renderer.append_to_body(parsed.body, error_block)
|
||||
await _write_atomic(
|
||||
file_path, _reattach_frontmatter(parsed.metadata, new_body)
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, f"backend error: {exc}"
|
||||
) from exc
|
||||
@@ -346,6 +358,67 @@ class MarkdownFrontend(Frontend):
|
||||
|
||||
# ---- helpers -------------------------------------------------------
|
||||
|
||||
async def _stream_to_file(
|
||||
self,
|
||||
*,
|
||||
events: Any,
|
||||
file_path: Path,
|
||||
parsed: parser.ParsedFile,
|
||||
model: str,
|
||||
filename: str,
|
||||
) -> Any:
|
||||
"""Drain ``events`` into a ``Message``, flushing partials to disk.
|
||||
|
||||
Flushes happen on each ``RawContentBlockStopEvent`` (natural
|
||||
block boundary, content is markdown-consistent) and on the
|
||||
``_STREAM_FLUSH_DEBOUNCE`` timer between events. The partial
|
||||
write keeps the as-parsed frontmatter; the post-stream final
|
||||
write in ``_write_assistant_reply`` is what stamps the refreshed
|
||||
fingerprint / agent / conversation_id.
|
||||
|
||||
On backend exception we still flush the last partial and append
|
||||
an error callout, so the human sees both what arrived and why it
|
||||
stopped. The exception propagates so ``_handle_chat`` can map it
|
||||
to a 500.
|
||||
"""
|
||||
acc = StreamAccumulator()
|
||||
|
||||
async def flush_partial() -> None:
|
||||
partial = acc.finalize(model=model)
|
||||
if not partial.content:
|
||||
return
|
||||
rendered = renderer.render_assistant_message(partial)
|
||||
new_body = renderer.append_to_body(parsed.body, rendered)
|
||||
await _write_atomic(
|
||||
file_path, _reattach_frontmatter(parsed.metadata, new_body)
|
||||
)
|
||||
|
||||
try:
|
||||
last_flush = time.monotonic()
|
||||
async for ev in events:
|
||||
acc.feed(ev)
|
||||
now = time.monotonic()
|
||||
if (
|
||||
isinstance(ev, RawContentBlockStopEvent)
|
||||
or (now - last_flush) >= _STREAM_FLUSH_DEBOUNCE
|
||||
):
|
||||
await flush_partial()
|
||||
last_flush = now
|
||||
except Exception as exc:
|
||||
_log.exception("backend failed for %s", filename)
|
||||
partial = acc.finalize(model=model)
|
||||
new_body = parsed.body
|
||||
if partial.content:
|
||||
new_body = renderer.append_to_body(
|
||||
new_body, renderer.render_assistant_message(partial)
|
||||
)
|
||||
new_body = renderer.append_to_body(new_body, _render_error_block(exc))
|
||||
await _write_atomic(
|
||||
file_path, _reattach_frontmatter(parsed.metadata, new_body)
|
||||
)
|
||||
raise
|
||||
return acc.finalize(model=model)
|
||||
|
||||
async def _write_assistant_reply(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -48,8 +48,8 @@ ALL_NAMESPACE = "all"
|
||||
|
||||
def build_internal_app(
|
||||
mcps: Iterable[McpServerT], *, host: str, port: int
|
||||
) -> tuple[Starlette, dict[str, str]]:
|
||||
"""Build the aggregator ``Starlette`` app and the per-namespace URL map.
|
||||
) -> tuple[Starlette, dict[str, str], dict[str, FastMCP]]:
|
||||
"""Build the aggregator ``Starlette`` app, per-namespace URL map, and server map.
|
||||
|
||||
``host``/``port`` only flavour the URL strings handed back — actually
|
||||
listening on them is the caller's job (``cli.main`` runs a uvicorn
|
||||
@@ -57,21 +57,23 @@ def build_internal_app(
|
||||
have to format the URLs themselves and risk drifting from the
|
||||
``/mcp/<name>`` convention.
|
||||
|
||||
Returns a map of ``{namespace: url}`` for the per-domain endpoints
|
||||
only (claude-code's MCP routing expects per-domain framing). The
|
||||
``/mcp/all/`` bundle endpoint exists on the same app but is
|
||||
intentionally omitted from the URL map — it's only meaningful to
|
||||
external clients via the MCP frontend, not to claude-code.
|
||||
Returns:
|
||||
* Starlette app to serve via uvicorn.
|
||||
* ``{namespace: url}`` for the per-domain endpoints (claude-code's
|
||||
MCP routing expects per-domain framing). ``/mcp/all/`` is
|
||||
omitted — only meaningful to external clients via the MCP
|
||||
frontend, not to claude-code.
|
||||
* ``{namespace: FastMCP}`` for backends that need in-process
|
||||
access (Raycast doesn't natively understand MCP, so the
|
||||
gateway calls ``list_tools``/``call_tool`` directly to splice
|
||||
MCP tools into the Raycast wire).
|
||||
"""
|
||||
servers: dict[str, FastMCP] = {spec.name: _build_server(spec) for spec in mcps}
|
||||
|
||||
child_apps = {
|
||||
name: s.http_app(transport="http", path="/")
|
||||
for name, s in servers.items()
|
||||
name: s.http_app(transport="http", path="/") for name, s in servers.items()
|
||||
}
|
||||
routes = [
|
||||
Mount(f"/mcp/{name}", app=app) for name, app in child_apps.items()
|
||||
]
|
||||
routes = [Mount(f"/mcp/{name}", app=app) for name, app in child_apps.items()]
|
||||
|
||||
# /mcp/all — flat-namespace bundle. Skip when there's nothing to
|
||||
# bundle so we don't pay for an empty session manager lifecycle.
|
||||
@@ -102,7 +104,7 @@ def build_internal_app(
|
||||
# 307 redirect from ``/mcp/<name>`` to ``/mcp/<name>/`` that
|
||||
# ``Mount`` produces when a child route lives at ``/``.
|
||||
urls = {name: f"http://{host}:{port}/mcp/{name}/" for name in servers}
|
||||
return app, urls
|
||||
return app, urls, servers
|
||||
|
||||
|
||||
def _build_server(spec: McpServerT) -> FastMCP:
|
||||
|
||||
@@ -287,7 +287,7 @@ local = [
|
||||
{ name = "raycast-api", version = "0.1.0", source = { editable = "../raycast-api" } },
|
||||
]
|
||||
prod = [
|
||||
{ name = "claude-code-api", version = "0.1.0", source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#1f20cef7d49d290f2b620ebb8a7aca92cdbd0e2a" } },
|
||||
{ name = "claude-code-api", version = "0.1.0", source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#86d8a8f4c471605577716ab0f039c857e6261a0e" } },
|
||||
{ name = "raycast-api", version = "0.1.0", source = { git = "https://git.kotikot.com/beaver/raycast-api.git#e73894c8e435da5c0709f92f69f11bcd0dab9afe" } },
|
||||
]
|
||||
|
||||
@@ -419,7 +419,7 @@ wheels = [
|
||||
[[package]]
|
||||
name = "claude-code-api"
|
||||
version = "0.1.0"
|
||||
source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#1f20cef7d49d290f2b620ebb8a7aca92cdbd0e2a" }
|
||||
source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#86d8a8f4c471605577716ab0f039c857e6261a0e" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14'",
|
||||
"python_full_version < '3.14'",
|
||||
|
||||
Reference in New Issue
Block a user