Refactoring, added docstrings
This commit is contained in:
@@ -15,12 +15,22 @@ class Connection(object):
|
|||||||
public_key: str
|
public_key: str
|
||||||
|
|
||||||
async def send_webmessage(self, obj: webmessages_union):
|
async def send_webmessage(self, obj: webmessages_union):
|
||||||
|
"""
|
||||||
|
Sends WebMessage object to this connection
|
||||||
|
:param obj: Should be some type of WebMessage
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
await self.ws.send_text(obj.to_json())
|
await self.ws.send_text(obj.to_json())
|
||||||
|
|
||||||
async def send_error(
|
async def send_error(
|
||||||
self,
|
self,
|
||||||
error_message: webmessage_error_message_literal
|
error_message: webmessage_error_message_literal
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Sends error with specified messages
|
||||||
|
:param error_message: See webmessage_error_message_literal for available
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
await self.send_webmessage(
|
await self.send_webmessage(
|
||||||
WebErrorMessage(
|
WebErrorMessage(
|
||||||
error_message=error_message
|
error_message=error_message
|
||||||
@@ -28,6 +38,10 @@ class Connection(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def send_connect(self):
|
async def send_connect(self):
|
||||||
|
"""
|
||||||
|
When new user is connected, send info about user
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
await self.send_webmessage(
|
await self.send_webmessage(
|
||||||
WebUserMessage(
|
WebUserMessage(
|
||||||
type="connect",
|
type="connect",
|
||||||
|
|||||||
@@ -18,6 +18,12 @@ class Room(object):
|
|||||||
connections: Dict[str, Connection] = {}
|
connections: Dict[str, Connection] = {}
|
||||||
|
|
||||||
async def accept_connection(self, ws: WebSocket) -> Connection:
|
async def accept_connection(self, ws: WebSocket) -> Connection:
|
||||||
|
"""
|
||||||
|
Accepts connection, checks username availability and adds it to dict of
|
||||||
|
connections
|
||||||
|
:param ws: Websocket of connection
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
print('Incoming connection')
|
print('Incoming connection')
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
connection = Connection(
|
connection = Connection(
|
||||||
@@ -36,11 +42,22 @@ class Room(object):
|
|||||||
return connection
|
return connection
|
||||||
|
|
||||||
async def broadcast_webmessage(self, obj: webmessages_union):
|
async def broadcast_webmessage(self, obj: webmessages_union):
|
||||||
|
"""
|
||||||
|
Broadcasts WebMessages to all connections in room
|
||||||
|
:param obj:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
for connection in self.connections.values():
|
for connection in self.connections.values():
|
||||||
print(f'Sending to {connection.username}: {obj}')
|
print(f'Sending to {connection.username}: {obj}')
|
||||||
await connection.send_webmessage(obj)
|
await connection.send_webmessage(obj)
|
||||||
|
|
||||||
async def broadcast_message(self, from_username: str, message: str):
|
async def broadcast_message(self, from_username: str, message: str):
|
||||||
|
"""
|
||||||
|
Broadcasts message to every user in room
|
||||||
|
:param from_username: User that sent message
|
||||||
|
:param message: content
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
await self.broadcast_webmessage(
|
await self.broadcast_webmessage(
|
||||||
WebMessageMessage(
|
WebMessageMessage(
|
||||||
username=from_username,
|
username=from_username,
|
||||||
@@ -49,6 +66,11 @@ class Room(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def broadcast_notification(self, message: str):
|
async def broadcast_notification(self, message: str):
|
||||||
|
"""
|
||||||
|
Broadcasts notification from server
|
||||||
|
:param message: Content
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
await self.broadcast_webmessage(
|
await self.broadcast_webmessage(
|
||||||
WebNotificationMessage(
|
WebNotificationMessage(
|
||||||
message=message
|
message=message
|
||||||
@@ -59,6 +81,11 @@ class Room(object):
|
|||||||
self,
|
self,
|
||||||
error_message: webmessage_error_message_literal
|
error_message: webmessage_error_message_literal
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Broadcasts server error
|
||||||
|
:param error_message: See webmessage_error_message_literal
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
await self.broadcast_webmessage(
|
await self.broadcast_webmessage(
|
||||||
WebErrorMessage(
|
WebErrorMessage(
|
||||||
error_message=error_message
|
error_message=error_message
|
||||||
@@ -66,6 +93,11 @@ class Room(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def broadcast_user_disconnected(self, username: str):
|
async def broadcast_user_disconnected(self, username: str):
|
||||||
|
"""
|
||||||
|
Broadcasts that user is disconnected
|
||||||
|
:param username: Username of user that disconnected
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
await self.broadcast_webmessage(
|
await self.broadcast_webmessage(
|
||||||
WebUserMessage(
|
WebUserMessage(
|
||||||
type="disconnect",
|
type="disconnect",
|
||||||
@@ -74,11 +106,24 @@ class Room(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def get_connection_by(self, attribute: str, value: str) -> Connection | None:
|
async def get_connection_by(self, attribute: str, value: str) -> Connection | None:
|
||||||
|
"""
|
||||||
|
Search for connection by attribute and value in it
|
||||||
|
:param attribute:
|
||||||
|
:param value:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
for connection in self.connections.values():
|
for connection in self.connections.values():
|
||||||
if getattr(connection, attribute) == value:
|
if getattr(connection, attribute) == value:
|
||||||
return connection
|
return connection
|
||||||
|
|
||||||
async def disconnect(self, connection: Connection, close_reason: str | None = None):
|
async def disconnect(self, connection: Connection, close_reason: str | None = None):
|
||||||
|
"""
|
||||||
|
Disconnects by connection object.
|
||||||
|
:param connection: Object of connection.
|
||||||
|
It can be obtained using get_connection_by
|
||||||
|
:param close_reason: Reason if exists
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
if connection not in self.connections.values():
|
if connection not in self.connections.values():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,11 @@ class Service(object):
|
|||||||
await room.broadcast_error(error_message)
|
await room.broadcast_error(error_message)
|
||||||
|
|
||||||
async def get_room_by_connection(self, connection: Connection) -> Room:
|
async def get_room_by_connection(self, connection: Connection) -> Room:
|
||||||
|
"""
|
||||||
|
Searches for room by valid connection object in it
|
||||||
|
:param connection: Connection in unknown room to search
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
for room in self.rooms.values():
|
for room in self.rooms.values():
|
||||||
if connection in room.connections.values():
|
if connection in room.connections.values():
|
||||||
return room
|
return room
|
||||||
@@ -36,11 +41,23 @@ class Service(object):
|
|||||||
async def get_connection_by_attribute(
|
async def get_connection_by_attribute(
|
||||||
self, attribute: str, value: str
|
self, attribute: str, value: str
|
||||||
) -> Connection:
|
) -> Connection:
|
||||||
|
"""
|
||||||
|
Gets connection in some room by attribute and value in it
|
||||||
|
:param attribute:
|
||||||
|
:param value:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
for room in self.rooms.values():
|
for room in self.rooms.values():
|
||||||
if connection := await room.get_connection_by(attribute, value):
|
if connection := await room.get_connection_by(attribute, value):
|
||||||
return connection
|
return connection
|
||||||
|
|
||||||
async def close_room(self, room_name: str, reason: str = 'Unknown reason'):
|
async def close_room(self, room_name: str, reason: str = 'Unknown reason'):
|
||||||
|
"""
|
||||||
|
Closes all connections in room
|
||||||
|
:param room_name: Close name
|
||||||
|
:param reason: Reason to close room, default is Unknown reason
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
room = self.rooms.get(room_name)
|
room = self.rooms.get(room_name)
|
||||||
if room is None:
|
if room is None:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -10,6 +10,12 @@ service = Service()
|
|||||||
|
|
||||||
|
|
||||||
async def serve_websocket(websocket: WebSocket, room_name: str):
|
async def serve_websocket(websocket: WebSocket, room_name: str):
|
||||||
|
"""
|
||||||
|
Serves websocket
|
||||||
|
:param websocket: Ws to serve
|
||||||
|
:param room_name: Room name to connect ws to
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
print(f'Connection opened room {room_name}')
|
print(f'Connection opened room {room_name}')
|
||||||
room = await service.get_room(room_name)
|
room = await service.get_room(room_name)
|
||||||
connection = await room.accept_connection(websocket)
|
connection = await room.accept_connection(websocket)
|
||||||
|
|||||||
@@ -8,6 +8,12 @@ from rich import print
|
|||||||
|
|
||||||
|
|
||||||
def integrate_onion(port: int, name: str) -> Onion:
|
def integrate_onion(port: int, name: str) -> Onion:
|
||||||
|
"""
|
||||||
|
Starts onion service, writes it to config
|
||||||
|
:param port: Port, where local service is started
|
||||||
|
:param name: Name of service to get or write to config
|
||||||
|
:return: Onion object, that is connected and service is started
|
||||||
|
"""
|
||||||
onion = Onion()
|
onion = Onion()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -6,6 +6,12 @@ from .routes import router
|
|||||||
|
|
||||||
|
|
||||||
def get_app(port: int, name: str) -> FastAPI:
|
def get_app(port: int, name: str) -> FastAPI:
|
||||||
|
"""
|
||||||
|
Creates FastAPI object and runs integrate_onion
|
||||||
|
:param port: Must be same with port on which uvicorn is running
|
||||||
|
:param name: Name of service
|
||||||
|
:return: FastAPI object with onion.cleanup function on shutdown
|
||||||
|
"""
|
||||||
onion = integrate_onion(port, name)
|
onion = integrate_onion(port, name)
|
||||||
return FastAPI(
|
return FastAPI(
|
||||||
title=f'dragonion-server: {name}',
|
title=f'dragonion-server: {name}',
|
||||||
@@ -15,6 +21,12 @@ def get_app(port: int, name: str) -> FastAPI:
|
|||||||
|
|
||||||
|
|
||||||
def run(name: str, port: int | None = get_available_port()):
|
def run(name: str, port: int | None = get_available_port()):
|
||||||
|
"""
|
||||||
|
Runs service with specified name and starts onion
|
||||||
|
:param name: Name of service
|
||||||
|
:param port: Port where to start service, if not specified - gets random available
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
if port is None:
|
if port is None:
|
||||||
port = get_available_port()
|
port = get_available_port()
|
||||||
app = get_app(port, name)
|
app = get_app(port, name)
|
||||||
|
|||||||
@@ -4,16 +4,6 @@ import shutil
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
def dir_size(start_path):
|
|
||||||
total_size = 0
|
|
||||||
for dirpath, dirnames, filenames in os.walk(start_path):
|
|
||||||
for f in filenames:
|
|
||||||
fp = os.path.join(dirpath, f)
|
|
||||||
if not os.path.islink(fp):
|
|
||||||
total_size += os.path.getsize(fp)
|
|
||||||
return total_size
|
|
||||||
|
|
||||||
|
|
||||||
def get_resource_path(filename):
|
def get_resource_path(filename):
|
||||||
application_path = 'resources'
|
application_path = 'resources'
|
||||||
|
|
||||||
|
|||||||
@@ -1,52 +0,0 @@
|
|||||||
import os
|
|
||||||
import hashlib
|
|
||||||
import base64
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
def random_string(num_bytes, output_len=None):
|
|
||||||
b = os.urandom(num_bytes)
|
|
||||||
h = hashlib.sha256(b).digest()[:16]
|
|
||||||
s = base64.b32encode(h).lower().replace(b"=", b"").decode("utf-8")
|
|
||||||
if not output_len:
|
|
||||||
return s
|
|
||||||
return s[:output_len]
|
|
||||||
|
|
||||||
|
|
||||||
def human_readable_filesize(b):
|
|
||||||
thresh = 1024.0
|
|
||||||
if b < thresh:
|
|
||||||
return "{:.1f} B".format(b)
|
|
||||||
units = ("KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB")
|
|
||||||
u = 0
|
|
||||||
b /= thresh
|
|
||||||
while b >= thresh:
|
|
||||||
b /= thresh
|
|
||||||
u += 1
|
|
||||||
return "{:.1f} {}".format(b, units[u])
|
|
||||||
|
|
||||||
|
|
||||||
def format_seconds(seconds):
|
|
||||||
days, seconds = divmod(seconds, 86400)
|
|
||||||
hours, seconds = divmod(seconds, 3600)
|
|
||||||
minutes, seconds = divmod(seconds, 60)
|
|
||||||
|
|
||||||
human_readable = []
|
|
||||||
if days:
|
|
||||||
human_readable.append("{:.0f}d".format(days))
|
|
||||||
if hours:
|
|
||||||
human_readable.append("{:.0f}h".format(hours))
|
|
||||||
if minutes:
|
|
||||||
human_readable.append("{:.0f}m".format(minutes))
|
|
||||||
if seconds or not human_readable:
|
|
||||||
human_readable.append("{:.0f}s".format(seconds))
|
|
||||||
return "".join(human_readable)
|
|
||||||
|
|
||||||
|
|
||||||
def estimated_time_remaining(bytes_downloaded, total_bytes, started):
|
|
||||||
now = time.time()
|
|
||||||
time_elapsed = now - started
|
|
||||||
download_rate = bytes_downloaded / time_elapsed
|
|
||||||
remaining_bytes = total_bytes - bytes_downloaded
|
|
||||||
eta = remaining_bytes / download_rate
|
|
||||||
return format_seconds(eta)
|
|
||||||
@@ -2,6 +2,11 @@ import sqlitedict
|
|||||||
|
|
||||||
|
|
||||||
class AuthFile(sqlitedict.SqliteDict):
|
class AuthFile(sqlitedict.SqliteDict):
|
||||||
|
"""
|
||||||
|
Valid AuthFile has fields:
|
||||||
|
host - .onion url of service
|
||||||
|
auth - v3 onion auth string in format, that can be written to .auth_private file
|
||||||
|
"""
|
||||||
def __init__(self, service):
|
def __init__(self, service):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
filename=f'{service}.auth',
|
filename=f'{service}.auth',
|
||||||
|
|||||||
@@ -134,6 +134,12 @@ class Onion(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_onion_service(name: str, port: int):
|
def write_onion_service(name: str, port: int):
|
||||||
|
"""
|
||||||
|
Writes onion service to config
|
||||||
|
:param name: Name of service
|
||||||
|
:param port: Port of real service on local machine to proxy
|
||||||
|
:return: ServiceModel object
|
||||||
|
"""
|
||||||
if name in services.keys():
|
if name in services.keys():
|
||||||
service: config.models.ServiceModel = services[name]
|
service: config.models.ServiceModel = services[name]
|
||||||
service.port = port
|
service.port = port
|
||||||
|
|||||||
@@ -8,6 +8,10 @@ from typing import Literal
|
|||||||
|
|
||||||
|
|
||||||
def get_latest_version() -> str:
|
def get_latest_version() -> str:
|
||||||
|
"""
|
||||||
|
Gets latest non-alfa version name from dist.torproject.org
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
r = requests.get('https://dist.torproject.org/torbrowser/').text
|
r = requests.get('https://dist.torproject.org/torbrowser/').text
|
||||||
|
|
||||||
results = re.findall(r'<a href=".+/">(.+)/</a>', r)
|
results = re.findall(r'<a href=".+/">(.+)/</a>', r)
|
||||||
@@ -22,6 +26,10 @@ def get_build() -> Literal[
|
|||||||
'macos-x86_64',
|
'macos-x86_64',
|
||||||
'macos-aarch64'
|
'macos-aarch64'
|
||||||
]:
|
]:
|
||||||
|
"""
|
||||||
|
Gets proper build name for your system
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
if sys.platform == 'win32':
|
if sys.platform == 'win32':
|
||||||
return 'windows-x86_64'
|
return 'windows-x86_64'
|
||||||
elif sys.platform == 'linux':
|
elif sys.platform == 'linux':
|
||||||
@@ -38,18 +46,32 @@ def get_build() -> Literal[
|
|||||||
|
|
||||||
def get_tor_expert_bundles(version: str = get_latest_version(),
|
def get_tor_expert_bundles(version: str = get_latest_version(),
|
||||||
platform: str = get_build()):
|
platform: str = get_build()):
|
||||||
|
"""
|
||||||
|
Returns a link for downloading tor expert bundle by version and platform
|
||||||
|
:param version: Tor expert bundle version that exists in dist.torproject.org
|
||||||
|
:param platform: Build type based on platform and arch, can be generated using
|
||||||
|
get_build()
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
return f'https://dist.torproject.org/torbrowser/{version}/tor-expert-bundle-' \
|
return f'https://dist.torproject.org/torbrowser/{version}/tor-expert-bundle-' \
|
||||||
f'{version}-{platform}.tar.gz'
|
f'{version}-{platform}.tar.gz'
|
||||||
|
|
||||||
|
|
||||||
def download_tor(url: str = get_tor_expert_bundles(), dist: str = 'tor'):
|
def download_tor(url: str = get_tor_expert_bundles(), dist: str = 'tor'):
|
||||||
|
"""
|
||||||
|
Downloads tor from url and unpacks it to specified directory. Note, that
|
||||||
|
it doesn't unpack only tor executable to dist folder, but creates there
|
||||||
|
tor folder, where tor executable and libs are stored
|
||||||
|
:param url: Direct link for downloading
|
||||||
|
:param dist: Directory where to unpack archive (tor folder will appear there)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
if not os.path.exists(dist):
|
if not os.path.exists(dist):
|
||||||
os.makedirs(dist)
|
os.makedirs(dist)
|
||||||
|
|
||||||
(tar := tarfile.open(fileobj=io.BytesIO(requests.get(url).content),
|
(tar := tarfile.open(fileobj=io.BytesIO(requests.get(url).content),
|
||||||
mode='r:gz')).extractall(
|
mode='r:gz')).extractall(
|
||||||
members=
|
members=[
|
||||||
[
|
|
||||||
tarinfo
|
tarinfo
|
||||||
for tarinfo
|
for tarinfo
|
||||||
in tar.getmembers()
|
in tar.getmembers()
|
||||||
|
|||||||
Reference in New Issue
Block a user