Files
claude-code-api/tests/test_turn.py
T

935 lines
31 KiB
Python

"""Unit + smoke tests for Layer 4 (`TurnManager`).
Unit tests use a `FakePty` that, on `write()`, dumps a scripted list of JSONL
records into a real temp file. A real `JsonlWatcher` tails that file so the
manager's read/normalize/turn-end loop is exercised end-to-end without
launching `claude`. The smoke test at the bottom spawns the real binary
behind `RUN_CLAUDE_SMOKE=1` and also serves as the empirical probe for
Open Q #2 (PTY echo / buffering).
"""
from __future__ import annotations
import asyncio
import json
import os
import sys
from pathlib import Path
from typing import Any
import pytest
from claude_code_api import (
AssistantMessage,
AuthError,
ProcessError,
RateLimitError,
ResultMessage,
SessionError,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from claude_code_api.paths import resolve_jsonl_path
from claude_code_api.watcher import JsonlWatcher
from claude_code_api.pty import PtyClaudeProcess, PtyProcessOptions
from claude_code_api.turn import TurnManager
# --- fakes -----------------------------------------------------------------
class FakePty:
"""Stand-in for `PtyClaudeProcess` that flushes a scripted JSONL batch on write.
The script is a list of records that get appended to `jsonl_path` (one
JSON object per line) as soon as the manager calls `write()`. This lets
a single synchronous setup drive the full turn loop — no async
coordination, no real `claude`. Multi-write scripts are supported: the
Nth `write()` flushes the Nth element of `scripts`.
Stage 10 additions: `alive` and `output` knobs let tests simulate
sub-process death and error chrome captured from the PTY drain buffer,
which `TurnManager` consults when classifying failures.
"""
def __init__(
self,
tmp_path: Path,
*,
session_id: str = "fake-session-0001",
scripts: list[list[dict[str, Any]]] | None = None,
alive: bool = True,
output: bytes = b"",
) -> None:
self.cwd = str(tmp_path)
self.session_id = session_id
self._jsonl = tmp_path / f"{session_id}.jsonl"
self._scripts = scripts if scripts is not None else []
self._write_count = 0
self.writes: list[str] = []
self.started = False
self.closed = False
self._alive = alive
self._output = output
async def start(self) -> None:
self.started = True
async def write(self, text: str, *, newline: bool = True) -> int:
self.writes.append(text)
if self._write_count < len(self._scripts):
records = self._scripts[self._write_count]
with self._jsonl.open("a", encoding="utf-8") as f:
for rec in records:
f.write(json.dumps(rec) + "\n")
self._write_count += 1
return len(text)
async def aclose(self) -> None:
self.closed = True
# --- Stage 10 surface ----------------------------------------------
def is_alive(self) -> bool:
return self._alive
def captured_output(self) -> bytes:
return self._output
def set_alive(self, alive: bool) -> None:
self._alive = alive
def set_output(self, output: bytes) -> None:
self._output = output
def _user_rec(text: str) -> dict[str, Any]:
return {
"type": "user",
"uuid": f"u-{text[:8]}",
"sessionId": "fake-session-0001",
"parentUuid": None,
"message": {"role": "user", "content": text},
}
def _assistant_rec(
text: str,
*,
stop_reason: str | None = "end_turn",
usage: dict[str, Any] | None = None,
) -> dict[str, Any]:
return {
"type": "assistant",
"uuid": f"a-{text[:8]}",
"sessionId": "fake-session-0001",
"parentUuid": None,
"message": {
"id": "msg_x",
"role": "assistant",
"model": "claude-test",
"content": [{"type": "text", "text": text}],
"stop_reason": stop_reason,
"usage": usage or {"input_tokens": 1, "output_tokens": 1},
},
}
def _tool_use_assistant_rec(name: str, tool_id: str) -> dict[str, Any]:
return {
"type": "assistant",
"uuid": f"a-tu-{tool_id}",
"sessionId": "fake-session-0001",
"parentUuid": None,
"message": {
"id": "msg_y",
"role": "assistant",
"model": "claude-test",
"content": [{"type": "tool_use", "id": tool_id, "name": name, "input": {}}],
"stop_reason": "tool_use",
"usage": {"input_tokens": 1, "output_tokens": 1},
},
}
def _tool_result_user_rec(tool_id: str, content: str) -> dict[str, Any]:
return {
"type": "user",
"uuid": f"u-tr-{tool_id}",
"sessionId": "fake-session-0001",
"parentUuid": None,
"message": {
"role": "user",
"content": [{"type": "tool_result", "tool_use_id": tool_id, "content": content}],
},
}
def _turn_duration_rec(duration_ms: int = 1234) -> dict[str, Any]:
return {
"type": "system",
"subtype": "turn_duration",
"uuid": "sys-td",
"sessionId": "fake-session-0001",
"durationMs": duration_ms,
}
def _make_manager(
fake: FakePty,
*,
wait_for_turn_duration: bool = False,
startup_delay: float = 0.0,
turn_duration_timeout: float | None = 1.0,
on_parse_error: Any = None,
) -> TurnManager:
"""Build a TurnManager wired to a real JsonlWatcher on the fake's path."""
watcher = JsonlWatcher(
Path(fake.cwd) / f"{fake.session_id}.jsonl",
poll_interval=0.01,
)
return TurnManager(
fake, # type: ignore[arg-type]
watcher,
wait_for_turn_duration=wait_for_turn_duration,
startup_delay=startup_delay,
turn_duration_timeout=turn_duration_timeout,
on_parse_error=on_parse_error,
)
# --- construction validation ----------------------------------------------
def test_init_rejects_negative_file_wait_timeout(tmp_path: Path) -> None:
fake = FakePty(tmp_path)
watcher = JsonlWatcher(tmp_path / "x.jsonl")
with pytest.raises(ValueError, match="file_wait_timeout"):
TurnManager(fake, watcher, file_wait_timeout=-1) # type: ignore[arg-type]
def test_init_rejects_negative_startup_delay(tmp_path: Path) -> None:
fake = FakePty(tmp_path)
watcher = JsonlWatcher(tmp_path / "x.jsonl")
with pytest.raises(ValueError, match="startup_delay"):
TurnManager(fake, watcher, startup_delay=-0.5) # type: ignore[arg-type]
def test_init_rejects_negative_turn_duration_timeout(tmp_path: Path) -> None:
fake = FakePty(tmp_path)
watcher = JsonlWatcher(tmp_path / "x.jsonl")
with pytest.raises(ValueError, match="turn_duration_timeout"):
TurnManager(fake, watcher, turn_duration_timeout=-1) # type: ignore[arg-type]
# --- lifecycle guards -----------------------------------------------------
@pytest.mark.asyncio
async def test_send_before_start_raises(tmp_path: Path) -> None:
fake = FakePty(tmp_path)
tm = _make_manager(fake)
with pytest.raises(RuntimeError, match="before start"):
async for _ in tm.send_user_message("hi"):
pass
@pytest.mark.asyncio
async def test_start_is_idempotent(tmp_path: Path) -> None:
fake = FakePty(tmp_path)
tm = _make_manager(fake)
await tm.start()
await tm.start()
# FakePty.start() flips `started` either way; we just need no exception
# and a stable state machine.
assert fake.started is True
# --- happy path: one turn, terminal end_turn -------------------------------
@pytest.mark.asyncio
async def test_basic_turn_yields_user_assistant_then_result(tmp_path: Path) -> None:
fake = FakePty(
tmp_path,
scripts=[
[
_user_rec("say hi"),
_assistant_rec("hi!", stop_reason="end_turn"),
# turn_duration is in the script but with
# wait_for_turn_duration=False it gets queued behind our
# early return — we don't yield it.
_turn_duration_rec(),
]
],
)
tm = _make_manager(fake)
await tm.start()
events: list[Any] = []
async for event in tm.send_user_message("say hi"):
events.append(event)
await tm.aclose()
assert fake.writes == ["say hi"]
assert isinstance(events[0], UserMessage)
assert isinstance(events[1], AssistantMessage)
assert events[1].stop_reason == "end_turn"
assert isinstance(events[1].content[0], TextBlock)
assert isinstance(events[-1], ResultMessage)
assert events[-1].stop_reason == "end_turn"
assert events[-1].num_turns == 1
assert events[-1].session_id == fake.session_id
# No turn_duration → duration_ms falls back to 0 in the synthesized result.
assert events[-1].duration_ms == 0
@pytest.mark.asyncio
async def test_wait_for_turn_duration_carries_duration_ms(tmp_path: Path) -> None:
fake = FakePty(
tmp_path,
scripts=[
[
_user_rec("ping"),
_assistant_rec("pong", stop_reason="end_turn"),
_turn_duration_rec(duration_ms=4242),
]
],
)
tm = _make_manager(fake, wait_for_turn_duration=True)
await tm.start()
events = [e async for e in tm.send_user_message("ping")]
await tm.aclose()
# We also want the system event itself to be visible in the stream.
assert any(isinstance(e, SystemMessage) and e.subtype == "turn_duration" for e in events)
result = events[-1]
assert isinstance(result, ResultMessage)
assert result.duration_ms == 4242
# --- tool loop continues until next terminal -----------------------------
@pytest.mark.asyncio
async def test_tool_use_stop_reason_does_not_close_turn(tmp_path: Path) -> None:
fake = FakePty(
tmp_path,
scripts=[
[
_user_rec("compute"),
_tool_use_assistant_rec("Bash", "tool_1"),
_tool_result_user_rec("tool_1", "42"),
_assistant_rec("the answer is 42", stop_reason="end_turn"),
]
],
)
tm = _make_manager(fake)
await tm.start()
events = [e async for e in tm.send_user_message("compute")]
await tm.aclose()
assistants = [e for e in events if isinstance(e, AssistantMessage)]
# Both assistant records made it through — the tool_use one did not
# short-circuit the loop.
assert len(assistants) == 2
assert assistants[0].stop_reason == "tool_use"
assert assistants[1].stop_reason == "end_turn"
assert isinstance(events[-1], ResultMessage)
assert events[-1].stop_reason == "end_turn"
# --- error & misuse paths -------------------------------------------------
@pytest.mark.asyncio
async def test_parse_error_callback_keeps_stream_alive(tmp_path: Path) -> None:
# A bogus record (missing `message`) sits between two valid ones. The
# callback should fire once and the stream should still terminate cleanly.
bad = {"type": "assistant", "uuid": "x", "sessionId": "fake-session-0001"}
fake = FakePty(
tmp_path,
scripts=[
[
_user_rec("hi"),
bad,
_assistant_rec("ok", stop_reason="end_turn"),
]
],
)
errors: list[tuple[Exception, dict[str, Any]]] = []
tm = _make_manager(fake, on_parse_error=lambda exc, rec: errors.append((exc, rec)))
await tm.start()
events = [e async for e in tm.send_user_message("hi")]
await tm.aclose()
assert len(errors) == 1
assert errors[0][1] is bad or errors[0][1] == bad
assert isinstance(events[-1], ResultMessage)
@pytest.mark.asyncio
async def test_double_send_raises_while_turn_in_progress(tmp_path: Path) -> None:
# Manager that will NEVER see a terminal assistant (no scripted records).
# Drive one __anext__ on the first generator so it enters the polling loop,
# then attempt a second concurrent send.
fake = FakePty(tmp_path, scripts=[[]])
# Touch the file so the file-wait doesn't block forever.
(tmp_path / f"{fake.session_id}.jsonl").touch()
tm = _make_manager(fake)
await tm.start()
gen1 = tm.send_user_message("first")
# Spin up the generator: schedule one read pass.
task = asyncio.create_task(gen1.__anext__())
await asyncio.sleep(0.05) # let _iter_turn flip turn_in_progress
with pytest.raises(RuntimeError, match="turn is in progress"):
async for _ in tm.send_user_message("second"):
pass
task.cancel()
with pytest.raises((asyncio.CancelledError, StopAsyncIteration)):
await task
await tm.aclose()
@pytest.mark.asyncio
async def test_aclose_terminates_owned_pty(tmp_path: Path) -> None:
fake = FakePty(tmp_path)
tm = _make_manager(fake)
await tm.start()
await tm.aclose()
assert fake.closed is True
@pytest.mark.asyncio
async def test_aclose_skips_pty_when_not_owned(tmp_path: Path) -> None:
fake = FakePty(tmp_path)
watcher = JsonlWatcher(tmp_path / f"{fake.session_id}.jsonl", poll_interval=0.01)
tm = TurnManager(fake, watcher, owns_pty=False, startup_delay=0.0) # type: ignore[arg-type]
await tm.start()
await tm.aclose()
assert fake.closed is False
# --- Stage 10: error mapping ---------------------------------------------
@pytest.mark.asyncio
async def test_session_error_raised_when_jsonl_never_appears(tmp_path: Path) -> None:
"""No script → FakePty.write() doesn't create the JSONL → the
file-wait timeout fires → TurnManager raises SessionError (not the
raw asyncio.TimeoutError)."""
fake = FakePty(tmp_path, scripts=[]) # write() is a no-op for JSONL
watcher = JsonlWatcher(
tmp_path / f"{fake.session_id}.jsonl",
poll_interval=0.01,
)
tm = TurnManager(
fake, # type: ignore[arg-type]
watcher,
startup_delay=0.0,
file_wait_timeout=0.05, # fire fast
)
await tm.start()
with pytest.raises(SessionError):
async for _ in tm.send_user_message("hi"):
pass
await tm.aclose()
@pytest.mark.asyncio
async def test_auth_marker_in_pty_output_raises_auth_error(tmp_path: Path) -> None:
"""When the JSONL never appears AND captured PTY output carries an
auth-block marker, the classifier promotes the failure to AuthError
(instead of the generic SessionError)."""
fake = FakePty(
tmp_path,
scripts=[],
output=b"Failed to authenticate. Please run /login.\r\n",
)
watcher = JsonlWatcher(
tmp_path / f"{fake.session_id}.jsonl",
poll_interval=0.01,
)
tm = TurnManager(
fake, # type: ignore[arg-type]
watcher,
startup_delay=0.0,
file_wait_timeout=0.05,
)
await tm.start()
with pytest.raises(AuthError):
async for _ in tm.send_user_message("hi"):
pass
await tm.aclose()
@pytest.mark.asyncio
async def test_rate_limit_marker_promotes_session_error_to_rate_limit(
tmp_path: Path,
) -> None:
"""Same path as the auth case but with a rate-limit marker."""
fake = FakePty(
tmp_path,
scripts=[],
output=b"\x1b[31mYou've hit your limit\x1b[0m. Try again at 9pm.",
)
watcher = JsonlWatcher(
tmp_path / f"{fake.session_id}.jsonl",
poll_interval=0.01,
)
tm = TurnManager(
fake, # type: ignore[arg-type]
watcher,
startup_delay=0.0,
file_wait_timeout=0.05,
)
await tm.start()
with pytest.raises(RateLimitError):
async for _ in tm.send_user_message("hi"):
pass
await tm.aclose()
@pytest.mark.asyncio
async def test_process_death_mid_poll_raises_process_error(tmp_path: Path) -> None:
"""The JSONL appears (so we leave the wait-for-file phase) but no
terminal assistant ever arrives AND the PTY reports dead. Detection
fires from inside the poll loop, with the captured output included in
the exception so a gateway can log what claude wrote before exiting.
"""
fake = FakePty(
tmp_path,
scripts=[[_user_rec("hi")]], # only the user record — no assistant
output=b"some claude chrome before death\r\n",
)
watcher = JsonlWatcher(
tmp_path / f"{fake.session_id}.jsonl",
poll_interval=0.01,
)
tm = TurnManager(
fake, # type: ignore[arg-type]
watcher,
startup_delay=0.0,
file_wait_timeout=2.0,
)
await tm.start()
async def consumer() -> list[Any]:
events: list[Any] = []
async for ev in tm.send_user_message("hi"):
events.append(ev)
# Once we've seen the user record, declare the PTY dead so the
# next polling pass enters the failure branch.
if isinstance(ev, UserMessage):
fake.set_alive(False)
return events
with pytest.raises(ProcessError) as info:
await consumer()
assert "exited before a terminal" in str(info.value)
assert info.value.stderr is not None
assert "claude chrome before death" in info.value.stderr
await tm.aclose()
@pytest.mark.asyncio
async def test_process_death_with_rate_limit_marker_raises_rate_limit(
tmp_path: Path,
) -> None:
"""Process-death classifier defers to the PTY marker: if the buffer
carries a rate-limit notice, raise the typed marker, not the generic
ProcessError."""
fake = FakePty(
tmp_path,
scripts=[[_user_rec("hi")]],
output=b"You've hit your limit. Cooling off.",
)
watcher = JsonlWatcher(
tmp_path / f"{fake.session_id}.jsonl",
poll_interval=0.01,
)
tm = TurnManager(
fake, # type: ignore[arg-type]
watcher,
startup_delay=0.0,
file_wait_timeout=2.0,
)
await tm.start()
async def consumer() -> None:
async for ev in tm.send_user_message("hi"):
if isinstance(ev, UserMessage):
fake.set_alive(False)
with pytest.raises(RateLimitError):
await consumer()
await tm.aclose()
# --- multi-turn (Stage 6) -------------------------------------------------
@pytest.mark.asyncio
async def test_two_consecutive_turns_each_yield_only_fresh_records(tmp_path: Path) -> None:
"""Stage 6 core: a second `send_user_message()` on the same manager sees
only the records appended after the first turn ended.
The watcher is reused across turns and tracks the byte offset internally
(see PROGRESS.md decision log: "TurnManager does NOT own
JsonlWatcher.offset"). This test pins that contract.
"""
fake = FakePty(
tmp_path,
scripts=[
[
_user_rec("Q1"),
_assistant_rec("A1", stop_reason="end_turn"),
],
[
_user_rec("Q2"),
_assistant_rec("A2", stop_reason="end_turn"),
],
],
)
tm = _make_manager(fake)
await tm.start()
turn1 = [e async for e in tm.send_user_message("Q1")]
turn2 = [e async for e in tm.send_user_message("Q2")]
await tm.aclose()
assert fake.writes == ["Q1", "Q2"]
# Turn 1: user("Q1"), assistant("A1"), result
assert [type(e).__name__ for e in turn1] == [
"UserMessage",
"AssistantMessage",
"ResultMessage",
]
assert turn1[0].content == "Q1"
assert isinstance(turn1[1], AssistantMessage)
assert isinstance(turn1[1].content[0], TextBlock)
assert turn1[1].content[0].text == "A1"
assert isinstance(turn1[-1], ResultMessage)
assert turn1[-1].num_turns == 1
# Turn 2 must NOT leak any of turn 1's records back to the caller.
assert [type(e).__name__ for e in turn2] == [
"UserMessage",
"AssistantMessage",
"ResultMessage",
]
assert turn2[0].content == "Q2"
assert isinstance(turn2[1], AssistantMessage)
assert isinstance(turn2[1].content[0], TextBlock)
assert turn2[1].content[0].text == "A2"
# Turn-count bookkeeping increments across turns; session_id is stable.
assert isinstance(turn2[-1], ResultMessage)
assert turn2[-1].num_turns == 2
assert turn2[-1].session_id == turn1[-1].session_id == fake.session_id
assert tm.turn_count == 2
@pytest.mark.asyncio
async def test_multi_turn_with_wait_for_turn_duration_carries_each_duration(
tmp_path: Path,
) -> None:
"""When `wait_for_turn_duration=True`, each turn's synthesized result
carries its own duration. The watcher offset advances past the
intervening turn_duration heartbeat so turn 2 starts clean.
"""
fake = FakePty(
tmp_path,
scripts=[
[
_user_rec("ping1"),
_assistant_rec("pong1", stop_reason="end_turn"),
_turn_duration_rec(duration_ms=111),
],
[
_user_rec("ping2"),
_assistant_rec("pong2", stop_reason="end_turn"),
_turn_duration_rec(duration_ms=222),
],
],
)
tm = _make_manager(fake, wait_for_turn_duration=True)
await tm.start()
turn1 = [e async for e in tm.send_user_message("ping1")]
turn2 = [e async for e in tm.send_user_message("ping2")]
await tm.aclose()
assert isinstance(turn1[-1], ResultMessage)
assert turn1[-1].duration_ms == 111
assert turn1[-1].num_turns == 1
assert isinstance(turn2[-1], ResultMessage)
assert turn2[-1].duration_ms == 222
assert turn2[-1].num_turns == 2
# --- smoke test (real claude) ---------------------------------------------
_SMOKE_ENV = "RUN_CLAUDE_SMOKE"
@pytest.mark.skipif(
os.environ.get(_SMOKE_ENV) != "1",
reason=f"set {_SMOKE_ENV}=1 to run the real-`claude` smoke test",
)
@pytest.mark.asyncio
async def test_smoke_send_hi(tmp_path: Path) -> None:
"""Smoke 1: end-to-end one-turn against real claude.
Confirms: PTY spawn, JSONL discovery, watcher tail, normalizer mapping,
turn-end detection, and ResultMessage synthesis all line up. Also
doubles as the empirical probe for Open Q #2 — if claude doesn't pick up
our prompt after `pty.write("say hi\\r")`, the JSONL never grows and the
file-wait timeout fires; that failure mode tells us the carriage-return
+ 1s startup delay is not enough and we need a different submit
mechanism.
"""
opts = PtyProcessOptions(
cwd=str(tmp_path),
dangerously_skip_permissions=True,
)
pty = PtyClaudeProcess(opts)
jsonl_path = resolve_jsonl_path(pty.cwd, pty.session_id)
watcher = JsonlWatcher(jsonl_path)
tm = TurnManager(pty, watcher)
try:
await tm.start()
events: list[Any] = []
async for event in tm.send_user_message("say hi"):
events.append(event)
finally:
await tm.aclose()
assistants = [e for e in events if isinstance(e, AssistantMessage)]
assert assistants, (
f"no AssistantMessage in stream; got {[type(e).__name__ for e in events]}"
)
terminal = next(
(
a
for a in assistants
if a.stop_reason in {"end_turn", "max_tokens", "stop_sequence", "refusal"}
),
None,
)
assert terminal is not None, (
f"no terminal stop_reason; got {[a.stop_reason for a in assistants]}"
)
assert any(isinstance(b, TextBlock) for b in terminal.content)
assert isinstance(events[-1], ResultMessage)
assert events[-1].stop_reason == terminal.stop_reason
assert events[-1].session_id == pty.session_id
@pytest.mark.skipif(
os.environ.get(_SMOKE_ENV) != "1",
reason=f"set {_SMOKE_ENV}=1 to run the real-`claude` smoke test",
)
@pytest.mark.asyncio
async def test_smoke_multi_turn_context_persists(tmp_path: Path) -> None:
"""Smoke 2 (Stage 6): two turns on one TurnManager, the second must see
the first's context.
Turn 1 plants a memorable token via the user message; turn 2 asks for it
back. If the same `--session-id` PTY truly accumulates context (as the
JSONL design implies), the second assistant text contains the token. If
instead each turn ran isolated, the second reply would not know it.
The token is a low-entropy proper noun ("Beaver" — same one we used in
the JSONL injection probe) chosen to be unlikely-but-not-impossible to
appear spontaneously, so a false positive remains very unlikely while
keeping the prompt natural.
"""
opts = PtyProcessOptions(
cwd=str(tmp_path),
dangerously_skip_permissions=True,
)
pty = PtyClaudeProcess(opts)
jsonl_path = resolve_jsonl_path(pty.cwd, pty.session_id)
watcher = JsonlWatcher(jsonl_path)
tm = TurnManager(pty, watcher)
turn1_events: list[Any] = []
turn2_events: list[Any] = []
try:
await tm.start()
async for event in tm.send_user_message(
"Please remember: my name is Beaver. Reply with just 'ok'."
):
turn1_events.append(event)
async for event in tm.send_user_message(
"What is my name? Answer with the single word only."
):
turn2_events.append(event)
finally:
await tm.aclose()
# Both turns yielded a synthesized result; num_turns increments.
assert isinstance(turn1_events[-1], ResultMessage)
assert isinstance(turn2_events[-1], ResultMessage)
assert turn1_events[-1].num_turns == 1
assert turn2_events[-1].num_turns == 2
assert turn1_events[-1].session_id == turn2_events[-1].session_id == pty.session_id
assert tm.turn_count == 2
# Second turn's terminal assistant must reference the planted token.
turn2_assistants = [e for e in turn2_events if isinstance(e, AssistantMessage)]
terminal2 = next(
(
a
for a in turn2_assistants
if a.stop_reason in {"end_turn", "max_tokens", "stop_sequence", "refusal"}
),
None,
)
assert terminal2 is not None, (
f"no terminal stop_reason in turn 2; got {[a.stop_reason for a in turn2_assistants]}"
)
text2 = " ".join(b.text for b in terminal2.content if isinstance(b, TextBlock))
assert "beaver" in text2.lower(), (
f"turn 2 did not inherit context from turn 1; reply was: {text2!r}"
)
# --- Stage 7: tool calls via external MCP server -------------------------
_REPO_ROOT = Path(__file__).resolve().parent.parent
_ECHO_MCP_SCRIPT = _REPO_ROOT / "scripts" / "echo_mcp_server.py"
@pytest.mark.skipif(
os.environ.get(_SMOKE_ENV) != "1",
reason=f"set {_SMOKE_ENV}=1 to run the real-`claude` smoke test",
)
@pytest.mark.asyncio
async def test_smoke_tool_call_via_mcp(tmp_path: Path) -> None:
"""Smoke 3 (Stage 7): real claude routes a tool call through an external
stdio MCP server, and the resulting `tool_use` + `tool_result` records
surface as typed events.
Setup:
- `scripts/echo_mcp_server.py` is a zero-dep stdio MCP server with one
tool, `echo`, that returns its `text` argument verbatim.
- We point claude at it via a temp `--mcp-config` JSON file (one
server named "echo"). `--strict-mcp-config` keeps the user's
ambient `.mcp.json` from leaking in and changing the tool surface.
Assertions:
- At least one `AssistantMessage.content` carries a `ToolUseBlock`
whose name references the echo tool (claude exposes external MCP
tools as `mcp__<server>__<tool>`, here `mcp__echo__echo`).
- The follow-up `UserMessage` carries a `ToolResultBlock` whose
content includes the marker token we asked the tool to echo —
the only place that token can come from is the MCP server, so
seeing it round-tripped proves the full path worked.
- A terminal assistant closes the turn and the synthesized
`ResultMessage` reflects its stop_reason.
"""
assert _ECHO_MCP_SCRIPT.exists(), f"missing echo MCP server at {_ECHO_MCP_SCRIPT}"
marker = "banana42xyz" # low-collision sentinel; must appear in tool_result
mcp_config_path = tmp_path / "mcp_config.json"
mcp_config_path.write_text(
json.dumps(
{
"mcpServers": {
"echo": {
"command": sys.executable,
"args": [str(_ECHO_MCP_SCRIPT)],
},
},
}
)
)
opts = PtyProcessOptions(
cwd=str(tmp_path),
dangerously_skip_permissions=True,
mcp_config=(str(mcp_config_path),),
)
pty = PtyClaudeProcess(opts)
jsonl_path = resolve_jsonl_path(pty.cwd, pty.session_id)
watcher = JsonlWatcher(jsonl_path)
# External MCP servers spawn during claude's startup, so the input box
# mounts a bit later than for a bare session. The 60s file-wait still
# leaves headroom even on a slow first MCP handshake.
tm = TurnManager(pty, watcher, file_wait_timeout=60.0)
prompt = f"Call mcp__echo__echo with text={marker!r}, then reply 'done'."
events: list[Any] = []
try:
await tm.start()
async for event in tm.send_user_message(prompt):
events.append(event)
finally:
await tm.aclose()
# --- assertions ---
tool_uses: list[ToolUseBlock] = []
for ev in events:
if isinstance(ev, AssistantMessage):
tool_uses.extend(b for b in ev.content if isinstance(b, ToolUseBlock))
assert tool_uses, (
"no ToolUseBlock in any assistant message; got "
f"{[type(e).__name__ for e in events]}"
)
echo_uses = [t for t in tool_uses if "echo" in t.name.lower()]
assert echo_uses, (
f"no tool_use referenced the echo tool; saw names {[t.name for t in tool_uses]}"
)
# The marker text only exists on the MCP server side, so finding it in a
# tool_result block proves the round-trip actually completed.
tool_results: list[ToolResultBlock] = []
for ev in events:
if isinstance(ev, UserMessage) and isinstance(ev.content, list):
tool_results.extend(b for b in ev.content if isinstance(b, ToolResultBlock))
assert tool_results, "no ToolResultBlock in any user message after the tool call"
def _result_text(block: ToolResultBlock) -> str:
if isinstance(block.content, str):
return block.content
if isinstance(block.content, list):
chunks: list[str] = []
for part in block.content:
if isinstance(part, dict) and isinstance(part.get("text"), str):
chunks.append(part["text"])
return " ".join(chunks)
return ""
assert any(marker in _result_text(b) for b in tool_results), (
f"marker {marker!r} did not appear in any tool_result; got "
f"{[_result_text(b) for b in tool_results]}"
)
terminal_assistant = next(
(
ev
for ev in events
if isinstance(ev, AssistantMessage)
and ev.stop_reason in {"end_turn", "max_tokens", "stop_sequence", "refusal"}
),
None,
)
assert terminal_assistant is not None, (
"no terminal assistant after tool round-trip; got stop_reasons "
f"{[e.stop_reason for e in events if isinstance(e, AssistantMessage)]}"
)
assert isinstance(events[-1], ResultMessage)
assert events[-1].stop_reason == terminal_assistant.stop_reason