[plotter] close only at the end is several run are plot on same fig

This commit is contained in:
Noe Brucy
2021-06-14 10:24:50 +02:00
parent 9cc1396ca5
commit 7b3dd2cead
+15 -5
View File
@@ -299,7 +299,6 @@ class Plotter(Aggregator, BaseProcessor):
datafiles = [] datafiles = []
# Several plots
if rule.kind == "snapshot" or rule.kind == "cells": if rule.kind == "snapshot" or rule.kind == "cells":
run_num = [(run, num) for run in runs for num in nums[run]] run_num = [(run, num) for run in runs for num in nums[run]]
elif rule.kind == "comp": elif rule.kind == "comp":
@@ -307,8 +306,15 @@ class Plotter(Aggregator, BaseProcessor):
else: else:
run_num = [(run, None) for run in runs] run_num = [(run, None) for run in runs]
onefigure = False # If axes are provided, only save/close once
if ax is not None:
onefigure = True
plot_filename = self._find_filename(name_full)
for i, (run, num) in enumerate(run_num): for i, (run, num) in enumerate(run_num):
# Find filename # Find filename
if not onefigure:
plot_filename = self._find_filename(name_full, run, num) plot_filename = self._find_filename(name_full, run, num)
# Find ax # Find ax
@@ -335,6 +341,7 @@ class Plotter(Aggregator, BaseProcessor):
# Call plot routine # Call plot routine
try: try:
close = (not onefigure) or (i == len(run_num) - 1)
plot_info = self._plot_rule( plot_info = self._plot_rule(
rule, rule,
save, save,
@@ -342,6 +349,7 @@ class Plotter(Aggregator, BaseProcessor):
plot_filename, plot_filename,
overwrite, overwrite,
ax=real_ax, ax=real_ax,
close=close,
run=run, run=run,
**kwargs, **kwargs,
) )
@@ -367,7 +375,9 @@ class Plotter(Aggregator, BaseProcessor):
datafiles.append(df) datafiles.append(df)
return datafiles return datafiles
def _plot_rule(self, rule, save, arg, plot_filename, overwrite, ax, **kwargs): def _plot_rule(
self, rule, save, arg, plot_filename, overwrite, ax, close=True, **kwargs
):
""" """
Once all dependencies are met, actually process the rule Once all dependencies are met, actually process the rule
""" """
@@ -375,10 +385,10 @@ class Plotter(Aggregator, BaseProcessor):
if self._needs_computation(overwrite, plot_filename): if self._needs_computation(overwrite, plot_filename):
plot_info = rule.plot(save, arg, **kwargs) plot_info = rule.plot(save, arg, **kwargs)
if not self.pp_params.out.interactive: if not self.pp_params.out.interactive and close:
plt.tight_layout(pad=1) plt.tight_layout(pad=1)
if self.pp_params.out.save: if self.pp_params.out.save and close:
plt.savefig(plot_filename) plt.savefig(plot_filename)
self._log("{} plotted".format(plot_filename), "SUCCESS") self._log("{} plotted".format(plot_filename), "SUCCESS")
else: else:
@@ -386,7 +396,7 @@ class Plotter(Aggregator, BaseProcessor):
"{} plotted".format(os.path.basename(plot_filename)), "SUCCESS" "{} plotted".format(os.path.basename(plot_filename)), "SUCCESS"
) )
if not self.pp_params.out.interactive: if not self.pp_params.out.interactive and close:
plt.close() plt.close()
return plot_info return plot_info
else: else: