[plotter] use getter instead of direct save access
This commit is contained in:
+98
-127
@@ -63,19 +63,6 @@ class PlotRule(Rule):
|
||||
Add an extra method, plot, that take the reference to an open hdf5 file (from pytables)
|
||||
"""
|
||||
|
||||
def plot(self, save, arg, **kwargs):
|
||||
"""
|
||||
Set the plotter's storage to 'save' and execute the rule
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save : opended pytables hdf5 file, where to find the data
|
||||
arg : main argument of the plotting function
|
||||
kargs : optional keyword arguments to the plotting function
|
||||
"""
|
||||
self.postproc.save = save
|
||||
return self.process_fn(arg, **kwargs)
|
||||
|
||||
def datafile(self, name, arg):
|
||||
if arg is not None:
|
||||
name = name + "_" + str(arg)
|
||||
@@ -189,8 +176,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
# generate astrophysix's simulations object
|
||||
self.gen_simus()
|
||||
|
||||
self.save = None
|
||||
self.current_snap = None
|
||||
self.current_processor = None
|
||||
|
||||
def gen_simus(self):
|
||||
self.simulations = {}
|
||||
@@ -344,35 +330,24 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
else:
|
||||
raise
|
||||
|
||||
# Find plot save
|
||||
if from_cells or rule.kind == "cells":
|
||||
self.current_snap = self.snaps[run][num]
|
||||
if not os.path.exists(self.current_snap.cells_filename):
|
||||
self.snaps[run][num].load_cells()
|
||||
self.snaps[run][num].unload_cells()
|
||||
save = tables.open_file(self.snaps[run][num].cells_filename)
|
||||
elif rule.kind == "snapshot":
|
||||
self.current_snap = self.snaps[run][num]
|
||||
save = tables.open_file( self.current_snap.filename)
|
||||
# Find underlying processor
|
||||
if rule.kind == "snapshot":
|
||||
self.current_processor = self.snaps[run][num]
|
||||
else:
|
||||
save = tables.open_file(self.study.filename, "r")
|
||||
self.current_processor = self.study
|
||||
|
||||
# Call plot routine
|
||||
try:
|
||||
close = (not onefigure) or (i == len(run_num) - 1)
|
||||
plot_info = self._plot_rule(
|
||||
rule,
|
||||
save,
|
||||
arg,
|
||||
plot_filename,
|
||||
overwrite,
|
||||
ax=real_ax,
|
||||
close=close,
|
||||
run=run,
|
||||
**kwargs,
|
||||
)
|
||||
finally:
|
||||
save.close()
|
||||
close = (not onefigure) or (i == len(run_num) - 1)
|
||||
plot_info = self._plot_rule(
|
||||
rule,
|
||||
arg,
|
||||
plot_filename,
|
||||
overwrite,
|
||||
ax=real_ax,
|
||||
close=close,
|
||||
run=run,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Save in astrophysix format
|
||||
df = rule.datafile(name, arg)
|
||||
@@ -404,14 +379,14 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
return datafiles
|
||||
|
||||
def _plot_rule(
|
||||
self, rule, save, arg, plot_filename, overwrite, ax, close=True, **kwargs
|
||||
self, rule, arg, plot_filename, overwrite, ax, close=True, **kwargs
|
||||
):
|
||||
"""
|
||||
Once all dependencies are met, actually process the rule
|
||||
"""
|
||||
plt.sca(ax)
|
||||
if self._needs_computation(overwrite, plot_filename):
|
||||
plot_info = rule.plot(save, arg, **kwargs)
|
||||
plot_info = rule.process(arg, **kwargs)
|
||||
|
||||
if not self.params.out.interactive and close:
|
||||
plt.tight_layout(pad=1)
|
||||
@@ -506,23 +481,22 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
label_run = label
|
||||
return label_run
|
||||
|
||||
def _ax_label_unit(self, node, label, unit, unit_coeff, put_units=True):
|
||||
def _ax_label_unit(self, node_name, label, unit, unit_coeff, put_units=True):
|
||||
"""
|
||||
Find appropriate labels for axis
|
||||
"""
|
||||
if label is None:
|
||||
if "label" in node._v_attrs:
|
||||
label = node._v_attrs.label
|
||||
elif node._v_name in self.label_convert:
|
||||
label = self.label_convert[node._v_name]
|
||||
elif not node._v_title == "":
|
||||
label = node._v_title
|
||||
else:
|
||||
label = node._v_name
|
||||
try:
|
||||
label = self.current_processor.get_attribute("node_name", "label")
|
||||
except KeyError:
|
||||
if os.path.basename(node_name) in self.label_convert:
|
||||
label = self.label_convert[os.path.basename(node_name)]
|
||||
else:
|
||||
label = os.path.basename(node_name)
|
||||
|
||||
if "unit" in node._v_attrs:
|
||||
unit_old = node._v_attrs.unit
|
||||
else:
|
||||
try:
|
||||
unit_old = self.current_processor.get_attribute("node_name", "unit")
|
||||
except KeyError:
|
||||
unit_old = U.none
|
||||
|
||||
if unit is None:
|
||||
@@ -542,7 +516,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
title = self.get_label_run(run, title, nml_key)
|
||||
|
||||
if put_time:
|
||||
time = self.save.root._v_attrs.time * self.study.info["unit_time"]
|
||||
time = self.current_processor.info["time"] * self.study.info["unit_time"]
|
||||
u_str = unit_str(unit_time, format="{unit}")
|
||||
time_str = self.params.plot.time_fmt.format(time.express(unit_time), u_str)
|
||||
if len(title) > 0:
|
||||
@@ -590,26 +564,26 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
ax_h = self._axes_h[ax_los]
|
||||
ax_v = self._axes_v[ax_los]
|
||||
|
||||
im_extent = np.array(self.save.root.maps._v_attrs.im_extent)
|
||||
unit_length = self.save.root._v_attrs["unit_length"]
|
||||
im_extent = np.array(self.current_processor.im_extent)
|
||||
unit_length = self.current_processor.info["unit_length"]
|
||||
|
||||
if embeded:
|
||||
axes = False
|
||||
scalebar = True
|
||||
|
||||
if center_space:
|
||||
center = self.save.root.maps._v_attrs.center
|
||||
center = self.current_processor.get_attribute("/maps", "center")
|
||||
center_h = center[self._ax_nb[ax_h]]
|
||||
center_v = center[self._ax_nb[ax_v]]
|
||||
im_extent[:2] = im_extent[:2] - center_h
|
||||
im_extent[2:] = im_extent[2:] - center_v
|
||||
im_extent = im_extent * unit_length.express(unit_space)
|
||||
|
||||
node = self.save.get_node("/maps/{}_{}".format(name, ax_los))
|
||||
dmap = node.read()
|
||||
node_name = f"/maps/{name}_{ax_los}"
|
||||
dmap = self.current_processor.get_value(node_name)
|
||||
|
||||
label, unit_old, unit = self._ax_label_unit(
|
||||
node, label, unit, unit_coeff, put_units
|
||||
node_name, label, unit, unit_coeff, put_units
|
||||
)
|
||||
|
||||
dmap = dmap * unit_old.express(unit)
|
||||
@@ -736,7 +710,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
"""
|
||||
Add an overlay : contour of other map
|
||||
"""
|
||||
map_contour = self.save.get_node("/maps/{}_{}".format(map_name, ax_los)).read()
|
||||
map_contour = self.current_processor.get_value("/maps/{}_{}".format(map_name, ax_los))
|
||||
if log:
|
||||
map_contour = np.log10(map_contour)
|
||||
# Computing linewidths
|
||||
@@ -797,29 +771,27 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
center_space=False,
|
||||
parts = True,
|
||||
sinks = False,
|
||||
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Add an overlay with particles data
|
||||
"""
|
||||
|
||||
unit_length = self.current_snap.info["unit_length"]
|
||||
unit_length = self.current_processor.info["unit_length"]
|
||||
|
||||
if sinks:
|
||||
self.current_snap.sinks()
|
||||
sinks = pd.DataFrame(self.current_snap.get_value("/datasets/sinks"))
|
||||
self.current_processor.sinks()
|
||||
sinks = pd.DataFrame(self.current_processor.get_value("/datasets/sinks"))
|
||||
part_pos = sinks[["x", "y", "z"]].values
|
||||
mass = sinks.M
|
||||
|
||||
if parts:
|
||||
# Open particle HDF5 filetype_from_ext
|
||||
filename = self.save.get_node("/hdf5/write_particles").read()[0].decode()
|
||||
hdf5_parts = tables.open_file(filename, "r")
|
||||
part_pos = hdf5_parts.get_node("/data/pos").read()
|
||||
mass = hdf5_parts.get_node("/data/mass").read()
|
||||
mass *= self.current_snap.info["unit_mass"].express(U.Msun)
|
||||
hdf5_parts.close()
|
||||
self.current_processor.load_parts(keys=["pos", "mass"])
|
||||
part_pos = self.current_processor.parts.pos
|
||||
mass = self.current_processor.parts.mass
|
||||
mass *= self.current_processor.info["unit_mass"].express(U.Msun)
|
||||
self.current_processor.unload_parts()
|
||||
|
||||
# index of the horizontal axis
|
||||
ih = self._ax_nb[self._axes_h[ax_los]]
|
||||
@@ -833,7 +805,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
if center_space:
|
||||
ax_h = self._axes_h[ax_los]
|
||||
ax_v = self._axes_v[ax_los]
|
||||
center = self.save.root.maps._v_attrs.center
|
||||
center = self.current_processor.get_attribute("/maps", "center")
|
||||
center_h = center[self._ax_nb[ax_h]]
|
||||
center_v = center[self._ax_nb[ax_v]]
|
||||
part_h -= center_h
|
||||
@@ -861,11 +833,10 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
"""
|
||||
Add an overlay : velocity vector field
|
||||
"""
|
||||
dmap_vh_node = self.save.get_node("/maps/speed_h_{}".format(ax_los))
|
||||
dmap_vh = dmap_vh_node.read()
|
||||
dmap_vv = self.save.get_node("/maps/speed_v_{}".format(ax_los)).read()
|
||||
dmap_vh = self.current_processor.get_value("/maps/speed_h_{}".format(ax_los))
|
||||
dmap_vv = self.current_processor.get_value("/maps/speed_v_{}".format(ax_los))
|
||||
|
||||
label, unit_old, unit = self._ax_label_unit(dmap_vh_node, "", unit, unit_coeff)
|
||||
label, unit_old, unit = self._ax_label_unit(f"/maps/speed_h_{ax_los}", "", unit, unit_coeff)
|
||||
|
||||
vel_red = self.params.plot.vel_red
|
||||
|
||||
@@ -915,16 +886,15 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
"""
|
||||
Add an overlay : magnetic streamlines
|
||||
"""
|
||||
dmap_Bh_node = self.save.get_node("/maps/B_h_{}".format(ax_los))
|
||||
dmap_Bh = dmap_Bh_node.read()
|
||||
dmap_Bv = self.save.get_node("/maps/B_v_{}".format(ax_los)).read()
|
||||
dmap_Bh = self.current_processor.get_value(f"/maps/B_h_{ax_los}")
|
||||
dmap_Bv = self.current_processor.get_value(f"/maps/B_v_{ax_los}")
|
||||
|
||||
# TODO : redo this with im_extent
|
||||
|
||||
vel_red = self.params.plot.vel_red
|
||||
radius = self.save.root.maps._v_attrs.radius
|
||||
center = self.save.root.maps._v_attrs.center
|
||||
lbox = self.save.root._v_attrs.lbox
|
||||
radius = self.current_processor.attribute("/maps", "radius")
|
||||
center = self.current_processor.attribute("/maps", "center")
|
||||
lbox = self.current_processor.lbox
|
||||
|
||||
map_Bh_red = dmap_Bh[::vel_red, ::vel_red] # take only a subset of velocities
|
||||
map_Bv_red = dmap_Bv[::vel_red, ::vel_red]
|
||||
@@ -973,22 +943,23 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
# Get node
|
||||
if ax_los is not None:
|
||||
name = name + "_" + ax_los
|
||||
node = self.save.get_node(group + name)
|
||||
node_name = group + name
|
||||
if xlog is None:
|
||||
try:
|
||||
xlog = node._v_attrs.logbins
|
||||
xlog = self.current_processor.get_attribute(node_name, "logbins")
|
||||
except AttributeError:
|
||||
xlog = False
|
||||
|
||||
# get label and units
|
||||
xlabel, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff)
|
||||
xlabel, unit_old, unit = self._ax_label_unit(node_name, label, unit, unit_coeff)
|
||||
|
||||
# Read data
|
||||
node = self.current_processor.get_value(node_name)
|
||||
if "mean" in node:
|
||||
index = node["runs"].read().index(run.encode())
|
||||
values, centers = node["mean"].read()[index]
|
||||
index = node["runs"].index(run.encode())
|
||||
values, centers = node["mean"][index]
|
||||
else:
|
||||
values, centers = node.read()
|
||||
values, centers = node
|
||||
if xlog:
|
||||
centers = centers + np.log10(unit_old.express(unit))
|
||||
else:
|
||||
@@ -1034,16 +1005,16 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
plt.ylabel(ylabel)
|
||||
|
||||
# Also diplay fit, previously saved
|
||||
if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
|
||||
slope = node.attrs.slope
|
||||
origin = node.attrs.origin
|
||||
plt.plot(
|
||||
centers,
|
||||
10 ** (slope * centers + origin),
|
||||
"--",
|
||||
linewidth=2,
|
||||
color="orange",
|
||||
)
|
||||
# if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
|
||||
# slope = node.attrs.slope
|
||||
# origin = node.attrs.origin
|
||||
# plt.plot(
|
||||
# centers,
|
||||
# 10 ** (slope * centers + origin),
|
||||
# "--",
|
||||
# linewidth=2,
|
||||
# color="orange",
|
||||
# )
|
||||
# or a new one
|
||||
if fit is not None:
|
||||
self._overlay_fit(
|
||||
@@ -1108,28 +1079,28 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg
|
||||
|
||||
# Get hdf5 nodes
|
||||
node_x = self.save.get_node(name_x)
|
||||
node_y = self.save.get_node(name_y)
|
||||
node_x = self.current_processor.get_value(name_x)
|
||||
node_y = self.current_processor.get_value(name_y)
|
||||
|
||||
# If the actual data is in another file, fetch it
|
||||
if subname_x:
|
||||
hdf5_x = tables.open_file(node_x.read())
|
||||
node_x = hdf5_x.get_node(subname_x)
|
||||
hdf5_x = tables.open_file(node_x)
|
||||
node_x = hdf5_x.get_node(subname_x).read()
|
||||
if subname_y:
|
||||
hdf5_y = tables.open_file(node_y.read())
|
||||
node_y = hdf5_y.get_node(subname_y)
|
||||
hdf5_y = tables.open_file(node_y)
|
||||
node_y = hdf5_y.get_node(subname_y).read()
|
||||
|
||||
# Find proper labels
|
||||
xlabel, xunit_old, xunit = self._ax_label_unit(
|
||||
node_x, xlabel, xunit, xunit_coeff
|
||||
name_x, xlabel, xunit, xunit_coeff
|
||||
)
|
||||
ylabel, yunit_old, yunit = self._ax_label_unit(
|
||||
node_y, ylabel, yunit, yunit_coeff
|
||||
name_y, ylabel, yunit, yunit_coeff
|
||||
)
|
||||
|
||||
# If relevent, get time
|
||||
if put_time:
|
||||
time = self.save.root._v_attrs.time * self.study.info["unit_time"]
|
||||
time = self.current_processor.time * self.study.info["unit_time"]
|
||||
time_str = self.params.plot.time_fmt.format(
|
||||
time.express(unit_time), unit_time.latex.replace("text", "math")
|
||||
)
|
||||
@@ -1141,28 +1112,28 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
|
||||
# Manage the different forms in which the data may be stored :
|
||||
# Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs)
|
||||
if node_y._v_attrs.CLASS == "ARRAY":
|
||||
x = node_x.read() * xunit_old.express(xunit)
|
||||
y = node_y.read() * yunit_old.express(yunit)
|
||||
if isinstance(node_y, np.ndarray):
|
||||
x = node_x * xunit_old.express(xunit)
|
||||
y = node_y * yunit_old.express(yunit)
|
||||
mask = np.isfinite(x) & np.isfinite(y)
|
||||
x, y = x[mask], y[mask]
|
||||
elif "mean" in node_y:
|
||||
x = node_x.read() * xunit_old.express(xunit)
|
||||
y = node_y.mean.read() * yunit_old.express(yunit)
|
||||
x = node_x * xunit_old.express(xunit)
|
||||
y = node_y.mean * yunit_old.express(yunit)
|
||||
|
||||
if yerr_kind == "std":
|
||||
std = node_y.std.read() * yunit_old.express(yunit)
|
||||
std = node_y["std"] * yunit_old.express(yunit)
|
||||
yerr_min = y - sigma_err * std
|
||||
yerr_max = y + sigma_err * std
|
||||
elif yerr_kind == "min_max":
|
||||
yerr_min = node_y.min.read() * yunit_old.express(yunit)
|
||||
yerr_max = node_y.max.read() * yunit_old.express(yunit)
|
||||
yerr_min = node_y["min"] * yunit_old.express(yunit)
|
||||
yerr_max = node_y["max"] * yunit_old.express(yunit)
|
||||
elif yerr_kind == "95per":
|
||||
yerr_min = node_y.q025.read() * yunit_old.express(yunit)
|
||||
yerr_max = node_y.q975.read() * yunit_old.express(yunit)
|
||||
yerr_min = node_y["q025"] * yunit_old.express(yunit)
|
||||
yerr_max = node_y["q975"] * yunit_old.express(yunit)
|
||||
elif yerr_kind == "68per":
|
||||
yerr_min = node_y.q16.read() * yunit_old.express(yunit)
|
||||
yerr_max = node_y.q84.read() * yunit_old.express(yunit)
|
||||
yerr_min = node_y["q16"] * yunit_old.express(yunit)
|
||||
yerr_max = node_y["q84"] * yunit_old.express(yunit)
|
||||
else:
|
||||
yerr_min = y
|
||||
yerr_max = y
|
||||
@@ -1176,13 +1147,13 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
yerr_max[mask],
|
||||
)
|
||||
else:
|
||||
x = node_x[run].read() * xunit_old.express(xunit)
|
||||
y = node_y[run].read() * yunit_old.express(yunit)
|
||||
x = node_x[run] * xunit_old.express(xunit)
|
||||
y = node_y[run] * yunit_old.express(yunit)
|
||||
mask = np.isfinite(x) & np.isfinite(y)
|
||||
x, y = x[mask], y[mask]
|
||||
|
||||
if isinstance(yerr, str):
|
||||
yerr = self.save.get_node(yerr).read()
|
||||
yerr = self.current_processor.get_value(yerr)
|
||||
|
||||
# Apply transformations on x
|
||||
if xtransform is not None:
|
||||
@@ -1211,7 +1182,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
color = colors[run]
|
||||
elif nml_color == "time":
|
||||
time = (
|
||||
self.save.root._v_attrs.time * self.study.info["unit_time"]
|
||||
self.current_processor.time * self.current_processor.info["unit_time"]
|
||||
).express(unit_time)
|
||||
color = colors(time)
|
||||
else:
|
||||
@@ -1261,8 +1232,8 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
Plot power spectrum (wrapper around pspec_read)
|
||||
"""
|
||||
del kwargs["run"]
|
||||
file_pspec = self.save.get_node("/hdf5/pspec").read()
|
||||
num = self.save.root._v_attrs.num
|
||||
file_pspec = self.current_processor.get_value("/hdf5/pspec")
|
||||
num = self.current_processor.num
|
||||
getattr(pspec_read, "pspec_" + name)(file_pspec, ".", num, **kwargs)
|
||||
|
||||
def _overlay_fit(self, x, y, yerr=None, kind="linear", label=None, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user