diff --git a/aggregator.py b/aggregator.py index aa02f90..f353429 100644 --- a/aggregator.py +++ b/aggregator.py @@ -5,7 +5,6 @@ from functools import partial from mypool import MyPool from postprocessor import PostProcessor -from run_selector import RunSelector def _map_rule(rule, arg, overwrite, path, path_out, pp_params, run_num): @@ -20,17 +19,9 @@ def _map_rule(rule, arg, overwrite, path, path_out, pp_params, run_num): class Aggregator: - def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs): - if "select" in kwargs: - select = kwargs["select"] + def _not_self_dep(self, name, dep, dep_arg, overwrite, select): + if select is not None: runs, nums = self.selector.select(**select) - elif "runs" in kwargs: - runs = kwargs["runs"] - if isinstance(runs, RunSelector): - nums = runs.nums - runs = runs.runs - else: - nums = self.nums else: runs = self.runs nums = self.nums diff --git a/baseprocessor.py b/baseprocessor.py index c9b019d..7ce2cf0 100644 --- a/baseprocessor.py +++ b/baseprocessor.py @@ -77,7 +77,13 @@ class BaseProcessor: print(self.log_id + string) def process( - self, to_process, arg=None, overwrite=False, overwrite_dep=False, **kwargs + self, + to_process, + arg=None, + overwrite=False, + overwrite_dep=False, + select=None, + **kwargs, ): """ Process the rule `to_process` @@ -89,7 +95,7 @@ class BaseProcessor: if to_process in self.rules: rule = self.rules[to_process] return self._solve_and_process_rule( - to_process, rule, arg, overwrite, **kwargs + to_process, rule, arg, overwrite, select, **kwargs ) else: self._log( @@ -99,12 +105,14 @@ class BaseProcessor: "ERROR", ) - def _solve_and_process_rule(self, name, rule, arg, overwrite=False, **kwargs): - updated = self._solve_dependencies(name, rule, arg, overwrite, **kwargs) + def _solve_and_process_rule( + self, name, rule, arg, overwrite=False, select=None, **kwargs + ): + updated = self._solve_dependencies(name, rule, arg, overwrite, select) overwrite_rule = overwrite or updated - return self._process_rule(name, rule, arg, overwrite_rule, **kwargs) + return self._process_rule(name, rule, arg, overwrite_rule, select, **kwargs) - def _solve_dependencies(self, name, rule, arg, overwrite=False, **kwargs): + def _solve_dependencies(self, name, rule, arg, overwrite=False, select=None): self.done_before_dep = len(self.just_done) @@ -113,7 +121,7 @@ class BaseProcessor: # get arguments try: dep_arg = rule.dependencies[dep] - except: + except (TypeError, KeyError): dep_arg = arg if dep_arg == "__parent__": @@ -123,20 +131,22 @@ class BaseProcessor: # it to a child processor if self.solve_self_dep and dep in self.rules: rule_dep = self.rules[dep] - self._solve_and_process_rule(dep, rule_dep, dep_arg, self.overwrite_dep) + self._solve_and_process_rule( + dep, rule_dep, dep_arg, self.overwrite_dep, select + ) else: - self._not_self_dep(name, dep, dep_arg, self.overwrite_dep, **kwargs) + self._not_self_dep(name, dep, dep_arg, self.overwrite_dep, select) # Whether dependencies where updated return len(self.just_done) > self.done_before_dep - def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs): + def _not_self_dep(self, name, dep, dep_arg, overwrite, select=None): self._log("Dependency {} for {} is unknown".format(dep, name), "ERROR") def _needs_computation(self, overwrite, name_full): return overwrite - def _process_rule(self, name, rule, arg, overwrite=False, **kwargs): + def _process_rule(self, name, rule, arg, overwrite=False, select=None, **kwargs): if arg is not None: name_full = rule.group + "/" + name + "_" + str(arg) else: @@ -186,11 +196,11 @@ class HDF5Container(BaseProcessor): def _needs_computation(self, overwrite, name_full): return overwrite or not (name_full in self.save) - def _process_rule(self, name, rule, arg, overwrite, **kwargs): + def _process_rule(self, name, rule, arg, overwrite, select, **kwargs): self.open() try: super(HDF5Container, self)._process_rule( - name, rule, arg, overwrite, **kwargs + name, rule, arg, overwrite, select, **kwargs ) finally: self.close() diff --git a/plotter.py b/plotter.py index 2d7d4ac..80544a8 100644 --- a/plotter.py +++ b/plotter.py @@ -232,18 +232,18 @@ class Plotter(Aggregator, BaseProcessor): self.simulations[run] = simu - def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs): + def _not_self_dep(self, name, dep, dep_arg, overwrite, select): """ Check if the dependency belongs to the plotter object or to another one (comp, pp, ..) """ if dep in self.comp.rules: result = self.comp.process( - dep, dep_arg, overwrite, overwrite_dep=self.overwrite_dep + dep, dep_arg, overwrite, self.overwrite_dep, select ) if result is not None: self.just_done.append(result) else: - super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, **kwargs) + super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, select) def _needs_computation(self, overwrite, plot_filename): """ @@ -256,7 +256,15 @@ class Plotter(Aggregator, BaseProcessor): ) def _process_rule( - self, name, rule, arg, overwrite=False, ax=None, from_cells=False, **kwargs + self, + name, + rule, + arg, + overwrite=False, + select=None, + ax=None, + from_cells=False, + **kwargs, ): """ Open storage and figure if needed before processing a rule @@ -282,8 +290,7 @@ class Plotter(Aggregator, BaseProcessor): filetype = filetype_from_ext[self.pp_params.out.ext] # Select runs and nums - if "select" in kwargs: - select = kwargs.pop("select") + if select is not None: runs, nums = self.selector.select(**select) else: runs = self.runs