feat: vibed out some slop over here also
This commit is contained in:
@@ -0,0 +1,852 @@
|
||||
"""Unit + smoke tests for Layer 5 (`ClaudeCodeBackend`).
|
||||
|
||||
Unit tests inject a `FakePty`-backed session factory so we can drive the
|
||||
dispatch logic end-to-end — fingerprint lookup, fresh spawn vs continuation,
|
||||
native_jsonl seeding vs concat_message preamble, post-turn fingerprint
|
||||
stash — without launching `claude`. The smoke test at the bottom spawns
|
||||
the real binary behind `RUN_CLAUDE_SMOKE=1`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_api import (
|
||||
AssistantMessage,
|
||||
BackendOptions,
|
||||
ClaudeCodeBackend,
|
||||
ResultMessage,
|
||||
SessionError,
|
||||
TextBlock,
|
||||
UserMessage,
|
||||
)
|
||||
from claude_code_api.backend import _LiveSession
|
||||
from claude_code_api.injection import hash_history
|
||||
from claude_code_api.paths import resolve_jsonl_path
|
||||
from claude_code_api.watcher import JsonlWatcher
|
||||
from claude_code_api.turn import TurnManager
|
||||
|
||||
# --- fakes -----------------------------------------------------------------
|
||||
|
||||
|
||||
class FakePty:
|
||||
"""Records writes and flushes a scripted JSONL batch on each `write()`.
|
||||
|
||||
Reused shape from `test_turn_manager.py` so the contract stays familiar.
|
||||
Each backend `complete()` call ultimately drives one `write()` on the
|
||||
underlying PTY, which consumes the next entry in `scripts`. Tests pre-load
|
||||
the script list with one batch per expected turn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
jsonl_path: Path,
|
||||
*,
|
||||
session_id: str,
|
||||
scripts: list[list[dict[str, Any]]],
|
||||
) -> None:
|
||||
self.cwd = str(jsonl_path.parent)
|
||||
self.session_id = session_id
|
||||
self._jsonl = jsonl_path
|
||||
self._scripts = scripts
|
||||
self._write_count = 0
|
||||
self.writes: list[str] = []
|
||||
self.started = False
|
||||
self.closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
self.started = True
|
||||
|
||||
async def write(self, text: str, *, newline: bool = True) -> int:
|
||||
self.writes.append(text)
|
||||
if self._write_count < len(self._scripts):
|
||||
self._jsonl.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self._jsonl.open("a", encoding="utf-8") as f:
|
||||
for rec in self._scripts[self._write_count]:
|
||||
f.write(json.dumps(rec) + "\n")
|
||||
self._write_count += 1
|
||||
return len(text)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
def _user_rec(text: str, session_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "user",
|
||||
"uuid": f"u-{text[:8]}",
|
||||
"sessionId": session_id,
|
||||
"parentUuid": None,
|
||||
"message": {"role": "user", "content": text},
|
||||
}
|
||||
|
||||
|
||||
def _assistant_rec(
|
||||
text: str,
|
||||
session_id: str,
|
||||
*,
|
||||
stop_reason: str = "end_turn",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "assistant",
|
||||
"uuid": f"a-{text[:8]}",
|
||||
"sessionId": session_id,
|
||||
"parentUuid": None,
|
||||
"message": {
|
||||
"id": "msg_x",
|
||||
"role": "assistant",
|
||||
"model": "claude-test",
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"stop_reason": stop_reason,
|
||||
"usage": {"input_tokens": 1, "output_tokens": 1},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class FakeFactoryHarness:
|
||||
"""Builds the `_session_factory` callable the backend wants, while
|
||||
also tracking every session spawned so tests can inspect them.
|
||||
|
||||
Each call to the factory pops the next FakePty script batch off the
|
||||
queue and wires a real `TurnManager` + `JsonlWatcher` around it — that
|
||||
way we exercise the same code path real sessions use, only the bottom
|
||||
layer is faked.
|
||||
"""
|
||||
|
||||
def __init__(self, scripts_per_session: list[list[list[dict[str, Any]]]]) -> None:
|
||||
self._scripts = list(scripts_per_session)
|
||||
self.spawned: list[FakePty] = []
|
||||
self.seed_files: list[tuple[Path, bytes]] = []
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
backend: ClaudeCodeBackend,
|
||||
session_id: str,
|
||||
resume: bool,
|
||||
jsonl_path: Path,
|
||||
start_offset: int,
|
||||
) -> Any:
|
||||
# Reconstruct the test-visible script for THIS session.
|
||||
if not self._scripts:
|
||||
raise AssertionError("FakeFactoryHarness ran out of scripts")
|
||||
scripts = self._scripts.pop(0)
|
||||
if resume and jsonl_path.exists():
|
||||
self.seed_files.append((jsonl_path, jsonl_path.read_bytes()))
|
||||
fake = FakePty(jsonl_path, session_id=session_id, scripts=scripts)
|
||||
self.spawned.append(fake)
|
||||
|
||||
watcher = JsonlWatcher(jsonl_path, poll_interval=0.01, start_offset=start_offset)
|
||||
tm = TurnManager(
|
||||
fake, # type: ignore[arg-type]
|
||||
watcher,
|
||||
startup_delay=0.0,
|
||||
file_wait_timeout=2.0,
|
||||
)
|
||||
|
||||
async def _start() -> _LiveSession:
|
||||
await tm.start()
|
||||
return _LiveSession(pty=fake, watcher=watcher, tm=tm) # type: ignore[arg-type]
|
||||
|
||||
return _start()
|
||||
|
||||
|
||||
# --- option / validation tests --------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_rejects_empty_messages(tmp_path: Path) -> None:
|
||||
backend = ClaudeCodeBackend(BackendOptions(cwd=str(tmp_path)))
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
async for _ in backend.complete([]):
|
||||
pass
|
||||
await backend.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_rejects_non_user_last_message(tmp_path: Path) -> None:
|
||||
backend = ClaudeCodeBackend(BackendOptions(cwd=str(tmp_path)))
|
||||
with pytest.raises(ValueError, match="user"):
|
||||
async for _ in backend.complete([{"role": "assistant", "content": "hi"}]):
|
||||
pass
|
||||
await backend.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_after_aclose_raises(tmp_path: Path) -> None:
|
||||
backend = ClaudeCodeBackend(BackendOptions(cwd=str(tmp_path)))
|
||||
await backend.aclose()
|
||||
with pytest.raises(RuntimeError, match="closed"):
|
||||
async for _ in backend.complete([{"role": "user", "content": "hi"}]):
|
||||
pass
|
||||
|
||||
|
||||
# --- single-turn fresh session -------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_fresh_session_yields_events(tmp_path: Path) -> None:
|
||||
"""One message → spawn a fresh session, run one turn, get events back.
|
||||
|
||||
Because there's no prior history, no seed JSONL gets written. The fake
|
||||
PTY's `write()` appends a scripted `(user, assistant)` pair to the JSONL
|
||||
on disk; the real watcher tails it and the real TurnManager closes the
|
||||
turn on the terminal assistant.
|
||||
"""
|
||||
# We need to know the session_id ahead of time? No — let the factory
|
||||
# pull it from the backend's invocation. The scripts in scripts_per_session
|
||||
# carry sessionId fields but those are decorative for our purposes —
|
||||
# the watcher / normalizer don't filter on them.
|
||||
scripts_per_session = [
|
||||
# session 0:
|
||||
[
|
||||
# turn 0 batch (written on first write())
|
||||
[
|
||||
_user_rec("hi", "S0"),
|
||||
_assistant_rec("hello there", "S0"),
|
||||
],
|
||||
],
|
||||
]
|
||||
harness = FakeFactoryHarness(scripts_per_session)
|
||||
backend = ClaudeCodeBackend(
|
||||
BackendOptions(cwd=str(tmp_path)),
|
||||
_session_factory=harness,
|
||||
)
|
||||
|
||||
events: list[Any] = []
|
||||
async for event in backend.complete([{"role": "user", "content": "hi"}]):
|
||||
events.append(event)
|
||||
await backend.aclose()
|
||||
|
||||
assert len(harness.spawned) == 1
|
||||
assert harness.spawned[0].writes == ["hi"]
|
||||
assert any(isinstance(e, UserMessage) for e in events)
|
||||
assert any(isinstance(e, AssistantMessage) for e in events)
|
||||
assert isinstance(events[-1], ResultMessage)
|
||||
assert events[-1].stop_reason == "end_turn"
|
||||
# No seed was written — first turn has empty prior history.
|
||||
assert harness.seed_files == []
|
||||
|
||||
|
||||
# --- multi-turn fingerprint reuse ----------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_continuation_reuses_live_session(tmp_path: Path) -> None:
|
||||
"""Second `complete()` whose `messages[:-1]` matches the post-turn
|
||||
fingerprint of the first call must hit the live session — no new PTY,
|
||||
no seed file.
|
||||
"""
|
||||
scripts_per_session = [
|
||||
# session 0 handles BOTH turns (two write() calls).
|
||||
[
|
||||
[_user_rec("hi", "S0"), _assistant_rec("hello there", "S0")],
|
||||
[_user_rec("again", "S0"), _assistant_rec("hi again", "S0")],
|
||||
],
|
||||
]
|
||||
harness = FakeFactoryHarness(scripts_per_session)
|
||||
backend = ClaudeCodeBackend(
|
||||
BackendOptions(cwd=str(tmp_path)),
|
||||
_session_factory=harness,
|
||||
)
|
||||
|
||||
events1: list[Any] = []
|
||||
async for e in backend.complete([{"role": "user", "content": "hi"}]):
|
||||
events1.append(e)
|
||||
|
||||
# Build the continuation: client echoes back our synthesized assistant
|
||||
# in canonical Anthropic shape (list of blocks).
|
||||
continuation = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "hello there"}]},
|
||||
{"role": "user", "content": "again"},
|
||||
]
|
||||
events2: list[Any] = []
|
||||
async for e in backend.complete(continuation):
|
||||
events2.append(e)
|
||||
await backend.aclose()
|
||||
|
||||
# Only ONE session was spawned across both turns.
|
||||
assert len(harness.spawned) == 1
|
||||
assert harness.spawned[0].writes == ["hi", "again"]
|
||||
# Second turn's events are clean (turn_count bookkeeping):
|
||||
assert isinstance(events2[-1], ResultMessage)
|
||||
assert events2[-1].num_turns == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unmatched_history_spawns_new_session_via_native_jsonl(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""When prior history doesn't match any live session, the backend
|
||||
seeds a JSONL with that history and spawns a fresh `--resume` session
|
||||
(native_jsonl default mode).
|
||||
"""
|
||||
scripts_per_session = [
|
||||
# one session for one turn — the only write() is the new user message
|
||||
[
|
||||
[_user_rec("how are you?", "S0"), _assistant_rec("good", "S0")],
|
||||
],
|
||||
]
|
||||
harness = FakeFactoryHarness(scripts_per_session)
|
||||
backend = ClaudeCodeBackend(
|
||||
BackendOptions(cwd=str(tmp_path)),
|
||||
_session_factory=harness,
|
||||
)
|
||||
|
||||
# Three messages, no live session in the pool — must seed.
|
||||
messages = [
|
||||
{"role": "user", "content": "remember beaver"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
{"role": "user", "content": "how are you?"},
|
||||
]
|
||||
events: list[Any] = []
|
||||
async for e in backend.complete(messages):
|
||||
events.append(e)
|
||||
await backend.aclose()
|
||||
|
||||
assert len(harness.spawned) == 1
|
||||
# Only the LAST user message is sent into the PTY — history went via seed.
|
||||
assert harness.spawned[0].writes == ["how are you?"]
|
||||
# A seed file was captured by the harness.
|
||||
assert len(harness.seed_files) == 1
|
||||
_seed_path, seed_bytes = harness.seed_files[0]
|
||||
seed_lines = [
|
||||
json.loads(line) for line in seed_bytes.decode("utf-8").strip().splitlines()
|
||||
]
|
||||
# Two seeded records (one user + one assistant) for the prior turn.
|
||||
assert [r["type"] for r in seed_lines] == ["user", "assistant"]
|
||||
assert seed_lines[0]["message"]["content"] == "remember beaver"
|
||||
assert seed_lines[1]["message"]["content"] == [{"type": "text", "text": "ok"}]
|
||||
assert isinstance(events[-1], ResultMessage)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unmatched_history_uses_concat_message_when_configured(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""In `concat_message` mode the backend does NOT write a seed JSONL —
|
||||
it concatenates the prior history into the first stdin payload."""
|
||||
scripts_per_session = [
|
||||
[
|
||||
[_user_rec("how are you?", "S0"), _assistant_rec("good", "S0")],
|
||||
],
|
||||
]
|
||||
harness = FakeFactoryHarness(scripts_per_session)
|
||||
backend = ClaudeCodeBackend(
|
||||
BackendOptions(cwd=str(tmp_path), history_injection_mode="concat_message"),
|
||||
_session_factory=harness,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "remember beaver"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
{"role": "user", "content": "how are you?"},
|
||||
]
|
||||
async for _ in backend.complete(messages):
|
||||
pass
|
||||
await backend.aclose()
|
||||
|
||||
assert harness.seed_files == [] # no native injection in concat mode
|
||||
assert len(harness.spawned) == 1
|
||||
sent = harness.spawned[0].writes[0]
|
||||
# The first payload is the concat preamble + the new user prompt.
|
||||
assert "Previous conversation context:" in sent
|
||||
assert "[User]: remember beaver" in sent
|
||||
assert "[Assistant]: ok" in sent
|
||||
assert "Continue from here. New user message: how are you?" in sent
|
||||
|
||||
|
||||
# --- failure handling ----------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_failure_does_not_stash_broken_session(tmp_path: Path) -> None:
|
||||
"""If the turn iteration raises, the session must be closed and NOT
|
||||
re-stored under any fingerprint.
|
||||
"""
|
||||
|
||||
class BrokenFactory:
|
||||
def __init__(self) -> None:
|
||||
self.spawned: list[FakePty] = []
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
backend: ClaudeCodeBackend,
|
||||
session_id: str,
|
||||
resume: bool,
|
||||
jsonl_path: Path,
|
||||
start_offset: int,
|
||||
) -> Any:
|
||||
fake = FakePty(jsonl_path, session_id=session_id, scripts=[])
|
||||
self.spawned.append(fake)
|
||||
watcher = JsonlWatcher(jsonl_path, poll_interval=0.01)
|
||||
tm = TurnManager(
|
||||
fake, # type: ignore[arg-type]
|
||||
watcher,
|
||||
startup_delay=0.0,
|
||||
file_wait_timeout=0.05, # fires fast — no JSONL ever appears
|
||||
)
|
||||
|
||||
async def _start() -> _LiveSession:
|
||||
await tm.start()
|
||||
return _LiveSession(pty=fake, watcher=watcher, tm=tm) # type: ignore[arg-type]
|
||||
|
||||
return _start()
|
||||
|
||||
factory = BrokenFactory()
|
||||
backend = ClaudeCodeBackend(
|
||||
BackendOptions(cwd=str(tmp_path)),
|
||||
_session_factory=factory,
|
||||
)
|
||||
|
||||
with pytest.raises(SessionError):
|
||||
async for _ in backend.complete([{"role": "user", "content": "hi"}]):
|
||||
pass
|
||||
|
||||
assert backend.live_session_count == 0
|
||||
assert factory.spawned[0].closed is True
|
||||
await backend.aclose()
|
||||
|
||||
|
||||
# --- cancellation (Stage 9) ----------------------------------------------
|
||||
|
||||
|
||||
class _HangingFactory:
|
||||
"""Factory whose sessions never produce records — perfect for cancel tests.
|
||||
|
||||
`write()` creates the JSONL (so `wait_for_file()` returns immediately) but
|
||||
leaves it empty, so `TurnManager.send_user_message` enters its poll loop
|
||||
and stays there until something cancels it from outside.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.spawned: list[FakePty] = []
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
backend: ClaudeCodeBackend,
|
||||
session_id: str,
|
||||
resume: bool,
|
||||
jsonl_path: Path,
|
||||
start_offset: int,
|
||||
) -> Any:
|
||||
fake = FakePty(jsonl_path, session_id=session_id, scripts=[[]])
|
||||
self.spawned.append(fake)
|
||||
watcher = JsonlWatcher(jsonl_path, poll_interval=0.01)
|
||||
tm = TurnManager(
|
||||
fake, # type: ignore[arg-type]
|
||||
watcher,
|
||||
startup_delay=0.0,
|
||||
file_wait_timeout=2.0,
|
||||
)
|
||||
|
||||
async def _start() -> _LiveSession:
|
||||
await tm.start()
|
||||
return _LiveSession(pty=fake, watcher=watcher, tm=tm) # type: ignore[arg-type]
|
||||
|
||||
return _start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_mid_turn_closes_session_and_leaves_pool_empty(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""task.cancel() on a consumer iterating `complete()` must:
|
||||
- propagate CancelledError to the consumer,
|
||||
- tear down the live session (PTY closed via TurnManager.aclose),
|
||||
- leave the live-session pool empty (broken session is never re-stashed).
|
||||
"""
|
||||
factory = _HangingFactory()
|
||||
backend = ClaudeCodeBackend(
|
||||
BackendOptions(cwd=str(tmp_path)),
|
||||
_session_factory=factory,
|
||||
)
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def consumer() -> None:
|
||||
async for _ in backend.complete([{"role": "user", "content": "hi"}]):
|
||||
started.set()
|
||||
started.set() # also signal if iteration ends naturally (shouldn't here)
|
||||
|
||||
task = asyncio.create_task(consumer())
|
||||
# Let the turn enter its poll loop. The poll interval is 10ms; 200ms is
|
||||
# plenty for the FakePty.write() + first read_once() to land.
|
||||
await asyncio.sleep(0.2)
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
assert backend.live_session_count == 0
|
||||
assert len(factory.spawned) == 1
|
||||
assert factory.spawned[0].closed is True
|
||||
await backend.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_releases_lock_so_next_complete_works(tmp_path: Path) -> None:
|
||||
"""After a cancelled turn, the backend's internal lock must be released
|
||||
so a subsequent `complete()` can run. We follow up with a normal call
|
||||
against a healthy session and assert it completes end-to-end.
|
||||
"""
|
||||
|
||||
class HangThenRespondFactory:
|
||||
"""First spawn hangs (cancel target); second spawn completes a turn."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._spawn_index = 0
|
||||
self.spawned: list[FakePty] = []
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
backend: ClaudeCodeBackend,
|
||||
session_id: str,
|
||||
resume: bool,
|
||||
jsonl_path: Path,
|
||||
start_offset: int,
|
||||
) -> Any:
|
||||
idx = self._spawn_index
|
||||
self._spawn_index += 1
|
||||
if idx == 0:
|
||||
scripts: list[list[dict[str, Any]]] = [[]] # hangs
|
||||
else:
|
||||
scripts = [
|
||||
[
|
||||
_user_rec("hi", "S1"),
|
||||
_assistant_rec("hello", "S1"),
|
||||
]
|
||||
]
|
||||
fake = FakePty(jsonl_path, session_id=session_id, scripts=scripts)
|
||||
self.spawned.append(fake)
|
||||
watcher = JsonlWatcher(jsonl_path, poll_interval=0.01)
|
||||
tm = TurnManager(
|
||||
fake, # type: ignore[arg-type]
|
||||
watcher,
|
||||
startup_delay=0.0,
|
||||
file_wait_timeout=2.0,
|
||||
)
|
||||
|
||||
async def _start() -> _LiveSession:
|
||||
await tm.start()
|
||||
return _LiveSession(pty=fake, watcher=watcher, tm=tm) # type: ignore[arg-type]
|
||||
|
||||
return _start()
|
||||
|
||||
factory = HangThenRespondFactory()
|
||||
backend = ClaudeCodeBackend(
|
||||
BackendOptions(cwd=str(tmp_path)),
|
||||
_session_factory=factory,
|
||||
)
|
||||
|
||||
# First call: cancel mid-stream.
|
||||
async def consumer() -> None:
|
||||
async for _ in backend.complete([{"role": "user", "content": "hi"}]):
|
||||
pass
|
||||
|
||||
task = asyncio.create_task(consumer())
|
||||
await asyncio.sleep(0.2)
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
# Second call: must proceed without deadlocking on the lock.
|
||||
events: list[Any] = []
|
||||
async for e in backend.complete([{"role": "user", "content": "hi"}]):
|
||||
events.append(e)
|
||||
|
||||
assert len(factory.spawned) == 2
|
||||
assert factory.spawned[0].closed is True # cancelled session is dead
|
||||
assert isinstance(events[-1], ResultMessage)
|
||||
assert events[-1].num_turns == 1 # fresh session, fresh counter
|
||||
await backend.aclose()
|
||||
|
||||
|
||||
# --- mcp_servers materialization -----------------------------------------
|
||||
|
||||
|
||||
def test_mcp_config_argument_writes_temp_file_lazily(tmp_path: Path) -> None:
|
||||
"""`mcp_servers` lifts to a temp `--mcp-config` JSON written on first
|
||||
access; the file is removed in `aclose()`."""
|
||||
backend = ClaudeCodeBackend(
|
||||
BackendOptions(
|
||||
cwd=str(tmp_path),
|
||||
mcp_servers={"echo": {"command": "/bin/echo", "args": []}},
|
||||
)
|
||||
)
|
||||
paths = backend._mcp_config_argument() # type: ignore[attr-defined]
|
||||
assert len(paths) == 1
|
||||
p = Path(paths[0])
|
||||
assert p.exists()
|
||||
body = json.loads(p.read_text())
|
||||
assert body == {"mcpServers": {"echo": {"command": "/bin/echo", "args": []}}}
|
||||
|
||||
# Calling again returns the same path; no second file.
|
||||
paths2 = backend._mcp_config_argument() # type: ignore[attr-defined]
|
||||
assert paths2 == paths
|
||||
|
||||
# aclose() removes the file.
|
||||
asyncio.run(backend.aclose())
|
||||
assert not p.exists()
|
||||
|
||||
|
||||
def test_no_mcp_config_returns_empty_tuple(tmp_path: Path) -> None:
|
||||
backend = ClaudeCodeBackend(BackendOptions(cwd=str(tmp_path)))
|
||||
assert backend._mcp_config_argument() == () # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# --- post-turn fingerprint key shape -------------------------------------
|
||||
|
||||
|
||||
def test_post_turn_fingerprint_matches_canonical_continuation(tmp_path: Path) -> None:
|
||||
"""Regression: the backend stashes the live session under
|
||||
hash_history(messages + [synthesized_assistant]) where the synthesized
|
||||
assistant uses the `[{"type": "text", "text": ...}]` block shape.
|
||||
|
||||
A gateway that echoes that same shape back on the next request must
|
||||
look up to the same fingerprint. Pin both sides of that contract here.
|
||||
"""
|
||||
# Synthesized assistant after one turn yielding "hello there":
|
||||
synthesized = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello there"}],
|
||||
}
|
||||
messages_sent = [{"role": "user", "content": "hi"}]
|
||||
fp_stash = hash_history([*messages_sent, synthesized])
|
||||
|
||||
next_request_prior = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "hello there"}]},
|
||||
]
|
||||
fp_lookup = hash_history(next_request_prior)
|
||||
assert fp_stash == fp_lookup
|
||||
|
||||
|
||||
# --- smoke test (real claude) --------------------------------------------
|
||||
|
||||
|
||||
_SMOKE_ENV = "RUN_CLAUDE_SMOKE"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get(_SMOKE_ENV) != "1",
|
||||
reason=f"set {_SMOKE_ENV}=1 to run the real-`claude` smoke test",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_smoke_backend_round_trip(tmp_path: Path) -> None:
|
||||
"""End-to-end against real claude through the public API.
|
||||
|
||||
Single `complete()` call with no prior history → fresh session →
|
||||
yields events. Asserts the same shape contracts the gateway will
|
||||
rely on: at least one terminal assistant message and a final
|
||||
`ResultMessage` whose session_id matches the live PTY.
|
||||
"""
|
||||
backend = ClaudeCodeBackend(
|
||||
BackendOptions(cwd=str(tmp_path), dangerously_skip_permissions=True),
|
||||
)
|
||||
|
||||
events: list[Any] = []
|
||||
try:
|
||||
async for event in backend.complete([{"role": "user", "content": "say hi"}]):
|
||||
events.append(event)
|
||||
finally:
|
||||
await backend.aclose()
|
||||
|
||||
terminal = next(
|
||||
(
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, AssistantMessage)
|
||||
and e.stop_reason in {"end_turn", "max_tokens", "stop_sequence", "refusal"}
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert terminal is not None
|
||||
assert any(isinstance(b, TextBlock) for b in terminal.content)
|
||||
assert isinstance(events[-1], ResultMessage)
|
||||
assert events[-1].stop_reason == terminal.stop_reason
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get(_SMOKE_ENV) != "1",
|
||||
reason=f"set {_SMOKE_ENV}=1 to run the real-`claude` smoke test",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_smoke_backend_native_jsonl_injection(tmp_path: Path) -> None:
|
||||
"""Real claude, real injection: send a 3-message history (no live
|
||||
session yet), the backend writes a seed JSONL and resumes — the
|
||||
assistant reply must reference the seeded context.
|
||||
"""
|
||||
backend = ClaudeCodeBackend(
|
||||
BackendOptions(cwd=str(tmp_path), dangerously_skip_permissions=True),
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "My name is Beaver. Please remember it."},
|
||||
{"role": "assistant", "content": "Got it — your name is Beaver."},
|
||||
{"role": "user", "content": "What is my name? Answer with just the name, one word."},
|
||||
]
|
||||
events: list[Any] = []
|
||||
try:
|
||||
async for event in backend.complete(messages):
|
||||
events.append(event)
|
||||
finally:
|
||||
await backend.aclose()
|
||||
|
||||
# The seeded JSONL should be visible on disk under the session path.
|
||||
# (We can't easily get the session_id back here, but the test of
|
||||
# correctness is in the reply.)
|
||||
terminal = next(
|
||||
(
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, AssistantMessage) and e.stop_reason == "end_turn"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert terminal is not None
|
||||
text = " ".join(b.text for b in terminal.content if isinstance(b, TextBlock))
|
||||
assert "beaver" in text.lower(), f"injection failed to plant context; got {text!r}"
|
||||
|
||||
# Sanity: the file the backend resumed against exists and contains our seed.
|
||||
session_id = events[-1].session_id # type: ignore[union-attr]
|
||||
assert isinstance(session_id, str)
|
||||
jsonl_path = resolve_jsonl_path(str(tmp_path), session_id)
|
||||
assert jsonl_path.exists()
|
||||
# The seeded user record's content text is in the file.
|
||||
assert "My name is Beaver" in jsonl_path.read_text()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get(_SMOKE_ENV) != "1",
|
||||
reason=f"set {_SMOKE_ENV}=1 to run the real-`claude` smoke test",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_smoke_cancellation_kills_pty_no_zombie(tmp_path: Path) -> None:
|
||||
"""Smoke 4 (Stage 9): cancel a real long-running turn, assert the PTY
|
||||
dies cleanly with no zombie left behind.
|
||||
|
||||
Strategy:
|
||||
- prompt claude with something verbose so the turn stays in flight
|
||||
long enough for us to cancel mid-stream;
|
||||
- wrap the spawn through `_session_factory` so we can capture the
|
||||
live `PtyClaudeProcess` while it's still in flight (the backend
|
||||
does NOT keep in-flight sessions in `_sessions`);
|
||||
- cancel the consumer task as soon as we've seen at least one event
|
||||
(proving the turn really started — otherwise we'd be cancelling a
|
||||
not-yet-spawned session);
|
||||
- after the cancel propagates, assert: PTY is dead (no `kill -0`),
|
||||
pool is empty, and a second `complete()` on the same backend still
|
||||
works (lock was released).
|
||||
"""
|
||||
import signal as _signal
|
||||
|
||||
captured: list[Any] = [] # collected _LiveSession objects
|
||||
|
||||
backend_box: dict[str, ClaudeCodeBackend] = {}
|
||||
|
||||
def capturing_factory(
|
||||
backend: ClaudeCodeBackend,
|
||||
session_id: str,
|
||||
resume: bool,
|
||||
jsonl_path: Path,
|
||||
start_offset: int,
|
||||
) -> Any:
|
||||
async def _real() -> Any:
|
||||
session = await backend._spawn_real_session( # type: ignore[attr-defined]
|
||||
session_id=session_id,
|
||||
resume=resume,
|
||||
jsonl_path=jsonl_path,
|
||||
start_offset=start_offset,
|
||||
)
|
||||
captured.append(session)
|
||||
return session
|
||||
|
||||
return _real()
|
||||
|
||||
backend = ClaudeCodeBackend(
|
||||
BackendOptions(cwd=str(tmp_path), dangerously_skip_permissions=True),
|
||||
_session_factory=capturing_factory,
|
||||
)
|
||||
backend_box["b"] = backend
|
||||
|
||||
saw_event = asyncio.Event()
|
||||
events: list[Any] = []
|
||||
|
||||
long_prompt = (
|
||||
"Please count slowly from 1 to 500, one number per line, in plain text. "
|
||||
"Do not stop until you reach 500."
|
||||
)
|
||||
|
||||
async def consumer() -> None:
|
||||
async for event in backend.complete([{"role": "user", "content": long_prompt}]):
|
||||
events.append(event)
|
||||
saw_event.set()
|
||||
|
||||
task = asyncio.create_task(consumer())
|
||||
try:
|
||||
# Wait until we have at least one event so we know the turn is in
|
||||
# flight on a live PTY. 30s is comfortably above the typical
|
||||
# spawn + first-record latency (~3-5s for cold claude startup).
|
||||
await asyncio.wait_for(saw_event.wait(), timeout=30.0)
|
||||
|
||||
assert len(captured) == 1, (
|
||||
f"expected exactly one captured session at cancel time; got {len(captured)}"
|
||||
)
|
||||
live = captured[0]
|
||||
pid = live.pty.pid
|
||||
assert pid is not None and pid > 0
|
||||
assert live.pty.is_alive() is True
|
||||
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
# SIGTERM ladder runs inside session.aclose() during cleanup, so by
|
||||
# the time `await task` returns the PTY has been reaped.
|
||||
assert live.pty.is_alive() is False, "PTY still alive after cancel cleanup"
|
||||
assert backend.live_session_count == 0, (
|
||||
"cancelled session must not be re-stashed in the live pool"
|
||||
)
|
||||
|
||||
# Belt-and-suspenders: confirm the OS no longer has the pid.
|
||||
# `os.kill(pid, 0)` raises ProcessLookupError when the process is gone;
|
||||
# any other state (zombie not yet reaped, still alive) raises something
|
||||
# else or returns successfully. We accept both ProcessLookupError and
|
||||
# the kernel reporting the pid is gone.
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
# If we got here, the pid is still claimable. With pty.close(force=True)
|
||||
# in _reap that shouldn't happen, but on macOS the reap might race
|
||||
# very briefly — give it one more beat.
|
||||
await asyncio.sleep(0.2)
|
||||
with pytest.raises(ProcessLookupError):
|
||||
os.kill(pid, 0)
|
||||
except ProcessLookupError:
|
||||
pass # expected: process is gone
|
||||
|
||||
# Lock released — a fresh call must still work end-to-end.
|
||||
followup_events: list[Any] = []
|
||||
async for ev in backend.complete([{"role": "user", "content": "say hi"}]):
|
||||
followup_events.append(ev)
|
||||
assert isinstance(followup_events[-1], ResultMessage), (
|
||||
"follow-up turn failed; backend may have leaked state after cancel"
|
||||
)
|
||||
finally:
|
||||
# Defensive: if anything above failed, make sure we don't leave a
|
||||
# zombie claude around for the next test run.
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
with contextlib.suppress(BaseException):
|
||||
await task
|
||||
for s in captured:
|
||||
if s.pty.is_alive():
|
||||
with contextlib.suppress(BaseException):
|
||||
s.pty._pty.kill(_signal.SIGKILL) # type: ignore[union-attr]
|
||||
await backend.aclose()
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Unit tests for the Stage 10 error hierarchy + PTY-output classifier."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_api import (
|
||||
AuthError,
|
||||
BackendError,
|
||||
CLINotFoundError,
|
||||
MessageParseError,
|
||||
ProcessError,
|
||||
RateLimitError,
|
||||
SessionError,
|
||||
classify_pty_failure,
|
||||
)
|
||||
|
||||
|
||||
def test_hierarchy_roots_under_backend_error() -> None:
|
||||
# Every backend-emitted exception must descend from BackendError so a
|
||||
# gateway can install a single catch-all handler.
|
||||
for cls in (
|
||||
AuthError,
|
||||
MessageParseError,
|
||||
ProcessError,
|
||||
RateLimitError,
|
||||
SessionError,
|
||||
):
|
||||
assert issubclass(cls, BackendError)
|
||||
assert issubclass(CLINotFoundError, ProcessError)
|
||||
|
||||
|
||||
def test_process_error_carries_exit_code_and_stderr_in_message() -> None:
|
||||
exc = ProcessError("boom", exit_code=7, stderr="line1\nline2")
|
||||
assert exc.exit_code == 7
|
||||
assert exc.stderr == "line1\nline2"
|
||||
rendered = str(exc)
|
||||
assert "boom" in rendered
|
||||
assert "exit code: 7" in rendered
|
||||
assert "line1" in rendered # included in the tail
|
||||
|
||||
|
||||
def test_process_error_tail_caps_huge_stderr() -> None:
|
||||
# A 5KB blob should not embed wholesale in the message.
|
||||
blob = "x" * 5000
|
||||
exc = ProcessError("oops", stderr=blob)
|
||||
rendered = str(exc)
|
||||
# Tail is capped to 2000 chars in the formatter.
|
||||
assert rendered.count("x") <= 2000 + 10 # +slack for any literal 'x' in prefix
|
||||
|
||||
|
||||
def test_cli_not_found_appends_executable() -> None:
|
||||
exc = CLINotFoundError(executable="/usr/local/bin/claude")
|
||||
assert "/usr/local/bin/claude" in str(exc)
|
||||
assert exc.executable == "/usr/local/bin/claude"
|
||||
# Default constructor is also valid.
|
||||
bare = CLINotFoundError()
|
||||
assert "not found" in str(bare).lower()
|
||||
|
||||
|
||||
def test_classify_pty_failure_returns_none_when_no_marker() -> None:
|
||||
assert classify_pty_failure(b"the model is thinking...") is None
|
||||
assert classify_pty_failure("") is None
|
||||
|
||||
|
||||
def test_classify_auth_markers() -> None:
|
||||
assert classify_pty_failure(b"Failed to authenticate (status 401)") is AuthError
|
||||
assert classify_pty_failure(b"API Error: 403 Forbidden") is AuthError
|
||||
# claude-p's compact match handles "Please run /login" even when ANSI
|
||||
# / spinner punctuation splits the words.
|
||||
assert classify_pty_failure(b"Please run /login to continue.") is AuthError
|
||||
assert (
|
||||
classify_pty_failure(b"\x1b[31mPlease\x1b[0m run /login")
|
||||
is AuthError
|
||||
)
|
||||
|
||||
|
||||
def test_classify_rate_limit_markers() -> None:
|
||||
assert classify_pty_failure(b"You've hit your limit. Try again later.") is RateLimitError
|
||||
assert classify_pty_failure(b"You have hit your limit.") is RateLimitError
|
||||
# Bare form (TUI sometimes wraps the noun out).
|
||||
assert classify_pty_failure(b"hit your limit") is RateLimitError
|
||||
|
||||
|
||||
def test_classify_strips_ansi_before_matching() -> None:
|
||||
# Common SGR sequences should not block the marker.
|
||||
coloured = b"\x1b[1;31mYou've hit your limit\x1b[0m"
|
||||
assert classify_pty_failure(coloured) is RateLimitError
|
||||
|
||||
|
||||
def test_classify_accepts_str_or_bytes() -> None:
|
||||
assert classify_pty_failure("Failed to authenticate") is AuthError
|
||||
assert classify_pty_failure(b"Failed to authenticate") is AuthError
|
||||
|
||||
|
||||
def test_auth_and_rate_limit_default_messages() -> None:
|
||||
# Default messages are descriptive enough to surface to a gateway.
|
||||
assert "auth" in str(AuthError()).lower()
|
||||
assert "rate" in str(RateLimitError()).lower() or "limit" in str(RateLimitError()).lower()
|
||||
|
||||
|
||||
def test_session_error_is_plain_backend_error() -> None:
|
||||
# No special fields — just a typed marker.
|
||||
exc = SessionError("never appeared")
|
||||
assert isinstance(exc, BackendError)
|
||||
assert "never appeared" in str(exc)
|
||||
|
||||
|
||||
def test_message_parse_error_carries_data() -> None:
|
||||
payload = {"oops": True}
|
||||
exc = MessageParseError("bad shape", data=payload)
|
||||
assert exc.data is payload
|
||||
|
||||
|
||||
def test_session_error_is_not_a_timeout_error() -> None:
|
||||
# We deliberately broke the TimeoutError lineage: gateways that used to
|
||||
# catch TimeoutError must migrate to SessionError. Pin that.
|
||||
assert not issubclass(SessionError, TimeoutError)
|
||||
|
||||
|
||||
def test_raise_chain_smoke() -> None:
|
||||
with pytest.raises(AuthError):
|
||||
raise AuthError()
|
||||
with pytest.raises(BackendError):
|
||||
raise RateLimitError()
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Unit tests for `history_injection` helpers.
|
||||
|
||||
Pure functions, no claude / no filesystem. The seed-JSONL shape is regression
|
||||
tested against the same minimal contract that `probe_jsonl_injection.py`
|
||||
proved out empirically (see FINDINGS § *Native JSONL injection works on
|
||||
--resume*).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_api.injection import (
|
||||
build_concat_prompt,
|
||||
build_seed_jsonl,
|
||||
hash_history,
|
||||
)
|
||||
|
||||
# --- hash_history ---------------------------------------------------------
|
||||
|
||||
|
||||
def test_hash_history_empty_is_stable() -> None:
|
||||
assert hash_history([]) == hash_history([])
|
||||
|
||||
|
||||
def test_hash_history_distinguishes_content() -> None:
|
||||
a = [{"role": "user", "content": "hi"}]
|
||||
b = [{"role": "user", "content": "bye"}]
|
||||
assert hash_history(a) != hash_history(b)
|
||||
|
||||
|
||||
def test_hash_history_ignores_block_key_order() -> None:
|
||||
"""Two clients that serialize the same block in different key orders
|
||||
must collide. Canonical-JSON serialization handles this."""
|
||||
a = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "t1", "name": "echo", "input": {"x": 1}}],
|
||||
}
|
||||
]
|
||||
b = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"input": {"x": 1}, "name": "echo", "id": "t1", "type": "tool_use"}],
|
||||
}
|
||||
]
|
||||
assert hash_history(a) == hash_history(b)
|
||||
|
||||
|
||||
def test_hash_history_rejects_unknown_role() -> None:
|
||||
with pytest.raises(ValueError, match="role"):
|
||||
hash_history([{"role": "system", "content": "x"}])
|
||||
|
||||
|
||||
def test_hash_history_text_blocks_collide_with_string_form() -> None:
|
||||
"""A bare string `content` and the equivalent single text block hash to
|
||||
DIFFERENT values. They represent the same semantic content but appear
|
||||
on the wire differently — the gateway must pick one form per role and
|
||||
stay consistent. We don't try to paper over that here."""
|
||||
a = [{"role": "user", "content": "hello"}]
|
||||
b = [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]
|
||||
assert hash_history(a) != hash_history(b)
|
||||
|
||||
|
||||
# --- build_seed_jsonl -----------------------------------------------------
|
||||
|
||||
|
||||
def test_build_seed_jsonl_empty_is_empty_string() -> None:
|
||||
assert build_seed_jsonl([], session_id="s", cwd="/tmp") == ""
|
||||
|
||||
|
||||
def test_build_seed_jsonl_two_records_for_one_turn() -> None:
|
||||
seed = build_seed_jsonl(
|
||||
[
|
||||
{"role": "user", "content": "My name is Beaver."},
|
||||
{"role": "assistant", "content": "Got it."},
|
||||
],
|
||||
session_id="sid-1",
|
||||
cwd="/work",
|
||||
)
|
||||
lines = [json.loads(line) for line in seed.strip().splitlines()]
|
||||
assert len(lines) == 2
|
||||
user_rec, asst_rec = lines
|
||||
|
||||
assert user_rec["type"] == "user"
|
||||
assert user_rec["sessionId"] == "sid-1"
|
||||
assert user_rec["cwd"] == "/work"
|
||||
assert user_rec["parentUuid"] is None
|
||||
assert user_rec["message"] == {"role": "user", "content": "My name is Beaver."}
|
||||
assert user_rec["isMeta"] is False
|
||||
assert "uuid" in user_rec and "timestamp" in user_rec
|
||||
|
||||
assert asst_rec["type"] == "assistant"
|
||||
assert asst_rec["parentUuid"] == user_rec["uuid"]
|
||||
assert asst_rec["message"]["role"] == "assistant"
|
||||
assert asst_rec["message"]["content"] == [{"type": "text", "text": "Got it."}]
|
||||
assert asst_rec["message"]["stop_reason"] == "end_turn"
|
||||
assert asst_rec["sessionId"] == "sid-1"
|
||||
|
||||
|
||||
def test_build_seed_jsonl_chains_parent_uuids_across_turns() -> None:
|
||||
"""The parentUuid graph must form a linear chain across turns — that's
|
||||
how claude reconstructs conversation order on resume."""
|
||||
seed = build_seed_jsonl(
|
||||
[
|
||||
{"role": "user", "content": "u1"},
|
||||
{"role": "assistant", "content": "a1"},
|
||||
{"role": "user", "content": "u2"},
|
||||
{"role": "assistant", "content": "a2"},
|
||||
],
|
||||
session_id="s",
|
||||
cwd="/tmp",
|
||||
)
|
||||
recs = [json.loads(line) for line in seed.strip().splitlines()]
|
||||
assert len(recs) == 4
|
||||
assert recs[0]["parentUuid"] is None
|
||||
assert recs[1]["parentUuid"] == recs[0]["uuid"]
|
||||
assert recs[2]["parentUuid"] == recs[1]["uuid"]
|
||||
assert recs[3]["parentUuid"] == recs[2]["uuid"]
|
||||
|
||||
|
||||
def test_build_seed_jsonl_passes_list_content_through_for_user() -> None:
|
||||
"""A user record with a tool_result block (the only list-form user
|
||||
content claude itself writes) must round-trip verbatim."""
|
||||
seed = build_seed_jsonl(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "t1", "content": "42"},
|
||||
],
|
||||
}
|
||||
],
|
||||
session_id="s",
|
||||
cwd="/tmp",
|
||||
)
|
||||
rec = json.loads(seed.strip())
|
||||
assert rec["message"]["content"] == [
|
||||
{"type": "tool_result", "tool_use_id": "t1", "content": "42"},
|
||||
]
|
||||
|
||||
|
||||
def test_build_seed_jsonl_rejects_unknown_role() -> None:
|
||||
with pytest.raises(ValueError, match="role"):
|
||||
build_seed_jsonl([{"role": "system", "content": "x"}], session_id="s", cwd="/tmp")
|
||||
|
||||
|
||||
# --- build_concat_prompt --------------------------------------------------
|
||||
|
||||
|
||||
def test_build_concat_prompt_empty_history_returns_just_last_user() -> None:
|
||||
assert build_concat_prompt([], "hello") == "hello"
|
||||
|
||||
|
||||
def test_build_concat_prompt_renders_alternating_history() -> None:
|
||||
out = build_concat_prompt(
|
||||
[
|
||||
{"role": "user", "content": "u1"},
|
||||
{"role": "assistant", "content": "a1"},
|
||||
{"role": "user", "content": "u2"},
|
||||
{"role": "assistant", "content": "a2"},
|
||||
],
|
||||
"u3",
|
||||
)
|
||||
assert "Previous conversation context:" in out
|
||||
assert "[User]: u1" in out
|
||||
assert "[Assistant]: a1" in out
|
||||
assert "[User]: u2" in out
|
||||
assert "[Assistant]: a2" in out
|
||||
assert "Continue from here. New user message: u3" in out
|
||||
# The new prompt must come after the history, not interleaved.
|
||||
assert out.index("[Assistant]: a2") < out.index("Continue from here")
|
||||
|
||||
|
||||
def test_build_concat_prompt_flattens_text_blocks_and_skips_tools() -> None:
|
||||
"""Content-as-list with text blocks gets flattened; tool blocks are
|
||||
skipped (they don't round-trip through stdin in any useful form)."""
|
||||
out = build_concat_prompt(
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{"type": "tool_use", "id": "t1", "name": "x", "input": {}},
|
||||
{"type": "text", "text": "world"},
|
||||
],
|
||||
},
|
||||
],
|
||||
"ping",
|
||||
)
|
||||
assert "[Assistant]: hello world" in out
|
||||
@@ -0,0 +1,421 @@
|
||||
"""Unit tests for Layer 3 (`event_normalizer.normalize`).
|
||||
|
||||
All fixtures are hand-built dicts shaped like real records observed under
|
||||
``~/.claude/projects/``; no `claude` is invoked. The normalizer is a pure
|
||||
function so every test is a one-shot ``normalize(record) -> Event | None``
|
||||
assertion.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_api import (
|
||||
AssistantMessage,
|
||||
MessageParseError,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
normalize,
|
||||
)
|
||||
|
||||
# --- envelope metadata shared by every record observed in the wild ---------
|
||||
|
||||
_ENVELOPE: dict[str, Any] = {
|
||||
"parentUuid": "parent-uuid",
|
||||
"isSidechain": False,
|
||||
"uuid": "rec-uuid",
|
||||
"timestamp": "2026-05-16T20:17:27.664Z",
|
||||
"userType": "external",
|
||||
"entrypoint": "cli",
|
||||
"cwd": "/some/cwd",
|
||||
"sessionId": "sess-uuid",
|
||||
"version": "2.1.143",
|
||||
"gitBranch": "HEAD",
|
||||
}
|
||||
|
||||
|
||||
def _envelope(extra: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Compose a record with the standard envelope plus the type-specific bits."""
|
||||
return {**_ENVELOPE, **extra}
|
||||
|
||||
|
||||
# --- user records ----------------------------------------------------------
|
||||
|
||||
|
||||
def test_user_string_content() -> None:
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": "hello there"},
|
||||
}
|
||||
)
|
||||
event = normalize(rec)
|
||||
assert isinstance(event, UserMessage)
|
||||
assert event.content == "hello there"
|
||||
assert event.uuid == "rec-uuid"
|
||||
assert event.session_id == "sess-uuid"
|
||||
assert event.parent_uuid == "parent-uuid"
|
||||
|
||||
|
||||
def test_user_tool_result_content() -> None:
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "user",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_01",
|
||||
"content": "stdout body",
|
||||
"is_error": False,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
event = normalize(rec)
|
||||
assert isinstance(event, UserMessage)
|
||||
assert isinstance(event.content, list)
|
||||
assert event.content == [
|
||||
ToolResultBlock(
|
||||
tool_use_id="toolu_01",
|
||||
content="stdout body",
|
||||
is_error=False,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_user_meta_filtered_by_default() -> None:
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "user",
|
||||
"isMeta": True,
|
||||
"message": {"role": "user", "content": "<local-command-caveat>...</...>"},
|
||||
}
|
||||
)
|
||||
assert normalize(rec) is None
|
||||
|
||||
|
||||
def test_user_meta_emitted_when_opt_in() -> None:
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "user",
|
||||
"isMeta": True,
|
||||
"message": {"role": "user", "content": "x"},
|
||||
}
|
||||
)
|
||||
event = normalize(rec, include_meta_user=True)
|
||||
assert isinstance(event, UserMessage)
|
||||
assert event.content == "x"
|
||||
|
||||
|
||||
def test_user_missing_message_raises() -> None:
|
||||
rec = _envelope({"type": "user"})
|
||||
with pytest.raises(MessageParseError, match="user record missing"):
|
||||
normalize(rec)
|
||||
|
||||
|
||||
def test_user_content_wrong_type_raises() -> None:
|
||||
rec = _envelope({"type": "user", "message": {"content": 42}})
|
||||
with pytest.raises(MessageParseError, match="content must be str or list"):
|
||||
normalize(rec)
|
||||
|
||||
|
||||
# --- assistant records -----------------------------------------------------
|
||||
|
||||
|
||||
def test_assistant_text_only() -> None:
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"model": "claude-opus-4-7",
|
||||
"id": "msg_01",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hi"}],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 1, "output_tokens": 2},
|
||||
},
|
||||
}
|
||||
)
|
||||
event = normalize(rec)
|
||||
assert isinstance(event, AssistantMessage)
|
||||
assert event.content == [TextBlock(text="hi")]
|
||||
assert event.model == "claude-opus-4-7"
|
||||
assert event.message_id == "msg_01"
|
||||
assert event.stop_reason == "end_turn"
|
||||
assert event.usage == {"input_tokens": 1, "output_tokens": 2}
|
||||
|
||||
|
||||
def test_assistant_all_block_types() -> None:
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"model": "claude-opus-4-7",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "...", "signature": "sig"},
|
||||
{"type": "text", "text": "calling tool"},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_01",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"},
|
||||
},
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
},
|
||||
}
|
||||
)
|
||||
event = normalize(rec)
|
||||
assert isinstance(event, AssistantMessage)
|
||||
assert event.content == [
|
||||
ThinkingBlock(thinking="...", signature="sig"),
|
||||
TextBlock(text="calling tool"),
|
||||
ToolUseBlock(id="toolu_01", name="Bash", input={"command": "ls"}),
|
||||
]
|
||||
assert event.stop_reason == "tool_use"
|
||||
|
||||
|
||||
def test_assistant_streaming_chunk_has_null_stop_reason() -> None:
|
||||
# claude writes partial assistant records mid-turn with stop_reason=null;
|
||||
# the normalizer surfaces the None so TurnManager can tell partial from
|
||||
# terminal.
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"model": "claude-opus-4-7",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "partial"}],
|
||||
"stop_reason": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
event = normalize(rec)
|
||||
assert isinstance(event, AssistantMessage)
|
||||
assert event.stop_reason is None
|
||||
|
||||
|
||||
def test_assistant_missing_model_raises() -> None:
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "assistant",
|
||||
"message": {"role": "assistant", "content": []},
|
||||
}
|
||||
)
|
||||
with pytest.raises(MessageParseError, match="assistant record missing"):
|
||||
normalize(rec)
|
||||
|
||||
|
||||
def test_assistant_content_not_list_raises() -> None:
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"model": "claude-opus-4-7",
|
||||
"role": "assistant",
|
||||
"content": "not a list",
|
||||
},
|
||||
}
|
||||
)
|
||||
with pytest.raises(MessageParseError, match="content must be a list"):
|
||||
normalize(rec)
|
||||
|
||||
|
||||
def test_assistant_unknown_block_type_raises() -> None:
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"model": "claude-opus-4-7",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "image", "data": "..."}],
|
||||
},
|
||||
}
|
||||
)
|
||||
with pytest.raises(MessageParseError, match="unknown content block type"):
|
||||
normalize(rec)
|
||||
|
||||
|
||||
def test_assistant_tool_use_missing_id_raises() -> None:
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"model": "claude-opus-4-7",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "name": "X", "input": {}}],
|
||||
},
|
||||
}
|
||||
)
|
||||
with pytest.raises(MessageParseError, match="tool_use block missing"):
|
||||
normalize(rec)
|
||||
|
||||
|
||||
# --- system records --------------------------------------------------------
|
||||
|
||||
|
||||
def test_system_turn_duration_surfaced() -> None:
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "system",
|
||||
"subtype": "turn_duration",
|
||||
"durationMs": 1234,
|
||||
"messageCount": 5,
|
||||
"isMeta": False,
|
||||
}
|
||||
)
|
||||
event = normalize(rec)
|
||||
assert isinstance(event, SystemMessage)
|
||||
assert event.subtype == "turn_duration"
|
||||
assert event.session_id == "sess-uuid"
|
||||
# `data` mirrors the full raw record so callers can pull `durationMs`
|
||||
# without re-parsing.
|
||||
assert event.data["durationMs"] == 1234
|
||||
assert event.data["messageCount"] == 5
|
||||
|
||||
|
||||
def test_system_stop_hook_summary_filtered() -> None:
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "system",
|
||||
"subtype": "stop_hook_summary",
|
||||
"hookCount": 0,
|
||||
"hookInfos": [],
|
||||
}
|
||||
)
|
||||
assert normalize(rec) is None
|
||||
|
||||
|
||||
def test_system_local_command_filtered() -> None:
|
||||
rec = _envelope(
|
||||
{
|
||||
"type": "system",
|
||||
"subtype": "local_command",
|
||||
"content": "<local-command-stdout></local-command-stdout>",
|
||||
}
|
||||
)
|
||||
assert normalize(rec) is None
|
||||
|
||||
|
||||
def test_system_missing_subtype_raises() -> None:
|
||||
rec = _envelope({"type": "system"})
|
||||
with pytest.raises(MessageParseError, match="system record missing 'subtype'"):
|
||||
normalize(rec)
|
||||
|
||||
|
||||
# --- filtered top-level types ---------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"record_type",
|
||||
[
|
||||
"attachment",
|
||||
"file-history-snapshot",
|
||||
"last-prompt",
|
||||
"ai-title",
|
||||
"permission-mode",
|
||||
"queue-operation",
|
||||
],
|
||||
)
|
||||
def test_bookkeeping_types_filtered(record_type: str) -> None:
|
||||
rec = _envelope({"type": record_type})
|
||||
assert normalize(rec) is None
|
||||
|
||||
|
||||
def test_unknown_type_silently_dropped() -> None:
|
||||
# forward-compat: a brand-new top-level record type from a future claude
|
||||
# version is dropped, not raised.
|
||||
rec = _envelope({"type": "some-new-record-type"})
|
||||
assert normalize(rec) is None
|
||||
|
||||
|
||||
# --- error path ------------------------------------------------------------
|
||||
|
||||
|
||||
def test_non_dict_record_raises() -> None:
|
||||
with pytest.raises(MessageParseError, match="must be a dict"):
|
||||
normalize("not a dict") # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_record_missing_type_raises() -> None:
|
||||
rec = _envelope({})
|
||||
with pytest.raises(MessageParseError, match="record missing 'type'"):
|
||||
normalize(rec)
|
||||
|
||||
|
||||
# --- regression fixtures from real session ---------------------------------
|
||||
|
||||
|
||||
def test_real_user_string_record() -> None:
|
||||
"""Copy-paste of an actual user prompt record from a 2.1.143 session."""
|
||||
rec = {
|
||||
"parentUuid": None,
|
||||
"isSidechain": False,
|
||||
"promptId": "364db1ee-f587-4096-bc6c-0dc4323512dc",
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": "What is my name?"},
|
||||
"uuid": "97968a26-6466-4410-84db-2077e65573e1",
|
||||
"timestamp": "2026-05-16T20:17:27.664Z",
|
||||
"userType": "external",
|
||||
"entrypoint": "cli",
|
||||
"cwd": "/Users/h/projects/playgrounds/claude-code-sdk",
|
||||
"sessionId": "4df01eee-6026-4782-bdba-d67ab47a3e5b",
|
||||
"version": "2.1.143",
|
||||
"gitBranch": "HEAD",
|
||||
}
|
||||
event = normalize(rec)
|
||||
assert isinstance(event, UserMessage)
|
||||
assert event.content == "What is my name?"
|
||||
assert event.parent_uuid is None
|
||||
|
||||
|
||||
def test_real_assistant_tool_use_record() -> None:
|
||||
"""Copy-paste of a real ``stop_reason=tool_use`` assistant record."""
|
||||
rec = {
|
||||
"parentUuid": "97968a26-6466-4410-84db-2077e65573e1",
|
||||
"isSidechain": False,
|
||||
"message": {
|
||||
"model": "claude-opus-4-7",
|
||||
"id": "msg_019Sy3eBbN24Y6YwgxuMvN7g",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "...", "signature": "sig"},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_01XCXcKt7TaDbAKscRPpvumi",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"},
|
||||
},
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"usage": {
|
||||
"input_tokens": 6,
|
||||
"cache_creation_input_tokens": 11211,
|
||||
"cache_read_input_tokens": 17654,
|
||||
"output_tokens": 172,
|
||||
},
|
||||
},
|
||||
"requestId": "req_011Cb6s6f7fhCRgo2yhNZY9G",
|
||||
"type": "assistant",
|
||||
"uuid": "14e394aa-9faa-4448-8a6c-1365bf2acb8a",
|
||||
"sessionId": "4df01eee-6026-4782-bdba-d67ab47a3e5b",
|
||||
}
|
||||
event = normalize(rec)
|
||||
assert isinstance(event, AssistantMessage)
|
||||
assert event.stop_reason == "tool_use"
|
||||
assert event.usage is not None
|
||||
assert event.usage["cache_read_input_tokens"] == 17654
|
||||
assert len(event.content) == 2
|
||||
assert isinstance(event.content[1], ToolUseBlock)
|
||||
assert event.content[1].name == "Bash"
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Unit tests for `jsonl_paths` — pure string transforms + light fs lookup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_api.paths import (
|
||||
claude_home,
|
||||
encode_project_key,
|
||||
find_jsonl_by_session_id,
|
||||
projects_root,
|
||||
resolve_jsonl_path,
|
||||
session_dir,
|
||||
)
|
||||
|
||||
# ----- encode_project_key -----------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("cwd", "expected"),
|
||||
[
|
||||
# Observed in this repo: bare alnum + slashes + literal dashes.
|
||||
(
|
||||
"/Users/h/projects/playgrounds/claude-code-sdk",
|
||||
"-Users-h-projects-playgrounds-claude-code-sdk",
|
||||
),
|
||||
# Observed: dot-prefixed dir produces doubled dash, dash-containing
|
||||
# path segments survive unchanged.
|
||||
(
|
||||
"/Users/h/.t3/worktrees/cars-system/t3code-9d8591ad",
|
||||
"-Users-h--t3-worktrees-cars-system-t3code-9d8591ad",
|
||||
),
|
||||
# Trailing slash collapses to a trailing dash — claude would not
|
||||
# normally see this, but the encoder is deterministic.
|
||||
("/Users/h/", "-Users-h-"),
|
||||
# Root.
|
||||
("/", "-"),
|
||||
# Spaces, parentheses, other punct all become dashes.
|
||||
("/tmp/My Project (v2)", "-tmp-My-Project--v2-"),
|
||||
],
|
||||
)
|
||||
def test_encode_known_paths(cwd: str, expected: str) -> None:
|
||||
assert encode_project_key(cwd) == expected
|
||||
|
||||
|
||||
def test_encode_rejects_relative() -> None:
|
||||
with pytest.raises(ValueError, match="absolute"):
|
||||
encode_project_key("relative/path")
|
||||
|
||||
|
||||
def test_encode_rejects_empty() -> None:
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
encode_project_key("")
|
||||
|
||||
|
||||
# ----- resolve_jsonl_path / session_dir --------------------------------------
|
||||
|
||||
|
||||
def test_resolve_jsonl_path_under_fake_home(tmp_path):
|
||||
sid = "deadbeef-0000-4000-8000-000000000001"
|
||||
p = resolve_jsonl_path("/foo/bar", sid, home=tmp_path)
|
||||
assert p == tmp_path / ".claude" / "projects" / "-foo-bar" / f"{sid}.jsonl"
|
||||
|
||||
|
||||
def test_session_dir_matches_resolve_parent(tmp_path):
|
||||
sid = "deadbeef-0000-4000-8000-000000000002"
|
||||
assert resolve_jsonl_path("/a/b", sid, home=tmp_path).parent == session_dir(
|
||||
"/a/b", home=tmp_path
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_rejects_empty_session_id(tmp_path):
|
||||
with pytest.raises(ValueError, match="session_id"):
|
||||
resolve_jsonl_path("/foo", "", home=tmp_path)
|
||||
|
||||
|
||||
def test_claude_home_and_projects_root_honor_override(tmp_path):
|
||||
assert claude_home(tmp_path) == tmp_path / ".claude"
|
||||
assert projects_root(tmp_path) == tmp_path / ".claude" / "projects"
|
||||
|
||||
|
||||
# ----- find_jsonl_by_session_id ---------------------------------------------
|
||||
|
||||
|
||||
def test_find_returns_none_when_root_missing(tmp_path):
|
||||
# No `.claude/projects` under tmp_path.
|
||||
assert find_jsonl_by_session_id("nope", home=tmp_path) is None
|
||||
|
||||
|
||||
def test_find_locates_existing_session(tmp_path):
|
||||
sid = "abcdef00-1111-4000-8000-000000000000"
|
||||
p = resolve_jsonl_path("/some/cwd", sid, home=tmp_path)
|
||||
p.parent.mkdir(parents=True)
|
||||
p.write_text("{}\n")
|
||||
found = find_jsonl_by_session_id(sid, home=tmp_path)
|
||||
assert found == p
|
||||
|
||||
|
||||
def test_find_rejects_empty_session_id(tmp_path):
|
||||
with pytest.raises(ValueError, match="session_id"):
|
||||
find_jsonl_by_session_id("", home=tmp_path)
|
||||
@@ -0,0 +1,261 @@
|
||||
"""Unit + smoke tests for Layer 1 (`PtyClaudeProcess`).
|
||||
|
||||
Unit tests exercise pure argv/env construction and don't require `claude`.
|
||||
The smoke test spawns the real binary and is opt-in via env var because it
|
||||
hits the user's OAuth state and the wider system.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_api import CLINotFoundError
|
||||
from claude_code_api.pty import (
|
||||
PtyClaudeProcess,
|
||||
PtyProcessOptions,
|
||||
build_argv,
|
||||
build_env,
|
||||
)
|
||||
|
||||
# --- argv construction ----------------------------------------------------
|
||||
|
||||
|
||||
def test_build_argv_minimal_uses_session_id_and_permission_mode() -> None:
|
||||
opts = PtyProcessOptions(cwd="/tmp")
|
||||
argv = build_argv(opts, session_id="abc-123")
|
||||
|
||||
assert argv[0] == "claude"
|
||||
# --session-id must come early so it can be observed in `ps` output even
|
||||
# if later flags are mistyped/dropped.
|
||||
assert argv[1:3] == ["--session-id", "abc-123"]
|
||||
assert "--permission-mode" in argv
|
||||
pm_index = argv.index("--permission-mode")
|
||||
assert argv[pm_index + 1] == "bypassPermissions"
|
||||
# Must never contain headless-only flags.
|
||||
for forbidden in ("--print", "-p", "--output-format", "--input-format"):
|
||||
assert forbidden not in argv
|
||||
|
||||
|
||||
def test_build_argv_dangerously_skip_permissions_excludes_permission_mode() -> None:
|
||||
opts = PtyProcessOptions(cwd="/tmp", dangerously_skip_permissions=True)
|
||||
argv = build_argv(opts, session_id="s")
|
||||
|
||||
assert "--dangerously-skip-permissions" in argv
|
||||
assert "--permission-mode" not in argv
|
||||
|
||||
|
||||
def test_build_argv_includes_optional_flags_when_set() -> None:
|
||||
opts = PtyProcessOptions(
|
||||
cwd="/tmp",
|
||||
model="claude-opus-4-7",
|
||||
system_prompt="be brief",
|
||||
append_system_prompt="also be kind",
|
||||
allowed_tools=("Read", "Glob"),
|
||||
disallowed_tools=("Bash",),
|
||||
mcp_config=("/tmp/a.json", "/tmp/b.json"),
|
||||
add_dir=("/srv/x", "/srv/y"),
|
||||
effort="high",
|
||||
settings="/tmp/settings.json",
|
||||
extra_args=("--brief",),
|
||||
)
|
||||
argv = build_argv(opts, session_id="s")
|
||||
|
||||
# Each flag should pair with its value.
|
||||
def _pairs(flag: str) -> list[str]:
|
||||
return [argv[i + 1] for i, v in enumerate(argv) if v == flag and i + 1 < len(argv)]
|
||||
|
||||
assert _pairs("--model") == ["claude-opus-4-7"]
|
||||
assert _pairs("--system-prompt") == ["be brief"]
|
||||
assert _pairs("--append-system-prompt") == ["also be kind"]
|
||||
# CSV form per claude CLI conventions.
|
||||
assert _pairs("--allowedTools") == ["Read,Glob"]
|
||||
assert _pairs("--disallowedTools") == ["Bash"]
|
||||
assert _pairs("--mcp-config") == ["/tmp/a.json", "/tmp/b.json"]
|
||||
assert _pairs("--effort") == ["high"]
|
||||
assert _pairs("--settings") == ["/tmp/settings.json"]
|
||||
# --add-dir is variadic in claude CLI: one flag, multiple values.
|
||||
add_dir_at = argv.index("--add-dir")
|
||||
assert argv[add_dir_at + 1 : add_dir_at + 3] == ["/srv/x", "/srv/y"]
|
||||
# extra_args are passthrough at the end.
|
||||
assert argv[-1] == "--brief"
|
||||
|
||||
|
||||
def test_build_argv_omits_unset_optionals() -> None:
|
||||
opts = PtyProcessOptions(cwd="/tmp")
|
||||
argv = build_argv(opts, session_id="s")
|
||||
for flag in (
|
||||
"--model",
|
||||
"--system-prompt",
|
||||
"--append-system-prompt",
|
||||
"--allowedTools",
|
||||
"--disallowedTools",
|
||||
"--mcp-config",
|
||||
"--add-dir",
|
||||
"--effort",
|
||||
"--settings",
|
||||
):
|
||||
assert flag not in argv
|
||||
|
||||
|
||||
def test_build_argv_resume_session_id_replaces_session_id_flag() -> None:
|
||||
"""Resume mode swaps `--session-id <fresh>` for `--resume <existing>`.
|
||||
|
||||
claude rejects the two flags together unless `--fork-session` is also
|
||||
passed (which would branch the session into a new JSONL). Higher layers
|
||||
pick resume mode when they've seeded a JSONL by hand and need claude to
|
||||
pick it up rather than create a new one.
|
||||
"""
|
||||
opts = PtyProcessOptions(cwd="/tmp", resume_session_id="resume-uuid")
|
||||
argv = build_argv(opts, session_id="ignored-fresh-uuid")
|
||||
|
||||
assert argv[1:3] == ["--resume", "resume-uuid"]
|
||||
assert "--session-id" not in argv
|
||||
|
||||
|
||||
def test_options_reject_session_id_with_resume_session_id() -> None:
|
||||
with pytest.raises(ValueError, match="session_id"):
|
||||
PtyProcessOptions(cwd="/tmp", session_id="a", resume_session_id="b")
|
||||
|
||||
|
||||
def test_pty_process_reports_resume_session_id_as_session_id() -> None:
|
||||
"""When constructed in resume mode, the process advertises the resumed
|
||||
session id (the id of the JSONL on disk) — not a fresh uuid. Higher
|
||||
layers rely on `pty.session_id` to compute the JSONL path."""
|
||||
proc = PtyClaudeProcess(PtyProcessOptions(cwd="/tmp", resume_session_id="seeded-123"))
|
||||
assert proc.session_id == "seeded-123"
|
||||
assert "--resume" in proc.argv
|
||||
assert "--session-id" not in proc.argv
|
||||
|
||||
|
||||
def test_options_reject_invalid_permission_mode() -> None:
|
||||
with pytest.raises(ValueError, match="permission_mode"):
|
||||
PtyProcessOptions(cwd="/tmp", permission_mode="banana")
|
||||
|
||||
|
||||
def test_options_reject_nonpositive_dimensions() -> None:
|
||||
with pytest.raises(ValueError, match="dimensions"):
|
||||
PtyProcessOptions(cwd="/tmp", dimensions=(0, 80))
|
||||
|
||||
|
||||
# --- env construction -----------------------------------------------------
|
||||
|
||||
|
||||
def test_build_env_strips_provider_vars_by_default() -> None:
|
||||
base = {
|
||||
"PATH": "/usr/bin",
|
||||
"HOME": "/home/x",
|
||||
"ANTHROPIC_API_KEY": "sk-xxx",
|
||||
"ANTHROPIC_AUTH_TOKEN": "tok",
|
||||
"ANTHROPIC_BASE_URL": "https://x.example",
|
||||
}
|
||||
env = build_env(PtyProcessOptions(cwd="/tmp"), base=base)
|
||||
for name in ("ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_BASE_URL"):
|
||||
assert name not in env
|
||||
assert env["PATH"] == "/usr/bin"
|
||||
assert env["HOME"] == "/home/x"
|
||||
assert env["TERM"] == "xterm-256color"
|
||||
assert env["NO_COLOR"] == "1"
|
||||
|
||||
|
||||
def test_build_env_preserve_provider_env_keeps_keys() -> None:
|
||||
base = {"ANTHROPIC_API_KEY": "sk-xxx", "PATH": "/usr/bin"}
|
||||
opts = PtyProcessOptions(cwd="/tmp", preserve_provider_env=True)
|
||||
env = build_env(opts, base=base)
|
||||
assert env["ANTHROPIC_API_KEY"] == "sk-xxx"
|
||||
|
||||
|
||||
def test_build_env_extra_env_overrides_base() -> None:
|
||||
base = {"PATH": "/usr/bin", "TERM": "dumb"}
|
||||
opts = PtyProcessOptions(cwd="/tmp", extra_env={"FOO": "bar", "TERM": "vt100"})
|
||||
env = build_env(opts, base=base)
|
||||
assert env["FOO"] == "bar"
|
||||
# Explicit override should win over the default TERM we set in build_env.
|
||||
assert env["TERM"] == "vt100"
|
||||
|
||||
|
||||
# --- construction-only PtyClaudeProcess sanity ----------------------------
|
||||
|
||||
|
||||
def test_session_id_is_autogenerated_when_omitted() -> None:
|
||||
proc = PtyClaudeProcess(PtyProcessOptions(cwd="/tmp"))
|
||||
# UUID4 is 36 chars including dashes.
|
||||
assert len(proc.session_id) == 36
|
||||
assert proc.is_alive() is False
|
||||
assert proc.pid is None
|
||||
|
||||
|
||||
def test_session_id_is_passed_through_when_provided() -> None:
|
||||
proc = PtyClaudeProcess(PtyProcessOptions(cwd="/tmp", session_id="custom-id"))
|
||||
assert proc.session_id == "custom-id"
|
||||
assert "--session-id" in proc.argv
|
||||
assert proc.argv[proc.argv.index("--session-id") + 1] == "custom-id"
|
||||
|
||||
|
||||
# --- error mapping (Stage 10) ---------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_raises_cli_not_found_when_executable_missing(tmp_path) -> None:
|
||||
"""`PtyClaudeProcess.start()` lifts ptyprocess's `FileNotFoundError`
|
||||
(which fires from the pre-fork `which()` lookup) into our typed
|
||||
`CLINotFoundError` so callers don't need to know about the underlying
|
||||
library."""
|
||||
opts = PtyProcessOptions(
|
||||
cwd=str(tmp_path),
|
||||
executable="claude-binary-that-does-not-exist-xyz",
|
||||
dangerously_skip_permissions=True,
|
||||
)
|
||||
proc = PtyClaudeProcess(opts)
|
||||
with pytest.raises(CLINotFoundError) as info:
|
||||
await proc.start()
|
||||
assert "claude-binary-that-does-not-exist-xyz" in str(info.value)
|
||||
assert info.value.executable == "claude-binary-that-does-not-exist-xyz"
|
||||
|
||||
|
||||
# --- smoke test (real claude) ---------------------------------------------
|
||||
|
||||
_SMOKE_ENV = "RUN_CLAUDE_SMOKE"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get(_SMOKE_ENV) != "1",
|
||||
reason=f"set {_SMOKE_ENV}=1 to run the real-`claude` smoke test",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_smoke_start_write_terminate(tmp_path) -> None:
|
||||
"""End-to-end Layer 1 check against the installed `claude` binary.
|
||||
|
||||
Spawns claude under a PTY, confirms it's alive, sends a no-op message
|
||||
(which we don't expect a turn to complete in this test), then terminates
|
||||
cleanly via SIGTERM. We only assert lifecycle invariants here — JSONL
|
||||
parsing and turn semantics live in later layers.
|
||||
"""
|
||||
opts = PtyProcessOptions(
|
||||
cwd=str(tmp_path),
|
||||
dangerously_skip_permissions=True,
|
||||
)
|
||||
proc = PtyClaudeProcess(opts)
|
||||
await proc.start()
|
||||
pid = proc.pid
|
||||
try:
|
||||
assert pid is not None and pid > 0
|
||||
# Give claude a moment to paint the TUI before we ask it to die.
|
||||
# If it can't even stay alive for a beat, something is fundamentally
|
||||
# wrong with the spawn (auth blocked, missing HOME, etc.).
|
||||
await asyncio.sleep(0.5)
|
||||
captured = proc.captured_output()
|
||||
assert proc.is_alive(), (
|
||||
f"claude exited within 0.5s of spawn; captured {len(captured)} bytes:\n"
|
||||
f"{captured[:1000]!r}"
|
||||
)
|
||||
await proc.write("hello")
|
||||
finally:
|
||||
exit_status = await proc.terminate(grace=5.0)
|
||||
assert proc.is_alive() is False
|
||||
# Either an exit code or a signal — anything other than `None` is fine.
|
||||
assert exit_status is not None, (
|
||||
f"terminate() returned None for pid={pid}; output:\n{proc.captured_output()[:1000]!r}"
|
||||
)
|
||||
@@ -0,0 +1,934 @@
|
||||
"""Unit + smoke tests for Layer 4 (`TurnManager`).
|
||||
|
||||
Unit tests use a `FakePty` that, on `write()`, dumps a scripted list of JSONL
|
||||
records into a real temp file. A real `JsonlWatcher` tails that file so the
|
||||
manager's read/normalize/turn-end loop is exercised end-to-end without
|
||||
launching `claude`. The smoke test at the bottom spawns the real binary
|
||||
behind `RUN_CLAUDE_SMOKE=1` and also serves as the empirical probe for
|
||||
Open Q #2 (PTY echo / buffering).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_api import (
|
||||
AssistantMessage,
|
||||
AuthError,
|
||||
ProcessError,
|
||||
RateLimitError,
|
||||
ResultMessage,
|
||||
SessionError,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
from claude_code_api.paths import resolve_jsonl_path
|
||||
from claude_code_api.watcher import JsonlWatcher
|
||||
from claude_code_api.pty import PtyClaudeProcess, PtyProcessOptions
|
||||
from claude_code_api.turn import TurnManager
|
||||
|
||||
# --- fakes -----------------------------------------------------------------
|
||||
|
||||
|
||||
class FakePty:
|
||||
"""Stand-in for `PtyClaudeProcess` that flushes a scripted JSONL batch on write.
|
||||
|
||||
The script is a list of records that get appended to `jsonl_path` (one
|
||||
JSON object per line) as soon as the manager calls `write()`. This lets
|
||||
a single synchronous setup drive the full turn loop — no async
|
||||
coordination, no real `claude`. Multi-write scripts are supported: the
|
||||
Nth `write()` flushes the Nth element of `scripts`.
|
||||
|
||||
Stage 10 additions: `alive` and `output` knobs let tests simulate
|
||||
sub-process death and error chrome captured from the PTY drain buffer,
|
||||
which `TurnManager` consults when classifying failures.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
*,
|
||||
session_id: str = "fake-session-0001",
|
||||
scripts: list[list[dict[str, Any]]] | None = None,
|
||||
alive: bool = True,
|
||||
output: bytes = b"",
|
||||
) -> None:
|
||||
self.cwd = str(tmp_path)
|
||||
self.session_id = session_id
|
||||
self._jsonl = tmp_path / f"{session_id}.jsonl"
|
||||
self._scripts = scripts if scripts is not None else []
|
||||
self._write_count = 0
|
||||
self.writes: list[str] = []
|
||||
self.started = False
|
||||
self.closed = False
|
||||
self._alive = alive
|
||||
self._output = output
|
||||
|
||||
async def start(self) -> None:
|
||||
self.started = True
|
||||
|
||||
async def write(self, text: str, *, newline: bool = True) -> int:
|
||||
self.writes.append(text)
|
||||
if self._write_count < len(self._scripts):
|
||||
records = self._scripts[self._write_count]
|
||||
with self._jsonl.open("a", encoding="utf-8") as f:
|
||||
for rec in records:
|
||||
f.write(json.dumps(rec) + "\n")
|
||||
self._write_count += 1
|
||||
return len(text)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
# --- Stage 10 surface ----------------------------------------------
|
||||
def is_alive(self) -> bool:
|
||||
return self._alive
|
||||
|
||||
def captured_output(self) -> bytes:
|
||||
return self._output
|
||||
|
||||
def set_alive(self, alive: bool) -> None:
|
||||
self._alive = alive
|
||||
|
||||
def set_output(self, output: bytes) -> None:
|
||||
self._output = output
|
||||
|
||||
|
||||
def _user_rec(text: str) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "user",
|
||||
"uuid": f"u-{text[:8]}",
|
||||
"sessionId": "fake-session-0001",
|
||||
"parentUuid": None,
|
||||
"message": {"role": "user", "content": text},
|
||||
}
|
||||
|
||||
|
||||
def _assistant_rec(
|
||||
text: str,
|
||||
*,
|
||||
stop_reason: str | None = "end_turn",
|
||||
usage: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "assistant",
|
||||
"uuid": f"a-{text[:8]}",
|
||||
"sessionId": "fake-session-0001",
|
||||
"parentUuid": None,
|
||||
"message": {
|
||||
"id": "msg_x",
|
||||
"role": "assistant",
|
||||
"model": "claude-test",
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"stop_reason": stop_reason,
|
||||
"usage": usage or {"input_tokens": 1, "output_tokens": 1},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _tool_use_assistant_rec(name: str, tool_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "assistant",
|
||||
"uuid": f"a-tu-{tool_id}",
|
||||
"sessionId": "fake-session-0001",
|
||||
"parentUuid": None,
|
||||
"message": {
|
||||
"id": "msg_y",
|
||||
"role": "assistant",
|
||||
"model": "claude-test",
|
||||
"content": [{"type": "tool_use", "id": tool_id, "name": name, "input": {}}],
|
||||
"stop_reason": "tool_use",
|
||||
"usage": {"input_tokens": 1, "output_tokens": 1},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _tool_result_user_rec(tool_id: str, content: str) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "user",
|
||||
"uuid": f"u-tr-{tool_id}",
|
||||
"sessionId": "fake-session-0001",
|
||||
"parentUuid": None,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": tool_id, "content": content}],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _turn_duration_rec(duration_ms: int = 1234) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "system",
|
||||
"subtype": "turn_duration",
|
||||
"uuid": "sys-td",
|
||||
"sessionId": "fake-session-0001",
|
||||
"durationMs": duration_ms,
|
||||
}
|
||||
|
||||
|
||||
def _make_manager(
|
||||
fake: FakePty,
|
||||
*,
|
||||
wait_for_turn_duration: bool = False,
|
||||
startup_delay: float = 0.0,
|
||||
turn_duration_timeout: float | None = 1.0,
|
||||
on_parse_error: Any = None,
|
||||
) -> TurnManager:
|
||||
"""Build a TurnManager wired to a real JsonlWatcher on the fake's path."""
|
||||
watcher = JsonlWatcher(
|
||||
Path(fake.cwd) / f"{fake.session_id}.jsonl",
|
||||
poll_interval=0.01,
|
||||
)
|
||||
return TurnManager(
|
||||
fake, # type: ignore[arg-type]
|
||||
watcher,
|
||||
wait_for_turn_duration=wait_for_turn_duration,
|
||||
startup_delay=startup_delay,
|
||||
turn_duration_timeout=turn_duration_timeout,
|
||||
on_parse_error=on_parse_error,
|
||||
)
|
||||
|
||||
|
||||
# --- construction validation ----------------------------------------------
|
||||
|
||||
|
||||
def test_init_rejects_negative_file_wait_timeout(tmp_path: Path) -> None:
|
||||
fake = FakePty(tmp_path)
|
||||
watcher = JsonlWatcher(tmp_path / "x.jsonl")
|
||||
with pytest.raises(ValueError, match="file_wait_timeout"):
|
||||
TurnManager(fake, watcher, file_wait_timeout=-1) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_init_rejects_negative_startup_delay(tmp_path: Path) -> None:
|
||||
fake = FakePty(tmp_path)
|
||||
watcher = JsonlWatcher(tmp_path / "x.jsonl")
|
||||
with pytest.raises(ValueError, match="startup_delay"):
|
||||
TurnManager(fake, watcher, startup_delay=-0.5) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_init_rejects_negative_turn_duration_timeout(tmp_path: Path) -> None:
|
||||
fake = FakePty(tmp_path)
|
||||
watcher = JsonlWatcher(tmp_path / "x.jsonl")
|
||||
with pytest.raises(ValueError, match="turn_duration_timeout"):
|
||||
TurnManager(fake, watcher, turn_duration_timeout=-1) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# --- lifecycle guards -----------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_before_start_raises(tmp_path: Path) -> None:
|
||||
fake = FakePty(tmp_path)
|
||||
tm = _make_manager(fake)
|
||||
with pytest.raises(RuntimeError, match="before start"):
|
||||
async for _ in tm.send_user_message("hi"):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_is_idempotent(tmp_path: Path) -> None:
|
||||
fake = FakePty(tmp_path)
|
||||
tm = _make_manager(fake)
|
||||
await tm.start()
|
||||
await tm.start()
|
||||
# FakePty.start() flips `started` either way; we just need no exception
|
||||
# and a stable state machine.
|
||||
assert fake.started is True
|
||||
|
||||
|
||||
# --- happy path: one turn, terminal end_turn -------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_turn_yields_user_assistant_then_result(tmp_path: Path) -> None:
|
||||
fake = FakePty(
|
||||
tmp_path,
|
||||
scripts=[
|
||||
[
|
||||
_user_rec("say hi"),
|
||||
_assistant_rec("hi!", stop_reason="end_turn"),
|
||||
# turn_duration is in the script but with
|
||||
# wait_for_turn_duration=False it gets queued behind our
|
||||
# early return — we don't yield it.
|
||||
_turn_duration_rec(),
|
||||
]
|
||||
],
|
||||
)
|
||||
tm = _make_manager(fake)
|
||||
await tm.start()
|
||||
events: list[Any] = []
|
||||
async for event in tm.send_user_message("say hi"):
|
||||
events.append(event)
|
||||
await tm.aclose()
|
||||
|
||||
assert fake.writes == ["say hi"]
|
||||
assert isinstance(events[0], UserMessage)
|
||||
assert isinstance(events[1], AssistantMessage)
|
||||
assert events[1].stop_reason == "end_turn"
|
||||
assert isinstance(events[1].content[0], TextBlock)
|
||||
assert isinstance(events[-1], ResultMessage)
|
||||
assert events[-1].stop_reason == "end_turn"
|
||||
assert events[-1].num_turns == 1
|
||||
assert events[-1].session_id == fake.session_id
|
||||
# No turn_duration → duration_ms falls back to 0 in the synthesized result.
|
||||
assert events[-1].duration_ms == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_turn_duration_carries_duration_ms(tmp_path: Path) -> None:
|
||||
fake = FakePty(
|
||||
tmp_path,
|
||||
scripts=[
|
||||
[
|
||||
_user_rec("ping"),
|
||||
_assistant_rec("pong", stop_reason="end_turn"),
|
||||
_turn_duration_rec(duration_ms=4242),
|
||||
]
|
||||
],
|
||||
)
|
||||
tm = _make_manager(fake, wait_for_turn_duration=True)
|
||||
await tm.start()
|
||||
events = [e async for e in tm.send_user_message("ping")]
|
||||
await tm.aclose()
|
||||
|
||||
# We also want the system event itself to be visible in the stream.
|
||||
assert any(isinstance(e, SystemMessage) and e.subtype == "turn_duration" for e in events)
|
||||
result = events[-1]
|
||||
assert isinstance(result, ResultMessage)
|
||||
assert result.duration_ms == 4242
|
||||
|
||||
|
||||
# --- tool loop continues until next terminal -----------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_use_stop_reason_does_not_close_turn(tmp_path: Path) -> None:
|
||||
fake = FakePty(
|
||||
tmp_path,
|
||||
scripts=[
|
||||
[
|
||||
_user_rec("compute"),
|
||||
_tool_use_assistant_rec("Bash", "tool_1"),
|
||||
_tool_result_user_rec("tool_1", "42"),
|
||||
_assistant_rec("the answer is 42", stop_reason="end_turn"),
|
||||
]
|
||||
],
|
||||
)
|
||||
tm = _make_manager(fake)
|
||||
await tm.start()
|
||||
events = [e async for e in tm.send_user_message("compute")]
|
||||
await tm.aclose()
|
||||
|
||||
assistants = [e for e in events if isinstance(e, AssistantMessage)]
|
||||
# Both assistant records made it through — the tool_use one did not
|
||||
# short-circuit the loop.
|
||||
assert len(assistants) == 2
|
||||
assert assistants[0].stop_reason == "tool_use"
|
||||
assert assistants[1].stop_reason == "end_turn"
|
||||
assert isinstance(events[-1], ResultMessage)
|
||||
assert events[-1].stop_reason == "end_turn"
|
||||
|
||||
|
||||
# --- error & misuse paths -------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_error_callback_keeps_stream_alive(tmp_path: Path) -> None:
|
||||
# A bogus record (missing `message`) sits between two valid ones. The
|
||||
# callback should fire once and the stream should still terminate cleanly.
|
||||
bad = {"type": "assistant", "uuid": "x", "sessionId": "fake-session-0001"}
|
||||
fake = FakePty(
|
||||
tmp_path,
|
||||
scripts=[
|
||||
[
|
||||
_user_rec("hi"),
|
||||
bad,
|
||||
_assistant_rec("ok", stop_reason="end_turn"),
|
||||
]
|
||||
],
|
||||
)
|
||||
errors: list[tuple[Exception, dict[str, Any]]] = []
|
||||
tm = _make_manager(fake, on_parse_error=lambda exc, rec: errors.append((exc, rec)))
|
||||
await tm.start()
|
||||
events = [e async for e in tm.send_user_message("hi")]
|
||||
await tm.aclose()
|
||||
|
||||
assert len(errors) == 1
|
||||
assert errors[0][1] is bad or errors[0][1] == bad
|
||||
assert isinstance(events[-1], ResultMessage)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_send_raises_while_turn_in_progress(tmp_path: Path) -> None:
|
||||
# Manager that will NEVER see a terminal assistant (no scripted records).
|
||||
# Drive one __anext__ on the first generator so it enters the polling loop,
|
||||
# then attempt a second concurrent send.
|
||||
fake = FakePty(tmp_path, scripts=[[]])
|
||||
# Touch the file so the file-wait doesn't block forever.
|
||||
(tmp_path / f"{fake.session_id}.jsonl").touch()
|
||||
tm = _make_manager(fake)
|
||||
await tm.start()
|
||||
|
||||
gen1 = tm.send_user_message("first")
|
||||
# Spin up the generator: schedule one read pass.
|
||||
task = asyncio.create_task(gen1.__anext__())
|
||||
await asyncio.sleep(0.05) # let _iter_turn flip turn_in_progress
|
||||
|
||||
with pytest.raises(RuntimeError, match="turn is in progress"):
|
||||
async for _ in tm.send_user_message("second"):
|
||||
pass
|
||||
|
||||
task.cancel()
|
||||
with pytest.raises((asyncio.CancelledError, StopAsyncIteration)):
|
||||
await task
|
||||
await tm.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_terminates_owned_pty(tmp_path: Path) -> None:
|
||||
fake = FakePty(tmp_path)
|
||||
tm = _make_manager(fake)
|
||||
await tm.start()
|
||||
await tm.aclose()
|
||||
assert fake.closed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_skips_pty_when_not_owned(tmp_path: Path) -> None:
|
||||
fake = FakePty(tmp_path)
|
||||
watcher = JsonlWatcher(tmp_path / f"{fake.session_id}.jsonl", poll_interval=0.01)
|
||||
tm = TurnManager(fake, watcher, owns_pty=False, startup_delay=0.0) # type: ignore[arg-type]
|
||||
await tm.start()
|
||||
await tm.aclose()
|
||||
assert fake.closed is False
|
||||
|
||||
|
||||
# --- Stage 10: error mapping ---------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_error_raised_when_jsonl_never_appears(tmp_path: Path) -> None:
|
||||
"""No script → FakePty.write() doesn't create the JSONL → the
|
||||
file-wait timeout fires → TurnManager raises SessionError (not the
|
||||
raw asyncio.TimeoutError)."""
|
||||
fake = FakePty(tmp_path, scripts=[]) # write() is a no-op for JSONL
|
||||
watcher = JsonlWatcher(
|
||||
tmp_path / f"{fake.session_id}.jsonl",
|
||||
poll_interval=0.01,
|
||||
)
|
||||
tm = TurnManager(
|
||||
fake, # type: ignore[arg-type]
|
||||
watcher,
|
||||
startup_delay=0.0,
|
||||
file_wait_timeout=0.05, # fire fast
|
||||
)
|
||||
await tm.start()
|
||||
with pytest.raises(SessionError):
|
||||
async for _ in tm.send_user_message("hi"):
|
||||
pass
|
||||
await tm.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_marker_in_pty_output_raises_auth_error(tmp_path: Path) -> None:
|
||||
"""When the JSONL never appears AND captured PTY output carries an
|
||||
auth-block marker, the classifier promotes the failure to AuthError
|
||||
(instead of the generic SessionError)."""
|
||||
fake = FakePty(
|
||||
tmp_path,
|
||||
scripts=[],
|
||||
output=b"Failed to authenticate. Please run /login.\r\n",
|
||||
)
|
||||
watcher = JsonlWatcher(
|
||||
tmp_path / f"{fake.session_id}.jsonl",
|
||||
poll_interval=0.01,
|
||||
)
|
||||
tm = TurnManager(
|
||||
fake, # type: ignore[arg-type]
|
||||
watcher,
|
||||
startup_delay=0.0,
|
||||
file_wait_timeout=0.05,
|
||||
)
|
||||
await tm.start()
|
||||
with pytest.raises(AuthError):
|
||||
async for _ in tm.send_user_message("hi"):
|
||||
pass
|
||||
await tm.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_marker_promotes_session_error_to_rate_limit(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Same path as the auth case but with a rate-limit marker."""
|
||||
fake = FakePty(
|
||||
tmp_path,
|
||||
scripts=[],
|
||||
output=b"\x1b[31mYou've hit your limit\x1b[0m. Try again at 9pm.",
|
||||
)
|
||||
watcher = JsonlWatcher(
|
||||
tmp_path / f"{fake.session_id}.jsonl",
|
||||
poll_interval=0.01,
|
||||
)
|
||||
tm = TurnManager(
|
||||
fake, # type: ignore[arg-type]
|
||||
watcher,
|
||||
startup_delay=0.0,
|
||||
file_wait_timeout=0.05,
|
||||
)
|
||||
await tm.start()
|
||||
with pytest.raises(RateLimitError):
|
||||
async for _ in tm.send_user_message("hi"):
|
||||
pass
|
||||
await tm.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_death_mid_poll_raises_process_error(tmp_path: Path) -> None:
|
||||
"""The JSONL appears (so we leave the wait-for-file phase) but no
|
||||
terminal assistant ever arrives AND the PTY reports dead. Detection
|
||||
fires from inside the poll loop, with the captured output included in
|
||||
the exception so a gateway can log what claude wrote before exiting.
|
||||
"""
|
||||
fake = FakePty(
|
||||
tmp_path,
|
||||
scripts=[[_user_rec("hi")]], # only the user record — no assistant
|
||||
output=b"some claude chrome before death\r\n",
|
||||
)
|
||||
watcher = JsonlWatcher(
|
||||
tmp_path / f"{fake.session_id}.jsonl",
|
||||
poll_interval=0.01,
|
||||
)
|
||||
tm = TurnManager(
|
||||
fake, # type: ignore[arg-type]
|
||||
watcher,
|
||||
startup_delay=0.0,
|
||||
file_wait_timeout=2.0,
|
||||
)
|
||||
await tm.start()
|
||||
|
||||
async def consumer() -> list[Any]:
|
||||
events: list[Any] = []
|
||||
async for ev in tm.send_user_message("hi"):
|
||||
events.append(ev)
|
||||
# Once we've seen the user record, declare the PTY dead so the
|
||||
# next polling pass enters the failure branch.
|
||||
if isinstance(ev, UserMessage):
|
||||
fake.set_alive(False)
|
||||
return events
|
||||
|
||||
with pytest.raises(ProcessError) as info:
|
||||
await consumer()
|
||||
assert "exited before a terminal" in str(info.value)
|
||||
assert info.value.stderr is not None
|
||||
assert "claude chrome before death" in info.value.stderr
|
||||
await tm.aclose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_death_with_rate_limit_marker_raises_rate_limit(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Process-death classifier defers to the PTY marker: if the buffer
|
||||
carries a rate-limit notice, raise the typed marker, not the generic
|
||||
ProcessError."""
|
||||
fake = FakePty(
|
||||
tmp_path,
|
||||
scripts=[[_user_rec("hi")]],
|
||||
output=b"You've hit your limit. Cooling off.",
|
||||
)
|
||||
watcher = JsonlWatcher(
|
||||
tmp_path / f"{fake.session_id}.jsonl",
|
||||
poll_interval=0.01,
|
||||
)
|
||||
tm = TurnManager(
|
||||
fake, # type: ignore[arg-type]
|
||||
watcher,
|
||||
startup_delay=0.0,
|
||||
file_wait_timeout=2.0,
|
||||
)
|
||||
await tm.start()
|
||||
|
||||
async def consumer() -> None:
|
||||
async for ev in tm.send_user_message("hi"):
|
||||
if isinstance(ev, UserMessage):
|
||||
fake.set_alive(False)
|
||||
|
||||
with pytest.raises(RateLimitError):
|
||||
await consumer()
|
||||
await tm.aclose()
|
||||
|
||||
|
||||
# --- multi-turn (Stage 6) -------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_consecutive_turns_each_yield_only_fresh_records(tmp_path: Path) -> None:
|
||||
"""Stage 6 core: a second `send_user_message()` on the same manager sees
|
||||
only the records appended after the first turn ended.
|
||||
|
||||
The watcher is reused across turns and tracks the byte offset internally
|
||||
(see PROGRESS.md decision log: "TurnManager does NOT own
|
||||
JsonlWatcher.offset"). This test pins that contract.
|
||||
"""
|
||||
fake = FakePty(
|
||||
tmp_path,
|
||||
scripts=[
|
||||
[
|
||||
_user_rec("Q1"),
|
||||
_assistant_rec("A1", stop_reason="end_turn"),
|
||||
],
|
||||
[
|
||||
_user_rec("Q2"),
|
||||
_assistant_rec("A2", stop_reason="end_turn"),
|
||||
],
|
||||
],
|
||||
)
|
||||
tm = _make_manager(fake)
|
||||
await tm.start()
|
||||
|
||||
turn1 = [e async for e in tm.send_user_message("Q1")]
|
||||
turn2 = [e async for e in tm.send_user_message("Q2")]
|
||||
await tm.aclose()
|
||||
|
||||
assert fake.writes == ["Q1", "Q2"]
|
||||
|
||||
# Turn 1: user("Q1"), assistant("A1"), result
|
||||
assert [type(e).__name__ for e in turn1] == [
|
||||
"UserMessage",
|
||||
"AssistantMessage",
|
||||
"ResultMessage",
|
||||
]
|
||||
assert turn1[0].content == "Q1"
|
||||
assert isinstance(turn1[1], AssistantMessage)
|
||||
assert isinstance(turn1[1].content[0], TextBlock)
|
||||
assert turn1[1].content[0].text == "A1"
|
||||
assert isinstance(turn1[-1], ResultMessage)
|
||||
assert turn1[-1].num_turns == 1
|
||||
|
||||
# Turn 2 must NOT leak any of turn 1's records back to the caller.
|
||||
assert [type(e).__name__ for e in turn2] == [
|
||||
"UserMessage",
|
||||
"AssistantMessage",
|
||||
"ResultMessage",
|
||||
]
|
||||
assert turn2[0].content == "Q2"
|
||||
assert isinstance(turn2[1], AssistantMessage)
|
||||
assert isinstance(turn2[1].content[0], TextBlock)
|
||||
assert turn2[1].content[0].text == "A2"
|
||||
|
||||
# Turn-count bookkeeping increments across turns; session_id is stable.
|
||||
assert isinstance(turn2[-1], ResultMessage)
|
||||
assert turn2[-1].num_turns == 2
|
||||
assert turn2[-1].session_id == turn1[-1].session_id == fake.session_id
|
||||
assert tm.turn_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_with_wait_for_turn_duration_carries_each_duration(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""When `wait_for_turn_duration=True`, each turn's synthesized result
|
||||
carries its own duration. The watcher offset advances past the
|
||||
intervening turn_duration heartbeat so turn 2 starts clean.
|
||||
"""
|
||||
fake = FakePty(
|
||||
tmp_path,
|
||||
scripts=[
|
||||
[
|
||||
_user_rec("ping1"),
|
||||
_assistant_rec("pong1", stop_reason="end_turn"),
|
||||
_turn_duration_rec(duration_ms=111),
|
||||
],
|
||||
[
|
||||
_user_rec("ping2"),
|
||||
_assistant_rec("pong2", stop_reason="end_turn"),
|
||||
_turn_duration_rec(duration_ms=222),
|
||||
],
|
||||
],
|
||||
)
|
||||
tm = _make_manager(fake, wait_for_turn_duration=True)
|
||||
await tm.start()
|
||||
|
||||
turn1 = [e async for e in tm.send_user_message("ping1")]
|
||||
turn2 = [e async for e in tm.send_user_message("ping2")]
|
||||
await tm.aclose()
|
||||
|
||||
assert isinstance(turn1[-1], ResultMessage)
|
||||
assert turn1[-1].duration_ms == 111
|
||||
assert turn1[-1].num_turns == 1
|
||||
|
||||
assert isinstance(turn2[-1], ResultMessage)
|
||||
assert turn2[-1].duration_ms == 222
|
||||
assert turn2[-1].num_turns == 2
|
||||
|
||||
|
||||
# --- smoke test (real claude) ---------------------------------------------
|
||||
|
||||
_SMOKE_ENV = "RUN_CLAUDE_SMOKE"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get(_SMOKE_ENV) != "1",
|
||||
reason=f"set {_SMOKE_ENV}=1 to run the real-`claude` smoke test",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_smoke_send_hi(tmp_path: Path) -> None:
|
||||
"""Smoke 1: end-to-end one-turn against real claude.
|
||||
|
||||
Confirms: PTY spawn, JSONL discovery, watcher tail, normalizer mapping,
|
||||
turn-end detection, and ResultMessage synthesis all line up. Also
|
||||
doubles as the empirical probe for Open Q #2 — if claude doesn't pick up
|
||||
our prompt after `pty.write("say hi\\r")`, the JSONL never grows and the
|
||||
file-wait timeout fires; that failure mode tells us the carriage-return
|
||||
+ 1s startup delay is not enough and we need a different submit
|
||||
mechanism.
|
||||
"""
|
||||
opts = PtyProcessOptions(
|
||||
cwd=str(tmp_path),
|
||||
dangerously_skip_permissions=True,
|
||||
)
|
||||
pty = PtyClaudeProcess(opts)
|
||||
jsonl_path = resolve_jsonl_path(pty.cwd, pty.session_id)
|
||||
watcher = JsonlWatcher(jsonl_path)
|
||||
|
||||
tm = TurnManager(pty, watcher)
|
||||
try:
|
||||
await tm.start()
|
||||
events: list[Any] = []
|
||||
async for event in tm.send_user_message("say hi"):
|
||||
events.append(event)
|
||||
finally:
|
||||
await tm.aclose()
|
||||
|
||||
assistants = [e for e in events if isinstance(e, AssistantMessage)]
|
||||
assert assistants, (
|
||||
f"no AssistantMessage in stream; got {[type(e).__name__ for e in events]}"
|
||||
)
|
||||
terminal = next(
|
||||
(
|
||||
a
|
||||
for a in assistants
|
||||
if a.stop_reason in {"end_turn", "max_tokens", "stop_sequence", "refusal"}
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert terminal is not None, (
|
||||
f"no terminal stop_reason; got {[a.stop_reason for a in assistants]}"
|
||||
)
|
||||
assert any(isinstance(b, TextBlock) for b in terminal.content)
|
||||
assert isinstance(events[-1], ResultMessage)
|
||||
assert events[-1].stop_reason == terminal.stop_reason
|
||||
assert events[-1].session_id == pty.session_id
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get(_SMOKE_ENV) != "1",
|
||||
reason=f"set {_SMOKE_ENV}=1 to run the real-`claude` smoke test",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_smoke_multi_turn_context_persists(tmp_path: Path) -> None:
|
||||
"""Smoke 2 (Stage 6): two turns on one TurnManager, the second must see
|
||||
the first's context.
|
||||
|
||||
Turn 1 plants a memorable token via the user message; turn 2 asks for it
|
||||
back. If the same `--session-id` PTY truly accumulates context (as the
|
||||
JSONL design implies), the second assistant text contains the token. If
|
||||
instead each turn ran isolated, the second reply would not know it.
|
||||
|
||||
The token is a low-entropy proper noun ("Beaver" — same one we used in
|
||||
the JSONL injection probe) chosen to be unlikely-but-not-impossible to
|
||||
appear spontaneously, so a false positive remains very unlikely while
|
||||
keeping the prompt natural.
|
||||
"""
|
||||
opts = PtyProcessOptions(
|
||||
cwd=str(tmp_path),
|
||||
dangerously_skip_permissions=True,
|
||||
)
|
||||
pty = PtyClaudeProcess(opts)
|
||||
jsonl_path = resolve_jsonl_path(pty.cwd, pty.session_id)
|
||||
watcher = JsonlWatcher(jsonl_path)
|
||||
|
||||
tm = TurnManager(pty, watcher)
|
||||
turn1_events: list[Any] = []
|
||||
turn2_events: list[Any] = []
|
||||
try:
|
||||
await tm.start()
|
||||
async for event in tm.send_user_message(
|
||||
"Please remember: my name is Beaver. Reply with just 'ok'."
|
||||
):
|
||||
turn1_events.append(event)
|
||||
async for event in tm.send_user_message(
|
||||
"What is my name? Answer with the single word only."
|
||||
):
|
||||
turn2_events.append(event)
|
||||
finally:
|
||||
await tm.aclose()
|
||||
|
||||
# Both turns yielded a synthesized result; num_turns increments.
|
||||
assert isinstance(turn1_events[-1], ResultMessage)
|
||||
assert isinstance(turn2_events[-1], ResultMessage)
|
||||
assert turn1_events[-1].num_turns == 1
|
||||
assert turn2_events[-1].num_turns == 2
|
||||
assert turn1_events[-1].session_id == turn2_events[-1].session_id == pty.session_id
|
||||
assert tm.turn_count == 2
|
||||
|
||||
# Second turn's terminal assistant must reference the planted token.
|
||||
turn2_assistants = [e for e in turn2_events if isinstance(e, AssistantMessage)]
|
||||
terminal2 = next(
|
||||
(
|
||||
a
|
||||
for a in turn2_assistants
|
||||
if a.stop_reason in {"end_turn", "max_tokens", "stop_sequence", "refusal"}
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert terminal2 is not None, (
|
||||
f"no terminal stop_reason in turn 2; got {[a.stop_reason for a in turn2_assistants]}"
|
||||
)
|
||||
text2 = " ".join(b.text for b in terminal2.content if isinstance(b, TextBlock))
|
||||
assert "beaver" in text2.lower(), (
|
||||
f"turn 2 did not inherit context from turn 1; reply was: {text2!r}"
|
||||
)
|
||||
|
||||
|
||||
# --- Stage 7: tool calls via external MCP server -------------------------
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
||||
_ECHO_MCP_SCRIPT = _REPO_ROOT / "scripts" / "echo_mcp_server.py"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get(_SMOKE_ENV) != "1",
|
||||
reason=f"set {_SMOKE_ENV}=1 to run the real-`claude` smoke test",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_smoke_tool_call_via_mcp(tmp_path: Path) -> None:
|
||||
"""Smoke 3 (Stage 7): real claude routes a tool call through an external
|
||||
stdio MCP server, and the resulting `tool_use` + `tool_result` records
|
||||
surface as typed events.
|
||||
|
||||
Setup:
|
||||
- `scripts/echo_mcp_server.py` is a zero-dep stdio MCP server with one
|
||||
tool, `echo`, that returns its `text` argument verbatim.
|
||||
- We point claude at it via a temp `--mcp-config` JSON file (one
|
||||
server named "echo"). `--strict-mcp-config` keeps the user's
|
||||
ambient `.mcp.json` from leaking in and changing the tool surface.
|
||||
|
||||
Assertions:
|
||||
- At least one `AssistantMessage.content` carries a `ToolUseBlock`
|
||||
whose name references the echo tool (claude exposes external MCP
|
||||
tools as `mcp__<server>__<tool>`, here `mcp__echo__echo`).
|
||||
- The follow-up `UserMessage` carries a `ToolResultBlock` whose
|
||||
content includes the marker token we asked the tool to echo —
|
||||
the only place that token can come from is the MCP server, so
|
||||
seeing it round-tripped proves the full path worked.
|
||||
- A terminal assistant closes the turn and the synthesized
|
||||
`ResultMessage` reflects its stop_reason.
|
||||
"""
|
||||
assert _ECHO_MCP_SCRIPT.exists(), f"missing echo MCP server at {_ECHO_MCP_SCRIPT}"
|
||||
|
||||
marker = "banana42xyz" # low-collision sentinel; must appear in tool_result
|
||||
|
||||
mcp_config_path = tmp_path / "mcp_config.json"
|
||||
mcp_config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"mcpServers": {
|
||||
"echo": {
|
||||
"command": sys.executable,
|
||||
"args": [str(_ECHO_MCP_SCRIPT)],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
opts = PtyProcessOptions(
|
||||
cwd=str(tmp_path),
|
||||
dangerously_skip_permissions=True,
|
||||
mcp_config=(str(mcp_config_path),),
|
||||
)
|
||||
pty = PtyClaudeProcess(opts)
|
||||
jsonl_path = resolve_jsonl_path(pty.cwd, pty.session_id)
|
||||
watcher = JsonlWatcher(jsonl_path)
|
||||
|
||||
# External MCP servers spawn during claude's startup, so the input box
|
||||
# mounts a bit later than for a bare session. The 60s file-wait still
|
||||
# leaves headroom even on a slow first MCP handshake.
|
||||
tm = TurnManager(pty, watcher, file_wait_timeout=60.0)
|
||||
|
||||
prompt = f"Call mcp__echo__echo with text={marker!r}, then reply 'done'."
|
||||
|
||||
events: list[Any] = []
|
||||
try:
|
||||
await tm.start()
|
||||
async for event in tm.send_user_message(prompt):
|
||||
events.append(event)
|
||||
finally:
|
||||
await tm.aclose()
|
||||
|
||||
# --- assertions ---
|
||||
tool_uses: list[ToolUseBlock] = []
|
||||
for ev in events:
|
||||
if isinstance(ev, AssistantMessage):
|
||||
tool_uses.extend(b for b in ev.content if isinstance(b, ToolUseBlock))
|
||||
assert tool_uses, (
|
||||
"no ToolUseBlock in any assistant message; got "
|
||||
f"{[type(e).__name__ for e in events]}"
|
||||
)
|
||||
echo_uses = [t for t in tool_uses if "echo" in t.name.lower()]
|
||||
assert echo_uses, (
|
||||
f"no tool_use referenced the echo tool; saw names {[t.name for t in tool_uses]}"
|
||||
)
|
||||
|
||||
# The marker text only exists on the MCP server side, so finding it in a
|
||||
# tool_result block proves the round-trip actually completed.
|
||||
tool_results: list[ToolResultBlock] = []
|
||||
for ev in events:
|
||||
if isinstance(ev, UserMessage) and isinstance(ev.content, list):
|
||||
tool_results.extend(b for b in ev.content if isinstance(b, ToolResultBlock))
|
||||
assert tool_results, "no ToolResultBlock in any user message after the tool call"
|
||||
|
||||
def _result_text(block: ToolResultBlock) -> str:
|
||||
if isinstance(block.content, str):
|
||||
return block.content
|
||||
if isinstance(block.content, list):
|
||||
chunks: list[str] = []
|
||||
for part in block.content:
|
||||
if isinstance(part, dict) and isinstance(part.get("text"), str):
|
||||
chunks.append(part["text"])
|
||||
return " ".join(chunks)
|
||||
return ""
|
||||
|
||||
assert any(marker in _result_text(b) for b in tool_results), (
|
||||
f"marker {marker!r} did not appear in any tool_result; got "
|
||||
f"{[_result_text(b) for b in tool_results]}"
|
||||
)
|
||||
|
||||
terminal_assistant = next(
|
||||
(
|
||||
ev
|
||||
for ev in events
|
||||
if isinstance(ev, AssistantMessage)
|
||||
and ev.stop_reason in {"end_turn", "max_tokens", "stop_sequence", "refusal"}
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert terminal_assistant is not None, (
|
||||
"no terminal assistant after tool round-trip; got stop_reasons "
|
||||
f"{[e.stop_reason for e in events if isinstance(e, AssistantMessage)]}"
|
||||
)
|
||||
assert isinstance(events[-1], ResultMessage)
|
||||
assert events[-1].stop_reason == terminal_assistant.stop_reason
|
||||
@@ -0,0 +1,364 @@
|
||||
"""Unit tests for Layer 2 (`JsonlWatcher`).
|
||||
|
||||
All tests use temp files; no `claude` involved. The watcher is exercised both
|
||||
in its single-pass mode (`read_once`) and in its long-running mode (`tail`).
|
||||
For `tail`, a producer task appends to the file while a consumer pulls from
|
||||
the async iterator; both run under one event loop with a short poll interval
|
||||
so tests stay quick.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_api.watcher import JsonlWatcher
|
||||
|
||||
|
||||
def _write_records(path: Path, records: list[dict]) -> None:
|
||||
"""Append JSONL records as a single text blob (with trailing newline)."""
|
||||
blob = "".join(json.dumps(r) + "\n" for r in records)
|
||||
with path.open("a", encoding="utf-8") as f:
|
||||
f.write(blob)
|
||||
|
||||
|
||||
# --- construction validation ------------------------------------------------
|
||||
|
||||
|
||||
def test_init_rejects_nonpositive_poll_interval(tmp_path: Path) -> None:
|
||||
with pytest.raises(ValueError, match="poll_interval"):
|
||||
JsonlWatcher(tmp_path / "x.jsonl", poll_interval=0)
|
||||
with pytest.raises(ValueError, match="poll_interval"):
|
||||
JsonlWatcher(tmp_path / "x.jsonl", poll_interval=-1)
|
||||
|
||||
|
||||
def test_init_rejects_negative_start_offset(tmp_path: Path) -> None:
|
||||
with pytest.raises(ValueError, match="start_offset"):
|
||||
JsonlWatcher(tmp_path / "x.jsonl", start_offset=-1)
|
||||
|
||||
|
||||
def test_init_rejects_nonpositive_read_chunk(tmp_path: Path) -> None:
|
||||
with pytest.raises(ValueError, match="read_chunk"):
|
||||
JsonlWatcher(tmp_path / "x.jsonl", read_chunk=0)
|
||||
|
||||
|
||||
def test_path_is_exposed(tmp_path: Path) -> None:
|
||||
p = tmp_path / "x.jsonl"
|
||||
w = JsonlWatcher(p)
|
||||
assert w.path == p
|
||||
assert w.offset == 0
|
||||
|
||||
|
||||
# --- read_once: synchronous behavior ---------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_once_returns_empty_when_file_missing(tmp_path: Path) -> None:
|
||||
w = JsonlWatcher(tmp_path / "missing.jsonl")
|
||||
assert await w.read_once() == []
|
||||
# Offset must not advance when there's nothing to read.
|
||||
assert w.offset == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_once_returns_all_existing_records(tmp_path: Path) -> None:
|
||||
p = tmp_path / "s.jsonl"
|
||||
records = [
|
||||
{"type": "user", "i": 0},
|
||||
{"type": "assistant", "i": 1},
|
||||
{"type": "system", "i": 2},
|
||||
]
|
||||
_write_records(p, records)
|
||||
|
||||
w = JsonlWatcher(p)
|
||||
got = await w.read_once()
|
||||
assert got == records
|
||||
# Offset should now be at EOF.
|
||||
assert w.offset == p.stat().st_size
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_once_is_incremental(tmp_path: Path) -> None:
|
||||
p = tmp_path / "s.jsonl"
|
||||
_write_records(p, [{"i": 0}])
|
||||
w = JsonlWatcher(p)
|
||||
assert await w.read_once() == [{"i": 0}]
|
||||
|
||||
# Second pass with no new bytes: empty.
|
||||
assert await w.read_once() == []
|
||||
|
||||
# Append more — only the new ones come out.
|
||||
_write_records(p, [{"i": 1}, {"i": 2}])
|
||||
assert await w.read_once() == [{"i": 1}, {"i": 2}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_once_buffers_partial_line(tmp_path: Path) -> None:
|
||||
p = tmp_path / "s.jsonl"
|
||||
# Write a complete record + a partial record (no trailing newline).
|
||||
rec1 = {"complete": True}
|
||||
partial = '{"complete":'
|
||||
with p.open("w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(rec1) + "\n")
|
||||
f.write(partial) # no newline
|
||||
|
||||
w = JsonlWatcher(p)
|
||||
assert await w.read_once() == [rec1]
|
||||
# Offset has consumed the partial bytes too — they're stashed internally.
|
||||
assert w.offset == p.stat().st_size
|
||||
|
||||
# Now finish the partial line.
|
||||
with p.open("a", encoding="utf-8") as f:
|
||||
f.write(" false}\n")
|
||||
|
||||
assert await w.read_once() == [{"complete": False}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_once_skips_blank_lines(tmp_path: Path) -> None:
|
||||
p = tmp_path / "s.jsonl"
|
||||
# Mix in some blank lines — the watcher should ignore them rather than
|
||||
# treat them as parse errors.
|
||||
with p.open("w", encoding="utf-8") as f:
|
||||
f.write("\n")
|
||||
f.write(json.dumps({"i": 0}) + "\n")
|
||||
f.write(" \n")
|
||||
f.write(json.dumps({"i": 1}) + "\n")
|
||||
f.write("\n")
|
||||
|
||||
w = JsonlWatcher(p)
|
||||
assert await w.read_once() == [{"i": 0}, {"i": 1}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_once_invokes_parse_error_callback(tmp_path: Path) -> None:
|
||||
p = tmp_path / "s.jsonl"
|
||||
with p.open("w", encoding="utf-8") as f:
|
||||
f.write(json.dumps({"i": 0}) + "\n")
|
||||
f.write("this is not json\n")
|
||||
f.write(json.dumps({"i": 2}) + "\n")
|
||||
|
||||
errors: list[tuple[bytes, Exception]] = []
|
||||
w = JsonlWatcher(p, on_parse_error=lambda line, exc: errors.append((line, exc)))
|
||||
got = await w.read_once()
|
||||
# Bad line skipped; valid ones returned.
|
||||
assert got == [{"i": 0}, {"i": 2}]
|
||||
assert len(errors) == 1
|
||||
bad_line, exc = errors[0]
|
||||
assert bad_line == b"this is not json"
|
||||
assert isinstance(exc, json.JSONDecodeError)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_once_drops_malformed_silently_without_callback(tmp_path: Path) -> None:
|
||||
p = tmp_path / "s.jsonl"
|
||||
with p.open("w", encoding="utf-8") as f:
|
||||
f.write("garbage\n")
|
||||
f.write(json.dumps({"i": 1}) + "\n")
|
||||
|
||||
w = JsonlWatcher(p) # no callback
|
||||
assert await w.read_once() == [{"i": 1}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_once_handles_chunk_boundary(tmp_path: Path) -> None:
|
||||
"""A record larger than `read_chunk` must still come out whole."""
|
||||
p = tmp_path / "s.jsonl"
|
||||
big = {"payload": "x" * 8000, "i": 0}
|
||||
small = {"i": 1}
|
||||
_write_records(p, [big, small])
|
||||
w = JsonlWatcher(p, read_chunk=128) # force many chunks per record
|
||||
assert await w.read_once() == [big, small]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_offset_skips_initial_content(tmp_path: Path) -> None:
|
||||
p = tmp_path / "s.jsonl"
|
||||
_write_records(p, [{"i": 0}, {"i": 1}])
|
||||
initial_size = p.stat().st_size
|
||||
|
||||
# Start a watcher pointed at EOF — it should see only future appends.
|
||||
w = JsonlWatcher(p, start_offset=initial_size)
|
||||
assert await w.read_once() == []
|
||||
|
||||
_write_records(p, [{"i": 2}])
|
||||
assert await w.read_once() == [{"i": 2}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_once_resets_on_truncation(tmp_path: Path) -> None:
|
||||
p = tmp_path / "s.jsonl"
|
||||
_write_records(p, [{"i": 0}, {"i": 1}])
|
||||
w = JsonlWatcher(p)
|
||||
assert await w.read_once() == [{"i": 0}, {"i": 1}]
|
||||
|
||||
# Truncate (or rotate) — write a brand-new shorter file.
|
||||
p.write_text(json.dumps({"reset": True}) + "\n", encoding="utf-8")
|
||||
assert await w.read_once() == [{"reset": True}]
|
||||
assert w.offset == p.stat().st_size
|
||||
|
||||
|
||||
# --- wait_for_file ----------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_file_returns_immediately_if_exists(tmp_path: Path) -> None:
|
||||
p = tmp_path / "exists.jsonl"
|
||||
p.write_text("", encoding="utf-8")
|
||||
w = JsonlWatcher(p, poll_interval=0.01)
|
||||
# If this doesn't return promptly we'd hang — wrap in a tight timeout.
|
||||
await asyncio.wait_for(w.wait_for_file(timeout=1.0), timeout=1.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_file_picks_up_late_creation(tmp_path: Path) -> None:
|
||||
p = tmp_path / "later.jsonl"
|
||||
w = JsonlWatcher(p, poll_interval=0.01)
|
||||
|
||||
async def create_later() -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
p.write_text("", encoding="utf-8")
|
||||
|
||||
creator = asyncio.create_task(create_later())
|
||||
try:
|
||||
await asyncio.wait_for(w.wait_for_file(timeout=1.0), timeout=1.0)
|
||||
finally:
|
||||
await creator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_file_times_out(tmp_path: Path) -> None:
|
||||
p = tmp_path / "never.jsonl"
|
||||
w = JsonlWatcher(p, poll_interval=0.01)
|
||||
with pytest.raises(TimeoutError):
|
||||
await w.wait_for_file(timeout=0.05)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_file_rejects_negative_timeout(tmp_path: Path) -> None:
|
||||
w = JsonlWatcher(tmp_path / "x.jsonl")
|
||||
with pytest.raises(ValueError, match="timeout"):
|
||||
await w.wait_for_file(timeout=-1)
|
||||
|
||||
|
||||
# --- tail: long-running async iteration ------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tail_yields_existing_records_first(tmp_path: Path) -> None:
|
||||
p = tmp_path / "s.jsonl"
|
||||
_write_records(p, [{"i": 0}, {"i": 1}])
|
||||
w = JsonlWatcher(p, poll_interval=0.01)
|
||||
|
||||
seen: list[dict] = []
|
||||
|
||||
async def consume() -> None:
|
||||
async for rec in w.tail():
|
||||
seen.append(rec)
|
||||
if len(seen) >= 2:
|
||||
return
|
||||
|
||||
await asyncio.wait_for(consume(), timeout=2.0)
|
||||
assert seen == [{"i": 0}, {"i": 1}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tail_waits_for_file_then_yields(tmp_path: Path) -> None:
|
||||
p = tmp_path / "delayed.jsonl"
|
||||
w = JsonlWatcher(p, poll_interval=0.01)
|
||||
|
||||
seen: list[dict] = []
|
||||
|
||||
async def consume() -> None:
|
||||
async for rec in w.tail():
|
||||
seen.append(rec)
|
||||
if len(seen) >= 1:
|
||||
return
|
||||
|
||||
async def produce() -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
_write_records(p, [{"late": True}])
|
||||
|
||||
consumer = asyncio.create_task(consume())
|
||||
producer = asyncio.create_task(produce())
|
||||
await asyncio.wait_for(asyncio.gather(consumer, producer), timeout=2.0)
|
||||
assert seen == [{"late": True}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tail_streams_incremental_appends(tmp_path: Path) -> None:
|
||||
p = tmp_path / "s.jsonl"
|
||||
p.write_text("", encoding="utf-8")
|
||||
w = JsonlWatcher(p, poll_interval=0.01)
|
||||
|
||||
seen: list[dict] = []
|
||||
target = [{"i": 0}, {"i": 1}, {"i": 2}, {"i": 3}]
|
||||
|
||||
async def consume() -> None:
|
||||
async for rec in w.tail():
|
||||
seen.append(rec)
|
||||
if len(seen) >= len(target):
|
||||
return
|
||||
|
||||
async def produce() -> None:
|
||||
for rec in target:
|
||||
_write_records(p, [rec])
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
consumer = asyncio.create_task(consume())
|
||||
producer = asyncio.create_task(produce())
|
||||
await asyncio.wait_for(asyncio.gather(consumer, producer), timeout=3.0)
|
||||
assert seen == target
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tail_handles_appends_arriving_mid_line(tmp_path: Path) -> None:
|
||||
"""A record split across two writes (no newline in the first) must arrive
|
||||
as one parsed record once the second chunk lands."""
|
||||
p = tmp_path / "s.jsonl"
|
||||
p.write_text("", encoding="utf-8")
|
||||
w = JsonlWatcher(p, poll_interval=0.01)
|
||||
|
||||
seen: list[dict] = []
|
||||
|
||||
async def consume() -> None:
|
||||
async for rec in w.tail():
|
||||
seen.append(rec)
|
||||
if len(seen) >= 1:
|
||||
return
|
||||
|
||||
async def produce() -> None:
|
||||
# Write the first half, sleep past at least one poll, then the rest.
|
||||
with p.open("a", encoding="utf-8") as f:
|
||||
f.write('{"split":')
|
||||
f.flush()
|
||||
await asyncio.sleep(0.05)
|
||||
with p.open("a", encoding="utf-8") as f:
|
||||
f.write(" true}\n")
|
||||
f.flush()
|
||||
|
||||
consumer = asyncio.create_task(consume())
|
||||
producer = asyncio.create_task(produce())
|
||||
await asyncio.wait_for(asyncio.gather(consumer, producer), timeout=2.0)
|
||||
assert seen == [{"split": True}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tail_is_cancellable(tmp_path: Path) -> None:
|
||||
p = tmp_path / "s.jsonl"
|
||||
p.write_text("", encoding="utf-8")
|
||||
w = JsonlWatcher(p, poll_interval=0.01)
|
||||
|
||||
async def consume() -> None:
|
||||
async for _ in w.tail():
|
||||
pass
|
||||
|
||||
task = asyncio.create_task(consume())
|
||||
# Give it a few poll ticks to settle into the idle loop, then cancel.
|
||||
await asyncio.sleep(0.05)
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
Reference in New Issue
Block a user