Files
telegram-downloader/media_downloader.py
T
yuming bc926bd972
部署到群晖 / deploy (push) Successful in 47s
fix: 清理 _download_result 中失败/中断任务的僵尸记录
UI "正在下载" 数会随时间膨胀(如设置并发=5 却显示 54 个)的根因:
_download_result 仅在用户手动「跳过」时被清理,下载失败/中断(3 次重试
都失败、_check_download_finish 大小不匹配抛 BadRequest 等)路径不会清,
导致中途失败的进度记录永久残留在内存字典里被 UI 误判为"还在下载"。

修复:在 download_task 末尾统一调用 remove_download_entry,无论成功/
失败/跳过任务结束就清掉。该 dict 仅用于 UI/Bot 实时进度展示,不参与
下载触发或重试决策(重试由 ids_to_retry + worker 内部 3 次重试两层兜底),
清理后下载逻辑零影响。

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-26 15:50:13 +08:00

860 lines
28 KiB
Python

"""Downloads media from telegram."""
import asyncio
import logging
import os
import shutil
import time
from typing import List, Optional, Tuple, Union
import pyrogram
from loguru import logger
from pyrogram.types import Audio, Document, Photo, Video, VideoNote, Voice
from rich.logging import RichHandler
import module.database as db
from module.app import Application, ChatDownloadConfig, DownloadStatus, TaskNode
from module.bot import start_download_bot, stop_download_bot
from module.download_stat import (
update_download_status,
update_task_progress,
reset_task_progress,
snapshot_current_chat,
increment_task_stat,
get_task_progress,
is_message_skipped,
skip_message,
remove_download_entry,
clear_skipped_message,
)
from module.get_chat_history_v2 import get_chat_history_v2
from module.language import _t
from module.pyrogram_extension import (
HookClient,
fetch_message,
get_extension,
record_download_status,
report_bot_download_status,
set_max_concurrent_transmissions,
set_meta_data,
update_cloud_upload_stat,
upload_telegram_chat,
)
from module.web import init_web, shutdown_web
from utils.format import truncate_filename, validate_title
from utils.log import LogFilter
from utils.meta import print_meta
from utils.meta_data import MetaData
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler()],
)
CONFIG_NAME = "config.yaml"
DATA_FILE_NAME = "data.yaml"
APPLICATION_NAME = "media_downloader"
app = Application(CONFIG_NAME, DATA_FILE_NAME, APPLICATION_NAME)
queue: asyncio.Queue = asyncio.Queue()
RETRY_TIME_OUT = 3
logging.getLogger("pyrogram.session.session").addFilter(LogFilter())
logging.getLogger("pyrogram.client").addFilter(LogFilter())
logging.getLogger("pyrogram").setLevel(logging.WARNING)
def _check_download_finish(media_size: int, download_path: str, ui_file_name: str):
"""Check download task if finish
Parameters
----------
media_size: int
The size of the downloaded resource
download_path: str
Resource download hold path
ui_file_name: str
Really show file name
"""
download_size = os.path.getsize(download_path)
if media_size == download_size:
logger.success(f"{_t('Successfully downloaded')} - {ui_file_name}")
else:
logger.warning(
f"{_t('Media downloaded with wrong size')}: "
f"{download_size}, {_t('actual')}: "
f"{media_size}, {_t('file name')}: {ui_file_name}"
)
os.remove(download_path)
raise pyrogram.errors.exceptions.bad_request_400.BadRequest()
def _move_to_download_path(temp_download_path: str, download_path: str):
"""Move file to download path
Parameters
----------
temp_download_path: str
Temporary download path
download_path: str
Download path
"""
directory, _ = os.path.split(download_path)
os.makedirs(directory, exist_ok=True)
shutil.move(temp_download_path, download_path)
def _check_timeout(retry: int, _: int):
"""Check if message download timeout, then add message id into failed_ids
Parameters
----------
retry: int
Retry download message times
message_id: int
Try to download message 's id
"""
if retry == 2:
return True
return False
def _can_download(_type: str, file_formats: dict, file_format: Optional[str]) -> bool:
"""
Check if the given file format can be downloaded.
Parameters
----------
_type: str
Type of media object.
file_formats: dict
Dictionary containing the list of file_formats
to be downloaded for `audio`, `document` & `video`
media types
file_format: str
Format of the current file to be downloaded.
Returns
-------
bool
True if the file format can be downloaded else False.
"""
if _type in ["audio", "document", "video"]:
allowed_formats: list = file_formats[_type]
if not file_format in allowed_formats and allowed_formats[0] != "all":
return False
return True
def _is_exist(file_path: str) -> bool:
"""
Check if a file exists and it is not a directory.
Parameters
----------
file_path: str
Absolute path of the file to be checked.
Returns
-------
bool
True if the file exists else False.
"""
return not os.path.isdir(file_path) and os.path.exists(file_path)
# pylint: disable = R0912
async def _get_media_meta(
chat_id: Union[int, str],
message: pyrogram.types.Message,
media_obj: Union[Audio, Document, Photo, Video, VideoNote, Voice],
_type: str,
) -> Tuple[str, str, Optional[str]]:
"""Extract file name and file id from media object.
Parameters
----------
media_obj: Union[Audio, Document, Photo, Video, VideoNote, Voice]
Media object to be extracted.
_type: str
Type of media object.
Returns
-------
Tuple[str, str, Optional[str]]
file_name, file_format
"""
if _type in ["audio", "document", "video"]:
# pylint: disable = C0301
file_format: Optional[str] = media_obj.mime_type.split("/")[-1] # type: ignore
else:
file_format = None
file_name = None
temp_file_name = None
dirname = validate_title(f"{chat_id}")
if message.chat and message.chat.title:
dirname = validate_title(f"{message.chat.title}")
if message.date:
datetime_dir_name = message.date.strftime(app.date_format)
else:
datetime_dir_name = "0"
if _type in ["voice", "video_note"]:
# pylint: disable = C0209
file_format = media_obj.mime_type.split("/")[-1] # type: ignore
file_save_path = app.get_file_save_path(_type, dirname, datetime_dir_name)
file_name = "{} - {}_{}.{}".format(
message.id,
_type,
media_obj.date.isoformat(), # type: ignore
file_format,
)
file_name = validate_title(file_name)
temp_file_name = os.path.join(app.temp_save_path, dirname, file_name)
file_name = os.path.join(file_save_path, file_name)
else:
file_name = getattr(media_obj, "file_name", None)
caption = getattr(message, "caption", None)
file_name_suffix = ".unknown"
if not file_name:
file_name_suffix = get_extension(
media_obj.file_id, getattr(media_obj, "mime_type", "")
)
else:
# file_name = file_name.split(".")[0]
_, file_name_without_suffix = os.path.split(os.path.normpath(file_name))
file_name, file_name_suffix = os.path.splitext(file_name_without_suffix)
if not file_name_suffix:
file_name_suffix = get_extension(
media_obj.file_id, getattr(media_obj, "mime_type", "")
)
if caption:
caption = validate_title(caption)
app.set_caption_name(chat_id, message.media_group_id, caption)
app.set_caption_entities(
chat_id, message.media_group_id, message.caption_entities
)
else:
caption = app.get_caption_name(chat_id, message.media_group_id)
if not file_name and message.photo:
file_name = f"{message.photo.file_unique_id}"
gen_file_name = (
app.get_file_name(message.id, file_name, caption) + file_name_suffix
)
file_save_path = app.get_file_save_path(_type, dirname, datetime_dir_name)
temp_file_name = os.path.join(app.temp_save_path, dirname, gen_file_name)
file_name = os.path.join(file_save_path, gen_file_name)
return truncate_filename(file_name), truncate_filename(temp_file_name), file_format
async def add_download_task(
message: pyrogram.types.Message,
node: TaskNode,
):
"""Add Download task"""
if message.empty:
return False
node.download_status[message.id] = DownloadStatus.Downloading
await queue.put((message, node))
node.total_task += 1
return True
async def save_msg_to_file(
app, chat_id: Union[int, str], message: pyrogram.types.Message
):
"""Write message text into file"""
dirname = validate_title(
message.chat.title if message.chat and message.chat.title else str(chat_id)
)
datetime_dir_name = message.date.strftime(app.date_format) if message.date else "0"
file_save_path = app.get_file_save_path("msg", dirname, datetime_dir_name)
file_name = os.path.join(
app.temp_save_path,
file_save_path,
f"{app.get_file_name(message.id, None, None)}.txt",
)
os.makedirs(os.path.dirname(file_name), exist_ok=True)
if _is_exist(file_name):
return DownloadStatus.SkipDownload, None
with open(file_name, "w", encoding="utf-8") as f:
f.write(message.text or "")
return DownloadStatus.SuccessDownload, file_name
async def download_task(
client: pyrogram.Client, message: pyrogram.types.Message, node: TaskNode
):
"""Download and Forward media"""
# Track download start
increment_task_stat("downloading_files")
download_status, file_name = await download_media(
client, message, app.media_types, app.file_formats, node
)
# Track download completion
increment_task_stat("downloading_files", -1)
if download_status == DownloadStatus.SuccessDownload:
increment_task_stat("completed_files")
elif download_status == DownloadStatus.FailedDownload:
increment_task_stat("failed_files")
if app.enable_download_txt and message.text and not message.media:
download_status, file_name = await save_msg_to_file(app, node.chat_id, message)
if not node.bot:
app.set_download_id(node, message.id, download_status)
node.download_status[message.id] = download_status
file_size = os.path.getsize(file_name) if file_name else 0
await upload_telegram_chat(
client,
node.upload_user if node.upload_user else client,
app,
node,
message,
download_status,
file_name,
)
# rclone upload
if (
not node.upload_telegram_chat_id
and download_status is DownloadStatus.SuccessDownload
):
ui_file_name = file_name
if app.hide_file_name:
ui_file_name = f"****{os.path.splitext(file_name)[-1]}"
if await app.upload_file(
file_name, update_cloud_upload_stat, (node, message.id, ui_file_name)
):
node.upload_success_count += 1
await report_bot_download_status(
node.bot,
node,
download_status,
file_size,
)
# 任务结束(无论成功/失败/跳过)都清理实时进度缓存,避免失败/中断的"僵尸记录"
# 残留在 _download_result 里被 UI 误判为"还在下载"。
# 已完成的历史记录走数据库(db.record_download),不依赖这个 dict。
# key 类型与 update_download_status 写入时保持一致(node.chat_id 原类型)。
remove_download_entry(node.chat_id, message.id)
# pylint: disable = R0915,R0914
@record_download_status
async def download_media(
client: pyrogram.client.Client,
message: pyrogram.types.Message,
media_types: List[str],
file_formats: dict,
node: TaskNode,
):
"""
Download media from Telegram.
Each of the files to download are retried 3 times with a
delay of 5 seconds each.
Parameters
----------
client: pyrogram.client.Client
Client to interact with Telegram APIs.
message: pyrogram.types.Message
Message object retrieved from telegram.
media_types: list
List of strings of media types to be downloaded.
Ex : `["audio", "photo"]`
Supported formats:
* audio
* document
* photo
* video
* voice
file_formats: dict
Dictionary containing the list of file_formats
to be downloaded for `audio`, `document` & `video`
media types.
Returns
-------
int
Current message id.
"""
# pylint: disable = R0912
file_name: str = ""
ui_file_name: str = ""
task_start_time: float = time.time()
media_size = 0
_media = None
message = await fetch_message(client, message)
try:
for _type in media_types:
_media = getattr(message, _type, None)
if _media is None:
continue
file_name, temp_file_name, file_format = await _get_media_meta(
node.chat_id, message, _media, _type
)
media_size = getattr(_media, "file_size", 0)
ui_file_name = file_name
if app.hide_file_name:
ui_file_name = f"****{os.path.splitext(file_name)[-1]}"
if _can_download(_type, file_formats, file_format):
if _is_exist(file_name):
file_size = os.path.getsize(file_name)
if file_size or file_size == media_size:
logger.info(
f"id={message.id} {ui_file_name} "
f"{_t('already download,download skipped')}.\n"
)
# Update skip counter
increment_task_stat("skipped_files")
# 这一类跳过属于"本次任务中发现已下载",单独计数给前端算分母
increment_task_stat("existing_skipped")
increment_task_stat("checked_messages")
return DownloadStatus.SkipDownload, None
else:
_should_skip, _reason = db.should_skip(str(node.chat_id), message.id)
if _should_skip:
logger.info(
f"id={message.id} {ui_file_name} {_reason},跳过。\n"
)
increment_task_stat("skipped_files")
# db 标记为跳过也算"本次任务跳过"
increment_task_stat("existing_skipped")
increment_task_stat("checked_messages")
return DownloadStatus.SkipDownload, None
else:
increment_task_stat("checked_messages")
return DownloadStatus.SkipDownload, None
break
except Exception as e:
logger.error(
f"Message[{message.id}]: "
f"{_t('could not be downloaded due to following exception')}:\n[{e}].",
exc_info=True,
)
return DownloadStatus.FailedDownload, None
if _media is None:
return DownloadStatus.SkipDownload, None
message_id = message.id
for retry in range(3):
if is_message_skipped(str(node.chat_id), message_id):
clear_skipped_message(str(node.chat_id), message_id)
remove_download_entry(str(node.chat_id), message_id)
return DownloadStatus.SkipDownload, None
try:
temp_download_path = await client.download_media(
message,
file_name=temp_file_name,
progress=update_download_status,
progress_args=(
message_id,
ui_file_name,
task_start_time,
node,
client,
),
)
if temp_download_path and isinstance(temp_download_path, str):
_check_download_finish(media_size, temp_download_path, ui_file_name)
await asyncio.sleep(0.5)
_move_to_download_path(temp_download_path, file_name)
chat_title = ""
if message.chat and message.chat.title:
chat_title = message.chat.title
db.record_download(
chat_id=str(node.chat_id),
chat_title=chat_title,
message_id=message.id,
file_name=os.path.basename(file_name),
file_path=file_name,
file_size=media_size,
media_type=_type,
status="success",
)
# TODO: if not exist file size or media
return DownloadStatus.SuccessDownload, file_name
except pyrogram.errors.exceptions.bad_request_400.BadRequest:
logger.warning(
f"Message[{message.id}]: {_t('file reference expired, refetching')}..."
)
await asyncio.sleep(RETRY_TIME_OUT)
message = await fetch_message(client, message)
if _check_timeout(retry, message.id):
# pylint: disable = C0301
logger.error(
f"Message[{message.id}]: "
f"{_t('file reference expired for 3 retries, download skipped.')}"
)
except pyrogram.errors.exceptions.flood_420.FloodWait as wait_err:
await asyncio.sleep(wait_err.value)
logger.warning("Message[{}]: FlowWait {}", message.id, wait_err.value)
_check_timeout(retry, message.id)
except TypeError:
# pylint: disable = C0301
logger.warning(
f"{_t('Timeout Error occurred when downloading Message')}[{message.id}], "
f"{_t('retrying after')} {RETRY_TIME_OUT} {_t('seconds')}"
)
await asyncio.sleep(RETRY_TIME_OUT)
if _check_timeout(retry, message.id):
logger.error(
f"Message[{message.id}]: {_t('Timing out after 3 reties, download skipped.')}"
)
except Exception as e:
# pylint: disable = C0301
logger.error(
f"Message[{message.id}]: "
f"{_t('could not be downloaded due to following exception')}:\n[{e}].",
exc_info=True,
)
break
return DownloadStatus.FailedDownload, None
def _load_config():
"""Load config"""
app.load_config()
def _check_config() -> bool:
"""Check config"""
print_meta(logger)
try:
_load_config()
logger.add(
os.path.join(app.log_file_path, "tdl.log"),
rotation="10 MB",
retention="10 days",
level=app.log_level,
)
except Exception as e:
logger.exception(f"load config error: {e}")
return False
return True
async def worker(client: pyrogram.client.Client):
"""Work for download task"""
while app.is_running:
try:
item = await queue.get()
message = item[0]
node: TaskNode = item[1]
if node.is_stop_transmission:
continue
if is_message_skipped(str(node.chat_id), message.id):
skip_message(str(node.chat_id), message.id)
continue
if node.client:
await download_task(node.client, message, node)
else:
await download_task(client, message, node)
except Exception as e:
logger.exception(f"{e}")
async def download_chat_task(
client: pyrogram.Client,
chat_download_config: ChatDownloadConfig,
node: TaskNode,
):
"""Download all task"""
# 切到下一个频道前,先把上一个频道的最终进度快照进 _completed_chats
snapshot_current_chat()
# Reset and update task progress
reset_task_progress()
# Try to get chat title
try:
chat = await client.get_chat(node.chat_id)
chat_title = chat.title or chat.first_name or str(node.chat_id)
except Exception:
chat_title = str(node.chat_id)
update_task_progress(
current_chat=str(node.chat_id),
current_chat_title=chat_title,
is_checking=True
)
# 改动点 A:尝试读预扫描缓存;命中则 banner 立刻显示 0 / N
filter_key = db.build_filter_key(
chat_download_config.download_filter,
app.media_types,
app.file_formats,
)
cached_total = db.get_scan_cache(str(node.chat_id), filter_key)
if cached_total is not None:
update_task_progress(estimated_total=cached_total)
messages_iter = get_chat_history_v2(
client,
node.chat_id,
limit=node.limit,
max_id=node.end_offset_id,
offset_id=chat_download_config.last_read_message_id,
reverse=True,
)
chat_download_config.node = node
if chat_download_config.ids_to_retry:
logger.info(f"{_t('Downloading files failed during last run')}...")
skipped_messages: list = await client.get_messages( # type: ignore
chat_id=node.chat_id, message_ids=chat_download_config.ids_to_retry
)
for message in skipped_messages:
await add_download_task(message, node)
async for message in messages_iter: # type: ignore
# Update checking progress for each message
increment_task_stat("checked_messages")
meta_data = MetaData()
caption = message.caption
if caption:
caption = validate_title(caption)
app.set_caption_name(node.chat_id, message.media_group_id, caption)
app.set_caption_entities(
node.chat_id, message.media_group_id, message.caption_entities
)
else:
caption = app.get_caption_name(node.chat_id, message.media_group_id)
set_meta_data(meta_data, message, caption)
if app.need_skip_message(chat_download_config, message.id):
continue
if app.exec_filter(chat_download_config, meta_data):
# 改动点 B:通过 filter 的消息计数,作为 X/N 的分母实时递增
increment_task_stat("qualified_files")
await add_download_task(message, node)
else:
node.download_status[message.id] = DownloadStatus.SkipDownload
increment_task_stat("skipped_files")
if message.media_group_id:
await upload_telegram_chat(
client,
node.upload_user,
app,
node,
message,
DownloadStatus.SkipDownload,
)
chat_download_config.need_check = True
chat_download_config.total_task = node.total_task
node.is_running = True
# 改动点 C:遍历正常完成,把实际 qualified_files 写入缓存并覆盖 estimated_total
# 仅在正常完成时写入(异常/中断时不执行,避免脏缓存)
actual_total = get_task_progress().get("qualified_files", 0)
db.save_scan_cache(str(node.chat_id), filter_key, actual_total)
update_task_progress(estimated_total=actual_total, is_checking=False)
async def download_all_chat(client: pyrogram.Client):
"""Download All chat"""
# Use list() to avoid "dictionary changed size during iteration" error
for key, value in list(app.chat_download_config.items()):
value.node = TaskNode(chat_id=key)
try:
await download_chat_task(client, value, value.node)
except Exception as e:
logger.warning(f"Download {key} error: {e}")
finally:
value.need_check = True
# 所有频道串行跑完后,补一次最后频道的快照;否则末尾那个频道永远只留在 _task_progress
# 里、不会进 _completed_chats,前端会一直显示「🚀 下载中」
snapshot_current_chat()
# snapshot 只是复制、不会清空 _task_progress,必须 reset 一下,否则前端 h-skip 会把
# 最后频道的 skipped_files 同时计入 completed_chats 汇总和当前值,造成双计
reset_task_progress()
async def run_until_all_task_finish():
"""Normal download"""
while True:
finish: bool = True
for _, value in app.chat_download_config.items():
if not value.need_check or value.total_task != value.finish_task:
finish = False
if app.restart_program:
break
await asyncio.sleep(1)
def _exec_loop():
"""Exec loop"""
app.loop.run_until_complete(run_until_all_task_finish())
async def start_server(client: pyrogram.Client):
"""
Start the server using the provided client.
"""
await client.start()
async def stop_server(client: pyrogram.Client):
"""
Stop the server using the provided client.
"""
await client.stop()
def main():
"""Main function of the downloader."""
tasks = []
# 未填写 API 凭证时以"Web 配置模式"启动,跳过 Telegram 客户端初始化
if app.is_configured():
client = HookClient(
"media_downloader",
api_id=app.api_id,
api_hash=app.api_hash,
proxy=app.proxy,
workdir=app.session_file_path,
start_timeout=app.start_timeout,
)
else:
client = None
logger.info(
"未检测到 API 凭证,以 Web 配置模式启动。"
f"请访问 http://0.0.0.0:{app.web_port} 完成初始设置。"
)
try:
app.pre_run()
db.init_db(os.path.join(os.path.abspath("."), "appdata", "downloads.db"))
init_web(app, client, add_download_task, download_chat_task)
if client is not None:
set_max_concurrent_transmissions(client, app.max_concurrent_transmissions)
app.loop.run_until_complete(start_server(client))
logger.success(_t("Successfully started (Press Ctrl+C to stop)"))
app.loop.create_task(download_all_chat(client))
for _ in range(app.max_download_task):
task = app.loop.create_task(worker(client))
tasks.append(task)
if app.bot_token:
app.loop.run_until_complete(
start_download_bot(app, client, add_download_task, download_chat_task)
)
_exec_loop()
else:
# Web 配置模式:保持进程运行,等待用户完成 Web 向导配置后触发重启
async def _wait_for_restart():
while not app.restart_program:
await asyncio.sleep(1)
app.loop.run_until_complete(_wait_for_restart())
except KeyboardInterrupt:
logger.info(_t("KeyboardInterrupt"))
except Exception as e:
logger.exception("{}", e)
finally:
app.is_running = False
if client is not None:
if app.bot_token:
app.loop.run_until_complete(stop_download_bot())
try:
app.loop.run_until_complete(stop_server(client))
except Exception:
pass
for task in tasks:
task.cancel()
logger.info(_t("Stopped!"))
# check_for_updates(app.proxy)
# Web 配置模式下重启时,不用内存中的空凭证覆盖向导刚写入的配置文件
if not (client is None and app.restart_program):
logger.info(f"{_t('update config')}......")
app.update_config()
logger.success(
f"{_t('Updated last read message_id to config file')},"
f"{_t('total download')} {app.total_download_task}, "
f"{_t('total upload file')} "
f"{app.cloud_drive_config.total_upload_success_file_count}"
)
# Return whether restart is needed
return app.restart_program
def run_with_restart():
"""Run main with auto-restart support using os.execv for clean restart"""
import sys
should_restart = main()
if should_restart:
logger.info("🔄 正在重启程序以应用新配置...")
# Shutdown web server to release the port
shutdown_web()
# Wait a moment for the port to be released
time.sleep(1)
# Use os.execv to completely restart the process
python = sys.executable
os.execv(python, [python] + sys.argv)
if __name__ == "__main__":
if _check_config():
run_with_restart()