53d3ab7769
部署到群晖 / deploy (push) Failing after 34s
三处协同修复: 1) worker 异常分支补 _release_stuck_task, 标记 FailedDownload、推进 finish_task、清残留。 2) download_media 中 fetch_message 加 try-except, 连接异常返回 FailedDownload,不再让异常冒泡。 3) download_chat_task 用 try/finally 兜底回写 chat_download_config.total_task,避免 _wait_node_finish 误判频道已完成而切到下一个。 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
932 lines
31 KiB
Python
932 lines
31 KiB
Python
"""Downloads media from telegram."""
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
import shutil
|
||
import time
|
||
from typing import List, Optional, Tuple, Union
|
||
|
||
import pyrogram
|
||
from loguru import logger
|
||
from pyrogram.types import Audio, Document, Photo, Video, VideoNote, Voice
|
||
from rich.logging import RichHandler
|
||
|
||
import module.database as db
|
||
from module.app import Application, ChatDownloadConfig, DownloadStatus, TaskNode
|
||
from module.bot import start_download_bot, stop_download_bot
|
||
from module.download_stat import (
|
||
update_download_status,
|
||
update_task_progress,
|
||
reset_task_progress,
|
||
snapshot_current_chat,
|
||
increment_task_stat,
|
||
get_task_progress,
|
||
is_message_skipped,
|
||
skip_message,
|
||
remove_download_entry,
|
||
clear_skipped_message,
|
||
)
|
||
from module.get_chat_history_v2 import get_chat_history_v2
|
||
from module.language import _t
|
||
from module.pyrogram_extension import (
|
||
HookClient,
|
||
fetch_message,
|
||
get_extension,
|
||
record_download_status,
|
||
report_bot_download_status,
|
||
set_max_concurrent_transmissions,
|
||
set_meta_data,
|
||
update_cloud_upload_stat,
|
||
upload_telegram_chat,
|
||
)
|
||
from module.web import init_web, shutdown_web
|
||
from utils.format import truncate_filename, validate_title
|
||
from utils.log import LogFilter
|
||
from utils.meta import print_meta
|
||
from utils.meta_data import MetaData
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(message)s",
|
||
datefmt="[%X]",
|
||
handlers=[RichHandler()],
|
||
)
|
||
|
||
CONFIG_NAME = "config.yaml"
|
||
DATA_FILE_NAME = "data.yaml"
|
||
APPLICATION_NAME = "media_downloader"
|
||
app = Application(CONFIG_NAME, DATA_FILE_NAME, APPLICATION_NAME)
|
||
|
||
queue: asyncio.Queue = asyncio.Queue()
|
||
RETRY_TIME_OUT = 3
|
||
|
||
logging.getLogger("pyrogram.session.session").addFilter(LogFilter())
|
||
logging.getLogger("pyrogram.client").addFilter(LogFilter())
|
||
|
||
logging.getLogger("pyrogram").setLevel(logging.WARNING)
|
||
|
||
|
||
def _check_download_finish(media_size: int, download_path: str, ui_file_name: str):
|
||
"""Check download task if finish
|
||
|
||
Parameters
|
||
----------
|
||
media_size: int
|
||
The size of the downloaded resource
|
||
download_path: str
|
||
Resource download hold path
|
||
ui_file_name: str
|
||
Really show file name
|
||
|
||
"""
|
||
download_size = os.path.getsize(download_path)
|
||
if media_size == download_size:
|
||
logger.success(f"{_t('Successfully downloaded')} - {ui_file_name}")
|
||
else:
|
||
logger.warning(
|
||
f"{_t('Media downloaded with wrong size')}: "
|
||
f"{download_size}, {_t('actual')}: "
|
||
f"{media_size}, {_t('file name')}: {ui_file_name}"
|
||
)
|
||
os.remove(download_path)
|
||
raise pyrogram.errors.exceptions.bad_request_400.BadRequest()
|
||
|
||
|
||
def _move_to_download_path(temp_download_path: str, download_path: str):
|
||
"""Move file to download path
|
||
|
||
Parameters
|
||
----------
|
||
temp_download_path: str
|
||
Temporary download path
|
||
|
||
download_path: str
|
||
Download path
|
||
|
||
"""
|
||
|
||
directory, _ = os.path.split(download_path)
|
||
os.makedirs(directory, exist_ok=True)
|
||
shutil.move(temp_download_path, download_path)
|
||
|
||
|
||
def _check_timeout(retry: int, _: int):
|
||
"""Check if message download timeout, then add message id into failed_ids
|
||
|
||
Parameters
|
||
----------
|
||
retry: int
|
||
Retry download message times
|
||
|
||
message_id: int
|
||
Try to download message 's id
|
||
|
||
"""
|
||
if retry == 2:
|
||
return True
|
||
return False
|
||
|
||
|
||
def _can_download(_type: str, file_formats: dict, file_format: Optional[str]) -> bool:
|
||
"""
|
||
Check if the given file format can be downloaded.
|
||
|
||
Parameters
|
||
----------
|
||
_type: str
|
||
Type of media object.
|
||
file_formats: dict
|
||
Dictionary containing the list of file_formats
|
||
to be downloaded for `audio`, `document` & `video`
|
||
media types
|
||
file_format: str
|
||
Format of the current file to be downloaded.
|
||
|
||
Returns
|
||
-------
|
||
bool
|
||
True if the file format can be downloaded else False.
|
||
"""
|
||
if _type in ["audio", "document", "video"]:
|
||
allowed_formats: list = file_formats[_type]
|
||
if not file_format in allowed_formats and allowed_formats[0] != "all":
|
||
return False
|
||
return True
|
||
|
||
|
||
def _is_exist(file_path: str) -> bool:
|
||
"""
|
||
Check if a file exists and it is not a directory.
|
||
|
||
Parameters
|
||
----------
|
||
file_path: str
|
||
Absolute path of the file to be checked.
|
||
|
||
Returns
|
||
-------
|
||
bool
|
||
True if the file exists else False.
|
||
"""
|
||
return not os.path.isdir(file_path) and os.path.exists(file_path)
|
||
|
||
|
||
# pylint: disable = R0912
|
||
|
||
|
||
async def _get_media_meta(
|
||
chat_id: Union[int, str],
|
||
message: pyrogram.types.Message,
|
||
media_obj: Union[Audio, Document, Photo, Video, VideoNote, Voice],
|
||
_type: str,
|
||
) -> Tuple[str, str, Optional[str]]:
|
||
"""Extract file name and file id from media object.
|
||
|
||
Parameters
|
||
----------
|
||
media_obj: Union[Audio, Document, Photo, Video, VideoNote, Voice]
|
||
Media object to be extracted.
|
||
_type: str
|
||
Type of media object.
|
||
|
||
Returns
|
||
-------
|
||
Tuple[str, str, Optional[str]]
|
||
file_name, file_format
|
||
"""
|
||
if _type in ["audio", "document", "video"]:
|
||
# pylint: disable = C0301
|
||
file_format: Optional[str] = media_obj.mime_type.split("/")[-1] # type: ignore
|
||
else:
|
||
file_format = None
|
||
|
||
file_name = None
|
||
temp_file_name = None
|
||
dirname = validate_title(f"{chat_id}")
|
||
if message.chat and message.chat.title:
|
||
dirname = validate_title(f"{message.chat.title}")
|
||
|
||
if message.date:
|
||
datetime_dir_name = message.date.strftime(app.date_format)
|
||
else:
|
||
datetime_dir_name = "0"
|
||
|
||
if _type in ["voice", "video_note"]:
|
||
# pylint: disable = C0209
|
||
file_format = media_obj.mime_type.split("/")[-1] # type: ignore
|
||
file_save_path = app.get_file_save_path(_type, dirname, datetime_dir_name)
|
||
file_name = "{} - {}_{}.{}".format(
|
||
message.id,
|
||
_type,
|
||
media_obj.date.isoformat(), # type: ignore
|
||
file_format,
|
||
)
|
||
file_name = validate_title(file_name)
|
||
temp_file_name = os.path.join(app.temp_save_path, dirname, file_name)
|
||
|
||
file_name = os.path.join(file_save_path, file_name)
|
||
else:
|
||
file_name = getattr(media_obj, "file_name", None)
|
||
caption = getattr(message, "caption", None)
|
||
|
||
file_name_suffix = ".unknown"
|
||
if not file_name:
|
||
file_name_suffix = get_extension(
|
||
media_obj.file_id, getattr(media_obj, "mime_type", "")
|
||
)
|
||
else:
|
||
# file_name = file_name.split(".")[0]
|
||
_, file_name_without_suffix = os.path.split(os.path.normpath(file_name))
|
||
file_name, file_name_suffix = os.path.splitext(file_name_without_suffix)
|
||
if not file_name_suffix:
|
||
file_name_suffix = get_extension(
|
||
media_obj.file_id, getattr(media_obj, "mime_type", "")
|
||
)
|
||
|
||
if caption:
|
||
caption = validate_title(caption)
|
||
app.set_caption_name(chat_id, message.media_group_id, caption)
|
||
app.set_caption_entities(
|
||
chat_id, message.media_group_id, message.caption_entities
|
||
)
|
||
else:
|
||
caption = app.get_caption_name(chat_id, message.media_group_id)
|
||
|
||
if not file_name and message.photo:
|
||
file_name = f"{message.photo.file_unique_id}"
|
||
|
||
gen_file_name = (
|
||
app.get_file_name(message.id, file_name, caption) + file_name_suffix
|
||
)
|
||
|
||
file_save_path = app.get_file_save_path(_type, dirname, datetime_dir_name)
|
||
|
||
temp_file_name = os.path.join(app.temp_save_path, dirname, gen_file_name)
|
||
|
||
file_name = os.path.join(file_save_path, gen_file_name)
|
||
return truncate_filename(file_name), truncate_filename(temp_file_name), file_format
|
||
|
||
|
||
async def add_download_task(
|
||
message: pyrogram.types.Message,
|
||
node: TaskNode,
|
||
):
|
||
"""Add Download task"""
|
||
if message.empty:
|
||
return False
|
||
node.download_status[message.id] = DownloadStatus.Downloading
|
||
await queue.put((message, node))
|
||
node.total_task += 1
|
||
return True
|
||
|
||
|
||
async def save_msg_to_file(
|
||
app, chat_id: Union[int, str], message: pyrogram.types.Message
|
||
):
|
||
"""Write message text into file"""
|
||
dirname = validate_title(
|
||
message.chat.title if message.chat and message.chat.title else str(chat_id)
|
||
)
|
||
datetime_dir_name = message.date.strftime(app.date_format) if message.date else "0"
|
||
|
||
file_save_path = app.get_file_save_path("msg", dirname, datetime_dir_name)
|
||
file_name = os.path.join(
|
||
app.temp_save_path,
|
||
file_save_path,
|
||
f"{app.get_file_name(message.id, None, None)}.txt",
|
||
)
|
||
|
||
os.makedirs(os.path.dirname(file_name), exist_ok=True)
|
||
|
||
if _is_exist(file_name):
|
||
return DownloadStatus.SkipDownload, None
|
||
|
||
with open(file_name, "w", encoding="utf-8") as f:
|
||
f.write(message.text or "")
|
||
|
||
return DownloadStatus.SuccessDownload, file_name
|
||
|
||
|
||
async def download_task(
|
||
client: pyrogram.Client, message: pyrogram.types.Message, node: TaskNode
|
||
):
|
||
"""Download and Forward media"""
|
||
|
||
# Track download start
|
||
increment_task_stat("downloading_files")
|
||
|
||
download_status, file_name = await download_media(
|
||
client, message, app.media_types, app.file_formats, node
|
||
)
|
||
|
||
# Track download completion
|
||
increment_task_stat("downloading_files", -1)
|
||
if download_status == DownloadStatus.SuccessDownload:
|
||
increment_task_stat("completed_files")
|
||
elif download_status == DownloadStatus.FailedDownload:
|
||
increment_task_stat("failed_files")
|
||
|
||
if app.enable_download_txt and message.text and not message.media:
|
||
download_status, file_name = await save_msg_to_file(app, node.chat_id, message)
|
||
|
||
if not node.bot:
|
||
app.set_download_id(node, message.id, download_status)
|
||
|
||
node.download_status[message.id] = download_status
|
||
|
||
file_size = os.path.getsize(file_name) if file_name else 0
|
||
|
||
await upload_telegram_chat(
|
||
client,
|
||
node.upload_user if node.upload_user else client,
|
||
app,
|
||
node,
|
||
message,
|
||
download_status,
|
||
file_name,
|
||
)
|
||
|
||
# rclone upload
|
||
if (
|
||
not node.upload_telegram_chat_id
|
||
and download_status is DownloadStatus.SuccessDownload
|
||
):
|
||
ui_file_name = file_name
|
||
if app.hide_file_name:
|
||
ui_file_name = f"****{os.path.splitext(file_name)[-1]}"
|
||
if await app.upload_file(
|
||
file_name, update_cloud_upload_stat, (node, message.id, ui_file_name)
|
||
):
|
||
node.upload_success_count += 1
|
||
|
||
await report_bot_download_status(
|
||
node.bot,
|
||
node,
|
||
download_status,
|
||
file_size,
|
||
)
|
||
|
||
# 任务结束(无论成功/失败/跳过)都清理实时进度缓存,避免失败/中断的"僵尸记录"
|
||
# 残留在 _download_result 里被 UI 误判为"还在下载"。
|
||
# 已完成的历史记录走数据库(db.record_download),不依赖这个 dict。
|
||
# key 类型与 update_download_status 写入时保持一致(node.chat_id 原类型)。
|
||
remove_download_entry(node.chat_id, message.id)
|
||
|
||
|
||
# pylint: disable = R0915,R0914
|
||
|
||
|
||
@record_download_status
|
||
async def download_media(
|
||
client: pyrogram.client.Client,
|
||
message: pyrogram.types.Message,
|
||
media_types: List[str],
|
||
file_formats: dict,
|
||
node: TaskNode,
|
||
):
|
||
"""
|
||
Download media from Telegram.
|
||
|
||
Each of the files to download are retried 3 times with a
|
||
delay of 5 seconds each.
|
||
|
||
Parameters
|
||
----------
|
||
client: pyrogram.client.Client
|
||
Client to interact with Telegram APIs.
|
||
message: pyrogram.types.Message
|
||
Message object retrieved from telegram.
|
||
media_types: list
|
||
List of strings of media types to be downloaded.
|
||
Ex : `["audio", "photo"]`
|
||
Supported formats:
|
||
* audio
|
||
* document
|
||
* photo
|
||
* video
|
||
* voice
|
||
file_formats: dict
|
||
Dictionary containing the list of file_formats
|
||
to be downloaded for `audio`, `document` & `video`
|
||
media types.
|
||
|
||
Returns
|
||
-------
|
||
int
|
||
Current message id.
|
||
"""
|
||
|
||
# pylint: disable = R0912
|
||
|
||
file_name: str = ""
|
||
ui_file_name: str = ""
|
||
task_start_time: float = time.time()
|
||
media_size = 0
|
||
_media = None
|
||
# 关键修复:fetch_message 是发起任何下载前的网络调用,连接异常时不能让异常一路冒泡,
|
||
# 否则 worker 只会打日志,状态永远停在 Downloading,UI 卡死。
|
||
try:
|
||
message = await fetch_message(client, message)
|
||
except Exception as fetch_err:
|
||
logger.warning(
|
||
f"Message[{getattr(message, 'id', '?')}] 拉取消息失败(可能连接断开): {fetch_err}"
|
||
)
|
||
return DownloadStatus.FailedDownload, None
|
||
try:
|
||
for _type in media_types:
|
||
_media = getattr(message, _type, None)
|
||
if _media is None:
|
||
continue
|
||
file_name, temp_file_name, file_format = await _get_media_meta(
|
||
node.chat_id, message, _media, _type
|
||
)
|
||
media_size = getattr(_media, "file_size", 0)
|
||
|
||
ui_file_name = file_name
|
||
if app.hide_file_name:
|
||
ui_file_name = f"****{os.path.splitext(file_name)[-1]}"
|
||
|
||
if _can_download(_type, file_formats, file_format):
|
||
if _is_exist(file_name):
|
||
file_size = os.path.getsize(file_name)
|
||
if file_size or file_size == media_size:
|
||
logger.info(
|
||
f"id={message.id} {ui_file_name} "
|
||
f"{_t('already download,download skipped')}.\n"
|
||
)
|
||
# Update skip counter
|
||
increment_task_stat("skipped_files")
|
||
# 这一类跳过属于"本次任务中发现已下载",单独计数给前端算分母
|
||
increment_task_stat("existing_skipped")
|
||
increment_task_stat("checked_messages")
|
||
return DownloadStatus.SkipDownload, None
|
||
else:
|
||
_should_skip, _reason = db.should_skip(str(node.chat_id), message.id)
|
||
if _should_skip:
|
||
logger.info(
|
||
f"id={message.id} {ui_file_name} {_reason},跳过。\n"
|
||
)
|
||
increment_task_stat("skipped_files")
|
||
# db 标记为跳过也算"本次任务跳过"
|
||
increment_task_stat("existing_skipped")
|
||
increment_task_stat("checked_messages")
|
||
return DownloadStatus.SkipDownload, None
|
||
else:
|
||
increment_task_stat("checked_messages")
|
||
return DownloadStatus.SkipDownload, None
|
||
|
||
break
|
||
except Exception as e:
|
||
logger.error(
|
||
f"Message[{message.id}]: "
|
||
f"{_t('could not be downloaded due to following exception')}:\n[{e}].",
|
||
exc_info=True,
|
||
)
|
||
return DownloadStatus.FailedDownload, None
|
||
if _media is None:
|
||
return DownloadStatus.SkipDownload, None
|
||
|
||
message_id = message.id
|
||
|
||
for retry in range(3):
|
||
if is_message_skipped(str(node.chat_id), message_id):
|
||
clear_skipped_message(str(node.chat_id), message_id)
|
||
remove_download_entry(str(node.chat_id), message_id)
|
||
return DownloadStatus.SkipDownload, None
|
||
try:
|
||
temp_download_path = await client.download_media(
|
||
message,
|
||
file_name=temp_file_name,
|
||
progress=update_download_status,
|
||
progress_args=(
|
||
message_id,
|
||
ui_file_name,
|
||
task_start_time,
|
||
node,
|
||
client,
|
||
),
|
||
)
|
||
|
||
if temp_download_path and isinstance(temp_download_path, str):
|
||
_check_download_finish(media_size, temp_download_path, ui_file_name)
|
||
await asyncio.sleep(0.5)
|
||
_move_to_download_path(temp_download_path, file_name)
|
||
chat_title = ""
|
||
if message.chat and message.chat.title:
|
||
chat_title = message.chat.title
|
||
db.record_download(
|
||
chat_id=str(node.chat_id),
|
||
chat_title=chat_title,
|
||
message_id=message.id,
|
||
file_name=os.path.basename(file_name),
|
||
file_path=file_name,
|
||
file_size=media_size,
|
||
media_type=_type,
|
||
status="success",
|
||
)
|
||
# TODO: if not exist file size or media
|
||
return DownloadStatus.SuccessDownload, file_name
|
||
except pyrogram.errors.exceptions.bad_request_400.BadRequest:
|
||
logger.warning(
|
||
f"Message[{message.id}]: {_t('file reference expired, refetching')}..."
|
||
)
|
||
await asyncio.sleep(RETRY_TIME_OUT)
|
||
message = await fetch_message(client, message)
|
||
if _check_timeout(retry, message.id):
|
||
# pylint: disable = C0301
|
||
logger.error(
|
||
f"Message[{message.id}]: "
|
||
f"{_t('file reference expired for 3 retries, download skipped.')}"
|
||
)
|
||
except pyrogram.errors.exceptions.flood_420.FloodWait as wait_err:
|
||
await asyncio.sleep(wait_err.value)
|
||
logger.warning("Message[{}]: FlowWait {}", message.id, wait_err.value)
|
||
_check_timeout(retry, message.id)
|
||
except TypeError:
|
||
# pylint: disable = C0301
|
||
logger.warning(
|
||
f"{_t('Timeout Error occurred when downloading Message')}[{message.id}], "
|
||
f"{_t('retrying after')} {RETRY_TIME_OUT} {_t('seconds')}"
|
||
)
|
||
await asyncio.sleep(RETRY_TIME_OUT)
|
||
if _check_timeout(retry, message.id):
|
||
logger.error(
|
||
f"Message[{message.id}]: {_t('Timing out after 3 reties, download skipped.')}"
|
||
)
|
||
except Exception as e:
|
||
# pylint: disable = C0301
|
||
logger.error(
|
||
f"Message[{message.id}]: "
|
||
f"{_t('could not be downloaded due to following exception')}:\n[{e}].",
|
||
exc_info=True,
|
||
)
|
||
break
|
||
|
||
return DownloadStatus.FailedDownload, None
|
||
|
||
|
||
def _load_config():
|
||
"""Load config"""
|
||
app.load_config()
|
||
|
||
|
||
def _check_config() -> bool:
|
||
"""Check config"""
|
||
print_meta(logger)
|
||
try:
|
||
_load_config()
|
||
logger.add(
|
||
os.path.join(app.log_file_path, "tdl.log"),
|
||
rotation="10 MB",
|
||
retention="10 days",
|
||
level=app.log_level,
|
||
)
|
||
except Exception as e:
|
||
logger.exception(f"load config error: {e}")
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
async def worker(client: pyrogram.client.Client):
|
||
"""下载任务消费者协程"""
|
||
while app.is_running:
|
||
message = None
|
||
node = None
|
||
try:
|
||
item = await queue.get()
|
||
message = item[0]
|
||
node = item[1]
|
||
|
||
if node.is_stop_transmission:
|
||
# 主动中止:把队列里残留的下载中状态清掉,避免 UI 一直显示 Downloading
|
||
_release_stuck_task(node, message, DownloadStatus.SkipDownload)
|
||
continue
|
||
|
||
if is_message_skipped(str(node.chat_id), message.id):
|
||
skip_message(str(node.chat_id), message.id)
|
||
_release_stuck_task(node, message, DownloadStatus.SkipDownload)
|
||
continue
|
||
|
||
if node.client:
|
||
await download_task(node.client, message, node)
|
||
else:
|
||
await download_task(client, message, node)
|
||
except Exception as e:
|
||
logger.exception(f"worker 捕获到未处理异常: {e}")
|
||
# 关键修复:worker 吞异常时必须把状态推进,否则 finish_task 永远追不上 total_task
|
||
if node is not None and message is not None:
|
||
_release_stuck_task(node, message, DownloadStatus.FailedDownload)
|
||
|
||
|
||
def _release_stuck_task(
|
||
node: "TaskNode",
|
||
message: "pyrogram.types.Message",
|
||
status: "DownloadStatus",
|
||
):
|
||
"""将异常/中断的任务从"下载中"状态释放,避免任务队列永久卡死。
|
||
|
||
做三件事:
|
||
1) 标记 download_status,让 UI 不再显示"下载中"。
|
||
2) 推进 finish_task,让 _wait_node_finish 能正常退出。
|
||
3) 清理 _download_result 残留条目,避免速度/列表脏数据。
|
||
"""
|
||
try:
|
||
msg_id = getattr(message, "id", None)
|
||
if msg_id is None:
|
||
return
|
||
node.download_status[msg_id] = status
|
||
# finish_task 通过 app.set_download_id 推进;bot 模式下原逻辑也不走 set_download_id,
|
||
# 这里保持一致:非 bot 时才推进,避免重复计数。
|
||
if not node.bot:
|
||
try:
|
||
app.set_download_id(node, msg_id, status)
|
||
except Exception as inner:
|
||
logger.warning(f"释放卡死任务时 set_download_id 失败 msg_id={msg_id}: {inner}")
|
||
remove_download_entry(node.chat_id, msg_id)
|
||
except Exception as e:
|
||
logger.warning(f"释放卡死任务清理失败: {e}")
|
||
|
||
|
||
async def download_chat_task(
|
||
client: pyrogram.Client,
|
||
chat_download_config: ChatDownloadConfig,
|
||
node: TaskNode,
|
||
):
|
||
"""Download all task"""
|
||
# 切到下一个频道前,先把上一个频道的最终进度快照进 _completed_chats
|
||
snapshot_current_chat()
|
||
# Reset and update task progress
|
||
reset_task_progress()
|
||
|
||
# Try to get chat title
|
||
try:
|
||
chat = await client.get_chat(node.chat_id)
|
||
chat_title = chat.title or chat.first_name or str(node.chat_id)
|
||
except Exception:
|
||
chat_title = str(node.chat_id)
|
||
|
||
update_task_progress(
|
||
current_chat=str(node.chat_id),
|
||
current_chat_title=chat_title,
|
||
is_checking=True
|
||
)
|
||
|
||
# 改动点 A:尝试读预扫描缓存;命中则 banner 立刻显示 0 / N
|
||
filter_key = db.build_filter_key(
|
||
chat_download_config.download_filter,
|
||
app.media_types,
|
||
app.file_formats,
|
||
)
|
||
cached_total = db.get_scan_cache(str(node.chat_id), filter_key)
|
||
if cached_total is not None:
|
||
update_task_progress(estimated_total=cached_total)
|
||
|
||
messages_iter = get_chat_history_v2(
|
||
client,
|
||
node.chat_id,
|
||
limit=node.limit,
|
||
max_id=node.end_offset_id,
|
||
offset_id=chat_download_config.last_read_message_id,
|
||
reverse=True,
|
||
)
|
||
|
||
chat_download_config.node = node
|
||
|
||
if chat_download_config.ids_to_retry:
|
||
logger.info(f"{_t('Downloading files failed during last run')}...")
|
||
skipped_messages: list = await client.get_messages( # type: ignore
|
||
chat_id=node.chat_id, message_ids=chat_download_config.ids_to_retry
|
||
)
|
||
|
||
for message in skipped_messages:
|
||
await add_download_task(message, node)
|
||
|
||
# 关键修复:消息迭代和缓存写入必须包进 try/finally,
|
||
# 否则中途抛 Connection lost 时 total_task 不会被回写,_wait_node_finish 误判频道已完成,
|
||
# 后续 worker 还在跑就被切到下一个频道,UI 永远显示"下载中"。
|
||
iter_completed = False
|
||
try:
|
||
async for message in messages_iter: # type: ignore
|
||
# Update checking progress for each message
|
||
increment_task_stat("checked_messages")
|
||
|
||
meta_data = MetaData()
|
||
|
||
caption = message.caption
|
||
if caption:
|
||
caption = validate_title(caption)
|
||
app.set_caption_name(node.chat_id, message.media_group_id, caption)
|
||
app.set_caption_entities(
|
||
node.chat_id, message.media_group_id, message.caption_entities
|
||
)
|
||
else:
|
||
caption = app.get_caption_name(node.chat_id, message.media_group_id)
|
||
set_meta_data(meta_data, message, caption)
|
||
|
||
if app.need_skip_message(chat_download_config, message.id):
|
||
continue
|
||
|
||
if app.exec_filter(chat_download_config, meta_data):
|
||
# 改动点 B:通过 filter 的消息计数,作为 X/N 的分母实时递增
|
||
increment_task_stat("qualified_files")
|
||
await add_download_task(message, node)
|
||
else:
|
||
node.download_status[message.id] = DownloadStatus.SkipDownload
|
||
increment_task_stat("skipped_files")
|
||
if message.media_group_id:
|
||
await upload_telegram_chat(
|
||
client,
|
||
node.upload_user,
|
||
app,
|
||
node,
|
||
message,
|
||
DownloadStatus.SkipDownload,
|
||
)
|
||
iter_completed = True
|
||
finally:
|
||
# 不论遍历是否正常结束,都必须把 node.total_task 同步给 chat_download_config,
|
||
# 否则 _wait_node_finish 会用 total_task=0 的旧值立即返回,跳过对 worker 的等待。
|
||
chat_download_config.need_check = True
|
||
chat_download_config.total_task = node.total_task
|
||
node.is_running = True
|
||
|
||
# 改动点 C:遍历正常完成,把实际 qualified_files 写入缓存并覆盖 estimated_total
|
||
# 仅在正常完成时写入(异常/中断时不执行,避免脏缓存)
|
||
if iter_completed:
|
||
actual_total = get_task_progress().get("qualified_files", 0)
|
||
db.save_scan_cache(str(node.chat_id), filter_key, actual_total)
|
||
update_task_progress(estimated_total=actual_total, is_checking=False)
|
||
else:
|
||
# 异常中断时也要把 is_checking 关掉,避免 UI 一直显示"扫描中"
|
||
update_task_progress(is_checking=False)
|
||
|
||
|
||
async def _wait_node_finish(chat_config, timeout: int = 3600):
|
||
"""等待单个频道的所有下载任务完成(含重试中的任务),超时后强制跳过避免队列卡死"""
|
||
deadline = asyncio.get_event_loop().time() + timeout
|
||
while True:
|
||
if chat_config.need_check and chat_config.finish_task >= chat_config.total_task:
|
||
break
|
||
if asyncio.get_event_loop().time() > deadline:
|
||
logger.warning(
|
||
f"等待频道任务完成超时({timeout}s),强制跳过,"
|
||
f"finish={chat_config.finish_task} total={chat_config.total_task}"
|
||
)
|
||
break
|
||
await asyncio.sleep(0.5)
|
||
|
||
|
||
async def download_all_chat(client: pyrogram.Client):
|
||
"""Download All chat"""
|
||
# Use list() to avoid "dictionary changed size during iteration" error
|
||
for key, value in list(app.chat_download_config.items()):
|
||
value.node = TaskNode(chat_id=key)
|
||
try:
|
||
await download_chat_task(client, value, value.node)
|
||
except Exception as e:
|
||
logger.warning(f"Download {key} error: {e}")
|
||
finally:
|
||
value.need_check = True
|
||
# 等当前频道所有下载任务跑完,再拍快照;否则 completed_files 还是 0
|
||
await _wait_node_finish(value)
|
||
snapshot_current_chat()
|
||
reset_task_progress()
|
||
# 所有频道都已在循环内快照并重置,此处仅确保最终状态干净
|
||
reset_task_progress()
|
||
|
||
|
||
async def run_until_all_task_finish():
|
||
"""Normal download"""
|
||
while True:
|
||
finish: bool = True
|
||
for _, value in app.chat_download_config.items():
|
||
if not value.need_check or value.total_task != value.finish_task:
|
||
finish = False
|
||
|
||
if app.restart_program:
|
||
break
|
||
|
||
await asyncio.sleep(1)
|
||
|
||
|
||
def _exec_loop():
|
||
"""Exec loop"""
|
||
|
||
app.loop.run_until_complete(run_until_all_task_finish())
|
||
|
||
|
||
async def start_server(client: pyrogram.Client):
|
||
"""
|
||
Start the server using the provided client.
|
||
"""
|
||
await client.start()
|
||
|
||
|
||
async def stop_server(client: pyrogram.Client):
|
||
"""
|
||
Stop the server using the provided client.
|
||
"""
|
||
await client.stop()
|
||
|
||
|
||
def main():
|
||
"""Main function of the downloader."""
|
||
tasks = []
|
||
|
||
# 未填写 API 凭证时以"Web 配置模式"启动,跳过 Telegram 客户端初始化
|
||
if app.is_configured():
|
||
client = HookClient(
|
||
"media_downloader",
|
||
api_id=app.api_id,
|
||
api_hash=app.api_hash,
|
||
proxy=app.proxy,
|
||
workdir=app.session_file_path,
|
||
start_timeout=app.start_timeout,
|
||
)
|
||
else:
|
||
client = None
|
||
logger.info(
|
||
"未检测到 API 凭证,以 Web 配置模式启动。"
|
||
f"请访问 http://0.0.0.0:{app.web_port} 完成初始设置。"
|
||
)
|
||
|
||
try:
|
||
app.pre_run()
|
||
db.init_db(os.path.join(os.path.abspath("."), "appdata", "downloads.db"))
|
||
init_web(app, client, add_download_task, download_chat_task)
|
||
|
||
if client is not None:
|
||
set_max_concurrent_transmissions(client, app.max_concurrent_transmissions)
|
||
app.loop.run_until_complete(start_server(client))
|
||
logger.success(_t("Successfully started (Press Ctrl+C to stop)"))
|
||
|
||
app.loop.create_task(download_all_chat(client))
|
||
for _ in range(app.max_download_task):
|
||
task = app.loop.create_task(worker(client))
|
||
tasks.append(task)
|
||
|
||
if app.bot_token:
|
||
app.loop.run_until_complete(
|
||
start_download_bot(app, client, add_download_task, download_chat_task)
|
||
)
|
||
_exec_loop()
|
||
else:
|
||
# Web 配置模式:保持进程运行,等待用户完成 Web 向导配置后触发重启
|
||
async def _wait_for_restart():
|
||
while not app.restart_program:
|
||
await asyncio.sleep(1)
|
||
app.loop.run_until_complete(_wait_for_restart())
|
||
except KeyboardInterrupt:
|
||
logger.info(_t("KeyboardInterrupt"))
|
||
except Exception as e:
|
||
logger.exception("{}", e)
|
||
finally:
|
||
app.is_running = False
|
||
if client is not None:
|
||
if app.bot_token:
|
||
app.loop.run_until_complete(stop_download_bot())
|
||
try:
|
||
app.loop.run_until_complete(stop_server(client))
|
||
except Exception:
|
||
pass
|
||
for task in tasks:
|
||
task.cancel()
|
||
logger.info(_t("Stopped!"))
|
||
# check_for_updates(app.proxy)
|
||
# Web 配置模式下重启时,不用内存中的空凭证覆盖向导刚写入的配置文件
|
||
if not (client is None and app.restart_program):
|
||
logger.info(f"{_t('update config')}......")
|
||
app.update_config()
|
||
logger.success(
|
||
f"{_t('Updated last read message_id to config file')},"
|
||
f"{_t('total download')} {app.total_download_task}, "
|
||
f"{_t('total upload file')} "
|
||
f"{app.cloud_drive_config.total_upload_success_file_count}"
|
||
)
|
||
|
||
# Return whether restart is needed
|
||
return app.restart_program
|
||
|
||
|
||
def run_with_restart():
|
||
"""Run main with auto-restart support using os.execv for clean restart"""
|
||
import sys
|
||
|
||
should_restart = main()
|
||
|
||
if should_restart:
|
||
logger.info("🔄 正在重启程序以应用新配置...")
|
||
# Shutdown web server to release the port
|
||
shutdown_web()
|
||
# Wait a moment for the port to be released
|
||
time.sleep(1)
|
||
# Use os.execv to completely restart the process
|
||
python = sys.executable
|
||
os.execv(python, [python] + sys.argv)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
if _check_config():
|
||
run_with_restart()
|