Admin and start handlers work. Added testing txt2img from first bot version (will be replaced soon)

This commit is contained in:
BarsTiger
2023-02-14 13:49:58 +02:00
parent 7ba5482e6a
commit dcdb4ef60d
20 changed files with 234 additions and 19 deletions

View File

@@ -4,7 +4,7 @@ from dotenv import load_dotenv
load_dotenv()
TOKEN = os.getenv('TOKEN')
ADMIN = os.getenv('ADMIN')
ADMIN = int(os.getenv('ADMIN'))
DB_CHAT = os.getenv('DB_CHAT')
_DB_PATH = os.getenv('DB_PATH')
DB = _DB_PATH + '/db'

View File

@@ -40,7 +40,6 @@ async def pull():
from .db import db
for table in DBTables.tables:
db[table].clear()
new_table = SqliteDict(DB + 'b', tablename=table)
for key in new_table.keys():
db[table][key] = new_table[key]

View File

@@ -0,0 +1,6 @@
from bot.common import dp
from .aliases import *
def register():
dp.register_message_handler(set_endpoint, commands='setendpoint')

View File

@@ -0,0 +1,23 @@
from aiogram import types
from bot.db import db, DBTables
import validators
from bot.config import ADMIN
from bot.utils.cooldown import throttle
@throttle(5)
async def set_endpoint(message: types.Message):
if message.from_id not in db[DBTables.config].get('admins') and message.from_id != ADMIN:
await message.reply('❌ You are not permitted to do that. '
'It is only for this bot instance maintainers and admins')
return
if not message.get_args() or not validators.url(message.get_args()):
await message.reply("❌ Specify correct url for endpoint")
return
db[DBTables.config]['endpoint'] = message.get_args()
await db[DBTables.config].write()
await message.reply("✅ New url set")

View File

@@ -1,2 +0,0 @@
help_data = {
}

View File

@@ -0,0 +1,6 @@
from bot.common import dp
from . import help_handler
def register():
dp.register_message_handler(help_handler.help_command, commands='help')

View File

@@ -1,9 +1,7 @@
from aiogram import types
from bot.common import dp
from .help_strings import help_data
@dp.message_handler(commands='help')
async def help_command(message: types.Message):
if message.get_args() == "":
await message.reply(

View File

@@ -0,0 +1,3 @@
help_data = {
'setendpoint': '(admin) Set StableDiffusion API endpoint'
}

View File

@@ -1 +1,8 @@
pass
from bot.common import dp, bot
from .start import *
from .all_messages import *
def register():
dp.register_message_handler(all_messages.sync_db_filter, lambda *_: not hasattr(bot, 'cloudmeta_message_text'))
dp.register_message_handler(start.start_command, commands='start')

View File

@@ -0,0 +1,7 @@
from aiogram.types import Message
from bot.db.pull_db import pull
async def sync_db_filter(message: Message):
await pull()
await message.reply(f'🔄️ Bot database synchronised. If you tried to run a command, run it again')

View File

@@ -1,7 +0,0 @@
from bot.common import dp
from bot.db.pull_db import pull
@dp.message_handler()
async def pull_db_if_new(_):
await pull()

View File

@@ -0,0 +1,23 @@
from aiogram import types
from bot.db import db, DBTables
from bot.config import ADMIN
from bot.utils.cooldown import throttle
@throttle(10)
async def start_command(message: types.Message):
if message.from_id == ADMIN:
await message.reply(f'👋 Hello, {message.from_user.username}. You are admin of this instance, '
f'so we will check config for you now')
if not isinstance(db[DBTables.config].get('admins'), list):
db[DBTables.config]['admins'] = list()
if ADMIN not in db[DBTables.config].get('admins'):
admins_ = db[DBTables.config].get('admins')
admins_.append(ADMIN)
db[DBTables.config]['admins'] = admins_
await db[DBTables.config].write()
await message.reply(f'✅ Added {message.from_user.username} to admins. You can add other admins, '
f'check bot settings menu')
return
await message.reply(f'👋 Hello, {message.from_user.username}. Use /help to see available commands.')

View File

@@ -2,10 +2,13 @@ from rich import print
def import_handlers():
import bot.handlers.help.help
assert bot.handlers.help.help
from bot.handlers import (
initialize, admin, help_command, txt2img
)
import bot.handlers.initialize.pull_db
assert bot.handlers.initialize.pull_db
initialize.register()
admin.register()
help_command.register()
txt2img.register()
print('[gray]All handlers imported[/]')

View File

@@ -0,0 +1,6 @@
from bot.common import dp
from .txt2img import *
def register():
dp.register_message_handler(txt2img, commands='txt2img')

View File

@@ -0,0 +1,31 @@
from aiogram import types
from bot.db import db, DBTables
from bot.utils.cooldown import throttle
import os.path
from bot.modules.api.txt2img import txt2img
@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins'))
async def txt2img_comand(message: types.Message):
temp_message = await message.reply("⏳ Generating image...")
if not message.get_args():
await temp_message.edit_text("Specify prompt for this command. Check /help txt2img")
return
if not os.path.isfile('adstring'):
with open('adstring', 'w') as f:
f.write('@aiistop_bot')
try:
image = await txt2img(message.get_args())
await message.reply_photo(photo=image[0], caption=str(
image[1]["infotexts"][0]) + "\n\n" + open('adstring').read())
except Exception as e:
assert e
await message.reply("We ran into error while processing your request. StableDiffusion models may not be "
"configured on specified endpoint or server with StableDiffusion may be turned "
"off. Ask admins of this bot instance if you have contacts for further info")
await temp_message.delete()
return
await temp_message.delete()

31
bot/modules/api/models.py Normal file
View File

@@ -0,0 +1,31 @@
import aiohttp
from bot.db import db, DBTables
from rich import print
async def get_models():
endpoint = db[DBTables.config].get('endpoint')
try:
async with aiohttp.ClientSession() as session:
r = await session.get(endpoint + "/sdapi/v1/sd-models")
if r.status != 200:
return None
return [x["title"] for x in await r.json()]
except Exception as e:
print(e)
return None
async def set_model(model_name: str):
endpoint = db[DBTables.config].get('endpoint')
try:
async with aiohttp.ClientSession() as session:
r = await session.post(endpoint + "/sdapi/v1/options", json={
"sd_model_checkpoint": model_name
})
if r.status != 200:
return False
return True
except Exception as e:
print(e)
return False

View File

@@ -0,0 +1,32 @@
import aiohttp
from bot.db import db, DBTables
import json
import base64
async def txt2img(prompt: str, negative_prompt: str = None, steps: int = 20,
cfg_scale: int = 7, width: int = 768, height: int = 768,
restore_faces: bool = True, sampler: str = "Euler a") -> list[bytes, dict] | None:
endpoint = db[DBTables.config].get('endpoint')
try:
async with aiohttp.ClientSession() as session:
r = await session.post(
endpoint + "/sdapi/v1/txt2img",
json={
"prompt": prompt,
"steps": steps,
"cfg_scale": cfg_scale,
"width": width,
"height": height,
"restore_faces": restore_faces,
"negative_prompt": negative_prompt,
"sampler_index": sampler
}
)
if r.status != 200:
return None
return [base64.b64decode((await r.json())["images"][0]),
json.loads((await r.json())["info"])]
except Exception as e:
assert e
return None

View File

@@ -4,7 +4,7 @@ from bot.common import bot
async def set_commands():
from bot.handlers.help.help_strings import help_data
from bot.handlers.help_command.help_strings import help_data
await bot.set_my_commands(
commands=list(

49
bot/utils/cooldown.py Normal file
View File

@@ -0,0 +1,49 @@
from functools import wraps
import datetime
from bot.common import bot
from bot.db import db, DBTables
import asyncio
from aiogram import types
def not_allowed(message: types.Message, cd: int, by_id: bool):
return asyncio.create_task(message.reply(
text=
f"❌ Wait for cooldown ({cd}s for this command)"
f"{'. Please note that this cooldown is for all users' if not by_id else ''}"
))
def throttle(cooldown: int = 5, by_id: bool = True, admin_ids: list = None):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
user_id = int(args[0]["from"]["id"])
if admin_ids and user_id in admin_ids:
return asyncio.create_task(func(*args, **kwargs))
user_id = str(user_id) if by_id else "0"
now = datetime.datetime.now()
delta = now - datetime.timedelta(seconds=cooldown)
try:
last_time = db[DBTables.cooldown].get(func.__name__).get(user_id)
except AttributeError:
last_time = None
if not last_time:
last_time = delta
if last_time <= delta:
try:
db[DBTables.cooldown][func.__name__][user_id] = now
except KeyError:
db[DBTables.cooldown][func.__name__] = dict()
db[DBTables.cooldown][func.__name__][user_id] = now
try:
return asyncio.create_task(func(*args, **kwargs))
except Exception as e:
assert e
else:
return not_allowed(*args, cooldown, by_id)
return wrapper
return decorator