| | import argparse |
| | from dataclasses import dataclass, field |
| | import json |
| | import copy |
| | import multiprocessing as mp |
| | import uuid |
| | from datetime import datetime, timedelta |
| | from collections import defaultdict, deque |
| | import io |
| | import zipfile |
| | import queue |
| | import time |
| | import random |
| | import logging |
| |
|
| | from tensordict import TensorDict |
| | import cv2 |
| | from flask import Flask, request, make_response, send_file |
| | from PIL import Image |
| | import torchvision.transforms as T |
| | import numpy as np |
| | import torch as th |
| |
|
| | from wham.utils import load_model_from_checkpoint, POS_BINS_BOUNDARIES, POS_BINS_MIDDLE |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| |
|
| | parser = argparse.ArgumentParser(description="Simple Dreamer") |
| | parser.add_argument("--model", type=str, required=True, help="Path to the model file for the local runs") |
| | parser.add_argument("--debug", action="store_true", help="Enable flask debug mode.") |
| | parser.add_argument("--random_model", action="store_true", help="Use randomly initialized model instead of the provided one") |
| | parser.add_argument("--port", type=int, default=5000) |
| |
|
| | parser.add_argument("--max_concurrent_jobs", type=int, default=30, help="Maximum number of jobs that can be run concurrently on this server.") |
| | parser.add_argument("--max_dream_steps_per_job", type=int, default=10, help="Maximum number of dream steps each job can request.") |
| | parser.add_argument("--max_job_lifespan", type=int, default=60 * 10, help="Maximum number of seconds we keep run around if not polled.") |
| |
|
| | parser.add_argument("--image_width", type=int, default=300, help="Width of the image") |
| | parser.add_argument("--image_height", type=int, default=180, help="Height of the image") |
| |
|
| | parser.add_argument("--max_batch_size", type=int, default=3, help="Maximum batch size for the dreamer workers") |
| |
|
| | PREDICTION_JSON_FILENAME = "predictions.json" |
| | |
| | JOB_CLEANUP_CHECK_RATE = timedelta(seconds=10) |
| |
|
| | MAX_CANCELLED_ID_QUEUE_SIZE = 100 |
| |
|
| | DEFAULT_SAMPLING_SETTINGS = { |
| | "temperature": 0.9, |
| | "top_k": None, |
| | "top_p": 1.0, |
| | "max_context_length": 10, |
| | } |
| |
|
| |
|
| | def float_or_none(string): |
| | if string.lower() == "none": |
| | return None |
| | return float(string) |
| |
|
| |
|
| | def be_image_preprocess(image, target_width, target_height): |
| | |
| | if target_width is not None and target_height is not None: |
| | |
| | if image.shape[1] != target_width or image.shape[0] != target_height: |
| | image = cv2.resize(image, (target_width, target_height)) |
| | return np.transpose(image, (2, 0, 1)) |
| |
|
| |
|
| | def action_vector_to_be_action_vector(action): |
| | |
| | |
| | |
| | |
| | action[-4:] = np.digitize(action[-4:], bins=POS_BINS_BOUNDARIES) - 1 |
| | return action |
| |
|
| |
|
| | def be_action_vector_to_action_vector(action): |
| | |
| | for stick_index in range(-4, 0): |
| | action[stick_index] = POS_BINS_MIDDLE[int(action[stick_index])] |
| | return action |
| |
|
| |
|
| |
|
| | @dataclass |
| | class DreamJob: |
| | job_id: str |
| | sampling_settings: dict |
| | num_predictions_remaining: int |
| | num_predictions_done: int |
| | |
| | context_images: th.Tensor |
| | context_actions: th.Tensor |
| | |
| | context_tokens: list |
| | |
| | |
| | actions_to_take: th.Tensor = None |
| |
|
| |
|
| | @dataclass |
| | class DreamJobResult: |
| | job_id: str |
| | dream_step_index: int |
| | |
| | dreamt_image: th.Tensor |
| | dreamt_action: th.Tensor |
| | dreamt_tokens: th.Tensor |
| | result_creation_time: datetime = field(default_factory=datetime.now) |
| |
|
| |
|
| |
|
| | def setup_and_load_model_be_model(args): |
| | model = load_model_from_checkpoint(args.model) |
| | th.set_float32_matmul_precision("high") |
| | th.backends.cuda.matmul.allow_tf32 = True |
| | return model |
| |
|
| |
|
| | def get_job_batchable_information(job): |
| | """Return comparable object of job information. Used for batching""" |
| | context_length = job.context_images.shape[1] |
| | return (context_length, job.sampling_settings) |
| |
|
| |
|
| | def fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size, timeout=1): |
| | """Return a list of jobs (or empty list) that can be batched together""" |
| | batchable_jobs = [] |
| | required_job_info = None |
| | while len(batchable_jobs) < max_batch_size: |
| | try: |
| | job = job_queue.get(timeout=timeout) |
| | except queue.Empty: |
| | break |
| | |
| | except OSError: |
| | break |
| | if job.job_id in cancelled_ids_set: |
| | |
| | continue |
| | job_info = get_job_batchable_information(job) |
| | if required_job_info is None: |
| | required_job_info = job_info |
| | elif required_job_info != job_info: |
| | |
| | job_queue.put(job) |
| | |
| | |
| | |
| | break |
| | batchable_jobs.append(job) |
| | return batchable_jobs |
| |
|
| |
|
| | def update_cancelled_jobs(cancelled_ids_queue, cancelled_ids_deque, cancelled_ids_set): |
| | """IN-PLACE Update cancelled_ids_set with new ids from the queue""" |
| | has_changed = False |
| | while not cancelled_ids_queue.empty(): |
| | try: |
| | cancelled_id = cancelled_ids_queue.get_nowait() |
| | except queue.Empty: |
| | break |
| | cancelled_ids_deque.append(cancelled_id) |
| | has_changed = True |
| |
|
| | if has_changed: |
| | cancelled_ids_set.clear() |
| | cancelled_ids_set.update(cancelled_ids_deque) |
| |
|
| |
|
| | def predict_step(context_data, sampling_settings, model, tokens=None): |
| | with th.no_grad(): |
| | predicted_step = model.predict_next_step(context_data, min_tokens_to_keep=1, tokens=tokens, **sampling_settings) |
| | return predicted_step |
| |
|
| |
|
| | def dreamer_worker(job_queue, result_queue, cancelled_jobs_queue, quit_flag, device_to_use, args): |
| | logger = logging.getLogger(f"dreamer_worker {device_to_use}") |
| | logger.info("Loading up model...") |
| | model = setup_and_load_model_be_model(args) |
| | model = model.to(device_to_use) |
| | logger.info("Model loaded. Fetching results") |
| |
|
| | cancelled_ids_deque = deque(maxlen=MAX_CANCELLED_ID_QUEUE_SIZE) |
| | cancelled_ids_set = set() |
| |
|
| | while not quit_flag.is_set(): |
| | update_cancelled_jobs(cancelled_jobs_queue, cancelled_ids_deque, cancelled_ids_set) |
| | batchable_jobs = fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size=args.max_batch_size) |
| | if len(batchable_jobs) == 0: |
| | continue |
| | sampling_settings = batchable_jobs[0].sampling_settings |
| | |
| | |
| | |
| | max_context_length = sampling_settings.pop("max_context_length") |
| |
|
| | images = [job.context_images[:, :max_context_length] for job in batchable_jobs] |
| | actions = [job.context_actions[:, :max_context_length] for job in batchable_jobs] |
| | tokens = [job.context_tokens for job in batchable_jobs] |
| |
|
| | images = th.concat(images, dim=0).to(device_to_use) |
| | actions = th.concat(actions, dim=0).to(device_to_use) |
| |
|
| | context_data = TensorDict({ |
| | "images": images, |
| | "actions_output": actions |
| | }, batch_size=images.shape[:2]) |
| |
|
| | predicted_step, predicted_image_tokens = predict_step(context_data, sampling_settings, model, tokens) |
| |
|
| | predicted_step = predicted_step.cpu() |
| | predicted_images = predicted_step["images"] |
| | predicted_actions = predicted_step["actions_output"] |
| | predicted_image_tokens = predicted_image_tokens.cpu() |
| |
|
| | for job_i, job in enumerate(batchable_jobs): |
| | image_context = job.context_images |
| | action_context = job.context_actions |
| | token_context = job.context_tokens |
| | |
| | dreamt_image = predicted_images[job_i].unsqueeze(0) |
| | dreamt_action = predicted_actions[job_i].unsqueeze(0) |
| | dreamt_tokens = predicted_image_tokens[job_i].unsqueeze(0) |
| |
|
| | |
| | actions_to_take = job.actions_to_take |
| | if actions_to_take is not None and actions_to_take.shape[1] > 0: |
| | dreamt_action = actions_to_take[:, 0:1] |
| | |
| | actions_to_take = actions_to_take[:, 1:] |
| | if actions_to_take.shape[1] == 0: |
| | actions_to_take = None |
| |
|
| | result_queue.put(DreamJobResult( |
| | job_id=job.job_id, |
| | dream_step_index=job.num_predictions_done, |
| | dreamt_image=dreamt_image, |
| | dreamt_action=dreamt_action, |
| | dreamt_tokens=dreamt_tokens |
| | )) |
| |
|
| | |
| | if job.num_predictions_remaining > 0: |
| | |
| | if image_context.shape[1] >= max_context_length: |
| | image_context = image_context[:, 1:] |
| | action_context = action_context[:, 1:] |
| | token_context = token_context[1:] |
| | image_context = th.cat([image_context, dreamt_image], dim=1) |
| | action_context = th.cat([action_context, dreamt_action], dim=1) |
| | token_context.append(dreamt_tokens[0, 0].tolist()) |
| | |
| | |
| | job.sampling_settings["max_context_length"] = max_context_length |
| | job_queue.put(DreamJob( |
| | job_id=job.job_id, |
| | sampling_settings=job.sampling_settings, |
| | num_predictions_remaining=job.num_predictions_remaining - 1, |
| | num_predictions_done=job.num_predictions_done + 1, |
| | context_images=image_context, |
| | context_actions=action_context, |
| | context_tokens=token_context, |
| | actions_to_take=actions_to_take |
| | )) |
| |
|
| |
|
| | class DreamerServer: |
| | def __init__(self, num_workers, args): |
| | self.num_workers = num_workers |
| | self.args = args |
| | self.model = None |
| | self.jobs = mp.Queue(maxsize=args.max_concurrent_jobs) |
| | self.results_queue = mp.Queue() |
| | self.cancelled_jobs = set() |
| | self.cancelled_jobs_queues = [mp.Queue() for _ in range(num_workers)] |
| | |
| | self._last_result_cleanup = datetime.now() |
| | self._max_job_lifespan_datetime = timedelta(seconds=args.max_job_lifespan) |
| | self.local_results = defaultdict(list) |
| | self.logger = logging.getLogger("DreamerServer") |
| |
|
| | def get_details(self): |
| | details = { |
| | "model_file": self.args.model, |
| | "max_concurrent_jobs": self.args.max_concurrent_jobs, |
| | "max_dream_steps_per_job": self.args.max_dream_steps_per_job, |
| | "max_job_lifespan": self.args.max_job_lifespan, |
| | } |
| | return json.dumps(details) |
| |
|
| | def _check_if_should_remove_old_jobs(self): |
| | time_now = datetime.now() |
| | |
| | if time_now - self._last_result_cleanup < JOB_CLEANUP_CHECK_RATE: |
| | return |
| |
|
| | self._last_result_cleanup = time_now |
| | |
| | self._gather_new_results() |
| | |
| | job_ids = list(self.local_results.keys()) |
| | for job_id in job_ids: |
| | results = self.local_results[job_id] |
| | |
| | if time_now - results[-1].result_creation_time > self._max_job_lifespan_datetime: |
| | self.logger.info(f"Deleted job {job_id} because it was too old. Last result was {results[-1].result_creation_time}") |
| | del self.local_results[job_id] |
| |
|
| | def add_new_job(self, request, request_json): |
| | """ |
| | Add new dreaming job to the queues. |
| | Request should have: |
| | |
| | |
| | Returns: json object with new job id |
| | """ |
| | self._check_if_should_remove_old_jobs() |
| |
|
| | sampling_settings = copy.deepcopy(DEFAULT_SAMPLING_SETTINGS) |
| | if "num_steps_to_predict" not in request_json: |
| | return make_response("num_steps_to_predict not in request", 400) |
| | num_steps_to_predict = request_json['num_steps_to_predict'] |
| | if num_steps_to_predict > self.args.max_dream_steps_per_job: |
| | return make_response(f"num_steps_to_predict too large. Max {self.args.max_dream_steps_per_job}", 400) |
| |
|
| | num_parallel_predictions = int(request_json['num_parallel_predictions']) if 'num_parallel_predictions' in request_json else 1 |
| |
|
| | if (self.jobs.qsize() + num_parallel_predictions) >= self.args.max_concurrent_jobs: |
| | return make_response(f"Too many jobs already running. Max {self.args.max_concurrent_jobs}", 400) |
| |
|
| | for key in sampling_settings: |
| | sampling_settings[key] = float_or_none(request_json[key]) if key in request_json else sampling_settings[key] |
| |
|
| | context_images = [] |
| | context_actions = [] |
| | context_tokens = [] |
| | future_actions = [] |
| |
|
| | for step in request_json["steps"]: |
| | image_path = step["image_name"] |
| | image = np.array(Image.open(request.files[image_path].stream)) |
| | image = be_image_preprocess(image, target_width=self.args.image_width, target_height=self.args.image_height) |
| | context_images.append(th.from_numpy(image)) |
| |
|
| | action = step["action"] |
| | action = action_vector_to_be_action_vector(action) |
| | context_actions.append(th.tensor(action)) |
| |
|
| | tokens = step["tokens"] |
| | context_tokens.append(tokens) |
| |
|
| | future_actions = None |
| | if "future_actions" in request_json: |
| | future_actions = [] |
| | for step in request_json["future_actions"]: |
| | |
| | action = step["action"] |
| | action = action_vector_to_be_action_vector(action) |
| | |
| | future_actions.append(th.tensor(action)) |
| |
|
| | |
| | context_images = th.stack(context_images).unsqueeze(0) |
| | context_actions = th.stack(context_actions).unsqueeze(0) |
| | future_actions = th.stack(future_actions).unsqueeze(0) if future_actions is not None else None |
| |
|
| | list_of_job_ids = [] |
| | for _ in range(num_parallel_predictions): |
| | job_id = uuid.uuid4().hex |
| | self.jobs.put(DreamJob( |
| | job_id=job_id, |
| | sampling_settings=sampling_settings, |
| | num_predictions_remaining=num_steps_to_predict, |
| | num_predictions_done=0, |
| | context_images=context_images, |
| | context_actions=context_actions, |
| | context_tokens=context_tokens, |
| | actions_to_take=future_actions |
| | )) |
| | list_of_job_ids.append(job_id) |
| |
|
| | job_queue_size = self.jobs.qsize() |
| | return json.dumps({"job_ids": list_of_job_ids, "current_jobs_in_queue": job_queue_size}) |
| |
|
| | def _gather_new_results(self): |
| | if not self.results_queue.empty(): |
| | for _ in range(self.results_queue.qsize()): |
| | result = self.results_queue.get() |
| | if result.job_id in self.cancelled_jobs: |
| | |
| | continue |
| | self.local_results[result.job_id].append(result) |
| |
|
| | def get_new_results(self, request, request_json): |
| | if "job_ids" not in request_json: |
| | return make_response("job_ids not in request", 400) |
| | self._gather_new_results() |
| | job_ids = request_json["job_ids"] |
| | if not isinstance(job_ids, list): |
| | job_ids = [job_ids] |
| | return_results = [] |
| | for job_id in job_ids: |
| | if job_id in self.local_results: |
| | return_results.append(self.local_results[job_id]) |
| | del self.local_results[job_id] |
| |
|
| | if len(return_results) == 0: |
| | return make_response("No new responses", 204) |
| |
|
| | output_json = [] |
| | output_image_bytes = {} |
| | for job_results in return_results: |
| | for result in job_results: |
| | action = result.dreamt_action.numpy() |
| | |
| | action = be_action_vector_to_action_vector(action[0, 0].tolist()) |
| | dreamt_tokens = result.dreamt_tokens[0, 0].tolist() |
| | image_filename = f"{result.job_id}_{result.dream_step_index}.png" |
| | output_json.append({ |
| | "job_id": result.job_id, |
| | "dream_step_index": result.dream_step_index, |
| | "action": action, |
| | "tokens": dreamt_tokens, |
| | "image_filename": image_filename |
| | }) |
| |
|
| | image_bytes = io.BytesIO() |
| | |
| | T.ToPILImage()(result.dreamt_image[0, 0]).save(image_bytes, format="PNG") |
| | output_image_bytes[image_filename] = image_bytes.getvalue() |
| |
|
| | |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] |
| | zip_bytes = io.BytesIO() |
| | with zipfile.ZipFile(zip_bytes, "w") as z: |
| | for filename, bytes in output_image_bytes.items(): |
| | z.writestr(filename, bytes) |
| | |
| | z.writestr(PREDICTION_JSON_FILENAME, json.dumps(output_json)) |
| |
|
| | zip_bytes.seek(0) |
| |
|
| | return send_file( |
| | zip_bytes, |
| | mimetype="zip", |
| | as_attachment=True, |
| | download_name=f"dreaming_results_{timestamp}.zip" |
| | ) |
| |
|
| | def cancel_job(self, request, request_json): |
| | if "job_id" not in request_json: |
| | return make_response("job_id not in request", 400) |
| | job_id = request_json["job_id"] |
| | self.cancelled_jobs.add(job_id) |
| | |
| | for job_queue in self.cancelled_jobs_queues: |
| | job_queue.put(job_id) |
| | return make_response("OK", 200) |
| |
|
| |
|
| | def main_run(args): |
| | app = Flask(__name__) |
| |
|
| | num_workers = th.cuda.device_count() |
| | if num_workers == 0: |
| | raise RuntimeError("No CUDA devices found. Cannot run Dreamer.") |
| |
|
| | server = DreamerServer(num_workers, args) |
| | quit_flag = mp.Event() |
| |
|
| | |
| | dreamer_worker_processes = [] |
| | for device_i in range(num_workers): |
| | device = f"cuda:{device_i}" |
| | dreamer_worker_process = mp.Process( |
| | target=dreamer_worker, |
| | args=(server.jobs, server.results_queue, server.cancelled_jobs_queues[device_i], quit_flag, device, args) |
| | ) |
| | dreamer_worker_process.daemon = True |
| | dreamer_worker_process.start() |
| | dreamer_worker_processes.append(dreamer_worker_process) |
| |
|
| | |
| | @app.route('/') |
| | def details(): |
| | return server.get_details() |
| |
|
| | @app.route('/new_job', methods=['POST']) |
| | def new_job(): |
| | request_json = json.loads(request.form["json"]) |
| | return server.add_new_job(request, request_json) |
| |
|
| | @app.route('/get_job_results', methods=['GET']) |
| | def get_results(): |
| | |
| | request_json = {"job_ids": request.args.getlist("job_ids")} |
| | return server.get_new_results(request, request_json) |
| |
|
| | @app.route('/cancel_job', methods=['GET']) |
| | def cancel_job(): |
| | request_json = request.args.to_dict() |
| | return server.cancel_job(request, request_json) |
| |
|
| | app.run(host="0.0.0.0", port=args.port, debug=args.debug) |
| |
|
| | |
| | quit_flag.set() |
| | for dreamer_worker_process in dreamer_worker_processes: |
| | dreamer_worker_process.join() |
| |
|
| |
|
| | if __name__ == '__main__': |
| | args = parser.parse_args() |
| | main_run(args) |
| |
|