diff --git a/src/beaver_gateway/frontends/markdown/frontend.py b/src/beaver_gateway/frontends/markdown/frontend.py index 51d3fb7..b5657e3 100644 --- a/src/beaver_gateway/frontends/markdown/frontend.py +++ b/src/beaver_gateway/frontends/markdown/frontend.py @@ -36,7 +36,8 @@ from typing import TYPE_CHECKING, Any import aiofile from anthropic.types import RawContentBlockStopEvent from fastapi import FastAPI, HTTPException, Request, status -from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse from beaver_gateway.backends.claude_code import ClaudeCodeBackendAdapter, TurnCapture from beaver_gateway.core import audit @@ -58,7 +59,7 @@ from beaver_gateway.frontends.markdown.crossfront import ( ) if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import AsyncIterator, Callable from anthropic.types import MessageParam @@ -78,6 +79,13 @@ __all__ = ["MarkdownFrontend"] # timer, so block boundaries always land in the file. _STREAM_FLUSH_DEBOUNCE = 0.4 +# Debounce for the SSE ``/chat/stream`` path. Network IO is cheaper than +# atomic file rewrites, so we send updates more frequently — the client +# wants the lowest possible latency and we control the renderer on the +# other end (the Obsidian plugin splices deltas into the editor, no +# disk round-trip). +_SSE_FLUSH_DEBOUNCE = 0.1 + class MarkdownFrontend(Frontend): """FastAPI app behind ``POST /chat`` driven by Obsidian-vault files.""" @@ -148,6 +156,19 @@ class MarkdownFrontend(Frontend): def _build_app(self, runtime: GatewayRuntime) -> FastAPI: app = FastAPI(title="beaver-gateway / Markdown") + # ``/chat/stream`` is consumed via ``fetch`` from the Obsidian + # plugin (``requestUrl`` can't read a body incrementally), and + # ``fetch`` is subject to CORS. Auth is bearer-token so we don't + # need credentialed mode; allow any origin and the standard + # methods/headers. The other endpoints are happy to ride along. + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=False, + allow_methods=["*"], + allow_headers=["*"], + ) + @app.get("/healthz") async def healthz() -> dict[str, str]: return {"status": "ok"} @@ -201,6 +222,73 @@ class MarkdownFrontend(Frontend): finally: self._busy.discard(file_path) + @app.post("/chat/stream") + async def chat_stream(request: Request) -> Any: + # Same contract as ``/chat`` (bearer auth, identical body), + # but the response is ``text/event-stream`` and intermediate + # rendered states are pushed as ``delta`` events. The + # gateway-side disk write only happens once, at end of turn, + # so streaming consumers (Obsidian plugin) and Obsidian Sync + # don't fight over the same file mid-stream. + token_name = await require_token(request, runtime, scope="messages") + try: + body = await request.json() + except json.JSONDecodeError as exc: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, f"invalid JSON: {exc}" + ) from exc + + filename = body.get("filename") + if not isinstance(filename, str) or not filename.strip(): + raise HTTPException( + status.HTTP_400_BAD_REQUEST, "missing or non-string `filename`" + ) + content_override = body.get("content") + agent_override = body.get("agent") + if agent_override is not None and not isinstance(agent_override, str): + raise HTTPException( + status.HTTP_400_BAD_REQUEST, "`agent` must be a string" + ) + + file_path = self._resolve_path(filename) + + # 409 path stays JSON — the stream hasn't started yet, so + # the caller can read it the same way as on ``/chat``. + if file_path in self._busy: + return JSONResponse( + status_code=status.HTTP_409_CONFLICT, + content={"status": "in_progress", "filename": filename}, + ) + self._busy.add(file_path) + + async def gen() -> AsyncIterator[bytes]: + try: + async for chunk in self._handle_chat_streaming( + runtime=runtime, + token_name=token_name, + filename=filename, + file_path=file_path, + content_override=content_override, + agent_override=agent_override, + ): + yield chunk + finally: + self._busy.discard(file_path) + + return StreamingResponse( + gen(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + # nginx default-buffers SSE bodies; this header tells + # both nginx and uvicorn-behind-proxy to flush as we + # write. Harmless if the deployment has no reverse + # proxy in front. + "X-Accel-Buffering": "no", + }, + ) + return app # ---- dispatch ------------------------------------------------------ @@ -356,6 +444,240 @@ class MarkdownFrontend(Frontend): "new_content": new_content, } + # ---- streaming dispatch (SSE) -------------------------------------- + + async def _handle_chat_streaming( # noqa: PLR0915 — mirrors _handle_chat, splitting only doubles read cost + self, + *, + runtime: GatewayRuntime, + token_name: str, + filename: str, + file_path: Path, + content_override: Any, + agent_override: str | None, + ) -> AsyncIterator[bytes]: + """SSE counterpart of :meth:`_handle_chat`. + + Mirrors the same pipeline (resolve file → parse → resolve agent → + run backend → persist), but emits ``event: delta`` frames as the + rendered turn grows and a single terminal ``event: done`` / + ``event: error``. Errors that ``_handle_chat`` would surface as + ``HTTPException`` go out as ``error`` frames here (the HTTP + envelope is already 200 by the time the stream starts). + + Intermediate disk writes are deliberately skipped — only the + post-stream :meth:`_write_assistant_reply` lands on disk, so the + gateway-side vault and the plugin-side editor are the only + writers in their respective halves of Obsidian Sync. Final + content is identical on both sides, so Sync no-ops. + """ + # File-text resolution + early bailouts. ``content_override`` is + # still written to disk on the gateway side because that's the + # state the rest of the request consumes; it just doesn't keep + # ticking after that single write. + if isinstance(content_override, str): + await _write_atomic(file_path, content_override) + file_text = content_override + elif content_override is None: + file_text = await _read_or_empty(file_path) + else: + yield _sse_pack( + "error", + { + "status_code": status.HTTP_400_BAD_REQUEST, + "detail": "`content` must be a string when present", + }, + ) + return + + parsed = parser.parse(file_text) + agent_name = parser.resolve_agent( + metadata=parsed.metadata, + request_override=agent_override, + default=self.default_agent, + ) + if not agent_name: + yield _sse_pack( + "error", + { + "status_code": status.HTTP_400_BAD_REQUEST, + "detail": ( + "no agent specified: pass `agent`, set frontmatter, " + "or configure `default_agent`" + ), + }, + ) + return + + if not parsed.messages: + yield _sse_pack( + "done", + { + "status": "nothing_to_do", + "reason": "empty file", + "new_content": file_text, + }, + ) + return + + if parser.last_role(parsed.messages) == "assistant": + yield _sse_pack( + "done", + { + "status": "nothing_to_do", + "reason": "last turn is assistant", + "new_content": file_text, + }, + ) + return + + agent = runtime.agents.get(agent_name) + if agent is None: + yield _sse_pack( + "error", + { + "status_code": status.HTTP_404_NOT_FOUND, + "detail": f"unknown agent: {agent_name!r}", + }, + ) + return + backend = runtime.backends.get(agent.name) + if backend is None: + yield _sse_pack( + "error", + { + "status_code": status.HTTP_503_SERVICE_UNAVAILABLE, + "detail": f"no backend configured for agent {agent.name!r}", + }, + ) + return + + _log.info( + "chat/stream: actor=%s agent=%s file=%s msgs=%d", + token_name, + agent.name, + filename, + len(parsed.messages), + ) + await audit.log( + runtime, + actor=f"token:{token_name}", + kind="markdown_chat_stream", + agent_name=agent.name, + filename=filename, + msgs=len(parsed.messages), + ) + + conv, conv_external_id, stored_msgs = await self._resolve_conversation( + runtime=runtime, metadata=parsed.metadata, agent_name=agent.name + ) + outcome = diff_and_fork(stored=stored_msgs, incoming=parsed.turns) + capture: TurnCapture | None = ( + TurnCapture() if isinstance(backend, ClaudeCodeBackendAdapter) else None + ) + + kwargs: dict[str, Any] = {} + if capture is not None: + kwargs["capture"] = capture + events = backend.complete( + agent=agent, messages=outcome.messages, system=None, **kwargs + ) + + acc = StreamAccumulator() + model = agent.model or agent.name + last_flush = time.monotonic() + last_payload: str | None = None + + def snapshot() -> str | None: + partial = acc.finalize(model=model) + if not partial.content: + return None + rendered = renderer.render_assistant_message(partial) + new_body = renderer.append_to_body(parsed.body, rendered) + return _reattach_frontmatter(parsed.metadata, new_body) + + try: + async for ev in events: + acc.feed(ev) + now = time.monotonic() + if ( + isinstance(ev, RawContentBlockStopEvent) + or (now - last_flush) >= _SSE_FLUSH_DEBOUNCE + ): + payload = snapshot() + # Skip duplicate snapshots — e.g. tool_use blocks + # render to the same prefix as before they closed + # (we don't surface the tool-call args in markdown). + if payload is not None and payload != last_payload: + yield _sse_pack("delta", {"new_content": payload}) + last_payload = payload + last_flush = now + except Exception as exc: # noqa: BLE001 — wire any backend failure as an SSE error frame + _log.exception("backend failed for %s", filename) + # Mirror the legacy path: write the last partial + an error + # callout to disk so other consumers (logs, file watchers) + # see what arrived. The client gets a clean SSE ``error``. + partial = acc.finalize(model=model) + new_body = parsed.body + if partial.content: + new_body = renderer.append_to_body( + new_body, renderer.render_assistant_message(partial) + ) + new_body = renderer.append_to_body(new_body, _render_error_block(exc)) + await _write_atomic( + file_path, _reattach_frontmatter(parsed.metadata, new_body) + ) + yield _sse_pack( + "error", + { + "status_code": status.HTTP_500_INTERNAL_SERVER_ERROR, + "detail": f"backend error: {exc}", + }, + ) + return + + message = acc.finalize(model=model) + + new_content = await self._write_assistant_reply( + file_path=file_path, + parsed=parsed, + message=message, + agent_name=agent.name, + conv_external_id=conv_external_id, + ) + + await self._persist_canonical_history( + runtime=runtime, + conversation_id=conv.id, + persist_messages=outcome.persist_messages, + new_user_text=parsed.turns[-1].text, + capture=capture, + message=message, + ) + + record = TurnRecord( + agent_name=agent.name, + input_messages=list(parsed.messages), + output_message=message, + system=None, + source="markdown", + ) + for handler in runtime.turn_log_handlers: + try: + await handler(record) + except Exception: # noqa: BLE001 + _log.exception("turn_log_handler raised; continuing") + + yield _sse_pack( + "done", + { + "status": "ok", + "turns_appended": 1, + "agent": agent.name, + "new_content": new_content, + }, + ) + # ---- helpers ------------------------------------------------------- async def _stream_to_file( @@ -530,6 +852,18 @@ class MarkdownFrontend(Frontend): # ---- module-level utilities ---------------------------------------------- +def _sse_pack(event: str, data: dict[str, Any]) -> bytes: + r"""Format one Server-Sent Event frame. + + Uses named events (``event: ``) so the plugin can dispatch on + type without parsing JSON discriminators. ``ensure_ascii=False`` so + multibyte content rides through verbatim instead of becoming + ``\uXXXX`` blobs that bloat the wire. + """ + body = json.dumps(data, ensure_ascii=False) + return f"event: {event}\ndata: {body}\n\n".encode() + + async def _read_or_empty(path: Path) -> str: """Return file contents, or empty string if the file doesn't exist.""" # ``path.exists()`` here is a metadata stat — microseconds — and