feat: add streaming to markdown, fix raycast mcps exposing

This commit is contained in:
h
2026-05-21 13:52:48 +02:00
parent 7fc0c9c0b1
commit 11f061070f
6 changed files with 557 additions and 99 deletions
+366 -53
View File
@@ -17,12 +17,24 @@ Two halves live here:
``arguments``) — Anthropic wants one ``content_block_start`` → deltas
→ ``content_block_stop`` per block, so we de-duplicate the final
summary against the per-delta increments already emitted.
MCP wiring: Raycast has no native MCP concept, so when an agent declares
``expose_mcps`` we splice each MCP's tools into the wire request as
``Tool.local(name=f"{mcp}__{tool}", ...)`` and run a gateway-internal
loop. Every time the model emits a tool_call for one of those local
tools we route it back to the underlying MCP in-process, append the
result as a ``tool`` message, and re-issue the stream — all inside one
Anthropic envelope (one ``message_start`` … one ``message_stop``). The
tool_use blocks DO surface to the caller (mirrors what
``ClaudeCodeBackendAdapter`` does), but tool_results stay internal.
"""
from __future__ import annotations
import json
import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, cast
from raycast_api import Message as RaycastMessage
@@ -47,6 +59,8 @@ if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
from anthropic.types import MessageParam
from fastmcp import FastMCP
from fastmcp.tools.base import Tool as FastMCPTool
from raycast_api import ChatStreamChunk, Client
from beaver_gateway.agents.base import BaseAgent
@@ -58,6 +72,16 @@ else:
__all__ = ["RaycastBackend"]
_log = logging.getLogger("beaver_gateway.backends.raycast")
# Cap on consecutive tool-call turns inside one Anthropic envelope.
# Real conversations rarely chain more than a handful; the limit only
# fires on a model that loops, and surfaces as a clean ``end_turn``
# with an error tool_result rather than a hang.
_MAX_TOOL_TURNS = 20
_RAYCAST_TO_ANTHROPIC_STOP: dict[str, StopReason] = {
"stop": "end_turn",
"STOP": "end_turn",
@@ -132,9 +156,7 @@ def _extract_tool_result_text(content: object) -> str:
return json.dumps(content, separators=(",", ":"), ensure_ascii=False)
def _to_raycast_messages(
messages: Iterable[MessageParam],
) -> list[RaycastMessage]:
def _to_raycast_messages(messages: Iterable[MessageParam]) -> list[RaycastMessage]:
"""Convert an Anthropic message history into a Raycast one.
``tool_use_id → name`` is tracked across the iteration so that
@@ -206,27 +228,143 @@ def _to_raycast_messages(
)
out.append(
RaycastMessage.assistant(
text="\n".join(text_parts),
tool_calls=tool_calls or None,
text="\n".join(text_parts), tool_calls=tool_calls or None
)
)
return out
def _build_tool_list(
agent: RaycastAgent,
) -> list[Tool | RemoteTool | str] | None:
"""Wrap each native-tool name as a Raycast remote tool.
@dataclass(frozen=True, slots=True)
class _AgentToolCatalog:
"""Per-agent MCP tool catalog + routing map.
Phase 1.2 scope: only the three model-agnostic remote tools
(`web_search`, `search_images`, `read_page`). Client-defined tools
coming through the Anthropic body's ``tools`` field stay out — that's
Phase 1.3+ alongside ``accept_client_tools``.
Cached on :class:`RaycastBackend` keyed by agent name. ``tools`` is
the fully-rendered ``tools`` argument for ``chat.stream`` (local MCP
tools first, native remote tools last). ``mcp_routing`` maps each
wire-name we registered back to the underlying ``(mcp, original)``
pair so we can dispatch a tool_call without re-parsing the prefix.
"""
if not agent.available_native_tools:
return None
return [Tool.remote(name) for name in agent.available_native_tools]
tools: tuple[Tool, ...]
mcp_routing: Mapping[str, tuple[str, str]]
_EMPTY_SCHEMA: dict[str, Any] = {"type": "object", "properties": {}}
def _sanitize_schema(schema: Mapping[str, Any] | None) -> dict[str, Any]:
"""Strip JSON-Schema meta keys that Raycast/Gemini's tool API rejects.
MCP servers ship draft-07 schemas including ``$schema``/``$id``/etc.
OpenAI-style function-calling expects a leaner subset — leaving the
meta keys in causes the upstream provider to reject the entire
request with an opaque ``unknown_api_error``. We drop the
document-level meta keys and keep the structural keys
(``type``/``properties``/``required``/...).
"""
if not schema:
return dict(_EMPTY_SCHEMA)
out: dict[str, Any] = {}
for key, value in schema.items():
if key.startswith("$"):
continue
out[key] = value
if "type" not in out:
out["type"] = "object"
if out.get("type") == "object" and "properties" not in out:
out["properties"] = {}
return out
def _build_agent_catalog(
agent: RaycastAgent, mcp_tools: Mapping[str, list[FastMCPTool]]
) -> _AgentToolCatalog:
"""Resolve ``agent.expose_mcps`` against the prefetched MCP tool lists.
Each exposed MCP contributes one ``Tool.local`` per tool, named
``{mcp}__{tool}`` (mirrors claude-code's wire convention, so a model
that has seen one style sees a familiar one here). ``ExposedMcp.tools``
optionally filters to a subset by original (unprefixed) name.
A missing MCP entry (broken at startup, see
``cli._prefetch_mcp_tools``) silently contributes nothing — surfaces
in logs at start, doesn't crash request-time.
"""
routing: dict[str, tuple[str, str]] = {}
local_tools: list[Tool] = []
for em in agent.expose_mcps:
for mt in mcp_tools.get(em.name, []):
if em.tools is not None and mt.name not in em.tools:
continue
wire_name = f"{em.name}__{mt.name}"
routing[wire_name] = (em.name, mt.name)
local_tools.append(
Tool.local(
name=wire_name,
description=mt.description or "",
parameters=_sanitize_schema(mt.parameters),
)
)
remote_tools = [Tool.remote(n) for n in agent.available_native_tools]
return _AgentToolCatalog(
tools=tuple(local_tools + remote_tools), mcp_routing=routing
)
def _render_mcp_tool_result(result: Any) -> str:
"""Flatten a ``fastmcp.ToolResult`` into the text Raycast carries.
Raycast ``tool`` messages only hold a string. MCP results can carry
text, images, structured payloads, etc. — we keep text blocks
verbatim, JSON-encode anything else, and fall back to
``structured_content`` if the content list is empty.
"""
parts: list[str] = []
for block in getattr(result, "content", None) or []:
text = getattr(block, "text", None)
if isinstance(text, str):
parts.append(text)
continue
dump = getattr(block, "model_dump", None)
if callable(dump):
parts.append(json.dumps(dump(), ensure_ascii=False))
else:
parts.append(str(block))
if not parts:
structured = getattr(result, "structured_content", None)
if structured is not None:
parts.append(json.dumps(structured, ensure_ascii=False))
return "\n".join(parts)
@dataclass(slots=True)
class _TurnAccumulator:
"""Pieces collected from one Raycast stream needed to feed the next.
``text_parts`` is the assistant's text portion (joined and replayed
on the next ``chat.stream`` call so the model sees its own prior
reply). ``tool_*`` mirrors the same data the wire-state already
emitted, but kept in raw form so we can build ``ToolCall`` objects
and dispatch them through the gateway.
"""
text_parts: list[str] = field(default_factory=list)
tool_order: list[str] = field(default_factory=list)
tool_names: dict[str, str] = field(default_factory=dict)
tool_args: dict[str, list[str]] = field(default_factory=dict)
finish_reason: str | None = None
usage: dict[str, int] | None = None
def add_tool_phase1(self, tool_id: str, name: str) -> None:
if tool_id in self.tool_names:
return
self.tool_order.append(tool_id)
self.tool_names[tool_id] = name
self.tool_args.setdefault(tool_id, [])
def add_tool_args(self, tool_id: str, fragment: str) -> None:
self.tool_args.setdefault(tool_id, []).append(fragment)
class _BlockState:
@@ -241,6 +379,10 @@ class _BlockState:
in phase 1 only and ``index`` in every chunk: ``tool_id_to_block``
keys by Raycast tool-call id, ``tool_idx_to_id`` resolves chunk
indices back to that id for the no-id delta chunks.
State spans the entire Anthropic envelope (all Raycast turns) so
block indices grow monotonically across the tool-call loop. The
per-turn ``_TurnAccumulator`` carries the volatile bits.
"""
__slots__ = ("index", "kind", "tool_id_to_block", "tool_idx_to_id")
@@ -251,18 +393,49 @@ class _BlockState:
self.tool_id_to_block: dict[str, int] = {}
self.tool_idx_to_id: dict[int, str] = {}
def reset_turn_indexing(self) -> None:
"""Drop chunk-index → id routing between turns.
Raycast chunk ``index`` fields restart at 0 inside each fresh
``chat.stream``; reusing the previous turn's table would alias
new tool_calls onto old ids. ``tool_id_to_block`` we keep — ids
are stream-unique and new ones won't collide.
"""
self.tool_idx_to_id = {}
class RaycastBackend:
"""Adapter from ``raycast-api`` chat streams to Anthropic stream events.
Construction takes a long-lived :class:`raycast_api.Client` (one per
gateway — bearer + device_id are process-wide). Each
:meth:`complete` call opens one ``chat.stream`` and yields a fully
Anthropic-shaped event sequence.
gateway — bearer + device_id are process-wide) plus the in-process
MCP server map and the prefetched per-MCP tool list. Each
:meth:`complete` call opens one or more ``chat.stream`` calls — one
per agent "turn" inside the gateway-internal tool-call loop — and
yields a single Anthropic stream envelope spanning all of them.
"""
def __init__(self, client: Client) -> None:
def __init__(
self,
client: Client,
*,
mcp_servers: Mapping[str, FastMCP] | None = None,
mcp_tools: Mapping[str, list[FastMCPTool]] | None = None,
) -> None:
self._client = client
self._mcp_servers: Mapping[str, FastMCP] = mcp_servers or {}
self._mcp_tools: Mapping[str, list[FastMCPTool]] = mcp_tools or {}
# Cached per-agent catalog — agents are immutable, so a single
# render at first use covers the gateway's lifetime.
self._agent_catalog: dict[str, _AgentToolCatalog] = {}
def _catalog_for(self, agent: RaycastAgent) -> _AgentToolCatalog:
cached = self._agent_catalog.get(agent.name)
if cached is not None:
return cached
cat = _build_agent_catalog(agent, self._mcp_tools)
self._agent_catalog[agent.name] = cat
return cat
async def complete(
self,
@@ -277,7 +450,13 @@ class RaycastBackend:
raise TypeError(msg)
raycast_messages = _to_raycast_messages(messages)
tools = _build_tool_list(agent)
catalog = self._catalog_for(agent)
# Native remote tools + spliced MCP locals. ``None`` keeps the
# SDK from sending a ``tools`` field at all when neither is
# declared.
tools_arg: list[Tool | RemoteTool | str] | None = (
list(catalog.tools) if catalog.tools else None
)
# On the wire Raycast uses ``system_instructions`` as a format
# marker (``"markdown"`` for AI_CHAT, ``"plain"`` otherwise —
@@ -294,7 +473,8 @@ class RaycastBackend:
async for event in self._stream(
agent=agent,
raycast_messages=raycast_messages,
tools=tools,
tools=tools_arg,
mcp_routing=catalog.mcp_routing,
prompt_content=prompt_content,
temperature=_first_set(options.get("temperature"), agent.temperature),
reasoning_effort=_first_set(
@@ -310,6 +490,7 @@ class RaycastBackend:
agent: RaycastAgent,
raycast_messages: Sequence[RaycastMessage],
tools: list[Tool | RemoteTool | str] | None,
mcp_routing: Mapping[str, tuple[str, str]],
prompt_content: str | None,
temperature: float | None,
reasoning_effort: str | None,
@@ -319,45 +500,153 @@ class RaycastBackend:
yield build_message_start(message_id=message_id, model=agent.model)
state = _BlockState()
final_finish: str | None = None
final_usage: dict[str, int] | None = None
# ``working_messages`` is the rolling history fed back to
# Raycast as we resolve MCP tool_calls turn by turn.
working_messages = list(raycast_messages)
last_usage: dict[str, int] | None = None
last_finish: str | None = None
stream = self._client.chat.stream(
model=agent.model,
messages=list(raycast_messages),
source=agent.source,
# ``system_instructions=None`` → SDK substitutes the source
# default (``"markdown"`` / ``"plain"``). Real prompt goes
# into ``additional_system_instructions``.
additional_system_instructions=prompt_content,
user_preferences=agent.user_preferences,
tools=tools,
tool_choice=tool_choice,
temperature=temperature,
reasoning_effort=reasoning_effort,
)
for _turn in range(_MAX_TOOL_TURNS):
acc = _TurnAccumulator()
state.reset_turn_indexing()
async for chunk in stream:
for event in self._handle_chunk(chunk, state):
yield event
if chunk.finish_reason:
final_finish = chunk.finish_reason
if chunk.usage:
final_usage = chunk.usage
stream = self._client.chat.stream(
model=agent.model,
messages=working_messages,
source=agent.source,
# ``system_instructions=None`` → SDK substitutes the
# source default (``"markdown"`` / ``"plain"``). Real
# prompt goes into ``additional_system_instructions``.
additional_system_instructions=prompt_content,
user_preferences=agent.user_preferences,
tools=tools,
tool_choice=tool_choice,
temperature=temperature,
reasoning_effort=reasoning_effort,
)
# Close whatever block is still open before the message delta.
async for chunk in stream:
for event in self._handle_chunk(chunk, state, acc):
yield event
if chunk.finish_reason:
acc.finish_reason = chunk.finish_reason
if chunk.usage:
acc.usage = chunk.usage
last_finish = acc.finish_reason
if acc.usage:
last_usage = acc.usage
# Figure out which (if any) tool_calls land on our MCP
# routing table. Anything not in the table is left to bubble
# out of the envelope as a regular tool_use block — the
# Anthropic caller can then respond with a tool_result the
# usual way, and the next ``complete`` invocation will
# carry it back in.
pending_mcp_calls = [
tid for tid in acc.tool_order if acc.tool_names.get(tid) in mcp_routing
]
if not pending_mcp_calls:
break
# Close whatever block is still open before we step into
# tool execution — the next turn's chunks start a fresh
# block sequence.
if state.kind is not None:
yield build_content_block_stop(state.index)
state.kind = None
# Echo the assistant turn so Raycast sees its own reply +
# tool_calls in subsequent context.
assistant_text = "".join(acc.text_parts)
assistant_tool_calls = [
ToolCall(
id=tid,
name=acc.tool_names.get(tid, ""),
arguments="".join(acc.tool_args.get(tid, [])) or "{}",
)
for tid in acc.tool_order
]
working_messages.append(
RaycastMessage.assistant(
text=assistant_text, tool_calls=assistant_tool_calls or None
)
)
# Dispatch each MCP-routed call and append a ``tool``
# message. Calls not in the routing table get a placeholder
# error so the model can correct itself rather than the
# gateway hanging the conversation.
for tid in acc.tool_order:
tool_name = acc.tool_names.get(tid, "")
args_str = "".join(acc.tool_args.get(tid, [])) or "{}"
if tool_name in mcp_routing:
result_text = await self._dispatch_mcp_call(
tool_name=tool_name, args_json=args_str, mcp_routing=mcp_routing
)
else:
# Non-MCP tool — shouldn't really happen because we
# haven't surfaced any other locals, but defend.
result_text = f"Tool {tool_name!r} is not handled by the gateway."
working_messages.append(
RaycastMessage.tool(
tool_call_id=tid, name=tool_name, result=result_text
)
)
# Loop: re-stream with the new history.
else:
_log.warning(
"raycast tool-call loop hit %d-turn cap for agent %r; "
"closing the envelope",
_MAX_TOOL_TURNS,
agent.name,
)
# Close whatever block is still open before the final delta.
if state.kind is not None:
yield build_content_block_stop(state.index)
state.kind = None
yield build_message_delta(
stop_reason=_map_stop_reason(final_finish),
usage=final_usage,
stop_reason=_map_stop_reason(last_finish), usage=last_usage
)
yield build_message_stop()
async def _dispatch_mcp_call(
self,
*,
tool_name: str,
args_json: str,
mcp_routing: Mapping[str, tuple[str, str]],
) -> str:
"""Route one tool_call to the matching MCP and render its result.
Errors (missing server, JSON parse, tool exception) surface as
text in the ``tool`` message rather than crashing the stream —
the model gets a chance to recover or apologize.
"""
mcp_name, original_name = mcp_routing[tool_name]
server = self._mcp_servers.get(mcp_name)
if server is None:
return f"Error: MCP {mcp_name!r} is not registered."
try:
args = json.loads(args_json) if args_json else {}
except json.JSONDecodeError as exc:
return f"Error: tool arguments are not valid JSON ({exc})."
if not isinstance(args, dict):
return "Error: tool arguments must be a JSON object."
try:
result = await server.call_tool(original_name, args)
except Exception as exc: # noqa: BLE001
_log.exception(
"MCP call failed: %s.%s args=%r", mcp_name, original_name, args
)
return f"Error calling {tool_name}: {exc}"
return _render_mcp_tool_result(result)
def _handle_chunk(
self, chunk: ChatStreamChunk, state: _BlockState
self, chunk: ChatStreamChunk, state: _BlockState, acc: _TurnAccumulator
) -> Iterable[MessageStreamEvent]:
"""Translate one Raycast chunk into zero or more Anthropic events.
@@ -373,6 +662,7 @@ class RaycastBackend:
if chunk.text:
events.extend(self._ensure_kind(state, "text"))
events.append(build_text_delta(state.index, chunk.text))
acc.text_parts.append(chunk.text)
if chunk.reasoning:
events.extend(self._ensure_kind(state, "thinking"))
@@ -380,7 +670,9 @@ class RaycastBackend:
if chunk.tool_calls:
events.extend(
self._handle_tool_calls(chunk, state, is_final_summary=is_final_summary)
self._handle_tool_calls(
chunk, state, acc, is_final_summary=is_final_summary
)
)
return events
@@ -409,6 +701,7 @@ class RaycastBackend:
self,
chunk: ChatStreamChunk,
state: _BlockState,
acc: _TurnAccumulator,
*,
is_final_summary: bool,
) -> Iterable[MessageStreamEvent]:
@@ -419,6 +712,11 @@ class RaycastBackend:
by id when present, otherwise by their wire ``index`` field, and
the final-summary chunk's arguments string is dropped because the
deltas already streamed it.
Side-effects on ``acc`` mirror what gets emitted to the wire so
the gateway can rebuild a full ``ToolCall`` for the next Raycast
turn (it needs the joined ``arguments`` JSON string, which the
wire never gives us as a whole — only in fragments).
"""
events: list[MessageStreamEvent] = []
raw_tcs = chunk.raw.get("tool_calls") or []
@@ -449,20 +747,35 @@ class RaycastBackend:
state.kind = "tool_use"
block_idx = state.index
state.tool_id_to_block[tool_id] = block_idx
acc.add_tool_phase1(tool_id, tc.name or "")
events.append(
build_tool_use_block_start(
block_idx, tool_use_id=tool_id, name=tc.name or ""
)
)
if tc.arguments and not is_final_summary:
# Streaming providers (GPT) deliver arguments as deltas
# across chunks and re-state the full string in a final
# summary chunk; non-streaming-args providers (Gemini)
# only send args in the final summary. Emit args in
# both cases when this is the first appearance — final
# summary then IS the full string.
if tc.arguments:
events.append(build_input_json_delta(block_idx, tc.arguments))
acc.add_tool_args(tool_id, tc.arguments)
continue
# Existing block. Skip args on the final summary — they're a
# full restatement of what's already been delta'd.
# Existing block. Final summary chunks restate the full args
# string; if we already streamed deltas, that restatement is
# a duplicate (skip). If we streamed nothing (the streaming
# provider didn't send mid-arg deltas — Gemini path again),
# the summary IS the args — emit it once.
if is_final_summary:
if tc.arguments and not acc.tool_args.get(tool_id):
events.append(build_input_json_delta(block_idx, tc.arguments))
acc.add_tool_args(tool_id, tc.arguments)
continue
if tc.arguments:
events.append(build_input_json_delta(block_idx, tc.arguments))
acc.add_tool_args(tool_id, tc.arguments)
return events
+46 -11
View File
@@ -44,6 +44,8 @@ from beaver_gateway.settings import Settings
from beaver_gateway.storage import Database
if TYPE_CHECKING:
from fastmcp import FastMCP
from fastmcp.tools.base import Tool as FastMCPTool
from starlette.applications import Starlette
from beaver_gateway.backends.base import Backend
@@ -95,16 +97,26 @@ async def _async_main() -> None:
stack.push_async_callback(token_store.stop)
# Internal MCP URLs must exist before we construct any
# ClaudeCodeBackendAdapter — adapters bake the URLs into their
# ``BackendOptions.mcp_servers`` at construction time.
internal_app, internal_urls = _build_internal_mcp(
# ``BackendOptions.mcp_servers`` at construction time. The
# ``mcp_servers`` map is used by the Raycast backend, which
# needs in-process ``list_tools`` / ``call_tool`` access (the
# Raycast wire has no native MCP concept).
internal_app, internal_urls, mcp_servers = _build_internal_mcp(
gateway.mcps, settings=settings
)
# Prefetch tool catalogs for every MCP so RaycastAgent requests
# don't pay a per-turn list_tools roundtrip and so a broken MCP
# surfaces at startup instead of mid-conversation.
mcp_tools = await _prefetch_mcp_tools(mcp_servers)
backends: dict[str, Backend] = await _build_backends(
settings=settings,
agents=agents,
stack=stack,
mcp_internal_urls=internal_urls,
mcp_servers=mcp_servers,
mcp_tools=mcp_tools,
)
runtime = GatewayRuntime(
@@ -152,20 +164,39 @@ async def _async_main() -> None:
def _build_internal_mcp(
mcps: list[McpServerT], *, settings: Settings
) -> tuple[Starlette | None, dict[str, str]]:
"""Build the aggregator app + URL map, or return ``(None, {})``.
) -> tuple[Starlette | None, dict[str, str], dict[str, FastMCP]]:
"""Build the aggregator app + URL map + server map, or empty equivalents.
The URL map is always handed out (frontends may still introspect
``runtime.mcp_internal_urls`` even if nothing is configured); the
app is ``None`` when there are no MCPs to mount, so the caller
skips the uvicorn task entirely.
skips the uvicorn task entirely. The server map is the in-process
handle the Raycast backend needs to splice MCP tools into its
requests — empty when no MCPs are configured.
"""
if not mcps:
return None, {}
app, urls = build_internal_app(
mcps, host="127.0.0.1", port=settings.internal_mcp_port
)
return app, urls
return None, {}, {}
return build_internal_app(mcps, host="127.0.0.1", port=settings.internal_mcp_port)
async def _prefetch_mcp_tools(
servers: dict[str, FastMCP],
) -> dict[str, list[FastMCPTool]]:
"""Eagerly enumerate tools per MCP so the Raycast loop has a static catalog.
Each underlying proxy is allowed to fail independently — a broken
MCP shouldn't take down the whole gateway. The result has one entry
per MCP that responded; agents that ``expose_mcps`` a missing entry
will simply expose no tools from it (logged once per request).
"""
out: dict[str, list[FastMCPTool]] = {}
for name, server in servers.items():
try:
out[name] = list(await server.list_tools())
except Exception: # noqa: BLE001 — proxy can raise any transport error; we degrade per-MCP rather than fail the whole gateway
_log.exception("failed to list tools for MCP %r — skipping", name)
out[name] = []
return out
async def _serve_internal_mcp(app: Starlette, *, settings: Settings) -> None:
@@ -196,6 +227,8 @@ async def _build_backends(
agents: AgentRegistry,
stack: AsyncExitStack,
mcp_internal_urls: dict[str, str],
mcp_servers: dict[str, FastMCP],
mcp_tools: dict[str, list[FastMCPTool]],
) -> dict[str, Backend]:
"""Construct one backend per agent name.
@@ -216,7 +249,9 @@ async def _build_backends(
if raycast_agents:
client = await _try_open_raycast_client(settings, stack)
if client is not None:
raycast_backend = RaycastBackend(client)
raycast_backend = RaycastBackend(
client, mcp_servers=mcp_servers, mcp_tools=mcp_tools
)
for a in raycast_agents:
backends[a.name] = raycast_backend
+42 -7
View File
@@ -45,6 +45,7 @@ import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import delete
from sqlmodel import select
from beaver_gateway.storage.models import Conversation, ConversationMessage
@@ -136,7 +137,34 @@ async def load_messages(
)
result = await session.exec(stmt)
rows = result.all()
return [{"role": r.role, "content": json.loads(r.content_json)} for r in rows]
return [
{"role": r.role, "content": _sanitize_content(json.loads(r.content_json))}
for r in rows
]
def _sanitize_content(content: Any) -> Any:
"""Strip wire-illegal fields from stored Anthropic content blocks.
Older capture code emitted ``"is_error": null`` on ``tool_result``
blocks; the Anthropic API rejects null there (the field is optional
but, when present, must be boolean). We omit the key on read so
historical rows don't break continuation.
"""
if not isinstance(content, list):
return content
cleaned: list[Any] = []
for blk in content:
out_blk = blk
if (
isinstance(blk, dict)
and blk.get("type") == "tool_result"
and blk.get("is_error") is None
and "is_error" in blk
):
out_blk = {k: v for k, v in blk.items() if k != "is_error"}
cleaned.append(out_blk)
return cleaned
async def rewrite_messages(
@@ -148,13 +176,20 @@ async def rewrite_messages(
our volume; if it ever matters we can switch to soft-delete +
branch pointers.
"""
# Delete existing rows for this conversation.
existing_stmt = select(ConversationMessage).where(
ConversationMessage.conversation_id == conversation_id
# Bulk-delete and flush before inserting the new sequence: SQLAlchemy's
# unit-of-work flushes INSERTs before DELETEs by default, which would
# trip ``uq_msg_conv_seq`` when the new rows reuse the same seq numbers
# as the soon-to-be-deleted ones.
# SQLModel descriptors resolve to ColumnElement at runtime but to bare
# ``int`` in ty's stubs; the select-path at line 135 lives behind sqlmodel's
# own ``select`` overloads that hide it, but ``sqlalchemy.delete().where``
# uses the raw stubs.
await session.execute( # ty: ignore[deprecated]
delete(ConversationMessage).where(
ConversationMessage.conversation_id == conversation_id # ty: ignore[invalid-argument-type]
)
)
result = await session.exec(existing_stmt)
for row in result.all():
await session.delete(row)
await session.flush()
# Insert the new sequence.
for seq, m in enumerate(messages):
session.add(
@@ -29,10 +29,12 @@ import json
import logging
import os
import tempfile
import time
from pathlib import Path
from typing import TYPE_CHECKING, Any
import aiofile
from anthropic.types import RawContentBlockStopEvent
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.responses import JSONResponse
@@ -46,7 +48,7 @@ from beaver_gateway.core.conversation_store import (
rewrite_messages,
)
from beaver_gateway.core.turn_record import TurnRecord
from beaver_gateway.frontends._accumulate import accumulate
from beaver_gateway.frontends._accumulate import StreamAccumulator
from beaver_gateway.frontends._auth import require_token
from beaver_gateway.frontends.base import Frontend
from beaver_gateway.frontends.markdown import parser, renderer
@@ -69,6 +71,14 @@ _log = logging.getLogger("beaver_gateway.frontends.markdown")
__all__ = ["MarkdownFrontend"]
# How often we re-render the assistant turn into the .md file while the
# backend stream is still open. Trades responsiveness (faster updates to
# Obsidian sync / Raycast tailers) against write amplification. Each
# ``RawContentBlockStopEvent`` also forces a flush regardless of the
# timer, so block boundaries always land in the file.
_STREAM_FLUSH_DEBOUNCE = 0.4
class MarkdownFrontend(Frontend):
"""FastAPI app behind ``POST /chat`` driven by Obsidian-vault files."""
@@ -285,21 +295,23 @@ class MarkdownFrontend(Frontend):
TurnCapture() if isinstance(backend, ClaudeCodeBackendAdapter) else None
)
kwargs: dict[str, Any] = {}
if capture is not None:
kwargs["capture"] = capture
events = backend.complete(
agent=agent, messages=outcome.messages, system=None, **kwargs
)
try:
kwargs: dict[str, Any] = {}
if capture is not None:
kwargs["capture"] = capture
events = backend.complete(
agent=agent, messages=outcome.messages, system=None, **kwargs
message = await self._stream_to_file(
events=events,
file_path=file_path,
parsed=parsed,
model=agent.model or agent.name,
filename=filename,
)
message = await accumulate(events, model=agent.model or agent.name)
except HTTPException:
raise
except Exception as exc:
_log.exception("backend failed for %s", filename)
error_block = _render_error_block(exc)
new_body = renderer.append_to_body(parsed.body, error_block)
await _write_atomic(
file_path, _reattach_frontmatter(parsed.metadata, new_body)
)
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, f"backend error: {exc}"
) from exc
@@ -346,6 +358,67 @@ class MarkdownFrontend(Frontend):
# ---- helpers -------------------------------------------------------
async def _stream_to_file(
self,
*,
events: Any,
file_path: Path,
parsed: parser.ParsedFile,
model: str,
filename: str,
) -> Any:
"""Drain ``events`` into a ``Message``, flushing partials to disk.
Flushes happen on each ``RawContentBlockStopEvent`` (natural
block boundary, content is markdown-consistent) and on the
``_STREAM_FLUSH_DEBOUNCE`` timer between events. The partial
write keeps the as-parsed frontmatter; the post-stream final
write in ``_write_assistant_reply`` is what stamps the refreshed
fingerprint / agent / conversation_id.
On backend exception we still flush the last partial and append
an error callout, so the human sees both what arrived and why it
stopped. The exception propagates so ``_handle_chat`` can map it
to a 500.
"""
acc = StreamAccumulator()
async def flush_partial() -> None:
partial = acc.finalize(model=model)
if not partial.content:
return
rendered = renderer.render_assistant_message(partial)
new_body = renderer.append_to_body(parsed.body, rendered)
await _write_atomic(
file_path, _reattach_frontmatter(parsed.metadata, new_body)
)
try:
last_flush = time.monotonic()
async for ev in events:
acc.feed(ev)
now = time.monotonic()
if (
isinstance(ev, RawContentBlockStopEvent)
or (now - last_flush) >= _STREAM_FLUSH_DEBOUNCE
):
await flush_partial()
last_flush = now
except Exception as exc:
_log.exception("backend failed for %s", filename)
partial = acc.finalize(model=model)
new_body = parsed.body
if partial.content:
new_body = renderer.append_to_body(
new_body, renderer.render_assistant_message(partial)
)
new_body = renderer.append_to_body(new_body, _render_error_block(exc))
await _write_atomic(
file_path, _reattach_frontmatter(parsed.metadata, new_body)
)
raise
return acc.finalize(model=model)
async def _write_assistant_reply(
self,
*,
+15 -13
View File
@@ -48,8 +48,8 @@ ALL_NAMESPACE = "all"
def build_internal_app(
mcps: Iterable[McpServerT], *, host: str, port: int
) -> tuple[Starlette, dict[str, str]]:
"""Build the aggregator ``Starlette`` app and the per-namespace URL map.
) -> tuple[Starlette, dict[str, str], dict[str, FastMCP]]:
"""Build the aggregator ``Starlette`` app, per-namespace URL map, and server map.
``host``/``port`` only flavour the URL strings handed back — actually
listening on them is the caller's job (``cli.main`` runs a uvicorn
@@ -57,21 +57,23 @@ def build_internal_app(
have to format the URLs themselves and risk drifting from the
``/mcp/<name>`` convention.
Returns a map of ``{namespace: url}`` for the per-domain endpoints
only (claude-code's MCP routing expects per-domain framing). The
``/mcp/all/`` bundle endpoint exists on the same app but is
intentionally omitted from the URL map — it's only meaningful to
external clients via the MCP frontend, not to claude-code.
Returns:
* Starlette app to serve via uvicorn.
* ``{namespace: url}`` for the per-domain endpoints (claude-code's
MCP routing expects per-domain framing). ``/mcp/all/`` is
omitted — only meaningful to external clients via the MCP
frontend, not to claude-code.
* ``{namespace: FastMCP}`` for backends that need in-process
access (Raycast doesn't natively understand MCP, so the
gateway calls ``list_tools``/``call_tool`` directly to splice
MCP tools into the Raycast wire).
"""
servers: dict[str, FastMCP] = {spec.name: _build_server(spec) for spec in mcps}
child_apps = {
name: s.http_app(transport="http", path="/")
for name, s in servers.items()
name: s.http_app(transport="http", path="/") for name, s in servers.items()
}
routes = [
Mount(f"/mcp/{name}", app=app) for name, app in child_apps.items()
]
routes = [Mount(f"/mcp/{name}", app=app) for name, app in child_apps.items()]
# /mcp/all — flat-namespace bundle. Skip when there's nothing to
# bundle so we don't pay for an empty session manager lifecycle.
@@ -102,7 +104,7 @@ def build_internal_app(
# 307 redirect from ``/mcp/<name>`` to ``/mcp/<name>/`` that
# ``Mount`` produces when a child route lives at ``/``.
urls = {name: f"http://{host}:{port}/mcp/{name}/" for name in servers}
return app, urls
return app, urls, servers
def _build_server(spec: McpServerT) -> FastMCP:
Generated
+2 -2
View File
@@ -287,7 +287,7 @@ local = [
{ name = "raycast-api", version = "0.1.0", source = { editable = "../raycast-api" } },
]
prod = [
{ name = "claude-code-api", version = "0.1.0", source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#1f20cef7d49d290f2b620ebb8a7aca92cdbd0e2a" } },
{ name = "claude-code-api", version = "0.1.0", source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#86d8a8f4c471605577716ab0f039c857e6261a0e" } },
{ name = "raycast-api", version = "0.1.0", source = { git = "https://git.kotikot.com/beaver/raycast-api.git#e73894c8e435da5c0709f92f69f11bcd0dab9afe" } },
]
@@ -419,7 +419,7 @@ wheels = [
[[package]]
name = "claude-code-api"
version = "0.1.0"
source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#1f20cef7d49d290f2b620ebb8a7aca92cdbd0e2a" }
source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#86d8a8f4c471605577716ab0f039c857e6261a0e" }
resolution-markers = [
"python_full_version >= '3.14'",
"python_full_version < '3.14'",