[plotter] use getter instead of direct save access

This commit is contained in:
Noe Brucy
2021-07-22 17:48:24 +02:00
parent 4fa72eccc5
commit fa952afdb0
+98 -127
View File
@@ -63,19 +63,6 @@ class PlotRule(Rule):
Add an extra method, plot, that take the reference to an open hdf5 file (from pytables) Add an extra method, plot, that take the reference to an open hdf5 file (from pytables)
""" """
def plot(self, save, arg, **kwargs):
"""
Set the plotter's storage to 'save' and execute the rule
Parameters
----------
save : opended pytables hdf5 file, where to find the data
arg : main argument of the plotting function
kargs : optional keyword arguments to the plotting function
"""
self.postproc.save = save
return self.process_fn(arg, **kwargs)
def datafile(self, name, arg): def datafile(self, name, arg):
if arg is not None: if arg is not None:
name = name + "_" + str(arg) name = name + "_" + str(arg)
@@ -189,8 +176,7 @@ class Plotter(Aggregator, BaseProcessor):
# generate astrophysix's simulations object # generate astrophysix's simulations object
self.gen_simus() self.gen_simus()
self.save = None self.current_processor = None
self.current_snap = None
def gen_simus(self): def gen_simus(self):
self.simulations = {} self.simulations = {}
@@ -344,35 +330,24 @@ class Plotter(Aggregator, BaseProcessor):
else: else:
raise raise
# Find plot save # Find underlying processor
if from_cells or rule.kind == "cells": if rule.kind == "snapshot":
self.current_snap = self.snaps[run][num] self.current_processor = self.snaps[run][num]
if not os.path.exists(self.current_snap.cells_filename):
self.snaps[run][num].load_cells()
self.snaps[run][num].unload_cells()
save = tables.open_file(self.snaps[run][num].cells_filename)
elif rule.kind == "snapshot":
self.current_snap = self.snaps[run][num]
save = tables.open_file( self.current_snap.filename)
else: else:
save = tables.open_file(self.study.filename, "r") self.current_processor = self.study
# Call plot routine # Call plot routine
try: close = (not onefigure) or (i == len(run_num) - 1)
close = (not onefigure) or (i == len(run_num) - 1) plot_info = self._plot_rule(
plot_info = self._plot_rule( rule,
rule, arg,
save, plot_filename,
arg, overwrite,
plot_filename, ax=real_ax,
overwrite, close=close,
ax=real_ax, run=run,
close=close, **kwargs,
run=run, )
**kwargs,
)
finally:
save.close()
# Save in astrophysix format # Save in astrophysix format
df = rule.datafile(name, arg) df = rule.datafile(name, arg)
@@ -404,14 +379,14 @@ class Plotter(Aggregator, BaseProcessor):
return datafiles return datafiles
def _plot_rule( def _plot_rule(
self, rule, save, arg, plot_filename, overwrite, ax, close=True, **kwargs self, rule, arg, plot_filename, overwrite, ax, close=True, **kwargs
): ):
""" """
Once all dependencies are met, actually process the rule Once all dependencies are met, actually process the rule
""" """
plt.sca(ax) plt.sca(ax)
if self._needs_computation(overwrite, plot_filename): if self._needs_computation(overwrite, plot_filename):
plot_info = rule.plot(save, arg, **kwargs) plot_info = rule.process(arg, **kwargs)
if not self.params.out.interactive and close: if not self.params.out.interactive and close:
plt.tight_layout(pad=1) plt.tight_layout(pad=1)
@@ -506,23 +481,22 @@ class Plotter(Aggregator, BaseProcessor):
label_run = label label_run = label
return label_run return label_run
def _ax_label_unit(self, node, label, unit, unit_coeff, put_units=True): def _ax_label_unit(self, node_name, label, unit, unit_coeff, put_units=True):
""" """
Find appropriate labels for axis Find appropriate labels for axis
""" """
if label is None: if label is None:
if "label" in node._v_attrs: try:
label = node._v_attrs.label label = self.current_processor.get_attribute("node_name", "label")
elif node._v_name in self.label_convert: except KeyError:
label = self.label_convert[node._v_name] if os.path.basename(node_name) in self.label_convert:
elif not node._v_title == "": label = self.label_convert[os.path.basename(node_name)]
label = node._v_title else:
else: label = os.path.basename(node_name)
label = node._v_name
if "unit" in node._v_attrs: try:
unit_old = node._v_attrs.unit unit_old = self.current_processor.get_attribute("node_name", "unit")
else: except KeyError:
unit_old = U.none unit_old = U.none
if unit is None: if unit is None:
@@ -542,7 +516,7 @@ class Plotter(Aggregator, BaseProcessor):
title = self.get_label_run(run, title, nml_key) title = self.get_label_run(run, title, nml_key)
if put_time: if put_time:
time = self.save.root._v_attrs.time * self.study.info["unit_time"] time = self.current_processor.info["time"] * self.study.info["unit_time"]
u_str = unit_str(unit_time, format="{unit}") u_str = unit_str(unit_time, format="{unit}")
time_str = self.params.plot.time_fmt.format(time.express(unit_time), u_str) time_str = self.params.plot.time_fmt.format(time.express(unit_time), u_str)
if len(title) > 0: if len(title) > 0:
@@ -590,26 +564,26 @@ class Plotter(Aggregator, BaseProcessor):
ax_h = self._axes_h[ax_los] ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los] ax_v = self._axes_v[ax_los]
im_extent = np.array(self.save.root.maps._v_attrs.im_extent) im_extent = np.array(self.current_processor.im_extent)
unit_length = self.save.root._v_attrs["unit_length"] unit_length = self.current_processor.info["unit_length"]
if embeded: if embeded:
axes = False axes = False
scalebar = True scalebar = True
if center_space: if center_space:
center = self.save.root.maps._v_attrs.center center = self.current_processor.get_attribute("/maps", "center")
center_h = center[self._ax_nb[ax_h]] center_h = center[self._ax_nb[ax_h]]
center_v = center[self._ax_nb[ax_v]] center_v = center[self._ax_nb[ax_v]]
im_extent[:2] = im_extent[:2] - center_h im_extent[:2] = im_extent[:2] - center_h
im_extent[2:] = im_extent[2:] - center_v im_extent[2:] = im_extent[2:] - center_v
im_extent = im_extent * unit_length.express(unit_space) im_extent = im_extent * unit_length.express(unit_space)
node = self.save.get_node("/maps/{}_{}".format(name, ax_los)) node_name = f"/maps/{name}_{ax_los}"
dmap = node.read() dmap = self.current_processor.get_value(node_name)
label, unit_old, unit = self._ax_label_unit( label, unit_old, unit = self._ax_label_unit(
node, label, unit, unit_coeff, put_units node_name, label, unit, unit_coeff, put_units
) )
dmap = dmap * unit_old.express(unit) dmap = dmap * unit_old.express(unit)
@@ -736,7 +710,7 @@ class Plotter(Aggregator, BaseProcessor):
""" """
Add an overlay : contour of other map Add an overlay : contour of other map
""" """
map_contour = self.save.get_node("/maps/{}_{}".format(map_name, ax_los)).read() map_contour = self.current_processor.get_value("/maps/{}_{}".format(map_name, ax_los))
if log: if log:
map_contour = np.log10(map_contour) map_contour = np.log10(map_contour)
# Computing linewidths # Computing linewidths
@@ -797,29 +771,27 @@ class Plotter(Aggregator, BaseProcessor):
center_space=False, center_space=False,
parts = True, parts = True,
sinks = False, sinks = False,
**kwargs **kwargs
): ):
""" """
Add an overlay with particles data Add an overlay with particles data
""" """
unit_length = self.current_snap.info["unit_length"] unit_length = self.current_processor.info["unit_length"]
if sinks: if sinks:
self.current_snap.sinks() self.current_processor.sinks()
sinks = pd.DataFrame(self.current_snap.get_value("/datasets/sinks")) sinks = pd.DataFrame(self.current_processor.get_value("/datasets/sinks"))
part_pos = sinks[["x", "y", "z"]].values part_pos = sinks[["x", "y", "z"]].values
mass = sinks.M mass = sinks.M
if parts: if parts:
# Open particle HDF5 filetype_from_ext # Open particle HDF5 filetype_from_ext
filename = self.save.get_node("/hdf5/write_particles").read()[0].decode() self.current_processor.load_parts(keys=["pos", "mass"])
hdf5_parts = tables.open_file(filename, "r") part_pos = self.current_processor.parts.pos
part_pos = hdf5_parts.get_node("/data/pos").read() mass = self.current_processor.parts.mass
mass = hdf5_parts.get_node("/data/mass").read() mass *= self.current_processor.info["unit_mass"].express(U.Msun)
mass *= self.current_snap.info["unit_mass"].express(U.Msun) self.current_processor.unload_parts()
hdf5_parts.close()
# index of the horizontal axis # index of the horizontal axis
ih = self._ax_nb[self._axes_h[ax_los]] ih = self._ax_nb[self._axes_h[ax_los]]
@@ -833,7 +805,7 @@ class Plotter(Aggregator, BaseProcessor):
if center_space: if center_space:
ax_h = self._axes_h[ax_los] ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los] ax_v = self._axes_v[ax_los]
center = self.save.root.maps._v_attrs.center center = self.current_processor.get_attribute("/maps", "center")
center_h = center[self._ax_nb[ax_h]] center_h = center[self._ax_nb[ax_h]]
center_v = center[self._ax_nb[ax_v]] center_v = center[self._ax_nb[ax_v]]
part_h -= center_h part_h -= center_h
@@ -861,11 +833,10 @@ class Plotter(Aggregator, BaseProcessor):
""" """
Add an overlay : velocity vector field Add an overlay : velocity vector field
""" """
dmap_vh_node = self.save.get_node("/maps/speed_h_{}".format(ax_los)) dmap_vh = self.current_processor.get_value("/maps/speed_h_{}".format(ax_los))
dmap_vh = dmap_vh_node.read() dmap_vv = self.current_processor.get_value("/maps/speed_v_{}".format(ax_los))
dmap_vv = self.save.get_node("/maps/speed_v_{}".format(ax_los)).read()
label, unit_old, unit = self._ax_label_unit(dmap_vh_node, "", unit, unit_coeff) label, unit_old, unit = self._ax_label_unit(f"/maps/speed_h_{ax_los}", "", unit, unit_coeff)
vel_red = self.params.plot.vel_red vel_red = self.params.plot.vel_red
@@ -915,16 +886,15 @@ class Plotter(Aggregator, BaseProcessor):
""" """
Add an overlay : magnetic streamlines Add an overlay : magnetic streamlines
""" """
dmap_Bh_node = self.save.get_node("/maps/B_h_{}".format(ax_los)) dmap_Bh = self.current_processor.get_value(f"/maps/B_h_{ax_los}")
dmap_Bh = dmap_Bh_node.read() dmap_Bv = self.current_processor.get_value(f"/maps/B_v_{ax_los}")
dmap_Bv = self.save.get_node("/maps/B_v_{}".format(ax_los)).read()
# TODO : redo this with im_extent # TODO : redo this with im_extent
vel_red = self.params.plot.vel_red vel_red = self.params.plot.vel_red
radius = self.save.root.maps._v_attrs.radius radius = self.current_processor.attribute("/maps", "radius")
center = self.save.root.maps._v_attrs.center center = self.current_processor.attribute("/maps", "center")
lbox = self.save.root._v_attrs.lbox lbox = self.current_processor.lbox
map_Bh_red = dmap_Bh[::vel_red, ::vel_red] # take only a subset of velocities map_Bh_red = dmap_Bh[::vel_red, ::vel_red] # take only a subset of velocities
map_Bv_red = dmap_Bv[::vel_red, ::vel_red] map_Bv_red = dmap_Bv[::vel_red, ::vel_red]
@@ -973,22 +943,23 @@ class Plotter(Aggregator, BaseProcessor):
# Get node # Get node
if ax_los is not None: if ax_los is not None:
name = name + "_" + ax_los name = name + "_" + ax_los
node = self.save.get_node(group + name) node_name = group + name
if xlog is None: if xlog is None:
try: try:
xlog = node._v_attrs.logbins xlog = self.current_processor.get_attribute(node_name, "logbins")
except AttributeError: except AttributeError:
xlog = False xlog = False
# get label and units # get label and units
xlabel, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff) xlabel, unit_old, unit = self._ax_label_unit(node_name, label, unit, unit_coeff)
# Read data # Read data
node = self.current_processor.get_value(node_name)
if "mean" in node: if "mean" in node:
index = node["runs"].read().index(run.encode()) index = node["runs"].index(run.encode())
values, centers = node["mean"].read()[index] values, centers = node["mean"][index]
else: else:
values, centers = node.read() values, centers = node
if xlog: if xlog:
centers = centers + np.log10(unit_old.express(unit)) centers = centers + np.log10(unit_old.express(unit))
else: else:
@@ -1034,16 +1005,16 @@ class Plotter(Aggregator, BaseProcessor):
plt.ylabel(ylabel) plt.ylabel(ylabel)
# Also diplay fit, previously saved # Also diplay fit, previously saved
if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save: # if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
slope = node.attrs.slope # slope = node.attrs.slope
origin = node.attrs.origin # origin = node.attrs.origin
plt.plot( # plt.plot(
centers, # centers,
10 ** (slope * centers + origin), # 10 ** (slope * centers + origin),
"--", # "--",
linewidth=2, # linewidth=2,
color="orange", # color="orange",
) # )
# or a new one # or a new one
if fit is not None: if fit is not None:
self._overlay_fit( self._overlay_fit(
@@ -1108,28 +1079,28 @@ class Plotter(Aggregator, BaseProcessor):
name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg
# Get hdf5 nodes # Get hdf5 nodes
node_x = self.save.get_node(name_x) node_x = self.current_processor.get_value(name_x)
node_y = self.save.get_node(name_y) node_y = self.current_processor.get_value(name_y)
# If the actual data is in another file, fetch it # If the actual data is in another file, fetch it
if subname_x: if subname_x:
hdf5_x = tables.open_file(node_x.read()) hdf5_x = tables.open_file(node_x)
node_x = hdf5_x.get_node(subname_x) node_x = hdf5_x.get_node(subname_x).read()
if subname_y: if subname_y:
hdf5_y = tables.open_file(node_y.read()) hdf5_y = tables.open_file(node_y)
node_y = hdf5_y.get_node(subname_y) node_y = hdf5_y.get_node(subname_y).read()
# Find proper labels # Find proper labels
xlabel, xunit_old, xunit = self._ax_label_unit( xlabel, xunit_old, xunit = self._ax_label_unit(
node_x, xlabel, xunit, xunit_coeff name_x, xlabel, xunit, xunit_coeff
) )
ylabel, yunit_old, yunit = self._ax_label_unit( ylabel, yunit_old, yunit = self._ax_label_unit(
node_y, ylabel, yunit, yunit_coeff name_y, ylabel, yunit, yunit_coeff
) )
# If relevent, get time # If relevent, get time
if put_time: if put_time:
time = self.save.root._v_attrs.time * self.study.info["unit_time"] time = self.current_processor.time * self.study.info["unit_time"]
time_str = self.params.plot.time_fmt.format( time_str = self.params.plot.time_fmt.format(
time.express(unit_time), unit_time.latex.replace("text", "math") time.express(unit_time), unit_time.latex.replace("text", "math")
) )
@@ -1141,28 +1112,28 @@ class Plotter(Aggregator, BaseProcessor):
# Manage the different forms in which the data may be stored : # Manage the different forms in which the data may be stored :
# Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs) # Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs)
if node_y._v_attrs.CLASS == "ARRAY": if isinstance(node_y, np.ndarray):
x = node_x.read() * xunit_old.express(xunit) x = node_x * xunit_old.express(xunit)
y = node_y.read() * yunit_old.express(yunit) y = node_y * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y) mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask] x, y = x[mask], y[mask]
elif "mean" in node_y: elif "mean" in node_y:
x = node_x.read() * xunit_old.express(xunit) x = node_x * xunit_old.express(xunit)
y = node_y.mean.read() * yunit_old.express(yunit) y = node_y.mean * yunit_old.express(yunit)
if yerr_kind == "std": if yerr_kind == "std":
std = node_y.std.read() * yunit_old.express(yunit) std = node_y["std"] * yunit_old.express(yunit)
yerr_min = y - sigma_err * std yerr_min = y - sigma_err * std
yerr_max = y + sigma_err * std yerr_max = y + sigma_err * std
elif yerr_kind == "min_max": elif yerr_kind == "min_max":
yerr_min = node_y.min.read() * yunit_old.express(yunit) yerr_min = node_y["min"] * yunit_old.express(yunit)
yerr_max = node_y.max.read() * yunit_old.express(yunit) yerr_max = node_y["max"] * yunit_old.express(yunit)
elif yerr_kind == "95per": elif yerr_kind == "95per":
yerr_min = node_y.q025.read() * yunit_old.express(yunit) yerr_min = node_y["q025"] * yunit_old.express(yunit)
yerr_max = node_y.q975.read() * yunit_old.express(yunit) yerr_max = node_y["q975"] * yunit_old.express(yunit)
elif yerr_kind == "68per": elif yerr_kind == "68per":
yerr_min = node_y.q16.read() * yunit_old.express(yunit) yerr_min = node_y["q16"] * yunit_old.express(yunit)
yerr_max = node_y.q84.read() * yunit_old.express(yunit) yerr_max = node_y["q84"] * yunit_old.express(yunit)
else: else:
yerr_min = y yerr_min = y
yerr_max = y yerr_max = y
@@ -1176,13 +1147,13 @@ class Plotter(Aggregator, BaseProcessor):
yerr_max[mask], yerr_max[mask],
) )
else: else:
x = node_x[run].read() * xunit_old.express(xunit) x = node_x[run] * xunit_old.express(xunit)
y = node_y[run].read() * yunit_old.express(yunit) y = node_y[run] * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y) mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask] x, y = x[mask], y[mask]
if isinstance(yerr, str): if isinstance(yerr, str):
yerr = self.save.get_node(yerr).read() yerr = self.current_processor.get_value(yerr)
# Apply transformations on x # Apply transformations on x
if xtransform is not None: if xtransform is not None:
@@ -1211,7 +1182,7 @@ class Plotter(Aggregator, BaseProcessor):
color = colors[run] color = colors[run]
elif nml_color == "time": elif nml_color == "time":
time = ( time = (
self.save.root._v_attrs.time * self.study.info["unit_time"] self.current_processor.time * self.current_processor.info["unit_time"]
).express(unit_time) ).express(unit_time)
color = colors(time) color = colors(time)
else: else:
@@ -1261,8 +1232,8 @@ class Plotter(Aggregator, BaseProcessor):
Plot power spectrum (wrapper around pspec_read) Plot power spectrum (wrapper around pspec_read)
""" """
del kwargs["run"] del kwargs["run"]
file_pspec = self.save.get_node("/hdf5/pspec").read() file_pspec = self.current_processor.get_value("/hdf5/pspec")
num = self.save.root._v_attrs.num num = self.current_processor.num
getattr(pspec_read, "pspec_" + name)(file_pspec, ".", num, **kwargs) getattr(pspec_read, "pspec_" + name)(file_pspec, ".", num, **kwargs)
def _overlay_fit(self, x, y, yerr=None, kind="linear", label=None, **kwargs): def _overlay_fit(self, x, y, yerr=None, kind="linear", label=None, **kwargs):