405 lines
14 KiB
Python
405 lines
14 KiB
Python
"""The gateway-facing public API.
|
|
|
|
`ClaudeCodeBackend` is the only class the gateway needs to know
|
|
about. It owns:
|
|
|
|
- a pool of live `claude` sessions, keyed by a fingerprint of conversation
|
|
history, so a continuing turn reuses an existing PTY (and the
|
|
server-side prompt cache) instead of paying a fresh-spawn tax;
|
|
- the choice between `native_jsonl` (default) and `concat_message`
|
|
(fallback) for seeding a session with prior history that the gateway
|
|
sends in but no live session matches;
|
|
- the conversion from `BackendOptions` (high-level, takes a dict of MCP
|
|
servers) into `PtyProcessOptions` (low-level, takes argv-ready flags),
|
|
including materializing an `--mcp-config` file when `mcp_servers` is set.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import uuid
|
|
from collections.abc import AsyncIterator, Callable, Iterable, Mapping
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Literal, Self
|
|
|
|
from claude_code_api.errors import MessageParseError
|
|
from claude_code_api.events import (
|
|
AssistantMessage,
|
|
ContentBlock,
|
|
Event,
|
|
TextBlock,
|
|
ThinkingBlock,
|
|
ToolResultBlock,
|
|
ToolUseBlock,
|
|
)
|
|
from claude_code_api.injection import (
|
|
build_concat_prompt,
|
|
build_seed_jsonl,
|
|
hash_history,
|
|
)
|
|
from claude_code_api.paths import resolve_jsonl_path
|
|
from claude_code_api.pty import PtyClaudeProcess, PtyProcessOptions
|
|
from claude_code_api.turn import TurnManager
|
|
from claude_code_api.watcher import JsonlWatcher
|
|
|
|
HistoryInjectionMode = Literal["native_jsonl", "concat_message"]
|
|
|
|
ParseErrorCallback = Callable[[MessageParseError, dict[str, Any]], None]
|
|
|
|
_TERMINAL_STOP_REASONS: frozenset[str] = frozenset(
|
|
{"end_turn", "max_tokens", "stop_sequence", "refusal"}
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class BackendOptions:
|
|
"""High-level configuration for `ClaudeCodeBackend`.
|
|
|
|
Mirrors `PtyProcessOptions` shape but speaks the gateway's vocabulary:
|
|
`mcp_servers` is a `{name: config}` mapping (materialized into a temp
|
|
`--mcp-config` file under the hood) rather than a tuple of file paths.
|
|
"""
|
|
|
|
cwd: str | os.PathLike[str]
|
|
model: str | None = None
|
|
system_prompt: str | None = None
|
|
append_system_prompt: str | None = None
|
|
allowed_tools: tuple[str, ...] = ()
|
|
disallowed_tools: tuple[str, ...] = ()
|
|
mcp_servers: Mapping[str, Mapping[str, Any]] | None = None
|
|
permission_mode: str = "bypassPermissions"
|
|
dangerously_skip_permissions: bool = False
|
|
effort: str | None = None
|
|
add_dir: tuple[str, ...] = ()
|
|
settings: str | None = None
|
|
extra_args: tuple[str, ...] = ()
|
|
extra_env: Mapping[str, str] = field(default_factory=dict)
|
|
preserve_provider_env: bool = False
|
|
|
|
history_injection_mode: HistoryInjectionMode = "native_jsonl"
|
|
wait_for_turn_duration: bool = False
|
|
include_meta_user: bool = False
|
|
startup_delay: float = 1.0
|
|
file_wait_timeout: float = 30.0
|
|
turn_duration_timeout: float = 5.0
|
|
|
|
|
|
@dataclass
|
|
class _LiveSession:
|
|
"""One live PTY + watcher + turn manager. Created per conversation."""
|
|
|
|
pty: PtyClaudeProcess
|
|
watcher: JsonlWatcher
|
|
tm: TurnManager
|
|
|
|
@property
|
|
def session_id(self) -> str:
|
|
return self.pty.session_id
|
|
|
|
async def aclose(self) -> None:
|
|
await self.tm.aclose()
|
|
|
|
|
|
SessionFactory = Callable[
|
|
["ClaudeCodeBackend", str, bool, Path, int],
|
|
"asyncio.Future[_LiveSession] | _LiveSession",
|
|
]
|
|
|
|
|
|
class ClaudeCodeBackend:
|
|
"""Persistent multi-session wrapper around the subscription `claude` CLI.
|
|
|
|
Lifecycle:
|
|
async with ClaudeCodeBackend(opts) as backend:
|
|
async for event in backend.complete([{"role": "user", "content": "hi"}]):
|
|
...
|
|
|
|
Each call to `complete()` either reuses a live PTY (if the new
|
|
`messages[:-1]` matches one we already have running) or spawns a fresh
|
|
session, optionally seeding it with prior history. On success, the
|
|
session is stashed under a new fingerprint that incorporates this
|
|
turn, so the next request can find it.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
options: BackendOptions,
|
|
*,
|
|
on_parse_error: ParseErrorCallback | None = None,
|
|
_session_factory: SessionFactory | None = None,
|
|
) -> None:
|
|
self._opts = options
|
|
self._on_parse_error = on_parse_error
|
|
self._sessions: dict[str, _LiveSession] = {}
|
|
self._mcp_config_path: Path | None = None
|
|
self._session_factory = _session_factory
|
|
self._closed = False
|
|
self._lock = asyncio.Lock()
|
|
|
|
@property
|
|
def options(self) -> BackendOptions:
|
|
return self._opts
|
|
|
|
@property
|
|
def live_session_count(self) -> int:
|
|
return len(self._sessions)
|
|
|
|
async def complete(self, messages: list[Mapping[str, Any]]) -> AsyncIterator[Event]:
|
|
"""Run one turn against the matching session (or spawn one).
|
|
|
|
`messages` is an Anthropic-Messages-API style list — alternating
|
|
user/assistant entries ending with a user entry. The backend uses
|
|
`messages[:-1]` to look up a live session by fingerprint; if none
|
|
matches it creates one (seeded with that history if non-empty).
|
|
Yields typed events as they arrive; the final event is the
|
|
synthesized `ResultMessage` from `TurnManager`.
|
|
"""
|
|
if self._closed:
|
|
msg = "ClaudeCodeBackend is closed"
|
|
raise RuntimeError(msg)
|
|
if not messages:
|
|
msg = "messages must not be empty"
|
|
raise ValueError(msg)
|
|
last = messages[-1]
|
|
if last.get("role") != "user":
|
|
msg = "last message must have role='user'"
|
|
raise ValueError(msg)
|
|
last_text = _user_text_payload(last.get("content"))
|
|
|
|
async with self._lock:
|
|
prior = list(messages[:-1])
|
|
fp_prior = hash_history(prior)
|
|
|
|
session: _LiveSession
|
|
if prior and fp_prior in self._sessions:
|
|
session = self._sessions.pop(fp_prior)
|
|
send_text = last_text
|
|
else:
|
|
session = await self._create_session(prior)
|
|
if prior and self._opts.history_injection_mode == "concat_message":
|
|
send_text = build_concat_prompt(prior, last_text)
|
|
else:
|
|
send_text = last_text
|
|
|
|
events: list[Event] = []
|
|
try:
|
|
async for event in session.tm.send_user_message(send_text):
|
|
events.append(event)
|
|
yield event
|
|
except BaseException:
|
|
with contextlib.suppress(Exception):
|
|
await session.aclose()
|
|
raise
|
|
|
|
synthetic_asst = _synthesize_assistant_dict(events)
|
|
new_history = [*list(messages), synthetic_asst]
|
|
self._sessions[hash_history(new_history)] = session
|
|
|
|
async def aclose(self) -> None:
|
|
"""Shut down all live sessions; remove the temp mcp-config file."""
|
|
self._closed = True
|
|
sessions = list(self._sessions.values())
|
|
self._sessions.clear()
|
|
for s in sessions:
|
|
with contextlib.suppress(Exception):
|
|
await s.aclose()
|
|
if self._mcp_config_path is not None:
|
|
with contextlib.suppress(OSError):
|
|
self._mcp_config_path.unlink()
|
|
self._mcp_config_path = None
|
|
|
|
async def __aenter__(self) -> Self:
|
|
return self
|
|
|
|
async def __aexit__(self, _exc_type: object, _exc: object, _tb: object) -> None:
|
|
await self.aclose()
|
|
|
|
async def _create_session(self, history: list[Mapping[str, Any]]) -> _LiveSession:
|
|
"""Spawn a fresh PTY + watcher + manager, optionally seeded.
|
|
|
|
`native_jsonl` (default): write a hand-crafted JSONL transcript at
|
|
`~/.claude/projects/<key>/<session_id>.jsonl`, then start claude
|
|
with `--resume <session_id>`. The watcher starts at the seed
|
|
file's end so it sees only fresh records.
|
|
|
|
`concat_message` (fallback): spawn fresh; the history is injected
|
|
into the first user prompt instead (handled by `complete()`).
|
|
"""
|
|
session_id = str(uuid.uuid4())
|
|
cwd = os.fspath(self._opts.cwd)
|
|
|
|
if history and self._opts.history_injection_mode == "native_jsonl":
|
|
jsonl_path = resolve_jsonl_path(cwd, session_id)
|
|
jsonl_path.parent.mkdir(parents=True, exist_ok=True)
|
|
seed = build_seed_jsonl(history, session_id=session_id, cwd=cwd)
|
|
jsonl_path.write_text(seed, encoding="utf-8")
|
|
start_offset = jsonl_path.stat().st_size
|
|
resume = True
|
|
else:
|
|
jsonl_path = resolve_jsonl_path(cwd, session_id)
|
|
start_offset = 0
|
|
resume = False
|
|
|
|
if self._session_factory is not None:
|
|
result = self._session_factory(
|
|
self, session_id, resume, jsonl_path, start_offset
|
|
)
|
|
if asyncio.iscoroutine(result):
|
|
return await result
|
|
return result # type: ignore[return-value]
|
|
|
|
return await self._spawn_real_session(
|
|
session_id=session_id,
|
|
resume=resume,
|
|
jsonl_path=jsonl_path,
|
|
start_offset=start_offset,
|
|
)
|
|
|
|
async def _spawn_real_session(
|
|
self, *, session_id: str, resume: bool, jsonl_path: Path, start_offset: int
|
|
) -> _LiveSession:
|
|
pty_opts = self._build_pty_options(session_id=session_id, resume=resume)
|
|
pty = PtyClaudeProcess(pty_opts)
|
|
watcher = JsonlWatcher(jsonl_path, start_offset=start_offset)
|
|
tm = TurnManager(
|
|
pty,
|
|
watcher,
|
|
wait_for_turn_duration=self._opts.wait_for_turn_duration,
|
|
include_meta_user=self._opts.include_meta_user,
|
|
file_wait_timeout=self._opts.file_wait_timeout,
|
|
turn_duration_timeout=self._opts.turn_duration_timeout,
|
|
startup_delay=self._opts.startup_delay,
|
|
on_parse_error=self._on_parse_error,
|
|
)
|
|
await tm.start()
|
|
return _LiveSession(pty=pty, watcher=watcher, tm=tm)
|
|
|
|
def _build_pty_options(self, *, session_id: str, resume: bool) -> PtyProcessOptions:
|
|
mcp_config = self._mcp_config_argument()
|
|
kwargs: dict[str, Any] = {
|
|
"cwd": self._opts.cwd,
|
|
"model": self._opts.model,
|
|
"system_prompt": self._opts.system_prompt,
|
|
"append_system_prompt": self._opts.append_system_prompt,
|
|
"allowed_tools": self._opts.allowed_tools,
|
|
"disallowed_tools": self._opts.disallowed_tools,
|
|
"mcp_config": mcp_config,
|
|
"add_dir": self._opts.add_dir,
|
|
"permission_mode": self._opts.permission_mode,
|
|
"dangerously_skip_permissions": self._opts.dangerously_skip_permissions,
|
|
"effort": self._opts.effort,
|
|
"settings": self._opts.settings,
|
|
"extra_args": self._opts.extra_args,
|
|
"preserve_provider_env": self._opts.preserve_provider_env,
|
|
"extra_env": self._opts.extra_env,
|
|
}
|
|
if resume:
|
|
kwargs["resume_session_id"] = session_id
|
|
else:
|
|
kwargs["session_id"] = session_id
|
|
return PtyProcessOptions(**kwargs)
|
|
|
|
def _mcp_config_argument(self) -> tuple[str, ...]:
|
|
"""Materialize `mcp_servers` into a `--mcp-config` file path tuple.
|
|
|
|
The temp file lives for the backend's lifetime — cleaned up in
|
|
`aclose()`. Written lazily so a backend that never spawns a
|
|
session leaves no debris.
|
|
"""
|
|
servers = self._opts.mcp_servers
|
|
if not servers:
|
|
return ()
|
|
if self._mcp_config_path is None:
|
|
fd, path = tempfile.mkstemp(prefix="claude-mcp-", suffix=".json")
|
|
try:
|
|
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
|
json.dump({"mcpServers": dict(servers)}, f)
|
|
except Exception:
|
|
with contextlib.suppress(OSError):
|
|
Path(path).unlink()
|
|
raise
|
|
self._mcp_config_path = Path(path)
|
|
return (str(self._mcp_config_path),)
|
|
|
|
|
|
def _user_text_payload(content: Any) -> str:
|
|
"""Extract the text we'll write to the PTY for the last user message.
|
|
|
|
A string `content` passes through as-is. A list of blocks is flattened
|
|
to its text content; tool_result blocks are not faithfully
|
|
reproducible through stdin and are skipped.
|
|
"""
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
chunks: list[str] = []
|
|
for block in content:
|
|
if isinstance(block, Mapping) and block.get("type") == "text":
|
|
text = block.get("text")
|
|
if isinstance(text, str):
|
|
chunks.append(text)
|
|
if not chunks:
|
|
msg = "last user message content must include at least one text block"
|
|
raise ValueError(msg)
|
|
return " ".join(chunks)
|
|
msg = f"last user message content must be str or list, got {type(content).__name__}"
|
|
raise ValueError(msg)
|
|
|
|
|
|
def _synthesize_assistant_dict(events: Iterable[Event]) -> dict[str, Any]:
|
|
"""Render the terminal assistant message in Anthropic Messages format."""
|
|
terminal: AssistantMessage | None = None
|
|
for ev in reversed(list(events)):
|
|
if (
|
|
isinstance(ev, AssistantMessage)
|
|
and ev.stop_reason in _TERMINAL_STOP_REASONS
|
|
):
|
|
terminal = ev
|
|
break
|
|
if terminal is None:
|
|
return {"role": "assistant", "content": []}
|
|
return {
|
|
"role": "assistant",
|
|
"content": [_block_to_dict(b) for b in terminal.content],
|
|
}
|
|
|
|
|
|
def _block_to_dict(block: ContentBlock) -> dict[str, Any]:
|
|
if isinstance(block, TextBlock):
|
|
return {"type": "text", "text": block.text}
|
|
if isinstance(block, ToolUseBlock):
|
|
return {
|
|
"type": "tool_use",
|
|
"id": block.id,
|
|
"name": block.name,
|
|
"input": block.input,
|
|
}
|
|
if isinstance(block, ToolResultBlock):
|
|
return {
|
|
"type": "tool_result",
|
|
"tool_use_id": block.tool_use_id,
|
|
"content": block.content,
|
|
"is_error": block.is_error,
|
|
}
|
|
if isinstance(block, ThinkingBlock):
|
|
return {
|
|
"type": "thinking",
|
|
"thinking": block.thinking,
|
|
"signature": block.signature,
|
|
}
|
|
msg = f"unknown content block type: {type(block).__name__}"
|
|
raise TypeError(msg)
|
|
|
|
|
|
__all__ = [
|
|
"BackendOptions",
|
|
"ClaudeCodeBackend",
|
|
"HistoryInjectionMode",
|
|
"ParseErrorCallback",
|
|
]
|