Files
claude-code-api/src/claude_code_api/backend.py
T

405 lines
14 KiB
Python

"""The gateway-facing public API.
`ClaudeCodeBackend` is the only class the gateway needs to know
about. It owns:
- a pool of live `claude` sessions, keyed by a fingerprint of conversation
history, so a continuing turn reuses an existing PTY (and the
server-side prompt cache) instead of paying a fresh-spawn tax;
- the choice between `native_jsonl` (default) and `concat_message`
(fallback) for seeding a session with prior history that the gateway
sends in but no live session matches;
- the conversion from `BackendOptions` (high-level, takes a dict of MCP
servers) into `PtyProcessOptions` (low-level, takes argv-ready flags),
including materializing an `--mcp-config` file when `mcp_servers` is set.
"""
from __future__ import annotations
import asyncio
import contextlib
import json
import os
import tempfile
import uuid
from collections.abc import AsyncIterator, Callable, Iterable, Mapping
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Self
from claude_code_api.errors import MessageParseError
from claude_code_api.events import (
AssistantMessage,
ContentBlock,
Event,
TextBlock,
ThinkingBlock,
ToolResultBlock,
ToolUseBlock,
)
from claude_code_api.injection import (
build_concat_prompt,
build_seed_jsonl,
hash_history,
)
from claude_code_api.paths import resolve_jsonl_path
from claude_code_api.pty import PtyClaudeProcess, PtyProcessOptions
from claude_code_api.turn import TurnManager
from claude_code_api.watcher import JsonlWatcher
HistoryInjectionMode = Literal["native_jsonl", "concat_message"]
ParseErrorCallback = Callable[[MessageParseError, dict[str, Any]], None]
_TERMINAL_STOP_REASONS: frozenset[str] = frozenset(
{"end_turn", "max_tokens", "stop_sequence", "refusal"}
)
@dataclass(frozen=True)
class BackendOptions:
"""High-level configuration for `ClaudeCodeBackend`.
Mirrors `PtyProcessOptions` shape but speaks the gateway's vocabulary:
`mcp_servers` is a `{name: config}` mapping (materialized into a temp
`--mcp-config` file under the hood) rather than a tuple of file paths.
"""
cwd: str | os.PathLike[str]
model: str | None = None
system_prompt: str | None = None
append_system_prompt: str | None = None
allowed_tools: tuple[str, ...] = ()
disallowed_tools: tuple[str, ...] = ()
mcp_servers: Mapping[str, Mapping[str, Any]] | None = None
permission_mode: str = "bypassPermissions"
dangerously_skip_permissions: bool = False
effort: str | None = None
add_dir: tuple[str, ...] = ()
settings: str | None = None
extra_args: tuple[str, ...] = ()
extra_env: Mapping[str, str] = field(default_factory=dict)
preserve_provider_env: bool = False
history_injection_mode: HistoryInjectionMode = "native_jsonl"
wait_for_turn_duration: bool = False
include_meta_user: bool = False
startup_delay: float = 1.0
file_wait_timeout: float = 30.0
turn_duration_timeout: float = 5.0
@dataclass
class _LiveSession:
"""One live PTY + watcher + turn manager. Created per conversation."""
pty: PtyClaudeProcess
watcher: JsonlWatcher
tm: TurnManager
@property
def session_id(self) -> str:
return self.pty.session_id
async def aclose(self) -> None:
await self.tm.aclose()
SessionFactory = Callable[
["ClaudeCodeBackend", str, bool, Path, int],
"asyncio.Future[_LiveSession] | _LiveSession",
]
class ClaudeCodeBackend:
"""Persistent multi-session wrapper around the subscription `claude` CLI.
Lifecycle:
async with ClaudeCodeBackend(opts) as backend:
async for event in backend.complete([{"role": "user", "content": "hi"}]):
...
Each call to `complete()` either reuses a live PTY (if the new
`messages[:-1]` matches one we already have running) or spawns a fresh
session, optionally seeding it with prior history. On success, the
session is stashed under a new fingerprint that incorporates this
turn, so the next request can find it.
"""
def __init__(
self,
options: BackendOptions,
*,
on_parse_error: ParseErrorCallback | None = None,
_session_factory: SessionFactory | None = None,
) -> None:
self._opts = options
self._on_parse_error = on_parse_error
self._sessions: dict[str, _LiveSession] = {}
self._mcp_config_path: Path | None = None
self._session_factory = _session_factory
self._closed = False
self._lock = asyncio.Lock()
@property
def options(self) -> BackendOptions:
return self._opts
@property
def live_session_count(self) -> int:
return len(self._sessions)
async def complete(self, messages: list[Mapping[str, Any]]) -> AsyncIterator[Event]:
"""Run one turn against the matching session (or spawn one).
`messages` is an Anthropic-Messages-API style list — alternating
user/assistant entries ending with a user entry. The backend uses
`messages[:-1]` to look up a live session by fingerprint; if none
matches it creates one (seeded with that history if non-empty).
Yields typed events as they arrive; the final event is the
synthesized `ResultMessage` from `TurnManager`.
"""
if self._closed:
msg = "ClaudeCodeBackend is closed"
raise RuntimeError(msg)
if not messages:
msg = "messages must not be empty"
raise ValueError(msg)
last = messages[-1]
if last.get("role") != "user":
msg = "last message must have role='user'"
raise ValueError(msg)
last_text = _user_text_payload(last.get("content"))
async with self._lock:
prior = list(messages[:-1])
fp_prior = hash_history(prior)
session: _LiveSession
if prior and fp_prior in self._sessions:
session = self._sessions.pop(fp_prior)
send_text = last_text
else:
session = await self._create_session(prior)
if prior and self._opts.history_injection_mode == "concat_message":
send_text = build_concat_prompt(prior, last_text)
else:
send_text = last_text
events: list[Event] = []
try:
async for event in session.tm.send_user_message(send_text):
events.append(event)
yield event
except BaseException:
with contextlib.suppress(Exception):
await session.aclose()
raise
synthetic_asst = _synthesize_assistant_dict(events)
new_history = [*list(messages), synthetic_asst]
self._sessions[hash_history(new_history)] = session
async def aclose(self) -> None:
"""Shut down all live sessions; remove the temp mcp-config file."""
self._closed = True
sessions = list(self._sessions.values())
self._sessions.clear()
for s in sessions:
with contextlib.suppress(Exception):
await s.aclose()
if self._mcp_config_path is not None:
with contextlib.suppress(OSError):
self._mcp_config_path.unlink()
self._mcp_config_path = None
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, _exc_type: object, _exc: object, _tb: object) -> None:
await self.aclose()
async def _create_session(self, history: list[Mapping[str, Any]]) -> _LiveSession:
"""Spawn a fresh PTY + watcher + manager, optionally seeded.
`native_jsonl` (default): write a hand-crafted JSONL transcript at
`~/.claude/projects/<key>/<session_id>.jsonl`, then start claude
with `--resume <session_id>`. The watcher starts at the seed
file's end so it sees only fresh records.
`concat_message` (fallback): spawn fresh; the history is injected
into the first user prompt instead (handled by `complete()`).
"""
session_id = str(uuid.uuid4())
cwd = os.fspath(self._opts.cwd)
if history and self._opts.history_injection_mode == "native_jsonl":
jsonl_path = resolve_jsonl_path(cwd, session_id)
jsonl_path.parent.mkdir(parents=True, exist_ok=True)
seed = build_seed_jsonl(history, session_id=session_id, cwd=cwd)
jsonl_path.write_text(seed, encoding="utf-8")
start_offset = jsonl_path.stat().st_size
resume = True
else:
jsonl_path = resolve_jsonl_path(cwd, session_id)
start_offset = 0
resume = False
if self._session_factory is not None:
result = self._session_factory(
self, session_id, resume, jsonl_path, start_offset
)
if asyncio.iscoroutine(result):
return await result
return result # type: ignore[return-value]
return await self._spawn_real_session(
session_id=session_id,
resume=resume,
jsonl_path=jsonl_path,
start_offset=start_offset,
)
async def _spawn_real_session(
self, *, session_id: str, resume: bool, jsonl_path: Path, start_offset: int
) -> _LiveSession:
pty_opts = self._build_pty_options(session_id=session_id, resume=resume)
pty = PtyClaudeProcess(pty_opts)
watcher = JsonlWatcher(jsonl_path, start_offset=start_offset)
tm = TurnManager(
pty,
watcher,
wait_for_turn_duration=self._opts.wait_for_turn_duration,
include_meta_user=self._opts.include_meta_user,
file_wait_timeout=self._opts.file_wait_timeout,
turn_duration_timeout=self._opts.turn_duration_timeout,
startup_delay=self._opts.startup_delay,
on_parse_error=self._on_parse_error,
)
await tm.start()
return _LiveSession(pty=pty, watcher=watcher, tm=tm)
def _build_pty_options(self, *, session_id: str, resume: bool) -> PtyProcessOptions:
mcp_config = self._mcp_config_argument()
kwargs: dict[str, Any] = {
"cwd": self._opts.cwd,
"model": self._opts.model,
"system_prompt": self._opts.system_prompt,
"append_system_prompt": self._opts.append_system_prompt,
"allowed_tools": self._opts.allowed_tools,
"disallowed_tools": self._opts.disallowed_tools,
"mcp_config": mcp_config,
"add_dir": self._opts.add_dir,
"permission_mode": self._opts.permission_mode,
"dangerously_skip_permissions": self._opts.dangerously_skip_permissions,
"effort": self._opts.effort,
"settings": self._opts.settings,
"extra_args": self._opts.extra_args,
"preserve_provider_env": self._opts.preserve_provider_env,
"extra_env": self._opts.extra_env,
}
if resume:
kwargs["resume_session_id"] = session_id
else:
kwargs["session_id"] = session_id
return PtyProcessOptions(**kwargs)
def _mcp_config_argument(self) -> tuple[str, ...]:
"""Materialize `mcp_servers` into a `--mcp-config` file path tuple.
The temp file lives for the backend's lifetime — cleaned up in
`aclose()`. Written lazily so a backend that never spawns a
session leaves no debris.
"""
servers = self._opts.mcp_servers
if not servers:
return ()
if self._mcp_config_path is None:
fd, path = tempfile.mkstemp(prefix="claude-mcp-", suffix=".json")
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump({"mcpServers": dict(servers)}, f)
except Exception:
with contextlib.suppress(OSError):
Path(path).unlink()
raise
self._mcp_config_path = Path(path)
return (str(self._mcp_config_path),)
def _user_text_payload(content: Any) -> str:
"""Extract the text we'll write to the PTY for the last user message.
A string `content` passes through as-is. A list of blocks is flattened
to its text content; tool_result blocks are not faithfully
reproducible through stdin and are skipped.
"""
if isinstance(content, str):
return content
if isinstance(content, list):
chunks: list[str] = []
for block in content:
if isinstance(block, Mapping) and block.get("type") == "text":
text = block.get("text")
if isinstance(text, str):
chunks.append(text)
if not chunks:
msg = "last user message content must include at least one text block"
raise ValueError(msg)
return " ".join(chunks)
msg = f"last user message content must be str or list, got {type(content).__name__}"
raise ValueError(msg)
def _synthesize_assistant_dict(events: Iterable[Event]) -> dict[str, Any]:
"""Render the terminal assistant message in Anthropic Messages format."""
terminal: AssistantMessage | None = None
for ev in reversed(list(events)):
if (
isinstance(ev, AssistantMessage)
and ev.stop_reason in _TERMINAL_STOP_REASONS
):
terminal = ev
break
if terminal is None:
return {"role": "assistant", "content": []}
return {
"role": "assistant",
"content": [_block_to_dict(b) for b in terminal.content],
}
def _block_to_dict(block: ContentBlock) -> dict[str, Any]:
if isinstance(block, TextBlock):
return {"type": "text", "text": block.text}
if isinstance(block, ToolUseBlock):
return {
"type": "tool_use",
"id": block.id,
"name": block.name,
"input": block.input,
}
if isinstance(block, ToolResultBlock):
return {
"type": "tool_result",
"tool_use_id": block.tool_use_id,
"content": block.content,
"is_error": block.is_error,
}
if isinstance(block, ThinkingBlock):
return {
"type": "thinking",
"thinking": block.thinking,
"signature": block.signature,
}
msg = f"unknown content block type: {type(block).__name__}"
raise TypeError(msg)
__all__ = [
"BackendOptions",
"ClaudeCodeBackend",
"HistoryInjectionMode",
"ParseErrorCallback",
]