164 lines
5.4 KiB
Python
164 lines
5.4 KiB
Python
import asyncio
|
|
import re
|
|
from mimetypes import guess_extension as extension
|
|
from pathlib import Path
|
|
from urllib.parse import unquote
|
|
|
|
import aiohttp
|
|
from tqdm import tqdm
|
|
|
|
CHUNK_READ_SIZE = 65536
|
|
MAX_RETRIES = 5
|
|
RETRY_DELAYS = [1, 2, 4, 8, 16]
|
|
_RETRYABLE = (TimeoutError, OSError, aiohttp.ClientError)
|
|
|
|
|
|
async def fetch_chunk( # noqa: PLR0913
|
|
session: aiohttp.ClientSession,
|
|
url: str,
|
|
start_byte: int,
|
|
end_byte: int,
|
|
filepath: Path,
|
|
pbar: tqdm,
|
|
) -> None:
|
|
for attempt in range(MAX_RETRIES):
|
|
bytes_written = 0
|
|
try:
|
|
headers = {"Range": f"bytes={start_byte}-{end_byte}"}
|
|
async with session.get(url, headers=headers) as response:
|
|
with filepath.open("r+b") as f:
|
|
f.seek(start_byte)
|
|
while True:
|
|
chunk = await response.content.read(CHUNK_READ_SIZE)
|
|
if not chunk:
|
|
break
|
|
f.write(chunk)
|
|
bytes_written += len(chunk)
|
|
pbar.update(len(chunk))
|
|
except _RETRYABLE:
|
|
pbar.update(-bytes_written)
|
|
if attempt < MAX_RETRIES - 1:
|
|
await asyncio.sleep(RETRY_DELAYS[attempt])
|
|
else:
|
|
raise
|
|
else:
|
|
return
|
|
|
|
|
|
async def fetch_single_stream(
|
|
session: aiohttp.ClientSession, url: str, filepath: Path, pbar: tqdm,
|
|
) -> None:
|
|
for attempt in range(MAX_RETRIES):
|
|
bytes_written = 0
|
|
try:
|
|
async with session.get(url) as response:
|
|
with filepath.open("wb") as f:
|
|
while True:
|
|
chunk = await response.content.read(CHUNK_READ_SIZE)
|
|
if not chunk:
|
|
break
|
|
f.write(chunk)
|
|
bytes_written += len(chunk)
|
|
pbar.update(len(chunk))
|
|
except _RETRYABLE:
|
|
pbar.update(-bytes_written)
|
|
if attempt < MAX_RETRIES - 1:
|
|
await asyncio.sleep(RETRY_DELAYS[attempt])
|
|
else:
|
|
raise
|
|
else:
|
|
return
|
|
|
|
|
|
def get_filename(response: aiohttp.ClientResponse) -> str:
|
|
headers = response.headers
|
|
|
|
try:
|
|
if (
|
|
"content-disposition" in headers
|
|
and "filename" in headers["content-disposition"]
|
|
):
|
|
match = re.match(r".*filename=\"(.+)\".*", headers["content-disposition"])
|
|
if match:
|
|
return unquote(match.group(1))
|
|
except (KeyError, AttributeError):
|
|
pass
|
|
|
|
url = str(response.url).split("?")[0]
|
|
filename = url.rstrip("/").split("/")[-1]
|
|
if not filename:
|
|
return "download"
|
|
if re.findall(r"\.[a-zA-Z]{2}\w{0,2}$", filename):
|
|
return unquote(filename)
|
|
|
|
content_type = headers.get("Content-Type", "")
|
|
ct_match = re.findall(r"([a-z]{4,11}/[\w+\-.]+)", content_type)
|
|
if ct_match and extension(ct_match[0]):
|
|
filename += extension(ct_match[0])
|
|
return unquote(filename)
|
|
|
|
|
|
async def _head_with_retry(
|
|
session: aiohttp.ClientSession, url: str,
|
|
) -> tuple[str, str | None, str]:
|
|
for attempt in range(MAX_RETRIES):
|
|
try:
|
|
async with session.head(url, allow_redirects=True) as response:
|
|
return (
|
|
get_filename(response),
|
|
response.headers.get("Content-Length"),
|
|
response.headers.get("Accept-Ranges", ""),
|
|
)
|
|
except _RETRYABLE:
|
|
if attempt < MAX_RETRIES - 1:
|
|
await asyncio.sleep(RETRY_DELAYS[attempt])
|
|
else:
|
|
raise
|
|
raise RuntimeError("unreachable") # pragma: no cover
|
|
|
|
|
|
async def download_file(url: str, num_parts: int = 20, *, position: int = 0) -> None:
|
|
timeout = aiohttp.ClientTimeout(total=None, connect=30, sock_read=60)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
filename, content_length_str, accept_ranges = await _head_with_retry(
|
|
session, url,
|
|
)
|
|
|
|
supports_ranges = accept_ranges == "bytes" and content_length_str is not None
|
|
|
|
temp_path = Path(filename + ".fastdl")
|
|
|
|
if supports_ranges:
|
|
content_length = int(content_length_str)
|
|
chunk_size = content_length // num_parts
|
|
|
|
with temp_path.open("wb") as f: # noqa: ASYNC230
|
|
f.truncate(content_length)
|
|
|
|
with tqdm(
|
|
total=content_length, unit="B", unit_scale=True,
|
|
desc=filename, position=position, leave=True,
|
|
) as pbar:
|
|
tasks = []
|
|
for i in range(num_parts):
|
|
start_byte = i * chunk_size
|
|
if i == num_parts - 1:
|
|
end_byte = content_length - 1
|
|
else:
|
|
end_byte = start_byte + chunk_size - 1
|
|
tasks.append(
|
|
fetch_chunk(
|
|
session, url, start_byte, end_byte, temp_path, pbar,
|
|
),
|
|
)
|
|
await asyncio.gather(*tasks)
|
|
else:
|
|
total = int(content_length_str) if content_length_str else None
|
|
with tqdm(
|
|
total=total, unit="B", unit_scale=True,
|
|
desc=filename, position=position, leave=True,
|
|
) as pbar:
|
|
await fetch_single_stream(session, url, temp_path, pbar)
|
|
|
|
temp_path.rename(filename)
|