561 lines
18 KiB
Python
561 lines
18 KiB
Python
# coding: utf-8
|
|
|
|
import copy
|
|
import os
|
|
import time
|
|
import logging
|
|
from abc import ABCMeta
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
|
|
import tables
|
|
from tables import HDF5ExtError, NoSuchNodeError
|
|
from tables.registry import class_name_dict
|
|
|
|
from utils.params import default_params, load_params
|
|
from utils.units import U
|
|
import sys
|
|
import traceback
|
|
|
|
|
|
class Rule:
|
|
def __init__(
|
|
self,
|
|
process,
|
|
description="",
|
|
group="",
|
|
dependencies=[],
|
|
kind="snapshot",
|
|
unit=U.none,
|
|
name="",
|
|
):
|
|
self.name = name
|
|
self.process_fn = process
|
|
self.dependencies = dependencies
|
|
self.group = group
|
|
self.description = description
|
|
self.unit = unit
|
|
self.kind = kind
|
|
|
|
def process(self, arg, **kwargs):
|
|
if arg is not None:
|
|
return self.process_fn(arg, **kwargs)
|
|
else:
|
|
return self.process_fn(**kwargs)
|
|
|
|
|
|
class BaseProcessor:
|
|
"""
|
|
Base class for processors, should not be instanciated
|
|
"""
|
|
|
|
__metaclass__ = ABCMeta
|
|
|
|
log_id = "base"
|
|
rules = {}
|
|
solve_self_dep = True
|
|
|
|
def __init__(self, path, path_out=".", params=None, tag=None):
|
|
if params is None:
|
|
self.params = default_params()
|
|
elif type(params) == str:
|
|
self.params = load_params(params)
|
|
else:
|
|
self.params = copy.deepcopy(params)
|
|
|
|
# Determining output directory
|
|
if path_out is None:
|
|
self.path_out = path
|
|
else:
|
|
self.path_out = path_out
|
|
|
|
if tag is not None:
|
|
self.params.out.tag = tag
|
|
|
|
# Initialize logger
|
|
self.logger = logging.getLogger(self.log_id)
|
|
self.logger.propagate = False
|
|
logging_format = "%(levelname)s | %(asctime)s | %(name)s.%(funcName)s:%(lineno)d | %(message)s"
|
|
formatter = logging.Formatter(logging_format, datefmt="%H:%M:%S")
|
|
|
|
if not self.logger.hasHandlers():
|
|
stream = logging.StreamHandler(sys.stdout)
|
|
stream.setFormatter(formatter)
|
|
self.logger.addHandler(stream)
|
|
|
|
if len(self.params.process.logfile) > 0:
|
|
fileHandler = logging.FileHandler(self.params.process.logfile)
|
|
fileHandler.setFormatter(formatter)
|
|
self.logger.addHandler(fileHandler)
|
|
if self.params.process.verbose:
|
|
self.logger.setLevel(logging.DEBUG)
|
|
else:
|
|
self.logger.setLevel(logging.WARNING)
|
|
|
|
for handler in self.logger.handlers:
|
|
handler.setFormatter(formatter)
|
|
|
|
def process(
|
|
self,
|
|
to_process,
|
|
arg=None,
|
|
overwrite=False,
|
|
overwrite_dep=False,
|
|
skip_dep=False,
|
|
select=None,
|
|
**kwargs,
|
|
):
|
|
self.overwrite_dep = overwrite_dep
|
|
self.just_done = []
|
|
""" Process the rule 'to_process'
|
|
|
|
Parameters
|
|
----------
|
|
to_process : str of Rule
|
|
name of the rule to process or Rule object with nonempty rule.name
|
|
arg : optional
|
|
argument to give to the rule
|
|
overwrite : bool, optional
|
|
Force redo if already done
|
|
overwrite_dep : bool, optional
|
|
Force redoing of the dependencies even if already done
|
|
skip_dep : bool, optional
|
|
Skip the dependency checks (assume they are already done)
|
|
select : dict, optional
|
|
Select object (see RunSelector) to only select some run/snapshot
|
|
"""
|
|
|
|
if to_process in self.rules:
|
|
rule = self.rules[to_process]
|
|
return self._solve_and_process_rule(
|
|
to_process, rule, arg, overwrite, skip_dep, select, **kwargs
|
|
)
|
|
elif isinstance(to_process, Rule):
|
|
rule = to_process
|
|
return self._solve_and_process_rule(
|
|
rule.name, rule, arg, overwrite, skip_dep, select, **kwargs
|
|
)
|
|
else:
|
|
self.logger.error(
|
|
"{} is unknown, allowed rules are {}".format(
|
|
to_process, self.rules.keys()
|
|
)
|
|
)
|
|
|
|
def _solve_and_process_rule(
|
|
self, name, rule, arg, overwrite=False, skip_dep=False, select=None, **kwargs
|
|
):
|
|
"""Resolve dependencies and proceed in the processing of a rule
|
|
|
|
Parameters
|
|
----------
|
|
name : str
|
|
name of the rule
|
|
rule : Rule
|
|
rule object
|
|
overwrite : bool, optional
|
|
Force redo if already done
|
|
skip_dep : bool, optional
|
|
Skip the dependency checks (assume they are already done)
|
|
select : dict, optional
|
|
Select object (see RunSelector) to only select some run/snapshot
|
|
|
|
Returns
|
|
-------
|
|
The outbut of self._process_rule
|
|
"""
|
|
updated = False
|
|
if not skip_dep:
|
|
updated = self._solve_dependencies(name, rule, arg, overwrite, select)
|
|
overwrite_rule = overwrite or updated
|
|
return self._process_rule(name, rule, arg, overwrite_rule, select, **kwargs)
|
|
|
|
def _solve_dependencies(self, name, rule, arg, overwrite=False, select=None):
|
|
|
|
self.done_before_dep = len(self.just_done)
|
|
|
|
# Solve dependencies
|
|
for dep in rule.dependencies:
|
|
# get arguments
|
|
try:
|
|
dep_arg = rule.dependencies[dep]
|
|
except (TypeError, KeyError):
|
|
dep_arg = arg
|
|
|
|
if dep_arg == "__parent__":
|
|
dep_arg = arg
|
|
|
|
# Whether the processor solves its own dependencies or it gives
|
|
# it to a child processor
|
|
if self.solve_self_dep and dep in self.rules:
|
|
rule_dep = self.rules[dep]
|
|
self._solve_and_process_rule(
|
|
dep, rule_dep, dep_arg, self.overwrite_dep, select
|
|
)
|
|
else:
|
|
self._not_self_dep(name, dep, dep_arg, self.overwrite_dep, select)
|
|
|
|
# Whether dependencies where updated
|
|
return len(self.just_done) > self.done_before_dep
|
|
|
|
def _not_self_dep(self, name, dep, dep_arg, overwrite, select=None):
|
|
self.logger.error("Dependency {} for {} is unknown".format(dep, name))
|
|
|
|
def _needs_computation(self, overwrite, name_full):
|
|
return overwrite
|
|
|
|
def _process_rule(self, name, rule, arg, overwrite=False, select=None, **kwargs):
|
|
if arg is not None and not isinstance(arg, BaseProcessor):
|
|
name_full = rule.group + "/" + name + "_" + str(arg)
|
|
else:
|
|
name_full = rule.group + "/" + name
|
|
|
|
if name_full not in self.just_done:
|
|
if self._needs_computation(overwrite, name_full):
|
|
self.logger.debug("Processing {}".format(name_full))
|
|
data = rule.process(arg, **kwargs)
|
|
self._save_data(name_full, data, rule.description, rule.unit)
|
|
self.logger.info("Data for {} computed".format(name_full))
|
|
self.just_done.append(name_full)
|
|
return data
|
|
else:
|
|
self.logger.info("Data for {} is already computed.".format(name_full))
|
|
|
|
def def_rules(self):
|
|
for rule in self.rules:
|
|
func = partial(self.process, rule)
|
|
func.__doc__ = self.rules[rule].description
|
|
setattr(self, rule, func)
|
|
|
|
|
|
class HDF5Container(BaseProcessor):
|
|
|
|
filename = ""
|
|
save = None
|
|
opened = False
|
|
|
|
def open(self):
|
|
if not self.opened:
|
|
try:
|
|
self.save = tables.open_file(self.filename, mode="a")
|
|
except HDF5ExtError:
|
|
# Wait a bit if the lock was not still released
|
|
time.sleep(3)
|
|
self.save = tables.open_file(self.filename, mode="a")
|
|
self.opened = True
|
|
|
|
def close(self):
|
|
if self.opened:
|
|
self.save.close()
|
|
self.opened = False
|
|
|
|
def _needs_computation(self, overwrite, name_full):
|
|
return overwrite or not (name_full in self.save)
|
|
|
|
def _process_rule(self, name, rule, arg, overwrite, select, **kwargs):
|
|
self.open()
|
|
try:
|
|
super(HDF5Container, self)._process_rule(
|
|
name, rule, arg, overwrite, select, **kwargs
|
|
)
|
|
except Exception as e:
|
|
if self.params.process.allow_error:
|
|
traceback_lines = traceback.format_exc().splitlines()
|
|
for line in traceback_lines:
|
|
if line != traceback_lines[-1]:
|
|
self.logger.error(line)
|
|
self.logger.error(f"{repr(e)}")
|
|
pass
|
|
else:
|
|
raise
|
|
finally:
|
|
self.close()
|
|
|
|
def get_value(self, node_name, unit=None, unit_old=None):
|
|
open_before = self.opened
|
|
if not open_before:
|
|
self.open()
|
|
try:
|
|
node = self.save.get_node(node_name)
|
|
|
|
if "unit" in node._v_attrs:
|
|
unit_old = node._v_attrs.unit
|
|
|
|
if node._v_attrs.CLASS == "GROUP":
|
|
value = {}
|
|
for child_name in node._v_children:
|
|
value[child_name] = self.get_value(
|
|
node_name + "/" + child_name, unit, unit_old
|
|
)
|
|
else:
|
|
value = node.read()
|
|
if isinstance(unit, dict):
|
|
name = os.path.basename(node_name)
|
|
if name in unit:
|
|
unit = unit[name]
|
|
else:
|
|
unit = None
|
|
if not (unit is None or unit_old is None or unit_old == U.none):
|
|
value = value * unit_old.express(unit)
|
|
except NoSuchNodeError:
|
|
self.logger.error(
|
|
f"The value {node_name} is node available", stack_info=True
|
|
)
|
|
raise
|
|
finally:
|
|
if not open_before:
|
|
self.close()
|
|
return value
|
|
|
|
def set_value(self, node_name, data, description, unit):
|
|
open_before = self.opened
|
|
if not open_before:
|
|
self.open()
|
|
try:
|
|
self._save_data(node_name, data, description, unit)
|
|
finally:
|
|
if not open_before:
|
|
self.close()
|
|
|
|
def get_attribute(self, node_name, attr_name):
|
|
open_before = self.opened
|
|
if not open_before:
|
|
self.open()
|
|
try:
|
|
node = self.save.get_node(node_name)
|
|
attr = node._v_attrs[attr_name]
|
|
finally:
|
|
if not open_before:
|
|
self.close()
|
|
return attr
|
|
|
|
def set_attribute(self, node_name, attr_name, attr_value):
|
|
open_before = self.opened
|
|
if not open_before:
|
|
self.open()
|
|
try:
|
|
node = self.save.get_node(node_name)
|
|
node._v_attrs[attr_name] = attr_value
|
|
finally:
|
|
if not open_before:
|
|
self.close()
|
|
|
|
def _get_units(self, unit, data=None):
|
|
"""
|
|
Get real units from info files
|
|
unit is either:
|
|
1. An instance of U.Unit (pymses unit class)
|
|
2. A string beginning by "unit_", referring to a code unit,
|
|
available in self.info
|
|
3. A dict {unit1 : exp1, unit2: exp2, ...} with unitX as 2.
|
|
and expX a float, referring to the compound unit
|
|
unit1**exp1 * unit2**exp2
|
|
4. A dict {key: unit, ...} where key is a field name (eg. 'time', or 'mass')
|
|
and unit the corresponding unit (on one on the above format)
|
|
|
|
Returns:
|
|
1-3. : a U.Unit instance
|
|
4. : a dict {key: unit, ...} with same key as input and unit being U.Unit instances
|
|
"""
|
|
if isinstance(unit, U.Unit):
|
|
return unit
|
|
if isinstance(unit, str) and unit[:5] == "unit_":
|
|
res = self.info[unit]
|
|
if unit == "unit_length":
|
|
res = res / self.info["boxlen"]
|
|
return res
|
|
if list(unit)[0][:5] == "unit_":
|
|
new_unit = U.none
|
|
for base_unit_str in unit:
|
|
expo = unit[base_unit_str]
|
|
base_unit = self._get_units(base_unit_str)
|
|
new_unit = new_unit * base_unit**expo
|
|
return new_unit
|
|
if (data is not None) and isinstance(data, dict) and list(unit)[0] in data:
|
|
for key in unit:
|
|
unit[key] = self._get_units(unit[key])
|
|
return unit
|
|
|
|
else:
|
|
raise ValueError("Invalid unit")
|
|
|
|
def _save_data(self, name_full, data, description, unit):
|
|
"""
|
|
Save data in the HDF5 structure, overwrite if necessary
|
|
"""
|
|
|
|
unit = self._get_units(unit, data=data)
|
|
|
|
if name_full in self.save:
|
|
self.save.remove_node(name_full, recursive=True)
|
|
|
|
attrs = None
|
|
if isinstance(data, tuple):
|
|
attrs = data[1]
|
|
data = data[0]
|
|
|
|
if isinstance(data, dict):
|
|
if type(description) == str:
|
|
self.save.create_group(
|
|
os.path.dirname(name_full),
|
|
os.path.basename(name_full),
|
|
description,
|
|
createparents=True,
|
|
)
|
|
else:
|
|
self.save.create_group(
|
|
os.path.dirname(name_full),
|
|
os.path.basename(name_full),
|
|
"",
|
|
createparents=True,
|
|
)
|
|
|
|
if not isinstance(unit, dict):
|
|
self.save.get_node(name_full)._v_attrs.unit = unit
|
|
|
|
for key in data:
|
|
key = str(key)
|
|
if isinstance(description, dict):
|
|
if isinstance(unit, dict):
|
|
self._save_data(
|
|
name_full + "/" + key,
|
|
data[key],
|
|
description[key],
|
|
unit[key],
|
|
)
|
|
else:
|
|
self._save_data(
|
|
name_full + "/" + key, data[key], description[key], unit
|
|
)
|
|
else:
|
|
if isinstance(unit, dict):
|
|
self._save_data(name_full + "/" + key, data[key], "", unit[key])
|
|
else:
|
|
self._save_data(name_full + "/" + key, data[key], "", unit)
|
|
else:
|
|
try:
|
|
if data is None or len(data) == 0:
|
|
return
|
|
except TypeError:
|
|
data = np.array([data])
|
|
|
|
group_name = os.path.dirname(name_full)
|
|
if group_name in self.save:
|
|
group = self.save.get_node(group_name)
|
|
if not isinstance(group, class_name_dict["Group"]):
|
|
self.logger.warning(
|
|
f"{group_name} already there and no a group, deleting"
|
|
)
|
|
self.save.remove_node(group)
|
|
self.save.create_array(
|
|
group_name,
|
|
os.path.basename(name_full),
|
|
data,
|
|
description,
|
|
createparents=True,
|
|
)
|
|
self.save.get_node(name_full).attrs.unit = unit
|
|
|
|
if attrs is not None:
|
|
for key in attrs:
|
|
key = str(key)
|
|
self.save.get_node(name_full)._v_attrs[key] = attrs[key]
|
|
|
|
def _transform(self, name, transform_fn, group="/maps", **kwargs):
|
|
src = self.save.get_node(group + "/" + name).read()
|
|
return transform_fn(src, **kwargs)
|
|
|
|
def _gen_rule_transform(
|
|
self,
|
|
rule_src_name,
|
|
transform_fn,
|
|
transform_name,
|
|
subarray_name=None,
|
|
group=None,
|
|
):
|
|
|
|
rule_src = self.rules[rule_src_name]
|
|
|
|
if subarray_name is None:
|
|
src_name = rule_src_name
|
|
group_src = rule_src.group
|
|
unit = rule_src.unit
|
|
description = rule_src.description
|
|
else:
|
|
src_name = subarray_name
|
|
group_src = rule_src.group + "/" + rule_src_name
|
|
unit = rule_src.unit[subarray_name]
|
|
description = rule_src.description[subarray_name]
|
|
|
|
def fn(arg=None, **kwargs):
|
|
if arg is None:
|
|
return self._transform(
|
|
src_name, transform_fn, group=group_src, **kwargs
|
|
)
|
|
else:
|
|
return self._transform(
|
|
src_name + "_" + str(arg), transform_fn, group=group_src, **kwargs
|
|
)
|
|
|
|
if group is None:
|
|
group = group_src
|
|
|
|
name = transform_name + "_" + rule_src_name
|
|
|
|
self.rules[name] = Rule(
|
|
fn,
|
|
group=group,
|
|
unit=unit,
|
|
description=description,
|
|
dependencies=[rule_src_name],
|
|
)
|
|
|
|
def apply(
|
|
self,
|
|
fn,
|
|
name_array_in,
|
|
name_array_out,
|
|
unit_out=U.none,
|
|
description="",
|
|
recursive=True,
|
|
):
|
|
array_in = self.get_value(name_array_in)
|
|
if recursive and isinstance(array_in, dict):
|
|
for key in array_in:
|
|
self.apply(
|
|
fn,
|
|
name_array_in + "/" + key,
|
|
name_array_out + "/" + key,
|
|
unit_out,
|
|
description,
|
|
)
|
|
self.set_attribute(name_array_out, "unit", unit_out)
|
|
else:
|
|
try:
|
|
array_out = fn(array_in)
|
|
except TypeError:
|
|
array_out = array_in
|
|
|
|
self.set_value(name_array_out, array_out, description, unit_out)
|
|
|
|
|
|
def simple_getter(name, dset):
|
|
return dset[name]
|
|
|
|
|
|
def vect_getter(name, i, dset):
|
|
return dset[name][:, i]
|
|
|
|
|
|
def oct_vect_getter(name, i, dset):
|
|
return dset[name][:, :, i]
|
|
|
|
|
|
def norm_getter(name, dset):
|
|
return np.sqrt(np.sum(dset[name] ** 2, axis=1))
|
|
|
|
|
|
def oct_norm_getter(name, dset):
|
|
return np.sqrt(np.sum(dset[name] ** 2, axis=2))
|