Files
raycast-api/tests/test_streaming.py
T

248 lines
8.0 KiB
Python

"""SSE parser tests.
Strategy: the parser is byte-driven. Most tests feed bytes in one shot;
the chunking tests feed the same input split at every possible byte
boundary and assert the output is identical.
"""
from __future__ import annotations
import json
from typing import AsyncIterator
import pytest
from raycast_api.client.streaming import SSEEvent, SSEParser, iter_sse
def _drain_bytes(data: bytes) -> list[SSEEvent]:
"""Run the parser over `data` in one feed, then flush."""
p = SSEParser()
out = list(p.feed(data))
out.extend(p.flush())
return out
def _drain_chunked(data: bytes, chunk_size: int) -> list[SSEEvent]:
p = SSEParser()
out: list[SSEEvent] = []
for i in range(0, len(data), chunk_size):
out.extend(p.feed(data[i : i + chunk_size]))
out.extend(p.flush())
return out
class TestBasicShape:
def test_single_event(self) -> None:
events = _drain_bytes(b'data: {"text":"hi"}\n\n')
assert len(events) == 1
evt = events[0]
assert evt.id is None
assert evt.event is None
assert evt.data == '{"text":"hi"}'
assert evt.json() == {"text": "hi"}
def test_id_and_data(self) -> None:
events = _drain_bytes(b'id: 7\ndata: {"text":"hi"}\n\n')
assert len(events) == 1
assert events[0].id == "7"
assert events[0].data == '{"text":"hi"}'
def test_event_field(self) -> None:
events = _drain_bytes(b'event: complete\ndata: {"complete":true}\n\n')
assert events[0].event == "complete"
assert events[0].is_terminal
def test_default_event_is_none(self) -> None:
events = _drain_bytes(b"data: x\n\n")
assert events[0].event is None
def test_done_terminator_legacy(self) -> None:
events = _drain_bytes(b"data: [DONE]\n\n")
assert events[0].is_terminal
assert not events[0].is_error
class TestLineEndings:
def test_crlf_line_endings(self) -> None:
events = _drain_bytes(b'id: 1\r\ndata: {"text":"x"}\r\n\r\n')
assert len(events) == 1
assert events[0].id == "1"
assert events[0].data == '{"text":"x"}'
def test_mixed_endings(self) -> None:
events = _drain_bytes(b'id: 1\ndata: {"text":"x"}\r\n\n')
assert len(events) == 1
assert events[0].id == "1"
def test_lone_cr_is_not_a_terminator(self) -> None:
events = _drain_bytes(b"data: a\rcontinued\n\n")
assert len(events) == 1
assert "\r" in events[0].data
class TestMultilineData:
def test_two_data_lines_joined_with_newline(self) -> None:
events = _drain_bytes(b"data: line1\ndata: line2\n\n")
assert events[0].data == "line1\nline2"
def test_empty_data_line_still_contributes(self) -> None:
events = _drain_bytes(b"data: line1\ndata:\ndata: line3\n\n")
assert events[0].data == "line1\n\nline3"
class TestCommentsAndUnknownFields:
def test_colon_prefix_is_comment(self) -> None:
events = _drain_bytes(b": keepalive\ndata: x\n\n")
assert len(events) == 1
assert events[0].data == "x"
def test_unknown_field_dropped(self) -> None:
events = _drain_bytes(b"retry: 5000\ndata: x\n\n")
assert len(events) == 1
assert events[0].data == "x"
def test_only_comments_no_event(self) -> None:
events = _drain_bytes(b": just\n: comments\n\n")
assert events == []
def test_line_without_colon(self) -> None:
events = _drain_bytes(b"data\n\n")
assert len(events) == 1
assert events[0].data == ""
class TestFieldParsing:
def test_single_space_after_colon_is_stripped(self) -> None:
events = _drain_bytes(b"data: x\n\n")
assert events[0].data == "x"
def test_no_space_after_colon(self) -> None:
events = _drain_bytes(b"data:x\n\n")
assert events[0].data == "x"
def test_multiple_spaces_preserve_second(self) -> None:
events = _drain_bytes(b"data: x\n\n")
assert events[0].data == " x"
def test_extra_colons_in_value(self) -> None:
events = _drain_bytes(b'data: {"k":"v"}\n\n')
assert events[0].data == '{"k":"v"}'
class TestEventBoundaries:
def test_two_events(self) -> None:
data = b"id: 1\ndata: a\n\nid: 2\ndata: b\n\n"
events = _drain_bytes(data)
assert [e.id for e in events] == ["1", "2"]
assert [e.data for e in events] == ["a", "b"]
def test_no_trailing_blank_line_flushes_at_eof(self) -> None:
events = _drain_bytes(b"data: tail")
assert len(events) == 1
assert events[0].data == "tail"
class TestChunkingRobustness:
"""Same bytes split at every possible boundary must yield identical events."""
PAYLOAD = (
b'id: 0\ndata: {"reasoning":"","text":""}\n\n'
b'id: 1\ndata: {"text":"hello"}\n\n'
b'event: complete\ndata: {"complete":true}\n\n'
)
def _expected(self) -> list[tuple[str | None, str | None, str]]:
ref = _drain_bytes(self.PAYLOAD)
return [(e.id, e.event, e.data) for e in ref]
@pytest.mark.parametrize("size", [1, 2, 3, 5, 7, 13, 64])
def test_split_at_arbitrary_size(self, size: int) -> None:
chunks = _drain_chunked(self.PAYLOAD, size)
observed = [(e.id, e.event, e.data) for e in chunks]
assert observed == self._expected()
def test_split_inside_field_value(self) -> None:
a = b'id: 1\ndata: {"text":"hel'
b = b'lo"}\n\n'
p = SSEParser()
events = list(p.feed(a))
assert events == []
events.extend(p.feed(b))
events.extend(p.flush())
assert len(events) == 1
assert events[0].json() == {"text": "hello"}
def test_split_between_cr_and_lf(self) -> None:
a = b"id: 1\r"
b = b"\ndata: x\r\n\r\n"
p = SSEParser()
out = list(p.feed(a))
out.extend(p.feed(b))
out.extend(p.flush())
assert len(out) == 1
assert out[0].id == "1"
assert out[0].data == "x"
class TestLastEventIdTracking:
def test_advances_with_id_field(self) -> None:
p = SSEParser()
list(p.feed(b"id: 5\ndata: a\n\n"))
assert p.last_event_id == "5"
list(p.feed(b"id: 9\ndata: b\n\n"))
assert p.last_event_id == "9"
def test_event_without_id_keeps_previous(self) -> None:
p = SSEParser()
list(p.feed(b"id: 5\ndata: a\n\ndata: b\n\n"))
assert p.last_event_id == "5"
def test_starts_none(self) -> None:
p = SSEParser()
assert p.last_event_id is None
class TestIterSseAsync:
@pytest.mark.asyncio
async def test_async_iteration_over_chunks(self) -> None:
async def gen() -> AsyncIterator[bytes]:
yield b"id: 1\nda"
yield b'ta: {"text":"a"}\n\n'
yield b'id: 2\ndata: {"text":"b"}\n\nevent: complete\ndata: {"complete":true}\n\n'
seen = [e async for e in iter_sse(gen())]
assert [e.id for e in seen] == ["1", "2", None]
assert [e.event for e in seen] == [None, None, "complete"]
assert seen[-1].is_terminal
@pytest.mark.asyncio
async def test_empty_stream(self) -> None:
async def gen() -> AsyncIterator[bytes]:
if False:
yield b""
assert [e async for e in iter_sse(gen())] == []
class TestErrorEvent:
def test_error_event_recognised(self) -> None:
events = _drain_bytes(b'event: error\ndata: {"message":"boom"}\n\n')
assert events[0].is_error
assert events[0].json() == {"message": "boom"}
def test_default_event_not_error(self) -> None:
events = _drain_bytes(b'data: {"text":"x"}\n\n')
assert events[0].is_error is False
class TestJsonParsing:
def test_invalid_json_raises_on_demand(self) -> None:
events = _drain_bytes(b"data: not json\n\n")
with pytest.raises(json.JSONDecodeError):
events[0].json()
def test_event_data_kept_raw(self) -> None:
events = _drain_bytes(b"data: not json\n\n")
assert events[0].data == "not json"