"""Download Stat""" import asyncio import time from enum import Enum from pyrogram import Client from module.app import TaskNode class DownloadState(Enum): """Download state""" Downloading = 1 StopDownload = 2 _download_result: dict = {} _paused_messages: set = set() # (chat_id, message_id) 单条暂停 _skipped_messages: set = set() # (chat_id, message_id) 单条跳过 _total_download_speed: int = 0 _total_download_size: int = 0 _last_download_time: float = time.time() _download_state: DownloadState = DownloadState.Downloading # Task progress tracking _task_progress: dict = { "current_chat": "", "current_chat_title": "", "total_messages": 0, "checked_messages": 0, "skipped_files": 0, "downloading_files": 0, "completed_files": 0, "failed_files": 0, # 当次遍历已通过 filter 的消息条数(实时递增,无缓存时用它做分母) "qualified_files": 0, # 缓存命中后的预计下载总数;未命中时遍历结束后再赋值 "estimated_total": 0, # 本次任务中"通过了 filter 但因已下载/被标记而跳过"的数量,用于前端算"真正要下载"的分母 "existing_skipped": 0, "is_checking": False, "last_update": 0, } # 多频道任务队列与已完成快照(独立于 _task_progress,不被单频道 reset 影响) # _task_queue: 任务开始时由 API 设置,[{chat_id, chat_title}, ...] # _completed_chats: 每切换一个频道前把上一个的最终进度快照到这里 _task_queue: list = [] _completed_chats: list = [] def get_download_result() -> dict: """get global download result""" return _download_result def get_total_download_speed() -> int: """get total download speed""" return _total_download_speed def get_download_state() -> DownloadState: """get download state""" return _download_state def get_task_progress() -> dict: """get task progress with auto-detection of checking state""" progress = _task_progress.copy() # Auto-detect if still checking based on last_update time # If last update was within 3 seconds, consider it still active if progress["current_chat"] and progress["last_update"] > 0: time_since_update = time.time() - progress["last_update"] # If there was recent activity (within 3 seconds), still checking # If no recent activity and we have skipped files, checking is done if time_since_update <= 3: progress["is_checking"] = True elif progress["skipped_files"] > 0 or progress["checked_messages"] > 0: progress["is_checking"] = False # 附加多频道任务总览,供前端渲染队列卡片 progress["task_queue"] = list(_task_queue) progress["completed_chats"] = list(_completed_chats) return progress def snapshot_current_chat(): """把当前 _task_progress 的核心字段快照进 _completed_chats。current_chat 为空时 no-op。 在多频道串行下载时,每个 download_chat_task 开始前调用,保留上一个频道的最终进度。 """ global _completed_chats chat_id = _task_progress.get("current_chat", "") if not chat_id: return qual = _task_progress.get("qualified_files", 0) or 0 est = _task_progress.get("estimated_total", 0) or 0 existing = _task_progress.get("existing_skipped", 0) or 0 raw_total = est or qual real_total = max(0, raw_total - existing) _completed_chats.append({ "chat_id": chat_id, "chat_title": _task_progress.get("current_chat_title", "") or chat_id, "done": _task_progress.get("completed_files", 0) or 0, "total": real_total, "skip": _task_progress.get("skipped_files", 0) or 0, "existing_skip": existing, "failed": _task_progress.get("failed_files", 0) or 0, }) def set_task_queue(items: list): """设置本次任务的完整频道队列。items 每项 {chat_id, chat_title}。""" global _task_queue _task_queue = [] for it in (items or []): cid = str(it.get("chat_id", "") or "").strip() if not cid: continue _task_queue.append({ "chat_id": cid, "chat_title": it.get("chat_title", "") or cid, }) def clear_completed_chats(): """新任务启动时清空已完成列表。""" global _completed_chats _completed_chats = [] def update_task_progress( current_chat: str = None, current_chat_title: str = None, total_messages: int = None, checked_messages: int = None, skipped_files: int = None, downloading_files: int = None, completed_files: int = None, failed_files: int = None, qualified_files: int = None, estimated_total: int = None, is_checking: bool = None, ): """update task progress""" global _task_progress if current_chat is not None: _task_progress["current_chat"] = current_chat if current_chat_title is not None: _task_progress["current_chat_title"] = current_chat_title if total_messages is not None: _task_progress["total_messages"] = total_messages if checked_messages is not None: _task_progress["checked_messages"] = checked_messages if skipped_files is not None: _task_progress["skipped_files"] = skipped_files if downloading_files is not None: _task_progress["downloading_files"] = downloading_files if completed_files is not None: _task_progress["completed_files"] = completed_files if failed_files is not None: _task_progress["failed_files"] = failed_files if qualified_files is not None: _task_progress["qualified_files"] = qualified_files if estimated_total is not None: _task_progress["estimated_total"] = estimated_total if is_checking is not None: _task_progress["is_checking"] = is_checking _task_progress["last_update"] = time.time() def reset_task_progress(): """reset task progress for new task""" global _task_progress _task_progress = { "current_chat": "", "current_chat_title": "", "total_messages": 0, "checked_messages": 0, "skipped_files": 0, "downloading_files": 0, "completed_files": 0, "failed_files": 0, "qualified_files": 0, "estimated_total": 0, "existing_skipped": 0, "is_checking": False, "last_update": time.time(), } def increment_task_stat(stat_type: str, count: int = 1): """increment a specific stat""" global _task_progress if stat_type in _task_progress and isinstance(_task_progress[stat_type], int): _task_progress[stat_type] += count _task_progress["last_update"] = time.time() # pylint: disable = W0603 def set_download_state(state: DownloadState): """set download state""" global _download_state _download_state = state def pause_message(chat_id: str, message_id: int): """暂停单条消息下载""" _paused_messages.add((chat_id, message_id)) def resume_message(chat_id: str, message_id: int): """继续单条消息下载""" _paused_messages.discard((chat_id, message_id)) def skip_message(chat_id: str, message_id: int): """跳过单条消息下载""" _skipped_messages.add((chat_id, message_id)) _paused_messages.discard((chat_id, message_id)) def is_message_paused(chat_id: str, message_id: int) -> bool: return (chat_id, message_id) in _paused_messages def is_message_skipped(chat_id: str, message_id: int) -> bool: return (chat_id, message_id) in _skipped_messages def clear_skipped_message(chat_id: str, message_id: int): """清除跳过标记(下载流程退出时调用)""" _skipped_messages.discard((chat_id, message_id)) def remove_download_entry(chat_id, message_id): """从下载结果中移除条目,使其不再显示在正在下载列表""" chat_key = chat_id if not isinstance(chat_id, str) else chat_id if chat_key in _download_result and message_id in _download_result[chat_key]: del _download_result[chat_key][message_id] async def update_download_status( down_byte: int, total_size: int, message_id: int, file_name: str, start_time: float, node: TaskNode, client: Client, ): """update_download_status""" cur_time = time.time() # pylint: disable = W0603 global _total_download_speed global _total_download_size global _last_download_time if node.is_stop_transmission: client.stop_transmission() chat_id = node.chat_id _msg_key = (str(chat_id), message_id) # 单条跳过(不在此处 discard,让重试循环也能检测到) if is_message_skipped(*_msg_key): remove_download_entry(chat_id, message_id) client.stop_transmission() return # 单条暂停(阻塞直到继续或跳过) while is_message_paused(*_msg_key): if is_message_skipped(*_msg_key): remove_download_entry(chat_id, message_id) client.stop_transmission() return await asyncio.sleep(0.5) # 全局暂停 while get_download_state() == DownloadState.StopDownload: if node.is_stop_transmission: client.stop_transmission() await asyncio.sleep(1) if not _download_result.get(chat_id): _download_result[chat_id] = {} if _download_result[chat_id].get(message_id): last_download_byte = _download_result[chat_id][message_id]["down_byte"] last_time = _download_result[chat_id][message_id]["end_time"] download_speed = _download_result[chat_id][message_id]["download_speed"] each_second_total_download = _download_result[chat_id][message_id][ "each_second_total_download" ] end_time = _download_result[chat_id][message_id]["end_time"] _total_download_size += down_byte - last_download_byte each_second_total_download += down_byte - last_download_byte if cur_time - last_time >= 1.0: download_speed = int(each_second_total_download / (cur_time - last_time)) end_time = cur_time each_second_total_download = 0 download_speed = max(download_speed, 0) _download_result[chat_id][message_id]["down_byte"] = down_byte _download_result[chat_id][message_id]["end_time"] = end_time _download_result[chat_id][message_id]["download_speed"] = download_speed _download_result[chat_id][message_id][ "each_second_total_download" ] = each_second_total_download else: each_second_total_download = down_byte _download_result[chat_id][message_id] = { "down_byte": down_byte, "total_size": total_size, "file_name": file_name, "start_time": start_time, "end_time": cur_time, "download_speed": down_byte / (cur_time - start_time), "each_second_total_download": each_second_total_download, "task_id": node.task_id, } _total_download_size += down_byte if cur_time - _last_download_time >= 1.0: # update speed _total_download_speed = int( _total_download_size / (cur_time - _last_download_time) ) _total_download_speed = max(_total_download_speed, 0) _total_download_size = 0 _last_download_time = cur_time