Files
claude-code-api/src/claude_code_api/backend.py
T

533 lines
20 KiB
Python

"""The gateway-facing public API.
`ClaudeCodeBackend` is the only class the gateway needs to know
about. It owns:
- a pool of live `claude` sessions, keyed by a fingerprint of conversation
history, so a continuing turn reuses an existing PTY (and the
server-side prompt cache) instead of paying a fresh-spawn tax;
- the choice between `native_jsonl` (default) and `concat_message`
(fallback) for seeding a session with prior history that the gateway
sends in but no live session matches;
- the conversion from `BackendOptions` (high-level, takes a dict of MCP
servers) into `PtyProcessOptions` (low-level, takes argv-ready flags),
including materializing an `--mcp-config` file when `mcp_servers` is set.
"""
from __future__ import annotations
import asyncio
import contextlib
import json
import logging
import os
import tempfile
import uuid
from collections.abc import AsyncIterator, Callable, Iterable, Mapping
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Self
_log = logging.getLogger("claude_code_api.backend")
from claude_code_api.errors import MessageParseError
from claude_code_api.events import (
AssistantMessage,
ContentBlock,
Event,
TextBlock,
ThinkingBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from claude_code_api.injection import (
build_concat_prompt,
build_seed_jsonl,
hash_history,
)
from claude_code_api.paths import resolve_jsonl_path
from claude_code_api.pty import PtyClaudeProcess, PtyProcessOptions
from claude_code_api.turn import TurnManager
from claude_code_api.watcher import JsonlWatcher
HistoryInjectionMode = Literal["native_jsonl", "concat_message"]
ParseErrorCallback = Callable[[MessageParseError, dict[str, Any]], None]
@dataclass(frozen=True)
class BackendOptions:
"""High-level configuration for `ClaudeCodeBackend`.
Mirrors `PtyProcessOptions` shape but speaks the gateway's vocabulary:
`mcp_servers` is a `{name: config}` mapping (materialized into a temp
`--mcp-config` file under the hood) rather than a tuple of file paths.
"""
cwd: str | os.PathLike[str]
model: str | None = None
system_prompt: str | None = None
append_system_prompt: str | None = None
allowed_tools: tuple[str, ...] = ()
disallowed_tools: tuple[str, ...] = ()
mcp_servers: Mapping[str, Mapping[str, Any]] | None = None
permission_mode: str = "bypassPermissions"
dangerously_skip_permissions: bool = False
effort: str | None = None
add_dir: tuple[str, ...] = ()
settings: str | None = None
extra_args: tuple[str, ...] = ()
extra_env: Mapping[str, str] = field(default_factory=dict)
preserve_provider_env: bool = False
history_injection_mode: HistoryInjectionMode = "native_jsonl"
wait_for_turn_duration: bool = False
include_meta_user: bool = False
startup_delay: float = 10.0
file_wait_timeout: float = 30.0
turn_duration_timeout: float = 5.0
@dataclass
class _LiveSession:
"""One live PTY + watcher + turn manager. Created per conversation."""
pty: PtyClaudeProcess
watcher: JsonlWatcher
tm: TurnManager
@property
def session_id(self) -> str:
return self.pty.session_id
async def aclose(self) -> None:
await self.tm.aclose()
SessionFactory = Callable[
["ClaudeCodeBackend", str, bool, Path, int],
"asyncio.Future[_LiveSession] | _LiveSession",
]
class ClaudeCodeBackend:
"""Persistent multi-session wrapper around the subscription `claude` CLI.
Lifecycle:
async with ClaudeCodeBackend(opts) as backend:
async for event in backend.complete([{"role": "user", "content": "hi"}]):
...
Each call to `complete()` either reuses a live PTY (if the new
`messages[:-1]` matches one we already have running) or spawns a fresh
session, optionally seeding it with prior history. On success, the
session is stashed under a new fingerprint that incorporates this
turn, so the next request can find it.
"""
def __init__(
self,
options: BackendOptions,
*,
on_parse_error: ParseErrorCallback | None = None,
_session_factory: SessionFactory | None = None,
) -> None:
self._opts = options
self._on_parse_error = on_parse_error
self._sessions: dict[str, _LiveSession] = {}
self._active: dict[str, _LiveSession] = {}
self._mcp_config_path: Path | None = None
self._session_factory = _session_factory
self._closed = False
self._lock = asyncio.Lock()
@property
def options(self) -> BackendOptions:
return self._opts
@property
def live_session_count(self) -> int:
return len(self._sessions)
@property
def live_sessions(self) -> dict[str, PtyClaudeProcess]:
"""Snapshot of live PTY processes keyed by ``session_id``.
Returned dict is a copy — caller may iterate freely without
worrying about concurrent ``complete()`` calls reshuffling
``_sessions`` (which is keyed by history fingerprint, not
``session_id``). Intended for debug surfaces (e.g. admin
terminal viewer) that need to look up a session by id.
Covers both idle pooled sessions (``_sessions``) and the turn(s)
currently running (``_active``) — so a live terminal is visible
*during* its turn, not only after it's repooled. ``_active`` wins
on key collision, but a session is never in both at once.
"""
out = {s.session_id: s.pty for s in self._sessions.values()}
out.update({sid: s.pty for sid, s in self._active.items()})
return out
async def complete(self, messages: list[Mapping[str, Any]]) -> AsyncIterator[Event]:
"""Run one turn against the matching session (or spawn one).
`messages` is an Anthropic-Messages-API style list — alternating
user/assistant entries ending with a user entry. The backend uses
`messages[:-1]` to look up a live session by fingerprint; if none
matches it creates one (seeded with that history if non-empty).
Yields typed events as they arrive; the final event is the
synthesized `ResultMessage` from `TurnManager`.
"""
if self._closed:
msg = "ClaudeCodeBackend is closed"
raise RuntimeError(msg)
if not messages:
msg = "messages must not be empty"
raise ValueError(msg)
last = messages[-1]
if last.get("role") != "user":
msg = "last message must have role='user'"
raise ValueError(msg)
last_text = _user_text_payload(last.get("content"))
async with self._lock:
prior = list(messages[:-1])
fp_prior = hash_history(prior)
_log.info(
"complete: n_msgs=%d prior_fp=%s last_text_len=%d pool_size=%d",
len(messages),
fp_prior[:12],
len(last_text),
len(self._sessions),
)
session: _LiveSession
if prior and fp_prior in self._sessions:
session = self._sessions.pop(fp_prior)
send_text = last_text
_log.info(
"complete: POOL HIT fp=%s -> reusing session_id=%s",
fp_prior[:12],
session.session_id,
)
else:
_log.info(
"complete: POOL MISS fp=%s (prior=%d msgs) -> spawning new session",
fp_prior[:12],
len(prior),
)
session = await self._create_session(prior)
if prior and self._opts.history_injection_mode == "concat_message":
send_text = build_concat_prompt(prior, last_text)
_log.info(
"complete: concat_message mode, send_text_len=%d",
len(send_text),
)
else:
send_text = last_text
events: list[Event] = []
n_assistant = 0
n_user = 0
n_system = 0
self._active[session.session_id] = session
try:
_log.info(
"complete: sending user msg to session_id=%s (text_len=%d)",
session.session_id,
len(send_text),
)
async for event in session.tm.send_user_message(send_text):
events.append(event)
if isinstance(event, AssistantMessage):
n_assistant += 1
elif isinstance(event, UserMessage):
n_user += 1
else:
n_system += 1
yield event
except BaseException as exc:
_log.exception(
"complete: session_id=%s FAILED (events so far: a=%d u=%d sys=%d): %s",
session.session_id,
n_assistant,
n_user,
n_system,
exc,
)
with contextlib.suppress(Exception):
await session.aclose()
raise
finally:
self._active.pop(session.session_id, None)
synthesized_cycle = synthesize_turn_messages(events)
new_history = [*list(messages), *synthesized_cycle]
new_fp = hash_history(new_history)
self._sessions[new_fp] = session
_log.info(
"complete: session_id=%s DONE a=%d u=%d sys=%d synth=%d -> repooled under fp=%s",
session.session_id,
n_assistant,
n_user,
n_system,
len(synthesized_cycle),
new_fp[:12],
)
async def aclose(self) -> None:
"""Shut down all live sessions; remove the temp mcp-config file."""
self._closed = True
sessions = list(self._sessions.values())
self._sessions.clear()
for s in sessions:
with contextlib.suppress(Exception):
await s.aclose()
if self._mcp_config_path is not None:
with contextlib.suppress(OSError):
self._mcp_config_path.unlink()
self._mcp_config_path = None
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, _exc_type: object, _exc: object, _tb: object) -> None:
await self.aclose()
async def _create_session(self, history: list[Mapping[str, Any]]) -> _LiveSession:
"""Spawn a fresh PTY + watcher + manager, optionally seeded.
`native_jsonl` (default): write a hand-crafted JSONL transcript at
`~/.claude/projects/<key>/<session_id>.jsonl`, then start claude
with `--resume <session_id>`. The watcher starts at the seed
file's end so it sees only fresh records.
`concat_message` (fallback): spawn fresh; the history is injected
into the first user prompt instead (handled by `complete()`).
"""
session_id = str(uuid.uuid4())
cwd = os.fspath(self._opts.cwd)
if history and self._opts.history_injection_mode == "native_jsonl":
jsonl_path = resolve_jsonl_path(cwd, session_id)
jsonl_path.parent.mkdir(parents=True, exist_ok=True)
seed = build_seed_jsonl(history, session_id=session_id, cwd=cwd)
jsonl_path.write_text(seed, encoding="utf-8")
start_offset = jsonl_path.stat().st_size
resume = True
_log.info(
"_create_session: session_id=%s SEED jsonl=%s bytes=%d history_msgs=%d resume=True",
session_id,
jsonl_path,
start_offset,
len(history),
)
else:
jsonl_path = resolve_jsonl_path(cwd, session_id)
start_offset = 0
resume = False
_log.info(
"_create_session: session_id=%s FRESH jsonl=%s history_msgs=%d resume=False mode=%s",
session_id,
jsonl_path,
len(history),
self._opts.history_injection_mode,
)
if self._session_factory is not None:
result = self._session_factory(
self, session_id, resume, jsonl_path, start_offset
)
if asyncio.iscoroutine(result):
return await result
return result # type: ignore[return-value]
return await self._spawn_real_session(
session_id=session_id,
resume=resume,
jsonl_path=jsonl_path,
start_offset=start_offset,
)
async def _spawn_real_session(
self, *, session_id: str, resume: bool, jsonl_path: Path, start_offset: int
) -> _LiveSession:
_log.info(
"_spawn_real_session: session_id=%s resume=%s start_offset=%d",
session_id,
resume,
start_offset,
)
pty_opts = self._build_pty_options(session_id=session_id, resume=resume)
pty = PtyClaudeProcess(pty_opts)
watcher = JsonlWatcher(jsonl_path, start_offset=start_offset)
tm = TurnManager(
pty,
watcher,
wait_for_turn_duration=self._opts.wait_for_turn_duration,
include_meta_user=self._opts.include_meta_user,
file_wait_timeout=self._opts.file_wait_timeout,
turn_duration_timeout=self._opts.turn_duration_timeout,
startup_delay=self._opts.startup_delay,
on_parse_error=self._on_parse_error,
)
await tm.start()
_log.info(
"_spawn_real_session: session_id=%s STARTED pid=%s",
session_id,
getattr(pty, "pid", "?"),
)
return _LiveSession(pty=pty, watcher=watcher, tm=tm)
def _build_pty_options(self, *, session_id: str, resume: bool) -> PtyProcessOptions:
mcp_config = self._mcp_config_argument()
kwargs: dict[str, Any] = {
"cwd": self._opts.cwd,
"model": self._opts.model,
"system_prompt": self._opts.system_prompt,
"append_system_prompt": self._opts.append_system_prompt,
"allowed_tools": self._opts.allowed_tools,
"disallowed_tools": self._opts.disallowed_tools,
"mcp_config": mcp_config,
"add_dir": self._opts.add_dir,
"permission_mode": self._opts.permission_mode,
"dangerously_skip_permissions": self._opts.dangerously_skip_permissions,
"effort": self._opts.effort,
"settings": self._opts.settings,
"extra_args": self._opts.extra_args,
"preserve_provider_env": self._opts.preserve_provider_env,
"extra_env": self._opts.extra_env,
}
if resume:
kwargs["resume_session_id"] = session_id
else:
kwargs["session_id"] = session_id
return PtyProcessOptions(**kwargs)
def _mcp_config_argument(self) -> tuple[str, ...]:
"""Materialize `mcp_servers` into a `--mcp-config` file path tuple.
The temp file lives for the backend's lifetime — cleaned up in
`aclose()`. Written lazily so a backend that never spawns a
session leaves no debris.
"""
servers = self._opts.mcp_servers
if not servers:
return ()
if self._mcp_config_path is None:
fd, path = tempfile.mkstemp(prefix="claude-mcp-", suffix=".json")
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump({"mcpServers": dict(servers)}, f)
except Exception:
with contextlib.suppress(OSError):
Path(path).unlink()
raise
self._mcp_config_path = Path(path)
return (str(self._mcp_config_path),)
def _user_text_payload(content: Any) -> str:
"""Extract the text we'll write to the PTY for the last user message.
A string `content` passes through as-is. A list of blocks is flattened
to its text content; tool_result blocks are not faithfully
reproducible through stdin and are skipped.
"""
if isinstance(content, str):
return content
if isinstance(content, list):
chunks: list[str] = []
for block in content:
if isinstance(block, Mapping) and block.get("type") == "text":
text = block.get("text")
if isinstance(text, str):
chunks.append(text)
if not chunks:
msg = "last user message content must include at least one text block"
raise ValueError(msg)
return " ".join(chunks)
msg = f"last user message content must be str or list, got {type(content).__name__}"
raise ValueError(msg)
def synthesize_turn_messages(events: Iterable[Event]) -> list[dict[str, Any]]:
"""Render a turn's full assistant↔tool cycle as Anthropic-shape messages.
A single ``complete()`` call can produce multiple ``AssistantMessage``
records (each tool-use cycle is its own record, terminated by a
``UserMessage`` carrying the matching ``tool_result`` blocks). We
fold that whole sequence into a list of canonical messages — exactly
what the Anthropic Messages API would see if claude were running
over the wire instead of in a PTY. The result is what the session
fingerprint is computed over and what gets seeded into JSONL on a
cache-miss re-spawn, so the live PTY and a freshly-resumed one stay
semantically equivalent.
Excludes intermediate ``UserMessage`` records that carry only the
echoed prompt text (string content) — those are claude's own input
record, not part of the conversational reply. Only tool_result
``UserMessage`` records (list-of-blocks content) survive.
"""
out: list[dict[str, Any]] = []
for ev in events:
if isinstance(ev, AssistantMessage):
out.append(
{
"role": "assistant",
"content": [_block_to_dict(b) for b in ev.content],
}
)
elif isinstance(ev, UserMessage):
content = ev.content
if isinstance(content, list) and content:
out.append(
{
"role": "user",
"content": [_block_to_dict(b) for b in content],
}
)
return out
def _block_to_dict(block: ContentBlock) -> dict[str, Any]:
if isinstance(block, TextBlock):
return {"type": "text", "text": block.text}
if isinstance(block, ToolUseBlock):
return {
"type": "tool_use",
"id": block.id,
"name": block.name,
"input": block.input,
}
if isinstance(block, ToolResultBlock):
# ``is_error`` is optional on the wire; emitting ``null`` makes the
# Anthropic API reject the message with "Input should be a valid
# boolean". Omit it entirely when we don't have a value.
result: dict[str, Any] = {
"type": "tool_result",
"tool_use_id": block.tool_use_id,
"content": block.content,
}
if block.is_error is not None:
result["is_error"] = block.is_error
return result
if isinstance(block, ThinkingBlock):
return {
"type": "thinking",
"thinking": block.thinking,
"signature": block.signature,
}
msg = f"unknown content block type: {type(block).__name__}"
raise TypeError(msg)
__all__ = [
"BackendOptions",
"ClaudeCodeBackend",
"HistoryInjectionMode",
"ParseErrorCallback",
"synthesize_turn_messages",
]