# coding: utf-8 import copy import os import time import logging from abc import ABCMeta from functools import partial import numpy as np import tables from tables import HDF5ExtError, NoSuchNodeError from tables.registry import class_name_dict from utils.params import default_params, load_params from utils.units import U import sys import traceback class Rule: def __init__( self, process, description="", group="", dependencies=[], kind="snapshot", unit=U.none, name="", ): self.name = name self.process_fn = process self.dependencies = dependencies self.group = group self.description = description self.unit = unit self.kind = kind def process(self, arg, **kwargs): if arg is not None: return self.process_fn(arg, **kwargs) else: return self.process_fn(**kwargs) class BaseProcessor: """ Base class for processors, should not be instanciated """ __metaclass__ = ABCMeta log_id = "base" rules = {} solve_self_dep = True def __init__(self, path, path_out=".", params=None, tag=None): if params is None: self.params = default_params() elif type(params) == str: self.params = load_params(params) else: self.params = copy.deepcopy(params) # Determining output directory if path_out is None: self.path_out = path else: self.path_out = path_out if tag is not None: self.params.out.tag = tag # Initialize logger self.logger = logging.getLogger(self.log_id) self.logger.propagate = False logging_format = "%(levelname)s | %(asctime)s | %(name)s.%(funcName)s:%(lineno)d | %(message)s" formatter = logging.Formatter(logging_format, datefmt="%H:%M:%S") if not self.logger.hasHandlers(): stream = logging.StreamHandler(sys.stdout) stream.setFormatter(formatter) self.logger.addHandler(stream) if len(self.params.process.logfile) > 0: fileHandler = logging.FileHandler(self.params.process.logfile) fileHandler.setFormatter(formatter) self.logger.addHandler(fileHandler) if self.params.process.verbose: self.logger.setLevel(logging.DEBUG) else: self.logger.setLevel(logging.WARNING) for handler in self.logger.handlers: handler.setFormatter(formatter) def process( self, to_process, arg=None, overwrite=False, overwrite_dep=False, skip_dep=False, select=None, **kwargs, ): self.overwrite_dep = overwrite_dep self.just_done = [] """ Process the rule 'to_process' Parameters ---------- to_process : str of Rule name of the rule to process or Rule object with nonempty rule.name arg : optional argument to give to the rule overwrite : bool, optional Force redo if already done overwrite_dep : bool, optional Force redoing of the dependencies even if already done skip_dep : bool, optional Skip the dependency checks (assume they are already done) select : dict, optional Select object (see RunSelector) to only select some run/snapshot """ if to_process in self.rules: rule = self.rules[to_process] return self._solve_and_process_rule( to_process, rule, arg, overwrite, skip_dep, select, **kwargs ) elif isinstance(to_process, Rule): rule = to_process return self._solve_and_process_rule( rule.name, rule, arg, overwrite, skip_dep, select, **kwargs ) else: self.logger.error( "{} is unknown, allowed rules are {}".format( to_process, self.rules.keys() ) ) def _solve_and_process_rule( self, name, rule, arg, overwrite=False, skip_dep=False, select=None, **kwargs ): """Resolve dependencies and proceed in the processing of a rule Parameters ---------- name : str name of the rule rule : Rule rule object overwrite : bool, optional Force redo if already done skip_dep : bool, optional Skip the dependency checks (assume they are already done) select : dict, optional Select object (see RunSelector) to only select some run/snapshot Returns ------- The outbut of self._process_rule """ updated = False if not skip_dep: updated = self._solve_dependencies(name, rule, arg, overwrite, select) overwrite_rule = overwrite or updated return self._process_rule(name, rule, arg, overwrite_rule, select, **kwargs) def _solve_dependencies(self, name, rule, arg, overwrite=False, select=None): self.done_before_dep = len(self.just_done) # Solve dependencies for dep in rule.dependencies: # get arguments try: dep_arg = rule.dependencies[dep] except (TypeError, KeyError): dep_arg = arg if dep_arg == "__parent__": dep_arg = arg # Whether the processor solves its own dependencies or it gives # it to a child processor if self.solve_self_dep and dep in self.rules: rule_dep = self.rules[dep] self._solve_and_process_rule( dep, rule_dep, dep_arg, self.overwrite_dep, select ) else: self._not_self_dep(name, dep, dep_arg, self.overwrite_dep, select) # Whether dependencies where updated return len(self.just_done) > self.done_before_dep def _not_self_dep(self, name, dep, dep_arg, overwrite, select=None): self.logger.error("Dependency {} for {} is unknown".format(dep, name)) def _needs_computation(self, overwrite, name_full): return overwrite def _process_rule(self, name, rule, arg, overwrite=False, select=None, **kwargs): if arg is not None and not isinstance(arg, BaseProcessor): name_full = rule.group + "/" + name + "_" + str(arg) else: name_full = rule.group + "/" + name if name_full not in self.just_done: if self._needs_computation(overwrite, name_full): self.logger.debug("Processing {}".format(name_full)) data = rule.process(arg, **kwargs) self._save_data(name_full, data, rule.description, rule.unit) self.logger.info("Data for {} computed".format(name_full)) self.just_done.append(name_full) return data else: self.logger.info("Data for {} is already computed.".format(name_full)) def def_rules(self): for rule in self.rules: func = partial(self.process, rule) func.__doc__ = self.rules[rule].description setattr(self, rule, func) class HDF5Container(BaseProcessor): filename = "" save = None opened = False def open(self): if not self.opened: try: self.save = tables.open_file(self.filename, mode="a") except HDF5ExtError: # Wait a bit if the lock was not still released time.sleep(3) self.save = tables.open_file(self.filename, mode="a") self.opened = True def close(self): if self.opened: self.save.close() self.opened = False def _needs_computation(self, overwrite, name_full): return overwrite or not (name_full in self.save) def _process_rule(self, name, rule, arg, overwrite, select, **kwargs): self.open() try: super(HDF5Container, self)._process_rule( name, rule, arg, overwrite, select, **kwargs ) except Exception as e: if self.params.process.allow_error: traceback_lines = traceback.format_exc().splitlines() for line in traceback_lines: if line != traceback_lines[-1]: self.logger.error(line) self.logger.error(f"{repr(e)}") pass else: raise finally: self.close() def get_value(self, node_name, unit=None, unit_old=None): open_before = self.opened if not open_before: self.open() try: node = self.save.get_node(node_name) if "unit" in node._v_attrs: unit_old = node._v_attrs.unit if node._v_attrs.CLASS == "GROUP": value = {} for child_name in node._v_children: value[child_name] = self.get_value( node_name + "/" + child_name, unit, unit_old ) else: value = node.read() if isinstance(unit, dict): name = os.path.basename(node_name) if name in unit: unit = unit[name] else: unit = None if not (unit is None or unit_old is None or unit_old == U.none): value = value * unit_old.express(unit) except NoSuchNodeError: self.logger.error( f"The value {node_name} is node available", stack_info=True ) raise finally: if not open_before: self.close() return value def set_value(self, node_name, data, description, unit): open_before = self.opened if not open_before: self.open() try: self._save_data(node_name, data, description, unit) finally: if not open_before: self.close() def get_attribute(self, node_name, attr_name): open_before = self.opened if not open_before: self.open() try: node = self.save.get_node(node_name) attr = node._v_attrs[attr_name] finally: if not open_before: self.close() return attr def set_attribute(self, node_name, attr_name, attr_value): open_before = self.opened if not open_before: self.open() try: node = self.save.get_node(node_name) node._v_attrs[attr_name] = attr_value finally: if not open_before: self.close() def _get_units(self, unit, data=None): """ Get real units from info files unit is either: 1. An instance of U.Unit (pymses unit class) 2. A string beginning by "unit_", referring to a code unit, available in self.info 3. A dict {unit1 : exp1, unit2: exp2, ...} with unitX as 2. and expX a float, referring to the compound unit unit1**exp1 * unit2**exp2 4. A dict {key: unit, ...} where key is a field name (eg. 'time', or 'mass') and unit the corresponding unit (on one on the above format) Returns: 1-3. : a U.Unit instance 4. : a dict {key: unit, ...} with same key as input and unit being U.Unit instances """ if isinstance(unit, U.Unit): return unit if isinstance(unit, str) and unit[:5] == "unit_": res = self.info[unit] if unit == "unit_length": res = res / self.info["boxlen"] return res if list(unit)[0][:5] == "unit_": new_unit = U.none for base_unit_str in unit: expo = unit[base_unit_str] base_unit = self._get_units(base_unit_str) new_unit = new_unit * base_unit**expo return new_unit if (data is not None) and isinstance(data, dict) and list(unit)[0] in data: for key in unit: unit[key] = self._get_units(unit[key]) return unit else: raise ValueError("Invalid unit") def _save_data(self, name_full, data, description, unit): """ Save data in the HDF5 structure, overwrite if necessary """ unit = self._get_units(unit, data=data) if name_full in self.save: self.save.remove_node(name_full, recursive=True) attrs = None if isinstance(data, tuple): attrs = data[1] data = data[0] if isinstance(data, dict): if type(description) == str: self.save.create_group( os.path.dirname(name_full), os.path.basename(name_full), description, createparents=True, ) else: self.save.create_group( os.path.dirname(name_full), os.path.basename(name_full), "", createparents=True, ) if not isinstance(unit, dict): self.save.get_node(name_full)._v_attrs.unit = unit for key in data: key = str(key) if isinstance(description, dict): if isinstance(unit, dict): self._save_data( name_full + "/" + key, data[key], description[key], unit[key], ) else: self._save_data( name_full + "/" + key, data[key], description[key], unit ) else: if isinstance(unit, dict): self._save_data(name_full + "/" + key, data[key], "", unit[key]) else: self._save_data(name_full + "/" + key, data[key], "", unit) else: try: if data is None or len(data) == 0: return except TypeError: data = np.array([data]) group_name = os.path.dirname(name_full) if group_name in self.save: group = self.save.get_node(group_name) if not isinstance(group, class_name_dict["Group"]): self.logger.warning( f"{group_name} already there and no a group, deleting" ) self.save.remove_node(group) self.save.create_array( group_name, os.path.basename(name_full), data, description, createparents=True, ) self.save.get_node(name_full).attrs.unit = unit if attrs is not None: for key in attrs: key = str(key) self.save.get_node(name_full)._v_attrs[key] = attrs[key] def _transform(self, name, transform_fn, group="/maps", **kwargs): src = self.save.get_node(group + "/" + name).read() return transform_fn(src, **kwargs) def _gen_rule_transform( self, rule_src_name, transform_fn, transform_name, subarray_name=None, group=None, ): rule_src = self.rules[rule_src_name] if subarray_name is None: src_name = rule_src_name group_src = rule_src.group unit = rule_src.unit description = rule_src.description else: src_name = subarray_name group_src = rule_src.group + "/" + rule_src_name unit = rule_src.unit[subarray_name] description = rule_src.description[subarray_name] def fn(arg=None, **kwargs): if arg is None: return self._transform( src_name, transform_fn, group=group_src, **kwargs ) else: return self._transform( src_name + "_" + str(arg), transform_fn, group=group_src, **kwargs ) if group is None: group = group_src name = transform_name + "_" + rule_src_name self.rules[name] = Rule( fn, group=group, unit=unit, description=description, dependencies=[rule_src_name], ) def apply( self, fn, name_array_in, name_array_out, unit_out=U.none, description="", recursive=True, ): array_in = self.get_value(name_array_in) if recursive and isinstance(array_in, dict): for key in array_in: self.apply( fn, name_array_in + "/" + key, name_array_out + "/" + key, unit_out, description, ) self.set_attribute(name_array_out, "unit", unit_out) else: try: array_out = fn(array_in) except TypeError: array_out = array_in self.set_value(name_array_out, array_out, description, unit_out) def simple_getter(name, dset): return dset[name] def vect_getter(name, i, dset): return dset[name][:, i] def oct_vect_getter(name, i, dset): return dset[name][:, :, i] def norm_getter(name, dset): return np.sqrt(np.sum(dset[name] ** 2, axis=1)) def oct_norm_getter(name, dset): return np.sqrt(np.sum(dset[name] ** 2, axis=2))