[aggregator] add mpi backend
This commit is contained in:
+23
-8
@@ -2,14 +2,21 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from mypool import MyPool
|
import postprocessor
|
||||||
|
|
||||||
from postprocessor import PostProcessor
|
try:
|
||||||
|
from mpi4py.futures import MPIPoolExecutor
|
||||||
|
|
||||||
|
mpi = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from mypool import MyPool
|
||||||
|
|
||||||
|
mpi = False
|
||||||
|
|
||||||
|
|
||||||
def _map_aux(fun, path, path_out, pp_params, run_num, **kwargs):
|
def _map_aux(fun, path, path_out, pp_params, run_num, **kwargs):
|
||||||
try:
|
try:
|
||||||
pp = PostProcessor(
|
pp = postprocessor.PostProcessor(
|
||||||
path + "/" + run_num[0], run_num[1], path_out + "/" + run_num[0], pp_params
|
path + "/" + run_num[0], run_num[1], path_out + "/" + run_num[0], pp_params
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -46,11 +53,19 @@ class Aggregator:
|
|||||||
map_fn = partial(
|
map_fn = partial(
|
||||||
_map_aux, func, self.path, self.path_out, self.pp_params, **kwargs
|
_map_aux, func, self.path, self.path_out, self.pp_params, **kwargs
|
||||||
)
|
)
|
||||||
|
if mpi:
|
||||||
pool = MyPool(processes=num_process, maxtasksperchild=1)
|
executor = MPIPoolExecutor(max_workers=num_process)
|
||||||
result = pool.map(map_fn, run_num)
|
try:
|
||||||
pool.close()
|
result = list(executor.map(map_fn, run_num, unordered=True))
|
||||||
pool.join()
|
finally:
|
||||||
|
executor.shutdown()
|
||||||
|
else:
|
||||||
|
pool = MyPool(processes=num_process)
|
||||||
|
try:
|
||||||
|
result = pool.map(map_fn, run_num)
|
||||||
|
finally:
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user