172 lines
5.7 KiB
Python
172 lines
5.7 KiB
Python
r"""SSE parser for the Raycast streaming endpoint.
|
|
|
|
The wire format is the standard W3C `text/event-stream`, reduced to what
|
|
the Raycast backend actually emits (per `BUNDLE_NOTES.md` §SSE):
|
|
|
|
- Line endings: `\n` or `\r\n` (both stripped).
|
|
- `:` prefix → comment, ignored.
|
|
- Field separator: first `:` on the line; one optional space after the colon
|
|
is stripped from the value.
|
|
- Recognised fields: `id`, `event`, `data`. Multiple `data` lines per event
|
|
are joined with `\n`. Unknown fields are dropped.
|
|
- Event boundary: an empty line. EOF flushes any pending event.
|
|
|
|
The parser is byte-stream-driven (`SSEParser.feed(chunk_bytes)`) so it can
|
|
sit directly behind `aiohttp.StreamResponse.content` and not care about
|
|
chunk boundaries — a line can be split across two TCP chunks and the parser
|
|
still emits the right event.
|
|
|
|
It tracks `last_event_id` for caller-driven resume bookkeeping: every event
|
|
that carries an `id:` field updates the attribute, and that's the value the
|
|
caller should pass back in `Last-Event-ID` on the resume GET.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SSEEvent:
|
|
r"""One parsed event from the stream.
|
|
|
|
- `id`: the `id:` field if present on any line of the event, else None.
|
|
- `event`: the `event:` field. None means the default event (most chunks).
|
|
Raycast uses `complete` for the terminator and `error` for failures.
|
|
- `data`: the `data:` payload as a string. If the server sent multiple
|
|
`data:` lines, they're joined with `\n`. Empty string when no `data:`
|
|
line was sent (rare — typically heartbeat comments use the `:` form
|
|
instead and don't produce an event at all).
|
|
"""
|
|
|
|
id: str | None
|
|
event: str | None
|
|
data: str
|
|
|
|
def json(self) -> Any:
|
|
"""Parse `data` as JSON. Raises `json.JSONDecodeError` on bad payloads."""
|
|
return json.loads(self.data)
|
|
|
|
@property
|
|
def is_terminal(self) -> bool:
|
|
"""True for the two end-of-stream markers Raycast emits.
|
|
|
|
- `event: complete` with `data: {"complete":true}` — production form
|
|
- `data: [DONE]` (no event) — legacy form, still supported by client
|
|
"""
|
|
if self.event == "complete":
|
|
return True
|
|
return self.event is None and self.data.strip() == "[DONE]"
|
|
|
|
@property
|
|
def is_error(self) -> bool:
|
|
return self.event == "error"
|
|
|
|
|
|
class SSEParser:
|
|
"""Stateful byte-stream SSE parser.
|
|
|
|
Feed it bytes as they arrive; it yields `SSEEvent`s when complete events
|
|
are buffered. Holds line + event state across `.feed()` calls so chunk
|
|
boundaries are transparent.
|
|
"""
|
|
|
|
__slots__ = ("_buf", "_data", "_event", "_id", "last_event_id")
|
|
|
|
def __init__(self) -> None:
|
|
self._buf = bytearray()
|
|
self._id: str | None = None
|
|
self._event: str | None = None
|
|
self._data: list[str] = []
|
|
self.last_event_id: str | None = None
|
|
|
|
def feed(self, chunk: bytes) -> Iterable[SSEEvent]:
|
|
"""Push bytes into the parser; yield any events that completed."""
|
|
if not chunk:
|
|
return
|
|
self._buf.extend(chunk)
|
|
while True:
|
|
nl = self._buf.find(b"\n")
|
|
if nl == -1:
|
|
return
|
|
raw = bytes(self._buf[:nl])
|
|
del self._buf[: nl + 1]
|
|
if raw.endswith(b"\r"):
|
|
raw = raw[:-1]
|
|
event = self._consume_line(raw)
|
|
if event is not None:
|
|
yield event
|
|
|
|
def flush(self) -> Iterable[SSEEvent]:
|
|
"""Yield any pending event at EOF.
|
|
|
|
The W3C spec says EOF without a trailing blank line does NOT dispatch
|
|
the pending event, but Raycast's parser (`Rkt` in the bundle) does
|
|
flush — matching that behavior keeps us robust to abruptly-closed
|
|
connections that hold the last chunk.
|
|
"""
|
|
if self._buf:
|
|
raw = bytes(self._buf)
|
|
self._buf.clear()
|
|
if raw.endswith(b"\r"):
|
|
raw = raw[:-1]
|
|
self._consume_line(raw)
|
|
if self._data or self._event is not None or self._id is not None:
|
|
yield self._dispatch()
|
|
|
|
|
|
def _consume_line(self, raw: bytes) -> SSEEvent | None:
|
|
if not raw:
|
|
if self._data or self._event is not None or self._id is not None:
|
|
return self._dispatch()
|
|
return None
|
|
|
|
line = raw.decode("utf-8", errors="replace")
|
|
|
|
if line.startswith(":"):
|
|
return None
|
|
|
|
if ":" in line:
|
|
field_name, _, value = line.partition(":")
|
|
value = value.removeprefix(" ")
|
|
else:
|
|
field_name, value = line, ""
|
|
|
|
if field_name == "data":
|
|
self._data.append(value)
|
|
elif field_name == "event":
|
|
self._event = value
|
|
elif field_name == "id":
|
|
self._id = value
|
|
self.last_event_id = value
|
|
return None
|
|
|
|
def _dispatch(self) -> SSEEvent:
|
|
data = "\n".join(self._data)
|
|
event = SSEEvent(id=self._id, event=self._event, data=data)
|
|
self._id = None
|
|
self._event = None
|
|
self._data = []
|
|
return event
|
|
|
|
|
|
async def iter_sse(byte_stream: AsyncIterable[bytes]) -> AsyncIterator[SSEEvent]:
|
|
"""Async generator: bytes → `SSEEvent`s.
|
|
|
|
Useful when you have an aiohttp response and just want events out:
|
|
|
|
async for evt in iter_sse(resp.content.iter_any()):
|
|
...
|
|
"""
|
|
parser = SSEParser()
|
|
async for chunk in byte_stream:
|
|
for evt in parser.feed(chunk):
|
|
yield evt
|
|
for evt in parser.flush():
|
|
yield evt
|