[plotter] use getter instead of direct save access

This commit is contained in:
Noe Brucy
2021-07-22 17:48:24 +02:00
parent 4fa72eccc5
commit fa952afdb0
+98 -127
View File
@@ -63,19 +63,6 @@ class PlotRule(Rule):
Add an extra method, plot, that take the reference to an open hdf5 file (from pytables)
"""
def plot(self, save, arg, **kwargs):
"""
Set the plotter's storage to 'save' and execute the rule
Parameters
----------
save : opended pytables hdf5 file, where to find the data
arg : main argument of the plotting function
kargs : optional keyword arguments to the plotting function
"""
self.postproc.save = save
return self.process_fn(arg, **kwargs)
def datafile(self, name, arg):
if arg is not None:
name = name + "_" + str(arg)
@@ -189,8 +176,7 @@ class Plotter(Aggregator, BaseProcessor):
# generate astrophysix's simulations object
self.gen_simus()
self.save = None
self.current_snap = None
self.current_processor = None
def gen_simus(self):
self.simulations = {}
@@ -344,35 +330,24 @@ class Plotter(Aggregator, BaseProcessor):
else:
raise
# Find plot save
if from_cells or rule.kind == "cells":
self.current_snap = self.snaps[run][num]
if not os.path.exists(self.current_snap.cells_filename):
self.snaps[run][num].load_cells()
self.snaps[run][num].unload_cells()
save = tables.open_file(self.snaps[run][num].cells_filename)
elif rule.kind == "snapshot":
self.current_snap = self.snaps[run][num]
save = tables.open_file( self.current_snap.filename)
# Find underlying processor
if rule.kind == "snapshot":
self.current_processor = self.snaps[run][num]
else:
save = tables.open_file(self.study.filename, "r")
self.current_processor = self.study
# Call plot routine
try:
close = (not onefigure) or (i == len(run_num) - 1)
plot_info = self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=real_ax,
close=close,
run=run,
**kwargs,
)
finally:
save.close()
close = (not onefigure) or (i == len(run_num) - 1)
plot_info = self._plot_rule(
rule,
arg,
plot_filename,
overwrite,
ax=real_ax,
close=close,
run=run,
**kwargs,
)
# Save in astrophysix format
df = rule.datafile(name, arg)
@@ -404,14 +379,14 @@ class Plotter(Aggregator, BaseProcessor):
return datafiles
def _plot_rule(
self, rule, save, arg, plot_filename, overwrite, ax, close=True, **kwargs
self, rule, arg, plot_filename, overwrite, ax, close=True, **kwargs
):
"""
Once all dependencies are met, actually process the rule
"""
plt.sca(ax)
if self._needs_computation(overwrite, plot_filename):
plot_info = rule.plot(save, arg, **kwargs)
plot_info = rule.process(arg, **kwargs)
if not self.params.out.interactive and close:
plt.tight_layout(pad=1)
@@ -506,23 +481,22 @@ class Plotter(Aggregator, BaseProcessor):
label_run = label
return label_run
def _ax_label_unit(self, node, label, unit, unit_coeff, put_units=True):
def _ax_label_unit(self, node_name, label, unit, unit_coeff, put_units=True):
"""
Find appropriate labels for axis
"""
if label is None:
if "label" in node._v_attrs:
label = node._v_attrs.label
elif node._v_name in self.label_convert:
label = self.label_convert[node._v_name]
elif not node._v_title == "":
label = node._v_title
else:
label = node._v_name
try:
label = self.current_processor.get_attribute("node_name", "label")
except KeyError:
if os.path.basename(node_name) in self.label_convert:
label = self.label_convert[os.path.basename(node_name)]
else:
label = os.path.basename(node_name)
if "unit" in node._v_attrs:
unit_old = node._v_attrs.unit
else:
try:
unit_old = self.current_processor.get_attribute("node_name", "unit")
except KeyError:
unit_old = U.none
if unit is None:
@@ -542,7 +516,7 @@ class Plotter(Aggregator, BaseProcessor):
title = self.get_label_run(run, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.study.info["unit_time"]
time = self.current_processor.info["time"] * self.study.info["unit_time"]
u_str = unit_str(unit_time, format="{unit}")
time_str = self.params.plot.time_fmt.format(time.express(unit_time), u_str)
if len(title) > 0:
@@ -590,26 +564,26 @@ class Plotter(Aggregator, BaseProcessor):
ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los]
im_extent = np.array(self.save.root.maps._v_attrs.im_extent)
unit_length = self.save.root._v_attrs["unit_length"]
im_extent = np.array(self.current_processor.im_extent)
unit_length = self.current_processor.info["unit_length"]
if embeded:
axes = False
scalebar = True
if center_space:
center = self.save.root.maps._v_attrs.center
center = self.current_processor.get_attribute("/maps", "center")
center_h = center[self._ax_nb[ax_h]]
center_v = center[self._ax_nb[ax_v]]
im_extent[:2] = im_extent[:2] - center_h
im_extent[2:] = im_extent[2:] - center_v
im_extent = im_extent * unit_length.express(unit_space)
node = self.save.get_node("/maps/{}_{}".format(name, ax_los))
dmap = node.read()
node_name = f"/maps/{name}_{ax_los}"
dmap = self.current_processor.get_value(node_name)
label, unit_old, unit = self._ax_label_unit(
node, label, unit, unit_coeff, put_units
node_name, label, unit, unit_coeff, put_units
)
dmap = dmap * unit_old.express(unit)
@@ -736,7 +710,7 @@ class Plotter(Aggregator, BaseProcessor):
"""
Add an overlay : contour of other map
"""
map_contour = self.save.get_node("/maps/{}_{}".format(map_name, ax_los)).read()
map_contour = self.current_processor.get_value("/maps/{}_{}".format(map_name, ax_los))
if log:
map_contour = np.log10(map_contour)
# Computing linewidths
@@ -797,29 +771,27 @@ class Plotter(Aggregator, BaseProcessor):
center_space=False,
parts = True,
sinks = False,
**kwargs
):
"""
Add an overlay with particles data
"""
unit_length = self.current_snap.info["unit_length"]
unit_length = self.current_processor.info["unit_length"]
if sinks:
self.current_snap.sinks()
sinks = pd.DataFrame(self.current_snap.get_value("/datasets/sinks"))
self.current_processor.sinks()
sinks = pd.DataFrame(self.current_processor.get_value("/datasets/sinks"))
part_pos = sinks[["x", "y", "z"]].values
mass = sinks.M
if parts:
# Open particle HDF5 filetype_from_ext
filename = self.save.get_node("/hdf5/write_particles").read()[0].decode()
hdf5_parts = tables.open_file(filename, "r")
part_pos = hdf5_parts.get_node("/data/pos").read()
mass = hdf5_parts.get_node("/data/mass").read()
mass *= self.current_snap.info["unit_mass"].express(U.Msun)
hdf5_parts.close()
self.current_processor.load_parts(keys=["pos", "mass"])
part_pos = self.current_processor.parts.pos
mass = self.current_processor.parts.mass
mass *= self.current_processor.info["unit_mass"].express(U.Msun)
self.current_processor.unload_parts()
# index of the horizontal axis
ih = self._ax_nb[self._axes_h[ax_los]]
@@ -833,7 +805,7 @@ class Plotter(Aggregator, BaseProcessor):
if center_space:
ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los]
center = self.save.root.maps._v_attrs.center
center = self.current_processor.get_attribute("/maps", "center")
center_h = center[self._ax_nb[ax_h]]
center_v = center[self._ax_nb[ax_v]]
part_h -= center_h
@@ -861,11 +833,10 @@ class Plotter(Aggregator, BaseProcessor):
"""
Add an overlay : velocity vector field
"""
dmap_vh_node = self.save.get_node("/maps/speed_h_{}".format(ax_los))
dmap_vh = dmap_vh_node.read()
dmap_vv = self.save.get_node("/maps/speed_v_{}".format(ax_los)).read()
dmap_vh = self.current_processor.get_value("/maps/speed_h_{}".format(ax_los))
dmap_vv = self.current_processor.get_value("/maps/speed_v_{}".format(ax_los))
label, unit_old, unit = self._ax_label_unit(dmap_vh_node, "", unit, unit_coeff)
label, unit_old, unit = self._ax_label_unit(f"/maps/speed_h_{ax_los}", "", unit, unit_coeff)
vel_red = self.params.plot.vel_red
@@ -915,16 +886,15 @@ class Plotter(Aggregator, BaseProcessor):
"""
Add an overlay : magnetic streamlines
"""
dmap_Bh_node = self.save.get_node("/maps/B_h_{}".format(ax_los))
dmap_Bh = dmap_Bh_node.read()
dmap_Bv = self.save.get_node("/maps/B_v_{}".format(ax_los)).read()
dmap_Bh = self.current_processor.get_value(f"/maps/B_h_{ax_los}")
dmap_Bv = self.current_processor.get_value(f"/maps/B_v_{ax_los}")
# TODO : redo this with im_extent
vel_red = self.params.plot.vel_red
radius = self.save.root.maps._v_attrs.radius
center = self.save.root.maps._v_attrs.center
lbox = self.save.root._v_attrs.lbox
radius = self.current_processor.attribute("/maps", "radius")
center = self.current_processor.attribute("/maps", "center")
lbox = self.current_processor.lbox
map_Bh_red = dmap_Bh[::vel_red, ::vel_red] # take only a subset of velocities
map_Bv_red = dmap_Bv[::vel_red, ::vel_red]
@@ -973,22 +943,23 @@ class Plotter(Aggregator, BaseProcessor):
# Get node
if ax_los is not None:
name = name + "_" + ax_los
node = self.save.get_node(group + name)
node_name = group + name
if xlog is None:
try:
xlog = node._v_attrs.logbins
xlog = self.current_processor.get_attribute(node_name, "logbins")
except AttributeError:
xlog = False
# get label and units
xlabel, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff)
xlabel, unit_old, unit = self._ax_label_unit(node_name, label, unit, unit_coeff)
# Read data
node = self.current_processor.get_value(node_name)
if "mean" in node:
index = node["runs"].read().index(run.encode())
values, centers = node["mean"].read()[index]
index = node["runs"].index(run.encode())
values, centers = node["mean"][index]
else:
values, centers = node.read()
values, centers = node
if xlog:
centers = centers + np.log10(unit_old.express(unit))
else:
@@ -1034,16 +1005,16 @@ class Plotter(Aggregator, BaseProcessor):
plt.ylabel(ylabel)
# Also diplay fit, previously saved
if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
slope = node.attrs.slope
origin = node.attrs.origin
plt.plot(
centers,
10 ** (slope * centers + origin),
"--",
linewidth=2,
color="orange",
)
# if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
# slope = node.attrs.slope
# origin = node.attrs.origin
# plt.plot(
# centers,
# 10 ** (slope * centers + origin),
# "--",
# linewidth=2,
# color="orange",
# )
# or a new one
if fit is not None:
self._overlay_fit(
@@ -1108,28 +1079,28 @@ class Plotter(Aggregator, BaseProcessor):
name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg
# Get hdf5 nodes
node_x = self.save.get_node(name_x)
node_y = self.save.get_node(name_y)
node_x = self.current_processor.get_value(name_x)
node_y = self.current_processor.get_value(name_y)
# If the actual data is in another file, fetch it
if subname_x:
hdf5_x = tables.open_file(node_x.read())
node_x = hdf5_x.get_node(subname_x)
hdf5_x = tables.open_file(node_x)
node_x = hdf5_x.get_node(subname_x).read()
if subname_y:
hdf5_y = tables.open_file(node_y.read())
node_y = hdf5_y.get_node(subname_y)
hdf5_y = tables.open_file(node_y)
node_y = hdf5_y.get_node(subname_y).read()
# Find proper labels
xlabel, xunit_old, xunit = self._ax_label_unit(
node_x, xlabel, xunit, xunit_coeff
name_x, xlabel, xunit, xunit_coeff
)
ylabel, yunit_old, yunit = self._ax_label_unit(
node_y, ylabel, yunit, yunit_coeff
name_y, ylabel, yunit, yunit_coeff
)
# If relevent, get time
if put_time:
time = self.save.root._v_attrs.time * self.study.info["unit_time"]
time = self.current_processor.time * self.study.info["unit_time"]
time_str = self.params.plot.time_fmt.format(
time.express(unit_time), unit_time.latex.replace("text", "math")
)
@@ -1141,28 +1112,28 @@ class Plotter(Aggregator, BaseProcessor):
# Manage the different forms in which the data may be stored :
# Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs)
if node_y._v_attrs.CLASS == "ARRAY":
x = node_x.read() * xunit_old.express(xunit)
y = node_y.read() * yunit_old.express(yunit)
if isinstance(node_y, np.ndarray):
x = node_x * xunit_old.express(xunit)
y = node_y * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
elif "mean" in node_y:
x = node_x.read() * xunit_old.express(xunit)
y = node_y.mean.read() * yunit_old.express(yunit)
x = node_x * xunit_old.express(xunit)
y = node_y.mean * yunit_old.express(yunit)
if yerr_kind == "std":
std = node_y.std.read() * yunit_old.express(yunit)
std = node_y["std"] * yunit_old.express(yunit)
yerr_min = y - sigma_err * std
yerr_max = y + sigma_err * std
elif yerr_kind == "min_max":
yerr_min = node_y.min.read() * yunit_old.express(yunit)
yerr_max = node_y.max.read() * yunit_old.express(yunit)
yerr_min = node_y["min"] * yunit_old.express(yunit)
yerr_max = node_y["max"] * yunit_old.express(yunit)
elif yerr_kind == "95per":
yerr_min = node_y.q025.read() * yunit_old.express(yunit)
yerr_max = node_y.q975.read() * yunit_old.express(yunit)
yerr_min = node_y["q025"] * yunit_old.express(yunit)
yerr_max = node_y["q975"] * yunit_old.express(yunit)
elif yerr_kind == "68per":
yerr_min = node_y.q16.read() * yunit_old.express(yunit)
yerr_max = node_y.q84.read() * yunit_old.express(yunit)
yerr_min = node_y["q16"] * yunit_old.express(yunit)
yerr_max = node_y["q84"] * yunit_old.express(yunit)
else:
yerr_min = y
yerr_max = y
@@ -1176,13 +1147,13 @@ class Plotter(Aggregator, BaseProcessor):
yerr_max[mask],
)
else:
x = node_x[run].read() * xunit_old.express(xunit)
y = node_y[run].read() * yunit_old.express(yunit)
x = node_x[run] * xunit_old.express(xunit)
y = node_y[run] * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
if isinstance(yerr, str):
yerr = self.save.get_node(yerr).read()
yerr = self.current_processor.get_value(yerr)
# Apply transformations on x
if xtransform is not None:
@@ -1211,7 +1182,7 @@ class Plotter(Aggregator, BaseProcessor):
color = colors[run]
elif nml_color == "time":
time = (
self.save.root._v_attrs.time * self.study.info["unit_time"]
self.current_processor.time * self.current_processor.info["unit_time"]
).express(unit_time)
color = colors(time)
else:
@@ -1261,8 +1232,8 @@ class Plotter(Aggregator, BaseProcessor):
Plot power spectrum (wrapper around pspec_read)
"""
del kwargs["run"]
file_pspec = self.save.get_node("/hdf5/pspec").read()
num = self.save.root._v_attrs.num
file_pspec = self.current_processor.get_value("/hdf5/pspec")
num = self.current_processor.num
getattr(pspec_read, "pspec_" + name)(file_pspec, ".", num, **kwargs)
def _overlay_fit(self, x, y, yerr=None, kind="linear", label=None, **kwargs):