# coding: utf-8 import copy import os import time from abc import ABCMeta from functools import partial import numpy as np import tables from tables import HDF5ExtError from pp_params import default_params, load_params from units import U class Rule: def __init__( self, postproc, process, description="", group="", dependencies=[], kind="classic", unit=U.none, ): self.postproc = postproc self.process_fn = process self.dependencies = dependencies self.group = group self.description = description self.unit = unit self.kind = kind def process(self, arg, **kwargs): if arg is not None: return self.process_fn(arg, **kwargs) else: return self.process_fn(**kwargs) class BaseProcessor: """ Base class for processors, should not be instanciated """ __metaclass__ = ABCMeta log_id = "" rules = {} solve_self_dep = True def __init__(self, path, path_out=None, pp_params=None, tag=None): if pp_params is None: self.pp_params = default_params() elif type(pp_params) == str: self.pp_params = load_params(pp_params) else: self.pp_params = copy.deepcopy(pp_params) if tag is not None: self.pp_params.out.tag = tag # Determining output directory if path_out is None: self.path_out = path else: self.path_out = path_out def _log(self, string, status=""): if self.pp_params.process.verbose: if len(status) > 0: print(status + ": " + self.log_id + string) else: print(self.log_id + string) def process( self, to_process, arg=None, overwrite=False, overwrite_dep=False, select=None, **kwargs, ): """ Process the rule `to_process` """ self.overwrite_dep = overwrite_dep self.just_done = [] if to_process in self.rules: rule = self.rules[to_process] return self._solve_and_process_rule( to_process, rule, arg, overwrite, select, **kwargs ) else: self._log( "{} is unknown, allowed rules are {}".format( to_process, self.rules.keys() ), "ERROR", ) def _solve_and_process_rule( self, name, rule, arg, overwrite=False, select=None, **kwargs ): updated = self._solve_dependencies(name, rule, arg, overwrite, select) overwrite_rule = overwrite or updated return self._process_rule(name, rule, arg, overwrite_rule, select, **kwargs) def _solve_dependencies(self, name, rule, arg, overwrite=False, select=None): self.done_before_dep = len(self.just_done) # Solve dependencies for dep in rule.dependencies: # get arguments try: dep_arg = rule.dependencies[dep] except (TypeError, KeyError): dep_arg = arg if dep_arg == "__parent__": dep_arg = arg # Whether the processor solves its own dependencies or it gives # it to a child processor if self.solve_self_dep and dep in self.rules: rule_dep = self.rules[dep] self._solve_and_process_rule( dep, rule_dep, dep_arg, self.overwrite_dep, select ) else: self._not_self_dep(name, dep, dep_arg, self.overwrite_dep, select) # Whether dependencies where updated return len(self.just_done) > self.done_before_dep def _not_self_dep(self, name, dep, dep_arg, overwrite, select=None): self._log("Dependency {} for {} is unknown".format(dep, name), "ERROR") def _needs_computation(self, overwrite, name_full): return overwrite def _process_rule(self, name, rule, arg, overwrite=False, select=None, **kwargs): if arg is not None: name_full = rule.group + "/" + name + "_" + str(arg) else: name_full = rule.group + "/" + name if name_full not in self.just_done: if self._needs_computation(overwrite, name_full): self._log("Processing {}".format(name_full)) data = rule.process(arg, **kwargs) self._save_data(name_full, data, rule.description, rule.unit) self._log("Data for {} computed".format(name_full), "SUCCESS") self.just_done.append(name_full) return data else: self._log( "Data for {} is already computed, skipping...".format(name_full) ) def def_rules(self): for rule in self.rules: func = partial(self.process, rule) func.__doc__ = self.rules[rule].description setattr(self, rule, func) class HDF5Container(BaseProcessor): filename = "" save = None opened = False def open(self): if not self.opened: try: self.save = tables.open_file(self.filename, mode="a") except HDF5ExtError: # Wait a bit if the lock was not still released time.sleep(3) self.save = tables.open_file(self.filename, mode="a") self.opened = True def close(self): if self.opened: self.save.close() self.opened = False def _needs_computation(self, overwrite, name_full): return overwrite or not (name_full in self.save) def _process_rule(self, name, rule, arg, overwrite, select, **kwargs): self.open() try: super(HDF5Container, self)._process_rule( name, rule, arg, overwrite, select, **kwargs ) finally: self.close() def get_value(self, node_name, unit=None, unit_old=None): open_before = self.opened if not open_before: self.open() try: node = self.save.get_node(node_name) if "unit" in node._v_attrs: unit_old = node._v_attrs.unit if node._v_attrs.CLASS == "GROUP": value = {} for child_name in node._v_children: value[child_name] = self.get_value( node_name + "/" + child_name, unit, unit_old ) else: value = node.read() if not (unit is None or unit_old is None or unit_old == U.none): value = value * unit_old.express(unit) finally: if not open_before: self.close() return value def _get_units(self, unit, data=None): """ Get real units from info files unit is either: 1. An instance of U.Unit (pymses unit class) 2. A string beginning by "unit_", referring to a code unit, available in self.info 3. A dict {unit1 : exp1, unit2: exp2, ...} with unitX as 2. and expX a float, referring to the compound unit unit1**exp1 * unit2**exp2 4. A dict {key: unit, ...} where key is a field name (eg. 'time', or 'mass') and unit the corresponding unit (on one on the above format) Returns: 1-3. : a U.Unit instance 4. : a dict {key: unit, ...} with same key as input and unit being U.Unit instances """ if isinstance(unit, U.Unit): return unit if isinstance(unit, str) and unit[:5] == "unit_": res = self.info[unit] if unit == "unit_length": res = res / self.info["boxlen"] return res if list(unit)[0][:5] == "unit_": new_unit = U.none for base_unit_str in unit: expo = unit[base_unit_str] base_unit = self._get_units(base_unit_str) new_unit = new_unit * base_unit ** expo return new_unit if (data is not None) and isinstance(data, dict) and list(unit)[0] in data: for key in unit: unit[key] = self._get_units(unit[key]) return unit else: raise ValueError("Invalid unit") def _save_data(self, name_full, data, description, unit): """ Save data in the HDF5 structure, overwrite if necessary """ unit = self._get_units(unit, data=data) if name_full in self.save: self.save.remove_node(name_full, recursive=True) attrs = None if isinstance(data, tuple): attrs = data[1] data = data[0] if isinstance(data, dict): if type(description) == str: self.save.create_group( os.path.dirname(name_full), os.path.basename(name_full), description, createparents=True, ) else: self.save.create_group( os.path.dirname(name_full), os.path.basename(name_full), "", createparents=True, ) if not isinstance(unit, dict): self.save.get_node(name_full)._v_attrs.unit = unit for key in data: key = str(key) if isinstance(description, dict): if isinstance(unit, dict): self._save_data( name_full + "/" + key, data[key], description[key], unit[key], ) else: self._save_data( name_full + "/" + key, data[key], description[key], unit ) else: if isinstance(unit, dict): self._save_data(name_full + "/" + key, data[key], "", unit[key]) else: self._save_data(name_full + "/" + key, data[key], "", unit) else: self.save.create_array( os.path.dirname(name_full), os.path.basename(name_full), data, description, createparents=True, ) self.save.get_node(name_full).attrs.unit = unit if attrs is not None: for key in attrs: key = str(key) self.save.get_node(name_full)._v_attrs[key] = attrs[key] def set_value(self, node_name, data, description, unit): self.open() try: self._save_data(node_name, data, description, unit) finally: self.close() def get_attribute(self, node_name, attr_name): self.open() try: node = self.save.get_node(node_name) attr = node._v_attrs[attr_name] finally: self.close() return attr def set_attribute(self, node_name, attr_name, attr_value): self.open() try: node = self.save.get_node(node_name) node._v_attrs[attr_name] = attr_value finally: self.close() def _transform(self, name, transform_fn, group="/maps", **kwargs): src = self.save.get_node(group + "/" + name).read() return transform_fn(src, **kwargs) def _gen_rule_transform( self, rule_src_name, transform_fn, transform_name, subarray_name=None, group=None, ): rule_src = self.rules[rule_src_name] if subarray_name is None: src_name = rule_src_name group_src = rule_src.group unit = rule_src.unit description = rule_src.description else: src_name = subarray_name group_src = rule_src.group + "/" + rule_src_name unit = rule_src.unit[subarray_name] description = rule_src.description[subarray_name] def fn(arg=None, **kwargs): if arg is None: return self._transform( src_name, transform_fn, group=group_src, **kwargs ) else: return self._transform( src_name + "_" + str(arg), transform_fn, group=group_src, **kwargs ) if group is None: group = group_src name = transform_name + "_" + rule_src_name self.rules[name] = Rule( self, fn, group=group, unit=unit, description=description, dependencies=[rule_src_name], ) def apply( self, fn, name_array_in, name_array_out, unit_out=U.none, description="", recursive=True, ): array_in = self.get_value(name_array_in) if recursive and isinstance(array_in, dict): for key in array_in: self.apply( fn, name_array_in + "/" + key, name_array_out + "/" + key, unit_out, description, ) self.set_attribute(name_array_out, "unit", unit_out) else: try: array_out = fn(array_in) except TypeError: array_out = array_in self.set_value(name_array_out, array_out, description, unit_out) def simple_getter(name, dset): return dset[name] def vect_getter(name, i, dset): return dset[name][:, i] def norm_getter(name, dset): return np.sqrt(np.sum(dset[name] ** 2, axis=1))