From a83bec709d547acbba83a3c0a1d766dc2f5e1df4 Mon Sep 17 00:00:00 2001 From: h Date: Thu, 21 May 2026 12:27:11 +0200 Subject: [PATCH] feat: add stateful conversation storage --- src/beaver_gateway/backends/claude_code.py | 68 ++- src/beaver_gateway/core/conversation_store.py | 523 ++++++++++++++++++ .../frontends/markdown/frontend.py | 183 +++++- .../frontends/markdown/parser.py | 237 ++++++-- src/beaver_gateway/storage/models.py | 73 ++- uv.lock | 4 +- 6 files changed, 994 insertions(+), 94 deletions(-) create mode 100644 src/beaver_gateway/core/conversation_store.py diff --git a/src/beaver_gateway/backends/claude_code.py b/src/beaver_gateway/backends/claude_code.py index 0d29143..4967f23 100644 --- a/src/beaver_gateway/backends/claude_code.py +++ b/src/beaver_gateway/backends/claude_code.py @@ -28,6 +28,7 @@ from __future__ import annotations import json import uuid +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Self from claude_code_api import ( @@ -38,6 +39,7 @@ from claude_code_api import ( TextBlock, ThinkingBlock, ToolUseBlock, + synthesize_turn_messages, ) from beaver_gateway.agents.claude import ClaudeAgent @@ -65,7 +67,27 @@ if TYPE_CHECKING: from beaver_gateway.core.events import MessageStreamEvent -__all__ = ["ClaudeCodeBackendAdapter"] +__all__ = ["ClaudeCodeBackendAdapter", "TurnCapture"] + + +@dataclass +class TurnCapture: + """Side-channel sink for per-turn metadata. + + Pass an instance via ``ClaudeCodeBackendAdapter.complete(capture=...)``. + After the stream finishes, :attr:`synthesized_messages` holds the + full assistant↔tool-result cycle (from + :func:`claude_code_api.synthesize_turn_messages`) — i.e. the exact + list of canonical Anthropic-shape messages claude-code-api stashed + the live session under. The markdown frontend uses this to write the + conversation history to its DB so a subsequent turn's prefix + fingerprint hits the same session. + + Other backends (anthropic, raycast) ignore the kwarg — it lands in + their ``**options`` and is silently dropped. + """ + + synthesized_messages: list[dict[str, Any]] = field(default_factory=list) _CLAUDE_TO_ANTHROPIC_STOP: dict[str, StopReason] = { @@ -185,10 +207,7 @@ class ClaudeCodeBackendAdapter: """ def __init__( - self, - *, - agent: ClaudeAgent, - mcp_internal_urls: Mapping[str, str], + self, *, agent: ClaudeAgent, mcp_internal_urls: Mapping[str, str] ) -> None: self._agent = agent self._backend = ClaudeCodeBackend( @@ -207,9 +226,7 @@ class ClaudeCodeBackendAdapter: await self._backend.__aenter__() return self - async def __aexit__( - self, exc_type: object, exc: object, tb: object - ) -> None: + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: await self._backend.__aexit__(exc_type, exc, tb) async def aclose(self) -> None: @@ -221,6 +238,7 @@ class ClaudeCodeBackendAdapter: agent: BaseAgent, messages: Iterable[MessageParam], system: str | None = None, # noqa: ARG002 — see module docstring + capture: TurnCapture | None = None, **options: Any, # noqa: ARG002 — no per-request knobs for claude-code yet ) -> AsyncIterator[MessageStreamEvent]: if not isinstance(agent, ClaudeAgent): @@ -245,27 +263,43 @@ class ClaudeCodeBackendAdapter: next_index = 0 stop_reason: str | None = None usage: Mapping[str, Any] | None = None + # We keep raw events so we can hand them to + # ``synthesize_turn_messages`` after the stream closes — the + # markdown frontend stores the result in its conversation + # history so the next turn's prefix matches the backend's + # session-pool fingerprint. UserMessage (tool_result) events + # are silently discarded from the wire but kept here. + raw_events: list[Any] = [] async for event in self._backend.complete(list(messages)): + raw_events.append(event) if isinstance(event, AssistantMessage): for block in event.content: for ev in _emit_block(block, next_index): yield ev next_index += 1 elif isinstance(event, ResultMessage): + # ResultMessage is the terminal event from TurnManager + # — we capture its stop_reason / usage for the envelope + # below. We DO NOT break here: an early break would + # raise GeneratorExit inside claude-code-api's + # ``complete`` coroutine before it gets a chance to + # stash the live session under the post-turn + # fingerprint, so every continuation would miss the + # cache and reseed. Let the inner generator exit + # naturally instead. stop_reason = event.stop_reason usage = event.usage - # ResultMessage is always last (TurnManager synthesizes - # it as the terminal event), so we break after emitting - # the envelope close. - break # UserMessage (tool_result records) and SystemMessage # (turn_duration heartbeats) carry no content for the - # /v1/messages caller — skip silently. + # /v1/messages caller — skip silently on the wire, but they + # ARE retained in ``raw_events`` for synthesis below. + + if capture is not None: + capture.synthesized_messages = synthesize_turn_messages(raw_events) yield build_message_delta( - stop_reason=_map_stop_reason(stop_reason), - usage=_normalize_usage(usage), + stop_reason=_map_stop_reason(stop_reason), usage=_normalize_usage(usage) ) yield build_message_stop() @@ -292,9 +326,7 @@ def _emit_block( build_content_block_stop(index), ) if isinstance(block, ToolUseBlock): - partial = json.dumps( - block.input, separators=(",", ":"), ensure_ascii=False - ) + partial = json.dumps(block.input, separators=(",", ":"), ensure_ascii=False) return ( build_tool_use_block_start(index, tool_use_id=block.id, name=block.name), build_input_json_delta(index, partial), diff --git a/src/beaver_gateway/core/conversation_store.py b/src/beaver_gateway/core/conversation_store.py new file mode 100644 index 0000000..37e7b9a --- /dev/null +++ b/src/beaver_gateway/core/conversation_store.py @@ -0,0 +1,523 @@ +"""Stateful conversation history for the markdown frontend. + +The gateway used to be stateless about identity: claude-code-api's +in-memory session pool was keyed by a fingerprint of the messages the +gateway forwarded, and on a fingerprint miss the same fingerprint was +used to seed a fresh PTY's JSONL transcript. That worked as long as +the frontend could round-trip the *exact* content blocks the live +session had observed. The markdown frontend can't — the parser strips +``[!tool]-`` callouts because the human is allowed to edit the prose, +and the rendered tool callouts don't carry the canonical ``tool_use`` +block fields anyway. So a continuation hit was *only* reliable for +turns that never used a tool; once tools entered the picture, every +subsequent turn missed the cache and reseeded from a tool-less +transcript, leading to "assistant doesn't remember the tool calls it +just made." + +This module makes the gateway stateful for the markdown frontend (and +any other frontend that wants in). The DB stores the full +Anthropic-shape message list — text blocks, ``tool_use`` blocks, +``tool_result`` blocks, thinking signatures — exactly as +claude-code-api would have seen on the wire. Before each turn we +align the file the user is editing against the stored history: + +* If the user just appended a new user turn at the bottom, we feed + the backend our stored-plus-new history and the fingerprint hits. +* If the user edited the *text* inside an assistant turn but left the + tool callouts alone, we splice the new text into the stored + ``tool_use`` blocks and feed *that* — the fingerprint misses (text + differs), claude-code-api reseeds with a full transcript (tools and + all), the new live session has memory of the prior tool calls. +* If the user changed the *structure* (added/removed/reordered a tool + callout, edited an old user turn, etc.) we fork: take stored history + up to the divergence, take incoming text-only past the divergence. + The fingerprint misses; claude-code-api reseeds with a clean + truncated history; downstream turns continue from there. + +"Divergence point" is found by walking the file's turns and the +stored display turns in lockstep. See :func:`diff_and_fork`. +""" + +from __future__ import annotations + +import json +import uuid +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +from sqlmodel import select + +from beaver_gateway.storage.models import Conversation, ConversationMessage + +if TYPE_CHECKING: + from anthropic.types import MessageParam + from sqlmodel.ext.asyncio.session import AsyncSession + + from beaver_gateway.frontends.markdown.parser import ParsedTurn + +__all__ = [ + "ForkOutcome", + "diff_and_fork", + "load_conversation", + "load_messages", + "mint_conversation", + "rewrite_messages", +] + + +# ---- types -------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class ForkOutcome: + """Result of aligning the incoming file against stored history. + + ``messages`` is what the gateway feeds to the backend (already + includes the new user prompt at the tail). ``persist_messages`` + is the canonical conversation state the gateway should hold in + the DB *up to but not including* the new assistant reply — the + caller appends the synthesized turn from the backend onto this + and writes the result back. ``divergence_index`` is the + display-turn index at which incoming first disagreed with stored + (``None`` if everything matched; the new tail is appended cleanly). + """ + + messages: list[MessageParam] + persist_messages: list[dict[str, Any]] + divergence_index: int | None + + +# ---- public store API --------------------------------------------------- + + +async def load_conversation( + session: AsyncSession, *, frontend: str, external_id: str +) -> Conversation | None: + stmt = ( + select(Conversation) + .where(Conversation.frontend == frontend) + .where(Conversation.external_id == external_id) + ) + result = await session.exec(stmt) + return result.first() + + +async def mint_conversation( + session: AsyncSession, *, frontend: str, agent_name: str +) -> Conversation: + """Create a fresh conversation row with a new uuid for external_id. + + Caller is responsible for persisting the returned ``external_id`` on + the frontend side (frontmatter, response header, …) so future + requests can find this conversation again. + """ + row = Conversation( + frontend=frontend, external_id=str(uuid.uuid4()), agent_name=agent_name + ) + session.add(row) + await session.commit() + await session.refresh(row) + return row + + +async def load_messages( + session: AsyncSession, *, conversation_id: int +) -> list[dict[str, Any]]: + """Return stored messages ordered by ``seq`` ascending. + + Each entry is a canonical Anthropic ``MessageParam`` dict — ``role`` + plus ``content`` (string or list of block dicts). The same shape + we feed to the backend on continuation. + """ + stmt = ( + select(ConversationMessage) + .where(ConversationMessage.conversation_id == conversation_id) + .order_by(ConversationMessage.seq.asc()) # ty: ignore[unresolved-attribute] + ) + result = await session.exec(stmt) + rows = result.all() + return [{"role": r.role, "content": json.loads(r.content_json)} for r in rows] + + +async def rewrite_messages( + session: AsyncSession, *, conversation_id: int, messages: list[dict[str, Any]] +) -> None: + """Replace the conversation's stored messages with ``messages``. + + The user said no branch history — we overwrite on fork. Cheap at + our volume; if it ever matters we can switch to soft-delete + + branch pointers. + """ + # Delete existing rows for this conversation. + existing_stmt = select(ConversationMessage).where( + ConversationMessage.conversation_id == conversation_id + ) + result = await session.exec(existing_stmt) + for row in result.all(): + await session.delete(row) + # Insert the new sequence. + for seq, m in enumerate(messages): + session.add( + ConversationMessage( + conversation_id=conversation_id, + seq=seq, + role=str(m["role"]), + content_json=json.dumps(m["content"], separators=(",", ":")), + ) + ) + # Bump conversation.updated_at. + conv = await session.get(Conversation, conversation_id) + if conv is not None: + from datetime import UTC, datetime + + conv.updated_at = datetime.now(UTC) + session.add(conv) + await session.commit() + + +# ---- alignment ---------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class _StoredDisplayTurn: + """A "display turn" reconstructed from stored raw messages. + + ``role`` is ``"user"`` (single user-prompt message) or + ``"assistant"`` (one or more assistant messages, optionally + interleaved with user-only-tool_result messages). ``messages`` is + the slice of stored raw messages this display turn covers, in + order. ``spoken_text`` and ``skeleton`` are the + parser-equivalents for diff purposes; ``text_segment_count`` lets + us refuse a splice when the user edited across a tool boundary in + a way we can't safely undo. + """ + + role: str + messages: tuple[dict[str, Any], ...] + spoken_text: str + skeleton: tuple[str, ...] + text_segment_count: int + + +def _group_display_turns(stored: list[dict[str, Any]]) -> list[_StoredDisplayTurn]: + """Walk raw stored messages, group them into Obsidian-visible turns. + + A user-prompt message (``role=user`` with string content, or list + content with no ``tool_result`` blocks) opens a user display turn. + Otherwise it's a tool-result follow-up and rolls into the current + assistant display turn. + """ + out: list[_StoredDisplayTurn] = [] + i = 0 + while i < len(stored): + msg = stored[i] + role = msg["role"] + if role == "user" and _is_user_prompt(msg.get("content")): + out.append( + _StoredDisplayTurn( + role="user", + messages=(msg,), + spoken_text=_user_prompt_text(msg.get("content")), + skeleton=(), + text_segment_count=0, + ) + ) + i += 1 + continue + # Assistant display turn: collect consecutive non-prompt messages. + group: list[dict[str, Any]] = [] + while i < len(stored): + m = stored[i] + if m["role"] == "user" and _is_user_prompt(m.get("content")): + break + group.append(m) + i += 1 + spoken, skeleton, text_count = _summarize_assistant_group(group) + out.append( + _StoredDisplayTurn( + role="assistant", + messages=tuple(group), + spoken_text=spoken, + skeleton=tuple(skeleton), + text_segment_count=text_count, + ) + ) + return out + + +def _is_user_prompt(content: Any) -> bool: + """A user message is a *prompt* unless its content carries tool_result blocks.""" + if isinstance(content, str): + return True + if isinstance(content, list): + return not any( + isinstance(b, dict) and b.get("type") == "tool_result" for b in content + ) + # Unknown shape — be conservative, treat as prompt. + return True + + +def _user_prompt_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + chunks = [ + str(b.get("text", "")) + for b in content + if isinstance(b, dict) and b.get("type") == "text" + ] + return "\n\n".join(c for c in chunks if c) + return "" + + +def _summarize_assistant_group( + group: list[dict[str, Any]], +) -> tuple[str, list[str], int]: + """Compute (spoken_text, tool_skeleton, text_segment_count) for a display group. + + Mirrors what ``parser.parse_assistant_structure`` would produce when + re-parsing the rendered version of this group: consecutive text + blocks across assistant messages collapse into one text segment; + tool_use blocks become skeleton entries; tool_result messages and + thinking blocks are invisible. + """ + # See ``diff_and_fork`` for why the parser-type imports are deferred. + from beaver_gateway.frontends.markdown.parser import TextSegment, ToolSegment + + segments: list[TextSegment | ToolSegment] = [] + pending: list[str] = [] + + def _flush() -> None: + if not pending: + return + joined = "\n\n".join(p for p in pending if p) + pending.clear() + cleaned = joined.strip() + if cleaned: + segments.append(TextSegment(text=cleaned)) + + for msg in group: + if msg["role"] == "user": + # tool_result message — boundary for text but emits no segment. + _flush() + continue + content = msg.get("content") + if not isinstance(content, list): + continue + for blk in content: + if not isinstance(blk, dict): + continue + btype = blk.get("type") + if btype == "text": + text = str(blk.get("text", "")).strip() + if text: + pending.append(text) + elif btype == "tool_use": + _flush() + segments.append(ToolSegment(name=str(blk.get("name", "")))) + # thinking: skip silently + _flush() + spoken_chunks = [s.text for s in segments if isinstance(s, TextSegment)] + spoken = "\n\n".join(c for c in spoken_chunks if c).strip() + skeleton = [s.name for s in segments if isinstance(s, ToolSegment)] + text_count = sum(1 for s in segments if isinstance(s, TextSegment)) + return spoken, skeleton, text_count + + +# ---- the core algorithm ------------------------------------------------- + + +def diff_and_fork( + *, stored: list[dict[str, Any]], incoming: list[ParsedTurn] +) -> ForkOutcome: + """Align the incoming parsed file against stored history. + + ``stored`` is the raw Anthropic-shape message list from the DB + (one entry per ``ConversationMessage`` row). ``incoming`` is the + user-visible turn list from the markdown parser. The last + ``incoming`` entry must be a user turn — that's the new prompt + triggering this request. + + Returns a :class:`ForkOutcome` whose ``messages`` is what the + backend should run on and whose ``persist_messages`` is the + canonical history to store in the DB once the backend's + synthesized cycle is appended. + """ + # ``parser`` lives under ``frontends/markdown/`` whose ``__init__`` + # eagerly loads ``frontend.py``, which in turn imports this module + # — pulling the parser at module-import time creates a cycle. The + # helpers below import the segment classes lazily inside their own + # function bodies to break it. + if not incoming or incoming[-1].role != "user": + msg = ( + "diff_and_fork expects incoming to end with a user turn " + "(the new prompt); got " + f"{incoming[-1].role if incoming else 'empty'}" + ) + raise ValueError(msg) + + stored_groups = _group_display_turns(stored) + new_user_turn = incoming[-1] + prior_incoming = incoming[:-1] + + spliced_groups, divergence = _walk_prefix(prior_incoming, stored_groups) + + if divergence is None and len(prior_incoming) < len(stored_groups): + # Incoming truncated stored (user deleted some past turns). + # Truncate stored to match. + divergence = len(prior_incoming) + + backend_msgs, persist_msgs = _assemble_tail( + spliced_groups=spliced_groups, + prior_incoming=prior_incoming, + divergence=divergence, + new_user_turn=new_user_turn, + ) + return ForkOutcome( + messages=backend_msgs, + persist_messages=persist_msgs, + divergence_index=divergence, + ) + + +def _walk_prefix( + prior_incoming: list[ParsedTurn], stored_groups: list[_StoredDisplayTurn] +) -> tuple[list[list[dict[str, Any]]], int | None]: + """Walk incoming vs stored side-by-side until first divergence. + + Returns the spliced/matched group list (one entry per matched + display turn, each carrying the raw messages we'll feed to the + backend for that turn) and the divergence index (``None`` if all + of ``prior_incoming`` matched). + """ + from beaver_gateway.frontends.markdown.parser import TextSegment, ToolSegment + + spliced_groups: list[list[dict[str, Any]]] = [] + for i, inc in enumerate(prior_incoming): + if i >= len(stored_groups): + return spliced_groups, i + st = stored_groups[i] + if inc.role != st.role: + return spliced_groups, i + if inc.role == "user": + if inc.text != st.spoken_text: + return spliced_groups, i + spliced_groups.append(list(st.messages)) + continue + inc_skeleton = tuple( + s.name for s in inc.structure if isinstance(s, ToolSegment) + ) + inc_text_count = sum(1 for s in inc.structure if isinstance(s, TextSegment)) + if inc_skeleton != st.skeleton: + return spliced_groups, i + if inc.text == st.spoken_text: + spliced_groups.append(list(st.messages)) + continue + if inc_text_count != st.text_segment_count: + return spliced_groups, i + spliced = _splice_assistant_group(stored_group=st, incoming=inc) + if spliced is None: + return spliced_groups, i + spliced_groups.append(spliced) + return spliced_groups, None + + +def _assemble_tail( + *, + spliced_groups: list[list[dict[str, Any]]], + prior_incoming: list[ParsedTurn], + divergence: int | None, + new_user_turn: ParsedTurn, +) -> tuple[list[MessageParam], list[dict[str, Any]]]: + """Build the (backend, persist) lists from aligned + post-divergence tail.""" + backend_msgs: list[MessageParam] = [] + persist_msgs: list[dict[str, Any]] = [] + for spliced in spliced_groups: + for m in spliced: + entry: dict[str, Any] = {"role": m["role"], "content": m["content"]} + backend_msgs.append(cast("MessageParam", entry)) + persist_msgs.append(entry) + if divergence is not None: + for inc in prior_incoming[divergence:]: + if not inc.text: + continue + entry = {"role": inc.role, "content": inc.text} + backend_msgs.append(cast("MessageParam", entry)) + persist_msgs.append(entry) + backend_msgs.append({"role": "user", "content": new_user_turn.text}) + return backend_msgs, persist_msgs + + +def _splice_assistant_group( + *, stored_group: _StoredDisplayTurn, incoming: ParsedTurn +) -> list[dict[str, Any]] | None: + """Rebuild an assistant display turn with new text + stored tool_use blocks. + + Walks the incoming structure; for each ``TextSegment`` emits a + text block into the current assistant message; for each + ``ToolSegment`` consumes the next stored ``tool_use`` block (by + position), closes the current assistant message, emits the + matching ``tool_result`` user message, and opens a new assistant + message. Final ``TextSegment`` closes the last assistant message. + + Returns ``None`` if we can't find a matching tool_result for some + tool_use (stored history is malformed) — caller falls back to + fork. + """ + # See ``diff_and_fork`` for why this import is deferred. + from beaver_gateway.frontends.markdown.parser import TextSegment + + tool_uses, tool_results_by_id = _harvest_tool_blocks(stored_group) + + spliced: list[dict[str, Any]] = [] + current_asst: list[dict[str, Any]] = [] + next_tool = 0 + for seg in incoming.structure: + if isinstance(seg, TextSegment): + if seg.text: + current_asst.append({"type": "text", "text": seg.text}) + continue + if next_tool >= len(tool_uses): + return None + tu = tool_uses[next_tool] + next_tool += 1 + current_asst.append(tu) + spliced.append({"role": "assistant", "content": current_asst}) + current_asst = [] + tr = tool_results_by_id.get(str(tu.get("id", ""))) + if tr is None: + return None + spliced.append({"role": "user", "content": [tr]}) + if current_asst: + spliced.append({"role": "assistant", "content": current_asst}) + elif not spliced: + # Defensive: assistant turn with no text and no tools makes no + # sense; caller will treat as fork. + return None + return spliced + + +def _harvest_tool_blocks( + stored_group: _StoredDisplayTurn, +) -> tuple[list[dict[str, Any]], dict[str, dict[str, Any]]]: + """Pull stored ``tool_use`` blocks (ordered) and ``tool_result`` blocks (by id).""" + tool_uses: list[dict[str, Any]] = [] + tool_results_by_id: dict[str, dict[str, Any]] = {} + for msg in stored_group.messages: + content = msg.get("content") + if not isinstance(content, list): + continue + if msg["role"] == "assistant": + tool_uses.extend( + blk + for blk in content + if isinstance(blk, dict) and blk.get("type") == "tool_use" + ) + continue + for blk in content: + if not isinstance(blk, dict) or blk.get("type") != "tool_result": + continue + tid = blk.get("tool_use_id") + if isinstance(tid, str): + tool_results_by_id[tid] = blk + return tool_uses, tool_results_by_id diff --git a/src/beaver_gateway/frontends/markdown/frontend.py b/src/beaver_gateway/frontends/markdown/frontend.py index a82d698..c1d4b3c 100644 --- a/src/beaver_gateway/frontends/markdown/frontend.py +++ b/src/beaver_gateway/frontends/markdown/frontend.py @@ -36,7 +36,15 @@ import aiofile from fastapi import FastAPI, HTTPException, Request, status from fastapi.responses import JSONResponse +from beaver_gateway.backends.claude_code import ClaudeCodeBackendAdapter, TurnCapture from beaver_gateway.core import audit +from beaver_gateway.core.conversation_store import ( + diff_and_fork, + load_conversation, + load_messages, + mint_conversation, + rewrite_messages, +) from beaver_gateway.core.turn_record import TurnRecord from beaver_gateway.frontends._accumulate import accumulate from beaver_gateway.frontends._auth import require_token @@ -264,9 +272,25 @@ class MarkdownFrontend(Frontend): msgs=len(parsed.messages), ) + # Resolve / mint the conversation row, align incoming against + # stored history, and feed the aligned messages to the backend + # — see ``core/conversation_store.py`` for the full rationale. + # If the backend isn't claude-code (no ``TurnCapture`` support) + # we fall through to the legacy parser-only path. + conv, conv_external_id, stored_msgs = await self._resolve_conversation( + runtime=runtime, metadata=parsed.metadata, agent_name=agent.name + ) + outcome = diff_and_fork(stored=stored_msgs, incoming=parsed.turns) + capture: TurnCapture | None = ( + TurnCapture() if isinstance(backend, ClaudeCodeBackendAdapter) else None + ) + try: + kwargs: dict[str, Any] = {} + if capture is not None: + kwargs["capture"] = capture events = backend.complete( - agent=agent, messages=parsed.messages, system=None + agent=agent, messages=outcome.messages, system=None, **kwargs ) message = await accumulate(events, model=agent.model or agent.name) except Exception as exc: @@ -280,22 +304,22 @@ class MarkdownFrontend(Frontend): status.HTTP_500_INTERNAL_SERVER_ERROR, f"backend error: {exc}" ) from exc - rendered = renderer.render_assistant_message(message) - new_body = renderer.append_to_body(parsed.body, rendered) - new_body = renderer.append_to_body(new_body, renderer.USER_SCAFFOLD) - # Recompute fingerprint so a future cross-frontend hit on this - # same conversation can find it. Stored as hex string in - # frontmatter — only the markdown frontend reads it. - assistant_param: MessageParam = { - "role": "assistant", - "content": _flatten_assistant_text(message), - } - updated_messages: list[MessageParam] = [*parsed.messages, assistant_param] - updated_metadata = dict(parsed.metadata) - updated_metadata["agent"] = agent.name - updated_metadata["fingerprint"] = fingerprint_messages(updated_messages) - new_content = _reattach_frontmatter(updated_metadata, new_body) - await _write_atomic(file_path, new_content) + new_content = await self._write_assistant_reply( + file_path=file_path, + parsed=parsed, + message=message, + agent_name=agent.name, + conv_external_id=conv_external_id, + ) + + await self._persist_canonical_history( + runtime=runtime, + conversation_id=conv.id, + persist_messages=outcome.persist_messages, + new_user_text=parsed.turns[-1].text, + capture=capture, + message=message, + ) # Broadcast our own turn so other handlers (none today, but the # symmetry is worth keeping) see what happened. ``source`` marks @@ -322,6 +346,94 @@ class MarkdownFrontend(Frontend): # ---- helpers ------------------------------------------------------- + async def _write_assistant_reply( + self, + *, + file_path: Path, + parsed: parser.ParsedFile, + message: Any, + agent_name: str, + conv_external_id: str, + ) -> str: + """Render the assistant turn, append to the file, refresh frontmatter.""" + rendered = renderer.render_assistant_message(message) + new_body = renderer.append_to_body(parsed.body, rendered) + new_body = renderer.append_to_body(new_body, renderer.USER_SCAFFOLD) + # Recompute fingerprint so a future cross-frontend hit on this + # same conversation can find it. Stored as hex string in + # frontmatter — only the markdown frontend reads it. + assistant_param: MessageParam = { + "role": "assistant", + "content": _flatten_assistant_text(message), + } + updated_messages: list[MessageParam] = [*parsed.messages, assistant_param] + updated_metadata = dict(parsed.metadata) + updated_metadata["agent"] = agent_name + updated_metadata["conversation_id"] = conv_external_id + updated_metadata["fingerprint"] = fingerprint_messages(updated_messages) + new_content = _reattach_frontmatter(updated_metadata, new_body) + await _write_atomic(file_path, new_content) + return new_content + + async def _resolve_conversation( + self, *, runtime: GatewayRuntime, metadata: dict[str, Any], agent_name: str + ) -> tuple[Any, str, list[dict[str, Any]]]: + """Resolve the conversation row + stored messages for this request. + + Looks up by frontmatter ``conversation_id``, mints a new row if + missing, and returns ``(conv, external_id, stored_messages)``. + ``conv.id`` is guaranteed non-None because both + ``load_conversation`` (after refresh on a committed row) and + ``mint_conversation`` (post-commit refresh) populate it. We + coerce with a runtime check so the rest of the handler can + treat it as ``int``. + """ + raw = metadata.get("conversation_id") + lookup_id = raw if isinstance(raw, str) and raw else None + async with runtime.db.session() as session: + conv = None + if lookup_id is not None: + conv = await load_conversation( + session, frontend="markdown", external_id=lookup_id + ) + if conv is None: + conv = await mint_conversation( + session, frontend="markdown", agent_name=agent_name + ) + if conv.id is None: + msg = "conversation row missing primary key after commit" + raise RuntimeError(msg) + stored = await load_messages(session, conversation_id=conv.id) + return conv, conv.external_id, stored + + async def _persist_canonical_history( + self, + *, + runtime: GatewayRuntime, + conversation_id: int, + persist_messages: list[dict[str, Any]], + new_user_text: str, + capture: TurnCapture | None, + message: Any, + ) -> None: + """Stamp the DB with the post-turn canonical Anthropic-shape history. + + Combines the matched/spliced prior state, the new user prompt, + and the synthesized assistant↔tool cycle from the backend (or + a text-only fallback for backends without ``TurnCapture``). + """ + new_user_msg = {"role": "user", "content": new_user_text} + synthesized = ( + capture.synthesized_messages + if capture is not None + else _fallback_synthesized(message) + ) + canonical = [*persist_messages, new_user_msg, *synthesized] + async with runtime.db.session() as session: + await rewrite_messages( + session, conversation_id=conversation_id, messages=canonical + ) + def _resolve_path(self, filename: str) -> Path: """Resolve ``filename`` under the vault; reject escapes.""" # ``filename`` may be relative or absolute; we always anchor @@ -399,6 +511,43 @@ def _reattach_frontmatter(metadata: dict[str, Any], body: str) -> str: return _fm.dumps(post) + "\n" +def _fallback_synthesized(message: Any) -> list[dict[str, Any]]: + """Build a single-assistant ``synthesized_messages`` list from a raw ``Message``. + + For backends that don't populate a :class:`TurnCapture` (anthropic + HTTP, raycast, …) we don't have access to per-tool-cycle + granularity, so the assistant reply lands in the DB as one + canonical-block message. Tool memory across cache misses would + degrade in that case, but those backends don't have the cache-miss + re-seed problem to begin with — they manage history client-side. + """ + content: list[dict[str, Any]] = [] + for block in getattr(message, "content", ()): + btype = getattr(block, "type", None) + if btype == "text": + content.append({"type": "text", "text": getattr(block, "text", "") or ""}) + elif btype == "tool_use": + content.append( + { + "type": "tool_use", + "id": getattr(block, "id", ""), + "name": getattr(block, "name", ""), + "input": getattr(block, "input", {}), + } + ) + elif btype == "thinking": + content.append( + { + "type": "thinking", + "thinking": getattr(block, "thinking", "") or "", + "signature": getattr(block, "signature", "") or "", + } + ) + if not content: + return [] + return [{"role": "assistant", "content": content}] + + def _flatten_assistant_text(message: Any) -> str: """Pull all text blocks from an assistant ``Message`` and join them. diff --git a/src/beaver_gateway/frontends/markdown/parser.py b/src/beaver_gateway/frontends/markdown/parser.py index 953b875..17e61db 100644 --- a/src/beaver_gateway/frontends/markdown/parser.py +++ b/src/beaver_gateway/frontends/markdown/parser.py @@ -27,7 +27,41 @@ if TYPE_CHECKING: from anthropic.types import MessageParam -__all__ = ["ParsedFile", "last_role", "parse", "resolve_agent"] +__all__ = [ + "AssistantSegment", + "ParsedFile", + "ParsedTurn", + "TextSegment", + "ToolSegment", + "last_role", + "parse", + "parse_assistant_structure", + "resolve_agent", +] + + +@dataclass(frozen=True, slots=True) +class TextSegment: + """A run of plain text inside an assistant turn (between callouts).""" + + text: str + + +@dataclass(frozen=True, slots=True) +class ToolSegment: + """A ``> [!tool]- `` callout placeholder. + + Only the tool ``name`` is captured — the " · summary" suffix on the + callout title and the JSON body inside the quote block are + decorative for the human reader; the canonical tool_use block lives + in the DB and is keyed by *position+name* against the structure + parsed here. + """ + + name: str + + +AssistantSegment = TextSegment | ToolSegment # Turn marker — must be exactly ``### User:`` or ``### Assistant:`` on @@ -42,6 +76,35 @@ _TURN_RE = re.compile(r"^###\s+(User|Assistant):\s*$", re.MULTILINE) # need to drop the whole quoted block. _CALLOUT_START_RE = re.compile(r"^>\s+\[!(thinking|tool)\]") +# Tool-callout title line: ``> [!tool]- `` or ``> [!tool]- · ``. +# We only need the ```` part for skeleton matching; the summary is +# decorative (built by ``renderer.summarize_tool_input`` from inputs the +# user can edit visually without semantic consequence). +_TOOL_TITLE_RE = re.compile(r"^>\s+\[!tool\]-\s*(.*?)\s*$") +# Renderer joins name + summary with " · " (U+00B7) — see +# ``renderer.summarize_tool_input``. We split on it to recover the +# bare tool name. +_TOOL_TITLE_SEP = " · " + + +@dataclass(frozen=True, slots=True) +class ParsedTurn: + """One turn extracted from the chat file. + + ``role`` is ``"user"`` or ``"assistant"``. ``text`` is the spoken + content with callouts stripped and HRs dropped — used both as the + backend's ``MessageParam.content`` (back-compat with the existing + parser shape) and as the diff key against stored turns. + ``structure`` is non-empty only for assistant turns: an ordered + list of ``TextSegment`` / ``ToolSegment`` reflecting the visible + layout of the assistant block, used by the conversation store to + align with the canonical tool_use blocks held in DB. + """ + + role: str + text: str + structure: tuple[TextSegment | ToolSegment, ...] = () + @dataclass(frozen=True, slots=True) class ParsedFile: @@ -49,48 +112,159 @@ class ParsedFile: ``metadata`` is the YAML frontmatter as a plain dict (empty if the file has none). ``messages`` is the conversation history shaped for - ``Backend.complete`` — assistant turns are text-only. ``body`` is the - raw markdown content *after* the frontmatter is stripped; the - renderer needs it when it appends a new assistant turn so it can - preserve whatever the human typed verbatim (including any callouts - or HRs they added). + ``Backend.complete`` — assistant turns are text-only. ``turns`` is + 1:1 with ``messages`` and carries the per-turn structure (for + assistant turns) that the conversation store needs to detect + text-only edits vs. structural forks. ``body`` is the raw markdown + content *after* the frontmatter is stripped; the renderer needs it + when it appends a new assistant turn so it can preserve whatever + the human typed verbatim (including any callouts or HRs they + added). """ metadata: dict[str, Any] body: str messages: list[MessageParam] + turns: list[ParsedTurn] def parse(text: str) -> ParsedFile: - """Parse a chat ``.md`` into ``(metadata, body, messages)``. + """Parse a chat ``.md`` into ``(metadata, body, messages, turns)``. A file with no turn markers but non-empty body is treated as a single user turn — the friendly path for "user types into a new file and hits send" before any turn markers exist. + + Assistant turns that have *only* tool callouts (no spoken text) are + preserved here even though their ``MessageParam.content`` is empty + — the structure carries tool-segment information the conversation + store needs for skeleton matching. The renderer in practice always + emits at least a trailing text block, so this branch is defensive. """ parsed = frontmatter.loads(text) metadata = dict(parsed.metadata) body = parsed.content messages: list[MessageParam] = [] - turns = _split_turns(body) - if not turns: + parsed_turns: list[ParsedTurn] = [] + raw_turns = _split_turns(body) + if not raw_turns: stripped = body.strip() if stripped: messages.append({"role": "user", "content": stripped}) - return ParsedFile(metadata=metadata, body=body, messages=messages) + parsed_turns.append(ParsedTurn(role="user", text=stripped)) + return ParsedFile( + metadata=metadata, body=body, messages=messages, turns=parsed_turns + ) - for role, raw in turns: + for role, raw in raw_turns: if role == "user": text_content = _strip_hrs(raw).strip() if text_content: messages.append({"role": "user", "content": text_content}) + parsed_turns.append(ParsedTurn(role="user", text=text_content)) else: - text_content = _extract_assistant_text(raw) + structure = parse_assistant_structure(raw) + text_content = _segments_to_spoken_text(structure) + has_tools = any(isinstance(s, ToolSegment) for s in structure) if text_content: messages.append({"role": "assistant", "content": text_content}) + parsed_turns.append( + ParsedTurn( + role="assistant", text=text_content, structure=tuple(structure) + ) + ) + elif has_tools: + # Tool-only assistant turn: nothing to feed the backend + # as ``content`` (it'd reject an empty string), but the + # structure must survive so the store can align it + # against stored tool_use blocks. We synthesize a + # single-space text content for backend round-trip; the + # conversation store will replace this payload with the + # canonical stored blocks before the backend ever sees + # it on a continuation. + messages.append({"role": "assistant", "content": " "}) + parsed_turns.append( + ParsedTurn(role="assistant", text="", structure=tuple(structure)) + ) - return ParsedFile(metadata=metadata, body=body, messages=messages) + return ParsedFile( + metadata=metadata, body=body, messages=messages, turns=parsed_turns + ) + + +def parse_assistant_structure(raw: str) -> list[TextSegment | ToolSegment]: + """Walk an assistant turn body, return its ordered text/tool segments. + + Tool callouts become :class:`ToolSegment` with just the tool name — + the title's optional ``" · summary"`` suffix and the JSON body + inside the quote block are decorative; the canonical tool_use + block is held in the conversation store. Thinking callouts are + stripped entirely (they were never round-trippable through the + file — signatures expire). HR separator lines drop out. + + Empty / whitespace-only text segments at the boundaries (start, + end, between adjacent tool callouts) are dropped so the skeleton + is robust against renderer whitespace choices; a non-empty text + segment with surrounding whitespace is trimmed on both ends but + preserved. + """ + segments: list[TextSegment | ToolSegment] = [] + pending_text: list[str] = [] + + def _flush_text() -> None: + if not pending_text: + return + joined = "\n".join(pending_text) + # Collapse runs of >2 blank lines (created when we stripped a + # mid-block callout) into one so the diff against a re-render + # is stable. + cleaned = re.sub(r"\n{3,}", "\n\n", joined).strip() + pending_text.clear() + if cleaned: + segments.append(TextSegment(text=cleaned)) + + lines = raw.splitlines() + i = 0 + while i < len(lines): + line = lines[i] + callout_match = _CALLOUT_START_RE.match(line) + if callout_match: + kind = callout_match.group(1) + # Capture tool name *before* advancing past the block. + if kind == "tool": + title_match = _TOOL_TITLE_RE.match(line) + title = title_match.group(1) if title_match else "" + name = title.split(_TOOL_TITLE_SEP, 1)[0].strip() + _flush_text() + segments.append(ToolSegment(name=name)) + else: + # Thinking callout — drop the whole block, emit nothing. + _flush_text() + # Skip the rest of the quote block. + while i < len(lines) and lines[i].lstrip().startswith(">"): + i += 1 + continue + if line.strip() == "---": + i += 1 + continue + pending_text.append(line) + i += 1 + _flush_text() + return segments + + +def _segments_to_spoken_text(segments: list[TextSegment | ToolSegment]) -> str: + r"""Reduce a structure list to the spoken-text view the backend sees. + + Concatenates :class:`TextSegment` contents with ``\n\n`` between + them, dropping :class:`ToolSegment` entries. Equivalent to what + the pre-Conversation-store parser did — we keep that behavior so + existing fingerprints (frontmatter ``fingerprint`` field) stay + valid. + """ + chunks = [s.text for s in segments if isinstance(s, TextSegment)] + return "\n\n".join(c for c in chunks if c).strip() def last_role(messages: list[MessageParam]) -> str | None: @@ -148,40 +322,3 @@ def _strip_hrs(raw: str) -> str: lines = raw.splitlines() kept = [ln for ln in lines if ln.strip() != "---"] return "\n".join(kept) - - -def _extract_assistant_text(raw: str) -> str: - """Strip thinking/tool callouts from an assistant turn, return spoken text. - - Walks line by line. When we see a callout-start line (``> [!thinking]-`` - or ``> [!tool]- ...``), we skip the entire contiguous quote block - (lines beginning with ``>`` or blank-then-`>` continuations don't - happen in Obsidian callouts — a blank line ends the callout). HR - lines (``---``) are dropped. Everything else is kept and joined, - then collapsed to a clean trim. - """ - lines = raw.splitlines() - out_lines: list[str] = [] - i = 0 - while i < len(lines): - line = lines[i] - if _CALLOUT_START_RE.match(line): - # Skip the whole quote block (consecutive lines starting - # with ``>``). Stop at first non-``>`` line, leaving it for - # the next iteration. Blank lines do not end the block — a - # callout body with a blank line uses ``> `` (quote-space) - # too — but in practice Obsidian's quote block ends on the - # first line that doesn't start with ``>``. - while i < len(lines) and lines[i].lstrip().startswith(">"): - i += 1 - continue - if line.strip() == "---": - i += 1 - continue - out_lines.append(line) - i += 1 - # Collapse runs of blank lines that callout-stripping creates - # (two newlines around a stripped block fold into one). - text_joined = "\n".join(out_lines) - text_joined = re.sub(r"\n{3,}", "\n\n", text_joined) - return text_joined.strip() diff --git a/src/beaver_gateway/storage/models.py b/src/beaver_gateway/storage/models.py index a18938c..4ec0170 100644 --- a/src/beaver_gateway/storage/models.py +++ b/src/beaver_gateway/storage/models.py @@ -1,15 +1,19 @@ """SQLModel tables. -Two tables, both flat, no relationships modelled yet (``actor`` and +Four tables, all flat, no FK relationships modelled (``actor`` and ``agent_name`` are stored as strings — joining audit→token by name is fine at this volume; we'll introduce FKs when the admin UI actually demands them). -A ``Session`` table originally lived here for live-session -observability. It was dropped after we decided the gateway stays -stateless about identity (claude-code-api's in-memory fingerprint pool -is the source of truth) and that conversation persistence belongs in a -future Obsidian-sync frontend, not a sessions table. +The ``Conversation`` + ``ConversationMessage`` pair persists chat +history per frontend so we can survive cache misses without losing +tool-call memory. The gateway is now stateful about conversation +content (we keep the raw Anthropic-shape message list including +``tool_use`` / ``tool_result`` blocks); the live ``claude-code-api`` +session pool stays the source of truth for *fingerprints*, and the DB +mirrors what we'd want to re-seed if a session evicts. See +``core/conversation_store.py`` for the diff-and-fork logic and +``frontends/markdown/frontend.py`` for the integration point. Datetimes are stored UTC; we set ``default_factory`` rather than relying on DB defaults so SQLite + Postgres behave identically. Every row that @@ -20,6 +24,7 @@ from __future__ import annotations from datetime import UTC, datetime +from sqlalchemy import UniqueConstraint from sqlmodel import Field, SQLModel @@ -66,4 +71,58 @@ class AuditLog(SQLModel, table=True): detail_json: str = Field(default="{}") -__all__ = ["AuditLog", "Token"] +class Conversation(SQLModel, table=True): + """One chat thread, scoped to a frontend. + + ``external_id`` is the identifier the frontend uses to find this + thread again on the next request — for the markdown frontend it's a + uuid we mint and persist into the file's frontmatter, for the + anthropic frontend it'd be the same metadata.conversation_id the + client passes. Unique per ``(frontend, external_id)`` because two + frontends sharing a uuid is fine; the same frontend reusing one is + a bug. + """ + + __tablename__ = "conversations" + __table_args__ = ( + UniqueConstraint("frontend", "external_id", name="uq_conv_frontend_extid"), + ) + + id: int | None = Field(default=None, primary_key=True) + frontend: str = Field(index=True) + external_id: str = Field(index=True) + agent_name: str = Field(index=True) + created_at: datetime = Field(default_factory=_utcnow) + updated_at: datetime = Field(default_factory=_utcnow) + + +class ConversationMessage(SQLModel, table=True): + """One raw Anthropic-shape message in a conversation's transcript. + + A single user/assistant exchange visible in Obsidian can occupy + multiple rows when claude ran a tool cycle: ``assistant`` + (tool_use), ``user`` (tool_result), ``assistant`` (final text) all + live as separate rows with the same conversation_id and monotonic + ``seq``. ``content_json`` is the canonical Anthropic content payload + (string or list-of-blocks) — exactly what we'll feed back to the + backend so its session-pool fingerprint matches. + + ``seq`` is per-conversation 0-based monotonic; the unique + constraint catches the trivial bug of two writers racing on the + same conversation. + """ + + __tablename__ = "conversation_messages" + __table_args__ = ( + UniqueConstraint("conversation_id", "seq", name="uq_msg_conv_seq"), + ) + + id: int | None = Field(default=None, primary_key=True) + conversation_id: int = Field(index=True) + seq: int + role: str + content_json: str + created_at: datetime = Field(default_factory=_utcnow) + + +__all__ = ["AuditLog", "Conversation", "ConversationMessage", "Token"] diff --git a/uv.lock b/uv.lock index 7c97451..29e7262 100644 --- a/uv.lock +++ b/uv.lock @@ -287,7 +287,7 @@ local = [ { name = "raycast-api", version = "0.1.0", source = { editable = "../raycast-api" } }, ] prod = [ - { name = "claude-code-api", version = "0.1.0", source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#bf6116dc8b7f3708685c5a6e27061859e73eb4c9" } }, + { name = "claude-code-api", version = "0.1.0", source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#1f20cef7d49d290f2b620ebb8a7aca92cdbd0e2a" } }, { name = "raycast-api", version = "0.1.0", source = { git = "https://git.kotikot.com/beaver/raycast-api.git#e73894c8e435da5c0709f92f69f11bcd0dab9afe" } }, ] @@ -419,7 +419,7 @@ wheels = [ [[package]] name = "claude-code-api" version = "0.1.0" -source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#bf6116dc8b7f3708685c5a6e27061859e73eb4c9" } +source = { git = "https://git.kotikot.com/beaver/claude-code-api.git#1f20cef7d49d290f2b620ebb8a7aca92cdbd0e2a" } resolution-markers = [ "python_full_version >= '3.14'", "python_full_version < '3.14'",