diff --git a/src/raycast_api/__init__.py b/src/raycast_api/__init__.py index 3b9e026..6d81115 100644 --- a/src/raycast_api/__init__.py +++ b/src/raycast_api/__init__.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: Tool, ToolCall, UserPreferences, + UserPreferencesArg, ) from raycast_api.client import Client, RetryPolicy, SSEEvent from raycast_api.config import Config @@ -38,6 +39,7 @@ __all__ = [ "Tool", "ToolCall", "UserPreferences", + "UserPreferencesArg", ] @@ -65,6 +67,7 @@ def __getattr__(name: str) -> Any: "Tool", "ToolCall", "UserPreferences", + "UserPreferencesArg", }: from raycast_api.ai import ( Attachment, @@ -78,6 +81,7 @@ def __getattr__(name: str) -> Any: Tool, ToolCall, UserPreferences, + UserPreferencesArg, ) return { @@ -92,5 +96,6 @@ def __getattr__(name: str) -> Any: "Tool": Tool, "ToolCall": ToolCall, "UserPreferences": UserPreferences, + "UserPreferencesArg": UserPreferencesArg, }[name] raise AttributeError(name) diff --git a/src/raycast_api/ai/__init__.py b/src/raycast_api/ai/__init__.py index efe9901..48ac381 100644 --- a/src/raycast_api/ai/__init__.py +++ b/src/raycast_api/ai/__init__.py @@ -6,7 +6,7 @@ dataclasses, and own the small amount of business logic the chat endpoint needs (preamble injection, source-specific defaults, SSE → typed chunks). """ -from raycast_api.ai.chat import ChatAPI, ChatResult, ChatStreamChunk +from raycast_api.ai.chat import ChatAPI, ChatResult, ChatStreamChunk, UserPreferencesArg from raycast_api.ai.files import FileMetadata, FilesAPI from raycast_api.ai.me import MeAPI from raycast_api.ai.models import ModelInfo, ModelsAPI, ModelsResponse @@ -37,4 +37,5 @@ __all__ = [ "Tool", "ToolCall", "UserPreferences", + "UserPreferencesArg", ] diff --git a/src/raycast_api/ai/chat.py b/src/raycast_api/ai/chat.py index 3bfaba5..df62f10 100644 --- a/src/raycast_api/ai/chat.py +++ b/src/raycast_api/ai/chat.py @@ -26,8 +26,9 @@ from __future__ import annotations import json import uuid +from collections.abc import Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeAlias, Union from raycast_api.ai.types import ( ChatStreamChunk, @@ -40,12 +41,20 @@ from raycast_api.ai.types import ( ) if TYPE_CHECKING: - from collections.abc import AsyncIterator, Callable + from collections.abc import AsyncIterator from raycast_api.ai.models import ModelInfo from raycast_api.client.http import Client +UserPreferencesArg: TypeAlias = Union[ # noqa: UP007 + UserPreferences, + bool, + None, + Callable[[], "UserPreferencesArg"], +] + + _SOURCE_DEFAULTS: dict[Source, dict[str, Any]] = { Source.AI_CHAT: { "system_instructions": "markdown", @@ -354,7 +363,7 @@ class ChatAPI: message_id: str | None = None, system_instructions: str | None = None, additional_system_instructions: str | None = None, - user_preferences: UserPreferences | None | bool = True, + user_preferences: UserPreferencesArg = True, temperature: float | None = None, reasoning_effort: str | None = None, tools: list[Tool | RemoteTool | str] | None = None, @@ -481,7 +490,7 @@ class ChatAPI: source: Source = Source.AI_CHAT, system_instructions: str | None = None, additional_system_instructions: str | None = None, - user_preferences: UserPreferences | None | bool = True, + user_preferences: UserPreferencesArg = True, temperature: float | None = None, reasoning_effort: str | None = None, tools: list[Tool | RemoteTool | str] | None = None, @@ -519,13 +528,15 @@ class ChatAPI: @staticmethod def _coerce_preferences( - value: UserPreferences | None | bool, # noqa: FBT001 + value: UserPreferencesArg, ) -> UserPreferences | None: + if isinstance(value, UserPreferences): + return value if value is True: return UserPreferences.auto() if value is False or value is None: return None - return value + return ChatAPI._coerce_preferences(value()) @staticmethod def _today_iso() -> str: