improvements

This commit is contained in:
Noe Brucy
2020-03-14 18:09:04 +01:00
parent 1c2750a7bd
commit c2f866b37f
5 changed files with 97 additions and 37 deletions
+70 -21
View File
@@ -119,7 +119,9 @@ class Plotter(Aggregator, BaseProcessor):
or not os.path.exists(plot_filename)
)
def _process_rule(self, name, rule, arg, overwrite=False, ax=None, **kwargs):
def _process_rule(
self, name, rule, arg, overwrite=False, ax=None, movie=False, **kwargs
):
if not arg is None:
name_full = name + "_" + str(arg)
else:
@@ -133,11 +135,9 @@ class Plotter(Aggregator, BaseProcessor):
runs = runs.runs
except KeyError:
runs = self.runs
if ax is None:
ax = [P.subplots(1, 1)[1] for run in runs for num in self.nums[run]]
i = 0
for run in runs:
files = []
for num in self.nums[run]:
plot_filename = self._find_filename(name_full, run, num)
save = tables.open_file(self.pp[run][num].filename)
@@ -153,24 +153,38 @@ class Plotter(Aggregator, BaseProcessor):
**kwargs
)
except TypeError as e:
if (
str(e)
!= "'LocatableAxes' object does not support indexing"
):
if str(e) in [
"'LocatableAxes' object does not support indexing",
"'AxesSubplot' object does not support indexing",
]:
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax,
run=run,
**kwargs
)
elif ax is None:
fig = P.figure()
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=P.gca(),
run=run,
**kwargs
)
else:
raise
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax,
run=run,
**kwargs
)
finally:
save.close()
i = i + 1
files.append(plot_filename)
else:
if ax is None:
ax = P.gca()
@@ -479,6 +493,7 @@ class Plotter(Aggregator, BaseProcessor):
xlog=None,
ylog=False,
kind="bar",
ylabel="$\mathcal{P}$",
color=None,
colors=None,
nml_color=None,
@@ -505,7 +520,7 @@ class Plotter(Aggregator, BaseProcessor):
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.out.time_fmt.format(
time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex
)
if len(title) > 0:
@@ -538,6 +553,8 @@ class Plotter(Aggregator, BaseProcessor):
if not label is None:
P.xlabel(label)
if not ylabel is None:
P.ylabel(ylabel)
if not ax_los is None and "/hist/fit_" + name + "_" + ax_los in self.save:
slope = node.attrs.slope
@@ -573,8 +590,11 @@ class Plotter(Aggregator, BaseProcessor):
yerr_kind="std",
sigma_err=2.0,
grid=True,
put_time=False,
time_unit=cst.Myr,
colors=None,
nml_color=None,
legend=None,
**kwargs
):
@@ -597,6 +617,16 @@ class Plotter(Aggregator, BaseProcessor):
if grid:
P.grid()
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex
)
if len(label) > 0:
label = label + " | " + time_str
else:
label = time_str
yerr = None
if node_y._v_attrs.CLASS == "ARRAY":
x = node_x.read() * xunit_old.express(xunit)
@@ -607,7 +637,23 @@ class Plotter(Aggregator, BaseProcessor):
y = gaussian_filter1d(y, sigma=smooth)
if not run is None:
label = self._label_run(run, node_y, label, nml_key)
(base_line,) = P.plot(x, y, label=label, **kwargs)
if colors is None:
(base_line,) = P.plot(x, y, label=label, **kwargs)
else:
if nml_color is None:
color = colors[run]
elif nml_color == "time":
time = (
self.save.root._v_attrs.time * self.comp.info["unit_time"]
).express(time_unit)
color = colors(time)
else:
nml = self.comp.get_nml(nml_color, run)
try:
color = colors[nml]
except:
color = colors(nml)
(base_line,) = P.plot(x, y, label=label, color=color, **kwargs)
elif "mean" in node_y:
x = node_x.read() * xunit_old.express(xunit)
y = node_y.mean.read() * yunit_old.express(yunit)
@@ -664,8 +710,11 @@ class Plotter(Aggregator, BaseProcessor):
except:
color = colors(nml)
(base_line,) = P.plot(x, y, label=label_run, color=color, **kwargs)
if legend is None:
legend = True
P.legend()
if legend:
P.legend()
if not fit is None:
self._overlay_fit(