[plotter] Improve logging + split plot function

This commit is contained in:
Noe Brucy
2022-09-02 16:15:57 +02:00
parent c841a34ceb
commit f1eb57a2b0
+120 -100
View File
@@ -240,7 +240,7 @@ class Plotter(Aggregator, BaseProcessor):
"""
# log info
self.log_id = "plot {}".format(tag)
self.log_id = "plotter({})".format(tag)
super(Plotter, self).__init__(path, path_out, params, tag)
@@ -409,12 +409,12 @@ class Plotter(Aggregator, BaseProcessor):
elif rule.kind == "comp":
run_num = [(None, None)]
if movie:
self._log(f"No movie possible for rule {name}", "WARNING")
self.logger.warning(f"No movie possible for rule {name}")
movie = False
else:
run_num = [(run, None) for run in runs]
if movie:
self._log(f"No movie possible for rule {name}", "WARNING")
self.logger.warning(f"No movie possible for rule {name}")
movie = False
onefigure = False # If axes are provided, only save/close once
@@ -499,20 +499,18 @@ class Plotter(Aggregator, BaseProcessor):
if self.params.plot.tight_layout and close:
plt.tight_layout(pad=1)
if self.params.out.save:
if self.params.out.save:
os.makedirs(os.path.dirname(plot_filename), exist_ok=True)
plt.savefig(plot_filename)
self._log("{} plotted".format(plot_filename), "SUCCESS")
self.logger.info(f"{plot_filename} plotted")
else:
self._log(
"{} plotted".format(os.path.basename(plot_filename)), "SUCCESS"
)
self.logger.info(f"{os.path.basename(plot_filename)} plotted")
if not self.params.out.interactive and close:
plt.close()
return plot_info
else:
self._log("Plot {} is already done, skipping...".format(plot_filename))
self.logger.info(f"Plot {plot_filename} is already done.")
def _find_filename(self, name_full, run=None, num=None, fmt=None, ext=None):
"""
@@ -591,7 +589,7 @@ class Plotter(Aggregator, BaseProcessor):
label = run
return label
if nml_key is None and label is None:
if nml_key is None and (label is None or len(label) == 0):
label_run = get_label_file(run)
elif nml_key is not None:
if not type(nml_key) == list:
@@ -600,7 +598,7 @@ class Plotter(Aggregator, BaseProcessor):
lbl_list = filter(lambda x: len(x) > 0, lbl_list) # Remove void labels
label_run = ", ".join(lbl_list)
if label is not None:
if label is not None and len(label) > 0:
label_run = label + " (" + label_run + ")"
else:
label_run = label
@@ -757,7 +755,7 @@ class Plotter(Aggregator, BaseProcessor):
plt.gca().add_artist(scalebar)
if axes_indicator:
# A liitle drawing saying what are the axes
# A little drawing saying what are the axes
plt.annotate(
"",
xy=(0.97, 0.1),
@@ -979,7 +977,7 @@ class Plotter(Aggregator, BaseProcessor):
part_pos = data[["x", "y", "z"]].values
unit_length /= self.current_processor.lbox
except KeyError:
self.current_processor._log("No sinks particles", "WARNING")
self.current_processor.logger.warning("No sinks particles")
return
elif parts:
# Open particle HDF5 filetype_from_ext
@@ -1219,14 +1217,115 @@ class Plotter(Aggregator, BaseProcessor):
plot_title=title,
)
def plot(
self,
x:np.array,
y:np.array,
xlabel:str="",
ylabel:str="",
label:str="",
xscale:str="linear",
yscale:str="linear",
fit:str=None,
fitlabel:str=None,
smooth:float=0,
nml_key=None,
run:str=None,
yerr:np.array=None,
grid:bool=False,
put_time:bool=False,
unit_time=U.Myr,
colors=None,
nml_color=None,
legend:bool=False,
**kwargs,
):
"""
Generic plot routine, with x, y two numpy arrauys
"""
# Option to smooth data for readability (beware)
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
# Special label if the plot apply to a given run
if run is not None:
label = self.get_label_run(run, label, nml_key)
# If relevant, get time
if put_time:
time = self.current_processor.time * self.study.info["unit_time"]
time_str = self.params.plot.time_fmt.format(
time.express(unit_time), unit_time.latex.replace("text", "math")
)
time_str = f"${time_str}$"
if len(label) > 0:
label = label + " | " + time_str
else:
label = time_str
# Look if special colors method is used
if colors is None:
if yerr is None:
(base_line,) = plt.plot(x, y, label=label, **kwargs)
else:
base_line, _, _ = plt.errorbar(x, y, yerr=yerr, label=label, **kwargs)
else:
if nml_color is None:
color = colors[run]
elif nml_color == "time":
time = (
self.current_processor.time
* self.current_processor.info["unit_time"]
).express(unit_time)
color = colors(time)
else:
nml_value = self.study.get_nml(nml_color, run)
if os.path.basename(nml_color) in self.value_convert:
nml_value = self.value_convert[os.path.basename(nml_color)](nml_value)
try:
color = colors[nml_value]
except TypeError:
color = colors(nml_value)
if yerr is None:
(base_line,) = plt.plot(x, y, label=label, color=color, **kwargs)
else:
base_line, _, _ = plt.errorbar(
x, y, yerr=yerr, color=color, label=label, **kwargs
)
# Ax decorations
plt.xlabel(xlabel)
plt.ylabel(ylabel)
if grid:
plt.grid()
if legend:
plt.legend()
# Ax scale
plt.xscale(xscale)
plt.yscale(yscale)
if fit is not None:
self._overlay_fit(
x,
y,
yerr,
kind=fit,
ls="--",
lw=1.5,
color=base_line.get_color(),
label=fitlabel,
)
def _plot(
self,
name_x,
name_y,
name_x:str,
name_y:str,
node_arg=None,
xlabel=None,
ylabel=None,
label=None,
xunit=None,
yunit=None,
put_units=True,
@@ -1234,22 +1333,10 @@ class Plotter(Aggregator, BaseProcessor):
yunit_coeff=1.0,
xtransform=None,
ytransform=None,
xscale="linear",
yscale="linear",
fit=None,
fitlabel=None,
smooth=0,
nml_key=None,
run=None,
yerr=None,
yerr_kind="std",
sigma_err=2.0,
grid=False,
put_time=False,
unit_time=U.Myr,
colors=None,
nml_color=None,
legend=None,
subname_x=None,
subname_y=None,
**kwargs,
@@ -1282,18 +1369,6 @@ class Plotter(Aggregator, BaseProcessor):
name_y, ylabel, yunit, yunit_coeff, put_units=put_units,
)
# If relevent, get time
if put_time:
time = self.current_processor.time * self.study.info["unit_time"]
time_str = self.params.plot.time_fmt.format(
time.express(unit_time), unit_time.latex.replace("text", "math")
)
time_str = f"${time_str}$"
if label is not None and len(label) > 0:
label = label + " | " + time_str
else:
label = time_str
# Manage the different forms in which the data may be stored :
# Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs)
if isinstance(node_y, np.ndarray):
@@ -1347,67 +1422,12 @@ class Plotter(Aggregator, BaseProcessor):
if ytransform is not None:
y = ytransform(y)
if yerr is not None:
self._log(
"Errorbar may be meaningless when ytransform is used", "WARNING"
)
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
if run is not None:
label = self.get_label_run(run, label, nml_key)
# Look if special colors method is used
if colors is None:
if yerr is None:
(base_line,) = plt.plot(x, y, label=label, **kwargs)
else:
base_line, _, _ = plt.errorbar(x, y, yerr=yerr, label=label, **kwargs)
else:
if nml_color is None:
color = colors[run]
elif nml_color == "time":
time = (
self.current_processor.time
* self.current_processor.info["unit_time"]
).express(unit_time)
color = colors(time)
else:
nml_value = self.study.get_nml(nml_color, run)
if os.path.basename(nml_color) in self.value_convert:
nml_value = self.value_convert[os.path.basename(nml_color)](nml_value)
try:
color = colors[nml_value]
except TypeError:
color = colors(nml_value)
if yerr is None or yerr_kind is None:
(base_line,) = plt.plot(x, y, label=label, color=color, **kwargs)
else:
base_line, _, _ = plt.errorbar(
x, y, yerr=yerr, color=color, label=label, **kwargs
self.logger.warning(
"Errorbar may be meaningless when ytransform is used"
)
# Ax decorations
plt.xlabel(xlabel)
plt.ylabel(ylabel)
if grid:
plt.grid()
if legend:
plt.legend()
# Ax scale
plt.xscale(xscale)
plt.yscale(yscale)
if fit is not None:
self._overlay_fit(
x,
y,
yerr,
kind=fit,
ls="--",
lw=1.5,
color=base_line.get_color(),
label=fitlabel,
)
self.plot(x, y, yerr=yerr, xlabel=xlabel,
ylabel=ylabel, run=run, **kwargs)
if subname_x:
hdf5_x.close()
@@ -1507,7 +1527,7 @@ class Plotter(Aggregator, BaseProcessor):
This is where rules are defined
"""
self.rules = {
"plot": PlotRule(lambda arg, **kwargs: self._plot(*arg, **kwargs), kind="comp"
"plot_arrays": PlotRule(lambda arg, **kwargs: self._plot(*arg, **kwargs), kind="comp"
),
"plot_run": PlotRule(lambda arg, **kwargs: self._plot(*arg, **kwargs), kind="run"
),