add astrophysix support, fix labels

This commit is contained in:
Noe Brucy
2020-12-14 10:08:01 +01:00
parent 653e64d782
commit 7f7216abf6
6 changed files with 519 additions and 355 deletions
+6 -1
View File
@@ -17,6 +17,7 @@ class Comparator(Aggregator, HDF5Container):
pp_params=default_params(),
selector=None,
tag=None,
unit_time=cst.year,
**kwargs
):
"""
@@ -53,7 +54,11 @@ class Comparator(Aggregator, HDF5Container):
for num in self.nums[run]:
self.pp[run][num] = PostProcessor(
path_run, num, path_out=path_out_run, pp_params=self.pp_params
path_run,
num,
path_out=path_out_run,
pp_params=self.pp_params,
unit_time=unit_time,
)
run0 = self.runs[0]
+266 -259
View File
@@ -17,7 +17,8 @@ from scipy.stats import linregress
from numpy.polynomial.polynomial import polyfit
from scipy.ndimage.filters import gaussian_filter1d
from scipy import optimize
from astrophysix.simdm.datafiles import Datafile, PlotType, PlotInfo
from astrophysix.utils.file import FileType
import matplotlib as mpl
if os.environ.get("DISPLAY", "") == "":
@@ -26,12 +27,15 @@ if os.environ.get("DISPLAY", "") == "":
import pylab as P
from comparator import *
import pspec_read
import datetime
P.rcParams["image.cmap"] = "plasma"
P.rcParams["savefig.dpi"] = 400
filetype_from_ext = {ext: ft for ft in FileType for ext in ft.extension_list}
tex_params = {"text.latex.preamble": r"\usepackage{amsmath}"}
P.rcParams.update(tex_params)
def not_array_error(err):
epy2 = "object does not support indexing"
epy3 = "object is not subscriptable"
return str(err)[-len(epy2) :] == epy2 or str(err)[-len(epy3) :] == epy3
class PlotRule(Rule):
@@ -42,7 +46,7 @@ class PlotRule(Rule):
def plot(self, save, arg, **kwargs):
"""
Set the plotter's storage to 'save' and exetute the rule
Set the plotter's storage to 'save' and execute the rule
Parameters
----------
@@ -53,6 +57,12 @@ class PlotRule(Rule):
self.postproc.save = save
return self.process_fn(arg, **kwargs)
def datafile(self, arg):
return Datafile(
name=self.description + "_" + arg,
description=self.description + " ({})".format(arg),
)
class Plotter(Aggregator, BaseProcessor):
"""
@@ -99,6 +109,7 @@ class Plotter(Aggregator, BaseProcessor):
pp_params=None,
selector=None,
tag=None,
unit_time=cst.year,
**kwargs,
):
@@ -135,7 +146,13 @@ class Plotter(Aggregator, BaseProcessor):
# Get comparator object
self.comp = Comparator(
path, self.runs, self.nums, path_out, self.pp_params, selector=self.selector
path,
self.runs,
self.nums,
path_out,
self.pp_params,
unit_time=unit_time,
selector=self.selector,
)
# Get postprocesor objets for each run
@@ -147,17 +164,56 @@ class Plotter(Aggregator, BaseProcessor):
# Define rules
self.def_rules()
# generate astrophysix's simulations object
self.gen_simus()
self.save = None
def gen_simus(self):
self.simulations = {}
simu_fmt = self.pp_params.astrophysix.simu_fmt
descr_fmt = self.pp_params.astrophysix.descr_fmt
tag = self.pp_params.out.tag
for run in self.runs:
pp = self.pp[run][self.nums[run][0]]
nml = self.comp.namelist[run]
name = simu_fmt.format(run=run, tag=tag, nml=nml)
exec_time = str(datetime.datetime.fromtimestamp(os.stat(pp.path).st_ctime))
description = descr_fmt.format(run=run, tag=tag, nml=nml)
simu = Simulation(
simu_code=ramses,
name=name,
alias=name.upper(),
description=description,
directory_path=pp.path,
execution_time=exec_time,
)
for param in ramses.input_parameters:
try:
param_setting = ParameterSetting(
input_param=param,
value=self.comp.get_nml(param.key, run),
visibility=ParameterVisibility.BASIC_DISPLAY,
)
simu.parameter_settings.add(param_setting)
except KeyError as e:
self._log("key {} not found".format(e), "WARNING")
except AttributeError as e:
self._log("{}".format(e), "WARNING")
self.simulations[run] = simu
def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs):
"""
Check if the dependency belongs to the plotter object or to another one (comp, pp, ..)
"""
if dep in self.comp.rules:
done = self.comp.process(
result = self.comp.process(
dep, dep_arg, overwrite, overwrite_dep=self.overwrite_dep
)
self.just_done.extend(done)
if result is not None:
self.just_done.append(done)
else:
super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, **kwargs)
@@ -172,19 +228,13 @@ class Plotter(Aggregator, BaseProcessor):
)
def _process_rule(
self,
name,
rule,
arg,
overwrite=False,
ax=None,
movie=False,
from_cells=False,
**kwargs,
self, name, rule, arg, overwrite=False, ax=None, from_cells=False, **kwargs
):
"""
Open storage and figure if needed before processing a rule
"""
# Set full name according to argument
if not arg is None:
name_full = (
name
@@ -200,28 +250,46 @@ class Plotter(Aggregator, BaseProcessor):
else:
name_full = name
if rule.is_valid(arg):
if rule.kind == "classic" or rule.kind == "cells":
# Exit if not valid
if not rule.is_valid(arg):
self._log("{} is not valid in this context".format(name_full), "ERROR")
return
# get filetype of the output
filetype = filetype_from_ext[self.pp_params.out.ext]
# Select runs and nums
if "select" in kwargs:
select = kwargs.pop("select")
runs, nums = self.selector.select(**select)
elif "runs" in kwargs:
runs = kwargs.pop("runs")
if isinstance(runs, RunSelector):
nums = runs.nums
runs = runs.runs
else:
nums = self.nums
else:
runs = self.runs
nums = self.nums
i = 0
for run in runs:
files = []
for num in nums[run]:
datafiles = []
# Several plots
if rule.kind == "classic" or rule.kind == "cells":
run_num = [(run, num) for run in runs for num in nums[run]]
else:
run_num = [(run, None) for run in runs]
for i, (run, num) in enumerate(run_num):
# Find filename
plot_filename = self._find_filename(name_full, run, num)
# Find ax
try:
real_ax = ax[i]
except TypeError as e:
if ax is None:
fig, real_ax = P.subplots(1, 1)
elif not_array_error(e):
real_ax = ax
else:
raise
# Find plot save
if from_cells or rule.kind == "cells":
if not os.exists(self.pp[run][num].cells_filename):
self.pp[run][num].load_cells()
@@ -231,113 +299,40 @@ class Plotter(Aggregator, BaseProcessor):
save = tables.open_file(self.pp[run][num].filename)
else:
save = tables.open_file(self.comp.filename, "r")
try:
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax[i],
run=run,
**kwargs,
)
except TypeError as e:
if str(e) in [
"'LocatableAxes' object does not support indexing",
"'AxesSubplot' object does not support indexing",
"'AxesSubplot' object is not subscriptable",
"'Axes' object is not subscriptable",
"'LocatableAxes' object is not subscriptable",
]:
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax,
run=run,
**kwargs,
)
elif ax is None:
fig = P.figure()
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=P.gca(),
run=run,
**kwargs,
)
else:
raise
finally:
save.close()
i = i + 1
files.append(plot_filename)
else:
if "select" in kwargs and not "runs" in kwargs:
select = kwargs.pop("select")
runs, nums = self.selector.select(**select)
if not rule.kind == "runs":
kwargs["runs"] = runs
elif rule.kind == "runs" and "runs" in kwargs:
runs = kwargs.pop("runs")
else:
runs = self.runs
if ax is None:
ax = P.gca()
if rule.kind == "series" and len(runs) == 1:
run = self.runs[0]
plot_filename = self._find_filename(name_full, run)
else:
plot_filename = self._find_filename(name_full)
save = tables.open_file(self.comp.filename, "r")
# Call plot routine
try:
if rule.kind == "runs":
for i, run in enumerate(runs):
try:
self._plot_rule(
plot_info = self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax[i],
ax=real_ax,
run=run,
**kwargs,
)
except TypeError as e:
if str(e) in [
"'LocatableAxes' object does not support indexing",
"'AxesSubplot' object does not support indexing",
"'AxesSubplot' object is not subscriptable",
"'Axes' object is not subscriptable",
"'LocatableAxes' object is not subscriptable",
]:
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax,
run=run,
**kwargs,
)
else:
self._plot_rule(
rule, save, arg, plot_filename, overwrite, ax, **kwargs
)
finally:
save.close()
else:
self._log("{} is not valid in this context".format(name_full), "ERROR")
# Save in astrophysix format
df = rule.datafile(arg)
df[filetype] = plot_filename
if plot_info is not None:
df.plot_info = plot_info
if num is not None:
snap = self.pp[run][num].snapshot
if overwrite and df.name in snap.datafiles:
del snap.datafiles[df.name]
elif df.name not in snap.datafiles:
snap.datafiles.add(df)
if snap not in self.simulations[run].snapshots:
self.simulations[run].snapshots.add(snap)
datafiles.append(df)
return datafiles
def _plot_rule(self, rule, save, arg, plot_filename, overwrite, ax, **kwargs):
"""
@@ -345,7 +340,7 @@ class Plotter(Aggregator, BaseProcessor):
"""
P.sca(ax)
if self._needs_computation(overwrite, plot_filename):
rule.plot(save, arg, **kwargs)
plot_info = rule.plot(save, arg, **kwargs)
if not self.pp_params.out.interactive:
P.tight_layout(pad=1)
@@ -360,6 +355,7 @@ class Plotter(Aggregator, BaseProcessor):
if not self.pp_params.out.interactive:
P.close()
return plot_info
else:
self._log("Plot {} is already done, skipping...".format(plot_filename))
@@ -396,7 +392,7 @@ class Plotter(Aggregator, BaseProcessor):
ext=self.pp_params.out.ext,
)
def _label_run(self, run, node, label, nml_key):
def _label_run(self, run, node, label, nml_key, time=None):
"""
Set up a label for the run from the namelist and parameters
"""
@@ -435,7 +431,7 @@ class Plotter(Aggregator, BaseProcessor):
label_run = label
return label_run
def _ax_label_unit(self, node, label, unit, unit_coeff):
def _ax_label_unit(self, node, label, unit, unit_coeff, put_units=True):
"""
Find appropriate labels for axis
"""
@@ -457,6 +453,7 @@ class Plotter(Aggregator, BaseProcessor):
if unit is None:
unit = unit_old
if put_units:
if not unit_coeff == 1:
base = unit
unit = unit_coeff * unit
@@ -466,11 +463,28 @@ class Plotter(Aggregator, BaseProcessor):
return label, unit_old, unit
def _snapshot_title(self, run, node, title, nml_key, put_time, unit_time=cst.Myr):
title = self._label_run(run, node, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
u_str = unit_str(unit_time, format="{unit}")
time_str = self.pp_params.plot.time_fmt.format(
time.express(unit_time), u_str
)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
return title
def _plot_map(
self,
name,
ax_los,
run,
xlabel=None,
ylabel=None,
label=None,
unit=None,
unit_coeff=1.0,
@@ -480,12 +494,14 @@ class Plotter(Aggregator, BaseProcessor):
put_title=True,
nml_key=None,
put_time=True,
time_unit=cst.Myr,
unit_time=cst.Myr,
put_units=True,
unit_space=cst.pc,
cmap="plasma",
norm="log",
put_cbar=True,
autoscale=True,
transform=None,
**kwargs,
):
"""
@@ -503,9 +519,13 @@ class Plotter(Aggregator, BaseProcessor):
node = self.save.get_node("/maps/{}_{}".format(name, ax_los))
dmap = node.read()
label, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff)
label, unit_old, unit = self._ax_label_unit(
node, label, unit, unit_coeff, put_units
)
dmap = dmap * unit_old.express(unit)
if transform is not None:
dmap = transform(dmap)
if norm == "log":
norm = mpl.colors.LogNorm()
@@ -521,32 +541,28 @@ class Plotter(Aggregator, BaseProcessor):
P.locator_params(axis="both", nbins=self.pp_params.plot.ntick)
P.xlabel(self._ax_title[ax_h] + unit_str(unit_space))
P.ylabel(self._ax_title[ax_v] + unit_str(unit_space))
if xlabel is None:
xlabel = self._ax_title[ax_h]
if ylabel is None:
ylabel = self._ax_title[ax_v]
if put_units:
xlabel = xlabel + unit_str(unit_space)
ylabel = ylabel + unit_str(unit_space)
P.xlabel(xlabel)
P.ylabel(ylabel)
try:
cbar = P.colorbar(im, cax=P.gca().cax)
except AttributeError:
cbar = P.colorbar()
if put_title:
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
P.title(title)
if not label is None:
cbar.set_label(label)
if put_title:
title = self._label_run(run, node, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex
)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
P.title(title)
for i, plot_overlay in enumerate(overlays):
if plot_overlay in self.overlays:
plot_overlay = self.overlays[plot_overlay]
@@ -556,6 +572,23 @@ class Plotter(Aggregator, BaseProcessor):
except:
plot_overlay(ax_los, im_extent)
return PlotInfo(
plot_type=PlotType.IMAGE,
xaxis_values=np.linspace(im_extent[0], im_extent[1], dmap.shape[0] + 1),
yaxis_values=np.linspace(im_extent[2], im_extent[3], dmap.shape[1] + 1),
values=dmap,
xaxis_log_scale=False,
yaxis_log_scale=False,
values_log_scale=(norm == mpl.colors.LogNorm()),
xaxis_label=xlabel,
yaxis_label=ylabel,
values_label=label,
xaxis_unit=unit_space,
yaxis_unit=unit_space,
values_unit=unit,
plot_title=title,
)
def _overlay_contour(
self,
ax_los,
@@ -724,7 +757,7 @@ class Plotter(Aggregator, BaseProcessor):
nml_key=None,
put_title=True,
put_time=True,
time_unit=cst.Myr,
unit_time=cst.Myr,
**kwargs,
):
"""
@@ -749,26 +782,13 @@ class Plotter(Aggregator, BaseProcessor):
if not ylabel is None:
P.ylabel(ylabel)
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
if put_title:
title = self._label_run(run, node, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex
)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
P.title(title)
if label is None:
if label == None:
label = title
P.plot(bin_centers, mean_bin, label=label, **kwargs)
P.plot(bin_centers, mean_bin, label=title, **kwargs)
P.plot(bin_centers, mean_bin, label=label, **kwargs)
def _plot_hist(
self,
@@ -781,10 +801,11 @@ class Plotter(Aggregator, BaseProcessor):
unit_coeff=1.0,
ytransform=None,
label=None,
put_title=True,
title=None,
nml_key=None,
put_time=True,
time_unit=cst.Myr,
unit_time=cst.Myr,
xlog=None,
ylog=False,
kind="bar",
@@ -799,47 +820,41 @@ class Plotter(Aggregator, BaseProcessor):
"""
Plot an histogram (PDF, etc ...)
"""
# Get node
if not ax_los is None:
name = name + "_" + ax_los
node = self.save.get_node(group + name)
if xlog is None:
try:
xlog = node._v_attrs_.logbins
except:
xlog = False
# get label and units
xlabel, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff)
# Read data
if "mean" in node:
index = node["runs"].read().index(run.encode())
values, centers = node["mean"].read()[index]
else:
values, centers = node.read()
if xlog:
centers = centers + np.log10(unit_old.express(unit))
else:
centers = centers * unit_old.express(unit)
if ytransform is not None:
values = ytransform(values)
width = centers[1] - centers[0]
title = self._label_run(run, node, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex
)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
# Set title
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
if put_title:
P.title(title)
if label == None:
label = title
# Set colors
if color is None and not colors is None:
if nml_color is None:
color = colors[run]
@@ -850,11 +865,8 @@ class Plotter(Aggregator, BaseProcessor):
except:
color = colors(nml)
if label == None:
label = title
# Actual plot
if kind == "bar":
width = centers[1] - centers[0]
P.bar(centers, values, width, log=ylog, color=color, label=label, **kwargs)
elif kind == "step":
if ylog:
@@ -863,11 +875,13 @@ class Plotter(Aggregator, BaseProcessor):
else:
raise ValueError("kind must be 'bar' or 'step'")
# put labels
if not label is None:
P.xlabel(xlabel)
if not ylabel is None:
P.ylabel(ylabel)
# Also diplay fit, previously saved
if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
slope = node.attrs.slope
origin = node.attrs.origin
@@ -878,14 +892,27 @@ class Plotter(Aggregator, BaseProcessor):
linewidth=2,
color="orange",
)
P.ylim([None, 1.0])
# or a new one
if not fit is None:
self._overlay_fit(
centers, values, kind=fit, ls="--", lw=1.5, label=fitlabel
)
# returns PlotInfo (for Galactica)
edges = np.append(centers - width / 2.0, centers[-1] + width / 2.0)
return PlotInfo(
plot_type=PlotType.HISTOGRAM,
xaxis_values=edges,
yaxis_values=values,
xaxis_log_scale=False,
yaxis_log_scale=ylog,
xaxis_label=xlabel,
yaxis_label=ylabel,
xaxis_unit=unit,
yaxis_unit=cst.none,
plot_title=title,
)
def _plot(
self,
name_x,
@@ -898,19 +925,17 @@ class Plotter(Aggregator, BaseProcessor):
yunit=None,
xunit_coeff=1.0,
yunit_coeff=1.0,
ylog=False,
fit=None,
fitlabel=None,
smooth=0,
nml_key=None,
run=None,
runs=None,
yerr=None,
yerr_kind="std",
sigma_err=2.0,
grid=False,
put_time=False,
time_unit=cst.Myr,
unit_time=cst.Myr,
colors=None,
nml_color=None,
legend=None,
@@ -922,12 +947,15 @@ class Plotter(Aggregator, BaseProcessor):
Generic plot routine, with name_x and name_y two path in the hdf5 file
"""
# Get proper hdf5 names
if not node_arg is None:
name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg
# Get hdf5 nodes
node_x = self.save.get_node(name_x)
node_y = self.save.get_node(name_y)
# If the actual data is in a,other file, fetch it
if subname_x:
hdf5_x = tables.open_file(node_x.read())
node_x = hdf5_x.get_node(subname_x)
@@ -935,6 +963,7 @@ class Plotter(Aggregator, BaseProcessor):
hdf5_y = tables.open_file(node_y.read())
node_y = hdf5_y.get_node(subname_y)
# Find proper labels
xlabel, xunit_old, xunit = self._ax_label_unit(
node_x, xlabel, xunit, xunit_coeff
)
@@ -942,58 +971,24 @@ class Plotter(Aggregator, BaseProcessor):
node_y, ylabel, yunit, yunit_coeff
)
P.xlabel(xlabel)
P.ylabel(ylabel)
if grid:
P.grid()
if ylog:
P.yscale("log")
# If relevent, get time
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex
time.express(unit_time), unit_time.latex
)
if label is not None and len(label) > 0:
label = label + " | " + time_str
else:
label = time_str
# Manage the different forms in which the data may be stored :
# Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs)
if node_y._v_attrs.CLASS == "ARRAY":
x = node_x.read() * xunit_old.express(xunit)
y = node_y.read() * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
if not run is None:
label = self._label_run(run, node_y, label, nml_key)
if colors is None:
(base_line,) = P.plot(x, y, label=label, **kwargs)
else:
if nml_color is None:
color = colors[run]
elif nml_color == "time":
time = (
self.save.root._v_attrs.time * self.comp.info["unit_time"]
).express(time_unit)
color = colors(time)
else:
nml = self.comp.get_nml(nml_color, run)
try:
color = colors[nml]
except:
color = colors(nml)
if yerr is None:
(base_line,) = P.plot(x, y, label=label, color=color, **kwargs)
else:
if isinstance(yerr, str):
yerr = self.save.get_node(yerr).read()
base_line, _, _ = P.errorbar(x, y, yerr=yerr, label=label, **kwargs)
elif "mean" in node_y:
x = node_x.read() * xunit_old.express(xunit)
y = node_y.mean.read() * yunit_old.express(yunit)
@@ -1014,7 +1009,6 @@ class Plotter(Aggregator, BaseProcessor):
else:
yerr_min = y
yerr_max = y
yerr = yerr_max - yerr_min
mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(yerr)
x, y, yerr, yerr_min, yerr_max = (
@@ -1024,44 +1018,51 @@ class Plotter(Aggregator, BaseProcessor):
yerr_min[mask],
yerr_max[mask],
)
else:
x, y = node_x[run], node_y[run]
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
if isinstance(yerr, str):
yerr = self.save.get_node(yerr).read()
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
if not run is None:
label = self._label_run(run, node_y, label, nml_key)
if yerr_kind is None:
yerr = None
# Look if special colors method is used
if colors is None:
if yerr is None:
(base_line,) = P.plot(x, y, label=label, **kwargs)
else:
base_line, _, _ = P.errorbar(
x, y, yerr=[y - yerr_min, yerr_max - y], label=label, **kwargs
)
else:
if runs is None:
runs = self.runs
for i, run in enumerate(runs):
x_run, y_run = node_x[run], node_y[run]
x = x_run.read() * xunit_old.express(xunit)
y = y_run.read() * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
label_run = self._label_run(run, y_run, label, nml_key)
if colors is None:
(base_line,) = P.plot(x, y, label=label_run, **kwargs)
base_line, _, _ = P.errorbar(x, y, yerr=yerr, label=label, **kwargs)
else:
if nml_color is None:
color = colors[i % len(colors)]
color = colors[run]
elif nml_color == "time":
time = (
self.save.root._v_attrs.time * self.comp.info["unit_time"]
).express(unit_time)
color = colors(time)
else:
nml = self.comp.get_nml(nml_color, run)
try:
color = colors[nml]
except:
color = colors(nml)
(base_line,) = P.plot(x, y, label=label_run, color=color, **kwargs)
if legend is None:
legend = True
if yerr is None:
(base_line,) = P.plot(x, y, label=label, color=color, **kwargs)
else:
base_line, _, _ = P.errorbar(
x, y, yerr=yerr, color=color, label=label, **kwargs
)
# Ax decorations
P.xlabel(xlabel)
P.ylabel(ylabel)
if grid:
P.grid()
if legend:
P.legend()
@@ -1076,6 +1077,7 @@ class Plotter(Aggregator, BaseProcessor):
color=base_line.get_color(),
label=fitlabel,
)
if subname_x:
hdf5_x.close()
if subname_y:
@@ -1603,11 +1605,16 @@ class Plotter(Aggregator, BaseProcessor):
dependencies=["radial_bins", "rad_avg_" + name],
)
self.rules["avg_map_" + name] = PlotRule(
self,
partial(self._plot_map, "avg_map_" + name),
"Map of the radial average of {}".format(name),
dependencies=["avg_map_" + name],
)
self.rules["fluct_" + name] = PlotRule(
self,
partial(
self._plot_map, "fluct_" + name, vmin=0.01, vmax=100, cmap="RdBu_r"
),
partial(self._plot_map, "fluct_" + name, cmap="RdBu_r"),
"Fluctuation of {}".format(name),
dependencies=["fluct_" + name],
)
@@ -1666,7 +1673,7 @@ class Plotter(Aggregator, BaseProcessor):
# Radial
generic_rule(field + "r")
# Othoradial
# Orthoradial
generic_rule(field + "phi")
# Norm
generic_rule(field + "_norm")
+54 -17
View File
@@ -149,7 +149,15 @@ class PostProcessor(HDF5Container):
cells_loaded = False
def __init__(self, path=None, num=None, path_out=None, pp_params=None, tag=None):
def __init__(
self,
path=None,
num=None,
path_out=None,
pp_params=None,
tag=None,
unit_time=cst.year,
):
"""
Creates the basic structures needed for the outputs
"""
@@ -170,6 +178,14 @@ class PostProcessor(HDF5Container):
self.path_out + "/cells_" + tag_name + format(num, "05") + ".h5"
)
self.pspec_filename = (
self.path_out + "/pspec_" + tag_name + format(num, "05") + ".h5"
)
self.filaments_filename = (
self.path_out + "/filaments_" + tag_name + format(num, "05") + ".pickle"
)
if not os.path.exists(self.path_out):
os.makedirs(self.path_out)
@@ -213,7 +229,7 @@ class PostProcessor(HDF5Container):
if not self.pp_params.pymses.multiprocessing:
self._rt.disable_multiprocessing()
# Set the extend of the image
# Set the extent of the image
self._radius = 0.5 / self.pp_params.pymses.zoom
self.lbox = self.info["boxlen"]
@@ -273,12 +289,23 @@ class PostProcessor(HDF5Container):
self.log_id = "[{}, {}] ".format(self.run, self.num)
if os.path.exists(self.path_out + "/filaments.pickle"):
with open(self.path_out + "/filaments.pickle", "rb") as f:
if os.path.exists(self.filaments_filename):
with open(self.filaments_filename, "rb") as f:
self.fil = pickle.load(f)
else:
self.fil = None
self.snapshot = Snapshot(
name=str(self.num),
description="",
time=(
self.info["time"] * self.info["unit_time"].express(unit_time),
unit_time,
),
directory_path=self.path,
data_reference="OUTPUT_{}".format(self.num),
)
self.def_rules()
def load_cells(self):
@@ -964,12 +991,22 @@ class PostProcessor(HDF5Container):
)
bins = np.zeros(r.shape, dtype=int)
for r0 in radial_bins[1:]:
for r0 in radial_bins[1:-1]:
bins = bins + (r >= r0).astype(int)
vr_mean = mean_bin_vr[bins]
vphi_mean = mean_bin_vphi[bins]
# use linear interpolation
# v = ((r[i+1] - r)v[i] + (r - r[i])v[i + 1]) / (r[i+1] - r[i])
vr_mean = (radial_bins[bins + 1] - r) * vr_mean
vr_mean = vr_mean + (r - radial_bins[bins]) * mean_bin_vr[bins + 1]
vr_mean = vr_mean / (radial_bins[bins + 1] - radial_bins[bins])
vphi_mean = (radial_bins[bins + 1] - r) * vphi_mean
vphi_mean = vphi_mean + (r - radial_bins[bins]) * mean_bin_vphi[bins + 1]
vphi_mean = vphi_mean / (radial_bins[bins + 1] - radial_bins[bins])
vr = self.oct_getter_vr(dset)
vphi = self.oct_getter_vphi(dset)
alpha = (vphi - vphi_mean) * (vr - vr_mean)
@@ -1055,7 +1092,7 @@ class PostProcessor(HDF5Container):
return sinks_dict
def _pspec(self):
outfile = self.path_out + "/pspec.h5"
outfile = self.pspec_filename
pspec_new.pspec(repo=self.path, iouts=[self.num], outfile=outfile)
return True
@@ -1086,9 +1123,10 @@ class PostProcessor(HDF5Container):
self.fil.create_mask(
verbose=verbose,
smooth_size=1 * u.pix,
adapt_thresh=2 * u.pix,
adapt_thresh=4 * u.pix,
size_thresh=size_thresh * u.pix ** 2,
glob_thresh=glob_thresh,
fill_hole_size=0.1 * u.pix ** 2,
)
self.fil.medskel(verbose=verbose)
self.fil.analyze_skeletons(
@@ -1099,8 +1137,7 @@ class PostProcessor(HDF5Container):
self.fil.exec_rht()
self.fil.find_widths()
outfile = self.path_out + "/filaments.pickle"
with open(outfile, "wb") as f:
with open(self.filaments_filename, "wb") as f:
pickle.dump(self.fil, f, pickle.HIGHEST_PROTOCOL)
return True
@@ -1121,7 +1158,7 @@ class PostProcessor(HDF5Container):
def _filaments_forces(self):
"""
Compute forces within a filament (for disks)
Compute forces within a filament (for disks), within the slice at z=0
"""
GM = self.G * self.pp_params.disk.mass_star # Mass parameter
@@ -1134,11 +1171,11 @@ class PostProcessor(HDF5Container):
i_center, j_center = self._filaments_center()
# Get slices and projections at z = 0
vphi = self.save.get_node("/maps/slice_velphi_z").read()
gr = self.save.get_node("/maps/slice_gr_z").read()
Pz = self.save.get_node("/maps/slice_P_z").read()
coldens = self.save.get_node("/maps/coldens_z").read()
vr = self.save.get_node("/maps/slice_velr_z").read()
vphi = self.get_value("/maps/slice_velphi_z")
gr = self.get_value("/maps/slice_gr_z")
Pz = self.get_value("/maps/slice_P_z")
rho = self.get_value("/maps/slice_rho_z")
vr = self.get_value("/maps/slice_velr_z")
# Get coordinates
im_extent = np.array(self.save.root.maps._v_attrs.im_extent) * self.lbox
@@ -1167,7 +1204,7 @@ class PostProcessor(HDF5Container):
# Thermal support
GPx, GPy = np.gradient(Pz)
gradPr = (xx * GPx + yy * GPy) / rr
fP = gradPr / coldens
fP = gradPr / rho
# Gravitational field
e2 = (1.0 / 512) ** 2
@@ -1371,7 +1408,7 @@ class PostProcessor(HDF5Container):
"slice_velphi": "z",
"slice_gr": "z",
"slice_P": "z",
"coldens": "z",
"slice_rho": "z",
"slice_velr": "z",
},
),
+17 -8
View File
@@ -22,22 +22,22 @@ disk: # Disk speficic parameters
pdf: # parameters for probability density functions
nb_bin : 200 # Number of bins for the PDF
nb_bin : 100 # Number of bins for the PDF
range : [-1.5, 2.5] # Range of the PDF (log of fluctuation)
xmin_fit : 0.1 # Lower boundary of the fit (log of fluctuation)
xmin_fit : 0.3 # Lower boundary of the fit (log of fluctuation)
xmax_fit : 1.5 # Upper boundary of the fit (log of fluctuation)
fit_cut : 1e-4 # Exclude value that are < fit_cut * maximum
filaments: # parameters for FilFinder
datamap : "rho_avg"
datamap : "fluct_coldens"
verbose : False
rmin : 0.15 # In fraction of the box (zoom to be taken into account)
rmax : 0.45 # In fraction of the box (idem)
size_thresh : 200 # in pixels**2
skel_thresh : 100 # in pixels
branch_thresh : 100 # in pixels
glob_thresh : 40 # in map unit
size_thresh : 400 # in pixels**2
skel_thresh : 50 # in pixels
branch_thresh : 50 # in pixels
glob_thresh : 1.5 # in datamap unit
pymses: # Parameters for Pymses reader
@@ -80,7 +80,7 @@ out: # Parameters for post processing
# {ext} : Extension defined above
# {name} : Name of the rule
# {tag} : Tag defined above
# {nml[nml_key]} : The value of nml_key in the namelist (ex: amr_params/levelmin)
# {nml[nml_key]} : The value of nml_key in the namelist (ex: {nml[amr_params/levelmin]})
process: # General setting of the post-processor module
@@ -91,3 +91,12 @@ process: # General setting of the post-processor module
rules: # Specific rules parameters
turb_energy_threshold : -1 # Remove invalid data (<0 = no threshold)
astrophysix: # Parameters for astrophysix and galactica
simu_fmt : "{tag}_{run}" # Format of the name of simulation
descr_fmt : "{tag}_{run}" # Format of the default description
# The following keys are accepted
# {run} : Name of the relevant run
# {tag} : Tag defined above
# {nml[nml_key]} : The value of nml_key in the namelist (ex: {nml[amr_params/levelmin]})
+92
View File
@@ -0,0 +1,92 @@
# coding: utf-8
from astrophysix.simdm.protocol import (
SimulationCode,
AlgoType,
Algorithm,
InputParameter,
PhysicalProcess,
Physics,
)
# Simulation code definition #
ramses = SimulationCode(
name="Ramses 3 (MHD)",
code_name="Ramses",
code_version="3.10.1",
alias="RAMSES_3",
url="https://www.ics.uzh.ch/~teyssier/ramses/RAMSES.html",
description="Ramses MHD code",
)
# => Add algorithms : available algorithm types are :
# - AlgoType.AdaptiveMeshRefinement
# - AlgoType.VoronoiMovingMesh
# - AlgoType.SmoothParticleHydrodynamics
# - AlgoType.Godunov
# - AlgoType.PoissonMultigrid
# - AlgoType.PoissonConjugateGradient
# - AlgoType.ParticleMesh
# - AlgoType.FriendOfFriend
# - AlgoType.HLLCRiemann
# - AlgoType.RayTracer
ramses.algorithms.add(
Algorithm(algo_type=AlgoType.AdaptiveMeshRefinement, description="AMR")
)
ramses.algorithms.add(
Algorithm(algo_type=AlgoType.Godunov, description="Godunov scheme")
)
ramses.algorithms.add(
Algorithm(algo_type=AlgoType.HLLCRiemann, description="HLLC Riemann solver")
)
ramses.algorithms.add(
Algorithm(
algo_type=AlgoType.PoissonMultigrid, description="Multigrid Poisson solver"
)
)
# => Add input parameters
ramses.input_parameters.add(
InputParameter(
key="amr_params/levelmin",
name="lmin",
description="min. level of AMR refinement",
)
)
ramses.input_parameters.add(
InputParameter(
key="amr_params/levelmax",
name="lmax",
description="max. level of AMR refinement",
)
)
ramses.input_parameters.add(
InputParameter(
key="amr_params/jeans_refine",
name="jeans_refine",
description="Array, number of cells needed to resolve the Jeans lenght at each level from lmin",
)
)
ramses.input_parameters.add(
InputParameter(
key="cloud_params/beta_cool", name="beta", description="Cooling parameter"
)
)
# => Add physical processes : available physics are :
# - Physics.SelfGravity
# - Physics.Hydrodynamics
# - Physics.MHD
# - Physics.StarFormation
# - Physics.SupernovaeFeedback
# - Physics.AGNFeedback
# - Physics.MolecularCooling
ramses.physical_processes.add(
PhysicalProcess(
physics=Physics.Hydrodynamics, description="Hydrodynamical equations are solved"
)
)
ramses.physical_processes.add(
PhysicalProcess(physics=Physics.SelfGravity, description="Self-Gravity is applied.")
)
+40 -26
View File
@@ -1,14 +1,16 @@
# coding: utf-8
import pymses.utils.constants as cst
import astrophysix.units as cst
create_unit = cst.Unit.create_unit
def parse_exp_unit(u):
splitted = u.split("^")
name_u = cst.Unit.from_name(splitted[0]).latex
name_u = cst.Unit.from_name(splitted[0]).latex.replace("text", "math")
exp = ""
if len(splitted) > 1:
exp = "$^{" + str(splitted[1]) + "}$"
exp = "^{" + str(splitted[1]) + "}"
return name_u + exp
@@ -17,66 +19,78 @@ def convert_exp(number, digits=4):
splitted = "{num:.{digits}g}".format(num=number, digits=digits).split("e")
# If no need of scientific notation (low number of digits)
if len(splitted) == 1:
return "${}$".format(splitted[0])
return "{}".format(splitted[0])
else:
coeff = splitted[0]
exp = splitted[1]
exp_str = "10^{" + str(int(exp)) + "}"
if float(coeff) == 1.0:
return "$" + exp_str + "$"
return exp_str
else:
return "${}\\times {}$".format(coeff, exp_str)
return "{}\\times {}".format(coeff, exp_str)
def unit_str(unit, base=None, prefix=""):
def unit_str(unit, base=None, prefix="", format=" [{unit}]"):
"""
Format a unit in matplotlib parsable latex
unit : astrophysics.units.unit.Unit
base : astrophysics.units.unit.Unit to use as base unit (if None `unit` is used)
prefix : str to put befor the unit
format : str with the {unit} key, to put external decoration
"""
if unit == cst.none:
return ""
elif not base is None:
coeff = unit.express(base)
return unit_str(base, prefix=convert_exp(coeff) + " ")
u_str = unit_str(base, prefix=convert_exp(coeff) + " ")
elif len(unit.latex) > 0:
if ("." in unit.latex or "^" in unit.latex) and not "$" in unit.latex:
if "." in unit.latex or "^" in unit.latex:
base_str = ".".join(map(parse_exp_unit, unit.name.split(".")))
return r" [{}{}]".format(prefix, base_str)
u_str = r"${}{}$".format(prefix, base_str)
else:
return r" [{}{}]".format(prefix, unit.latex)
u_str = r"${}{}$".format(prefix, unit.latex.replace("text", "math"))
elif len(unit.name) > 0:
try:
base_str = ".".join(map(parse_exp_unit, unit.name.split(".")))
u_str = r" [{}{}]".format(prefix, base_str)
u_str = r"${}{}$".format(prefix, base_str)
except:
u_str = r" [{}{}]".format(prefix, unit.name)
return u_str
u_str = r"${}{}$".format(prefix, unit.name)
else:
base_str = ".".join(
map(parse_exp_unit, unit._decompose_base_units().split("."))
)
return r" [{}{} {}]".format(prefix, unit.coeff, base_str)
u_str = r"${}{} {}$".format(prefix, unit.coeff, base_str)
return format.format(unit=u_str)
cst.Msun_pc3 = cst.create_unit(
cst.coldens = create_unit(
"Msun.pc^-2", base_unit=cst.Msun / cst.pc ** 2, descr="Column density"
)
cst.km_s = create_unit("km.s^-1", base_unit=cst.km / cst.s, descr="Speed")
cst.Msun_pc3 = create_unit(
"Msun.pc^-3", base_unit=cst.Msun / cst.pc ** 3, descr="Density"
)
cst.Msun_pc3 = cst.create_unit(
"Msun.pc^-3", base_unit=cst.Msun / cst.pc ** 3, descr="Density"
)
cst.kg_m3 = create_unit("kg.m^-3", base_unit=cst.kg / cst.m ** 3, descr="Density")
cst.ssfr = cst.create_unit(
"Msun.yr^-1.pc^-2",
cst.ssfr = create_unit(
"Msun.year^-1.pc^-2",
base_unit=cst.Msun / cst.year / cst.pc ** 2,
descr="Surfacic SFR",
latex="M$_{\odot}$.yr$^{-1}$.pc$^{-2}$",
)
# latex='M$_{\odot}$.yr$^{-1}$.pc$^{-2}$')
cst.ssfrG = cst.create_unit(
cst.ssfrG = create_unit(
"Msun.Gyr^-1.pc^-2",
base_unit=1e-9 * cst.Msun / cst.year / cst.pc ** 2,
descr="Surfacic SFR",
latex="M$_{\odot}$.Gyr$^{-1}$.pc$^{-2}$",
latex="\mathrm{M}_{\odot}.\mathrm{Gyr}^{-1}.\mathrm{pc}^{-2}",
)
cst.uG = cst.create_unit(
"μG", base_unit=1e-10 * cst.T, descr="Micro Gauss", latex="$\mu\mathrm{G}$"
cst.uG = create_unit(
"μG", base_unit=1e-10 * cst.T, descr="Micro Gauss", latex="\\mu\\mathrm{G}"
)