Files
pipeline/baseprocessor.py
T

472 lines
14 KiB
Python

# coding: utf-8
import copy
import os
import time
from abc import ABCMeta
from functools import partial
import numpy as np
import tables
from tables import HDF5ExtError
from params import default_params, load_params
from units import U
class Rule:
def __init__(
self,
postproc,
process,
description="",
group="",
dependencies=[],
kind="snapshot",
unit=U.none,
):
self.postproc = postproc
self.process_fn = process
self.dependencies = dependencies
self.group = group
self.description = description
self.unit = unit
self.kind = kind
def process(self, arg, **kwargs):
if arg is not None:
return self.process_fn(arg, **kwargs)
else:
return self.process_fn(**kwargs)
class BaseProcessor:
"""
Base class for processors, should not be instanciated
"""
__metaclass__ = ABCMeta
log_id = ""
rules = {}
solve_self_dep = True
def __init__(self, path, path_out=".", params=None, tag=None):
if params is None:
self.params = default_params()
elif type(params) == str:
self.params = load_params(params)
else:
self.params = copy.deepcopy(params)
# Determining output directory
if path_out is None:
self.path_out = path
else:
self.path_out = path_out
if tag is not None:
self.params.out.tag = tag
def _log(self, string, status=""):
if self.params.process.verbose:
if len(status) > 0:
print(status + ": " + self.log_id + string)
else:
print(self.log_id + string)
def process(
self,
to_process,
arg=None,
overwrite=False,
overwrite_dep=False,
select=None,
**kwargs,
):
"""
Process the rule `to_process`
"""
self.overwrite_dep = overwrite_dep
self.just_done = []
if to_process in self.rules:
rule = self.rules[to_process]
return self._solve_and_process_rule(
to_process, rule, arg, overwrite, select, **kwargs
)
else:
self._log(
"{} is unknown, allowed rules are {}".format(
to_process, self.rules.keys()
),
"ERROR",
)
def _solve_and_process_rule(
self, name, rule, arg, overwrite=False, select=None, **kwargs
):
updated = self._solve_dependencies(name, rule, arg, overwrite, select)
overwrite_rule = overwrite or updated
return self._process_rule(name, rule, arg, overwrite_rule, select, **kwargs)
def _solve_dependencies(self, name, rule, arg, overwrite=False, select=None):
self.done_before_dep = len(self.just_done)
# Solve dependencies
for dep in rule.dependencies:
# get arguments
try:
dep_arg = rule.dependencies[dep]
except (TypeError, KeyError):
dep_arg = arg
if dep_arg == "__parent__":
dep_arg = arg
# Whether the processor solves its own dependencies or it gives
# it to a child processor
if self.solve_self_dep and dep in self.rules:
rule_dep = self.rules[dep]
self._solve_and_process_rule(
dep, rule_dep, dep_arg, self.overwrite_dep, select
)
else:
self._not_self_dep(name, dep, dep_arg, self.overwrite_dep, select)
# Whether dependencies where updated
return len(self.just_done) > self.done_before_dep
def _not_self_dep(self, name, dep, dep_arg, overwrite, select=None):
self._log("Dependency {} for {} is unknown".format(dep, name), "ERROR")
def _needs_computation(self, overwrite, name_full):
return overwrite
def _process_rule(self, name, rule, arg, overwrite=False, select=None, **kwargs):
if arg is not None:
name_full = rule.group + "/" + name + "_" + str(arg)
else:
name_full = rule.group + "/" + name
if name_full not in self.just_done:
if self._needs_computation(overwrite, name_full):
self._log("Processing {}".format(name_full))
data = rule.process(arg, **kwargs)
self._save_data(name_full, data, rule.description, rule.unit)
self._log("Data for {} computed".format(name_full), "SUCCESS")
self.just_done.append(name_full)
return data
else:
self._log(
"Data for {} is already computed, skipping...".format(name_full)
)
def def_rules(self):
for rule in self.rules:
func = partial(self.process, rule)
func.__doc__ = self.rules[rule].description
setattr(self, rule, func)
class HDF5Container(BaseProcessor):
filename = ""
save = None
opened = False
def open(self):
if not self.opened:
try:
self.save = tables.open_file(self.filename, mode="a")
except HDF5ExtError:
# Wait a bit if the lock was not still released
time.sleep(3)
self.save = tables.open_file(self.filename, mode="a")
self.opened = True
def close(self):
if self.opened:
self.save.close()
self.opened = False
def _needs_computation(self, overwrite, name_full):
return overwrite or not (name_full in self.save)
def _process_rule(self, name, rule, arg, overwrite, select, **kwargs):
self.open()
try:
super(HDF5Container, self)._process_rule(
name, rule, arg, overwrite, select, **kwargs
)
finally:
self.close()
def get_value(self, node_name, unit=None, unit_old=None):
open_before = self.opened
if not open_before:
self.open()
try:
node = self.save.get_node(node_name)
if "unit" in node._v_attrs:
unit_old = node._v_attrs.unit
if node._v_attrs.CLASS == "GROUP":
value = {}
for child_name in node._v_children:
value[child_name] = self.get_value(
node_name + "/" + child_name, unit, unit_old
)
else:
value = node.read()
if not (unit is None or unit_old is None or unit_old == U.none):
value = value * unit_old.express(unit)
finally:
if not open_before:
self.close()
return value
def set_value(self, node_name, data, description, unit):
open_before = self.opened
if not open_before:
self.open()
try:
self._save_data(node_name, data, description, unit)
finally:
if not open_before:
self.close()
def get_attribute(self, node_name, attr_name):
open_before = self.opened
if not open_before:
self.open()
try:
node = self.save.get_node(node_name)
attr = node._v_attrs[attr_name]
finally:
if not open_before:
self.close()
return attr
def set_attribute(self, node_name, attr_name, attr_value):
open_before = self.opened
if not open_before:
self.open()
try:
node = self.save.get_node(node_name)
node._v_attrs[attr_name] = attr_value
finally:
if not open_before:
self.close()
def _get_units(self, unit, data=None):
"""
Get real units from info files
unit is either:
1. An instance of U.Unit (pymses unit class)
2. A string beginning by "unit_", referring to a code unit,
available in self.info
3. A dict {unit1 : exp1, unit2: exp2, ...} with unitX as 2.
and expX a float, referring to the compound unit
unit1**exp1 * unit2**exp2
4. A dict {key: unit, ...} where key is a field name (eg. 'time', or 'mass')
and unit the corresponding unit (on one on the above format)
Returns:
1-3. : a U.Unit instance
4. : a dict {key: unit, ...} with same key as input and unit being U.Unit instances
"""
if isinstance(unit, U.Unit):
return unit
if isinstance(unit, str) and unit[:5] == "unit_":
res = self.info[unit]
if unit == "unit_length":
res = res / self.info["boxlen"]
return res
if list(unit)[0][:5] == "unit_":
new_unit = U.none
for base_unit_str in unit:
expo = unit[base_unit_str]
base_unit = self._get_units(base_unit_str)
new_unit = new_unit * base_unit ** expo
return new_unit
if (data is not None) and isinstance(data, dict) and list(unit)[0] in data:
for key in unit:
unit[key] = self._get_units(unit[key])
return unit
else:
raise ValueError("Invalid unit")
def _save_data(self, name_full, data, description, unit):
"""
Save data in the HDF5 structure, overwrite if necessary
"""
unit = self._get_units(unit, data=data)
if name_full in self.save:
self.save.remove_node(name_full, recursive=True)
attrs = None
if isinstance(data, tuple):
attrs = data[1]
data = data[0]
if isinstance(data, dict):
if type(description) == str:
self.save.create_group(
os.path.dirname(name_full),
os.path.basename(name_full),
description,
createparents=True,
)
else:
self.save.create_group(
os.path.dirname(name_full),
os.path.basename(name_full),
"",
createparents=True,
)
if not isinstance(unit, dict):
self.save.get_node(name_full)._v_attrs.unit = unit
for key in data:
key = str(key)
if isinstance(description, dict):
if isinstance(unit, dict):
self._save_data(
name_full + "/" + key,
data[key],
description[key],
unit[key],
)
else:
self._save_data(
name_full + "/" + key, data[key], description[key], unit
)
else:
if isinstance(unit, dict):
self._save_data(name_full + "/" + key, data[key], "", unit[key])
else:
self._save_data(name_full + "/" + key, data[key], "", unit)
else:
try:
if len(data) == 0:
return
except TypeError:
data = np.array([data])
self.save.create_array(
os.path.dirname(name_full),
os.path.basename(name_full),
data,
description,
createparents=True,
)
self.save.get_node(name_full).attrs.unit = unit
if attrs is not None:
for key in attrs:
key = str(key)
self.save.get_node(name_full)._v_attrs[key] = attrs[key]
def _transform(self, name, transform_fn, group="/maps", **kwargs):
src = self.save.get_node(group + "/" + name).read()
return transform_fn(src, **kwargs)
def _gen_rule_transform(
self,
rule_src_name,
transform_fn,
transform_name,
subarray_name=None,
group=None,
):
rule_src = self.rules[rule_src_name]
if subarray_name is None:
src_name = rule_src_name
group_src = rule_src.group
unit = rule_src.unit
description = rule_src.description
else:
src_name = subarray_name
group_src = rule_src.group + "/" + rule_src_name
unit = rule_src.unit[subarray_name]
description = rule_src.description[subarray_name]
def fn(arg=None, **kwargs):
if arg is None:
return self._transform(
src_name, transform_fn, group=group_src, **kwargs
)
else:
return self._transform(
src_name + "_" + str(arg), transform_fn, group=group_src, **kwargs
)
if group is None:
group = group_src
name = transform_name + "_" + rule_src_name
self.rules[name] = Rule(
self,
fn,
group=group,
unit=unit,
description=description,
dependencies=[rule_src_name],
)
def apply(
self,
fn,
name_array_in,
name_array_out,
unit_out=U.none,
description="",
recursive=True,
):
array_in = self.get_value(name_array_in)
if recursive and isinstance(array_in, dict):
for key in array_in:
self.apply(
fn,
name_array_in + "/" + key,
name_array_out + "/" + key,
unit_out,
description,
)
self.set_attribute(name_array_out, "unit", unit_out)
else:
try:
array_out = fn(array_in)
except TypeError:
array_out = array_in
self.set_value(name_array_out, array_out, description, unit_out)
def simple_getter(name, dset):
return dset[name]
def vect_getter(name, i, dset):
return dset[name][:, i]
def oct_vect_getter(name, i, dset):
return dset[name][:, :, i]
def norm_getter(name, dset):
return np.sqrt(np.sum(dset[name] ** 2, axis=1))