Files
raycast-api/src/raycast_api/client/streaming.py
T

172 lines
5.7 KiB
Python

r"""SSE parser for the Raycast streaming endpoint.
The wire format is the standard W3C `text/event-stream`, reduced to what
the Raycast backend actually emits (per `BUNDLE_NOTES.md` §SSE):
- Line endings: `\n` or `\r\n` (both stripped).
- `:` prefix → comment, ignored.
- Field separator: first `:` on the line; one optional space after the colon
is stripped from the value.
- Recognised fields: `id`, `event`, `data`. Multiple `data` lines per event
are joined with `\n`. Unknown fields are dropped.
- Event boundary: an empty line. EOF flushes any pending event.
The parser is byte-stream-driven (`SSEParser.feed(chunk_bytes)`) so it can
sit directly behind `aiohttp.StreamResponse.content` and not care about
chunk boundaries — a line can be split across two TCP chunks and the parser
still emits the right event.
It tracks `last_event_id` for caller-driven resume bookkeeping: every event
that carries an `id:` field updates the attribute, and that's the value the
caller should pass back in `Last-Event-ID` on the resume GET.
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from collections.abc import AsyncIterable, AsyncIterator, Iterable
@dataclass(frozen=True)
class SSEEvent:
r"""One parsed event from the stream.
- `id`: the `id:` field if present on any line of the event, else None.
- `event`: the `event:` field. None means the default event (most chunks).
Raycast uses `complete` for the terminator and `error` for failures.
- `data`: the `data:` payload as a string. If the server sent multiple
`data:` lines, they're joined with `\n`. Empty string when no `data:`
line was sent (rare — typically heartbeat comments use the `:` form
instead and don't produce an event at all).
"""
id: str | None
event: str | None
data: str
def json(self) -> Any:
"""Parse `data` as JSON. Raises `json.JSONDecodeError` on bad payloads."""
return json.loads(self.data)
@property
def is_terminal(self) -> bool:
"""True for the two end-of-stream markers Raycast emits.
- `event: complete` with `data: {"complete":true}` — production form
- `data: [DONE]` (no event) — legacy form, still supported by client
"""
if self.event == "complete":
return True
return self.event is None and self.data.strip() == "[DONE]"
@property
def is_error(self) -> bool:
return self.event == "error"
class SSEParser:
"""Stateful byte-stream SSE parser.
Feed it bytes as they arrive; it yields `SSEEvent`s when complete events
are buffered. Holds line + event state across `.feed()` calls so chunk
boundaries are transparent.
"""
__slots__ = ("_buf", "_data", "_event", "_id", "last_event_id")
def __init__(self) -> None:
self._buf = bytearray()
self._id: str | None = None
self._event: str | None = None
self._data: list[str] = []
self.last_event_id: str | None = None
def feed(self, chunk: bytes) -> Iterable[SSEEvent]:
"""Push bytes into the parser; yield any events that completed."""
if not chunk:
return
self._buf.extend(chunk)
while True:
nl = self._buf.find(b"\n")
if nl == -1:
return
raw = bytes(self._buf[:nl])
del self._buf[: nl + 1]
if raw.endswith(b"\r"):
raw = raw[:-1]
event = self._consume_line(raw)
if event is not None:
yield event
def flush(self) -> Iterable[SSEEvent]:
"""Yield any pending event at EOF.
The W3C spec says EOF without a trailing blank line does NOT dispatch
the pending event, but Raycast's parser (`Rkt` in the bundle) does
flush — matching that behavior keeps us robust to abruptly-closed
connections that hold the last chunk.
"""
if self._buf:
raw = bytes(self._buf)
self._buf.clear()
if raw.endswith(b"\r"):
raw = raw[:-1]
self._consume_line(raw)
if self._data or self._event is not None or self._id is not None:
yield self._dispatch()
def _consume_line(self, raw: bytes) -> SSEEvent | None:
if not raw:
if self._data or self._event is not None or self._id is not None:
return self._dispatch()
return None
line = raw.decode("utf-8", errors="replace")
if line.startswith(":"):
return None
if ":" in line:
field_name, _, value = line.partition(":")
value = value.removeprefix(" ")
else:
field_name, value = line, ""
if field_name == "data":
self._data.append(value)
elif field_name == "event":
self._event = value
elif field_name == "id":
self._id = value
self.last_event_id = value
return None
def _dispatch(self) -> SSEEvent:
data = "\n".join(self._data)
event = SSEEvent(id=self._id, event=self._event, data=data)
self._id = None
self._event = None
self._data = []
return event
async def iter_sse(byte_stream: AsyncIterable[bytes]) -> AsyncIterator[SSEEvent]:
"""Async generator: bytes → `SSEEvent`s.
Useful when you have an aiohttp response and just want events out:
async for evt in iter_sse(resp.content.iter_any()):
...
"""
parser = SSEParser()
async for chunk in byte_stream:
for evt in parser.feed(chunk):
yield evt
for evt in parser.flush():
yield evt