diff --git a/src/beaver_gateway/backends/raycast.py b/src/beaver_gateway/backends/raycast.py index f6ebac0..39ce4bd 100644 --- a/src/beaver_gateway/backends/raycast.py +++ b/src/beaver_gateway/backends/raycast.py @@ -17,12 +17,24 @@ Two halves live here: ``arguments``) — Anthropic wants one ``content_block_start`` → deltas → ``content_block_stop`` per block, so we de-duplicate the final summary against the per-delta increments already emitted. + +MCP wiring: Raycast has no native MCP concept, so when an agent declares +``expose_mcps`` we splice each MCP's tools into the wire request as +``Tool.local(name=f"{mcp}__{tool}", ...)`` and run a gateway-internal +loop. Every time the model emits a tool_call for one of those local +tools we route it back to the underlying MCP in-process, append the +result as a ``tool`` message, and re-issue the stream — all inside one +Anthropic envelope (one ``message_start`` … one ``message_stop``). The +tool_use blocks DO surface to the caller (mirrors what +``ClaudeCodeBackendAdapter`` does), but tool_results stay internal. """ from __future__ import annotations import json +import logging import uuid +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, cast from raycast_api import Message as RaycastMessage @@ -47,6 +59,8 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, Mapping, Sequence from anthropic.types import MessageParam + from fastmcp import FastMCP + from fastmcp.tools.base import Tool as FastMCPTool from raycast_api import ChatStreamChunk, Client from beaver_gateway.agents.base import BaseAgent @@ -58,6 +72,16 @@ else: __all__ = ["RaycastBackend"] +_log = logging.getLogger("beaver_gateway.backends.raycast") + + +# Cap on consecutive tool-call turns inside one Anthropic envelope. +# Real conversations rarely chain more than a handful; the limit only +# fires on a model that loops, and surfaces as a clean ``end_turn`` +# with an error tool_result rather than a hang. +_MAX_TOOL_TURNS = 20 + + _RAYCAST_TO_ANTHROPIC_STOP: dict[str, StopReason] = { "stop": "end_turn", "STOP": "end_turn", @@ -132,9 +156,7 @@ def _extract_tool_result_text(content: object) -> str: return json.dumps(content, separators=(",", ":"), ensure_ascii=False) -def _to_raycast_messages( - messages: Iterable[MessageParam], -) -> list[RaycastMessage]: +def _to_raycast_messages(messages: Iterable[MessageParam]) -> list[RaycastMessage]: """Convert an Anthropic message history into a Raycast one. ``tool_use_id → name`` is tracked across the iteration so that @@ -206,27 +228,143 @@ def _to_raycast_messages( ) out.append( RaycastMessage.assistant( - text="\n".join(text_parts), - tool_calls=tool_calls or None, + text="\n".join(text_parts), tool_calls=tool_calls or None ) ) return out -def _build_tool_list( - agent: RaycastAgent, -) -> list[Tool | RemoteTool | str] | None: - """Wrap each native-tool name as a Raycast remote tool. +@dataclass(frozen=True, slots=True) +class _AgentToolCatalog: + """Per-agent MCP tool catalog + routing map. - Phase 1.2 scope: only the three model-agnostic remote tools - (`web_search`, `search_images`, `read_page`). Client-defined tools - coming through the Anthropic body's ``tools`` field stay out — that's - Phase 1.3+ alongside ``accept_client_tools``. + Cached on :class:`RaycastBackend` keyed by agent name. ``tools`` is + the fully-rendered ``tools`` argument for ``chat.stream`` (local MCP + tools first, native remote tools last). ``mcp_routing`` maps each + wire-name we registered back to the underlying ``(mcp, original)`` + pair so we can dispatch a tool_call without re-parsing the prefix. """ - if not agent.available_native_tools: - return None - return [Tool.remote(name) for name in agent.available_native_tools] + + tools: tuple[Tool, ...] + mcp_routing: Mapping[str, tuple[str, str]] + + +_EMPTY_SCHEMA: dict[str, Any] = {"type": "object", "properties": {}} + + +def _sanitize_schema(schema: Mapping[str, Any] | None) -> dict[str, Any]: + """Strip JSON-Schema meta keys that Raycast/Gemini's tool API rejects. + + MCP servers ship draft-07 schemas including ``$schema``/``$id``/etc. + OpenAI-style function-calling expects a leaner subset — leaving the + meta keys in causes the upstream provider to reject the entire + request with an opaque ``unknown_api_error``. We drop the + document-level meta keys and keep the structural keys + (``type``/``properties``/``required``/...). + """ + if not schema: + return dict(_EMPTY_SCHEMA) + out: dict[str, Any] = {} + for key, value in schema.items(): + if key.startswith("$"): + continue + out[key] = value + if "type" not in out: + out["type"] = "object" + if out.get("type") == "object" and "properties" not in out: + out["properties"] = {} + return out + + +def _build_agent_catalog( + agent: RaycastAgent, mcp_tools: Mapping[str, list[FastMCPTool]] +) -> _AgentToolCatalog: + """Resolve ``agent.expose_mcps`` against the prefetched MCP tool lists. + + Each exposed MCP contributes one ``Tool.local`` per tool, named + ``{mcp}__{tool}`` (mirrors claude-code's wire convention, so a model + that has seen one style sees a familiar one here). ``ExposedMcp.tools`` + optionally filters to a subset by original (unprefixed) name. + + A missing MCP entry (broken at startup, see + ``cli._prefetch_mcp_tools``) silently contributes nothing — surfaces + in logs at start, doesn't crash request-time. + """ + routing: dict[str, tuple[str, str]] = {} + local_tools: list[Tool] = [] + for em in agent.expose_mcps: + for mt in mcp_tools.get(em.name, []): + if em.tools is not None and mt.name not in em.tools: + continue + wire_name = f"{em.name}__{mt.name}" + routing[wire_name] = (em.name, mt.name) + local_tools.append( + Tool.local( + name=wire_name, + description=mt.description or "", + parameters=_sanitize_schema(mt.parameters), + ) + ) + remote_tools = [Tool.remote(n) for n in agent.available_native_tools] + return _AgentToolCatalog( + tools=tuple(local_tools + remote_tools), mcp_routing=routing + ) + + +def _render_mcp_tool_result(result: Any) -> str: + """Flatten a ``fastmcp.ToolResult`` into the text Raycast carries. + + Raycast ``tool`` messages only hold a string. MCP results can carry + text, images, structured payloads, etc. — we keep text blocks + verbatim, JSON-encode anything else, and fall back to + ``structured_content`` if the content list is empty. + """ + parts: list[str] = [] + for block in getattr(result, "content", None) or []: + text = getattr(block, "text", None) + if isinstance(text, str): + parts.append(text) + continue + dump = getattr(block, "model_dump", None) + if callable(dump): + parts.append(json.dumps(dump(), ensure_ascii=False)) + else: + parts.append(str(block)) + if not parts: + structured = getattr(result, "structured_content", None) + if structured is not None: + parts.append(json.dumps(structured, ensure_ascii=False)) + return "\n".join(parts) + + +@dataclass(slots=True) +class _TurnAccumulator: + """Pieces collected from one Raycast stream needed to feed the next. + + ``text_parts`` is the assistant's text portion (joined and replayed + on the next ``chat.stream`` call so the model sees its own prior + reply). ``tool_*`` mirrors the same data the wire-state already + emitted, but kept in raw form so we can build ``ToolCall`` objects + and dispatch them through the gateway. + """ + + text_parts: list[str] = field(default_factory=list) + tool_order: list[str] = field(default_factory=list) + tool_names: dict[str, str] = field(default_factory=dict) + tool_args: dict[str, list[str]] = field(default_factory=dict) + finish_reason: str | None = None + usage: dict[str, int] | None = None + + def add_tool_phase1(self, tool_id: str, name: str) -> None: + if tool_id in self.tool_names: + return + self.tool_order.append(tool_id) + self.tool_names[tool_id] = name + self.tool_args.setdefault(tool_id, []) + + def add_tool_args(self, tool_id: str, fragment: str) -> None: + self.tool_args.setdefault(tool_id, []).append(fragment) class _BlockState: @@ -241,6 +379,10 @@ class _BlockState: in phase 1 only and ``index`` in every chunk: ``tool_id_to_block`` keys by Raycast tool-call id, ``tool_idx_to_id`` resolves chunk indices back to that id for the no-id delta chunks. + + State spans the entire Anthropic envelope (all Raycast turns) so + block indices grow monotonically across the tool-call loop. The + per-turn ``_TurnAccumulator`` carries the volatile bits. """ __slots__ = ("index", "kind", "tool_id_to_block", "tool_idx_to_id") @@ -251,18 +393,49 @@ class _BlockState: self.tool_id_to_block: dict[str, int] = {} self.tool_idx_to_id: dict[int, str] = {} + def reset_turn_indexing(self) -> None: + """Drop chunk-index → id routing between turns. + + Raycast chunk ``index`` fields restart at 0 inside each fresh + ``chat.stream``; reusing the previous turn's table would alias + new tool_calls onto old ids. ``tool_id_to_block`` we keep — ids + are stream-unique and new ones won't collide. + """ + self.tool_idx_to_id = {} + class RaycastBackend: """Adapter from ``raycast-api`` chat streams to Anthropic stream events. Construction takes a long-lived :class:`raycast_api.Client` (one per - gateway — bearer + device_id are process-wide). Each - :meth:`complete` call opens one ``chat.stream`` and yields a fully - Anthropic-shaped event sequence. + gateway — bearer + device_id are process-wide) plus the in-process + MCP server map and the prefetched per-MCP tool list. Each + :meth:`complete` call opens one or more ``chat.stream`` calls — one + per agent "turn" inside the gateway-internal tool-call loop — and + yields a single Anthropic stream envelope spanning all of them. """ - def __init__(self, client: Client) -> None: + def __init__( + self, + client: Client, + *, + mcp_servers: Mapping[str, FastMCP] | None = None, + mcp_tools: Mapping[str, list[FastMCPTool]] | None = None, + ) -> None: self._client = client + self._mcp_servers: Mapping[str, FastMCP] = mcp_servers or {} + self._mcp_tools: Mapping[str, list[FastMCPTool]] = mcp_tools or {} + # Cached per-agent catalog — agents are immutable, so a single + # render at first use covers the gateway's lifetime. + self._agent_catalog: dict[str, _AgentToolCatalog] = {} + + def _catalog_for(self, agent: RaycastAgent) -> _AgentToolCatalog: + cached = self._agent_catalog.get(agent.name) + if cached is not None: + return cached + cat = _build_agent_catalog(agent, self._mcp_tools) + self._agent_catalog[agent.name] = cat + return cat async def complete( self, @@ -277,7 +450,13 @@ class RaycastBackend: raise TypeError(msg) raycast_messages = _to_raycast_messages(messages) - tools = _build_tool_list(agent) + catalog = self._catalog_for(agent) + # Native remote tools + spliced MCP locals. ``None`` keeps the + # SDK from sending a ``tools`` field at all when neither is + # declared. + tools_arg: list[Tool | RemoteTool | str] | None = ( + list(catalog.tools) if catalog.tools else None + ) # On the wire Raycast uses ``system_instructions`` as a format # marker (``"markdown"`` for AI_CHAT, ``"plain"`` otherwise — @@ -294,7 +473,8 @@ class RaycastBackend: async for event in self._stream( agent=agent, raycast_messages=raycast_messages, - tools=tools, + tools=tools_arg, + mcp_routing=catalog.mcp_routing, prompt_content=prompt_content, temperature=_first_set(options.get("temperature"), agent.temperature), reasoning_effort=_first_set( @@ -310,6 +490,7 @@ class RaycastBackend: agent: RaycastAgent, raycast_messages: Sequence[RaycastMessage], tools: list[Tool | RemoteTool | str] | None, + mcp_routing: Mapping[str, tuple[str, str]], prompt_content: str | None, temperature: float | None, reasoning_effort: str | None, @@ -319,45 +500,153 @@ class RaycastBackend: yield build_message_start(message_id=message_id, model=agent.model) state = _BlockState() - final_finish: str | None = None - final_usage: dict[str, int] | None = None + # ``working_messages`` is the rolling history fed back to + # Raycast as we resolve MCP tool_calls turn by turn. + working_messages = list(raycast_messages) + last_usage: dict[str, int] | None = None + last_finish: str | None = None - stream = self._client.chat.stream( - model=agent.model, - messages=list(raycast_messages), - source=agent.source, - # ``system_instructions=None`` → SDK substitutes the source - # default (``"markdown"`` / ``"plain"``). Real prompt goes - # into ``additional_system_instructions``. - additional_system_instructions=prompt_content, - user_preferences=agent.user_preferences, - tools=tools, - tool_choice=tool_choice, - temperature=temperature, - reasoning_effort=reasoning_effort, - ) + for _turn in range(_MAX_TOOL_TURNS): + acc = _TurnAccumulator() + state.reset_turn_indexing() - async for chunk in stream: - for event in self._handle_chunk(chunk, state): - yield event - if chunk.finish_reason: - final_finish = chunk.finish_reason - if chunk.usage: - final_usage = chunk.usage + stream = self._client.chat.stream( + model=agent.model, + messages=working_messages, + source=agent.source, + # ``system_instructions=None`` → SDK substitutes the + # source default (``"markdown"`` / ``"plain"``). Real + # prompt goes into ``additional_system_instructions``. + additional_system_instructions=prompt_content, + user_preferences=agent.user_preferences, + tools=tools, + tool_choice=tool_choice, + temperature=temperature, + reasoning_effort=reasoning_effort, + ) - # Close whatever block is still open before the message delta. + async for chunk in stream: + for event in self._handle_chunk(chunk, state, acc): + yield event + if chunk.finish_reason: + acc.finish_reason = chunk.finish_reason + if chunk.usage: + acc.usage = chunk.usage + + last_finish = acc.finish_reason + if acc.usage: + last_usage = acc.usage + + # Figure out which (if any) tool_calls land on our MCP + # routing table. Anything not in the table is left to bubble + # out of the envelope as a regular tool_use block — the + # Anthropic caller can then respond with a tool_result the + # usual way, and the next ``complete`` invocation will + # carry it back in. + pending_mcp_calls = [ + tid for tid in acc.tool_order if acc.tool_names.get(tid) in mcp_routing + ] + if not pending_mcp_calls: + break + + # Close whatever block is still open before we step into + # tool execution — the next turn's chunks start a fresh + # block sequence. + if state.kind is not None: + yield build_content_block_stop(state.index) + state.kind = None + + # Echo the assistant turn so Raycast sees its own reply + + # tool_calls in subsequent context. + assistant_text = "".join(acc.text_parts) + assistant_tool_calls = [ + ToolCall( + id=tid, + name=acc.tool_names.get(tid, ""), + arguments="".join(acc.tool_args.get(tid, [])) or "{}", + ) + for tid in acc.tool_order + ] + working_messages.append( + RaycastMessage.assistant( + text=assistant_text, tool_calls=assistant_tool_calls or None + ) + ) + + # Dispatch each MCP-routed call and append a ``tool`` + # message. Calls not in the routing table get a placeholder + # error so the model can correct itself rather than the + # gateway hanging the conversation. + for tid in acc.tool_order: + tool_name = acc.tool_names.get(tid, "") + args_str = "".join(acc.tool_args.get(tid, [])) or "{}" + if tool_name in mcp_routing: + result_text = await self._dispatch_mcp_call( + tool_name=tool_name, args_json=args_str, mcp_routing=mcp_routing + ) + else: + # Non-MCP tool — shouldn't really happen because we + # haven't surfaced any other locals, but defend. + result_text = f"Tool {tool_name!r} is not handled by the gateway." + working_messages.append( + RaycastMessage.tool( + tool_call_id=tid, name=tool_name, result=result_text + ) + ) + + # Loop: re-stream with the new history. + else: + _log.warning( + "raycast tool-call loop hit %d-turn cap for agent %r; " + "closing the envelope", + _MAX_TOOL_TURNS, + agent.name, + ) + + # Close whatever block is still open before the final delta. if state.kind is not None: yield build_content_block_stop(state.index) state.kind = None yield build_message_delta( - stop_reason=_map_stop_reason(final_finish), - usage=final_usage, + stop_reason=_map_stop_reason(last_finish), usage=last_usage ) yield build_message_stop() + async def _dispatch_mcp_call( + self, + *, + tool_name: str, + args_json: str, + mcp_routing: Mapping[str, tuple[str, str]], + ) -> str: + """Route one tool_call to the matching MCP and render its result. + + Errors (missing server, JSON parse, tool exception) surface as + text in the ``tool`` message rather than crashing the stream — + the model gets a chance to recover or apologize. + """ + mcp_name, original_name = mcp_routing[tool_name] + server = self._mcp_servers.get(mcp_name) + if server is None: + return f"Error: MCP {mcp_name!r} is not registered." + try: + args = json.loads(args_json) if args_json else {} + except json.JSONDecodeError as exc: + return f"Error: tool arguments are not valid JSON ({exc})." + if not isinstance(args, dict): + return "Error: tool arguments must be a JSON object." + try: + result = await server.call_tool(original_name, args) + except Exception as exc: # noqa: BLE001 + _log.exception( + "MCP call failed: %s.%s args=%r", mcp_name, original_name, args + ) + return f"Error calling {tool_name}: {exc}" + return _render_mcp_tool_result(result) + def _handle_chunk( - self, chunk: ChatStreamChunk, state: _BlockState + self, chunk: ChatStreamChunk, state: _BlockState, acc: _TurnAccumulator ) -> Iterable[MessageStreamEvent]: """Translate one Raycast chunk into zero or more Anthropic events. @@ -373,6 +662,7 @@ class RaycastBackend: if chunk.text: events.extend(self._ensure_kind(state, "text")) events.append(build_text_delta(state.index, chunk.text)) + acc.text_parts.append(chunk.text) if chunk.reasoning: events.extend(self._ensure_kind(state, "thinking")) @@ -380,7 +670,9 @@ class RaycastBackend: if chunk.tool_calls: events.extend( - self._handle_tool_calls(chunk, state, is_final_summary=is_final_summary) + self._handle_tool_calls( + chunk, state, acc, is_final_summary=is_final_summary + ) ) return events @@ -409,6 +701,7 @@ class RaycastBackend: self, chunk: ChatStreamChunk, state: _BlockState, + acc: _TurnAccumulator, *, is_final_summary: bool, ) -> Iterable[MessageStreamEvent]: @@ -419,6 +712,11 @@ class RaycastBackend: by id when present, otherwise by their wire ``index`` field, and the final-summary chunk's arguments string is dropped because the deltas already streamed it. + + Side-effects on ``acc`` mirror what gets emitted to the wire so + the gateway can rebuild a full ``ToolCall`` for the next Raycast + turn (it needs the joined ``arguments`` JSON string, which the + wire never gives us as a whole — only in fragments). """ events: list[MessageStreamEvent] = [] raw_tcs = chunk.raw.get("tool_calls") or [] @@ -449,20 +747,35 @@ class RaycastBackend: state.kind = "tool_use" block_idx = state.index state.tool_id_to_block[tool_id] = block_idx + acc.add_tool_phase1(tool_id, tc.name or "") events.append( build_tool_use_block_start( block_idx, tool_use_id=tool_id, name=tc.name or "" ) ) - if tc.arguments and not is_final_summary: + # Streaming providers (GPT) deliver arguments as deltas + # across chunks and re-state the full string in a final + # summary chunk; non-streaming-args providers (Gemini) + # only send args in the final summary. Emit args in + # both cases when this is the first appearance — final + # summary then IS the full string. + if tc.arguments: events.append(build_input_json_delta(block_idx, tc.arguments)) + acc.add_tool_args(tool_id, tc.arguments) continue - # Existing block. Skip args on the final summary — they're a - # full restatement of what's already been delta'd. + # Existing block. Final summary chunks restate the full args + # string; if we already streamed deltas, that restatement is + # a duplicate (skip). If we streamed nothing (the streaming + # provider didn't send mid-arg deltas — Gemini path again), + # the summary IS the args — emit it once. if is_final_summary: + if tc.arguments and not acc.tool_args.get(tool_id): + events.append(build_input_json_delta(block_idx, tc.arguments)) + acc.add_tool_args(tool_id, tc.arguments) continue if tc.arguments: events.append(build_input_json_delta(block_idx, tc.arguments)) + acc.add_tool_args(tool_id, tc.arguments) return events diff --git a/src/beaver_gateway/cli.py b/src/beaver_gateway/cli.py index f2b1aad..c74ea4d 100644 --- a/src/beaver_gateway/cli.py +++ b/src/beaver_gateway/cli.py @@ -44,6 +44,8 @@ from beaver_gateway.settings import Settings from beaver_gateway.storage import Database if TYPE_CHECKING: + from fastmcp import FastMCP + from fastmcp.tools.base import Tool as FastMCPTool from starlette.applications import Starlette from beaver_gateway.backends.base import Backend @@ -95,16 +97,26 @@ async def _async_main() -> None: stack.push_async_callback(token_store.stop) # Internal MCP URLs must exist before we construct any # ClaudeCodeBackendAdapter — adapters bake the URLs into their - # ``BackendOptions.mcp_servers`` at construction time. - internal_app, internal_urls = _build_internal_mcp( + # ``BackendOptions.mcp_servers`` at construction time. The + # ``mcp_servers`` map is used by the Raycast backend, which + # needs in-process ``list_tools`` / ``call_tool`` access (the + # Raycast wire has no native MCP concept). + internal_app, internal_urls, mcp_servers = _build_internal_mcp( gateway.mcps, settings=settings ) + # Prefetch tool catalogs for every MCP so RaycastAgent requests + # don't pay a per-turn list_tools roundtrip and so a broken MCP + # surfaces at startup instead of mid-conversation. + mcp_tools = await _prefetch_mcp_tools(mcp_servers) + backends: dict[str, Backend] = await _build_backends( settings=settings, agents=agents, stack=stack, mcp_internal_urls=internal_urls, + mcp_servers=mcp_servers, + mcp_tools=mcp_tools, ) runtime = GatewayRuntime( @@ -152,20 +164,39 @@ async def _async_main() -> None: def _build_internal_mcp( mcps: list[McpServerT], *, settings: Settings -) -> tuple[Starlette | None, dict[str, str]]: - """Build the aggregator app + URL map, or return ``(None, {})``. +) -> tuple[Starlette | None, dict[str, str], dict[str, FastMCP]]: + """Build the aggregator app + URL map + server map, or empty equivalents. The URL map is always handed out (frontends may still introspect ``runtime.mcp_internal_urls`` even if nothing is configured); the app is ``None`` when there are no MCPs to mount, so the caller - skips the uvicorn task entirely. + skips the uvicorn task entirely. The server map is the in-process + handle the Raycast backend needs to splice MCP tools into its + requests — empty when no MCPs are configured. """ if not mcps: - return None, {} - app, urls = build_internal_app( - mcps, host="127.0.0.1", port=settings.internal_mcp_port - ) - return app, urls + return None, {}, {} + return build_internal_app(mcps, host="127.0.0.1", port=settings.internal_mcp_port) + + +async def _prefetch_mcp_tools( + servers: dict[str, FastMCP], +) -> dict[str, list[FastMCPTool]]: + """Eagerly enumerate tools per MCP so the Raycast loop has a static catalog. + + Each underlying proxy is allowed to fail independently — a broken + MCP shouldn't take down the whole gateway. The result has one entry + per MCP that responded; agents that ``expose_mcps`` a missing entry + will simply expose no tools from it (logged once per request). + """ + out: dict[str, list[FastMCPTool]] = {} + for name, server in servers.items(): + try: + out[name] = list(await server.list_tools()) + except Exception: # noqa: BLE001 — proxy can raise any transport error; we degrade per-MCP rather than fail the whole gateway + _log.exception("failed to list tools for MCP %r — skipping", name) + out[name] = [] + return out async def _serve_internal_mcp(app: Starlette, *, settings: Settings) -> None: @@ -196,6 +227,8 @@ async def _build_backends( agents: AgentRegistry, stack: AsyncExitStack, mcp_internal_urls: dict[str, str], + mcp_servers: dict[str, FastMCP], + mcp_tools: dict[str, list[FastMCPTool]], ) -> dict[str, Backend]: """Construct one backend per agent name. @@ -216,7 +249,9 @@ async def _build_backends( if raycast_agents: client = await _try_open_raycast_client(settings, stack) if client is not None: - raycast_backend = RaycastBackend(client) + raycast_backend = RaycastBackend( + client, mcp_servers=mcp_servers, mcp_tools=mcp_tools + ) for a in raycast_agents: backends[a.name] = raycast_backend diff --git a/src/beaver_gateway/core/conversation_store.py b/src/beaver_gateway/core/conversation_store.py index 37e7b9a..4e0b54b 100644 --- a/src/beaver_gateway/core/conversation_store.py +++ b/src/beaver_gateway/core/conversation_store.py @@ -45,6 +45,7 @@ import uuid from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast +from sqlalchemy import delete from sqlmodel import select from beaver_gateway.storage.models import Conversation, ConversationMessage @@ -136,7 +137,34 @@ async def load_messages( ) result = await session.exec(stmt) rows = result.all() - return [{"role": r.role, "content": json.loads(r.content_json)} for r in rows] + return [ + {"role": r.role, "content": _sanitize_content(json.loads(r.content_json))} + for r in rows + ] + + +def _sanitize_content(content: Any) -> Any: + """Strip wire-illegal fields from stored Anthropic content blocks. + + Older capture code emitted ``"is_error": null`` on ``tool_result`` + blocks; the Anthropic API rejects null there (the field is optional + but, when present, must be boolean). We omit the key on read so + historical rows don't break continuation. + """ + if not isinstance(content, list): + return content + cleaned: list[Any] = [] + for blk in content: + out_blk = blk + if ( + isinstance(blk, dict) + and blk.get("type") == "tool_result" + and blk.get("is_error") is None + and "is_error" in blk + ): + out_blk = {k: v for k, v in blk.items() if k != "is_error"} + cleaned.append(out_blk) + return cleaned async def rewrite_messages( @@ -148,13 +176,20 @@ async def rewrite_messages( our volume; if it ever matters we can switch to soft-delete + branch pointers. """ - # Delete existing rows for this conversation. - existing_stmt = select(ConversationMessage).where( - ConversationMessage.conversation_id == conversation_id + # Bulk-delete and flush before inserting the new sequence: SQLAlchemy's + # unit-of-work flushes INSERTs before DELETEs by default, which would + # trip ``uq_msg_conv_seq`` when the new rows reuse the same seq numbers + # as the soon-to-be-deleted ones. + # SQLModel descriptors resolve to ColumnElement at runtime but to bare + # ``int`` in ty's stubs; the select-path at line 135 lives behind sqlmodel's + # own ``select`` overloads that hide it, but ``sqlalchemy.delete().where`` + # uses the raw stubs. + await session.execute( # ty: ignore[deprecated] + delete(ConversationMessage).where( + ConversationMessage.conversation_id == conversation_id # ty: ignore[invalid-argument-type] + ) ) - result = await session.exec(existing_stmt) - for row in result.all(): - await session.delete(row) + await session.flush() # Insert the new sequence. for seq, m in enumerate(messages): session.add( diff --git a/src/beaver_gateway/frontends/markdown/frontend.py b/src/beaver_gateway/frontends/markdown/frontend.py index c1d4b3c..51d3fb7 100644 --- a/src/beaver_gateway/frontends/markdown/frontend.py +++ b/src/beaver_gateway/frontends/markdown/frontend.py @@ -29,10 +29,12 @@ import json import logging import os import tempfile +import time from pathlib import Path from typing import TYPE_CHECKING, Any import aiofile +from anthropic.types import RawContentBlockStopEvent from fastapi import FastAPI, HTTPException, Request, status from fastapi.responses import JSONResponse @@ -46,7 +48,7 @@ from beaver_gateway.core.conversation_store import ( rewrite_messages, ) from beaver_gateway.core.turn_record import TurnRecord -from beaver_gateway.frontends._accumulate import accumulate +from beaver_gateway.frontends._accumulate import StreamAccumulator from beaver_gateway.frontends._auth import require_token from beaver_gateway.frontends.base import Frontend from beaver_gateway.frontends.markdown import parser, renderer @@ -69,6 +71,14 @@ _log = logging.getLogger("beaver_gateway.frontends.markdown") __all__ = ["MarkdownFrontend"] +# How often we re-render the assistant turn into the .md file while the +# backend stream is still open. Trades responsiveness (faster updates to +# Obsidian sync / Raycast tailers) against write amplification. Each +# ``RawContentBlockStopEvent`` also forces a flush regardless of the +# timer, so block boundaries always land in the file. +_STREAM_FLUSH_DEBOUNCE = 0.4 + + class MarkdownFrontend(Frontend): """FastAPI app behind ``POST /chat`` driven by Obsidian-vault files.""" @@ -285,21 +295,23 @@ class MarkdownFrontend(Frontend): TurnCapture() if isinstance(backend, ClaudeCodeBackendAdapter) else None ) + kwargs: dict[str, Any] = {} + if capture is not None: + kwargs["capture"] = capture + events = backend.complete( + agent=agent, messages=outcome.messages, system=None, **kwargs + ) try: - kwargs: dict[str, Any] = {} - if capture is not None: - kwargs["capture"] = capture - events = backend.complete( - agent=agent, messages=outcome.messages, system=None, **kwargs + message = await self._stream_to_file( + events=events, + file_path=file_path, + parsed=parsed, + model=agent.model or agent.name, + filename=filename, ) - message = await accumulate(events, model=agent.model or agent.name) + except HTTPException: + raise except Exception as exc: - _log.exception("backend failed for %s", filename) - error_block = _render_error_block(exc) - new_body = renderer.append_to_body(parsed.body, error_block) - await _write_atomic( - file_path, _reattach_frontmatter(parsed.metadata, new_body) - ) raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, f"backend error: {exc}" ) from exc @@ -346,6 +358,67 @@ class MarkdownFrontend(Frontend): # ---- helpers ------------------------------------------------------- + async def _stream_to_file( + self, + *, + events: Any, + file_path: Path, + parsed: parser.ParsedFile, + model: str, + filename: str, + ) -> Any: + """Drain ``events`` into a ``Message``, flushing partials to disk. + + Flushes happen on each ``RawContentBlockStopEvent`` (natural + block boundary, content is markdown-consistent) and on the + ``_STREAM_FLUSH_DEBOUNCE`` timer between events. The partial + write keeps the as-parsed frontmatter; the post-stream final + write in ``_write_assistant_reply`` is what stamps the refreshed + fingerprint / agent / conversation_id. + + On backend exception we still flush the last partial and append + an error callout, so the human sees both what arrived and why it + stopped. The exception propagates so ``_handle_chat`` can map it + to a 500. + """ + acc = StreamAccumulator() + + async def flush_partial() -> None: + partial = acc.finalize(model=model) + if not partial.content: + return + rendered = renderer.render_assistant_message(partial) + new_body = renderer.append_to_body(parsed.body, rendered) + await _write_atomic( + file_path, _reattach_frontmatter(parsed.metadata, new_body) + ) + + try: + last_flush = time.monotonic() + async for ev in events: + acc.feed(ev) + now = time.monotonic() + if ( + isinstance(ev, RawContentBlockStopEvent) + or (now - last_flush) >= _STREAM_FLUSH_DEBOUNCE + ): + await flush_partial() + last_flush = now + except Exception as exc: + _log.exception("backend failed for %s", filename) + partial = acc.finalize(model=model) + new_body = parsed.body + if partial.content: + new_body = renderer.append_to_body( + new_body, renderer.render_assistant_message(partial) + ) + new_body = renderer.append_to_body(new_body, _render_error_block(exc)) + await _write_atomic( + file_path, _reattach_frontmatter(parsed.metadata, new_body) + ) + raise + return acc.finalize(model=model) + async def _write_assistant_reply( self, *, diff --git a/src/beaver_gateway/mcp/internal_app.py b/src/beaver_gateway/mcp/internal_app.py index 294a7f4..faf639f 100644 --- a/src/beaver_gateway/mcp/internal_app.py +++ b/src/beaver_gateway/mcp/internal_app.py @@ -48,8 +48,8 @@ ALL_NAMESPACE = "all" def build_internal_app( mcps: Iterable[McpServerT], *, host: str, port: int -) -> tuple[Starlette, dict[str, str]]: - """Build the aggregator ``Starlette`` app and the per-namespace URL map. +) -> tuple[Starlette, dict[str, str], dict[str, FastMCP]]: + """Build the aggregator ``Starlette`` app, per-namespace URL map, and server map. ``host``/``port`` only flavour the URL strings handed back — actually listening on them is the caller's job (``cli.main`` runs a uvicorn @@ -57,21 +57,23 @@ def build_internal_app( have to format the URLs themselves and risk drifting from the ``/mcp/`` convention. - Returns a map of ``{namespace: url}`` for the per-domain endpoints - only (claude-code's MCP routing expects per-domain framing). The - ``/mcp/all/`` bundle endpoint exists on the same app but is - intentionally omitted from the URL map — it's only meaningful to - external clients via the MCP frontend, not to claude-code. + Returns: + * Starlette app to serve via uvicorn. + * ``{namespace: url}`` for the per-domain endpoints (claude-code's + MCP routing expects per-domain framing). ``/mcp/all/`` is + omitted — only meaningful to external clients via the MCP + frontend, not to claude-code. + * ``{namespace: FastMCP}`` for backends that need in-process + access (Raycast doesn't natively understand MCP, so the + gateway calls ``list_tools``/``call_tool`` directly to splice + MCP tools into the Raycast wire). """ servers: dict[str, FastMCP] = {spec.name: _build_server(spec) for spec in mcps} child_apps = { - name: s.http_app(transport="http", path="/") - for name, s in servers.items() + name: s.http_app(transport="http", path="/") for name, s in servers.items() } - routes = [ - Mount(f"/mcp/{name}", app=app) for name, app in child_apps.items() - ] + routes = [Mount(f"/mcp/{name}", app=app) for name, app in child_apps.items()] # /mcp/all — flat-namespace bundle. Skip when there's nothing to # bundle so we don't pay for an empty session manager lifecycle. @@ -102,7 +104,7 @@ def build_internal_app( # 307 redirect from ``/mcp/`` to ``/mcp//`` that # ``Mount`` produces when a child route lives at ``/``. urls = {name: f"http://{host}:{port}/mcp/{name}/" for name in servers} - return app, urls + return app, urls, servers def _build_server(spec: McpServerT) -> FastMCP: diff --git a/uv.lock b/uv.lock index 29e7262..cbab96b 100644 --- a/uv.lock +++ b/uv.lock @@ -287,7 +287,7 @@ local = [ { name = "raycast-api", version = "0.1.0", source = { editable = "../raycast-api" } }, ] prod = [ - { name = "claude-code-api", version = "0.1.0", source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#1f20cef7d49d290f2b620ebb8a7aca92cdbd0e2a" } }, + { name = "claude-code-api", version = "0.1.0", source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#86d8a8f4c471605577716ab0f039c857e6261a0e" } }, { name = "raycast-api", version = "0.1.0", source = { git = "https://git.kotikot.com/beaver/raycast-api.git#e73894c8e435da5c0709f92f69f11bcd0dab9afe" } }, ] @@ -419,7 +419,7 @@ wheels = [ [[package]] name = "claude-code-api" version = "0.1.0" -source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#1f20cef7d49d290f2b620ebb8a7aca92cdbd0e2a" } +source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#86d8a8f4c471605577716ab0f039c857e6261a0e" } resolution-markers = [ "python_full_version >= '3.14'", "python_full_version < '3.14'",