"""Downloads media from telegram.""" import asyncio import logging import os import shutil import time from typing import List, Optional, Tuple, Union import pyrogram from loguru import logger from pyrogram.types import Audio, Document, Photo, Video, VideoNote, Voice from rich.logging import RichHandler import module.database as db from module.app import Application, ChatDownloadConfig, DownloadStatus, TaskNode from module.bot import start_download_bot, stop_download_bot from module.download_stat import ( update_download_status, update_task_progress, reset_task_progress, snapshot_current_chat, increment_task_stat, get_task_progress, is_message_skipped, skip_message, remove_download_entry, clear_skipped_message, ) from module.get_chat_history_v2 import get_chat_history_v2 from module.language import _t from module.pyrogram_extension import ( HookClient, fetch_message, get_extension, record_download_status, report_bot_download_status, set_max_concurrent_transmissions, set_meta_data, update_cloud_upload_stat, upload_telegram_chat, ) from module.web import init_web, shutdown_web from utils.format import truncate_filename, validate_title from utils.log import LogFilter from utils.meta import print_meta from utils.meta_data import MetaData logging.basicConfig( level=logging.INFO, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()], ) CONFIG_NAME = "config.yaml" DATA_FILE_NAME = "data.yaml" APPLICATION_NAME = "media_downloader" app = Application(CONFIG_NAME, DATA_FILE_NAME, APPLICATION_NAME) queue: asyncio.Queue = asyncio.Queue() RETRY_TIME_OUT = 3 logging.getLogger("pyrogram.session.session").addFilter(LogFilter()) logging.getLogger("pyrogram.client").addFilter(LogFilter()) logging.getLogger("pyrogram").setLevel(logging.WARNING) def _check_download_finish(media_size: int, download_path: str, ui_file_name: str): """Check download task if finish Parameters ---------- media_size: int The size of the downloaded resource download_path: str Resource download hold path ui_file_name: str Really show file name """ download_size = os.path.getsize(download_path) if media_size == download_size: logger.success(f"{_t('Successfully downloaded')} - {ui_file_name}") else: logger.warning( f"{_t('Media downloaded with wrong size')}: " f"{download_size}, {_t('actual')}: " f"{media_size}, {_t('file name')}: {ui_file_name}" ) os.remove(download_path) raise pyrogram.errors.exceptions.bad_request_400.BadRequest() def _move_to_download_path(temp_download_path: str, download_path: str): """Move file to download path Parameters ---------- temp_download_path: str Temporary download path download_path: str Download path """ directory, _ = os.path.split(download_path) os.makedirs(directory, exist_ok=True) shutil.move(temp_download_path, download_path) def _check_timeout(retry: int, _: int): """Check if message download timeout, then add message id into failed_ids Parameters ---------- retry: int Retry download message times message_id: int Try to download message 's id """ if retry == 2: return True return False def _can_download(_type: str, file_formats: dict, file_format: Optional[str]) -> bool: """ Check if the given file format can be downloaded. Parameters ---------- _type: str Type of media object. file_formats: dict Dictionary containing the list of file_formats to be downloaded for `audio`, `document` & `video` media types file_format: str Format of the current file to be downloaded. Returns ------- bool True if the file format can be downloaded else False. """ if _type in ["audio", "document", "video"]: allowed_formats: list = file_formats[_type] if not file_format in allowed_formats and allowed_formats[0] != "all": return False return True def _is_exist(file_path: str) -> bool: """ Check if a file exists and it is not a directory. Parameters ---------- file_path: str Absolute path of the file to be checked. Returns ------- bool True if the file exists else False. """ return not os.path.isdir(file_path) and os.path.exists(file_path) # pylint: disable = R0912 async def _get_media_meta( chat_id: Union[int, str], message: pyrogram.types.Message, media_obj: Union[Audio, Document, Photo, Video, VideoNote, Voice], _type: str, ) -> Tuple[str, str, Optional[str]]: """Extract file name and file id from media object. Parameters ---------- media_obj: Union[Audio, Document, Photo, Video, VideoNote, Voice] Media object to be extracted. _type: str Type of media object. Returns ------- Tuple[str, str, Optional[str]] file_name, file_format """ if _type in ["audio", "document", "video"]: # pylint: disable = C0301 file_format: Optional[str] = media_obj.mime_type.split("/")[-1] # type: ignore else: file_format = None file_name = None temp_file_name = None dirname = validate_title(f"{chat_id}") if message.chat and message.chat.title: dirname = validate_title(f"{message.chat.title}") if message.date: datetime_dir_name = message.date.strftime(app.date_format) else: datetime_dir_name = "0" if _type in ["voice", "video_note"]: # pylint: disable = C0209 file_format = media_obj.mime_type.split("/")[-1] # type: ignore file_save_path = app.get_file_save_path(_type, dirname, datetime_dir_name) file_name = "{} - {}_{}.{}".format( message.id, _type, media_obj.date.isoformat(), # type: ignore file_format, ) file_name = validate_title(file_name) temp_file_name = os.path.join(app.temp_save_path, dirname, file_name) file_name = os.path.join(file_save_path, file_name) else: file_name = getattr(media_obj, "file_name", None) caption = getattr(message, "caption", None) file_name_suffix = ".unknown" if not file_name: file_name_suffix = get_extension( media_obj.file_id, getattr(media_obj, "mime_type", "") ) else: # file_name = file_name.split(".")[0] _, file_name_without_suffix = os.path.split(os.path.normpath(file_name)) file_name, file_name_suffix = os.path.splitext(file_name_without_suffix) if not file_name_suffix: file_name_suffix = get_extension( media_obj.file_id, getattr(media_obj, "mime_type", "") ) if caption: caption = validate_title(caption) app.set_caption_name(chat_id, message.media_group_id, caption) app.set_caption_entities( chat_id, message.media_group_id, message.caption_entities ) else: caption = app.get_caption_name(chat_id, message.media_group_id) if not file_name and message.photo: file_name = f"{message.photo.file_unique_id}" gen_file_name = ( app.get_file_name(message.id, file_name, caption) + file_name_suffix ) file_save_path = app.get_file_save_path(_type, dirname, datetime_dir_name) temp_file_name = os.path.join(app.temp_save_path, dirname, gen_file_name) file_name = os.path.join(file_save_path, gen_file_name) return truncate_filename(file_name), truncate_filename(temp_file_name), file_format async def add_download_task( message: pyrogram.types.Message, node: TaskNode, ): """Add Download task""" if message.empty: return False node.download_status[message.id] = DownloadStatus.Downloading await queue.put((message, node)) node.total_task += 1 return True async def save_msg_to_file( app, chat_id: Union[int, str], message: pyrogram.types.Message ): """Write message text into file""" dirname = validate_title( message.chat.title if message.chat and message.chat.title else str(chat_id) ) datetime_dir_name = message.date.strftime(app.date_format) if message.date else "0" file_save_path = app.get_file_save_path("msg", dirname, datetime_dir_name) file_name = os.path.join( app.temp_save_path, file_save_path, f"{app.get_file_name(message.id, None, None)}.txt", ) os.makedirs(os.path.dirname(file_name), exist_ok=True) if _is_exist(file_name): return DownloadStatus.SkipDownload, None with open(file_name, "w", encoding="utf-8") as f: f.write(message.text or "") return DownloadStatus.SuccessDownload, file_name async def download_task( client: pyrogram.Client, message: pyrogram.types.Message, node: TaskNode ): """Download and Forward media""" # Track download start increment_task_stat("downloading_files") download_status, file_name = await download_media( client, message, app.media_types, app.file_formats, node ) # Track download completion increment_task_stat("downloading_files", -1) if download_status == DownloadStatus.SuccessDownload: increment_task_stat("completed_files") elif download_status == DownloadStatus.FailedDownload: increment_task_stat("failed_files") if app.enable_download_txt and message.text and not message.media: download_status, file_name = await save_msg_to_file(app, node.chat_id, message) if not node.bot: app.set_download_id(node, message.id, download_status) node.download_status[message.id] = download_status file_size = os.path.getsize(file_name) if file_name else 0 await upload_telegram_chat( client, node.upload_user if node.upload_user else client, app, node, message, download_status, file_name, ) # rclone upload if ( not node.upload_telegram_chat_id and download_status is DownloadStatus.SuccessDownload ): ui_file_name = file_name if app.hide_file_name: ui_file_name = f"****{os.path.splitext(file_name)[-1]}" if await app.upload_file( file_name, update_cloud_upload_stat, (node, message.id, ui_file_name) ): node.upload_success_count += 1 await report_bot_download_status( node.bot, node, download_status, file_size, ) # 任务结束(无论成功/失败/跳过)都清理实时进度缓存,避免失败/中断的"僵尸记录" # 残留在 _download_result 里被 UI 误判为"还在下载"。 # 已完成的历史记录走数据库(db.record_download),不依赖这个 dict。 # key 类型与 update_download_status 写入时保持一致(node.chat_id 原类型)。 remove_download_entry(node.chat_id, message.id) # pylint: disable = R0915,R0914 @record_download_status async def download_media( client: pyrogram.client.Client, message: pyrogram.types.Message, media_types: List[str], file_formats: dict, node: TaskNode, ): """ Download media from Telegram. Each of the files to download are retried 3 times with a delay of 5 seconds each. Parameters ---------- client: pyrogram.client.Client Client to interact with Telegram APIs. message: pyrogram.types.Message Message object retrieved from telegram. media_types: list List of strings of media types to be downloaded. Ex : `["audio", "photo"]` Supported formats: * audio * document * photo * video * voice file_formats: dict Dictionary containing the list of file_formats to be downloaded for `audio`, `document` & `video` media types. Returns ------- int Current message id. """ # pylint: disable = R0912 file_name: str = "" ui_file_name: str = "" task_start_time: float = time.time() media_size = 0 _media = None message = await fetch_message(client, message) try: for _type in media_types: _media = getattr(message, _type, None) if _media is None: continue file_name, temp_file_name, file_format = await _get_media_meta( node.chat_id, message, _media, _type ) media_size = getattr(_media, "file_size", 0) ui_file_name = file_name if app.hide_file_name: ui_file_name = f"****{os.path.splitext(file_name)[-1]}" if _can_download(_type, file_formats, file_format): if _is_exist(file_name): file_size = os.path.getsize(file_name) if file_size or file_size == media_size: logger.info( f"id={message.id} {ui_file_name} " f"{_t('already download,download skipped')}.\n" ) # Update skip counter increment_task_stat("skipped_files") # 这一类跳过属于"本次任务中发现已下载",单独计数给前端算分母 increment_task_stat("existing_skipped") increment_task_stat("checked_messages") return DownloadStatus.SkipDownload, None else: _should_skip, _reason = db.should_skip(str(node.chat_id), message.id) if _should_skip: logger.info( f"id={message.id} {ui_file_name} {_reason},跳过。\n" ) increment_task_stat("skipped_files") # db 标记为跳过也算"本次任务跳过" increment_task_stat("existing_skipped") increment_task_stat("checked_messages") return DownloadStatus.SkipDownload, None else: increment_task_stat("checked_messages") return DownloadStatus.SkipDownload, None break except Exception as e: logger.error( f"Message[{message.id}]: " f"{_t('could not be downloaded due to following exception')}:\n[{e}].", exc_info=True, ) return DownloadStatus.FailedDownload, None if _media is None: return DownloadStatus.SkipDownload, None message_id = message.id for retry in range(3): if is_message_skipped(str(node.chat_id), message_id): clear_skipped_message(str(node.chat_id), message_id) remove_download_entry(str(node.chat_id), message_id) return DownloadStatus.SkipDownload, None try: temp_download_path = await client.download_media( message, file_name=temp_file_name, progress=update_download_status, progress_args=( message_id, ui_file_name, task_start_time, node, client, ), ) if temp_download_path and isinstance(temp_download_path, str): _check_download_finish(media_size, temp_download_path, ui_file_name) await asyncio.sleep(0.5) _move_to_download_path(temp_download_path, file_name) chat_title = "" if message.chat and message.chat.title: chat_title = message.chat.title db.record_download( chat_id=str(node.chat_id), chat_title=chat_title, message_id=message.id, file_name=os.path.basename(file_name), file_path=file_name, file_size=media_size, media_type=_type, status="success", ) # TODO: if not exist file size or media return DownloadStatus.SuccessDownload, file_name except pyrogram.errors.exceptions.bad_request_400.BadRequest: logger.warning( f"Message[{message.id}]: {_t('file reference expired, refetching')}..." ) await asyncio.sleep(RETRY_TIME_OUT) message = await fetch_message(client, message) if _check_timeout(retry, message.id): # pylint: disable = C0301 logger.error( f"Message[{message.id}]: " f"{_t('file reference expired for 3 retries, download skipped.')}" ) except pyrogram.errors.exceptions.flood_420.FloodWait as wait_err: await asyncio.sleep(wait_err.value) logger.warning("Message[{}]: FlowWait {}", message.id, wait_err.value) _check_timeout(retry, message.id) except TypeError: # pylint: disable = C0301 logger.warning( f"{_t('Timeout Error occurred when downloading Message')}[{message.id}], " f"{_t('retrying after')} {RETRY_TIME_OUT} {_t('seconds')}" ) await asyncio.sleep(RETRY_TIME_OUT) if _check_timeout(retry, message.id): logger.error( f"Message[{message.id}]: {_t('Timing out after 3 reties, download skipped.')}" ) except Exception as e: # pylint: disable = C0301 logger.error( f"Message[{message.id}]: " f"{_t('could not be downloaded due to following exception')}:\n[{e}].", exc_info=True, ) break return DownloadStatus.FailedDownload, None def _load_config(): """Load config""" app.load_config() def _check_config() -> bool: """Check config""" print_meta(logger) try: _load_config() logger.add( os.path.join(app.log_file_path, "tdl.log"), rotation="10 MB", retention="10 days", level=app.log_level, ) except Exception as e: logger.exception(f"load config error: {e}") return False return True async def worker(client: pyrogram.client.Client): """Work for download task""" while app.is_running: try: item = await queue.get() message = item[0] node: TaskNode = item[1] if node.is_stop_transmission: continue if is_message_skipped(str(node.chat_id), message.id): skip_message(str(node.chat_id), message.id) continue if node.client: await download_task(node.client, message, node) else: await download_task(client, message, node) except Exception as e: logger.exception(f"{e}") async def download_chat_task( client: pyrogram.Client, chat_download_config: ChatDownloadConfig, node: TaskNode, ): """Download all task""" # 切到下一个频道前,先把上一个频道的最终进度快照进 _completed_chats snapshot_current_chat() # Reset and update task progress reset_task_progress() # Try to get chat title try: chat = await client.get_chat(node.chat_id) chat_title = chat.title or chat.first_name or str(node.chat_id) except Exception: chat_title = str(node.chat_id) update_task_progress( current_chat=str(node.chat_id), current_chat_title=chat_title, is_checking=True ) # 改动点 A:尝试读预扫描缓存;命中则 banner 立刻显示 0 / N filter_key = db.build_filter_key( chat_download_config.download_filter, app.media_types, app.file_formats, ) cached_total = db.get_scan_cache(str(node.chat_id), filter_key) if cached_total is not None: update_task_progress(estimated_total=cached_total) messages_iter = get_chat_history_v2( client, node.chat_id, limit=node.limit, max_id=node.end_offset_id, offset_id=chat_download_config.last_read_message_id, reverse=True, ) chat_download_config.node = node if chat_download_config.ids_to_retry: logger.info(f"{_t('Downloading files failed during last run')}...") skipped_messages: list = await client.get_messages( # type: ignore chat_id=node.chat_id, message_ids=chat_download_config.ids_to_retry ) for message in skipped_messages: await add_download_task(message, node) async for message in messages_iter: # type: ignore # Update checking progress for each message increment_task_stat("checked_messages") meta_data = MetaData() caption = message.caption if caption: caption = validate_title(caption) app.set_caption_name(node.chat_id, message.media_group_id, caption) app.set_caption_entities( node.chat_id, message.media_group_id, message.caption_entities ) else: caption = app.get_caption_name(node.chat_id, message.media_group_id) set_meta_data(meta_data, message, caption) if app.need_skip_message(chat_download_config, message.id): continue if app.exec_filter(chat_download_config, meta_data): # 改动点 B:通过 filter 的消息计数,作为 X/N 的分母实时递增 increment_task_stat("qualified_files") await add_download_task(message, node) else: node.download_status[message.id] = DownloadStatus.SkipDownload increment_task_stat("skipped_files") if message.media_group_id: await upload_telegram_chat( client, node.upload_user, app, node, message, DownloadStatus.SkipDownload, ) chat_download_config.need_check = True chat_download_config.total_task = node.total_task node.is_running = True # 改动点 C:遍历正常完成,把实际 qualified_files 写入缓存并覆盖 estimated_total # 仅在正常完成时写入(异常/中断时不执行,避免脏缓存) actual_total = get_task_progress().get("qualified_files", 0) db.save_scan_cache(str(node.chat_id), filter_key, actual_total) update_task_progress(estimated_total=actual_total, is_checking=False) async def _wait_node_finish(node: TaskNode): """等待单个频道的所有下载任务完成(含重试中的任务)""" while True: if node.need_check and node.finish_task >= node.total_task: break await asyncio.sleep(0.5) async def download_all_chat(client: pyrogram.Client): """Download All chat""" # Use list() to avoid "dictionary changed size during iteration" error for key, value in list(app.chat_download_config.items()): value.node = TaskNode(chat_id=key) try: await download_chat_task(client, value, value.node) except Exception as e: logger.warning(f"Download {key} error: {e}") finally: value.need_check = True # 等当前频道所有下载任务跑完,再拍快照;否则 completed_files 还是 0 await _wait_node_finish(value.node) snapshot_current_chat() reset_task_progress() # 所有频道都已在循环内快照并重置,此处仅确保最终状态干净 reset_task_progress() async def run_until_all_task_finish(): """Normal download""" while True: finish: bool = True for _, value in app.chat_download_config.items(): if not value.need_check or value.total_task != value.finish_task: finish = False if app.restart_program: break await asyncio.sleep(1) def _exec_loop(): """Exec loop""" app.loop.run_until_complete(run_until_all_task_finish()) async def start_server(client: pyrogram.Client): """ Start the server using the provided client. """ await client.start() async def stop_server(client: pyrogram.Client): """ Stop the server using the provided client. """ await client.stop() def main(): """Main function of the downloader.""" tasks = [] # 未填写 API 凭证时以"Web 配置模式"启动,跳过 Telegram 客户端初始化 if app.is_configured(): client = HookClient( "media_downloader", api_id=app.api_id, api_hash=app.api_hash, proxy=app.proxy, workdir=app.session_file_path, start_timeout=app.start_timeout, ) else: client = None logger.info( "未检测到 API 凭证,以 Web 配置模式启动。" f"请访问 http://0.0.0.0:{app.web_port} 完成初始设置。" ) try: app.pre_run() db.init_db(os.path.join(os.path.abspath("."), "appdata", "downloads.db")) init_web(app, client, add_download_task, download_chat_task) if client is not None: set_max_concurrent_transmissions(client, app.max_concurrent_transmissions) app.loop.run_until_complete(start_server(client)) logger.success(_t("Successfully started (Press Ctrl+C to stop)")) app.loop.create_task(download_all_chat(client)) for _ in range(app.max_download_task): task = app.loop.create_task(worker(client)) tasks.append(task) if app.bot_token: app.loop.run_until_complete( start_download_bot(app, client, add_download_task, download_chat_task) ) _exec_loop() else: # Web 配置模式:保持进程运行,等待用户完成 Web 向导配置后触发重启 async def _wait_for_restart(): while not app.restart_program: await asyncio.sleep(1) app.loop.run_until_complete(_wait_for_restart()) except KeyboardInterrupt: logger.info(_t("KeyboardInterrupt")) except Exception as e: logger.exception("{}", e) finally: app.is_running = False if client is not None: if app.bot_token: app.loop.run_until_complete(stop_download_bot()) try: app.loop.run_until_complete(stop_server(client)) except Exception: pass for task in tasks: task.cancel() logger.info(_t("Stopped!")) # check_for_updates(app.proxy) # Web 配置模式下重启时,不用内存中的空凭证覆盖向导刚写入的配置文件 if not (client is None and app.restart_program): logger.info(f"{_t('update config')}......") app.update_config() logger.success( f"{_t('Updated last read message_id to config file')}," f"{_t('total download')} {app.total_download_task}, " f"{_t('total upload file')} " f"{app.cloud_drive_config.total_upload_success_file_count}" ) # Return whether restart is needed return app.restart_program def run_with_restart(): """Run main with auto-restart support using os.execv for clean restart""" import sys should_restart = main() if should_restart: logger.info("🔄 正在重启程序以应用新配置...") # Shutdown web server to release the port shutdown_web() # Wait a moment for the port to be released time.sleep(1) # Use os.execv to completely restart the process python = sys.executable os.execv(python, [python] + sys.argv) if __name__ == "__main__": if _check_config(): run_with_restart()