| | import torch |
| | import torch.nn as nn |
| | import threading |
| | from torch._utils import ExceptionWrapper |
| | import logging |
| |
|
| | def get_a_var(obj): |
| | if isinstance(obj, torch.Tensor): |
| | return obj |
| |
|
| | if isinstance(obj, list) or isinstance(obj, tuple): |
| | for result in map(get_a_var, obj): |
| | if isinstance(result, torch.Tensor): |
| | return result |
| | if isinstance(obj, dict): |
| | for result in map(get_a_var, obj.items()): |
| | if isinstance(result, torch.Tensor): |
| | return result |
| | return None |
| |
|
| | def parallel_apply(fct, model, inputs, device_ids): |
| | modules = nn.parallel.replicate(model, device_ids) |
| | assert len(modules) == len(inputs) |
| | lock = threading.Lock() |
| | results = {} |
| | grad_enabled = torch.is_grad_enabled() |
| |
|
| | def _worker(i, module, input): |
| | torch.set_grad_enabled(grad_enabled) |
| | device = get_a_var(input).get_device() |
| | try: |
| | with torch.cuda.device(device): |
| | |
| | if not isinstance(input, (list, tuple)): |
| | input = (input,) |
| | output = fct(module, *input) |
| | with lock: |
| | results[i] = output |
| | except Exception: |
| | with lock: |
| | results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device)) |
| |
|
| | if len(modules) > 1: |
| | threads = [threading.Thread(target=_worker, args=(i, module, input)) |
| | for i, (module, input) in enumerate(zip(modules, inputs))] |
| |
|
| | for thread in threads: |
| | thread.start() |
| | for thread in threads: |
| | thread.join() |
| | else: |
| | _worker(0, modules[0], inputs[0]) |
| |
|
| | outputs = [] |
| | for i in range(len(inputs)): |
| | output = results[i] |
| | if isinstance(output, ExceptionWrapper): |
| | output.reraise() |
| | outputs.append(output) |
| | return outputs |
| |
|
| | def get_logger(filename=None): |
| | logger = logging.getLogger('logger') |
| | logger.setLevel(logging.DEBUG) |
| | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', |
| | datefmt='%m/%d/%Y %H:%M:%S', |
| | level=logging.INFO) |
| | if filename is not None: |
| | handler = logging.FileHandler(filename) |
| | handler.setLevel(logging.DEBUG) |
| | handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) |
| | logging.getLogger().addHandler(handler) |
| | return logger |