feat: create backend skeleton
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
import uvicorn
|
||||
|
||||
from utils.env import env
|
||||
|
||||
|
||||
def main() -> None:
|
||||
uvicorn.run("api.app:app", host=env.api.host, port=env.api.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,30 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import asyncpg
|
||||
from dishka.integrations.fastapi import FromDishka, inject, setup_dishka
|
||||
from fastapi import FastAPI
|
||||
|
||||
from dependencies.container import container
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app_: FastAPI) -> AsyncGenerator[None]:
|
||||
yield
|
||||
await app_.state.dishka_container.close()
|
||||
|
||||
|
||||
app = FastAPI(title="beavergram API", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@inject
|
||||
async def health(pool: FromDishka[asyncpg.Pool]) -> dict[str, bool]:
|
||||
db_ok = await pool.fetchval("SELECT 1") == 1
|
||||
timescale_ok = await pool.fetchval(
|
||||
"SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'timescaledb')"
|
||||
)
|
||||
return {"db": db_ok, "timescaledb": bool(timescale_ok)}
|
||||
|
||||
|
||||
setup_dishka(container, app)
|
||||
@@ -1 +0,0 @@
|
||||
pass
|
||||
@@ -0,0 +1,3 @@
|
||||
from .integrations.asyncio import FromDishka, Scope, inject
|
||||
|
||||
__all__ = ["FromDishka", "Scope", "inject"]
|
||||
@@ -0,0 +1,5 @@
|
||||
from dishka import make_async_container
|
||||
|
||||
from dependencies.providers.postgres import DbProvider
|
||||
|
||||
container = make_async_container(DbProvider())
|
||||
@@ -0,0 +1,3 @@
|
||||
from . import asyncio
|
||||
|
||||
__all__ = ["asyncio"]
|
||||
@@ -0,0 +1,25 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from dishka import FromDishka, Scope
|
||||
from dishka.integrations.base import wrap_injection
|
||||
|
||||
from dependencies.container import container
|
||||
|
||||
|
||||
def inject[**P, T](
|
||||
_func: Callable[P, T] | None = None, *, scope: Scope | None = None
|
||||
) -> Callable[P, T] | Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
return wrap_injection(
|
||||
func=func,
|
||||
is_async=True,
|
||||
container_getter=lambda _args, _kwargs: container,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
if _func is None:
|
||||
return decorator
|
||||
return decorator(_func)
|
||||
|
||||
|
||||
__all__ = ["FromDishka", "Scope", "container", "inject"]
|
||||
@@ -0,0 +1,27 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import asyncpg
|
||||
from dishka import Provider, Scope, provide
|
||||
|
||||
from utils.env import env
|
||||
|
||||
|
||||
async def _init_connection(conn: asyncpg.Connection) -> None:
|
||||
await conn.execute("SET timezone TO 'UTC'")
|
||||
|
||||
|
||||
class DbProvider(Provider):
|
||||
@provide(scope=Scope.APP)
|
||||
async def get_pool(self) -> AsyncGenerator[asyncpg.Pool]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
pool = await asyncpg.create_pool(
|
||||
dsn=env.db.connection_url,
|
||||
min_size=env.db.min_pool_size,
|
||||
max_size=env.db.max_pool_size,
|
||||
command_timeout=60,
|
||||
init=_init_connection,
|
||||
)
|
||||
try:
|
||||
yield pool
|
||||
finally:
|
||||
await pool.close()
|
||||
@@ -0,0 +1,4 @@
|
||||
from .runner import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,3 @@
|
||||
from .client import PyroClient
|
||||
|
||||
__all__ = ["PyroClient"]
|
||||
@@ -0,0 +1,19 @@
|
||||
from pyrogram import Client, enums
|
||||
|
||||
|
||||
class PyroClient(Client):
|
||||
def __init__(self, name: str, *, workdir: str = "sessions") -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
workdir=workdir,
|
||||
api_id=2040,
|
||||
api_hash="b18441a1ff607e10a989891a5462e627",
|
||||
device_model="Desktop",
|
||||
system_version="Windows 11 x64",
|
||||
app_version="6.2.4 x64",
|
||||
lang_pack="tdesktop",
|
||||
client_platform=enums.ClientPlatform.DESKTOP,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["PyroClient"]
|
||||
@@ -0,0 +1,98 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
import uvloop
|
||||
|
||||
from dependencies.container import container
|
||||
from userbot.modules import PyroClient
|
||||
from utils.env import env
|
||||
from utils.logging import logger, setup_logging
|
||||
|
||||
setup_logging()
|
||||
|
||||
_UPSERT_ACCOUNT = """
|
||||
INSERT INTO accounts
|
||||
(tg_user_id, label, phone, session_name, is_active, raw, updated_at)
|
||||
VALUES ($1, $2, $3, $4, TRUE, $5::jsonb, now())
|
||||
ON CONFLICT (tg_user_id) DO UPDATE SET
|
||||
label = EXCLUDED.label,
|
||||
phone = EXCLUDED.phone,
|
||||
session_name = EXCLUDED.session_name,
|
||||
is_active = TRUE,
|
||||
raw = EXCLUDED.raw,
|
||||
updated_at = now()
|
||||
"""
|
||||
|
||||
|
||||
def _discover_sessions(sessions_dir: Path) -> list[Path]:
|
||||
sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
return sorted(sessions_dir.glob("*.session"))
|
||||
|
||||
|
||||
async def _sync_account(
|
||||
pool: asyncpg.Pool, client: PyroClient, session_name: str
|
||||
) -> None:
|
||||
me = client.me
|
||||
if not me:
|
||||
return
|
||||
raw = json.dumps(
|
||||
{
|
||||
"id": me.id,
|
||||
"first_name": me.first_name,
|
||||
"last_name": me.last_name,
|
||||
"username": me.username,
|
||||
"phone_number": me.phone_number,
|
||||
}
|
||||
)
|
||||
label = " ".join(filter(None, [me.first_name, me.last_name])) or me.username
|
||||
await pool.execute(
|
||||
_UPSERT_ACCOUNT, me.id, label, me.phone_number, session_name, raw
|
||||
)
|
||||
logger.info(f"[green]Account synced:[/] {label} ({me.id})")
|
||||
|
||||
|
||||
async def runner() -> None:
|
||||
pool = await container.get(asyncpg.Pool)
|
||||
|
||||
sessions_dir = Path(env.tg.sessions_dir)
|
||||
session_files = _discover_sessions(sessions_dir)
|
||||
|
||||
if not session_files:
|
||||
logger.warning(
|
||||
f"[yellow]No .session files in {sessions_dir}/. "
|
||||
f"Log in first, then restart userbot.[/]"
|
||||
)
|
||||
|
||||
clients: list[PyroClient] = []
|
||||
try:
|
||||
for session_path in session_files:
|
||||
session_name = session_path.stem
|
||||
client = PyroClient(session_name, workdir=str(sessions_dir))
|
||||
await client.start()
|
||||
clients.append(client)
|
||||
logger.info(
|
||||
f"[green]Client started:[/] "
|
||||
f"{client.me.full_name if client.me else 'unknown'} "
|
||||
f"{client.me.id if client.me else 'unknown'}"
|
||||
)
|
||||
await _sync_account(pool, client, session_name)
|
||||
|
||||
if clients:
|
||||
logger.info("[green]Userbot running. Idle (no handlers until phase 3).[/]")
|
||||
await asyncio.Event().wait()
|
||||
finally:
|
||||
for client in clients:
|
||||
with contextlib.suppress(Exception):
|
||||
await client.stop()
|
||||
await container.close()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
uvloop.install()
|
||||
logger.info("Starting userbot...")
|
||||
with contextlib.suppress(KeyboardInterrupt):
|
||||
asyncio.run(runner())
|
||||
logger.info("[red]Userbot stopped.[/]")
|
||||
@@ -0,0 +1,5 @@
|
||||
from .env import env
|
||||
from .logging import logger, setup_logging
|
||||
from .storage import ContentAddressedStorage
|
||||
|
||||
__all__ = ["ContentAddressedStorage", "env", "logger", "setup_logging"]
|
||||
@@ -0,0 +1,3 @@
|
||||
from . import models
|
||||
|
||||
__all__ = ["models"]
|
||||
@@ -0,0 +1,50 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import BigInteger, Column, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class Account(SQLModel, table=True):
|
||||
__tablename__ = "accounts"
|
||||
|
||||
account_id: int | None = Field(default=None, primary_key=True)
|
||||
tg_user_id: int | None = Field(
|
||||
default=None, sa_column=Column(BigInteger, unique=True)
|
||||
)
|
||||
label: str | None = None
|
||||
phone: str | None = None
|
||||
session_name: str
|
||||
is_active: bool = True
|
||||
raw: dict[str, Any] = Field(
|
||||
default_factory=dict, sa_column=Column(JSONB, nullable=False)
|
||||
)
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class Message(SQLModel, table=True):
|
||||
__tablename__ = "messages"
|
||||
|
||||
account_id: int = Field(primary_key=True)
|
||||
chat_id: int = Field(sa_column=Column(BigInteger, primary_key=True))
|
||||
message_id: int = Field(sa_column=Column(BigInteger, primary_key=True))
|
||||
date: datetime = Field(sa_column=Column(DateTime(timezone=True), primary_key=True))
|
||||
raw: dict[str, Any] = Field(
|
||||
default_factory=dict, sa_column=Column(JSONB, nullable=False)
|
||||
)
|
||||
deleted_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class DatabaseSettings(BaseSettings):
|
||||
host: str = "postgres"
|
||||
port: int = 5432
|
||||
user: str = "beavergram"
|
||||
password: SecretStr = SecretStr("beavergram")
|
||||
db_name: str = "beavergram"
|
||||
min_pool_size: int = 5
|
||||
max_pool_size: int = 20
|
||||
scripts_connection_url: str = (
|
||||
"postgresql://beavergram:beavergram@localhost:5433/beavergram"
|
||||
)
|
||||
|
||||
@property
|
||||
def connection_url(self) -> str:
|
||||
if os.getenv("RUN_ENVIRONMENT") != "prod":
|
||||
return self.scripts_connection_url
|
||||
return (
|
||||
f"postgresql://{self.user}:{self.password.get_secret_value()}"
|
||||
f"@{self.host}:{self.port}/{self.db_name}"
|
||||
)
|
||||
|
||||
|
||||
class TelegramSettings(BaseSettings):
|
||||
session_name: str = "beavergram"
|
||||
sessions_dir: str = "sessions"
|
||||
|
||||
|
||||
class ApiSettings(BaseSettings):
|
||||
host: str = "0.0.0.0" # noqa: S104
|
||||
port: int = 8080
|
||||
|
||||
|
||||
class StorageSettings(BaseSettings):
|
||||
root: str = "storage"
|
||||
shard_depth: int = 2
|
||||
|
||||
|
||||
class LogSettings(BaseSettings):
|
||||
level: str = "INFO"
|
||||
level_external: str = "WARNING"
|
||||
show_time: bool = False
|
||||
console_width: int = 150
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
db: DatabaseSettings = Field(default_factory=DatabaseSettings)
|
||||
tg: TelegramSettings = Field(default_factory=TelegramSettings)
|
||||
api: ApiSettings = Field(default_factory=ApiSettings)
|
||||
storage: StorageSettings = Field(default_factory=StorageSettings)
|
||||
log: LogSettings = Field(default_factory=LogSettings)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
case_sensitive=False, env_file=".env", env_nested_delimiter="__", extra="ignore"
|
||||
)
|
||||
|
||||
|
||||
env = Settings()
|
||||
@@ -0,0 +1,33 @@
|
||||
import logging
|
||||
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
from rich.traceback import install
|
||||
|
||||
from .env import env
|
||||
|
||||
console = Console(width=env.log.console_width, color_system="auto", force_terminal=True)
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
logging.basicConfig(
|
||||
level=env.log.level_external,
|
||||
format="",
|
||||
datefmt=None,
|
||||
handlers=[
|
||||
RichHandler(
|
||||
console=console,
|
||||
markup=True,
|
||||
rich_tracebacks=True,
|
||||
enable_link_path=False,
|
||||
tracebacks_show_locals=True,
|
||||
omit_repeated_times=False,
|
||||
show_time=env.log.show_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
install(console=console, show_locals=True)
|
||||
|
||||
|
||||
logger = logging.getLogger("beavergram")
|
||||
logger.setLevel(env.log.level)
|
||||
@@ -0,0 +1,29 @@
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ContentAddressedStorage:
|
||||
def __init__(self, root: str | Path, shard_depth: int = 2) -> None:
|
||||
self._root = Path(root)
|
||||
self._shard_depth = shard_depth
|
||||
|
||||
def _path(self, key: str) -> Path:
|
||||
shards = [key[i * 2 : i * 2 + 2] for i in range(self._shard_depth)]
|
||||
return self._root.joinpath(*shards, key)
|
||||
|
||||
def put(self, data: bytes) -> str:
|
||||
key = hashlib.sha256(data).hexdigest()
|
||||
path = self._path(key)
|
||||
if not path.exists():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(data)
|
||||
return key
|
||||
|
||||
def get(self, key: str) -> bytes:
|
||||
return self._path(key).read_bytes()
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return self._path(key).exists()
|
||||
|
||||
def url(self, key: str) -> str:
|
||||
return str(self._path(key))
|
||||
Reference in New Issue
Block a user