add astrophysix support, fix labels
This commit is contained in:
+310
-303
@@ -17,7 +17,8 @@ from scipy.stats import linregress
|
||||
from numpy.polynomial.polynomial import polyfit
|
||||
from scipy.ndimage.filters import gaussian_filter1d
|
||||
from scipy import optimize
|
||||
|
||||
from astrophysix.simdm.datafiles import Datafile, PlotType, PlotInfo
|
||||
from astrophysix.utils.file import FileType
|
||||
import matplotlib as mpl
|
||||
|
||||
if os.environ.get("DISPLAY", "") == "":
|
||||
@@ -26,12 +27,15 @@ if os.environ.get("DISPLAY", "") == "":
|
||||
import pylab as P
|
||||
from comparator import *
|
||||
import pspec_read
|
||||
import datetime
|
||||
|
||||
P.rcParams["image.cmap"] = "plasma"
|
||||
P.rcParams["savefig.dpi"] = 400
|
||||
filetype_from_ext = {ext: ft for ft in FileType for ext in ft.extension_list}
|
||||
|
||||
tex_params = {"text.latex.preamble": r"\usepackage{amsmath}"}
|
||||
P.rcParams.update(tex_params)
|
||||
|
||||
def not_array_error(err):
|
||||
epy2 = "object does not support indexing"
|
||||
epy3 = "object is not subscriptable"
|
||||
return str(err)[-len(epy2) :] == epy2 or str(err)[-len(epy3) :] == epy3
|
||||
|
||||
|
||||
class PlotRule(Rule):
|
||||
@@ -42,7 +46,7 @@ class PlotRule(Rule):
|
||||
|
||||
def plot(self, save, arg, **kwargs):
|
||||
"""
|
||||
Set the plotter's storage to 'save' and exetute the rule
|
||||
Set the plotter's storage to 'save' and execute the rule
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -53,6 +57,12 @@ class PlotRule(Rule):
|
||||
self.postproc.save = save
|
||||
return self.process_fn(arg, **kwargs)
|
||||
|
||||
def datafile(self, arg):
|
||||
return Datafile(
|
||||
name=self.description + "_" + arg,
|
||||
description=self.description + " ({})".format(arg),
|
||||
)
|
||||
|
||||
|
||||
class Plotter(Aggregator, BaseProcessor):
|
||||
"""
|
||||
@@ -99,6 +109,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
pp_params=None,
|
||||
selector=None,
|
||||
tag=None,
|
||||
unit_time=cst.year,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
@@ -135,7 +146,13 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
|
||||
# Get comparator object
|
||||
self.comp = Comparator(
|
||||
path, self.runs, self.nums, path_out, self.pp_params, selector=self.selector
|
||||
path,
|
||||
self.runs,
|
||||
self.nums,
|
||||
path_out,
|
||||
self.pp_params,
|
||||
unit_time=unit_time,
|
||||
selector=self.selector,
|
||||
)
|
||||
|
||||
# Get postprocesor objets for each run
|
||||
@@ -147,17 +164,56 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
# Define rules
|
||||
self.def_rules()
|
||||
|
||||
# generate astrophysix's simulations object
|
||||
self.gen_simus()
|
||||
|
||||
self.save = None
|
||||
|
||||
def gen_simus(self):
|
||||
self.simulations = {}
|
||||
simu_fmt = self.pp_params.astrophysix.simu_fmt
|
||||
descr_fmt = self.pp_params.astrophysix.descr_fmt
|
||||
tag = self.pp_params.out.tag
|
||||
for run in self.runs:
|
||||
pp = self.pp[run][self.nums[run][0]]
|
||||
nml = self.comp.namelist[run]
|
||||
name = simu_fmt.format(run=run, tag=tag, nml=nml)
|
||||
exec_time = str(datetime.datetime.fromtimestamp(os.stat(pp.path).st_ctime))
|
||||
description = descr_fmt.format(run=run, tag=tag, nml=nml)
|
||||
simu = Simulation(
|
||||
simu_code=ramses,
|
||||
name=name,
|
||||
alias=name.upper(),
|
||||
description=description,
|
||||
directory_path=pp.path,
|
||||
execution_time=exec_time,
|
||||
)
|
||||
|
||||
for param in ramses.input_parameters:
|
||||
try:
|
||||
param_setting = ParameterSetting(
|
||||
input_param=param,
|
||||
value=self.comp.get_nml(param.key, run),
|
||||
visibility=ParameterVisibility.BASIC_DISPLAY,
|
||||
)
|
||||
simu.parameter_settings.add(param_setting)
|
||||
except KeyError as e:
|
||||
self._log("key {} not found".format(e), "WARNING")
|
||||
except AttributeError as e:
|
||||
self._log("{}".format(e), "WARNING")
|
||||
|
||||
self.simulations[run] = simu
|
||||
|
||||
def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs):
|
||||
"""
|
||||
Check if the dependency belongs to the plotter object or to another one (comp, pp, ..)
|
||||
"""
|
||||
if dep in self.comp.rules:
|
||||
done = self.comp.process(
|
||||
result = self.comp.process(
|
||||
dep, dep_arg, overwrite, overwrite_dep=self.overwrite_dep
|
||||
)
|
||||
self.just_done.extend(done)
|
||||
if result is not None:
|
||||
self.just_done.append(done)
|
||||
else:
|
||||
super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, **kwargs)
|
||||
|
||||
@@ -172,19 +228,13 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
)
|
||||
|
||||
def _process_rule(
|
||||
self,
|
||||
name,
|
||||
rule,
|
||||
arg,
|
||||
overwrite=False,
|
||||
ax=None,
|
||||
movie=False,
|
||||
from_cells=False,
|
||||
**kwargs,
|
||||
self, name, rule, arg, overwrite=False, ax=None, from_cells=False, **kwargs
|
||||
):
|
||||
"""
|
||||
Open storage and figure if needed before processing a rule
|
||||
"""
|
||||
|
||||
# Set full name according to argument
|
||||
if not arg is None:
|
||||
name_full = (
|
||||
name
|
||||
@@ -200,144 +250,89 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
else:
|
||||
name_full = name
|
||||
|
||||
if rule.is_valid(arg):
|
||||
if rule.kind == "classic" or rule.kind == "cells":
|
||||
if "select" in kwargs:
|
||||
select = kwargs.pop("select")
|
||||
runs, nums = self.selector.select(**select)
|
||||
elif "runs" in kwargs:
|
||||
runs = kwargs.pop("runs")
|
||||
if isinstance(runs, RunSelector):
|
||||
nums = runs.nums
|
||||
runs = runs.runs
|
||||
else:
|
||||
nums = self.nums
|
||||
else:
|
||||
runs = self.runs
|
||||
nums = self.nums
|
||||
|
||||
i = 0
|
||||
for run in runs:
|
||||
files = []
|
||||
for num in nums[run]:
|
||||
plot_filename = self._find_filename(name_full, run, num)
|
||||
|
||||
if from_cells or rule.kind == "cells":
|
||||
if not os.exists(self.pp[run][num].cells_filename):
|
||||
self.pp[run][num].load_cells()
|
||||
self.pp[run][num].unload_cells()
|
||||
save = tables.open_file(self.pp[run][num].cells_filename)
|
||||
elif rule.kind == "classic":
|
||||
save = tables.open_file(self.pp[run][num].filename)
|
||||
else:
|
||||
save = tables.open_file(self.comp.filename, "r")
|
||||
try:
|
||||
self._plot_rule(
|
||||
rule,
|
||||
save,
|
||||
arg,
|
||||
plot_filename,
|
||||
overwrite,
|
||||
ax=ax[i],
|
||||
run=run,
|
||||
**kwargs,
|
||||
)
|
||||
except TypeError as e:
|
||||
if str(e) in [
|
||||
"'LocatableAxes' object does not support indexing",
|
||||
"'AxesSubplot' object does not support indexing",
|
||||
"'AxesSubplot' object is not subscriptable",
|
||||
"'Axes' object is not subscriptable",
|
||||
"'LocatableAxes' object is not subscriptable",
|
||||
]:
|
||||
self._plot_rule(
|
||||
rule,
|
||||
save,
|
||||
arg,
|
||||
plot_filename,
|
||||
overwrite,
|
||||
ax=ax,
|
||||
run=run,
|
||||
**kwargs,
|
||||
)
|
||||
elif ax is None:
|
||||
fig = P.figure()
|
||||
self._plot_rule(
|
||||
rule,
|
||||
save,
|
||||
arg,
|
||||
plot_filename,
|
||||
overwrite,
|
||||
ax=P.gca(),
|
||||
run=run,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
finally:
|
||||
save.close()
|
||||
i = i + 1
|
||||
files.append(plot_filename)
|
||||
else:
|
||||
|
||||
if "select" in kwargs and not "runs" in kwargs:
|
||||
select = kwargs.pop("select")
|
||||
runs, nums = self.selector.select(**select)
|
||||
if not rule.kind == "runs":
|
||||
kwargs["runs"] = runs
|
||||
elif rule.kind == "runs" and "runs" in kwargs:
|
||||
runs = kwargs.pop("runs")
|
||||
else:
|
||||
runs = self.runs
|
||||
|
||||
if ax is None:
|
||||
ax = P.gca()
|
||||
if rule.kind == "series" and len(runs) == 1:
|
||||
run = self.runs[0]
|
||||
plot_filename = self._find_filename(name_full, run)
|
||||
else:
|
||||
plot_filename = self._find_filename(name_full)
|
||||
save = tables.open_file(self.comp.filename, "r")
|
||||
try:
|
||||
if rule.kind == "runs":
|
||||
for i, run in enumerate(runs):
|
||||
try:
|
||||
self._plot_rule(
|
||||
rule,
|
||||
save,
|
||||
arg,
|
||||
plot_filename,
|
||||
overwrite,
|
||||
ax=ax[i],
|
||||
run=run,
|
||||
**kwargs,
|
||||
)
|
||||
except TypeError as e:
|
||||
if str(e) in [
|
||||
"'LocatableAxes' object does not support indexing",
|
||||
"'AxesSubplot' object does not support indexing",
|
||||
"'AxesSubplot' object is not subscriptable",
|
||||
"'Axes' object is not subscriptable",
|
||||
"'LocatableAxes' object is not subscriptable",
|
||||
]:
|
||||
self._plot_rule(
|
||||
rule,
|
||||
save,
|
||||
arg,
|
||||
plot_filename,
|
||||
overwrite,
|
||||
ax=ax,
|
||||
run=run,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self._plot_rule(
|
||||
rule, save, arg, plot_filename, overwrite, ax, **kwargs
|
||||
)
|
||||
finally:
|
||||
save.close()
|
||||
else:
|
||||
# Exit if not valid
|
||||
if not rule.is_valid(arg):
|
||||
self._log("{} is not valid in this context".format(name_full), "ERROR")
|
||||
return
|
||||
|
||||
# get filetype of the output
|
||||
filetype = filetype_from_ext[self.pp_params.out.ext]
|
||||
|
||||
# Select runs and nums
|
||||
if "select" in kwargs:
|
||||
select = kwargs.pop("select")
|
||||
runs, nums = self.selector.select(**select)
|
||||
else:
|
||||
runs = self.runs
|
||||
nums = self.nums
|
||||
|
||||
datafiles = []
|
||||
|
||||
# Several plots
|
||||
if rule.kind == "classic" or rule.kind == "cells":
|
||||
run_num = [(run, num) for run in runs for num in nums[run]]
|
||||
else:
|
||||
run_num = [(run, None) for run in runs]
|
||||
|
||||
for i, (run, num) in enumerate(run_num):
|
||||
# Find filename
|
||||
plot_filename = self._find_filename(name_full, run, num)
|
||||
|
||||
# Find ax
|
||||
try:
|
||||
real_ax = ax[i]
|
||||
except TypeError as e:
|
||||
if ax is None:
|
||||
fig, real_ax = P.subplots(1, 1)
|
||||
elif not_array_error(e):
|
||||
real_ax = ax
|
||||
else:
|
||||
raise
|
||||
|
||||
# Find plot save
|
||||
if from_cells or rule.kind == "cells":
|
||||
if not os.exists(self.pp[run][num].cells_filename):
|
||||
self.pp[run][num].load_cells()
|
||||
self.pp[run][num].unload_cells()
|
||||
save = tables.open_file(self.pp[run][num].cells_filename)
|
||||
elif rule.kind == "classic":
|
||||
save = tables.open_file(self.pp[run][num].filename)
|
||||
else:
|
||||
save = tables.open_file(self.comp.filename, "r")
|
||||
|
||||
# Call plot routine
|
||||
try:
|
||||
plot_info = self._plot_rule(
|
||||
rule,
|
||||
save,
|
||||
arg,
|
||||
plot_filename,
|
||||
overwrite,
|
||||
ax=real_ax,
|
||||
run=run,
|
||||
**kwargs,
|
||||
)
|
||||
finally:
|
||||
save.close()
|
||||
|
||||
# Save in astrophysix format
|
||||
df = rule.datafile(arg)
|
||||
df[filetype] = plot_filename
|
||||
if plot_info is not None:
|
||||
df.plot_info = plot_info
|
||||
if num is not None:
|
||||
snap = self.pp[run][num].snapshot
|
||||
|
||||
if overwrite and df.name in snap.datafiles:
|
||||
del snap.datafiles[df.name]
|
||||
elif df.name not in snap.datafiles:
|
||||
snap.datafiles.add(df)
|
||||
|
||||
if snap not in self.simulations[run].snapshots:
|
||||
self.simulations[run].snapshots.add(snap)
|
||||
|
||||
datafiles.append(df)
|
||||
return datafiles
|
||||
|
||||
def _plot_rule(self, rule, save, arg, plot_filename, overwrite, ax, **kwargs):
|
||||
"""
|
||||
@@ -345,7 +340,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
"""
|
||||
P.sca(ax)
|
||||
if self._needs_computation(overwrite, plot_filename):
|
||||
rule.plot(save, arg, **kwargs)
|
||||
plot_info = rule.plot(save, arg, **kwargs)
|
||||
|
||||
if not self.pp_params.out.interactive:
|
||||
P.tight_layout(pad=1)
|
||||
@@ -360,6 +355,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
|
||||
if not self.pp_params.out.interactive:
|
||||
P.close()
|
||||
return plot_info
|
||||
else:
|
||||
self._log("Plot {} is already done, skipping...".format(plot_filename))
|
||||
|
||||
@@ -396,7 +392,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
ext=self.pp_params.out.ext,
|
||||
)
|
||||
|
||||
def _label_run(self, run, node, label, nml_key):
|
||||
def _label_run(self, run, node, label, nml_key, time=None):
|
||||
"""
|
||||
Set up a label for the run from the namelist and parameters
|
||||
"""
|
||||
@@ -435,7 +431,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
label_run = label
|
||||
return label_run
|
||||
|
||||
def _ax_label_unit(self, node, label, unit, unit_coeff):
|
||||
def _ax_label_unit(self, node, label, unit, unit_coeff, put_units=True):
|
||||
"""
|
||||
Find appropriate labels for axis
|
||||
"""
|
||||
@@ -457,20 +453,38 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
if unit is None:
|
||||
unit = unit_old
|
||||
|
||||
if not unit_coeff == 1:
|
||||
base = unit
|
||||
unit = unit_coeff * unit
|
||||
label = label + unit_str(unit, base=base)
|
||||
else:
|
||||
label = label + unit_str(unit)
|
||||
if put_units:
|
||||
if not unit_coeff == 1:
|
||||
base = unit
|
||||
unit = unit_coeff * unit
|
||||
label = label + unit_str(unit, base=base)
|
||||
else:
|
||||
label = label + unit_str(unit)
|
||||
|
||||
return label, unit_old, unit
|
||||
|
||||
def _snapshot_title(self, run, node, title, nml_key, put_time, unit_time=cst.Myr):
|
||||
title = self._label_run(run, node, title, nml_key)
|
||||
|
||||
if put_time:
|
||||
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
|
||||
u_str = unit_str(unit_time, format="{unit}")
|
||||
time_str = self.pp_params.plot.time_fmt.format(
|
||||
time.express(unit_time), u_str
|
||||
)
|
||||
if len(title) > 0:
|
||||
title = title + " | " + time_str
|
||||
else:
|
||||
title = time_str
|
||||
return title
|
||||
|
||||
def _plot_map(
|
||||
self,
|
||||
name,
|
||||
ax_los,
|
||||
run,
|
||||
xlabel=None,
|
||||
ylabel=None,
|
||||
label=None,
|
||||
unit=None,
|
||||
unit_coeff=1.0,
|
||||
@@ -480,12 +494,14 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
put_title=True,
|
||||
nml_key=None,
|
||||
put_time=True,
|
||||
time_unit=cst.Myr,
|
||||
unit_time=cst.Myr,
|
||||
put_units=True,
|
||||
unit_space=cst.pc,
|
||||
cmap="plasma",
|
||||
norm="log",
|
||||
put_cbar=True,
|
||||
autoscale=True,
|
||||
transform=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -503,9 +519,13 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
node = self.save.get_node("/maps/{}_{}".format(name, ax_los))
|
||||
dmap = node.read()
|
||||
|
||||
label, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff)
|
||||
label, unit_old, unit = self._ax_label_unit(
|
||||
node, label, unit, unit_coeff, put_units
|
||||
)
|
||||
|
||||
dmap = dmap * unit_old.express(unit)
|
||||
if transform is not None:
|
||||
dmap = transform(dmap)
|
||||
|
||||
if norm == "log":
|
||||
norm = mpl.colors.LogNorm()
|
||||
@@ -521,32 +541,28 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
|
||||
P.locator_params(axis="both", nbins=self.pp_params.plot.ntick)
|
||||
|
||||
P.xlabel(self._ax_title[ax_h] + unit_str(unit_space))
|
||||
P.ylabel(self._ax_title[ax_v] + unit_str(unit_space))
|
||||
if xlabel is None:
|
||||
xlabel = self._ax_title[ax_h]
|
||||
if ylabel is None:
|
||||
ylabel = self._ax_title[ax_v]
|
||||
if put_units:
|
||||
xlabel = xlabel + unit_str(unit_space)
|
||||
ylabel = ylabel + unit_str(unit_space)
|
||||
P.xlabel(xlabel)
|
||||
P.ylabel(ylabel)
|
||||
|
||||
try:
|
||||
cbar = P.colorbar(im, cax=P.gca().cax)
|
||||
except AttributeError:
|
||||
cbar = P.colorbar()
|
||||
|
||||
if put_title:
|
||||
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
|
||||
P.title(title)
|
||||
|
||||
if not label is None:
|
||||
cbar.set_label(label)
|
||||
|
||||
if put_title:
|
||||
title = self._label_run(run, node, title, nml_key)
|
||||
|
||||
if put_time:
|
||||
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
|
||||
time_str = self.pp_params.plot.time_fmt.format(
|
||||
time.express(time_unit), time_unit.latex
|
||||
)
|
||||
if len(title) > 0:
|
||||
title = title + " | " + time_str
|
||||
else:
|
||||
title = time_str
|
||||
|
||||
P.title(title)
|
||||
|
||||
for i, plot_overlay in enumerate(overlays):
|
||||
if plot_overlay in self.overlays:
|
||||
plot_overlay = self.overlays[plot_overlay]
|
||||
@@ -556,6 +572,23 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
except:
|
||||
plot_overlay(ax_los, im_extent)
|
||||
|
||||
return PlotInfo(
|
||||
plot_type=PlotType.IMAGE,
|
||||
xaxis_values=np.linspace(im_extent[0], im_extent[1], dmap.shape[0] + 1),
|
||||
yaxis_values=np.linspace(im_extent[2], im_extent[3], dmap.shape[1] + 1),
|
||||
values=dmap,
|
||||
xaxis_log_scale=False,
|
||||
yaxis_log_scale=False,
|
||||
values_log_scale=(norm == mpl.colors.LogNorm()),
|
||||
xaxis_label=xlabel,
|
||||
yaxis_label=ylabel,
|
||||
values_label=label,
|
||||
xaxis_unit=unit_space,
|
||||
yaxis_unit=unit_space,
|
||||
values_unit=unit,
|
||||
plot_title=title,
|
||||
)
|
||||
|
||||
def _overlay_contour(
|
||||
self,
|
||||
ax_los,
|
||||
@@ -724,7 +757,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
nml_key=None,
|
||||
put_title=True,
|
||||
put_time=True,
|
||||
time_unit=cst.Myr,
|
||||
unit_time=cst.Myr,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -749,26 +782,13 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
if not ylabel is None:
|
||||
P.ylabel(ylabel)
|
||||
|
||||
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
|
||||
if put_title:
|
||||
title = self._label_run(run, node, title, nml_key)
|
||||
|
||||
if put_time:
|
||||
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
|
||||
time_str = self.pp_params.plot.time_fmt.format(
|
||||
time.express(time_unit), time_unit.latex
|
||||
)
|
||||
if len(title) > 0:
|
||||
title = title + " | " + time_str
|
||||
else:
|
||||
title = time_str
|
||||
|
||||
P.title(title)
|
||||
|
||||
if label is None:
|
||||
if label == None:
|
||||
label = title
|
||||
P.plot(bin_centers, mean_bin, label=label, **kwargs)
|
||||
|
||||
P.plot(bin_centers, mean_bin, label=title, **kwargs)
|
||||
P.plot(bin_centers, mean_bin, label=label, **kwargs)
|
||||
|
||||
def _plot_hist(
|
||||
self,
|
||||
@@ -781,10 +801,11 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
unit_coeff=1.0,
|
||||
ytransform=None,
|
||||
label=None,
|
||||
put_title=True,
|
||||
title=None,
|
||||
nml_key=None,
|
||||
put_time=True,
|
||||
time_unit=cst.Myr,
|
||||
unit_time=cst.Myr,
|
||||
xlog=None,
|
||||
ylog=False,
|
||||
kind="bar",
|
||||
@@ -799,47 +820,41 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
"""
|
||||
Plot an histogram (PDF, etc ...)
|
||||
"""
|
||||
# Get node
|
||||
if not ax_los is None:
|
||||
name = name + "_" + ax_los
|
||||
|
||||
node = self.save.get_node(group + name)
|
||||
|
||||
if xlog is None:
|
||||
try:
|
||||
xlog = node._v_attrs_.logbins
|
||||
except:
|
||||
xlog = False
|
||||
|
||||
# get label and units
|
||||
xlabel, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff)
|
||||
|
||||
# Read data
|
||||
if "mean" in node:
|
||||
index = node["runs"].read().index(run.encode())
|
||||
values, centers = node["mean"].read()[index]
|
||||
else:
|
||||
values, centers = node.read()
|
||||
|
||||
if xlog:
|
||||
centers = centers + np.log10(unit_old.express(unit))
|
||||
else:
|
||||
centers = centers * unit_old.express(unit)
|
||||
|
||||
if ytransform is not None:
|
||||
values = ytransform(values)
|
||||
width = centers[1] - centers[0]
|
||||
|
||||
title = self._label_run(run, node, title, nml_key)
|
||||
|
||||
if put_time:
|
||||
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
|
||||
time_str = self.pp_params.plot.time_fmt.format(
|
||||
time.express(time_unit), time_unit.latex
|
||||
)
|
||||
if len(title) > 0:
|
||||
title = title + " | " + time_str
|
||||
else:
|
||||
title = time_str
|
||||
|
||||
P.title(title)
|
||||
# Set title
|
||||
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
|
||||
if put_title:
|
||||
P.title(title)
|
||||
if label == None:
|
||||
label = title
|
||||
|
||||
# Set colors
|
||||
if color is None and not colors is None:
|
||||
if nml_color is None:
|
||||
color = colors[run]
|
||||
@@ -850,11 +865,8 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
except:
|
||||
color = colors(nml)
|
||||
|
||||
if label == None:
|
||||
label = title
|
||||
|
||||
# Actual plot
|
||||
if kind == "bar":
|
||||
width = centers[1] - centers[0]
|
||||
P.bar(centers, values, width, log=ylog, color=color, label=label, **kwargs)
|
||||
elif kind == "step":
|
||||
if ylog:
|
||||
@@ -863,11 +875,13 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
else:
|
||||
raise ValueError("kind must be 'bar' or 'step'")
|
||||
|
||||
# put labels
|
||||
if not label is None:
|
||||
P.xlabel(xlabel)
|
||||
if not ylabel is None:
|
||||
P.ylabel(ylabel)
|
||||
|
||||
# Also diplay fit, previously saved
|
||||
if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
|
||||
slope = node.attrs.slope
|
||||
origin = node.attrs.origin
|
||||
@@ -878,14 +892,27 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
linewidth=2,
|
||||
color="orange",
|
||||
)
|
||||
|
||||
P.ylim([None, 1.0])
|
||||
|
||||
# or a new one
|
||||
if not fit is None:
|
||||
self._overlay_fit(
|
||||
centers, values, kind=fit, ls="--", lw=1.5, label=fitlabel
|
||||
)
|
||||
|
||||
# returns PlotInfo (for Galactica)
|
||||
edges = np.append(centers - width / 2.0, centers[-1] + width / 2.0)
|
||||
return PlotInfo(
|
||||
plot_type=PlotType.HISTOGRAM,
|
||||
xaxis_values=edges,
|
||||
yaxis_values=values,
|
||||
xaxis_log_scale=False,
|
||||
yaxis_log_scale=ylog,
|
||||
xaxis_label=xlabel,
|
||||
yaxis_label=ylabel,
|
||||
xaxis_unit=unit,
|
||||
yaxis_unit=cst.none,
|
||||
plot_title=title,
|
||||
)
|
||||
|
||||
def _plot(
|
||||
self,
|
||||
name_x,
|
||||
@@ -898,19 +925,17 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
yunit=None,
|
||||
xunit_coeff=1.0,
|
||||
yunit_coeff=1.0,
|
||||
ylog=False,
|
||||
fit=None,
|
||||
fitlabel=None,
|
||||
smooth=0,
|
||||
nml_key=None,
|
||||
run=None,
|
||||
runs=None,
|
||||
yerr=None,
|
||||
yerr_kind="std",
|
||||
sigma_err=2.0,
|
||||
grid=False,
|
||||
put_time=False,
|
||||
time_unit=cst.Myr,
|
||||
unit_time=cst.Myr,
|
||||
colors=None,
|
||||
nml_color=None,
|
||||
legend=None,
|
||||
@@ -922,12 +947,15 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
Generic plot routine, with name_x and name_y two path in the hdf5 file
|
||||
"""
|
||||
|
||||
# Get proper hdf5 names
|
||||
if not node_arg is None:
|
||||
name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg
|
||||
|
||||
# Get hdf5 nodes
|
||||
node_x = self.save.get_node(name_x)
|
||||
node_y = self.save.get_node(name_y)
|
||||
|
||||
# If the actual data is in a,other file, fetch it
|
||||
if subname_x:
|
||||
hdf5_x = tables.open_file(node_x.read())
|
||||
node_x = hdf5_x.get_node(subname_x)
|
||||
@@ -935,6 +963,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
hdf5_y = tables.open_file(node_y.read())
|
||||
node_y = hdf5_y.get_node(subname_y)
|
||||
|
||||
# Find proper labels
|
||||
xlabel, xunit_old, xunit = self._ax_label_unit(
|
||||
node_x, xlabel, xunit, xunit_coeff
|
||||
)
|
||||
@@ -942,58 +971,24 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
node_y, ylabel, yunit, yunit_coeff
|
||||
)
|
||||
|
||||
P.xlabel(xlabel)
|
||||
P.ylabel(ylabel)
|
||||
|
||||
if grid:
|
||||
P.grid()
|
||||
|
||||
if ylog:
|
||||
P.yscale("log")
|
||||
|
||||
# If relevent, get time
|
||||
if put_time:
|
||||
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
|
||||
time_str = self.pp_params.plot.time_fmt.format(
|
||||
time.express(time_unit), time_unit.latex
|
||||
time.express(unit_time), unit_time.latex
|
||||
)
|
||||
if label is not None and len(label) > 0:
|
||||
label = label + " | " + time_str
|
||||
else:
|
||||
label = time_str
|
||||
|
||||
# Manage the different forms in which the data may be stored :
|
||||
# Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs)
|
||||
if node_y._v_attrs.CLASS == "ARRAY":
|
||||
x = node_x.read() * xunit_old.express(xunit)
|
||||
y = node_y.read() * yunit_old.express(yunit)
|
||||
mask = np.isfinite(x) & np.isfinite(y)
|
||||
x, y = x[mask], y[mask]
|
||||
if smooth > 0:
|
||||
y = gaussian_filter1d(y, sigma=smooth)
|
||||
if not run is None:
|
||||
label = self._label_run(run, node_y, label, nml_key)
|
||||
if colors is None:
|
||||
(base_line,) = P.plot(x, y, label=label, **kwargs)
|
||||
else:
|
||||
if nml_color is None:
|
||||
color = colors[run]
|
||||
elif nml_color == "time":
|
||||
time = (
|
||||
self.save.root._v_attrs.time * self.comp.info["unit_time"]
|
||||
).express(time_unit)
|
||||
color = colors(time)
|
||||
else:
|
||||
nml = self.comp.get_nml(nml_color, run)
|
||||
try:
|
||||
color = colors[nml]
|
||||
except:
|
||||
color = colors(nml)
|
||||
if yerr is None:
|
||||
(base_line,) = P.plot(x, y, label=label, color=color, **kwargs)
|
||||
else:
|
||||
if isinstance(yerr, str):
|
||||
yerr = self.save.get_node(yerr).read()
|
||||
|
||||
base_line, _, _ = P.errorbar(x, y, yerr=yerr, label=label, **kwargs)
|
||||
|
||||
elif "mean" in node_y:
|
||||
x = node_x.read() * xunit_old.express(xunit)
|
||||
y = node_y.mean.read() * yunit_old.express(yunit)
|
||||
@@ -1014,7 +1009,6 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
else:
|
||||
yerr_min = y
|
||||
yerr_max = y
|
||||
|
||||
yerr = yerr_max - yerr_min
|
||||
mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(yerr)
|
||||
x, y, yerr, yerr_min, yerr_max = (
|
||||
@@ -1024,44 +1018,51 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
yerr_min[mask],
|
||||
yerr_max[mask],
|
||||
)
|
||||
if not run is None:
|
||||
label = self._label_run(run, node_y, label, nml_key)
|
||||
else:
|
||||
x, y = node_x[run], node_y[run]
|
||||
mask = np.isfinite(x) & np.isfinite(y)
|
||||
x, y = x[mask], y[mask]
|
||||
|
||||
if yerr_kind is None:
|
||||
yerr = None
|
||||
if isinstance(yerr, str):
|
||||
yerr = self.save.get_node(yerr).read()
|
||||
|
||||
if smooth > 0:
|
||||
y = gaussian_filter1d(y, sigma=smooth)
|
||||
if not run is None:
|
||||
label = self._label_run(run, node_y, label, nml_key)
|
||||
|
||||
# Look if special colors method is used
|
||||
if colors is None:
|
||||
if yerr is None:
|
||||
(base_line,) = P.plot(x, y, label=label, **kwargs)
|
||||
else:
|
||||
base_line, _, _ = P.errorbar(x, y, yerr=yerr, label=label, **kwargs)
|
||||
else:
|
||||
if nml_color is None:
|
||||
color = colors[run]
|
||||
elif nml_color == "time":
|
||||
time = (
|
||||
self.save.root._v_attrs.time * self.comp.info["unit_time"]
|
||||
).express(unit_time)
|
||||
color = colors(time)
|
||||
else:
|
||||
nml = self.comp.get_nml(nml_color, run)
|
||||
try:
|
||||
color = colors[nml]
|
||||
except:
|
||||
color = colors(nml)
|
||||
if yerr is None:
|
||||
(base_line,) = P.plot(x, y, label=label, color=color, **kwargs)
|
||||
else:
|
||||
base_line, _, _ = P.errorbar(
|
||||
x, y, yerr=[y - yerr_min, yerr_max - y], label=label, **kwargs
|
||||
x, y, yerr=yerr, color=color, label=label, **kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
if runs is None:
|
||||
runs = self.runs
|
||||
for i, run in enumerate(runs):
|
||||
x_run, y_run = node_x[run], node_y[run]
|
||||
x = x_run.read() * xunit_old.express(xunit)
|
||||
y = y_run.read() * yunit_old.express(yunit)
|
||||
mask = np.isfinite(x) & np.isfinite(y)
|
||||
x, y = x[mask], y[mask]
|
||||
if smooth > 0:
|
||||
y = gaussian_filter1d(y, sigma=smooth)
|
||||
label_run = self._label_run(run, y_run, label, nml_key)
|
||||
if colors is None:
|
||||
(base_line,) = P.plot(x, y, label=label_run, **kwargs)
|
||||
else:
|
||||
if nml_color is None:
|
||||
color = colors[i % len(colors)]
|
||||
else:
|
||||
nml = self.comp.get_nml(nml_color, run)
|
||||
try:
|
||||
color = colors[nml]
|
||||
except:
|
||||
color = colors(nml)
|
||||
(base_line,) = P.plot(x, y, label=label_run, color=color, **kwargs)
|
||||
if legend is None:
|
||||
legend = True
|
||||
|
||||
# Ax decorations
|
||||
P.xlabel(xlabel)
|
||||
P.ylabel(ylabel)
|
||||
if grid:
|
||||
P.grid()
|
||||
if legend:
|
||||
P.legend()
|
||||
|
||||
@@ -1076,6 +1077,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
color=base_line.get_color(),
|
||||
label=fitlabel,
|
||||
)
|
||||
|
||||
if subname_x:
|
||||
hdf5_x.close()
|
||||
if subname_y:
|
||||
@@ -1603,11 +1605,16 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
dependencies=["radial_bins", "rad_avg_" + name],
|
||||
)
|
||||
|
||||
self.rules["avg_map_" + name] = PlotRule(
|
||||
self,
|
||||
partial(self._plot_map, "avg_map_" + name),
|
||||
"Map of the radial average of {}".format(name),
|
||||
dependencies=["avg_map_" + name],
|
||||
)
|
||||
|
||||
self.rules["fluct_" + name] = PlotRule(
|
||||
self,
|
||||
partial(
|
||||
self._plot_map, "fluct_" + name, vmin=0.01, vmax=100, cmap="RdBu_r"
|
||||
),
|
||||
partial(self._plot_map, "fluct_" + name, cmap="RdBu_r"),
|
||||
"Fluctuation of {}".format(name),
|
||||
dependencies=["fluct_" + name],
|
||||
)
|
||||
@@ -1666,7 +1673,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
|
||||
# Radial
|
||||
generic_rule(field + "r")
|
||||
# Othoradial
|
||||
# Orthoradial
|
||||
generic_rule(field + "phi")
|
||||
# Norm
|
||||
generic_rule(field + "_norm")
|
||||
|
||||
Reference in New Issue
Block a user