[selector] make select a hardcoded argument, and propaget it (see #6)

This commit is contained in:
Noe Brucy
2021-01-28 11:58:16 +01:00
parent 9f8da0a8db
commit 8c82117dfc
3 changed files with 38 additions and 30 deletions
+2 -11
View File
@@ -5,7 +5,6 @@ from functools import partial
from mypool import MyPool from mypool import MyPool
from postprocessor import PostProcessor from postprocessor import PostProcessor
from run_selector import RunSelector
def _map_rule(rule, arg, overwrite, path, path_out, pp_params, run_num): def _map_rule(rule, arg, overwrite, path, path_out, pp_params, run_num):
@@ -20,17 +19,9 @@ def _map_rule(rule, arg, overwrite, path, path_out, pp_params, run_num):
class Aggregator: class Aggregator:
def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs): def _not_self_dep(self, name, dep, dep_arg, overwrite, select):
if "select" in kwargs: if select is not None:
select = kwargs["select"]
runs, nums = self.selector.select(**select) runs, nums = self.selector.select(**select)
elif "runs" in kwargs:
runs = kwargs["runs"]
if isinstance(runs, RunSelector):
nums = runs.nums
runs = runs.runs
else:
nums = self.nums
else: else:
runs = self.runs runs = self.runs
nums = self.nums nums = self.nums
+23 -13
View File
@@ -77,7 +77,13 @@ class BaseProcessor:
print(self.log_id + string) print(self.log_id + string)
def process( def process(
self, to_process, arg=None, overwrite=False, overwrite_dep=False, **kwargs self,
to_process,
arg=None,
overwrite=False,
overwrite_dep=False,
select=None,
**kwargs,
): ):
""" """
Process the rule `to_process` Process the rule `to_process`
@@ -89,7 +95,7 @@ class BaseProcessor:
if to_process in self.rules: if to_process in self.rules:
rule = self.rules[to_process] rule = self.rules[to_process]
return self._solve_and_process_rule( return self._solve_and_process_rule(
to_process, rule, arg, overwrite, **kwargs to_process, rule, arg, overwrite, select, **kwargs
) )
else: else:
self._log( self._log(
@@ -99,12 +105,14 @@ class BaseProcessor:
"ERROR", "ERROR",
) )
def _solve_and_process_rule(self, name, rule, arg, overwrite=False, **kwargs): def _solve_and_process_rule(
updated = self._solve_dependencies(name, rule, arg, overwrite, **kwargs) self, name, rule, arg, overwrite=False, select=None, **kwargs
):
updated = self._solve_dependencies(name, rule, arg, overwrite, select)
overwrite_rule = overwrite or updated overwrite_rule = overwrite or updated
return self._process_rule(name, rule, arg, overwrite_rule, **kwargs) return self._process_rule(name, rule, arg, overwrite_rule, select, **kwargs)
def _solve_dependencies(self, name, rule, arg, overwrite=False, **kwargs): def _solve_dependencies(self, name, rule, arg, overwrite=False, select=None):
self.done_before_dep = len(self.just_done) self.done_before_dep = len(self.just_done)
@@ -113,7 +121,7 @@ class BaseProcessor:
# get arguments # get arguments
try: try:
dep_arg = rule.dependencies[dep] dep_arg = rule.dependencies[dep]
except: except (TypeError, KeyError):
dep_arg = arg dep_arg = arg
if dep_arg == "__parent__": if dep_arg == "__parent__":
@@ -123,20 +131,22 @@ class BaseProcessor:
# it to a child processor # it to a child processor
if self.solve_self_dep and dep in self.rules: if self.solve_self_dep and dep in self.rules:
rule_dep = self.rules[dep] rule_dep = self.rules[dep]
self._solve_and_process_rule(dep, rule_dep, dep_arg, self.overwrite_dep) self._solve_and_process_rule(
dep, rule_dep, dep_arg, self.overwrite_dep, select
)
else: else:
self._not_self_dep(name, dep, dep_arg, self.overwrite_dep, **kwargs) self._not_self_dep(name, dep, dep_arg, self.overwrite_dep, select)
# Whether dependencies where updated # Whether dependencies where updated
return len(self.just_done) > self.done_before_dep return len(self.just_done) > self.done_before_dep
def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs): def _not_self_dep(self, name, dep, dep_arg, overwrite, select=None):
self._log("Dependency {} for {} is unknown".format(dep, name), "ERROR") self._log("Dependency {} for {} is unknown".format(dep, name), "ERROR")
def _needs_computation(self, overwrite, name_full): def _needs_computation(self, overwrite, name_full):
return overwrite return overwrite
def _process_rule(self, name, rule, arg, overwrite=False, **kwargs): def _process_rule(self, name, rule, arg, overwrite=False, select=None, **kwargs):
if arg is not None: if arg is not None:
name_full = rule.group + "/" + name + "_" + str(arg) name_full = rule.group + "/" + name + "_" + str(arg)
else: else:
@@ -186,11 +196,11 @@ class HDF5Container(BaseProcessor):
def _needs_computation(self, overwrite, name_full): def _needs_computation(self, overwrite, name_full):
return overwrite or not (name_full in self.save) return overwrite or not (name_full in self.save)
def _process_rule(self, name, rule, arg, overwrite, **kwargs): def _process_rule(self, name, rule, arg, overwrite, select, **kwargs):
self.open() self.open()
try: try:
super(HDF5Container, self)._process_rule( super(HDF5Container, self)._process_rule(
name, rule, arg, overwrite, **kwargs name, rule, arg, overwrite, select, **kwargs
) )
finally: finally:
self.close() self.close()
+13 -6
View File
@@ -232,18 +232,18 @@ class Plotter(Aggregator, BaseProcessor):
self.simulations[run] = simu self.simulations[run] = simu
def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs): def _not_self_dep(self, name, dep, dep_arg, overwrite, select):
""" """
Check if the dependency belongs to the plotter object or to another one (comp, pp, ..) Check if the dependency belongs to the plotter object or to another one (comp, pp, ..)
""" """
if dep in self.comp.rules: if dep in self.comp.rules:
result = self.comp.process( result = self.comp.process(
dep, dep_arg, overwrite, overwrite_dep=self.overwrite_dep dep, dep_arg, overwrite, self.overwrite_dep, select
) )
if result is not None: if result is not None:
self.just_done.append(result) self.just_done.append(result)
else: else:
super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, **kwargs) super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, select)
def _needs_computation(self, overwrite, plot_filename): def _needs_computation(self, overwrite, plot_filename):
""" """
@@ -256,7 +256,15 @@ class Plotter(Aggregator, BaseProcessor):
) )
def _process_rule( def _process_rule(
self, name, rule, arg, overwrite=False, ax=None, from_cells=False, **kwargs self,
name,
rule,
arg,
overwrite=False,
select=None,
ax=None,
from_cells=False,
**kwargs,
): ):
""" """
Open storage and figure if needed before processing a rule Open storage and figure if needed before processing a rule
@@ -282,8 +290,7 @@ class Plotter(Aggregator, BaseProcessor):
filetype = filetype_from_ext[self.pp_params.out.ext] filetype = filetype_from_ext[self.pp_params.out.ext]
# Select runs and nums # Select runs and nums
if "select" in kwargs: if select is not None:
select = kwargs.pop("select")
runs, nums = self.selector.select(**select) runs, nums = self.selector.select(**select)
else: else:
runs = self.runs runs = self.runs