bc926bd972
部署到群晖 / deploy (push) Successful in 47s
UI "正在下载" 数会随时间膨胀(如设置并发=5 却显示 54 个)的根因: _download_result 仅在用户手动「跳过」时被清理,下载失败/中断(3 次重试 都失败、_check_download_finish 大小不匹配抛 BadRequest 等)路径不会清, 导致中途失败的进度记录永久残留在内存字典里被 UI 误判为"还在下载"。 修复:在 download_task 末尾统一调用 remove_download_entry,无论成功/ 失败/跳过任务结束就清掉。该 dict 仅用于 UI/Bot 实时进度展示,不参与 下载触发或重试决策(重试由 ids_to_retry + worker 内部 3 次重试两层兜底), 清理后下载逻辑零影响。 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
860 lines
28 KiB
Python
860 lines
28 KiB
Python
"""Downloads media from telegram."""
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import time
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import pyrogram
|
|
from loguru import logger
|
|
from pyrogram.types import Audio, Document, Photo, Video, VideoNote, Voice
|
|
from rich.logging import RichHandler
|
|
|
|
import module.database as db
|
|
from module.app import Application, ChatDownloadConfig, DownloadStatus, TaskNode
|
|
from module.bot import start_download_bot, stop_download_bot
|
|
from module.download_stat import (
|
|
update_download_status,
|
|
update_task_progress,
|
|
reset_task_progress,
|
|
snapshot_current_chat,
|
|
increment_task_stat,
|
|
get_task_progress,
|
|
is_message_skipped,
|
|
skip_message,
|
|
remove_download_entry,
|
|
clear_skipped_message,
|
|
)
|
|
from module.get_chat_history_v2 import get_chat_history_v2
|
|
from module.language import _t
|
|
from module.pyrogram_extension import (
|
|
HookClient,
|
|
fetch_message,
|
|
get_extension,
|
|
record_download_status,
|
|
report_bot_download_status,
|
|
set_max_concurrent_transmissions,
|
|
set_meta_data,
|
|
update_cloud_upload_stat,
|
|
upload_telegram_chat,
|
|
)
|
|
from module.web import init_web, shutdown_web
|
|
from utils.format import truncate_filename, validate_title
|
|
from utils.log import LogFilter
|
|
from utils.meta import print_meta
|
|
from utils.meta_data import MetaData
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(message)s",
|
|
datefmt="[%X]",
|
|
handlers=[RichHandler()],
|
|
)
|
|
|
|
CONFIG_NAME = "config.yaml"
|
|
DATA_FILE_NAME = "data.yaml"
|
|
APPLICATION_NAME = "media_downloader"
|
|
app = Application(CONFIG_NAME, DATA_FILE_NAME, APPLICATION_NAME)
|
|
|
|
queue: asyncio.Queue = asyncio.Queue()
|
|
RETRY_TIME_OUT = 3
|
|
|
|
logging.getLogger("pyrogram.session.session").addFilter(LogFilter())
|
|
logging.getLogger("pyrogram.client").addFilter(LogFilter())
|
|
|
|
logging.getLogger("pyrogram").setLevel(logging.WARNING)
|
|
|
|
|
|
def _check_download_finish(media_size: int, download_path: str, ui_file_name: str):
|
|
"""Check download task if finish
|
|
|
|
Parameters
|
|
----------
|
|
media_size: int
|
|
The size of the downloaded resource
|
|
download_path: str
|
|
Resource download hold path
|
|
ui_file_name: str
|
|
Really show file name
|
|
|
|
"""
|
|
download_size = os.path.getsize(download_path)
|
|
if media_size == download_size:
|
|
logger.success(f"{_t('Successfully downloaded')} - {ui_file_name}")
|
|
else:
|
|
logger.warning(
|
|
f"{_t('Media downloaded with wrong size')}: "
|
|
f"{download_size}, {_t('actual')}: "
|
|
f"{media_size}, {_t('file name')}: {ui_file_name}"
|
|
)
|
|
os.remove(download_path)
|
|
raise pyrogram.errors.exceptions.bad_request_400.BadRequest()
|
|
|
|
|
|
def _move_to_download_path(temp_download_path: str, download_path: str):
|
|
"""Move file to download path
|
|
|
|
Parameters
|
|
----------
|
|
temp_download_path: str
|
|
Temporary download path
|
|
|
|
download_path: str
|
|
Download path
|
|
|
|
"""
|
|
|
|
directory, _ = os.path.split(download_path)
|
|
os.makedirs(directory, exist_ok=True)
|
|
shutil.move(temp_download_path, download_path)
|
|
|
|
|
|
def _check_timeout(retry: int, _: int):
|
|
"""Check if message download timeout, then add message id into failed_ids
|
|
|
|
Parameters
|
|
----------
|
|
retry: int
|
|
Retry download message times
|
|
|
|
message_id: int
|
|
Try to download message 's id
|
|
|
|
"""
|
|
if retry == 2:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _can_download(_type: str, file_formats: dict, file_format: Optional[str]) -> bool:
|
|
"""
|
|
Check if the given file format can be downloaded.
|
|
|
|
Parameters
|
|
----------
|
|
_type: str
|
|
Type of media object.
|
|
file_formats: dict
|
|
Dictionary containing the list of file_formats
|
|
to be downloaded for `audio`, `document` & `video`
|
|
media types
|
|
file_format: str
|
|
Format of the current file to be downloaded.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the file format can be downloaded else False.
|
|
"""
|
|
if _type in ["audio", "document", "video"]:
|
|
allowed_formats: list = file_formats[_type]
|
|
if not file_format in allowed_formats and allowed_formats[0] != "all":
|
|
return False
|
|
return True
|
|
|
|
|
|
def _is_exist(file_path: str) -> bool:
|
|
"""
|
|
Check if a file exists and it is not a directory.
|
|
|
|
Parameters
|
|
----------
|
|
file_path: str
|
|
Absolute path of the file to be checked.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the file exists else False.
|
|
"""
|
|
return not os.path.isdir(file_path) and os.path.exists(file_path)
|
|
|
|
|
|
# pylint: disable = R0912
|
|
|
|
|
|
async def _get_media_meta(
|
|
chat_id: Union[int, str],
|
|
message: pyrogram.types.Message,
|
|
media_obj: Union[Audio, Document, Photo, Video, VideoNote, Voice],
|
|
_type: str,
|
|
) -> Tuple[str, str, Optional[str]]:
|
|
"""Extract file name and file id from media object.
|
|
|
|
Parameters
|
|
----------
|
|
media_obj: Union[Audio, Document, Photo, Video, VideoNote, Voice]
|
|
Media object to be extracted.
|
|
_type: str
|
|
Type of media object.
|
|
|
|
Returns
|
|
-------
|
|
Tuple[str, str, Optional[str]]
|
|
file_name, file_format
|
|
"""
|
|
if _type in ["audio", "document", "video"]:
|
|
# pylint: disable = C0301
|
|
file_format: Optional[str] = media_obj.mime_type.split("/")[-1] # type: ignore
|
|
else:
|
|
file_format = None
|
|
|
|
file_name = None
|
|
temp_file_name = None
|
|
dirname = validate_title(f"{chat_id}")
|
|
if message.chat and message.chat.title:
|
|
dirname = validate_title(f"{message.chat.title}")
|
|
|
|
if message.date:
|
|
datetime_dir_name = message.date.strftime(app.date_format)
|
|
else:
|
|
datetime_dir_name = "0"
|
|
|
|
if _type in ["voice", "video_note"]:
|
|
# pylint: disable = C0209
|
|
file_format = media_obj.mime_type.split("/")[-1] # type: ignore
|
|
file_save_path = app.get_file_save_path(_type, dirname, datetime_dir_name)
|
|
file_name = "{} - {}_{}.{}".format(
|
|
message.id,
|
|
_type,
|
|
media_obj.date.isoformat(), # type: ignore
|
|
file_format,
|
|
)
|
|
file_name = validate_title(file_name)
|
|
temp_file_name = os.path.join(app.temp_save_path, dirname, file_name)
|
|
|
|
file_name = os.path.join(file_save_path, file_name)
|
|
else:
|
|
file_name = getattr(media_obj, "file_name", None)
|
|
caption = getattr(message, "caption", None)
|
|
|
|
file_name_suffix = ".unknown"
|
|
if not file_name:
|
|
file_name_suffix = get_extension(
|
|
media_obj.file_id, getattr(media_obj, "mime_type", "")
|
|
)
|
|
else:
|
|
# file_name = file_name.split(".")[0]
|
|
_, file_name_without_suffix = os.path.split(os.path.normpath(file_name))
|
|
file_name, file_name_suffix = os.path.splitext(file_name_without_suffix)
|
|
if not file_name_suffix:
|
|
file_name_suffix = get_extension(
|
|
media_obj.file_id, getattr(media_obj, "mime_type", "")
|
|
)
|
|
|
|
if caption:
|
|
caption = validate_title(caption)
|
|
app.set_caption_name(chat_id, message.media_group_id, caption)
|
|
app.set_caption_entities(
|
|
chat_id, message.media_group_id, message.caption_entities
|
|
)
|
|
else:
|
|
caption = app.get_caption_name(chat_id, message.media_group_id)
|
|
|
|
if not file_name and message.photo:
|
|
file_name = f"{message.photo.file_unique_id}"
|
|
|
|
gen_file_name = (
|
|
app.get_file_name(message.id, file_name, caption) + file_name_suffix
|
|
)
|
|
|
|
file_save_path = app.get_file_save_path(_type, dirname, datetime_dir_name)
|
|
|
|
temp_file_name = os.path.join(app.temp_save_path, dirname, gen_file_name)
|
|
|
|
file_name = os.path.join(file_save_path, gen_file_name)
|
|
return truncate_filename(file_name), truncate_filename(temp_file_name), file_format
|
|
|
|
|
|
async def add_download_task(
|
|
message: pyrogram.types.Message,
|
|
node: TaskNode,
|
|
):
|
|
"""Add Download task"""
|
|
if message.empty:
|
|
return False
|
|
node.download_status[message.id] = DownloadStatus.Downloading
|
|
await queue.put((message, node))
|
|
node.total_task += 1
|
|
return True
|
|
|
|
|
|
async def save_msg_to_file(
|
|
app, chat_id: Union[int, str], message: pyrogram.types.Message
|
|
):
|
|
"""Write message text into file"""
|
|
dirname = validate_title(
|
|
message.chat.title if message.chat and message.chat.title else str(chat_id)
|
|
)
|
|
datetime_dir_name = message.date.strftime(app.date_format) if message.date else "0"
|
|
|
|
file_save_path = app.get_file_save_path("msg", dirname, datetime_dir_name)
|
|
file_name = os.path.join(
|
|
app.temp_save_path,
|
|
file_save_path,
|
|
f"{app.get_file_name(message.id, None, None)}.txt",
|
|
)
|
|
|
|
os.makedirs(os.path.dirname(file_name), exist_ok=True)
|
|
|
|
if _is_exist(file_name):
|
|
return DownloadStatus.SkipDownload, None
|
|
|
|
with open(file_name, "w", encoding="utf-8") as f:
|
|
f.write(message.text or "")
|
|
|
|
return DownloadStatus.SuccessDownload, file_name
|
|
|
|
|
|
async def download_task(
|
|
client: pyrogram.Client, message: pyrogram.types.Message, node: TaskNode
|
|
):
|
|
"""Download and Forward media"""
|
|
|
|
# Track download start
|
|
increment_task_stat("downloading_files")
|
|
|
|
download_status, file_name = await download_media(
|
|
client, message, app.media_types, app.file_formats, node
|
|
)
|
|
|
|
# Track download completion
|
|
increment_task_stat("downloading_files", -1)
|
|
if download_status == DownloadStatus.SuccessDownload:
|
|
increment_task_stat("completed_files")
|
|
elif download_status == DownloadStatus.FailedDownload:
|
|
increment_task_stat("failed_files")
|
|
|
|
if app.enable_download_txt and message.text and not message.media:
|
|
download_status, file_name = await save_msg_to_file(app, node.chat_id, message)
|
|
|
|
if not node.bot:
|
|
app.set_download_id(node, message.id, download_status)
|
|
|
|
node.download_status[message.id] = download_status
|
|
|
|
file_size = os.path.getsize(file_name) if file_name else 0
|
|
|
|
await upload_telegram_chat(
|
|
client,
|
|
node.upload_user if node.upload_user else client,
|
|
app,
|
|
node,
|
|
message,
|
|
download_status,
|
|
file_name,
|
|
)
|
|
|
|
# rclone upload
|
|
if (
|
|
not node.upload_telegram_chat_id
|
|
and download_status is DownloadStatus.SuccessDownload
|
|
):
|
|
ui_file_name = file_name
|
|
if app.hide_file_name:
|
|
ui_file_name = f"****{os.path.splitext(file_name)[-1]}"
|
|
if await app.upload_file(
|
|
file_name, update_cloud_upload_stat, (node, message.id, ui_file_name)
|
|
):
|
|
node.upload_success_count += 1
|
|
|
|
await report_bot_download_status(
|
|
node.bot,
|
|
node,
|
|
download_status,
|
|
file_size,
|
|
)
|
|
|
|
# 任务结束(无论成功/失败/跳过)都清理实时进度缓存,避免失败/中断的"僵尸记录"
|
|
# 残留在 _download_result 里被 UI 误判为"还在下载"。
|
|
# 已完成的历史记录走数据库(db.record_download),不依赖这个 dict。
|
|
# key 类型与 update_download_status 写入时保持一致(node.chat_id 原类型)。
|
|
remove_download_entry(node.chat_id, message.id)
|
|
|
|
|
|
# pylint: disable = R0915,R0914
|
|
|
|
|
|
@record_download_status
|
|
async def download_media(
|
|
client: pyrogram.client.Client,
|
|
message: pyrogram.types.Message,
|
|
media_types: List[str],
|
|
file_formats: dict,
|
|
node: TaskNode,
|
|
):
|
|
"""
|
|
Download media from Telegram.
|
|
|
|
Each of the files to download are retried 3 times with a
|
|
delay of 5 seconds each.
|
|
|
|
Parameters
|
|
----------
|
|
client: pyrogram.client.Client
|
|
Client to interact with Telegram APIs.
|
|
message: pyrogram.types.Message
|
|
Message object retrieved from telegram.
|
|
media_types: list
|
|
List of strings of media types to be downloaded.
|
|
Ex : `["audio", "photo"]`
|
|
Supported formats:
|
|
* audio
|
|
* document
|
|
* photo
|
|
* video
|
|
* voice
|
|
file_formats: dict
|
|
Dictionary containing the list of file_formats
|
|
to be downloaded for `audio`, `document` & `video`
|
|
media types.
|
|
|
|
Returns
|
|
-------
|
|
int
|
|
Current message id.
|
|
"""
|
|
|
|
# pylint: disable = R0912
|
|
|
|
file_name: str = ""
|
|
ui_file_name: str = ""
|
|
task_start_time: float = time.time()
|
|
media_size = 0
|
|
_media = None
|
|
message = await fetch_message(client, message)
|
|
try:
|
|
for _type in media_types:
|
|
_media = getattr(message, _type, None)
|
|
if _media is None:
|
|
continue
|
|
file_name, temp_file_name, file_format = await _get_media_meta(
|
|
node.chat_id, message, _media, _type
|
|
)
|
|
media_size = getattr(_media, "file_size", 0)
|
|
|
|
ui_file_name = file_name
|
|
if app.hide_file_name:
|
|
ui_file_name = f"****{os.path.splitext(file_name)[-1]}"
|
|
|
|
if _can_download(_type, file_formats, file_format):
|
|
if _is_exist(file_name):
|
|
file_size = os.path.getsize(file_name)
|
|
if file_size or file_size == media_size:
|
|
logger.info(
|
|
f"id={message.id} {ui_file_name} "
|
|
f"{_t('already download,download skipped')}.\n"
|
|
)
|
|
# Update skip counter
|
|
increment_task_stat("skipped_files")
|
|
# 这一类跳过属于"本次任务中发现已下载",单独计数给前端算分母
|
|
increment_task_stat("existing_skipped")
|
|
increment_task_stat("checked_messages")
|
|
return DownloadStatus.SkipDownload, None
|
|
else:
|
|
_should_skip, _reason = db.should_skip(str(node.chat_id), message.id)
|
|
if _should_skip:
|
|
logger.info(
|
|
f"id={message.id} {ui_file_name} {_reason},跳过。\n"
|
|
)
|
|
increment_task_stat("skipped_files")
|
|
# db 标记为跳过也算"本次任务跳过"
|
|
increment_task_stat("existing_skipped")
|
|
increment_task_stat("checked_messages")
|
|
return DownloadStatus.SkipDownload, None
|
|
else:
|
|
increment_task_stat("checked_messages")
|
|
return DownloadStatus.SkipDownload, None
|
|
|
|
break
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Message[{message.id}]: "
|
|
f"{_t('could not be downloaded due to following exception')}:\n[{e}].",
|
|
exc_info=True,
|
|
)
|
|
return DownloadStatus.FailedDownload, None
|
|
if _media is None:
|
|
return DownloadStatus.SkipDownload, None
|
|
|
|
message_id = message.id
|
|
|
|
for retry in range(3):
|
|
if is_message_skipped(str(node.chat_id), message_id):
|
|
clear_skipped_message(str(node.chat_id), message_id)
|
|
remove_download_entry(str(node.chat_id), message_id)
|
|
return DownloadStatus.SkipDownload, None
|
|
try:
|
|
temp_download_path = await client.download_media(
|
|
message,
|
|
file_name=temp_file_name,
|
|
progress=update_download_status,
|
|
progress_args=(
|
|
message_id,
|
|
ui_file_name,
|
|
task_start_time,
|
|
node,
|
|
client,
|
|
),
|
|
)
|
|
|
|
if temp_download_path and isinstance(temp_download_path, str):
|
|
_check_download_finish(media_size, temp_download_path, ui_file_name)
|
|
await asyncio.sleep(0.5)
|
|
_move_to_download_path(temp_download_path, file_name)
|
|
chat_title = ""
|
|
if message.chat and message.chat.title:
|
|
chat_title = message.chat.title
|
|
db.record_download(
|
|
chat_id=str(node.chat_id),
|
|
chat_title=chat_title,
|
|
message_id=message.id,
|
|
file_name=os.path.basename(file_name),
|
|
file_path=file_name,
|
|
file_size=media_size,
|
|
media_type=_type,
|
|
status="success",
|
|
)
|
|
# TODO: if not exist file size or media
|
|
return DownloadStatus.SuccessDownload, file_name
|
|
except pyrogram.errors.exceptions.bad_request_400.BadRequest:
|
|
logger.warning(
|
|
f"Message[{message.id}]: {_t('file reference expired, refetching')}..."
|
|
)
|
|
await asyncio.sleep(RETRY_TIME_OUT)
|
|
message = await fetch_message(client, message)
|
|
if _check_timeout(retry, message.id):
|
|
# pylint: disable = C0301
|
|
logger.error(
|
|
f"Message[{message.id}]: "
|
|
f"{_t('file reference expired for 3 retries, download skipped.')}"
|
|
)
|
|
except pyrogram.errors.exceptions.flood_420.FloodWait as wait_err:
|
|
await asyncio.sleep(wait_err.value)
|
|
logger.warning("Message[{}]: FlowWait {}", message.id, wait_err.value)
|
|
_check_timeout(retry, message.id)
|
|
except TypeError:
|
|
# pylint: disable = C0301
|
|
logger.warning(
|
|
f"{_t('Timeout Error occurred when downloading Message')}[{message.id}], "
|
|
f"{_t('retrying after')} {RETRY_TIME_OUT} {_t('seconds')}"
|
|
)
|
|
await asyncio.sleep(RETRY_TIME_OUT)
|
|
if _check_timeout(retry, message.id):
|
|
logger.error(
|
|
f"Message[{message.id}]: {_t('Timing out after 3 reties, download skipped.')}"
|
|
)
|
|
except Exception as e:
|
|
# pylint: disable = C0301
|
|
logger.error(
|
|
f"Message[{message.id}]: "
|
|
f"{_t('could not be downloaded due to following exception')}:\n[{e}].",
|
|
exc_info=True,
|
|
)
|
|
break
|
|
|
|
return DownloadStatus.FailedDownload, None
|
|
|
|
|
|
def _load_config():
|
|
"""Load config"""
|
|
app.load_config()
|
|
|
|
|
|
def _check_config() -> bool:
|
|
"""Check config"""
|
|
print_meta(logger)
|
|
try:
|
|
_load_config()
|
|
logger.add(
|
|
os.path.join(app.log_file_path, "tdl.log"),
|
|
rotation="10 MB",
|
|
retention="10 days",
|
|
level=app.log_level,
|
|
)
|
|
except Exception as e:
|
|
logger.exception(f"load config error: {e}")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
async def worker(client: pyrogram.client.Client):
|
|
"""Work for download task"""
|
|
while app.is_running:
|
|
try:
|
|
item = await queue.get()
|
|
message = item[0]
|
|
node: TaskNode = item[1]
|
|
|
|
if node.is_stop_transmission:
|
|
continue
|
|
|
|
if is_message_skipped(str(node.chat_id), message.id):
|
|
skip_message(str(node.chat_id), message.id)
|
|
continue
|
|
|
|
if node.client:
|
|
await download_task(node.client, message, node)
|
|
else:
|
|
await download_task(client, message, node)
|
|
except Exception as e:
|
|
logger.exception(f"{e}")
|
|
|
|
|
|
async def download_chat_task(
|
|
client: pyrogram.Client,
|
|
chat_download_config: ChatDownloadConfig,
|
|
node: TaskNode,
|
|
):
|
|
"""Download all task"""
|
|
# 切到下一个频道前,先把上一个频道的最终进度快照进 _completed_chats
|
|
snapshot_current_chat()
|
|
# Reset and update task progress
|
|
reset_task_progress()
|
|
|
|
# Try to get chat title
|
|
try:
|
|
chat = await client.get_chat(node.chat_id)
|
|
chat_title = chat.title or chat.first_name or str(node.chat_id)
|
|
except Exception:
|
|
chat_title = str(node.chat_id)
|
|
|
|
update_task_progress(
|
|
current_chat=str(node.chat_id),
|
|
current_chat_title=chat_title,
|
|
is_checking=True
|
|
)
|
|
|
|
# 改动点 A:尝试读预扫描缓存;命中则 banner 立刻显示 0 / N
|
|
filter_key = db.build_filter_key(
|
|
chat_download_config.download_filter,
|
|
app.media_types,
|
|
app.file_formats,
|
|
)
|
|
cached_total = db.get_scan_cache(str(node.chat_id), filter_key)
|
|
if cached_total is not None:
|
|
update_task_progress(estimated_total=cached_total)
|
|
|
|
messages_iter = get_chat_history_v2(
|
|
client,
|
|
node.chat_id,
|
|
limit=node.limit,
|
|
max_id=node.end_offset_id,
|
|
offset_id=chat_download_config.last_read_message_id,
|
|
reverse=True,
|
|
)
|
|
|
|
chat_download_config.node = node
|
|
|
|
if chat_download_config.ids_to_retry:
|
|
logger.info(f"{_t('Downloading files failed during last run')}...")
|
|
skipped_messages: list = await client.get_messages( # type: ignore
|
|
chat_id=node.chat_id, message_ids=chat_download_config.ids_to_retry
|
|
)
|
|
|
|
for message in skipped_messages:
|
|
await add_download_task(message, node)
|
|
|
|
async for message in messages_iter: # type: ignore
|
|
# Update checking progress for each message
|
|
increment_task_stat("checked_messages")
|
|
|
|
meta_data = MetaData()
|
|
|
|
caption = message.caption
|
|
if caption:
|
|
caption = validate_title(caption)
|
|
app.set_caption_name(node.chat_id, message.media_group_id, caption)
|
|
app.set_caption_entities(
|
|
node.chat_id, message.media_group_id, message.caption_entities
|
|
)
|
|
else:
|
|
caption = app.get_caption_name(node.chat_id, message.media_group_id)
|
|
set_meta_data(meta_data, message, caption)
|
|
|
|
if app.need_skip_message(chat_download_config, message.id):
|
|
continue
|
|
|
|
if app.exec_filter(chat_download_config, meta_data):
|
|
# 改动点 B:通过 filter 的消息计数,作为 X/N 的分母实时递增
|
|
increment_task_stat("qualified_files")
|
|
await add_download_task(message, node)
|
|
else:
|
|
node.download_status[message.id] = DownloadStatus.SkipDownload
|
|
increment_task_stat("skipped_files")
|
|
if message.media_group_id:
|
|
await upload_telegram_chat(
|
|
client,
|
|
node.upload_user,
|
|
app,
|
|
node,
|
|
message,
|
|
DownloadStatus.SkipDownload,
|
|
)
|
|
|
|
chat_download_config.need_check = True
|
|
chat_download_config.total_task = node.total_task
|
|
node.is_running = True
|
|
|
|
# 改动点 C:遍历正常完成,把实际 qualified_files 写入缓存并覆盖 estimated_total
|
|
# 仅在正常完成时写入(异常/中断时不执行,避免脏缓存)
|
|
actual_total = get_task_progress().get("qualified_files", 0)
|
|
db.save_scan_cache(str(node.chat_id), filter_key, actual_total)
|
|
update_task_progress(estimated_total=actual_total, is_checking=False)
|
|
|
|
|
|
async def download_all_chat(client: pyrogram.Client):
|
|
"""Download All chat"""
|
|
# Use list() to avoid "dictionary changed size during iteration" error
|
|
for key, value in list(app.chat_download_config.items()):
|
|
value.node = TaskNode(chat_id=key)
|
|
try:
|
|
await download_chat_task(client, value, value.node)
|
|
except Exception as e:
|
|
logger.warning(f"Download {key} error: {e}")
|
|
finally:
|
|
value.need_check = True
|
|
# 所有频道串行跑完后,补一次最后频道的快照;否则末尾那个频道永远只留在 _task_progress
|
|
# 里、不会进 _completed_chats,前端会一直显示「🚀 下载中」
|
|
snapshot_current_chat()
|
|
# snapshot 只是复制、不会清空 _task_progress,必须 reset 一下,否则前端 h-skip 会把
|
|
# 最后频道的 skipped_files 同时计入 completed_chats 汇总和当前值,造成双计
|
|
reset_task_progress()
|
|
|
|
|
|
async def run_until_all_task_finish():
|
|
"""Normal download"""
|
|
while True:
|
|
finish: bool = True
|
|
for _, value in app.chat_download_config.items():
|
|
if not value.need_check or value.total_task != value.finish_task:
|
|
finish = False
|
|
|
|
if app.restart_program:
|
|
break
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
def _exec_loop():
|
|
"""Exec loop"""
|
|
|
|
app.loop.run_until_complete(run_until_all_task_finish())
|
|
|
|
|
|
async def start_server(client: pyrogram.Client):
|
|
"""
|
|
Start the server using the provided client.
|
|
"""
|
|
await client.start()
|
|
|
|
|
|
async def stop_server(client: pyrogram.Client):
|
|
"""
|
|
Stop the server using the provided client.
|
|
"""
|
|
await client.stop()
|
|
|
|
|
|
def main():
|
|
"""Main function of the downloader."""
|
|
tasks = []
|
|
|
|
# 未填写 API 凭证时以"Web 配置模式"启动,跳过 Telegram 客户端初始化
|
|
if app.is_configured():
|
|
client = HookClient(
|
|
"media_downloader",
|
|
api_id=app.api_id,
|
|
api_hash=app.api_hash,
|
|
proxy=app.proxy,
|
|
workdir=app.session_file_path,
|
|
start_timeout=app.start_timeout,
|
|
)
|
|
else:
|
|
client = None
|
|
logger.info(
|
|
"未检测到 API 凭证,以 Web 配置模式启动。"
|
|
f"请访问 http://0.0.0.0:{app.web_port} 完成初始设置。"
|
|
)
|
|
|
|
try:
|
|
app.pre_run()
|
|
db.init_db(os.path.join(os.path.abspath("."), "appdata", "downloads.db"))
|
|
init_web(app, client, add_download_task, download_chat_task)
|
|
|
|
if client is not None:
|
|
set_max_concurrent_transmissions(client, app.max_concurrent_transmissions)
|
|
app.loop.run_until_complete(start_server(client))
|
|
logger.success(_t("Successfully started (Press Ctrl+C to stop)"))
|
|
|
|
app.loop.create_task(download_all_chat(client))
|
|
for _ in range(app.max_download_task):
|
|
task = app.loop.create_task(worker(client))
|
|
tasks.append(task)
|
|
|
|
if app.bot_token:
|
|
app.loop.run_until_complete(
|
|
start_download_bot(app, client, add_download_task, download_chat_task)
|
|
)
|
|
_exec_loop()
|
|
else:
|
|
# Web 配置模式:保持进程运行,等待用户完成 Web 向导配置后触发重启
|
|
async def _wait_for_restart():
|
|
while not app.restart_program:
|
|
await asyncio.sleep(1)
|
|
app.loop.run_until_complete(_wait_for_restart())
|
|
except KeyboardInterrupt:
|
|
logger.info(_t("KeyboardInterrupt"))
|
|
except Exception as e:
|
|
logger.exception("{}", e)
|
|
finally:
|
|
app.is_running = False
|
|
if client is not None:
|
|
if app.bot_token:
|
|
app.loop.run_until_complete(stop_download_bot())
|
|
try:
|
|
app.loop.run_until_complete(stop_server(client))
|
|
except Exception:
|
|
pass
|
|
for task in tasks:
|
|
task.cancel()
|
|
logger.info(_t("Stopped!"))
|
|
# check_for_updates(app.proxy)
|
|
# Web 配置模式下重启时,不用内存中的空凭证覆盖向导刚写入的配置文件
|
|
if not (client is None and app.restart_program):
|
|
logger.info(f"{_t('update config')}......")
|
|
app.update_config()
|
|
logger.success(
|
|
f"{_t('Updated last read message_id to config file')},"
|
|
f"{_t('total download')} {app.total_download_task}, "
|
|
f"{_t('total upload file')} "
|
|
f"{app.cloud_drive_config.total_upload_success_file_count}"
|
|
)
|
|
|
|
# Return whether restart is needed
|
|
return app.restart_program
|
|
|
|
|
|
def run_with_restart():
|
|
"""Run main with auto-restart support using os.execv for clean restart"""
|
|
import sys
|
|
|
|
should_restart = main()
|
|
|
|
if should_restart:
|
|
logger.info("🔄 正在重启程序以应用新配置...")
|
|
# Shutdown web server to release the port
|
|
shutdown_web()
|
|
# Wait a moment for the port to be released
|
|
time.sleep(1)
|
|
# Use os.execv to completely restart the process
|
|
python = sys.executable
|
|
os.execv(python, [python] + sys.argv)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if _check_config():
|
|
run_with_restart()
|