[selector] make select a hardcoded argument, and propaget it (see #6)

This commit is contained in:
Noe Brucy
2021-01-28 11:58:16 +01:00
parent 9f8da0a8db
commit 8c82117dfc
3 changed files with 38 additions and 30 deletions
+2 -11
View File
@@ -5,7 +5,6 @@ from functools import partial
from mypool import MyPool
from postprocessor import PostProcessor
from run_selector import RunSelector
def _map_rule(rule, arg, overwrite, path, path_out, pp_params, run_num):
@@ -20,17 +19,9 @@ def _map_rule(rule, arg, overwrite, path, path_out, pp_params, run_num):
class Aggregator:
def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs):
if "select" in kwargs:
select = kwargs["select"]
def _not_self_dep(self, name, dep, dep_arg, overwrite, select):
if select is not None:
runs, nums = self.selector.select(**select)
elif "runs" in kwargs:
runs = kwargs["runs"]
if isinstance(runs, RunSelector):
nums = runs.nums
runs = runs.runs
else:
nums = self.nums
else:
runs = self.runs
nums = self.nums
+23 -13
View File
@@ -77,7 +77,13 @@ class BaseProcessor:
print(self.log_id + string)
def process(
self, to_process, arg=None, overwrite=False, overwrite_dep=False, **kwargs
self,
to_process,
arg=None,
overwrite=False,
overwrite_dep=False,
select=None,
**kwargs,
):
"""
Process the rule `to_process`
@@ -89,7 +95,7 @@ class BaseProcessor:
if to_process in self.rules:
rule = self.rules[to_process]
return self._solve_and_process_rule(
to_process, rule, arg, overwrite, **kwargs
to_process, rule, arg, overwrite, select, **kwargs
)
else:
self._log(
@@ -99,12 +105,14 @@ class BaseProcessor:
"ERROR",
)
def _solve_and_process_rule(self, name, rule, arg, overwrite=False, **kwargs):
updated = self._solve_dependencies(name, rule, arg, overwrite, **kwargs)
def _solve_and_process_rule(
self, name, rule, arg, overwrite=False, select=None, **kwargs
):
updated = self._solve_dependencies(name, rule, arg, overwrite, select)
overwrite_rule = overwrite or updated
return self._process_rule(name, rule, arg, overwrite_rule, **kwargs)
return self._process_rule(name, rule, arg, overwrite_rule, select, **kwargs)
def _solve_dependencies(self, name, rule, arg, overwrite=False, **kwargs):
def _solve_dependencies(self, name, rule, arg, overwrite=False, select=None):
self.done_before_dep = len(self.just_done)
@@ -113,7 +121,7 @@ class BaseProcessor:
# get arguments
try:
dep_arg = rule.dependencies[dep]
except:
except (TypeError, KeyError):
dep_arg = arg
if dep_arg == "__parent__":
@@ -123,20 +131,22 @@ class BaseProcessor:
# it to a child processor
if self.solve_self_dep and dep in self.rules:
rule_dep = self.rules[dep]
self._solve_and_process_rule(dep, rule_dep, dep_arg, self.overwrite_dep)
self._solve_and_process_rule(
dep, rule_dep, dep_arg, self.overwrite_dep, select
)
else:
self._not_self_dep(name, dep, dep_arg, self.overwrite_dep, **kwargs)
self._not_self_dep(name, dep, dep_arg, self.overwrite_dep, select)
# Whether dependencies where updated
return len(self.just_done) > self.done_before_dep
def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs):
def _not_self_dep(self, name, dep, dep_arg, overwrite, select=None):
self._log("Dependency {} for {} is unknown".format(dep, name), "ERROR")
def _needs_computation(self, overwrite, name_full):
return overwrite
def _process_rule(self, name, rule, arg, overwrite=False, **kwargs):
def _process_rule(self, name, rule, arg, overwrite=False, select=None, **kwargs):
if arg is not None:
name_full = rule.group + "/" + name + "_" + str(arg)
else:
@@ -186,11 +196,11 @@ class HDF5Container(BaseProcessor):
def _needs_computation(self, overwrite, name_full):
return overwrite or not (name_full in self.save)
def _process_rule(self, name, rule, arg, overwrite, **kwargs):
def _process_rule(self, name, rule, arg, overwrite, select, **kwargs):
self.open()
try:
super(HDF5Container, self)._process_rule(
name, rule, arg, overwrite, **kwargs
name, rule, arg, overwrite, select, **kwargs
)
finally:
self.close()
+13 -6
View File
@@ -232,18 +232,18 @@ class Plotter(Aggregator, BaseProcessor):
self.simulations[run] = simu
def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs):
def _not_self_dep(self, name, dep, dep_arg, overwrite, select):
"""
Check if the dependency belongs to the plotter object or to another one (comp, pp, ..)
"""
if dep in self.comp.rules:
result = self.comp.process(
dep, dep_arg, overwrite, overwrite_dep=self.overwrite_dep
dep, dep_arg, overwrite, self.overwrite_dep, select
)
if result is not None:
self.just_done.append(result)
else:
super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, **kwargs)
super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, select)
def _needs_computation(self, overwrite, plot_filename):
"""
@@ -256,7 +256,15 @@ class Plotter(Aggregator, BaseProcessor):
)
def _process_rule(
self, name, rule, arg, overwrite=False, ax=None, from_cells=False, **kwargs
self,
name,
rule,
arg,
overwrite=False,
select=None,
ax=None,
from_cells=False,
**kwargs,
):
"""
Open storage and figure if needed before processing a rule
@@ -282,8 +290,7 @@ class Plotter(Aggregator, BaseProcessor):
filetype = filetype_from_ext[self.pp_params.out.ext]
# Select runs and nums
if "select" in kwargs:
select = kwargs.pop("select")
if select is not None:
runs, nums = self.selector.select(**select)
else:
runs = self.runs