[plotter] use getter instead of direct save access
This commit is contained in:
+98
-127
@@ -63,19 +63,6 @@ class PlotRule(Rule):
|
|||||||
Add an extra method, plot, that take the reference to an open hdf5 file (from pytables)
|
Add an extra method, plot, that take the reference to an open hdf5 file (from pytables)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def plot(self, save, arg, **kwargs):
|
|
||||||
"""
|
|
||||||
Set the plotter's storage to 'save' and execute the rule
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
save : opended pytables hdf5 file, where to find the data
|
|
||||||
arg : main argument of the plotting function
|
|
||||||
kargs : optional keyword arguments to the plotting function
|
|
||||||
"""
|
|
||||||
self.postproc.save = save
|
|
||||||
return self.process_fn(arg, **kwargs)
|
|
||||||
|
|
||||||
def datafile(self, name, arg):
|
def datafile(self, name, arg):
|
||||||
if arg is not None:
|
if arg is not None:
|
||||||
name = name + "_" + str(arg)
|
name = name + "_" + str(arg)
|
||||||
@@ -189,8 +176,7 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
# generate astrophysix's simulations object
|
# generate astrophysix's simulations object
|
||||||
self.gen_simus()
|
self.gen_simus()
|
||||||
|
|
||||||
self.save = None
|
self.current_processor = None
|
||||||
self.current_snap = None
|
|
||||||
|
|
||||||
def gen_simus(self):
|
def gen_simus(self):
|
||||||
self.simulations = {}
|
self.simulations = {}
|
||||||
@@ -344,35 +330,24 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Find plot save
|
# Find underlying processor
|
||||||
if from_cells or rule.kind == "cells":
|
if rule.kind == "snapshot":
|
||||||
self.current_snap = self.snaps[run][num]
|
self.current_processor = self.snaps[run][num]
|
||||||
if not os.path.exists(self.current_snap.cells_filename):
|
|
||||||
self.snaps[run][num].load_cells()
|
|
||||||
self.snaps[run][num].unload_cells()
|
|
||||||
save = tables.open_file(self.snaps[run][num].cells_filename)
|
|
||||||
elif rule.kind == "snapshot":
|
|
||||||
self.current_snap = self.snaps[run][num]
|
|
||||||
save = tables.open_file( self.current_snap.filename)
|
|
||||||
else:
|
else:
|
||||||
save = tables.open_file(self.study.filename, "r")
|
self.current_processor = self.study
|
||||||
|
|
||||||
# Call plot routine
|
# Call plot routine
|
||||||
try:
|
close = (not onefigure) or (i == len(run_num) - 1)
|
||||||
close = (not onefigure) or (i == len(run_num) - 1)
|
plot_info = self._plot_rule(
|
||||||
plot_info = self._plot_rule(
|
rule,
|
||||||
rule,
|
arg,
|
||||||
save,
|
plot_filename,
|
||||||
arg,
|
overwrite,
|
||||||
plot_filename,
|
ax=real_ax,
|
||||||
overwrite,
|
close=close,
|
||||||
ax=real_ax,
|
run=run,
|
||||||
close=close,
|
**kwargs,
|
||||||
run=run,
|
)
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
save.close()
|
|
||||||
|
|
||||||
# Save in astrophysix format
|
# Save in astrophysix format
|
||||||
df = rule.datafile(name, arg)
|
df = rule.datafile(name, arg)
|
||||||
@@ -404,14 +379,14 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
return datafiles
|
return datafiles
|
||||||
|
|
||||||
def _plot_rule(
|
def _plot_rule(
|
||||||
self, rule, save, arg, plot_filename, overwrite, ax, close=True, **kwargs
|
self, rule, arg, plot_filename, overwrite, ax, close=True, **kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Once all dependencies are met, actually process the rule
|
Once all dependencies are met, actually process the rule
|
||||||
"""
|
"""
|
||||||
plt.sca(ax)
|
plt.sca(ax)
|
||||||
if self._needs_computation(overwrite, plot_filename):
|
if self._needs_computation(overwrite, plot_filename):
|
||||||
plot_info = rule.plot(save, arg, **kwargs)
|
plot_info = rule.process(arg, **kwargs)
|
||||||
|
|
||||||
if not self.params.out.interactive and close:
|
if not self.params.out.interactive and close:
|
||||||
plt.tight_layout(pad=1)
|
plt.tight_layout(pad=1)
|
||||||
@@ -506,23 +481,22 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
label_run = label
|
label_run = label
|
||||||
return label_run
|
return label_run
|
||||||
|
|
||||||
def _ax_label_unit(self, node, label, unit, unit_coeff, put_units=True):
|
def _ax_label_unit(self, node_name, label, unit, unit_coeff, put_units=True):
|
||||||
"""
|
"""
|
||||||
Find appropriate labels for axis
|
Find appropriate labels for axis
|
||||||
"""
|
"""
|
||||||
if label is None:
|
if label is None:
|
||||||
if "label" in node._v_attrs:
|
try:
|
||||||
label = node._v_attrs.label
|
label = self.current_processor.get_attribute("node_name", "label")
|
||||||
elif node._v_name in self.label_convert:
|
except KeyError:
|
||||||
label = self.label_convert[node._v_name]
|
if os.path.basename(node_name) in self.label_convert:
|
||||||
elif not node._v_title == "":
|
label = self.label_convert[os.path.basename(node_name)]
|
||||||
label = node._v_title
|
else:
|
||||||
else:
|
label = os.path.basename(node_name)
|
||||||
label = node._v_name
|
|
||||||
|
|
||||||
if "unit" in node._v_attrs:
|
try:
|
||||||
unit_old = node._v_attrs.unit
|
unit_old = self.current_processor.get_attribute("node_name", "unit")
|
||||||
else:
|
except KeyError:
|
||||||
unit_old = U.none
|
unit_old = U.none
|
||||||
|
|
||||||
if unit is None:
|
if unit is None:
|
||||||
@@ -542,7 +516,7 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
title = self.get_label_run(run, title, nml_key)
|
title = self.get_label_run(run, title, nml_key)
|
||||||
|
|
||||||
if put_time:
|
if put_time:
|
||||||
time = self.save.root._v_attrs.time * self.study.info["unit_time"]
|
time = self.current_processor.info["time"] * self.study.info["unit_time"]
|
||||||
u_str = unit_str(unit_time, format="{unit}")
|
u_str = unit_str(unit_time, format="{unit}")
|
||||||
time_str = self.params.plot.time_fmt.format(time.express(unit_time), u_str)
|
time_str = self.params.plot.time_fmt.format(time.express(unit_time), u_str)
|
||||||
if len(title) > 0:
|
if len(title) > 0:
|
||||||
@@ -590,26 +564,26 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
ax_h = self._axes_h[ax_los]
|
ax_h = self._axes_h[ax_los]
|
||||||
ax_v = self._axes_v[ax_los]
|
ax_v = self._axes_v[ax_los]
|
||||||
|
|
||||||
im_extent = np.array(self.save.root.maps._v_attrs.im_extent)
|
im_extent = np.array(self.current_processor.im_extent)
|
||||||
unit_length = self.save.root._v_attrs["unit_length"]
|
unit_length = self.current_processor.info["unit_length"]
|
||||||
|
|
||||||
if embeded:
|
if embeded:
|
||||||
axes = False
|
axes = False
|
||||||
scalebar = True
|
scalebar = True
|
||||||
|
|
||||||
if center_space:
|
if center_space:
|
||||||
center = self.save.root.maps._v_attrs.center
|
center = self.current_processor.get_attribute("/maps", "center")
|
||||||
center_h = center[self._ax_nb[ax_h]]
|
center_h = center[self._ax_nb[ax_h]]
|
||||||
center_v = center[self._ax_nb[ax_v]]
|
center_v = center[self._ax_nb[ax_v]]
|
||||||
im_extent[:2] = im_extent[:2] - center_h
|
im_extent[:2] = im_extent[:2] - center_h
|
||||||
im_extent[2:] = im_extent[2:] - center_v
|
im_extent[2:] = im_extent[2:] - center_v
|
||||||
im_extent = im_extent * unit_length.express(unit_space)
|
im_extent = im_extent * unit_length.express(unit_space)
|
||||||
|
|
||||||
node = self.save.get_node("/maps/{}_{}".format(name, ax_los))
|
node_name = f"/maps/{name}_{ax_los}"
|
||||||
dmap = node.read()
|
dmap = self.current_processor.get_value(node_name)
|
||||||
|
|
||||||
label, unit_old, unit = self._ax_label_unit(
|
label, unit_old, unit = self._ax_label_unit(
|
||||||
node, label, unit, unit_coeff, put_units
|
node_name, label, unit, unit_coeff, put_units
|
||||||
)
|
)
|
||||||
|
|
||||||
dmap = dmap * unit_old.express(unit)
|
dmap = dmap * unit_old.express(unit)
|
||||||
@@ -736,7 +710,7 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
"""
|
"""
|
||||||
Add an overlay : contour of other map
|
Add an overlay : contour of other map
|
||||||
"""
|
"""
|
||||||
map_contour = self.save.get_node("/maps/{}_{}".format(map_name, ax_los)).read()
|
map_contour = self.current_processor.get_value("/maps/{}_{}".format(map_name, ax_los))
|
||||||
if log:
|
if log:
|
||||||
map_contour = np.log10(map_contour)
|
map_contour = np.log10(map_contour)
|
||||||
# Computing linewidths
|
# Computing linewidths
|
||||||
@@ -797,29 +771,27 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
center_space=False,
|
center_space=False,
|
||||||
parts = True,
|
parts = True,
|
||||||
sinks = False,
|
sinks = False,
|
||||||
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add an overlay with particles data
|
Add an overlay with particles data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
unit_length = self.current_snap.info["unit_length"]
|
unit_length = self.current_processor.info["unit_length"]
|
||||||
|
|
||||||
if sinks:
|
if sinks:
|
||||||
self.current_snap.sinks()
|
self.current_processor.sinks()
|
||||||
sinks = pd.DataFrame(self.current_snap.get_value("/datasets/sinks"))
|
sinks = pd.DataFrame(self.current_processor.get_value("/datasets/sinks"))
|
||||||
part_pos = sinks[["x", "y", "z"]].values
|
part_pos = sinks[["x", "y", "z"]].values
|
||||||
mass = sinks.M
|
mass = sinks.M
|
||||||
|
|
||||||
if parts:
|
if parts:
|
||||||
# Open particle HDF5 filetype_from_ext
|
# Open particle HDF5 filetype_from_ext
|
||||||
filename = self.save.get_node("/hdf5/write_particles").read()[0].decode()
|
self.current_processor.load_parts(keys=["pos", "mass"])
|
||||||
hdf5_parts = tables.open_file(filename, "r")
|
part_pos = self.current_processor.parts.pos
|
||||||
part_pos = hdf5_parts.get_node("/data/pos").read()
|
mass = self.current_processor.parts.mass
|
||||||
mass = hdf5_parts.get_node("/data/mass").read()
|
mass *= self.current_processor.info["unit_mass"].express(U.Msun)
|
||||||
mass *= self.current_snap.info["unit_mass"].express(U.Msun)
|
self.current_processor.unload_parts()
|
||||||
hdf5_parts.close()
|
|
||||||
|
|
||||||
# index of the horizontal axis
|
# index of the horizontal axis
|
||||||
ih = self._ax_nb[self._axes_h[ax_los]]
|
ih = self._ax_nb[self._axes_h[ax_los]]
|
||||||
@@ -833,7 +805,7 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
if center_space:
|
if center_space:
|
||||||
ax_h = self._axes_h[ax_los]
|
ax_h = self._axes_h[ax_los]
|
||||||
ax_v = self._axes_v[ax_los]
|
ax_v = self._axes_v[ax_los]
|
||||||
center = self.save.root.maps._v_attrs.center
|
center = self.current_processor.get_attribute("/maps", "center")
|
||||||
center_h = center[self._ax_nb[ax_h]]
|
center_h = center[self._ax_nb[ax_h]]
|
||||||
center_v = center[self._ax_nb[ax_v]]
|
center_v = center[self._ax_nb[ax_v]]
|
||||||
part_h -= center_h
|
part_h -= center_h
|
||||||
@@ -861,11 +833,10 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
"""
|
"""
|
||||||
Add an overlay : velocity vector field
|
Add an overlay : velocity vector field
|
||||||
"""
|
"""
|
||||||
dmap_vh_node = self.save.get_node("/maps/speed_h_{}".format(ax_los))
|
dmap_vh = self.current_processor.get_value("/maps/speed_h_{}".format(ax_los))
|
||||||
dmap_vh = dmap_vh_node.read()
|
dmap_vv = self.current_processor.get_value("/maps/speed_v_{}".format(ax_los))
|
||||||
dmap_vv = self.save.get_node("/maps/speed_v_{}".format(ax_los)).read()
|
|
||||||
|
|
||||||
label, unit_old, unit = self._ax_label_unit(dmap_vh_node, "", unit, unit_coeff)
|
label, unit_old, unit = self._ax_label_unit(f"/maps/speed_h_{ax_los}", "", unit, unit_coeff)
|
||||||
|
|
||||||
vel_red = self.params.plot.vel_red
|
vel_red = self.params.plot.vel_red
|
||||||
|
|
||||||
@@ -915,16 +886,15 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
"""
|
"""
|
||||||
Add an overlay : magnetic streamlines
|
Add an overlay : magnetic streamlines
|
||||||
"""
|
"""
|
||||||
dmap_Bh_node = self.save.get_node("/maps/B_h_{}".format(ax_los))
|
dmap_Bh = self.current_processor.get_value(f"/maps/B_h_{ax_los}")
|
||||||
dmap_Bh = dmap_Bh_node.read()
|
dmap_Bv = self.current_processor.get_value(f"/maps/B_v_{ax_los}")
|
||||||
dmap_Bv = self.save.get_node("/maps/B_v_{}".format(ax_los)).read()
|
|
||||||
|
|
||||||
# TODO : redo this with im_extent
|
# TODO : redo this with im_extent
|
||||||
|
|
||||||
vel_red = self.params.plot.vel_red
|
vel_red = self.params.plot.vel_red
|
||||||
radius = self.save.root.maps._v_attrs.radius
|
radius = self.current_processor.attribute("/maps", "radius")
|
||||||
center = self.save.root.maps._v_attrs.center
|
center = self.current_processor.attribute("/maps", "center")
|
||||||
lbox = self.save.root._v_attrs.lbox
|
lbox = self.current_processor.lbox
|
||||||
|
|
||||||
map_Bh_red = dmap_Bh[::vel_red, ::vel_red] # take only a subset of velocities
|
map_Bh_red = dmap_Bh[::vel_red, ::vel_red] # take only a subset of velocities
|
||||||
map_Bv_red = dmap_Bv[::vel_red, ::vel_red]
|
map_Bv_red = dmap_Bv[::vel_red, ::vel_red]
|
||||||
@@ -973,22 +943,23 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
# Get node
|
# Get node
|
||||||
if ax_los is not None:
|
if ax_los is not None:
|
||||||
name = name + "_" + ax_los
|
name = name + "_" + ax_los
|
||||||
node = self.save.get_node(group + name)
|
node_name = group + name
|
||||||
if xlog is None:
|
if xlog is None:
|
||||||
try:
|
try:
|
||||||
xlog = node._v_attrs.logbins
|
xlog = self.current_processor.get_attribute(node_name, "logbins")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
xlog = False
|
xlog = False
|
||||||
|
|
||||||
# get label and units
|
# get label and units
|
||||||
xlabel, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff)
|
xlabel, unit_old, unit = self._ax_label_unit(node_name, label, unit, unit_coeff)
|
||||||
|
|
||||||
# Read data
|
# Read data
|
||||||
|
node = self.current_processor.get_value(node_name)
|
||||||
if "mean" in node:
|
if "mean" in node:
|
||||||
index = node["runs"].read().index(run.encode())
|
index = node["runs"].index(run.encode())
|
||||||
values, centers = node["mean"].read()[index]
|
values, centers = node["mean"][index]
|
||||||
else:
|
else:
|
||||||
values, centers = node.read()
|
values, centers = node
|
||||||
if xlog:
|
if xlog:
|
||||||
centers = centers + np.log10(unit_old.express(unit))
|
centers = centers + np.log10(unit_old.express(unit))
|
||||||
else:
|
else:
|
||||||
@@ -1034,16 +1005,16 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
plt.ylabel(ylabel)
|
plt.ylabel(ylabel)
|
||||||
|
|
||||||
# Also diplay fit, previously saved
|
# Also diplay fit, previously saved
|
||||||
if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
|
# if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
|
||||||
slope = node.attrs.slope
|
# slope = node.attrs.slope
|
||||||
origin = node.attrs.origin
|
# origin = node.attrs.origin
|
||||||
plt.plot(
|
# plt.plot(
|
||||||
centers,
|
# centers,
|
||||||
10 ** (slope * centers + origin),
|
# 10 ** (slope * centers + origin),
|
||||||
"--",
|
# "--",
|
||||||
linewidth=2,
|
# linewidth=2,
|
||||||
color="orange",
|
# color="orange",
|
||||||
)
|
# )
|
||||||
# or a new one
|
# or a new one
|
||||||
if fit is not None:
|
if fit is not None:
|
||||||
self._overlay_fit(
|
self._overlay_fit(
|
||||||
@@ -1108,28 +1079,28 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg
|
name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg
|
||||||
|
|
||||||
# Get hdf5 nodes
|
# Get hdf5 nodes
|
||||||
node_x = self.save.get_node(name_x)
|
node_x = self.current_processor.get_value(name_x)
|
||||||
node_y = self.save.get_node(name_y)
|
node_y = self.current_processor.get_value(name_y)
|
||||||
|
|
||||||
# If the actual data is in another file, fetch it
|
# If the actual data is in another file, fetch it
|
||||||
if subname_x:
|
if subname_x:
|
||||||
hdf5_x = tables.open_file(node_x.read())
|
hdf5_x = tables.open_file(node_x)
|
||||||
node_x = hdf5_x.get_node(subname_x)
|
node_x = hdf5_x.get_node(subname_x).read()
|
||||||
if subname_y:
|
if subname_y:
|
||||||
hdf5_y = tables.open_file(node_y.read())
|
hdf5_y = tables.open_file(node_y)
|
||||||
node_y = hdf5_y.get_node(subname_y)
|
node_y = hdf5_y.get_node(subname_y).read()
|
||||||
|
|
||||||
# Find proper labels
|
# Find proper labels
|
||||||
xlabel, xunit_old, xunit = self._ax_label_unit(
|
xlabel, xunit_old, xunit = self._ax_label_unit(
|
||||||
node_x, xlabel, xunit, xunit_coeff
|
name_x, xlabel, xunit, xunit_coeff
|
||||||
)
|
)
|
||||||
ylabel, yunit_old, yunit = self._ax_label_unit(
|
ylabel, yunit_old, yunit = self._ax_label_unit(
|
||||||
node_y, ylabel, yunit, yunit_coeff
|
name_y, ylabel, yunit, yunit_coeff
|
||||||
)
|
)
|
||||||
|
|
||||||
# If relevent, get time
|
# If relevent, get time
|
||||||
if put_time:
|
if put_time:
|
||||||
time = self.save.root._v_attrs.time * self.study.info["unit_time"]
|
time = self.current_processor.time * self.study.info["unit_time"]
|
||||||
time_str = self.params.plot.time_fmt.format(
|
time_str = self.params.plot.time_fmt.format(
|
||||||
time.express(unit_time), unit_time.latex.replace("text", "math")
|
time.express(unit_time), unit_time.latex.replace("text", "math")
|
||||||
)
|
)
|
||||||
@@ -1141,28 +1112,28 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
|
|
||||||
# Manage the different forms in which the data may be stored :
|
# Manage the different forms in which the data may be stored :
|
||||||
# Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs)
|
# Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs)
|
||||||
if node_y._v_attrs.CLASS == "ARRAY":
|
if isinstance(node_y, np.ndarray):
|
||||||
x = node_x.read() * xunit_old.express(xunit)
|
x = node_x * xunit_old.express(xunit)
|
||||||
y = node_y.read() * yunit_old.express(yunit)
|
y = node_y * yunit_old.express(yunit)
|
||||||
mask = np.isfinite(x) & np.isfinite(y)
|
mask = np.isfinite(x) & np.isfinite(y)
|
||||||
x, y = x[mask], y[mask]
|
x, y = x[mask], y[mask]
|
||||||
elif "mean" in node_y:
|
elif "mean" in node_y:
|
||||||
x = node_x.read() * xunit_old.express(xunit)
|
x = node_x * xunit_old.express(xunit)
|
||||||
y = node_y.mean.read() * yunit_old.express(yunit)
|
y = node_y.mean * yunit_old.express(yunit)
|
||||||
|
|
||||||
if yerr_kind == "std":
|
if yerr_kind == "std":
|
||||||
std = node_y.std.read() * yunit_old.express(yunit)
|
std = node_y["std"] * yunit_old.express(yunit)
|
||||||
yerr_min = y - sigma_err * std
|
yerr_min = y - sigma_err * std
|
||||||
yerr_max = y + sigma_err * std
|
yerr_max = y + sigma_err * std
|
||||||
elif yerr_kind == "min_max":
|
elif yerr_kind == "min_max":
|
||||||
yerr_min = node_y.min.read() * yunit_old.express(yunit)
|
yerr_min = node_y["min"] * yunit_old.express(yunit)
|
||||||
yerr_max = node_y.max.read() * yunit_old.express(yunit)
|
yerr_max = node_y["max"] * yunit_old.express(yunit)
|
||||||
elif yerr_kind == "95per":
|
elif yerr_kind == "95per":
|
||||||
yerr_min = node_y.q025.read() * yunit_old.express(yunit)
|
yerr_min = node_y["q025"] * yunit_old.express(yunit)
|
||||||
yerr_max = node_y.q975.read() * yunit_old.express(yunit)
|
yerr_max = node_y["q975"] * yunit_old.express(yunit)
|
||||||
elif yerr_kind == "68per":
|
elif yerr_kind == "68per":
|
||||||
yerr_min = node_y.q16.read() * yunit_old.express(yunit)
|
yerr_min = node_y["q16"] * yunit_old.express(yunit)
|
||||||
yerr_max = node_y.q84.read() * yunit_old.express(yunit)
|
yerr_max = node_y["q84"] * yunit_old.express(yunit)
|
||||||
else:
|
else:
|
||||||
yerr_min = y
|
yerr_min = y
|
||||||
yerr_max = y
|
yerr_max = y
|
||||||
@@ -1176,13 +1147,13 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
yerr_max[mask],
|
yerr_max[mask],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
x = node_x[run].read() * xunit_old.express(xunit)
|
x = node_x[run] * xunit_old.express(xunit)
|
||||||
y = node_y[run].read() * yunit_old.express(yunit)
|
y = node_y[run] * yunit_old.express(yunit)
|
||||||
mask = np.isfinite(x) & np.isfinite(y)
|
mask = np.isfinite(x) & np.isfinite(y)
|
||||||
x, y = x[mask], y[mask]
|
x, y = x[mask], y[mask]
|
||||||
|
|
||||||
if isinstance(yerr, str):
|
if isinstance(yerr, str):
|
||||||
yerr = self.save.get_node(yerr).read()
|
yerr = self.current_processor.get_value(yerr)
|
||||||
|
|
||||||
# Apply transformations on x
|
# Apply transformations on x
|
||||||
if xtransform is not None:
|
if xtransform is not None:
|
||||||
@@ -1211,7 +1182,7 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
color = colors[run]
|
color = colors[run]
|
||||||
elif nml_color == "time":
|
elif nml_color == "time":
|
||||||
time = (
|
time = (
|
||||||
self.save.root._v_attrs.time * self.study.info["unit_time"]
|
self.current_processor.time * self.current_processor.info["unit_time"]
|
||||||
).express(unit_time)
|
).express(unit_time)
|
||||||
color = colors(time)
|
color = colors(time)
|
||||||
else:
|
else:
|
||||||
@@ -1261,8 +1232,8 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
Plot power spectrum (wrapper around pspec_read)
|
Plot power spectrum (wrapper around pspec_read)
|
||||||
"""
|
"""
|
||||||
del kwargs["run"]
|
del kwargs["run"]
|
||||||
file_pspec = self.save.get_node("/hdf5/pspec").read()
|
file_pspec = self.current_processor.get_value("/hdf5/pspec")
|
||||||
num = self.save.root._v_attrs.num
|
num = self.current_processor.num
|
||||||
getattr(pspec_read, "pspec_" + name)(file_pspec, ".", num, **kwargs)
|
getattr(pspec_read, "pspec_" + name)(file_pspec, ".", num, **kwargs)
|
||||||
|
|
||||||
def _overlay_fit(self, x, y, yerr=None, kind="linear", label=None, **kwargs):
|
def _overlay_fit(self, x, y, yerr=None, kind="linear", label=None, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user