| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import os |
| import os.path as osp |
| import json |
| import torch |
| from models.modeling_xvla import XVLA |
| from models.processing_xvla import XVLAProcessor |
| import sys |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Launch XVLA inference FastAPI server") |
| parser.add_argument("--model_path", type=str, required=True, |
| help="Path to the pretrained XVLA model directory") |
| parser.add_argument('--processor_path', type=str, default=None) |
| parser.add_argument('--LoRA_path', type=str, default=None) |
| parser.add_argument("--output_dir", type=str, default="./logs", |
| help="Directory to save runtime info (info.json)") |
| parser.add_argument("--device", type=str, default="cuda", |
| help="Device to load model on (cuda / cpu / auto)") |
| parser.add_argument("--port", default=8010, type=int, |
| help="Port number for FastAPI server") |
| parser.add_argument("--host", default="0.0.0.0", type=str, |
| help="Host address for FastAPI server") |
| parser.add_argument("--disable_slurm", action="store_true", default=False) |
|
|
| args = parser.parse_args() |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| print("π Starting XVLA Inference Server...") |
| print(f"πΉ Model Path : {args.model_path}") |
| print(f"πΉ Output Dir : {args.output_dir}") |
| print(f"πΉ Device Arg : {args.device}") |
| print(f"πΉ Port : {args.port}") |
|
|
| |
| |
| |
| if args.device == "auto": |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| device = torch.device(args.device) |
| print(f"π§ Using device: {device}") |
|
|
| |
| |
| |
| processor = None |
| try: |
| print("\nπ§© Loading XVLAProcessor...") |
| processor_path = args.processor_path if args.processor_path else args.model_path |
| processor = XVLAProcessor.from_pretrained(processor_path) |
| print("β
XVLAProcessor loaded successfully.") |
| except Exception as e: |
| print(f"β οΈ No processor found or failed to load: {e}") |
|
|
| |
| |
| |
| print("\nπ¦ Loading XVLA model from pretrained checkpoint...") |
| try: |
| model = XVLA.from_pretrained( |
| args.model_path, |
| trust_remote_code=True, |
| torch_dtype=torch.float32 |
| ).to(device).to(torch.float32) |
| |
| if args.LoRA_path is not None: |
| print(f"πΈ Applying LoRA weights from {args.LoRA_path} ...") |
| from peft import PeftModel |
| model = PeftModel.from_pretrained( |
| model, |
| args.LoRA_path, |
| torch_dtype=torch.float32, |
| ).to(device) |
| |
| print("β
LoRA weights applied successfully.") |
| |
| |
| print("β
Model successfully loaded and moved to device.") |
| except Exception as e: |
| print(f"β Failed to load model: {e}") |
| return |
|
|
| |
| |
| |
| node_list = os.environ.get("SLURM_NODELIST") |
| job_id = os.environ.get("SLURM_JOB_ID", "none") |
|
|
| if node_list and not args.disable_slurm: |
| print("\nπ₯οΈ SLURM Environment Detected:") |
| print(f" Node list : {node_list}") |
| print(f" Job ID : {job_id}") |
|
|
| |
| try: |
| host = ".".join(node_list.split("-")[1:]) if "-" in node_list else node_list |
| except Exception: |
| host = args.host |
| else: |
| print("\nβ οΈ No SLURM environment detected, defaulting to 0.0.0.0") |
| host = args.host |
|
|
| |
| |
| |
| info_path = osp.join(args.output_dir, "info.json") |
| infos = { |
| "host": host, |
| "port": args.port, |
| "job_id": job_id, |
| "node_list": node_list or "none", |
| } |
|
|
| |
| if osp.exists(info_path): |
| print(f"β Error: {info_path} already exists. " |
| f"This usually means another server is still running or the previous job did not clean up properly.") |
| print("π Please remove it manually or use a different --output_dir.") |
| sys.exit(1) |
|
|
| |
| try: |
| with open(info_path, "w") as f: |
| json.dump(infos, f, indent=4) |
| print(f"π Server info written to {info_path}") |
| except Exception as e: |
| print(f"β οΈ Failed to write {info_path}: {e}") |
| sys.exit(1) |
|
|
| |
| |
| |
| print(f"\nπ Launching FastAPI service at http://{host}:{args.port} ...") |
| try: |
| if hasattr(model, "run"): |
| model.run(processor=processor, host=host, port=args.port) |
| else: |
| print("β The loaded model does not implement `.run()` (FastAPI entrypoint).") |
| except KeyboardInterrupt: |
| print("\nπ Server stopped manually.") |
| except Exception as e: |
| print(f"β Server failed to start: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|