add astrophysix support, fix labels

This commit is contained in:
Noe Brucy
2020-12-14 10:08:01 +01:00
parent 653e64d782
commit 7f7216abf6
6 changed files with 519 additions and 355 deletions
+310 -303
View File
@@ -17,7 +17,8 @@ from scipy.stats import linregress
from numpy.polynomial.polynomial import polyfit
from scipy.ndimage.filters import gaussian_filter1d
from scipy import optimize
from astrophysix.simdm.datafiles import Datafile, PlotType, PlotInfo
from astrophysix.utils.file import FileType
import matplotlib as mpl
if os.environ.get("DISPLAY", "") == "":
@@ -26,12 +27,15 @@ if os.environ.get("DISPLAY", "") == "":
import pylab as P
from comparator import *
import pspec_read
import datetime
P.rcParams["image.cmap"] = "plasma"
P.rcParams["savefig.dpi"] = 400
filetype_from_ext = {ext: ft for ft in FileType for ext in ft.extension_list}
tex_params = {"text.latex.preamble": r"\usepackage{amsmath}"}
P.rcParams.update(tex_params)
def not_array_error(err):
epy2 = "object does not support indexing"
epy3 = "object is not subscriptable"
return str(err)[-len(epy2) :] == epy2 or str(err)[-len(epy3) :] == epy3
class PlotRule(Rule):
@@ -42,7 +46,7 @@ class PlotRule(Rule):
def plot(self, save, arg, **kwargs):
"""
Set the plotter's storage to 'save' and exetute the rule
Set the plotter's storage to 'save' and execute the rule
Parameters
----------
@@ -53,6 +57,12 @@ class PlotRule(Rule):
self.postproc.save = save
return self.process_fn(arg, **kwargs)
def datafile(self, arg):
return Datafile(
name=self.description + "_" + arg,
description=self.description + " ({})".format(arg),
)
class Plotter(Aggregator, BaseProcessor):
"""
@@ -99,6 +109,7 @@ class Plotter(Aggregator, BaseProcessor):
pp_params=None,
selector=None,
tag=None,
unit_time=cst.year,
**kwargs,
):
@@ -135,7 +146,13 @@ class Plotter(Aggregator, BaseProcessor):
# Get comparator object
self.comp = Comparator(
path, self.runs, self.nums, path_out, self.pp_params, selector=self.selector
path,
self.runs,
self.nums,
path_out,
self.pp_params,
unit_time=unit_time,
selector=self.selector,
)
# Get postprocesor objets for each run
@@ -147,17 +164,56 @@ class Plotter(Aggregator, BaseProcessor):
# Define rules
self.def_rules()
# generate astrophysix's simulations object
self.gen_simus()
self.save = None
def gen_simus(self):
self.simulations = {}
simu_fmt = self.pp_params.astrophysix.simu_fmt
descr_fmt = self.pp_params.astrophysix.descr_fmt
tag = self.pp_params.out.tag
for run in self.runs:
pp = self.pp[run][self.nums[run][0]]
nml = self.comp.namelist[run]
name = simu_fmt.format(run=run, tag=tag, nml=nml)
exec_time = str(datetime.datetime.fromtimestamp(os.stat(pp.path).st_ctime))
description = descr_fmt.format(run=run, tag=tag, nml=nml)
simu = Simulation(
simu_code=ramses,
name=name,
alias=name.upper(),
description=description,
directory_path=pp.path,
execution_time=exec_time,
)
for param in ramses.input_parameters:
try:
param_setting = ParameterSetting(
input_param=param,
value=self.comp.get_nml(param.key, run),
visibility=ParameterVisibility.BASIC_DISPLAY,
)
simu.parameter_settings.add(param_setting)
except KeyError as e:
self._log("key {} not found".format(e), "WARNING")
except AttributeError as e:
self._log("{}".format(e), "WARNING")
self.simulations[run] = simu
def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs):
"""
Check if the dependency belongs to the plotter object or to another one (comp, pp, ..)
"""
if dep in self.comp.rules:
done = self.comp.process(
result = self.comp.process(
dep, dep_arg, overwrite, overwrite_dep=self.overwrite_dep
)
self.just_done.extend(done)
if result is not None:
self.just_done.append(done)
else:
super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, **kwargs)
@@ -172,19 +228,13 @@ class Plotter(Aggregator, BaseProcessor):
)
def _process_rule(
self,
name,
rule,
arg,
overwrite=False,
ax=None,
movie=False,
from_cells=False,
**kwargs,
self, name, rule, arg, overwrite=False, ax=None, from_cells=False, **kwargs
):
"""
Open storage and figure if needed before processing a rule
"""
# Set full name according to argument
if not arg is None:
name_full = (
name
@@ -200,144 +250,89 @@ class Plotter(Aggregator, BaseProcessor):
else:
name_full = name
if rule.is_valid(arg):
if rule.kind == "classic" or rule.kind == "cells":
if "select" in kwargs:
select = kwargs.pop("select")
runs, nums = self.selector.select(**select)
elif "runs" in kwargs:
runs = kwargs.pop("runs")
if isinstance(runs, RunSelector):
nums = runs.nums
runs = runs.runs
else:
nums = self.nums
else:
runs = self.runs
nums = self.nums
i = 0
for run in runs:
files = []
for num in nums[run]:
plot_filename = self._find_filename(name_full, run, num)
if from_cells or rule.kind == "cells":
if not os.exists(self.pp[run][num].cells_filename):
self.pp[run][num].load_cells()
self.pp[run][num].unload_cells()
save = tables.open_file(self.pp[run][num].cells_filename)
elif rule.kind == "classic":
save = tables.open_file(self.pp[run][num].filename)
else:
save = tables.open_file(self.comp.filename, "r")
try:
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax[i],
run=run,
**kwargs,
)
except TypeError as e:
if str(e) in [
"'LocatableAxes' object does not support indexing",
"'AxesSubplot' object does not support indexing",
"'AxesSubplot' object is not subscriptable",
"'Axes' object is not subscriptable",
"'LocatableAxes' object is not subscriptable",
]:
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax,
run=run,
**kwargs,
)
elif ax is None:
fig = P.figure()
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=P.gca(),
run=run,
**kwargs,
)
else:
raise
finally:
save.close()
i = i + 1
files.append(plot_filename)
else:
if "select" in kwargs and not "runs" in kwargs:
select = kwargs.pop("select")
runs, nums = self.selector.select(**select)
if not rule.kind == "runs":
kwargs["runs"] = runs
elif rule.kind == "runs" and "runs" in kwargs:
runs = kwargs.pop("runs")
else:
runs = self.runs
if ax is None:
ax = P.gca()
if rule.kind == "series" and len(runs) == 1:
run = self.runs[0]
plot_filename = self._find_filename(name_full, run)
else:
plot_filename = self._find_filename(name_full)
save = tables.open_file(self.comp.filename, "r")
try:
if rule.kind == "runs":
for i, run in enumerate(runs):
try:
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax[i],
run=run,
**kwargs,
)
except TypeError as e:
if str(e) in [
"'LocatableAxes' object does not support indexing",
"'AxesSubplot' object does not support indexing",
"'AxesSubplot' object is not subscriptable",
"'Axes' object is not subscriptable",
"'LocatableAxes' object is not subscriptable",
]:
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax,
run=run,
**kwargs,
)
else:
self._plot_rule(
rule, save, arg, plot_filename, overwrite, ax, **kwargs
)
finally:
save.close()
else:
# Exit if not valid
if not rule.is_valid(arg):
self._log("{} is not valid in this context".format(name_full), "ERROR")
return
# get filetype of the output
filetype = filetype_from_ext[self.pp_params.out.ext]
# Select runs and nums
if "select" in kwargs:
select = kwargs.pop("select")
runs, nums = self.selector.select(**select)
else:
runs = self.runs
nums = self.nums
datafiles = []
# Several plots
if rule.kind == "classic" or rule.kind == "cells":
run_num = [(run, num) for run in runs for num in nums[run]]
else:
run_num = [(run, None) for run in runs]
for i, (run, num) in enumerate(run_num):
# Find filename
plot_filename = self._find_filename(name_full, run, num)
# Find ax
try:
real_ax = ax[i]
except TypeError as e:
if ax is None:
fig, real_ax = P.subplots(1, 1)
elif not_array_error(e):
real_ax = ax
else:
raise
# Find plot save
if from_cells or rule.kind == "cells":
if not os.exists(self.pp[run][num].cells_filename):
self.pp[run][num].load_cells()
self.pp[run][num].unload_cells()
save = tables.open_file(self.pp[run][num].cells_filename)
elif rule.kind == "classic":
save = tables.open_file(self.pp[run][num].filename)
else:
save = tables.open_file(self.comp.filename, "r")
# Call plot routine
try:
plot_info = self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=real_ax,
run=run,
**kwargs,
)
finally:
save.close()
# Save in astrophysix format
df = rule.datafile(arg)
df[filetype] = plot_filename
if plot_info is not None:
df.plot_info = plot_info
if num is not None:
snap = self.pp[run][num].snapshot
if overwrite and df.name in snap.datafiles:
del snap.datafiles[df.name]
elif df.name not in snap.datafiles:
snap.datafiles.add(df)
if snap not in self.simulations[run].snapshots:
self.simulations[run].snapshots.add(snap)
datafiles.append(df)
return datafiles
def _plot_rule(self, rule, save, arg, plot_filename, overwrite, ax, **kwargs):
"""
@@ -345,7 +340,7 @@ class Plotter(Aggregator, BaseProcessor):
"""
P.sca(ax)
if self._needs_computation(overwrite, plot_filename):
rule.plot(save, arg, **kwargs)
plot_info = rule.plot(save, arg, **kwargs)
if not self.pp_params.out.interactive:
P.tight_layout(pad=1)
@@ -360,6 +355,7 @@ class Plotter(Aggregator, BaseProcessor):
if not self.pp_params.out.interactive:
P.close()
return plot_info
else:
self._log("Plot {} is already done, skipping...".format(plot_filename))
@@ -396,7 +392,7 @@ class Plotter(Aggregator, BaseProcessor):
ext=self.pp_params.out.ext,
)
def _label_run(self, run, node, label, nml_key):
def _label_run(self, run, node, label, nml_key, time=None):
"""
Set up a label for the run from the namelist and parameters
"""
@@ -435,7 +431,7 @@ class Plotter(Aggregator, BaseProcessor):
label_run = label
return label_run
def _ax_label_unit(self, node, label, unit, unit_coeff):
def _ax_label_unit(self, node, label, unit, unit_coeff, put_units=True):
"""
Find appropriate labels for axis
"""
@@ -457,20 +453,38 @@ class Plotter(Aggregator, BaseProcessor):
if unit is None:
unit = unit_old
if not unit_coeff == 1:
base = unit
unit = unit_coeff * unit
label = label + unit_str(unit, base=base)
else:
label = label + unit_str(unit)
if put_units:
if not unit_coeff == 1:
base = unit
unit = unit_coeff * unit
label = label + unit_str(unit, base=base)
else:
label = label + unit_str(unit)
return label, unit_old, unit
def _snapshot_title(self, run, node, title, nml_key, put_time, unit_time=cst.Myr):
title = self._label_run(run, node, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
u_str = unit_str(unit_time, format="{unit}")
time_str = self.pp_params.plot.time_fmt.format(
time.express(unit_time), u_str
)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
return title
def _plot_map(
self,
name,
ax_los,
run,
xlabel=None,
ylabel=None,
label=None,
unit=None,
unit_coeff=1.0,
@@ -480,12 +494,14 @@ class Plotter(Aggregator, BaseProcessor):
put_title=True,
nml_key=None,
put_time=True,
time_unit=cst.Myr,
unit_time=cst.Myr,
put_units=True,
unit_space=cst.pc,
cmap="plasma",
norm="log",
put_cbar=True,
autoscale=True,
transform=None,
**kwargs,
):
"""
@@ -503,9 +519,13 @@ class Plotter(Aggregator, BaseProcessor):
node = self.save.get_node("/maps/{}_{}".format(name, ax_los))
dmap = node.read()
label, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff)
label, unit_old, unit = self._ax_label_unit(
node, label, unit, unit_coeff, put_units
)
dmap = dmap * unit_old.express(unit)
if transform is not None:
dmap = transform(dmap)
if norm == "log":
norm = mpl.colors.LogNorm()
@@ -521,32 +541,28 @@ class Plotter(Aggregator, BaseProcessor):
P.locator_params(axis="both", nbins=self.pp_params.plot.ntick)
P.xlabel(self._ax_title[ax_h] + unit_str(unit_space))
P.ylabel(self._ax_title[ax_v] + unit_str(unit_space))
if xlabel is None:
xlabel = self._ax_title[ax_h]
if ylabel is None:
ylabel = self._ax_title[ax_v]
if put_units:
xlabel = xlabel + unit_str(unit_space)
ylabel = ylabel + unit_str(unit_space)
P.xlabel(xlabel)
P.ylabel(ylabel)
try:
cbar = P.colorbar(im, cax=P.gca().cax)
except AttributeError:
cbar = P.colorbar()
if put_title:
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
P.title(title)
if not label is None:
cbar.set_label(label)
if put_title:
title = self._label_run(run, node, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex
)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
P.title(title)
for i, plot_overlay in enumerate(overlays):
if plot_overlay in self.overlays:
plot_overlay = self.overlays[plot_overlay]
@@ -556,6 +572,23 @@ class Plotter(Aggregator, BaseProcessor):
except:
plot_overlay(ax_los, im_extent)
return PlotInfo(
plot_type=PlotType.IMAGE,
xaxis_values=np.linspace(im_extent[0], im_extent[1], dmap.shape[0] + 1),
yaxis_values=np.linspace(im_extent[2], im_extent[3], dmap.shape[1] + 1),
values=dmap,
xaxis_log_scale=False,
yaxis_log_scale=False,
values_log_scale=(norm == mpl.colors.LogNorm()),
xaxis_label=xlabel,
yaxis_label=ylabel,
values_label=label,
xaxis_unit=unit_space,
yaxis_unit=unit_space,
values_unit=unit,
plot_title=title,
)
def _overlay_contour(
self,
ax_los,
@@ -724,7 +757,7 @@ class Plotter(Aggregator, BaseProcessor):
nml_key=None,
put_title=True,
put_time=True,
time_unit=cst.Myr,
unit_time=cst.Myr,
**kwargs,
):
"""
@@ -749,26 +782,13 @@ class Plotter(Aggregator, BaseProcessor):
if not ylabel is None:
P.ylabel(ylabel)
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
if put_title:
title = self._label_run(run, node, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex
)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
P.title(title)
if label is None:
if label == None:
label = title
P.plot(bin_centers, mean_bin, label=label, **kwargs)
P.plot(bin_centers, mean_bin, label=title, **kwargs)
P.plot(bin_centers, mean_bin, label=label, **kwargs)
def _plot_hist(
self,
@@ -781,10 +801,11 @@ class Plotter(Aggregator, BaseProcessor):
unit_coeff=1.0,
ytransform=None,
label=None,
put_title=True,
title=None,
nml_key=None,
put_time=True,
time_unit=cst.Myr,
unit_time=cst.Myr,
xlog=None,
ylog=False,
kind="bar",
@@ -799,47 +820,41 @@ class Plotter(Aggregator, BaseProcessor):
"""
Plot an histogram (PDF, etc ...)
"""
# Get node
if not ax_los is None:
name = name + "_" + ax_los
node = self.save.get_node(group + name)
if xlog is None:
try:
xlog = node._v_attrs_.logbins
except:
xlog = False
# get label and units
xlabel, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff)
# Read data
if "mean" in node:
index = node["runs"].read().index(run.encode())
values, centers = node["mean"].read()[index]
else:
values, centers = node.read()
if xlog:
centers = centers + np.log10(unit_old.express(unit))
else:
centers = centers * unit_old.express(unit)
if ytransform is not None:
values = ytransform(values)
width = centers[1] - centers[0]
title = self._label_run(run, node, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex
)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
P.title(title)
# Set title
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
if put_title:
P.title(title)
if label == None:
label = title
# Set colors
if color is None and not colors is None:
if nml_color is None:
color = colors[run]
@@ -850,11 +865,8 @@ class Plotter(Aggregator, BaseProcessor):
except:
color = colors(nml)
if label == None:
label = title
# Actual plot
if kind == "bar":
width = centers[1] - centers[0]
P.bar(centers, values, width, log=ylog, color=color, label=label, **kwargs)
elif kind == "step":
if ylog:
@@ -863,11 +875,13 @@ class Plotter(Aggregator, BaseProcessor):
else:
raise ValueError("kind must be 'bar' or 'step'")
# put labels
if not label is None:
P.xlabel(xlabel)
if not ylabel is None:
P.ylabel(ylabel)
# Also diplay fit, previously saved
if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
slope = node.attrs.slope
origin = node.attrs.origin
@@ -878,14 +892,27 @@ class Plotter(Aggregator, BaseProcessor):
linewidth=2,
color="orange",
)
P.ylim([None, 1.0])
# or a new one
if not fit is None:
self._overlay_fit(
centers, values, kind=fit, ls="--", lw=1.5, label=fitlabel
)
# returns PlotInfo (for Galactica)
edges = np.append(centers - width / 2.0, centers[-1] + width / 2.0)
return PlotInfo(
plot_type=PlotType.HISTOGRAM,
xaxis_values=edges,
yaxis_values=values,
xaxis_log_scale=False,
yaxis_log_scale=ylog,
xaxis_label=xlabel,
yaxis_label=ylabel,
xaxis_unit=unit,
yaxis_unit=cst.none,
plot_title=title,
)
def _plot(
self,
name_x,
@@ -898,19 +925,17 @@ class Plotter(Aggregator, BaseProcessor):
yunit=None,
xunit_coeff=1.0,
yunit_coeff=1.0,
ylog=False,
fit=None,
fitlabel=None,
smooth=0,
nml_key=None,
run=None,
runs=None,
yerr=None,
yerr_kind="std",
sigma_err=2.0,
grid=False,
put_time=False,
time_unit=cst.Myr,
unit_time=cst.Myr,
colors=None,
nml_color=None,
legend=None,
@@ -922,12 +947,15 @@ class Plotter(Aggregator, BaseProcessor):
Generic plot routine, with name_x and name_y two path in the hdf5 file
"""
# Get proper hdf5 names
if not node_arg is None:
name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg
# Get hdf5 nodes
node_x = self.save.get_node(name_x)
node_y = self.save.get_node(name_y)
# If the actual data is in a,other file, fetch it
if subname_x:
hdf5_x = tables.open_file(node_x.read())
node_x = hdf5_x.get_node(subname_x)
@@ -935,6 +963,7 @@ class Plotter(Aggregator, BaseProcessor):
hdf5_y = tables.open_file(node_y.read())
node_y = hdf5_y.get_node(subname_y)
# Find proper labels
xlabel, xunit_old, xunit = self._ax_label_unit(
node_x, xlabel, xunit, xunit_coeff
)
@@ -942,58 +971,24 @@ class Plotter(Aggregator, BaseProcessor):
node_y, ylabel, yunit, yunit_coeff
)
P.xlabel(xlabel)
P.ylabel(ylabel)
if grid:
P.grid()
if ylog:
P.yscale("log")
# If relevent, get time
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex
time.express(unit_time), unit_time.latex
)
if label is not None and len(label) > 0:
label = label + " | " + time_str
else:
label = time_str
# Manage the different forms in which the data may be stored :
# Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs)
if node_y._v_attrs.CLASS == "ARRAY":
x = node_x.read() * xunit_old.express(xunit)
y = node_y.read() * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
if not run is None:
label = self._label_run(run, node_y, label, nml_key)
if colors is None:
(base_line,) = P.plot(x, y, label=label, **kwargs)
else:
if nml_color is None:
color = colors[run]
elif nml_color == "time":
time = (
self.save.root._v_attrs.time * self.comp.info["unit_time"]
).express(time_unit)
color = colors(time)
else:
nml = self.comp.get_nml(nml_color, run)
try:
color = colors[nml]
except:
color = colors(nml)
if yerr is None:
(base_line,) = P.plot(x, y, label=label, color=color, **kwargs)
else:
if isinstance(yerr, str):
yerr = self.save.get_node(yerr).read()
base_line, _, _ = P.errorbar(x, y, yerr=yerr, label=label, **kwargs)
elif "mean" in node_y:
x = node_x.read() * xunit_old.express(xunit)
y = node_y.mean.read() * yunit_old.express(yunit)
@@ -1014,7 +1009,6 @@ class Plotter(Aggregator, BaseProcessor):
else:
yerr_min = y
yerr_max = y
yerr = yerr_max - yerr_min
mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(yerr)
x, y, yerr, yerr_min, yerr_max = (
@@ -1024,44 +1018,51 @@ class Plotter(Aggregator, BaseProcessor):
yerr_min[mask],
yerr_max[mask],
)
if not run is None:
label = self._label_run(run, node_y, label, nml_key)
else:
x, y = node_x[run], node_y[run]
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
if yerr_kind is None:
yerr = None
if isinstance(yerr, str):
yerr = self.save.get_node(yerr).read()
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
if not run is None:
label = self._label_run(run, node_y, label, nml_key)
# Look if special colors method is used
if colors is None:
if yerr is None:
(base_line,) = P.plot(x, y, label=label, **kwargs)
else:
base_line, _, _ = P.errorbar(x, y, yerr=yerr, label=label, **kwargs)
else:
if nml_color is None:
color = colors[run]
elif nml_color == "time":
time = (
self.save.root._v_attrs.time * self.comp.info["unit_time"]
).express(unit_time)
color = colors(time)
else:
nml = self.comp.get_nml(nml_color, run)
try:
color = colors[nml]
except:
color = colors(nml)
if yerr is None:
(base_line,) = P.plot(x, y, label=label, color=color, **kwargs)
else:
base_line, _, _ = P.errorbar(
x, y, yerr=[y - yerr_min, yerr_max - y], label=label, **kwargs
x, y, yerr=yerr, color=color, label=label, **kwargs
)
else:
if runs is None:
runs = self.runs
for i, run in enumerate(runs):
x_run, y_run = node_x[run], node_y[run]
x = x_run.read() * xunit_old.express(xunit)
y = y_run.read() * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
label_run = self._label_run(run, y_run, label, nml_key)
if colors is None:
(base_line,) = P.plot(x, y, label=label_run, **kwargs)
else:
if nml_color is None:
color = colors[i % len(colors)]
else:
nml = self.comp.get_nml(nml_color, run)
try:
color = colors[nml]
except:
color = colors(nml)
(base_line,) = P.plot(x, y, label=label_run, color=color, **kwargs)
if legend is None:
legend = True
# Ax decorations
P.xlabel(xlabel)
P.ylabel(ylabel)
if grid:
P.grid()
if legend:
P.legend()
@@ -1076,6 +1077,7 @@ class Plotter(Aggregator, BaseProcessor):
color=base_line.get_color(),
label=fitlabel,
)
if subname_x:
hdf5_x.close()
if subname_y:
@@ -1603,11 +1605,16 @@ class Plotter(Aggregator, BaseProcessor):
dependencies=["radial_bins", "rad_avg_" + name],
)
self.rules["avg_map_" + name] = PlotRule(
self,
partial(self._plot_map, "avg_map_" + name),
"Map of the radial average of {}".format(name),
dependencies=["avg_map_" + name],
)
self.rules["fluct_" + name] = PlotRule(
self,
partial(
self._plot_map, "fluct_" + name, vmin=0.01, vmax=100, cmap="RdBu_r"
),
partial(self._plot_map, "fluct_" + name, cmap="RdBu_r"),
"Fluctuation of {}".format(name),
dependencies=["fluct_" + name],
)
@@ -1666,7 +1673,7 @@ class Plotter(Aggregator, BaseProcessor):
# Radial
generic_rule(field + "r")
# Othoradial
# Orthoradial
generic_rule(field + "phi")
# Norm
generic_rule(field + "_norm")