248 lines
8.0 KiB
Python
248 lines
8.0 KiB
Python
"""SSE parser tests.
|
|
|
|
Strategy: the parser is byte-driven. Most tests feed bytes in one shot;
|
|
the chunking tests feed the same input split at every possible byte
|
|
boundary and assert the output is identical.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import AsyncIterator
|
|
|
|
import pytest
|
|
|
|
from raycast_api.client.streaming import SSEEvent, SSEParser, iter_sse
|
|
|
|
|
|
def _drain_bytes(data: bytes) -> list[SSEEvent]:
|
|
"""Run the parser over `data` in one feed, then flush."""
|
|
p = SSEParser()
|
|
out = list(p.feed(data))
|
|
out.extend(p.flush())
|
|
return out
|
|
|
|
|
|
def _drain_chunked(data: bytes, chunk_size: int) -> list[SSEEvent]:
|
|
p = SSEParser()
|
|
out: list[SSEEvent] = []
|
|
for i in range(0, len(data), chunk_size):
|
|
out.extend(p.feed(data[i : i + chunk_size]))
|
|
out.extend(p.flush())
|
|
return out
|
|
|
|
|
|
class TestBasicShape:
|
|
def test_single_event(self) -> None:
|
|
events = _drain_bytes(b'data: {"text":"hi"}\n\n')
|
|
assert len(events) == 1
|
|
evt = events[0]
|
|
assert evt.id is None
|
|
assert evt.event is None
|
|
assert evt.data == '{"text":"hi"}'
|
|
assert evt.json() == {"text": "hi"}
|
|
|
|
def test_id_and_data(self) -> None:
|
|
events = _drain_bytes(b'id: 7\ndata: {"text":"hi"}\n\n')
|
|
assert len(events) == 1
|
|
assert events[0].id == "7"
|
|
assert events[0].data == '{"text":"hi"}'
|
|
|
|
def test_event_field(self) -> None:
|
|
events = _drain_bytes(b'event: complete\ndata: {"complete":true}\n\n')
|
|
assert events[0].event == "complete"
|
|
assert events[0].is_terminal
|
|
|
|
def test_default_event_is_none(self) -> None:
|
|
events = _drain_bytes(b"data: x\n\n")
|
|
assert events[0].event is None
|
|
|
|
def test_done_terminator_legacy(self) -> None:
|
|
events = _drain_bytes(b"data: [DONE]\n\n")
|
|
assert events[0].is_terminal
|
|
assert not events[0].is_error
|
|
|
|
|
|
class TestLineEndings:
|
|
def test_crlf_line_endings(self) -> None:
|
|
events = _drain_bytes(b'id: 1\r\ndata: {"text":"x"}\r\n\r\n')
|
|
assert len(events) == 1
|
|
assert events[0].id == "1"
|
|
assert events[0].data == '{"text":"x"}'
|
|
|
|
def test_mixed_endings(self) -> None:
|
|
events = _drain_bytes(b'id: 1\ndata: {"text":"x"}\r\n\n')
|
|
assert len(events) == 1
|
|
assert events[0].id == "1"
|
|
|
|
def test_lone_cr_is_not_a_terminator(self) -> None:
|
|
events = _drain_bytes(b"data: a\rcontinued\n\n")
|
|
assert len(events) == 1
|
|
assert "\r" in events[0].data
|
|
|
|
|
|
class TestMultilineData:
|
|
def test_two_data_lines_joined_with_newline(self) -> None:
|
|
events = _drain_bytes(b"data: line1\ndata: line2\n\n")
|
|
assert events[0].data == "line1\nline2"
|
|
|
|
def test_empty_data_line_still_contributes(self) -> None:
|
|
events = _drain_bytes(b"data: line1\ndata:\ndata: line3\n\n")
|
|
assert events[0].data == "line1\n\nline3"
|
|
|
|
|
|
class TestCommentsAndUnknownFields:
|
|
def test_colon_prefix_is_comment(self) -> None:
|
|
events = _drain_bytes(b": keepalive\ndata: x\n\n")
|
|
assert len(events) == 1
|
|
assert events[0].data == "x"
|
|
|
|
def test_unknown_field_dropped(self) -> None:
|
|
events = _drain_bytes(b"retry: 5000\ndata: x\n\n")
|
|
assert len(events) == 1
|
|
assert events[0].data == "x"
|
|
|
|
def test_only_comments_no_event(self) -> None:
|
|
events = _drain_bytes(b": just\n: comments\n\n")
|
|
assert events == []
|
|
|
|
def test_line_without_colon(self) -> None:
|
|
events = _drain_bytes(b"data\n\n")
|
|
assert len(events) == 1
|
|
assert events[0].data == ""
|
|
|
|
|
|
class TestFieldParsing:
|
|
def test_single_space_after_colon_is_stripped(self) -> None:
|
|
events = _drain_bytes(b"data: x\n\n")
|
|
assert events[0].data == "x"
|
|
|
|
def test_no_space_after_colon(self) -> None:
|
|
events = _drain_bytes(b"data:x\n\n")
|
|
assert events[0].data == "x"
|
|
|
|
def test_multiple_spaces_preserve_second(self) -> None:
|
|
events = _drain_bytes(b"data: x\n\n")
|
|
assert events[0].data == " x"
|
|
|
|
def test_extra_colons_in_value(self) -> None:
|
|
events = _drain_bytes(b'data: {"k":"v"}\n\n')
|
|
assert events[0].data == '{"k":"v"}'
|
|
|
|
|
|
class TestEventBoundaries:
|
|
def test_two_events(self) -> None:
|
|
data = b"id: 1\ndata: a\n\nid: 2\ndata: b\n\n"
|
|
events = _drain_bytes(data)
|
|
assert [e.id for e in events] == ["1", "2"]
|
|
assert [e.data for e in events] == ["a", "b"]
|
|
|
|
def test_no_trailing_blank_line_flushes_at_eof(self) -> None:
|
|
events = _drain_bytes(b"data: tail")
|
|
assert len(events) == 1
|
|
assert events[0].data == "tail"
|
|
|
|
|
|
class TestChunkingRobustness:
|
|
"""Same bytes split at every possible boundary must yield identical events."""
|
|
|
|
PAYLOAD = (
|
|
b'id: 0\ndata: {"reasoning":"","text":""}\n\n'
|
|
b'id: 1\ndata: {"text":"hello"}\n\n'
|
|
b'event: complete\ndata: {"complete":true}\n\n'
|
|
)
|
|
|
|
def _expected(self) -> list[tuple[str | None, str | None, str]]:
|
|
ref = _drain_bytes(self.PAYLOAD)
|
|
return [(e.id, e.event, e.data) for e in ref]
|
|
|
|
@pytest.mark.parametrize("size", [1, 2, 3, 5, 7, 13, 64])
|
|
def test_split_at_arbitrary_size(self, size: int) -> None:
|
|
chunks = _drain_chunked(self.PAYLOAD, size)
|
|
observed = [(e.id, e.event, e.data) for e in chunks]
|
|
assert observed == self._expected()
|
|
|
|
def test_split_inside_field_value(self) -> None:
|
|
a = b'id: 1\ndata: {"text":"hel'
|
|
b = b'lo"}\n\n'
|
|
p = SSEParser()
|
|
events = list(p.feed(a))
|
|
assert events == []
|
|
events.extend(p.feed(b))
|
|
events.extend(p.flush())
|
|
assert len(events) == 1
|
|
assert events[0].json() == {"text": "hello"}
|
|
|
|
def test_split_between_cr_and_lf(self) -> None:
|
|
a = b"id: 1\r"
|
|
b = b"\ndata: x\r\n\r\n"
|
|
p = SSEParser()
|
|
out = list(p.feed(a))
|
|
out.extend(p.feed(b))
|
|
out.extend(p.flush())
|
|
assert len(out) == 1
|
|
assert out[0].id == "1"
|
|
assert out[0].data == "x"
|
|
|
|
|
|
class TestLastEventIdTracking:
|
|
def test_advances_with_id_field(self) -> None:
|
|
p = SSEParser()
|
|
list(p.feed(b"id: 5\ndata: a\n\n"))
|
|
assert p.last_event_id == "5"
|
|
list(p.feed(b"id: 9\ndata: b\n\n"))
|
|
assert p.last_event_id == "9"
|
|
|
|
def test_event_without_id_keeps_previous(self) -> None:
|
|
p = SSEParser()
|
|
list(p.feed(b"id: 5\ndata: a\n\ndata: b\n\n"))
|
|
assert p.last_event_id == "5"
|
|
|
|
def test_starts_none(self) -> None:
|
|
p = SSEParser()
|
|
assert p.last_event_id is None
|
|
|
|
|
|
class TestIterSseAsync:
|
|
@pytest.mark.asyncio
|
|
async def test_async_iteration_over_chunks(self) -> None:
|
|
async def gen() -> AsyncIterator[bytes]:
|
|
yield b"id: 1\nda"
|
|
yield b'ta: {"text":"a"}\n\n'
|
|
yield b'id: 2\ndata: {"text":"b"}\n\nevent: complete\ndata: {"complete":true}\n\n'
|
|
|
|
seen = [e async for e in iter_sse(gen())]
|
|
assert [e.id for e in seen] == ["1", "2", None]
|
|
assert [e.event for e in seen] == [None, None, "complete"]
|
|
assert seen[-1].is_terminal
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_stream(self) -> None:
|
|
async def gen() -> AsyncIterator[bytes]:
|
|
if False:
|
|
yield b""
|
|
|
|
assert [e async for e in iter_sse(gen())] == []
|
|
|
|
|
|
class TestErrorEvent:
|
|
def test_error_event_recognised(self) -> None:
|
|
events = _drain_bytes(b'event: error\ndata: {"message":"boom"}\n\n')
|
|
assert events[0].is_error
|
|
assert events[0].json() == {"message": "boom"}
|
|
|
|
def test_default_event_not_error(self) -> None:
|
|
events = _drain_bytes(b'data: {"text":"x"}\n\n')
|
|
assert events[0].is_error is False
|
|
|
|
|
|
class TestJsonParsing:
|
|
def test_invalid_json_raises_on_demand(self) -> None:
|
|
events = _drain_bytes(b"data: not json\n\n")
|
|
with pytest.raises(json.JSONDecodeError):
|
|
events[0].json()
|
|
|
|
def test_event_data_kept_raw(self) -> None:
|
|
events = _drain_bytes(b"data: not json\n\n")
|
|
assert events[0].data == "not json"
|