89 lines
2.5 KiB
Python
89 lines
2.5 KiB
Python
# coding: utf-8
|
|
|
|
import numpy as np
|
|
from functools import partial
|
|
from baseprocessor import Rule
|
|
import snapshotprocessor
|
|
from utils.mypool import MyPool
|
|
|
|
try:
|
|
from mpi4py.futures import MPIPoolExecutor
|
|
|
|
mpi_loaded = True
|
|
except ModuleNotFoundError:
|
|
mpi_loaded = False
|
|
|
|
|
|
def _map_aux(fun, path, path_out, params, run_num, **kwargs):
|
|
try:
|
|
snap = snapshotprocessor.SnapshotProcessor(
|
|
path + "/" + run_num[0], run_num[1], path_out + "/" + run_num[0], params
|
|
)
|
|
except Exception as e:
|
|
print(e)
|
|
raise
|
|
return fun(snap, **kwargs)
|
|
|
|
|
|
def _map_rule(snap, rule, **kwargs):
|
|
return snap.process(rule, **kwargs)
|
|
|
|
|
|
class Aggregator:
|
|
def get_snap_list(self, select=None):
|
|
|
|
if select is not None:
|
|
runs, nums = self.selector.select(**select)
|
|
else:
|
|
runs = self.runs
|
|
nums = self.nums
|
|
return [self.snaps[run][num] for run in runs for num in nums[run]]
|
|
|
|
def map(self, func, select=None, num_process=None, **kwargs):
|
|
|
|
if isinstance(func, Rule):
|
|
return self.map(_map_rule, select, num_process, rule=func, **kwargs)
|
|
|
|
snaps = self.get_snap_list(select)
|
|
|
|
if num_process is None:
|
|
num_process = self.params.process.num_process
|
|
|
|
if num_process == 1:
|
|
result = [func(snap, **kwargs) for snap in snaps]
|
|
else:
|
|
run_num = [(snap.run, snap.num) for snap in snaps]
|
|
map_fn = partial(
|
|
_map_aux, func, self.path, self.path_out, self.params, **kwargs
|
|
)
|
|
if mpi_loaded and self.params.process.mpi:
|
|
executor = MPIPoolExecutor(max_workers=num_process)
|
|
try:
|
|
result = list(executor.map(map_fn, run_num, unordered=True))
|
|
finally:
|
|
executor.shutdown()
|
|
else:
|
|
pool = MyPool(processes=num_process)
|
|
try:
|
|
result = pool.map(map_fn, run_num)
|
|
finally:
|
|
pool.close()
|
|
pool.join()
|
|
|
|
return result
|
|
|
|
def _not_self_dep(self, name, dep, dep_arg, overwrite, select):
|
|
|
|
result = self.map(
|
|
_map_rule,
|
|
select,
|
|
None,
|
|
rule=dep,
|
|
arg=dep_arg,
|
|
overwrite=overwrite,
|
|
overwrite_dep=overwrite,
|
|
)
|
|
|
|
if np.any([res is not None for res in result]):
|
|
self.just_done.append(dep)
|