Files
pipeline/aggregator.py
T
2021-06-24 10:47:53 +02:00

86 lines
2.3 KiB
Python

# coding: utf-8
import numpy as np
from functools import partial
import postprocessor
try:
from mpi4py.futures import MPIPoolExecutor
mpi = True
except ModuleNotFoundError:
from mypool import MyPool
mpi = False
def _map_aux(fun, path, path_out, pp_params, run_num, **kwargs):
try:
pp = postprocessor.PostProcessor(
path + "/" + run_num[0], run_num[1], path_out + "/" + run_num[0], pp_params
)
except Exception as e:
print(e)
raise
return fun(pp, **kwargs)
def _map_rule(pp, rule, arg, overwrite, overwrite_dep):
return pp.process(rule, arg, overwrite, overwrite_dep)
class Aggregator:
def get_pp_list(self, select=None):
if select is not None:
runs, nums = self.selector.select(**select)
else:
runs = self.runs
nums = self.nums
return [self.pp[run][num] for run in runs for num in nums[run]]
def map(self, func, select=None, num_process=None, **kwargs):
pp_list = self.get_pp_list(select)
if num_process is None:
num_process = self.pp_params.process.num_process
if num_process == 1:
result = [func(pp, **kwargs) for pp in pp_list]
else:
run_num = [(pp.run, pp.num) for pp in pp_list]
map_fn = partial(
_map_aux, func, self.path, self.path_out, self.pp_params, **kwargs
)
if mpi:
executor = MPIPoolExecutor(max_workers=num_process)
try:
result = list(executor.map(map_fn, run_num, unordered=True))
finally:
executor.shutdown()
else:
pool = MyPool(processes=num_process)
try:
result = pool.map(map_fn, run_num)
finally:
pool.close()
pool.join()
return result
def _not_self_dep(self, name, dep, dep_arg, overwrite, select):
result = self.map(
_map_rule,
select,
None,
rule=dep,
arg=dep_arg,
overwrite=overwrite,
overwrite_dep=overwrite,
)
if np.any([res is not None for res in result]):
self.just_done.append(dep)