Spaces:
Build error
Build error
liuyiyang01
commited on
Commit
·
6964b3e
1
Parent(s):
4f482ac
oss support
Browse files- app.py +1 -1
- app_utils.py +30 -14
app.py
CHANGED
|
@@ -24,7 +24,7 @@ header_html = """
|
|
| 24 |
</div>
|
| 25 |
<div style="display: flex; gap: 15px; align-items: center;">
|
| 26 |
<a href="https://github.com/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'">
|
| 27 |
-
<img src="
|
| 28 |
</a>
|
| 29 |
<a href="https://huggingface.co/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'">
|
| 30 |
<img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="HuggingFace" style="height: 30px;">
|
|
|
|
| 24 |
</div>
|
| 25 |
<div style="display: flex; gap: 15px; align-items: center;">
|
| 26 |
<a href="https://github.com/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'">
|
| 27 |
+
<img src="https://github.githubassets.com/images/modules/logos_page/GitHub-Mark.png" alt="GitHub" style="height: 30px;">
|
| 28 |
</a>
|
| 29 |
<a href="https://huggingface.co/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'">
|
| 30 |
<img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="HuggingFace" style="height: 30px;">
|
app_utils.py
CHANGED
|
@@ -19,7 +19,7 @@ os.makedirs(TMP_ROOT, exist_ok=True)
|
|
| 19 |
|
| 20 |
|
| 21 |
# 后端API配置(可配置化)
|
| 22 |
-
BACKEND_URL = os.getenv("BACKEND_URL")
|
| 23 |
API_ENDPOINTS = {
|
| 24 |
"submit_task": f"{BACKEND_URL}/predict/video",
|
| 25 |
"query_status": f"{BACKEND_URL}/predict/task",
|
|
@@ -183,6 +183,13 @@ def download_oss_file(oss_path: str, local_path: str):
|
|
| 183 |
"""从OSS下载文件到本地"""
|
| 184 |
bucket.get_object_to_file(oss_path, local_path)
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
def stream_simulation_results(result_folder: str, task_id: str, request: gr.Request, fps: int = 30):
|
| 187 |
"""
|
| 188 |
流式输出仿真结果,从OSS读取图片
|
|
@@ -190,8 +197,9 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
|
|
| 190 |
参数:
|
| 191 |
result_folder: OSS上包含生成图片的文件夹路径
|
| 192 |
task_id: 后端任务ID用于状态查询
|
| 193 |
-
fps: 输出视频的帧率
|
| 194 |
request: Gradio请求对象
|
|
|
|
|
|
|
| 195 |
|
| 196 |
生成:
|
| 197 |
生成的视频文件路径 (分段输出)
|
|
@@ -204,12 +212,13 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
|
|
| 204 |
processed_files = set()
|
| 205 |
width, height = 0, 0
|
| 206 |
last_status_check = 0
|
| 207 |
-
status_check_interval =
|
| 208 |
max_time = 240
|
| 209 |
|
| 210 |
# 创建临时目录存储下载的图片
|
| 211 |
user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
|
| 212 |
-
local_image_dir = os.path.join(user_dir, "tasks", "images")
|
|
|
|
| 213 |
os.makedirs(local_image_dir, exist_ok=True)
|
| 214 |
|
| 215 |
while max_time > 0:
|
|
@@ -219,7 +228,7 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
|
|
| 219 |
# 定期检查后端状态
|
| 220 |
if current_time - last_status_check > status_check_interval:
|
| 221 |
status = get_task_status(task_id)
|
| 222 |
-
print("status: ", status)
|
| 223 |
if status.get("status") == "completed":
|
| 224 |
# 确保处理完所有已生成的图片
|
| 225 |
process_remaining_oss_images(image_folder, local_image_dir, processed_files, frame_buffer)
|
|
@@ -275,7 +284,7 @@ def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height:
|
|
| 275 |
"""创建视频片段"""
|
| 276 |
user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
|
| 277 |
os.makedirs(user_dir, exist_ok=True)
|
| 278 |
-
video_chunk_path = os.path.join(user_dir, "
|
| 279 |
os.makedirs(video_chunk_path, exist_ok=True)
|
| 280 |
segment_name = os.path.join(video_chunk_path, f"output_{uuid.uuid4()}.mp4")
|
| 281 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
@@ -364,7 +373,7 @@ def get_task_status(task_id: str) -> dict:
|
|
| 364 |
)
|
| 365 |
return response.json()
|
| 366 |
except Exception as e:
|
| 367 |
-
return {"status": "error", "message": str(e)}
|
| 368 |
|
| 369 |
def terminate_task(task_id: str) -> Optional[dict]:
|
| 370 |
"""
|
|
@@ -431,6 +440,7 @@ def run_simulation(
|
|
| 431 |
# 记录用户提交
|
| 432 |
user_ip = request.client.host if request else "unknown"
|
| 433 |
session_id = request.session_hash
|
|
|
|
| 434 |
|
| 435 |
if not is_request_allowed(user_ip):
|
| 436 |
log_submission(scene, prompt, model, max_step, user_ip, "IP blocked temporarily")
|
|
@@ -454,19 +464,20 @@ def run_simulation(
|
|
| 454 |
status = get_task_status(task_id)
|
| 455 |
print("first status: ", status)
|
| 456 |
result_folder = status.get("result", "")
|
|
|
|
| 457 |
except Exception as e:
|
| 458 |
log_submission(scene, prompt, model, max_step, user_ip, str(e))
|
| 459 |
raise gr.Error(f"error occurred when parsing submission result from backend: {str(e)}")
|
| 460 |
|
| 461 |
|
| 462 |
-
if not os.path.exists(result_folder):
|
| 463 |
-
|
| 464 |
-
|
| 465 |
|
| 466 |
|
| 467 |
# 流式输出视频片段
|
| 468 |
try:
|
| 469 |
-
for video_path in stream_simulation_results(result_folder, task_id):
|
| 470 |
if video_path:
|
| 471 |
yield video_path, history
|
| 472 |
except Exception as e:
|
|
@@ -477,9 +488,14 @@ def run_simulation(
|
|
| 477 |
status = get_task_status(task_id)
|
| 478 |
print("status: ", status)
|
| 479 |
if status.get("status") == "completed":
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
|
| 484 |
# 创建新的历史记录条目
|
| 485 |
new_entry = {
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
# 后端API配置(可配置化)
|
| 22 |
+
BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000")
|
| 23 |
API_ENDPOINTS = {
|
| 24 |
"submit_task": f"{BACKEND_URL}/predict/video",
|
| 25 |
"query_status": f"{BACKEND_URL}/predict/task",
|
|
|
|
| 183 |
"""从OSS下载文件到本地"""
|
| 184 |
bucket.get_object_to_file(oss_path, local_path)
|
| 185 |
|
| 186 |
+
def oss_file_exists(oss_path):
|
| 187 |
+
try:
|
| 188 |
+
# Assuming you have an OSS bucket object
|
| 189 |
+
return bucket.object_exists(oss_path)
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"Error checking if file exists in OSS: {str(e)}")
|
| 192 |
+
return False
|
| 193 |
def stream_simulation_results(result_folder: str, task_id: str, request: gr.Request, fps: int = 30):
|
| 194 |
"""
|
| 195 |
流式输出仿真结果,从OSS读取图片
|
|
|
|
| 197 |
参数:
|
| 198 |
result_folder: OSS上包含生成图片的文件夹路径
|
| 199 |
task_id: 后端任务ID用于状态查询
|
|
|
|
| 200 |
request: Gradio请求对象
|
| 201 |
+
fps: 输出视频的帧率
|
| 202 |
+
|
| 203 |
|
| 204 |
生成:
|
| 205 |
生成的视频文件路径 (分段输出)
|
|
|
|
| 212 |
processed_files = set()
|
| 213 |
width, height = 0, 0
|
| 214 |
last_status_check = 0
|
| 215 |
+
status_check_interval = 1 # 每5秒检查一次后端状态
|
| 216 |
max_time = 240
|
| 217 |
|
| 218 |
# 创建临时目录存储下载的图片
|
| 219 |
user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
|
| 220 |
+
local_image_dir = os.path.join(user_dir, task_id, "tasks", "images")
|
| 221 |
+
|
| 222 |
os.makedirs(local_image_dir, exist_ok=True)
|
| 223 |
|
| 224 |
while max_time > 0:
|
|
|
|
| 228 |
# 定期检查后端状态
|
| 229 |
if current_time - last_status_check > status_check_interval:
|
| 230 |
status = get_task_status(task_id)
|
| 231 |
+
print(str(request.session_hash), "status: ", status)
|
| 232 |
if status.get("status") == "completed":
|
| 233 |
# 确保处理完所有已生成的图片
|
| 234 |
process_remaining_oss_images(image_folder, local_image_dir, processed_files, frame_buffer)
|
|
|
|
| 284 |
"""创建视频片段"""
|
| 285 |
user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
|
| 286 |
os.makedirs(user_dir, exist_ok=True)
|
| 287 |
+
video_chunk_path = os.path.join(user_dir, "video_chunk")
|
| 288 |
os.makedirs(video_chunk_path, exist_ok=True)
|
| 289 |
segment_name = os.path.join(video_chunk_path, f"output_{uuid.uuid4()}.mp4")
|
| 290 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
|
|
| 373 |
)
|
| 374 |
return response.json()
|
| 375 |
except Exception as e:
|
| 376 |
+
return {"status": "error get_task_status", "message": str(e)}
|
| 377 |
|
| 378 |
def terminate_task(task_id: str) -> Optional[dict]:
|
| 379 |
"""
|
|
|
|
| 440 |
# 记录用户提交
|
| 441 |
user_ip = request.client.host if request else "unknown"
|
| 442 |
session_id = request.session_hash
|
| 443 |
+
user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
|
| 444 |
|
| 445 |
if not is_request_allowed(user_ip):
|
| 446 |
log_submission(scene, prompt, model, max_step, user_ip, "IP blocked temporarily")
|
|
|
|
| 464 |
status = get_task_status(task_id)
|
| 465 |
print("first status: ", status)
|
| 466 |
result_folder = status.get("result", "")
|
| 467 |
+
result_folder = "gradio_demo/tasks/" + task_id
|
| 468 |
except Exception as e:
|
| 469 |
log_submission(scene, prompt, model, max_step, user_ip, str(e))
|
| 470 |
raise gr.Error(f"error occurred when parsing submission result from backend: {str(e)}")
|
| 471 |
|
| 472 |
|
| 473 |
+
# if not os.path.exists(result_folder):
|
| 474 |
+
# log_submission(scene, prompt, model, max_step, user_ip, "Result folder provided by backend doesn't exist")
|
| 475 |
+
# raise gr.Error(f"Result folder provided by backend doesn't exist: <PATH>{result_folder}")
|
| 476 |
|
| 477 |
|
| 478 |
# 流式输出视频片段
|
| 479 |
try:
|
| 480 |
+
for video_path in stream_simulation_results(result_folder, task_id, request):
|
| 481 |
if video_path:
|
| 482 |
yield video_path, history
|
| 483 |
except Exception as e:
|
|
|
|
| 488 |
status = get_task_status(task_id)
|
| 489 |
print("status: ", status)
|
| 490 |
if status.get("status") == "completed":
|
| 491 |
+
# time.sleep(3)
|
| 492 |
+
oss_video_path = os.path.join(result_folder, "manipulation.mp4")
|
| 493 |
+
local_video_path = os.path.join(user_dir, task_id, "tasks", "manipulation.mp4")
|
| 494 |
+
download_oss_file(oss_video_path, local_video_path)
|
| 495 |
+
print("oss_video_path: ", oss_video_path)
|
| 496 |
+
print("local_video_path: ", local_video_path)
|
| 497 |
+
|
| 498 |
+
video_path = convert_to_h264(local_video_path)
|
| 499 |
|
| 500 |
# 创建新的历史记录条目
|
| 501 |
new_entry = {
|