import asyncio import re from mimetypes import guess_extension as extension from pathlib import Path from urllib.parse import unquote import aiohttp from tqdm import tqdm CHUNK_READ_SIZE = 65536 MAX_RETRIES = 5 RETRY_DELAYS = [1, 2, 4, 8, 16] _RETRYABLE = (TimeoutError, OSError, aiohttp.ClientError) async def fetch_chunk( # noqa: PLR0913 session: aiohttp.ClientSession, url: str, start_byte: int, end_byte: int, filepath: Path, pbar: tqdm, ) -> None: for attempt in range(MAX_RETRIES): bytes_written = 0 try: headers = {"Range": f"bytes={start_byte}-{end_byte}"} async with session.get(url, headers=headers) as response: with filepath.open("r+b") as f: f.seek(start_byte) while True: chunk = await response.content.read(CHUNK_READ_SIZE) if not chunk: break f.write(chunk) bytes_written += len(chunk) pbar.update(len(chunk)) except _RETRYABLE: pbar.update(-bytes_written) if attempt < MAX_RETRIES - 1: await asyncio.sleep(RETRY_DELAYS[attempt]) else: raise else: return async def fetch_single_stream( session: aiohttp.ClientSession, url: str, filepath: Path, pbar: tqdm, ) -> None: for attempt in range(MAX_RETRIES): bytes_written = 0 try: async with session.get(url) as response: with filepath.open("wb") as f: while True: chunk = await response.content.read(CHUNK_READ_SIZE) if not chunk: break f.write(chunk) bytes_written += len(chunk) pbar.update(len(chunk)) except _RETRYABLE: pbar.update(-bytes_written) if attempt < MAX_RETRIES - 1: await asyncio.sleep(RETRY_DELAYS[attempt]) else: raise else: return def get_filename(response: aiohttp.ClientResponse) -> str: headers = response.headers try: if ( "content-disposition" in headers and "filename" in headers["content-disposition"] ): match = re.match(r".*filename=\"(.+)\".*", headers["content-disposition"]) if match: return unquote(match.group(1)) except (KeyError, AttributeError): pass url = str(response.url).split("?")[0] filename = url.rstrip("/").split("/")[-1] if not filename: return "download" if re.findall(r"\.[a-zA-Z]{2}\w{0,2}$", filename): return unquote(filename) content_type = headers.get("Content-Type", "") ct_match = re.findall(r"([a-z]{4,11}/[\w+\-.]+)", content_type) if ct_match and extension(ct_match[0]): filename += extension(ct_match[0]) return unquote(filename) async def _head_with_retry( session: aiohttp.ClientSession, url: str, ) -> tuple[str, str | None, str]: for attempt in range(MAX_RETRIES): try: async with session.head(url, allow_redirects=True) as response: return ( get_filename(response), response.headers.get("Content-Length"), response.headers.get("Accept-Ranges", ""), ) except _RETRYABLE: if attempt < MAX_RETRIES - 1: await asyncio.sleep(RETRY_DELAYS[attempt]) else: raise raise RuntimeError("unreachable") # pragma: no cover async def download_file(url: str, num_parts: int = 20, *, position: int = 0) -> None: timeout = aiohttp.ClientTimeout(total=None, connect=30, sock_read=60) async with aiohttp.ClientSession(timeout=timeout) as session: filename, content_length_str, accept_ranges = await _head_with_retry( session, url, ) supports_ranges = accept_ranges == "bytes" and content_length_str is not None temp_path = Path(filename + ".fastdl") if supports_ranges: content_length = int(content_length_str) chunk_size = content_length // num_parts with temp_path.open("wb") as f: # noqa: ASYNC230 f.truncate(content_length) with tqdm( total=content_length, unit="B", unit_scale=True, desc=filename, position=position, leave=True, ) as pbar: tasks = [] for i in range(num_parts): start_byte = i * chunk_size if i == num_parts - 1: end_byte = content_length - 1 else: end_byte = start_byte + chunk_size - 1 tasks.append( fetch_chunk( session, url, start_byte, end_byte, temp_path, pbar, ), ) await asyncio.gather(*tasks) else: total = int(content_length_str) if content_length_str else None with tqdm( total=total, unit="B", unit_scale=True, desc=filename, position=position, leave=True, ) as pbar: await fetch_single_stream(session, url, temp_path, pbar) temp_path.rename(filename)