Spaces:
Build error
Build error
| import gradio as gr | |
| import requests | |
| import json | |
| import os | |
| import subprocess | |
| import uuid | |
| import time | |
| import cv2 | |
| from typing import Optional, List | |
| import numpy as np | |
| from datetime import datetime, timedelta | |
| from collections import defaultdict | |
| import shutil | |
| from urllib.parse import urljoin | |
| import oss2 | |
| from natsort import natsorted | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from tqdm import tqdm | |
| import hashlib | |
| TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") | |
| os.makedirs(TMP_ROOT, exist_ok=True) | |
| # 后端API配置(可配置化) | |
| BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000") | |
| API_ENDPOINTS = { | |
| "submit_task": f"{BACKEND_URL}/predict/video", | |
| "query_status": f"{BACKEND_URL}/predict/task", | |
| "terminate_task": f"{BACKEND_URL}/predict/terminate" | |
| } | |
| # 模拟场景配置 | |
| SCENE_CONFIGS = { | |
| "scene_1": { | |
| "description": "scene_1", | |
| "objects": ["milk carton", "ceramic bowl", "mug"], | |
| "preview_image": "assets/scene_1.png" | |
| }, | |
| } | |
| # 可用模型列表 | |
| MODEL_CHOICES = [ | |
| "gr1", | |
| # "GR00T-N1", | |
| # "GR00T-1.5", | |
| # "Pi0", | |
| # "DP+CLIP", | |
| # "AcT+CLIP" | |
| ] | |
| ############################################################################### | |
| SESSION_TASKS = {} | |
| IP_REQUEST_RECORDS = defaultdict(list) | |
| IP_LIMIT = 5 # 每分钟最多请求次数 | |
| def is_request_allowed(ip: str) -> bool: | |
| now = datetime.now() | |
| IP_REQUEST_RECORDS[ip] = [t for t in IP_REQUEST_RECORDS[ip] if now - t < timedelta(minutes=1)] | |
| if len(IP_REQUEST_RECORDS[ip]) < IP_LIMIT: | |
| IP_REQUEST_RECORDS[ip].append(now) | |
| return True | |
| return False | |
| ############################################################################### | |
| # 日志文件路径 | |
| LOG_DIR = "logs" | |
| os.makedirs(LOG_DIR, exist_ok=True) | |
| ACCESS_LOG = os.path.join(LOG_DIR, "access.log") | |
| SUBMISSION_LOG = os.path.join(LOG_DIR, "submissions.log") | |
| def log_access(user_ip: str = None, user_agent: str = None): | |
| """记录用户访问日志""" | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| log_entry = { | |
| "timestamp": timestamp, | |
| "type": "access", | |
| "user_ip": user_ip or "unknown", | |
| "user_agent": user_agent or "unknown" | |
| } | |
| with open(ACCESS_LOG, "a") as f: | |
| f.write(json.dumps(log_entry) + "\n") | |
| def log_submission(scene: str, prompt: str, model: str, max_step: int, user: str = "anonymous", res: str = "unknown"): | |
| """记录用户提交日志""" | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| log_entry = { | |
| "timestamp": timestamp, | |
| "type": "submission", | |
| "user": user, | |
| "scene": scene, | |
| "prompt": prompt, | |
| "model": model, | |
| "max_step": str(max_step), | |
| "res": res | |
| } | |
| with open(SUBMISSION_LOG, "a") as f: | |
| f.write(json.dumps(log_entry) + "\n") | |
| # 记录访问 | |
| def record_access(request: gr.Request): | |
| user_ip = request.client.host if request else "unknown" | |
| user_agent = request.headers.get("user-agent", "unknown") | |
| log_access(user_ip, user_agent) | |
| return update_log_display() | |
| def read_logs(log_type: str = "all", max_entries: int = 50) -> list: | |
| """读取日志文件""" | |
| logs = [] | |
| if log_type in ["all", "access"]: | |
| try: | |
| with open(ACCESS_LOG, "r") as f: | |
| for line in f: | |
| logs.append(json.loads(line.strip())) | |
| except FileNotFoundError: | |
| pass | |
| if log_type in ["all", "submission"]: | |
| try: | |
| with open(SUBMISSION_LOG, "r") as f: | |
| for line in f: | |
| logs.append(json.loads(line.strip())) | |
| except FileNotFoundError: | |
| pass | |
| # 按时间戳排序,最新的在前 | |
| logs.sort(key=lambda x: x["timestamp"], reverse=True) | |
| return logs[:max_entries] | |
| def format_logs_for_display(logs: list) -> str: | |
| """格式化日志用于显示""" | |
| if not logs: | |
| return "暂无日志记录" | |
| markdown = "### 系统访问日志\n\n" | |
| markdown += "| 时间 | 类型 | 用户/IP | 详细信息 |\n" | |
| markdown += "|------|------|---------|----------|\n" | |
| for log in logs: | |
| timestamp = log.get("timestamp", "unknown") | |
| log_type = "访问" if log.get("type") == "access" else "提交" | |
| if log_type == "访问": | |
| user = log.get("user_ip", "unknown") | |
| details = f"User-Agent: {log.get('user_agent', 'unknown')}" | |
| else: | |
| user = log.get("user", "anonymous") | |
| result = log.get('res', 'unknown') | |
| if result != "success": | |
| if len(result) > 40: # Adjust this threshold as needed | |
| result = f"{result[:20]}...{result[-20:]}" | |
| details = f"场景: {log.get('scene', 'unknown')}, 指令: {log.get('prompt', '')}, 模型: {log.get('model', 'unknown')}, max step: {log.get('max_step', '300')}, result: {result}" | |
| markdown += f"| {timestamp} | {log_type} | {user} | {details} |\n" | |
| return markdown | |
| ############################################################################### | |
| # OSS配置 | |
| OSS_CONFIG = { | |
| "access_key_id": os.getenv("OSS_ACCESS_KEY_ID"), | |
| "access_key_secret": os.getenv("OSS_ACCESS_KEY_SECRET"), | |
| "endpoint": os.getenv("OSS_ENDPOINT"), | |
| "bucket_name": os.getenv("OSS_BUCKET_NAME") | |
| } | |
| auth = oss2.Auth(OSS_CONFIG["access_key_id"], OSS_CONFIG["access_key_secret"]) | |
| bucket = oss2.Bucket(auth, OSS_CONFIG["endpoint"], OSS_CONFIG["bucket_name"], enable_crc=False) | |
| def list_oss_files(folder_path: str) -> List[str]: | |
| """列出OSS文件夹中的所有文件""" | |
| files = [] | |
| for obj in oss2.ObjectIterator(bucket, prefix=folder_path): | |
| if not obj.key.endswith('/'): # 排除目录本身 | |
| files.append(obj.key) | |
| return sorted(files, key=lambda x: os.path.splitext(x)[0]) | |
| def parallel_download_oss_files( | |
| bucket, | |
| oss_folder: str, | |
| local_dir: str, | |
| file_list: list[str], | |
| max_workers: int = 5 | |
| ) -> bool: | |
| """ | |
| 极简版并行下载指定文件列表 | |
| 参数: | |
| bucket: OSS Bucket对象 | |
| oss_folder: OSS文件夹路径 (如 "path/to/folder/") | |
| local_dir: 本地存储目录 | |
| file_list: 需要下载的文件相对路径列表 (如 ["img1.jpg", "sub/img2.png"]) | |
| max_workers: 最大并发数 | |
| """ | |
| def download_single_file(oss_path, local_path): | |
| try: | |
| bucket.get_object_to_file(oss_path, local_path) | |
| return True | |
| except Exception as e: | |
| print(f"下载失败 {oss_path}: {str(e)}") | |
| return False | |
| # 确保本地目录存在 | |
| os.makedirs(local_dir, exist_ok=True) | |
| # 准备下载任务 | |
| tasks = [] | |
| for file in file_list: | |
| oss_path = f"{file.lstrip('/')}" | |
| filename = os.path.basename(oss_path) | |
| local_path = os.path.join(local_dir, filename) | |
| # os.makedirs(os.path.dirname(local_path), exist_ok=True) | |
| tasks.append((oss_path, local_path)) | |
| # 并行下载 | |
| success_count = 0 | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| futures = [] | |
| for oss_path, local_path in tasks: | |
| futures.append(executor.submit(download_single_file, oss_path, local_path)) | |
| # 进度条显示 | |
| for future in tqdm(as_completed(futures), total=len(tasks), desc="下载进度"): | |
| if future.result(): | |
| success_count += 1 | |
| print(f"下载完成: {success_count}/{len(tasks)} 成功") | |
| return success_count == len(tasks) | |
| def download_oss_file(oss_path: str, local_path: str) -> bool: | |
| """从OSS下载文件到本地,返回是否成功""" | |
| start_time = time.time() # 记录开始时间 | |
| try: | |
| result = bucket.get_object_to_file(oss_path, local_path) | |
| download_time = time.time() - start_time # 计算下载耗时 | |
| print(f"下载: {oss_path}, 状态码: {result.status}, 耗时: {download_time:.2f}秒") | |
| return result.status == 200 | |
| except Exception as e: | |
| print(f"下载失败: {e}") | |
| return False | |
| def oss_file_exists(oss_path): | |
| try: | |
| # Assuming you have an OSS bucket object | |
| return bucket.object_exists(oss_path) | |
| except Exception as e: | |
| print(f"Error checking if file exists in OSS: {str(e)}") | |
| return False | |
| def stream_simulation_results(result_folder: str, task_id: str, request: gr.Request, fps: int = 30): | |
| """ | |
| 流式输出仿真结果,从OSS读取图片 | |
| 参数: | |
| result_folder: OSS上包含生成图片的文件夹路径 | |
| task_id: 后端任务ID用于状态查询 | |
| request: Gradio请求对象 | |
| fps: 输出视频的帧率 | |
| 生成: | |
| 生成的视频文件路径 (分段输出) | |
| """ | |
| # 初始化变量 | |
| image_folder = os.path.join(result_folder, "image") | |
| os.makedirs(image_folder, exist_ok=True) | |
| frame_buffer: List[np.ndarray] = [] | |
| min_frames_per_segment = fps * 1 # 至少30帧才输出 | |
| processed_files = set() | |
| width, height = 0, 0 | |
| last_status_check = 0 | |
| status_check_interval = 5 # 每5秒检查一次后端状态 | |
| max_time = 240 | |
| # 创建临时目录存储下载的图片 | |
| user_dir = os.path.join(TMP_ROOT, str(request.session_hash)) | |
| local_image_dir = os.path.join(user_dir, task_id, "tasks", "images") | |
| os.makedirs(local_image_dir, exist_ok=True) | |
| while max_time > 0: | |
| max_time -= 1 | |
| current_time = time.time() | |
| # 定期检查后端状态 | |
| if current_time - last_status_check > status_check_interval: | |
| status = get_task_status(task_id) | |
| print(str(request.session_hash), "status: ", status) | |
| if status.get("status") == "completed": | |
| # 确保处理完所有已生成的图片 | |
| process_remaining_oss_images(image_folder, local_image_dir, processed_files, frame_buffer) | |
| if frame_buffer: | |
| yield create_video_segment(frame_buffer, fps, width, height, request) | |
| break | |
| elif status.get("status") == "failed": | |
| raise gr.Error(f"任务执行失败: {status.get('result', '未知错误')}") | |
| elif status.get("status") == "terminated": | |
| break | |
| last_status_check = current_time | |
| # 从OSS获取文件列表 | |
| try: | |
| oss_files = list_oss_files(image_folder) | |
| new_files = [f for f in oss_files if f not in processed_files] | |
| if len(new_files) != 0: | |
| print(f"发现新文件: {len(new_files)} 个", new_files) | |
| success = parallel_download_oss_files( | |
| bucket=bucket, | |
| oss_folder=image_folder + "/", | |
| local_dir=local_image_dir + "/", | |
| file_list=new_files, | |
| max_workers=5 # 根据网络带宽调整 | |
| ) | |
| if not success: | |
| raise gr.Error("无法从OSS同步图片文件") | |
| # if not download_oss_files_with_ossutil(image_folder + "/", local_image_dir + "/"): | |
| # raise gr.Error("无法从OSS同步图片文件") | |
| for oss_path in new_files: | |
| try: | |
| # 下载文件到本地 | |
| filename = os.path.basename(oss_path) | |
| local_path = os.path.join(local_image_dir, filename) | |
| # download_oss_file(oss_path, local_path) | |
| # 读取图片 | |
| frame = cv2.imread(local_path) | |
| if frame is not None: | |
| if width == 0: # 第一次获取图像尺寸 | |
| height, width = frame.shape[:2] | |
| frame_buffer.append(frame) | |
| processed_files.add(oss_path) | |
| except Exception as e: | |
| print(f"Error processing {oss_path}: {e}") | |
| # 如果有新帧且积累够60帧以上,输出所有当前帧 | |
| if len(frame_buffer) >= min_frames_per_segment: | |
| yield create_video_segment(frame_buffer, fps, width, height, request) | |
| frame_buffer = [] # 清空缓冲区 | |
| except Exception as e: | |
| print(f"Error accessing OSS: {e}") | |
| time.sleep(1) # 避免过于频繁检查 | |
| if max_time <= 0: | |
| raise gr.Error("timeout 240s") | |
| def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height: int, req: gr.Request) -> str: | |
| """创建视频片段""" | |
| user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) | |
| os.makedirs(user_dir, exist_ok=True) | |
| video_chunk_path = os.path.join(user_dir, "video_chunk") | |
| os.makedirs(video_chunk_path, exist_ok=True) | |
| segment_name = os.path.join(video_chunk_path, f"output_{uuid.uuid4()}.mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(segment_name, fourcc, fps, (width, height)) | |
| for frame in frames: | |
| out.write(frame) | |
| out.release() | |
| return segment_name | |
| def process_remaining_oss_images(oss_folder: str, local_dir: str, processed_files: set, frame_buffer: List[np.ndarray]): | |
| """处理OSS上剩余的图片""" | |
| try: | |
| oss_files = list_oss_files(oss_folder) | |
| new_files = [f for f in oss_files if f not in processed_files] | |
| if len(new_files) != 0: | |
| print(f"发现新文件: {len(new_files)} 个", new_files) | |
| success = parallel_download_oss_files( | |
| bucket=bucket, | |
| oss_folder=oss_folder + "/", | |
| local_dir=local_dir + "/", | |
| file_list=new_files, | |
| max_workers=5 # 根据网络带宽调整 | |
| ) | |
| if not success: | |
| raise gr.Error("无法从OSS同步图片文件") | |
| for oss_path in new_files: | |
| try: | |
| # 下载文件到本地 | |
| filename = os.path.basename(oss_path) | |
| local_path = os.path.join(local_dir, filename) | |
| # download_oss_file(oss_path, local_path) | |
| # 读取图片 | |
| frame = cv2.imread(local_path) | |
| if frame is not None: | |
| frame_buffer.append(frame) | |
| processed_files.add(oss_path) | |
| except Exception as e: | |
| print(f"Error processing remaining {oss_path}: {e}") | |
| except Exception as e: | |
| print(f"Error accessing OSS for remaining files: {e}") | |
| ############################################################################### | |
| def submit_to_backend( | |
| scene: str, | |
| prompt: str, | |
| model: str, | |
| max_step: int, | |
| user: str = "Gradio-user", | |
| ) -> dict: | |
| job_id = str(uuid.uuid4()) | |
| data = { | |
| "scene_type": scene, | |
| "instruction": prompt, | |
| "model_type": model, | |
| "max_step": str(max_step) | |
| } | |
| payload = { | |
| "user": user, | |
| "task": "robot_manipulation", | |
| "job_id": job_id, | |
| "data": json.dumps(data) | |
| } | |
| try: | |
| headers = {"Content-Type": "application/json"} | |
| response = requests.post( | |
| API_ENDPOINTS["submit_task"], | |
| json=payload, | |
| headers=headers, | |
| timeout=10 | |
| ) | |
| return response.json() | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| def get_task_status(task_id: str) -> dict: | |
| """ | |
| 查询任务状态 | |
| """ | |
| start_time = time.time() | |
| try: | |
| response = requests.get( | |
| f"{API_ENDPOINTS['query_status']}/{task_id}", | |
| timeout=5 | |
| ) | |
| elapsed_time = time.time() - start_time # 计算耗时 | |
| print(f"[查询任务状态] task_id: {task_id}, 耗时: {elapsed_time:.3f}s") | |
| return response.json() | |
| except Exception as e: | |
| elapsed_time = time.time() - start_time # 计算失败耗时 | |
| print(f"[查询任务状态失败] task_id: {task_id}, 错误: {str(e)}, 耗时: {elapsed_time:.3f}s") | |
| return {"status": "error get_task_status", "message": str(e)} | |
| def terminate_task(task_id: str) -> Optional[dict]: | |
| """ | |
| 终止任务 | |
| """ | |
| try: | |
| response = requests.post( | |
| f"{API_ENDPOINTS['terminate_task']}/{task_id}", | |
| timeout=3 | |
| ) | |
| return response.json() | |
| except Exception as e: | |
| print(f"Error terminate task: {e}") | |
| return None | |
| def convert_to_h264(video_path): | |
| """ | |
| 将视频转换为 H.264 编码的 MP4 格式 | |
| 生成新文件路径在原路径基础上添加 _h264 后缀) | |
| """ | |
| base, ext = os.path.splitext(video_path) | |
| video_path_h264 = f"{base}_h264.mp4" | |
| try: | |
| # 构建 FFmpeg 命令 | |
| ffmpeg_cmd = [ | |
| "ffmpeg", | |
| "-i", video_path, | |
| "-c:v", "libx264", | |
| "-preset", "slow", | |
| "-crf", "23", | |
| "-c:a", "aac", | |
| "-movflags", "+faststart", | |
| video_path_h264 | |
| ] | |
| # 执行 FFmpeg 命令 | |
| subprocess.run(ffmpeg_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| # 检查输出文件是否存在 | |
| if not os.path.exists(video_path_h264): | |
| raise FileNotFoundError(f"H.264 编码文件未生成: {video_path_h264}") | |
| return video_path_h264 | |
| except subprocess.CalledProcessError as e: | |
| raise gr.Error(f"FFmpeg 转换失败: {e.stderr}") | |
| except Exception as e: | |
| raise gr.Error(f"转换过程中发生错误: {str(e)}") | |
| def generate_whole_video(task_id: str, request: gr.Request, fps: int = 30) -> str: | |
| """ | |
| 从图片序列生成完整视频 | |
| Args: | |
| task_id: 任务ID | |
| fps: 视频帧率,默认为30 | |
| Returns: | |
| 生成的视频文件路径 | |
| """ | |
| frame_buffer: List[np.ndarray] = [] | |
| user_dir = os.path.join(TMP_ROOT, str(request.session_hash)) | |
| image_folder = os.path.join(user_dir, task_id, "tasks", "images") | |
| # 确保输出目录存在 | |
| result_folder = os.path.join(user_dir, task_id, "tasks", "video") | |
| os.makedirs(result_folder, exist_ok=True) | |
| # 获取所有图片文件并按自然顺序排序 | |
| image_files = [f for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
| image_files = natsorted(image_files) # 自然排序处理类似 'frame1, frame2, ..., frame10' 的情况 | |
| if not image_files: | |
| raise ValueError(f"No image files found in {image_folder}") | |
| # 初始化视频尺寸 | |
| width, height = 0, 0 | |
| for img_file in image_files: | |
| img_path = os.path.join(image_folder, img_file) | |
| try: | |
| frame = cv2.imread(img_path) | |
| if frame is not None: | |
| if width == 0: # 第一次获取图像尺寸 | |
| height, width = frame.shape[:2] | |
| frame_buffer.append(frame) | |
| except Exception as e: | |
| print(f"Error processing {img_path}: {e}") | |
| continue | |
| if not frame_buffer: | |
| raise ValueError("No valid frames found to create video") | |
| # 生成视频文件名 | |
| output_video_path = os.path.join(result_folder, f"manipulation.mp4") | |
| # 创建视频 | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 或使用 'avc1' 更好的兼容性 | |
| video_writer = cv2.VideoWriter( | |
| output_video_path, | |
| fourcc, | |
| fps, | |
| (width, height) | |
| ) | |
| for frame in frame_buffer: | |
| video_writer.write(frame) | |
| video_writer.release() | |
| # 验证视频是否成功创建 | |
| if not os.path.exists(output_video_path) or os.path.getsize(output_video_path) == 0: | |
| raise RuntimeError(f"Failed to create video at {output_video_path}") | |
| return output_video_path | |
| def run_simulation( | |
| scene: str, | |
| prompt: str, | |
| model: str, | |
| max_step: int, | |
| history: list, | |
| request: gr.Request | |
| ): | |
| """运行仿真并更新历史记录""" | |
| # 获取当前时间 | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| scene_desc = SCENE_CONFIGS.get(scene, {}).get("description", scene) | |
| # 记录用户提交 | |
| user_ip = request.client.host if request else "unknown" | |
| session_id = request.session_hash | |
| user_dir = os.path.join(TMP_ROOT, str(request.session_hash)) | |
| if not is_request_allowed(user_ip): | |
| log_submission(scene, prompt, model, max_step, user_ip, "IP blocked temporarily") | |
| raise gr.Error("Too many requests from this IP. Please wait and try again one minute later.") | |
| # 提交任务到后端 | |
| submission_result = submit_to_backend(scene, prompt, model, max_step, user_ip) | |
| print("submission_result: ", submission_result) | |
| if submission_result.get("status") != "pending": | |
| log_submission(scene, prompt, model, max_step, user_ip, "Submission failed") | |
| raise gr.Error(f"Submission failed: {submission_result.get('message', 'unknown issue')}") | |
| try: | |
| task_id = submission_result["task_id"] | |
| SESSION_TASKS[session_id] = task_id | |
| gr.Info(f"Simulation started, task_id: {task_id}") | |
| time.sleep(5) | |
| # 获取任务状态 | |
| status = get_task_status(task_id) | |
| print("first status: ", status) | |
| result_folder = status.get("result", "") | |
| result_folder = "gradio_demo/tasks/" + task_id | |
| except Exception as e: | |
| log_submission(scene, prompt, model, max_step, user_ip, str(e)) | |
| raise gr.Error(f"error occurred when parsing submission result from backend: {str(e)}") | |
| # if not os.path.exists(result_folder): | |
| # log_submission(scene, prompt, model, max_step, user_ip, "Result folder provided by backend doesn't exist") | |
| # raise gr.Error(f"Result folder provided by backend doesn't exist: <PATH>{result_folder}") | |
| # 流式输出视频片段 | |
| try: | |
| for video_path in stream_simulation_results(result_folder, task_id, request): | |
| if video_path: | |
| yield video_path, history | |
| except Exception as e: | |
| log_submission(scene, prompt, model, max_step, user_ip, str(e)) | |
| raise gr.Error(f"Error while streaming: {str(e)}") | |
| # 获取任务状态 | |
| status = get_task_status(task_id) | |
| print("status: ", status) | |
| if status.get("status") == "completed": | |
| time.sleep(3) | |
| oss_video_path = os.path.join(result_folder, "manipulation.mp4") | |
| local_video_path = os.path.join(user_dir, task_id, "tasks", "manipulation.mp4") | |
| # download_oss_file(oss_video_path, local_video_path) | |
| print("oss_video_path: ", oss_video_path) | |
| print("local_video_path: ", local_video_path) | |
| video_path = generate_whole_video(task_id, request) | |
| # video_path = convert_to_h264(local_video_path) | |
| # 创建新的历史记录条目 | |
| new_entry = { | |
| "timestamp": timestamp, | |
| "scene": scene, | |
| "model": model, | |
| "prompt": prompt, | |
| "max_step": max_step, | |
| "video_path": video_path, | |
| "task_id": task_id | |
| } | |
| # 将新条目添加到历史记录顶部 | |
| updated_history = history + [new_entry] | |
| # 限制历史记录数量,避免内存问题 | |
| if len(updated_history) > 10: | |
| updated_history = updated_history[:10] | |
| print("updated_history:", updated_history) | |
| log_submission(scene, prompt, model, max_step, user_ip, "success") | |
| gr.Info("Simulation completed successfully!") | |
| yield None, updated_history | |
| elif status.get("status") == "failed": | |
| log_submission(scene, prompt, model, max_step, user_ip, status.get('result', 'backend error')) | |
| raise gr.Error(f"Task execution failed: {status.get('result', 'backend unknown issue')}") | |
| yield None, history | |
| elif status.get("status") == "terminated": | |
| log_submission(scene, prompt, model, max_step, user_ip, "user end terminated") | |
| yield None, history | |
| else: | |
| log_submission(scene, prompt, model, max_step, user_ip, "missing task's status from backend (Pending?)") | |
| raise gr.Error("missing task's status from backend (Pending?)") | |
| yield None, history | |
| def update_history_display(history: list) -> list: | |
| """更新历史记录显示""" | |
| print("更新历史记录显示") | |
| updates = [] | |
| for i in range(10): | |
| if i < len(history): # 如果有历史记录,更新对应槽位 | |
| entry = history[i] | |
| updates.extend([ | |
| gr.update(visible=True), # 更新 Column 可见性 | |
| gr.update(visible=True, label=f"# {i+1} | {entry['scene']} | {entry['model']} | {entry['prompt']}", open=(i+1==len(history))), # 更新 Accordion | |
| gr.update(value=entry['video_path'], visible=True, autoplay=False), # 更新 Video | |
| gr.update(value=f"{entry['timestamp']}") # 更新详细 Markdown | |
| ]) | |
| else: # 如果没有历史记录,隐藏槽位 | |
| updates.extend([ | |
| gr.update(visible=False), # 隐藏 Column | |
| gr.update(visible=False), # 隐藏 Accordion | |
| gr.update(value=None, visible=False), # 清空 Video | |
| gr.update(value="") # 清空详细 Markdown | |
| ]) | |
| print("更新完成!") | |
| return updates | |
| def update_scene_display(scene: str) -> tuple[str, Optional[str]]: | |
| """更新场景描述和预览图""" | |
| config = SCENE_CONFIGS.get(scene, {}) | |
| desc = config.get("description", "No description") | |
| objects = ", ".join(config.get("objects", [])) | |
| image = config.get("preview_image", None) | |
| markdown = f"**{desc}** \nObjects in this scene: {objects}" | |
| return markdown, image | |
| def update_log_display(): | |
| """更新日志显示""" | |
| logs = read_logs() | |
| return format_logs_for_display(logs) | |
| ############################################################################### | |
| def cleanup_session(req: gr.Request): | |
| session_id = req.session_hash | |
| task_id = SESSION_TASKS.pop(session_id, None) | |
| if task_id: | |
| try: | |
| status = get_task_status(task_id) | |
| print("clean up check status: ", status) | |
| if status.get("status") == "pending": | |
| res = terminate_task(task_id) | |
| if res.get("status") == "success": | |
| print(f"已终止任务 {task_id}") | |
| else: | |
| print(f"终止任务失败 {task_id}: {res.get('status', 'unknown issue')}") | |
| except Exception as e: | |
| print(f"终止任务失败 {task_id}: {e}") | |
| user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) | |
| shutil.rmtree(user_dir) |