From f47e422b1a1dc3cf2290654cf485105017714b0e Mon Sep 17 00:00:00 2001 From: Noe Brucy Date: Mon, 14 Jun 2021 20:22:35 +0200 Subject: [PATCH] [aggregator] add mpi backend --- aggregator.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/aggregator.py b/aggregator.py index 92b5746..3ae8db4 100644 --- a/aggregator.py +++ b/aggregator.py @@ -2,14 +2,21 @@ import numpy as np from functools import partial -from mypool import MyPool +import postprocessor -from postprocessor import PostProcessor +try: + from mpi4py.futures import MPIPoolExecutor + + mpi = True +except ModuleNotFoundError: + from mypool import MyPool + + mpi = False def _map_aux(fun, path, path_out, pp_params, run_num, **kwargs): try: - pp = PostProcessor( + pp = postprocessor.PostProcessor( path + "/" + run_num[0], run_num[1], path_out + "/" + run_num[0], pp_params ) except Exception as e: @@ -46,11 +53,19 @@ class Aggregator: map_fn = partial( _map_aux, func, self.path, self.path_out, self.pp_params, **kwargs ) - - pool = MyPool(processes=num_process, maxtasksperchild=1) - result = pool.map(map_fn, run_num) - pool.close() - pool.join() + if mpi: + executor = MPIPoolExecutor(max_workers=num_process) + try: + result = list(executor.map(map_fn, run_num, unordered=True)) + finally: + executor.shutdown() + else: + pool = MyPool(processes=num_process) + try: + result = pool.map(map_fn, run_num) + finally: + pool.close() + pool.join() return result