diff --git a/aggregator.py b/aggregator.py index 3498703..549a0d5 100644 --- a/aggregator.py +++ b/aggregator.py @@ -7,7 +7,7 @@ from mypool import MyPool from postprocessor import PostProcessor -def _map_rule(rule, arg, overwrite, path, path_out, pp_params, run_num): +def _map_aux(fun, path, path_out, pp_params, run_num, **kwargs): try: pp = PostProcessor( path + "/" + run_num[0], run_num[1], path_out + "/" + run_num[0], pp_params @@ -15,35 +15,56 @@ def _map_rule(rule, arg, overwrite, path, path_out, pp_params, run_num): except Exception as e: print(e) raise - return pp.process(rule, arg, overwrite, overwrite) + return fun(pp, **kwargs) + + +def _map_rule(pp, rule, arg, overwrite, overwrite_dep): + return pp.process(rule, arg, overwrite, overwrite_dep) class Aggregator: - def get_pp_list(self): - return [self.pp[run][num] for run in self.runs for num in self.nums[run]] + def get_pp_list(self, select=None): - def map(self, func): - return [func(pp) for pp in self.get_pp_list()] - - def _not_self_dep(self, name, dep, dep_arg, overwrite, select): if select is not None: runs, nums = self.selector.select(**select) else: runs = self.runs nums = self.nums + return [self.pp[run][num] for run in runs for num in nums[run]] - run_num = [(run, num) for run in runs for num in nums[run]] - map_fn = partial( - _map_rule, dep, dep_arg, overwrite, self.path, self.path_out, self.pp_params - ) + def map(self, func, select=None, num_process=None, **kwargs): - if self.pp_params.process.num_process > 1: - pool = MyPool(processes=self.pp_params.process.num_process) + pp_list = self.get_pp_list(select) + + if num_process is None: + num_process = self.pp_params.process.num_process + + if num_process == 1: + result = [func(pp, **kwargs) for pp in pp_list] + else: + run_num = [(pp.run, pp.num) for pp in pp_list] + map_fn = partial( + _map_aux, func, self.path, self.path_out, self.pp_params, **kwargs + ) + + pool = MyPool(processes=num_process) result = pool.map(map_fn, run_num) pool.close() pool.join() - else: - result = map(map_fn, run_num) + + return result + + def _not_self_dep(self, name, dep, dep_arg, overwrite, select): + + result = self.map( + _map_rule, + select, + None, + rule=dep, + arg=dep_arg, + overwrite=overwrite, + overwrite_dep=overwrite, + ) if np.any([res is not None for res in result]): self.just_done.append(dep)