[aggregator] add mpi backend

This commit is contained in:
Noe Brucy
2021-06-14 20:22:35 +02:00
parent 4b83de26bf
commit f47e422b1a
+19 -4
View File
@@ -2,14 +2,21 @@
import numpy as np import numpy as np
from functools import partial from functools import partial
import postprocessor
try:
from mpi4py.futures import MPIPoolExecutor
mpi = True
except ModuleNotFoundError:
from mypool import MyPool from mypool import MyPool
from postprocessor import PostProcessor mpi = False
def _map_aux(fun, path, path_out, pp_params, run_num, **kwargs): def _map_aux(fun, path, path_out, pp_params, run_num, **kwargs):
try: try:
pp = PostProcessor( pp = postprocessor.PostProcessor(
path + "/" + run_num[0], run_num[1], path_out + "/" + run_num[0], pp_params path + "/" + run_num[0], run_num[1], path_out + "/" + run_num[0], pp_params
) )
except Exception as e: except Exception as e:
@@ -46,9 +53,17 @@ class Aggregator:
map_fn = partial( map_fn = partial(
_map_aux, func, self.path, self.path_out, self.pp_params, **kwargs _map_aux, func, self.path, self.path_out, self.pp_params, **kwargs
) )
if mpi:
pool = MyPool(processes=num_process, maxtasksperchild=1) executor = MPIPoolExecutor(max_workers=num_process)
try:
result = list(executor.map(map_fn, run_num, unordered=True))
finally:
executor.shutdown()
else:
pool = MyPool(processes=num_process)
try:
result = pool.map(map_fn, run_num) result = pool.map(map_fn, run_num)
finally:
pool.close() pool.close()
pool.join() pool.join()