Files
telegram-downloader/media_downloader.py
T
yuming 53d3ab7769
部署到群晖 / deploy (push) Failing after 34s
fix: 网络断开后任务永久卡在"下载中"
三处协同修复:
1) worker 异常分支补 _release_stuck_task,
   标记 FailedDownload、推进 finish_task、清残留。
2) download_media 中 fetch_message 加 try-except,
   连接异常返回 FailedDownload,不再让异常冒泡。
3) download_chat_task 用 try/finally 兜底回写
   chat_download_config.total_task,避免 _wait_node_finish
   误判频道已完成而切到下一个。

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 22:46:05 +08:00

932 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Downloads media from telegram."""
import asyncio
import logging
import os
import shutil
import time
from typing import List, Optional, Tuple, Union
import pyrogram
from loguru import logger
from pyrogram.types import Audio, Document, Photo, Video, VideoNote, Voice
from rich.logging import RichHandler
import module.database as db
from module.app import Application, ChatDownloadConfig, DownloadStatus, TaskNode
from module.bot import start_download_bot, stop_download_bot
from module.download_stat import (
update_download_status,
update_task_progress,
reset_task_progress,
snapshot_current_chat,
increment_task_stat,
get_task_progress,
is_message_skipped,
skip_message,
remove_download_entry,
clear_skipped_message,
)
from module.get_chat_history_v2 import get_chat_history_v2
from module.language import _t
from module.pyrogram_extension import (
HookClient,
fetch_message,
get_extension,
record_download_status,
report_bot_download_status,
set_max_concurrent_transmissions,
set_meta_data,
update_cloud_upload_stat,
upload_telegram_chat,
)
from module.web import init_web, shutdown_web
from utils.format import truncate_filename, validate_title
from utils.log import LogFilter
from utils.meta import print_meta
from utils.meta_data import MetaData
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler()],
)
CONFIG_NAME = "config.yaml"
DATA_FILE_NAME = "data.yaml"
APPLICATION_NAME = "media_downloader"
app = Application(CONFIG_NAME, DATA_FILE_NAME, APPLICATION_NAME)
queue: asyncio.Queue = asyncio.Queue()
RETRY_TIME_OUT = 3
logging.getLogger("pyrogram.session.session").addFilter(LogFilter())
logging.getLogger("pyrogram.client").addFilter(LogFilter())
logging.getLogger("pyrogram").setLevel(logging.WARNING)
def _check_download_finish(media_size: int, download_path: str, ui_file_name: str):
"""Check download task if finish
Parameters
----------
media_size: int
The size of the downloaded resource
download_path: str
Resource download hold path
ui_file_name: str
Really show file name
"""
download_size = os.path.getsize(download_path)
if media_size == download_size:
logger.success(f"{_t('Successfully downloaded')} - {ui_file_name}")
else:
logger.warning(
f"{_t('Media downloaded with wrong size')}: "
f"{download_size}, {_t('actual')}: "
f"{media_size}, {_t('file name')}: {ui_file_name}"
)
os.remove(download_path)
raise pyrogram.errors.exceptions.bad_request_400.BadRequest()
def _move_to_download_path(temp_download_path: str, download_path: str):
"""Move file to download path
Parameters
----------
temp_download_path: str
Temporary download path
download_path: str
Download path
"""
directory, _ = os.path.split(download_path)
os.makedirs(directory, exist_ok=True)
shutil.move(temp_download_path, download_path)
def _check_timeout(retry: int, _: int):
"""Check if message download timeout, then add message id into failed_ids
Parameters
----------
retry: int
Retry download message times
message_id: int
Try to download message 's id
"""
if retry == 2:
return True
return False
def _can_download(_type: str, file_formats: dict, file_format: Optional[str]) -> bool:
"""
Check if the given file format can be downloaded.
Parameters
----------
_type: str
Type of media object.
file_formats: dict
Dictionary containing the list of file_formats
to be downloaded for `audio`, `document` & `video`
media types
file_format: str
Format of the current file to be downloaded.
Returns
-------
bool
True if the file format can be downloaded else False.
"""
if _type in ["audio", "document", "video"]:
allowed_formats: list = file_formats[_type]
if not file_format in allowed_formats and allowed_formats[0] != "all":
return False
return True
def _is_exist(file_path: str) -> bool:
"""
Check if a file exists and it is not a directory.
Parameters
----------
file_path: str
Absolute path of the file to be checked.
Returns
-------
bool
True if the file exists else False.
"""
return not os.path.isdir(file_path) and os.path.exists(file_path)
# pylint: disable = R0912
async def _get_media_meta(
chat_id: Union[int, str],
message: pyrogram.types.Message,
media_obj: Union[Audio, Document, Photo, Video, VideoNote, Voice],
_type: str,
) -> Tuple[str, str, Optional[str]]:
"""Extract file name and file id from media object.
Parameters
----------
media_obj: Union[Audio, Document, Photo, Video, VideoNote, Voice]
Media object to be extracted.
_type: str
Type of media object.
Returns
-------
Tuple[str, str, Optional[str]]
file_name, file_format
"""
if _type in ["audio", "document", "video"]:
# pylint: disable = C0301
file_format: Optional[str] = media_obj.mime_type.split("/")[-1] # type: ignore
else:
file_format = None
file_name = None
temp_file_name = None
dirname = validate_title(f"{chat_id}")
if message.chat and message.chat.title:
dirname = validate_title(f"{message.chat.title}")
if message.date:
datetime_dir_name = message.date.strftime(app.date_format)
else:
datetime_dir_name = "0"
if _type in ["voice", "video_note"]:
# pylint: disable = C0209
file_format = media_obj.mime_type.split("/")[-1] # type: ignore
file_save_path = app.get_file_save_path(_type, dirname, datetime_dir_name)
file_name = "{} - {}_{}.{}".format(
message.id,
_type,
media_obj.date.isoformat(), # type: ignore
file_format,
)
file_name = validate_title(file_name)
temp_file_name = os.path.join(app.temp_save_path, dirname, file_name)
file_name = os.path.join(file_save_path, file_name)
else:
file_name = getattr(media_obj, "file_name", None)
caption = getattr(message, "caption", None)
file_name_suffix = ".unknown"
if not file_name:
file_name_suffix = get_extension(
media_obj.file_id, getattr(media_obj, "mime_type", "")
)
else:
# file_name = file_name.split(".")[0]
_, file_name_without_suffix = os.path.split(os.path.normpath(file_name))
file_name, file_name_suffix = os.path.splitext(file_name_without_suffix)
if not file_name_suffix:
file_name_suffix = get_extension(
media_obj.file_id, getattr(media_obj, "mime_type", "")
)
if caption:
caption = validate_title(caption)
app.set_caption_name(chat_id, message.media_group_id, caption)
app.set_caption_entities(
chat_id, message.media_group_id, message.caption_entities
)
else:
caption = app.get_caption_name(chat_id, message.media_group_id)
if not file_name and message.photo:
file_name = f"{message.photo.file_unique_id}"
gen_file_name = (
app.get_file_name(message.id, file_name, caption) + file_name_suffix
)
file_save_path = app.get_file_save_path(_type, dirname, datetime_dir_name)
temp_file_name = os.path.join(app.temp_save_path, dirname, gen_file_name)
file_name = os.path.join(file_save_path, gen_file_name)
return truncate_filename(file_name), truncate_filename(temp_file_name), file_format
async def add_download_task(
message: pyrogram.types.Message,
node: TaskNode,
):
"""Add Download task"""
if message.empty:
return False
node.download_status[message.id] = DownloadStatus.Downloading
await queue.put((message, node))
node.total_task += 1
return True
async def save_msg_to_file(
app, chat_id: Union[int, str], message: pyrogram.types.Message
):
"""Write message text into file"""
dirname = validate_title(
message.chat.title if message.chat and message.chat.title else str(chat_id)
)
datetime_dir_name = message.date.strftime(app.date_format) if message.date else "0"
file_save_path = app.get_file_save_path("msg", dirname, datetime_dir_name)
file_name = os.path.join(
app.temp_save_path,
file_save_path,
f"{app.get_file_name(message.id, None, None)}.txt",
)
os.makedirs(os.path.dirname(file_name), exist_ok=True)
if _is_exist(file_name):
return DownloadStatus.SkipDownload, None
with open(file_name, "w", encoding="utf-8") as f:
f.write(message.text or "")
return DownloadStatus.SuccessDownload, file_name
async def download_task(
client: pyrogram.Client, message: pyrogram.types.Message, node: TaskNode
):
"""Download and Forward media"""
# Track download start
increment_task_stat("downloading_files")
download_status, file_name = await download_media(
client, message, app.media_types, app.file_formats, node
)
# Track download completion
increment_task_stat("downloading_files", -1)
if download_status == DownloadStatus.SuccessDownload:
increment_task_stat("completed_files")
elif download_status == DownloadStatus.FailedDownload:
increment_task_stat("failed_files")
if app.enable_download_txt and message.text and not message.media:
download_status, file_name = await save_msg_to_file(app, node.chat_id, message)
if not node.bot:
app.set_download_id(node, message.id, download_status)
node.download_status[message.id] = download_status
file_size = os.path.getsize(file_name) if file_name else 0
await upload_telegram_chat(
client,
node.upload_user if node.upload_user else client,
app,
node,
message,
download_status,
file_name,
)
# rclone upload
if (
not node.upload_telegram_chat_id
and download_status is DownloadStatus.SuccessDownload
):
ui_file_name = file_name
if app.hide_file_name:
ui_file_name = f"****{os.path.splitext(file_name)[-1]}"
if await app.upload_file(
file_name, update_cloud_upload_stat, (node, message.id, ui_file_name)
):
node.upload_success_count += 1
await report_bot_download_status(
node.bot,
node,
download_status,
file_size,
)
# 任务结束(无论成功/失败/跳过)都清理实时进度缓存,避免失败/中断的"僵尸记录"
# 残留在 _download_result 里被 UI 误判为"还在下载"。
# 已完成的历史记录走数据库(db.record_download),不依赖这个 dict。
# key 类型与 update_download_status 写入时保持一致(node.chat_id 原类型)。
remove_download_entry(node.chat_id, message.id)
# pylint: disable = R0915,R0914
@record_download_status
async def download_media(
client: pyrogram.client.Client,
message: pyrogram.types.Message,
media_types: List[str],
file_formats: dict,
node: TaskNode,
):
"""
Download media from Telegram.
Each of the files to download are retried 3 times with a
delay of 5 seconds each.
Parameters
----------
client: pyrogram.client.Client
Client to interact with Telegram APIs.
message: pyrogram.types.Message
Message object retrieved from telegram.
media_types: list
List of strings of media types to be downloaded.
Ex : `["audio", "photo"]`
Supported formats:
* audio
* document
* photo
* video
* voice
file_formats: dict
Dictionary containing the list of file_formats
to be downloaded for `audio`, `document` & `video`
media types.
Returns
-------
int
Current message id.
"""
# pylint: disable = R0912
file_name: str = ""
ui_file_name: str = ""
task_start_time: float = time.time()
media_size = 0
_media = None
# 关键修复:fetch_message 是发起任何下载前的网络调用,连接异常时不能让异常一路冒泡,
# 否则 worker 只会打日志,状态永远停在 DownloadingUI 卡死。
try:
message = await fetch_message(client, message)
except Exception as fetch_err:
logger.warning(
f"Message[{getattr(message, 'id', '?')}] 拉取消息失败(可能连接断开): {fetch_err}"
)
return DownloadStatus.FailedDownload, None
try:
for _type in media_types:
_media = getattr(message, _type, None)
if _media is None:
continue
file_name, temp_file_name, file_format = await _get_media_meta(
node.chat_id, message, _media, _type
)
media_size = getattr(_media, "file_size", 0)
ui_file_name = file_name
if app.hide_file_name:
ui_file_name = f"****{os.path.splitext(file_name)[-1]}"
if _can_download(_type, file_formats, file_format):
if _is_exist(file_name):
file_size = os.path.getsize(file_name)
if file_size or file_size == media_size:
logger.info(
f"id={message.id} {ui_file_name} "
f"{_t('already download,download skipped')}.\n"
)
# Update skip counter
increment_task_stat("skipped_files")
# 这一类跳过属于"本次任务中发现已下载",单独计数给前端算分母
increment_task_stat("existing_skipped")
increment_task_stat("checked_messages")
return DownloadStatus.SkipDownload, None
else:
_should_skip, _reason = db.should_skip(str(node.chat_id), message.id)
if _should_skip:
logger.info(
f"id={message.id} {ui_file_name} {_reason},跳过。\n"
)
increment_task_stat("skipped_files")
# db 标记为跳过也算"本次任务跳过"
increment_task_stat("existing_skipped")
increment_task_stat("checked_messages")
return DownloadStatus.SkipDownload, None
else:
increment_task_stat("checked_messages")
return DownloadStatus.SkipDownload, None
break
except Exception as e:
logger.error(
f"Message[{message.id}]: "
f"{_t('could not be downloaded due to following exception')}:\n[{e}].",
exc_info=True,
)
return DownloadStatus.FailedDownload, None
if _media is None:
return DownloadStatus.SkipDownload, None
message_id = message.id
for retry in range(3):
if is_message_skipped(str(node.chat_id), message_id):
clear_skipped_message(str(node.chat_id), message_id)
remove_download_entry(str(node.chat_id), message_id)
return DownloadStatus.SkipDownload, None
try:
temp_download_path = await client.download_media(
message,
file_name=temp_file_name,
progress=update_download_status,
progress_args=(
message_id,
ui_file_name,
task_start_time,
node,
client,
),
)
if temp_download_path and isinstance(temp_download_path, str):
_check_download_finish(media_size, temp_download_path, ui_file_name)
await asyncio.sleep(0.5)
_move_to_download_path(temp_download_path, file_name)
chat_title = ""
if message.chat and message.chat.title:
chat_title = message.chat.title
db.record_download(
chat_id=str(node.chat_id),
chat_title=chat_title,
message_id=message.id,
file_name=os.path.basename(file_name),
file_path=file_name,
file_size=media_size,
media_type=_type,
status="success",
)
# TODO: if not exist file size or media
return DownloadStatus.SuccessDownload, file_name
except pyrogram.errors.exceptions.bad_request_400.BadRequest:
logger.warning(
f"Message[{message.id}]: {_t('file reference expired, refetching')}..."
)
await asyncio.sleep(RETRY_TIME_OUT)
message = await fetch_message(client, message)
if _check_timeout(retry, message.id):
# pylint: disable = C0301
logger.error(
f"Message[{message.id}]: "
f"{_t('file reference expired for 3 retries, download skipped.')}"
)
except pyrogram.errors.exceptions.flood_420.FloodWait as wait_err:
await asyncio.sleep(wait_err.value)
logger.warning("Message[{}]: FlowWait {}", message.id, wait_err.value)
_check_timeout(retry, message.id)
except TypeError:
# pylint: disable = C0301
logger.warning(
f"{_t('Timeout Error occurred when downloading Message')}[{message.id}], "
f"{_t('retrying after')} {RETRY_TIME_OUT} {_t('seconds')}"
)
await asyncio.sleep(RETRY_TIME_OUT)
if _check_timeout(retry, message.id):
logger.error(
f"Message[{message.id}]: {_t('Timing out after 3 reties, download skipped.')}"
)
except Exception as e:
# pylint: disable = C0301
logger.error(
f"Message[{message.id}]: "
f"{_t('could not be downloaded due to following exception')}:\n[{e}].",
exc_info=True,
)
break
return DownloadStatus.FailedDownload, None
def _load_config():
"""Load config"""
app.load_config()
def _check_config() -> bool:
"""Check config"""
print_meta(logger)
try:
_load_config()
logger.add(
os.path.join(app.log_file_path, "tdl.log"),
rotation="10 MB",
retention="10 days",
level=app.log_level,
)
except Exception as e:
logger.exception(f"load config error: {e}")
return False
return True
async def worker(client: pyrogram.client.Client):
"""下载任务消费者协程"""
while app.is_running:
message = None
node = None
try:
item = await queue.get()
message = item[0]
node = item[1]
if node.is_stop_transmission:
# 主动中止:把队列里残留的下载中状态清掉,避免 UI 一直显示 Downloading
_release_stuck_task(node, message, DownloadStatus.SkipDownload)
continue
if is_message_skipped(str(node.chat_id), message.id):
skip_message(str(node.chat_id), message.id)
_release_stuck_task(node, message, DownloadStatus.SkipDownload)
continue
if node.client:
await download_task(node.client, message, node)
else:
await download_task(client, message, node)
except Exception as e:
logger.exception(f"worker 捕获到未处理异常: {e}")
# 关键修复:worker 吞异常时必须把状态推进,否则 finish_task 永远追不上 total_task
if node is not None and message is not None:
_release_stuck_task(node, message, DownloadStatus.FailedDownload)
def _release_stuck_task(
node: "TaskNode",
message: "pyrogram.types.Message",
status: "DownloadStatus",
):
"""将异常/中断的任务从"下载中"状态释放,避免任务队列永久卡死。
做三件事:
1) 标记 download_status,让 UI 不再显示"下载中"
2) 推进 finish_task,让 _wait_node_finish 能正常退出。
3) 清理 _download_result 残留条目,避免速度/列表脏数据。
"""
try:
msg_id = getattr(message, "id", None)
if msg_id is None:
return
node.download_status[msg_id] = status
# finish_task 通过 app.set_download_id 推进;bot 模式下原逻辑也不走 set_download_id
# 这里保持一致:非 bot 时才推进,避免重复计数。
if not node.bot:
try:
app.set_download_id(node, msg_id, status)
except Exception as inner:
logger.warning(f"释放卡死任务时 set_download_id 失败 msg_id={msg_id}: {inner}")
remove_download_entry(node.chat_id, msg_id)
except Exception as e:
logger.warning(f"释放卡死任务清理失败: {e}")
async def download_chat_task(
client: pyrogram.Client,
chat_download_config: ChatDownloadConfig,
node: TaskNode,
):
"""Download all task"""
# 切到下一个频道前,先把上一个频道的最终进度快照进 _completed_chats
snapshot_current_chat()
# Reset and update task progress
reset_task_progress()
# Try to get chat title
try:
chat = await client.get_chat(node.chat_id)
chat_title = chat.title or chat.first_name or str(node.chat_id)
except Exception:
chat_title = str(node.chat_id)
update_task_progress(
current_chat=str(node.chat_id),
current_chat_title=chat_title,
is_checking=True
)
# 改动点 A:尝试读预扫描缓存;命中则 banner 立刻显示 0 / N
filter_key = db.build_filter_key(
chat_download_config.download_filter,
app.media_types,
app.file_formats,
)
cached_total = db.get_scan_cache(str(node.chat_id), filter_key)
if cached_total is not None:
update_task_progress(estimated_total=cached_total)
messages_iter = get_chat_history_v2(
client,
node.chat_id,
limit=node.limit,
max_id=node.end_offset_id,
offset_id=chat_download_config.last_read_message_id,
reverse=True,
)
chat_download_config.node = node
if chat_download_config.ids_to_retry:
logger.info(f"{_t('Downloading files failed during last run')}...")
skipped_messages: list = await client.get_messages( # type: ignore
chat_id=node.chat_id, message_ids=chat_download_config.ids_to_retry
)
for message in skipped_messages:
await add_download_task(message, node)
# 关键修复:消息迭代和缓存写入必须包进 try/finally
# 否则中途抛 Connection lost 时 total_task 不会被回写,_wait_node_finish 误判频道已完成,
# 后续 worker 还在跑就被切到下一个频道,UI 永远显示"下载中"。
iter_completed = False
try:
async for message in messages_iter: # type: ignore
# Update checking progress for each message
increment_task_stat("checked_messages")
meta_data = MetaData()
caption = message.caption
if caption:
caption = validate_title(caption)
app.set_caption_name(node.chat_id, message.media_group_id, caption)
app.set_caption_entities(
node.chat_id, message.media_group_id, message.caption_entities
)
else:
caption = app.get_caption_name(node.chat_id, message.media_group_id)
set_meta_data(meta_data, message, caption)
if app.need_skip_message(chat_download_config, message.id):
continue
if app.exec_filter(chat_download_config, meta_data):
# 改动点 B:通过 filter 的消息计数,作为 X/N 的分母实时递增
increment_task_stat("qualified_files")
await add_download_task(message, node)
else:
node.download_status[message.id] = DownloadStatus.SkipDownload
increment_task_stat("skipped_files")
if message.media_group_id:
await upload_telegram_chat(
client,
node.upload_user,
app,
node,
message,
DownloadStatus.SkipDownload,
)
iter_completed = True
finally:
# 不论遍历是否正常结束,都必须把 node.total_task 同步给 chat_download_config
# 否则 _wait_node_finish 会用 total_task=0 的旧值立即返回,跳过对 worker 的等待。
chat_download_config.need_check = True
chat_download_config.total_task = node.total_task
node.is_running = True
# 改动点 C:遍历正常完成,把实际 qualified_files 写入缓存并覆盖 estimated_total
# 仅在正常完成时写入(异常/中断时不执行,避免脏缓存)
if iter_completed:
actual_total = get_task_progress().get("qualified_files", 0)
db.save_scan_cache(str(node.chat_id), filter_key, actual_total)
update_task_progress(estimated_total=actual_total, is_checking=False)
else:
# 异常中断时也要把 is_checking 关掉,避免 UI 一直显示"扫描中"
update_task_progress(is_checking=False)
async def _wait_node_finish(chat_config, timeout: int = 3600):
"""等待单个频道的所有下载任务完成(含重试中的任务),超时后强制跳过避免队列卡死"""
deadline = asyncio.get_event_loop().time() + timeout
while True:
if chat_config.need_check and chat_config.finish_task >= chat_config.total_task:
break
if asyncio.get_event_loop().time() > deadline:
logger.warning(
f"等待频道任务完成超时({timeout}s),强制跳过,"
f"finish={chat_config.finish_task} total={chat_config.total_task}"
)
break
await asyncio.sleep(0.5)
async def download_all_chat(client: pyrogram.Client):
"""Download All chat"""
# Use list() to avoid "dictionary changed size during iteration" error
for key, value in list(app.chat_download_config.items()):
value.node = TaskNode(chat_id=key)
try:
await download_chat_task(client, value, value.node)
except Exception as e:
logger.warning(f"Download {key} error: {e}")
finally:
value.need_check = True
# 等当前频道所有下载任务跑完,再拍快照;否则 completed_files 还是 0
await _wait_node_finish(value)
snapshot_current_chat()
reset_task_progress()
# 所有频道都已在循环内快照并重置,此处仅确保最终状态干净
reset_task_progress()
async def run_until_all_task_finish():
"""Normal download"""
while True:
finish: bool = True
for _, value in app.chat_download_config.items():
if not value.need_check or value.total_task != value.finish_task:
finish = False
if app.restart_program:
break
await asyncio.sleep(1)
def _exec_loop():
"""Exec loop"""
app.loop.run_until_complete(run_until_all_task_finish())
async def start_server(client: pyrogram.Client):
"""
Start the server using the provided client.
"""
await client.start()
async def stop_server(client: pyrogram.Client):
"""
Stop the server using the provided client.
"""
await client.stop()
def main():
"""Main function of the downloader."""
tasks = []
# 未填写 API 凭证时以"Web 配置模式"启动,跳过 Telegram 客户端初始化
if app.is_configured():
client = HookClient(
"media_downloader",
api_id=app.api_id,
api_hash=app.api_hash,
proxy=app.proxy,
workdir=app.session_file_path,
start_timeout=app.start_timeout,
)
else:
client = None
logger.info(
"未检测到 API 凭证,以 Web 配置模式启动。"
f"请访问 http://0.0.0.0:{app.web_port} 完成初始设置。"
)
try:
app.pre_run()
db.init_db(os.path.join(os.path.abspath("."), "appdata", "downloads.db"))
init_web(app, client, add_download_task, download_chat_task)
if client is not None:
set_max_concurrent_transmissions(client, app.max_concurrent_transmissions)
app.loop.run_until_complete(start_server(client))
logger.success(_t("Successfully started (Press Ctrl+C to stop)"))
app.loop.create_task(download_all_chat(client))
for _ in range(app.max_download_task):
task = app.loop.create_task(worker(client))
tasks.append(task)
if app.bot_token:
app.loop.run_until_complete(
start_download_bot(app, client, add_download_task, download_chat_task)
)
_exec_loop()
else:
# Web 配置模式:保持进程运行,等待用户完成 Web 向导配置后触发重启
async def _wait_for_restart():
while not app.restart_program:
await asyncio.sleep(1)
app.loop.run_until_complete(_wait_for_restart())
except KeyboardInterrupt:
logger.info(_t("KeyboardInterrupt"))
except Exception as e:
logger.exception("{}", e)
finally:
app.is_running = False
if client is not None:
if app.bot_token:
app.loop.run_until_complete(stop_download_bot())
try:
app.loop.run_until_complete(stop_server(client))
except Exception:
pass
for task in tasks:
task.cancel()
logger.info(_t("Stopped!"))
# check_for_updates(app.proxy)
# Web 配置模式下重启时,不用内存中的空凭证覆盖向导刚写入的配置文件
if not (client is None and app.restart_program):
logger.info(f"{_t('update config')}......")
app.update_config()
logger.success(
f"{_t('Updated last read message_id to config file')},"
f"{_t('total download')} {app.total_download_task}, "
f"{_t('total upload file')} "
f"{app.cloud_drive_config.total_upload_success_file_count}"
)
# Return whether restart is needed
return app.restart_program
def run_with_restart():
"""Run main with auto-restart support using os.execv for clean restart"""
import sys
should_restart = main()
if should_restart:
logger.info("🔄 正在重启程序以应用新配置...")
# Shutdown web server to release the port
shutdown_web()
# Wait a moment for the port to be released
time.sleep(1)
# Use os.execv to completely restart the process
python = sys.executable
os.execv(python, [python] + sys.argv)
if __name__ == "__main__":
if _check_config():
run_with_restart()