From fa952afdb0b21377d836691fec8ffc53dfa4cf28 Mon Sep 17 00:00:00 2001 From: Noe Brucy Date: Thu, 22 Jul 2021 17:48:24 +0200 Subject: [PATCH] [plotter] use getter instead of direct save access --- plotter.py | 225 +++++++++++++++++++++++------------------------------ 1 file changed, 98 insertions(+), 127 deletions(-) diff --git a/plotter.py b/plotter.py index ef8cbb1..8758edb 100644 --- a/plotter.py +++ b/plotter.py @@ -63,19 +63,6 @@ class PlotRule(Rule): Add an extra method, plot, that take the reference to an open hdf5 file (from pytables) """ - def plot(self, save, arg, **kwargs): - """ - Set the plotter's storage to 'save' and execute the rule - - Parameters - ---------- - save : opended pytables hdf5 file, where to find the data - arg : main argument of the plotting function - kargs : optional keyword arguments to the plotting function - """ - self.postproc.save = save - return self.process_fn(arg, **kwargs) - def datafile(self, name, arg): if arg is not None: name = name + "_" + str(arg) @@ -189,8 +176,7 @@ class Plotter(Aggregator, BaseProcessor): # generate astrophysix's simulations object self.gen_simus() - self.save = None - self.current_snap = None + self.current_processor = None def gen_simus(self): self.simulations = {} @@ -344,35 +330,24 @@ class Plotter(Aggregator, BaseProcessor): else: raise - # Find plot save - if from_cells or rule.kind == "cells": - self.current_snap = self.snaps[run][num] - if not os.path.exists(self.current_snap.cells_filename): - self.snaps[run][num].load_cells() - self.snaps[run][num].unload_cells() - save = tables.open_file(self.snaps[run][num].cells_filename) - elif rule.kind == "snapshot": - self.current_snap = self.snaps[run][num] - save = tables.open_file( self.current_snap.filename) + # Find underlying processor + if rule.kind == "snapshot": + self.current_processor = self.snaps[run][num] else: - save = tables.open_file(self.study.filename, "r") + self.current_processor = self.study # Call plot routine - try: - close = (not onefigure) or (i == len(run_num) - 1) - plot_info = self._plot_rule( - rule, - save, - arg, - plot_filename, - overwrite, - ax=real_ax, - close=close, - run=run, - **kwargs, - ) - finally: - save.close() + close = (not onefigure) or (i == len(run_num) - 1) + plot_info = self._plot_rule( + rule, + arg, + plot_filename, + overwrite, + ax=real_ax, + close=close, + run=run, + **kwargs, + ) # Save in astrophysix format df = rule.datafile(name, arg) @@ -404,14 +379,14 @@ class Plotter(Aggregator, BaseProcessor): return datafiles def _plot_rule( - self, rule, save, arg, plot_filename, overwrite, ax, close=True, **kwargs + self, rule, arg, plot_filename, overwrite, ax, close=True, **kwargs ): """ Once all dependencies are met, actually process the rule """ plt.sca(ax) if self._needs_computation(overwrite, plot_filename): - plot_info = rule.plot(save, arg, **kwargs) + plot_info = rule.process(arg, **kwargs) if not self.params.out.interactive and close: plt.tight_layout(pad=1) @@ -506,23 +481,22 @@ class Plotter(Aggregator, BaseProcessor): label_run = label return label_run - def _ax_label_unit(self, node, label, unit, unit_coeff, put_units=True): + def _ax_label_unit(self, node_name, label, unit, unit_coeff, put_units=True): """ Find appropriate labels for axis """ if label is None: - if "label" in node._v_attrs: - label = node._v_attrs.label - elif node._v_name in self.label_convert: - label = self.label_convert[node._v_name] - elif not node._v_title == "": - label = node._v_title - else: - label = node._v_name + try: + label = self.current_processor.get_attribute("node_name", "label") + except KeyError: + if os.path.basename(node_name) in self.label_convert: + label = self.label_convert[os.path.basename(node_name)] + else: + label = os.path.basename(node_name) - if "unit" in node._v_attrs: - unit_old = node._v_attrs.unit - else: + try: + unit_old = self.current_processor.get_attribute("node_name", "unit") + except KeyError: unit_old = U.none if unit is None: @@ -542,7 +516,7 @@ class Plotter(Aggregator, BaseProcessor): title = self.get_label_run(run, title, nml_key) if put_time: - time = self.save.root._v_attrs.time * self.study.info["unit_time"] + time = self.current_processor.info["time"] * self.study.info["unit_time"] u_str = unit_str(unit_time, format="{unit}") time_str = self.params.plot.time_fmt.format(time.express(unit_time), u_str) if len(title) > 0: @@ -590,26 +564,26 @@ class Plotter(Aggregator, BaseProcessor): ax_h = self._axes_h[ax_los] ax_v = self._axes_v[ax_los] - im_extent = np.array(self.save.root.maps._v_attrs.im_extent) - unit_length = self.save.root._v_attrs["unit_length"] + im_extent = np.array(self.current_processor.im_extent) + unit_length = self.current_processor.info["unit_length"] if embeded: axes = False scalebar = True if center_space: - center = self.save.root.maps._v_attrs.center + center = self.current_processor.get_attribute("/maps", "center") center_h = center[self._ax_nb[ax_h]] center_v = center[self._ax_nb[ax_v]] im_extent[:2] = im_extent[:2] - center_h im_extent[2:] = im_extent[2:] - center_v im_extent = im_extent * unit_length.express(unit_space) - node = self.save.get_node("/maps/{}_{}".format(name, ax_los)) - dmap = node.read() + node_name = f"/maps/{name}_{ax_los}" + dmap = self.current_processor.get_value(node_name) label, unit_old, unit = self._ax_label_unit( - node, label, unit, unit_coeff, put_units + node_name, label, unit, unit_coeff, put_units ) dmap = dmap * unit_old.express(unit) @@ -736,7 +710,7 @@ class Plotter(Aggregator, BaseProcessor): """ Add an overlay : contour of other map """ - map_contour = self.save.get_node("/maps/{}_{}".format(map_name, ax_los)).read() + map_contour = self.current_processor.get_value("/maps/{}_{}".format(map_name, ax_los)) if log: map_contour = np.log10(map_contour) # Computing linewidths @@ -797,29 +771,27 @@ class Plotter(Aggregator, BaseProcessor): center_space=False, parts = True, sinks = False, - **kwargs ): """ Add an overlay with particles data """ - unit_length = self.current_snap.info["unit_length"] + unit_length = self.current_processor.info["unit_length"] if sinks: - self.current_snap.sinks() - sinks = pd.DataFrame(self.current_snap.get_value("/datasets/sinks")) + self.current_processor.sinks() + sinks = pd.DataFrame(self.current_processor.get_value("/datasets/sinks")) part_pos = sinks[["x", "y", "z"]].values mass = sinks.M if parts: # Open particle HDF5 filetype_from_ext - filename = self.save.get_node("/hdf5/write_particles").read()[0].decode() - hdf5_parts = tables.open_file(filename, "r") - part_pos = hdf5_parts.get_node("/data/pos").read() - mass = hdf5_parts.get_node("/data/mass").read() - mass *= self.current_snap.info["unit_mass"].express(U.Msun) - hdf5_parts.close() + self.current_processor.load_parts(keys=["pos", "mass"]) + part_pos = self.current_processor.parts.pos + mass = self.current_processor.parts.mass + mass *= self.current_processor.info["unit_mass"].express(U.Msun) + self.current_processor.unload_parts() # index of the horizontal axis ih = self._ax_nb[self._axes_h[ax_los]] @@ -833,7 +805,7 @@ class Plotter(Aggregator, BaseProcessor): if center_space: ax_h = self._axes_h[ax_los] ax_v = self._axes_v[ax_los] - center = self.save.root.maps._v_attrs.center + center = self.current_processor.get_attribute("/maps", "center") center_h = center[self._ax_nb[ax_h]] center_v = center[self._ax_nb[ax_v]] part_h -= center_h @@ -861,11 +833,10 @@ class Plotter(Aggregator, BaseProcessor): """ Add an overlay : velocity vector field """ - dmap_vh_node = self.save.get_node("/maps/speed_h_{}".format(ax_los)) - dmap_vh = dmap_vh_node.read() - dmap_vv = self.save.get_node("/maps/speed_v_{}".format(ax_los)).read() + dmap_vh = self.current_processor.get_value("/maps/speed_h_{}".format(ax_los)) + dmap_vv = self.current_processor.get_value("/maps/speed_v_{}".format(ax_los)) - label, unit_old, unit = self._ax_label_unit(dmap_vh_node, "", unit, unit_coeff) + label, unit_old, unit = self._ax_label_unit(f"/maps/speed_h_{ax_los}", "", unit, unit_coeff) vel_red = self.params.plot.vel_red @@ -915,16 +886,15 @@ class Plotter(Aggregator, BaseProcessor): """ Add an overlay : magnetic streamlines """ - dmap_Bh_node = self.save.get_node("/maps/B_h_{}".format(ax_los)) - dmap_Bh = dmap_Bh_node.read() - dmap_Bv = self.save.get_node("/maps/B_v_{}".format(ax_los)).read() + dmap_Bh = self.current_processor.get_value(f"/maps/B_h_{ax_los}") + dmap_Bv = self.current_processor.get_value(f"/maps/B_v_{ax_los}") # TODO : redo this with im_extent vel_red = self.params.plot.vel_red - radius = self.save.root.maps._v_attrs.radius - center = self.save.root.maps._v_attrs.center - lbox = self.save.root._v_attrs.lbox + radius = self.current_processor.attribute("/maps", "radius") + center = self.current_processor.attribute("/maps", "center") + lbox = self.current_processor.lbox map_Bh_red = dmap_Bh[::vel_red, ::vel_red] # take only a subset of velocities map_Bv_red = dmap_Bv[::vel_red, ::vel_red] @@ -973,22 +943,23 @@ class Plotter(Aggregator, BaseProcessor): # Get node if ax_los is not None: name = name + "_" + ax_los - node = self.save.get_node(group + name) + node_name = group + name if xlog is None: try: - xlog = node._v_attrs.logbins + xlog = self.current_processor.get_attribute(node_name, "logbins") except AttributeError: xlog = False # get label and units - xlabel, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff) + xlabel, unit_old, unit = self._ax_label_unit(node_name, label, unit, unit_coeff) # Read data + node = self.current_processor.get_value(node_name) if "mean" in node: - index = node["runs"].read().index(run.encode()) - values, centers = node["mean"].read()[index] + index = node["runs"].index(run.encode()) + values, centers = node["mean"][index] else: - values, centers = node.read() + values, centers = node if xlog: centers = centers + np.log10(unit_old.express(unit)) else: @@ -1034,16 +1005,16 @@ class Plotter(Aggregator, BaseProcessor): plt.ylabel(ylabel) # Also diplay fit, previously saved - if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save: - slope = node.attrs.slope - origin = node.attrs.origin - plt.plot( - centers, - 10 ** (slope * centers + origin), - "--", - linewidth=2, - color="orange", - ) + # if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save: + # slope = node.attrs.slope + # origin = node.attrs.origin + # plt.plot( + # centers, + # 10 ** (slope * centers + origin), + # "--", + # linewidth=2, + # color="orange", + # ) # or a new one if fit is not None: self._overlay_fit( @@ -1108,28 +1079,28 @@ class Plotter(Aggregator, BaseProcessor): name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg # Get hdf5 nodes - node_x = self.save.get_node(name_x) - node_y = self.save.get_node(name_y) + node_x = self.current_processor.get_value(name_x) + node_y = self.current_processor.get_value(name_y) # If the actual data is in another file, fetch it if subname_x: - hdf5_x = tables.open_file(node_x.read()) - node_x = hdf5_x.get_node(subname_x) + hdf5_x = tables.open_file(node_x) + node_x = hdf5_x.get_node(subname_x).read() if subname_y: - hdf5_y = tables.open_file(node_y.read()) - node_y = hdf5_y.get_node(subname_y) + hdf5_y = tables.open_file(node_y) + node_y = hdf5_y.get_node(subname_y).read() # Find proper labels xlabel, xunit_old, xunit = self._ax_label_unit( - node_x, xlabel, xunit, xunit_coeff + name_x, xlabel, xunit, xunit_coeff ) ylabel, yunit_old, yunit = self._ax_label_unit( - node_y, ylabel, yunit, yunit_coeff + name_y, ylabel, yunit, yunit_coeff ) # If relevent, get time if put_time: - time = self.save.root._v_attrs.time * self.study.info["unit_time"] + time = self.current_processor.time * self.study.info["unit_time"] time_str = self.params.plot.time_fmt.format( time.express(unit_time), unit_time.latex.replace("text", "math") ) @@ -1141,28 +1112,28 @@ class Plotter(Aggregator, BaseProcessor): # Manage the different forms in which the data may be stored : # Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs) - if node_y._v_attrs.CLASS == "ARRAY": - x = node_x.read() * xunit_old.express(xunit) - y = node_y.read() * yunit_old.express(yunit) + if isinstance(node_y, np.ndarray): + x = node_x * xunit_old.express(xunit) + y = node_y * yunit_old.express(yunit) mask = np.isfinite(x) & np.isfinite(y) x, y = x[mask], y[mask] elif "mean" in node_y: - x = node_x.read() * xunit_old.express(xunit) - y = node_y.mean.read() * yunit_old.express(yunit) + x = node_x * xunit_old.express(xunit) + y = node_y.mean * yunit_old.express(yunit) if yerr_kind == "std": - std = node_y.std.read() * yunit_old.express(yunit) + std = node_y["std"] * yunit_old.express(yunit) yerr_min = y - sigma_err * std yerr_max = y + sigma_err * std elif yerr_kind == "min_max": - yerr_min = node_y.min.read() * yunit_old.express(yunit) - yerr_max = node_y.max.read() * yunit_old.express(yunit) + yerr_min = node_y["min"] * yunit_old.express(yunit) + yerr_max = node_y["max"] * yunit_old.express(yunit) elif yerr_kind == "95per": - yerr_min = node_y.q025.read() * yunit_old.express(yunit) - yerr_max = node_y.q975.read() * yunit_old.express(yunit) + yerr_min = node_y["q025"] * yunit_old.express(yunit) + yerr_max = node_y["q975"] * yunit_old.express(yunit) elif yerr_kind == "68per": - yerr_min = node_y.q16.read() * yunit_old.express(yunit) - yerr_max = node_y.q84.read() * yunit_old.express(yunit) + yerr_min = node_y["q16"] * yunit_old.express(yunit) + yerr_max = node_y["q84"] * yunit_old.express(yunit) else: yerr_min = y yerr_max = y @@ -1176,13 +1147,13 @@ class Plotter(Aggregator, BaseProcessor): yerr_max[mask], ) else: - x = node_x[run].read() * xunit_old.express(xunit) - y = node_y[run].read() * yunit_old.express(yunit) + x = node_x[run] * xunit_old.express(xunit) + y = node_y[run] * yunit_old.express(yunit) mask = np.isfinite(x) & np.isfinite(y) x, y = x[mask], y[mask] if isinstance(yerr, str): - yerr = self.save.get_node(yerr).read() + yerr = self.current_processor.get_value(yerr) # Apply transformations on x if xtransform is not None: @@ -1211,7 +1182,7 @@ class Plotter(Aggregator, BaseProcessor): color = colors[run] elif nml_color == "time": time = ( - self.save.root._v_attrs.time * self.study.info["unit_time"] + self.current_processor.time * self.current_processor.info["unit_time"] ).express(unit_time) color = colors(time) else: @@ -1261,8 +1232,8 @@ class Plotter(Aggregator, BaseProcessor): Plot power spectrum (wrapper around pspec_read) """ del kwargs["run"] - file_pspec = self.save.get_node("/hdf5/pspec").read() - num = self.save.root._v_attrs.num + file_pspec = self.current_processor.get_value("/hdf5/pspec") + num = self.current_processor.num getattr(pspec_read, "pspec_" + name)(file_pspec, ".", num, **kwargs) def _overlay_fit(self, x, y, yerr=None, kind="linear", label=None, **kwargs):