Files
pipeline/baseprocessor.py
2023-07-12 11:13:03 +02:00

561 lines
18 KiB
Python

# coding: utf-8
import copy
import os
import time
import logging
from abc import ABCMeta
from functools import partial
import numpy as np
import tables
from tables import HDF5ExtError, NoSuchNodeError
from tables.registry import class_name_dict
from utils.params import default_params, load_params
from utils.units import U
import sys
import traceback
class Rule:
def __init__(
self,
process,
description="",
group="",
dependencies=[],
kind="snapshot",
unit=U.none,
name="",
):
self.name = name
self.process_fn = process
self.dependencies = dependencies
self.group = group
self.description = description
self.unit = unit
self.kind = kind
def process(self, arg, **kwargs):
if arg is not None:
return self.process_fn(arg, **kwargs)
else:
return self.process_fn(**kwargs)
class BaseProcessor:
"""
Base class for processors, should not be instanciated
"""
__metaclass__ = ABCMeta
log_id = "base"
rules = {}
solve_self_dep = True
def __init__(self, path, path_out=".", params=None, tag=None):
if params is None:
self.params = default_params()
elif type(params) == str:
self.params = load_params(params)
else:
self.params = copy.deepcopy(params)
# Determining output directory
if path_out is None:
self.path_out = path
else:
self.path_out = path_out
if tag is not None:
self.params.out.tag = tag
# Initialize logger
self.logger = logging.getLogger(self.log_id)
self.logger.propagate = False
logging_format = "%(levelname)s | %(asctime)s | %(name)s.%(funcName)s:%(lineno)d | %(message)s"
formatter = logging.Formatter(logging_format, datefmt="%H:%M:%S")
if not self.logger.hasHandlers():
stream = logging.StreamHandler(sys.stdout)
stream.setFormatter(formatter)
self.logger.addHandler(stream)
if len(self.params.process.logfile) > 0:
fileHandler = logging.FileHandler(self.params.process.logfile)
fileHandler.setFormatter(formatter)
self.logger.addHandler(fileHandler)
if self.params.process.verbose:
self.logger.setLevel(logging.DEBUG)
else:
self.logger.setLevel(logging.WARNING)
for handler in self.logger.handlers:
handler.setFormatter(formatter)
def process(
self,
to_process,
arg=None,
overwrite=False,
overwrite_dep=False,
skip_dep=False,
select=None,
**kwargs,
):
self.overwrite_dep = overwrite_dep
self.just_done = []
""" Process the rule 'to_process'
Parameters
----------
to_process : str of Rule
name of the rule to process or Rule object with nonempty rule.name
arg : optional
argument to give to the rule
overwrite : bool, optional
Force redo if already done
overwrite_dep : bool, optional
Force redoing of the dependencies even if already done
skip_dep : bool, optional
Skip the dependency checks (assume they are already done)
select : dict, optional
Select object (see RunSelector) to only select some run/snapshot
"""
if to_process in self.rules:
rule = self.rules[to_process]
return self._solve_and_process_rule(
to_process, rule, arg, overwrite, skip_dep, select, **kwargs
)
elif isinstance(to_process, Rule):
rule = to_process
return self._solve_and_process_rule(
rule.name, rule, arg, overwrite, skip_dep, select, **kwargs
)
else:
self.logger.error(
"{} is unknown, allowed rules are {}".format(
to_process, self.rules.keys()
)
)
def _solve_and_process_rule(
self, name, rule, arg, overwrite=False, skip_dep=False, select=None, **kwargs
):
"""Resolve dependencies and proceed in the processing of a rule
Parameters
----------
name : str
name of the rule
rule : Rule
rule object
overwrite : bool, optional
Force redo if already done
skip_dep : bool, optional
Skip the dependency checks (assume they are already done)
select : dict, optional
Select object (see RunSelector) to only select some run/snapshot
Returns
-------
The outbut of self._process_rule
"""
updated = False
if not skip_dep:
updated = self._solve_dependencies(name, rule, arg, overwrite, select)
overwrite_rule = overwrite or updated
return self._process_rule(name, rule, arg, overwrite_rule, select, **kwargs)
def _solve_dependencies(self, name, rule, arg, overwrite=False, select=None):
self.done_before_dep = len(self.just_done)
# Solve dependencies
for dep in rule.dependencies:
# get arguments
try:
dep_arg = rule.dependencies[dep]
except (TypeError, KeyError):
dep_arg = arg
if dep_arg == "__parent__":
dep_arg = arg
# Whether the processor solves its own dependencies or it gives
# it to a child processor
if self.solve_self_dep and dep in self.rules:
rule_dep = self.rules[dep]
self._solve_and_process_rule(
dep, rule_dep, dep_arg, self.overwrite_dep, select
)
else:
self._not_self_dep(name, dep, dep_arg, self.overwrite_dep, select)
# Whether dependencies where updated
return len(self.just_done) > self.done_before_dep
def _not_self_dep(self, name, dep, dep_arg, overwrite, select=None):
self.logger.error("Dependency {} for {} is unknown".format(dep, name))
def _needs_computation(self, overwrite, name_full):
return overwrite
def _process_rule(self, name, rule, arg, overwrite=False, select=None, **kwargs):
if arg is not None and not isinstance(arg, BaseProcessor):
name_full = rule.group + "/" + name + "_" + str(arg)
else:
name_full = rule.group + "/" + name
if name_full not in self.just_done:
if self._needs_computation(overwrite, name_full):
self.logger.debug("Processing {}".format(name_full))
data = rule.process(arg, **kwargs)
self._save_data(name_full, data, rule.description, rule.unit)
self.logger.info("Data for {} computed".format(name_full))
self.just_done.append(name_full)
return data
else:
self.logger.info("Data for {} is already computed.".format(name_full))
def def_rules(self):
for rule in self.rules:
func = partial(self.process, rule)
func.__doc__ = self.rules[rule].description
setattr(self, rule, func)
class HDF5Container(BaseProcessor):
filename = ""
save = None
opened = False
def open(self):
if not self.opened:
try:
self.save = tables.open_file(self.filename, mode="a")
except HDF5ExtError:
# Wait a bit if the lock was not still released
time.sleep(3)
self.save = tables.open_file(self.filename, mode="a")
self.opened = True
def close(self):
if self.opened:
self.save.close()
self.opened = False
def _needs_computation(self, overwrite, name_full):
return overwrite or not (name_full in self.save)
def _process_rule(self, name, rule, arg, overwrite, select, **kwargs):
self.open()
try:
super(HDF5Container, self)._process_rule(
name, rule, arg, overwrite, select, **kwargs
)
except Exception as e:
if self.params.process.allow_error:
traceback_lines = traceback.format_exc().splitlines()
for line in traceback_lines:
if line != traceback_lines[-1]:
self.logger.error(line)
self.logger.error(f"{repr(e)}")
pass
else:
raise
finally:
self.close()
def get_value(self, node_name, unit=None, unit_old=None):
open_before = self.opened
if not open_before:
self.open()
try:
node = self.save.get_node(node_name)
if "unit" in node._v_attrs:
unit_old = node._v_attrs.unit
if node._v_attrs.CLASS == "GROUP":
value = {}
for child_name in node._v_children:
value[child_name] = self.get_value(
node_name + "/" + child_name, unit, unit_old
)
else:
value = node.read()
if isinstance(unit, dict):
name = os.path.basename(node_name)
if name in unit:
unit = unit[name]
else:
unit = None
if not (unit is None or unit_old is None or unit_old == U.none):
value = value * unit_old.express(unit)
except NoSuchNodeError:
self.logger.error(
f"The value {node_name} is node available", stack_info=True
)
raise
finally:
if not open_before:
self.close()
return value
def set_value(self, node_name, data, description, unit):
open_before = self.opened
if not open_before:
self.open()
try:
self._save_data(node_name, data, description, unit)
finally:
if not open_before:
self.close()
def get_attribute(self, node_name, attr_name):
open_before = self.opened
if not open_before:
self.open()
try:
node = self.save.get_node(node_name)
attr = node._v_attrs[attr_name]
finally:
if not open_before:
self.close()
return attr
def set_attribute(self, node_name, attr_name, attr_value):
open_before = self.opened
if not open_before:
self.open()
try:
node = self.save.get_node(node_name)
node._v_attrs[attr_name] = attr_value
finally:
if not open_before:
self.close()
def _get_units(self, unit, data=None):
"""
Get real units from info files
unit is either:
1. An instance of U.Unit (pymses unit class)
2. A string beginning by "unit_", referring to a code unit,
available in self.info
3. A dict {unit1 : exp1, unit2: exp2, ...} with unitX as 2.
and expX a float, referring to the compound unit
unit1**exp1 * unit2**exp2
4. A dict {key: unit, ...} where key is a field name (eg. 'time', or 'mass')
and unit the corresponding unit (on one on the above format)
Returns:
1-3. : a U.Unit instance
4. : a dict {key: unit, ...} with same key as input and unit being U.Unit instances
"""
if isinstance(unit, U.Unit):
return unit
if isinstance(unit, str) and unit[:5] == "unit_":
res = self.info[unit]
if unit == "unit_length":
res = res / self.info["boxlen"]
return res
if list(unit)[0][:5] == "unit_":
new_unit = U.none
for base_unit_str in unit:
expo = unit[base_unit_str]
base_unit = self._get_units(base_unit_str)
new_unit = new_unit * base_unit**expo
return new_unit
if (data is not None) and isinstance(data, dict) and list(unit)[0] in data:
for key in unit:
unit[key] = self._get_units(unit[key])
return unit
else:
raise ValueError("Invalid unit")
def _save_data(self, name_full, data, description, unit):
"""
Save data in the HDF5 structure, overwrite if necessary
"""
unit = self._get_units(unit, data=data)
if name_full in self.save:
self.save.remove_node(name_full, recursive=True)
attrs = None
if isinstance(data, tuple):
attrs = data[1]
data = data[0]
if isinstance(data, dict):
if type(description) == str:
self.save.create_group(
os.path.dirname(name_full),
os.path.basename(name_full),
description,
createparents=True,
)
else:
self.save.create_group(
os.path.dirname(name_full),
os.path.basename(name_full),
"",
createparents=True,
)
if not isinstance(unit, dict):
self.save.get_node(name_full)._v_attrs.unit = unit
for key in data:
key = str(key)
if isinstance(description, dict):
if isinstance(unit, dict):
self._save_data(
name_full + "/" + key,
data[key],
description[key],
unit[key],
)
else:
self._save_data(
name_full + "/" + key, data[key], description[key], unit
)
else:
if isinstance(unit, dict):
self._save_data(name_full + "/" + key, data[key], "", unit[key])
else:
self._save_data(name_full + "/" + key, data[key], "", unit)
else:
try:
if data is None or len(data) == 0:
return
except TypeError:
data = np.array([data])
group_name = os.path.dirname(name_full)
if group_name in self.save:
group = self.save.get_node(group_name)
if not isinstance(group, class_name_dict["Group"]):
self.logger.warning(
f"{group_name} already there and no a group, deleting"
)
self.save.remove_node(group)
self.save.create_array(
group_name,
os.path.basename(name_full),
data,
description,
createparents=True,
)
self.save.get_node(name_full).attrs.unit = unit
if attrs is not None:
for key in attrs:
key = str(key)
self.save.get_node(name_full)._v_attrs[key] = attrs[key]
def _transform(self, name, transform_fn, group="/maps", **kwargs):
src = self.save.get_node(group + "/" + name).read()
return transform_fn(src, **kwargs)
def _gen_rule_transform(
self,
rule_src_name,
transform_fn,
transform_name,
subarray_name=None,
group=None,
):
rule_src = self.rules[rule_src_name]
if subarray_name is None:
src_name = rule_src_name
group_src = rule_src.group
unit = rule_src.unit
description = rule_src.description
else:
src_name = subarray_name
group_src = rule_src.group + "/" + rule_src_name
unit = rule_src.unit[subarray_name]
description = rule_src.description[subarray_name]
def fn(arg=None, **kwargs):
if arg is None:
return self._transform(
src_name, transform_fn, group=group_src, **kwargs
)
else:
return self._transform(
src_name + "_" + str(arg), transform_fn, group=group_src, **kwargs
)
if group is None:
group = group_src
name = transform_name + "_" + rule_src_name
self.rules[name] = Rule(
fn,
group=group,
unit=unit,
description=description,
dependencies=[rule_src_name],
)
def apply(
self,
fn,
name_array_in,
name_array_out,
unit_out=U.none,
description="",
recursive=True,
):
array_in = self.get_value(name_array_in)
if recursive and isinstance(array_in, dict):
for key in array_in:
self.apply(
fn,
name_array_in + "/" + key,
name_array_out + "/" + key,
unit_out,
description,
)
self.set_attribute(name_array_out, "unit", unit_out)
else:
try:
array_out = fn(array_in)
except TypeError:
array_out = array_in
self.set_value(name_array_out, array_out, description, unit_out)
def simple_getter(name, dset):
return dset[name]
def vect_getter(name, i, dset):
return dset[name][:, i]
def oct_vect_getter(name, i, dset):
return dset[name][:, :, i]
def norm_getter(name, dset):
return np.sqrt(np.sum(dset[name] ** 2, axis=1))
def oct_norm_getter(name, dset):
return np.sqrt(np.sum(dset[name] ** 2, axis=2))