Files
pipeline/plotter.py
T
2021-07-25 17:58:33 +02:00

1809 lines
58 KiB
Python

# coding: utf-8
"""
This software is a helper to use pysmes tools to read and analyse RAMSES Outputs.
It's a rule based interface.
This is the plotter module.
@author Noé Brucy 2019-2021
"""
import os
from functools import partial
import matplotlib as mpl
import numpy as np
import tables
from astrophysix.simdm.datafiles import Datafile, PlotInfo, PlotType
from astrophysix.utils.file import FileType
from numpy.polynomial.polynomial import polyfit
from scipy import optimize
from scipy.ndimage.filters import gaussian_filter1d
from scipy.stats import linregress
if os.environ.get("DISPLAY", "") == "":
print("No display found. Using non-interactive Agg backend")
mpl.use("Agg")
import datetime
import matplotlib.pyplot as plt
import pspec_read
from baseprocessor import Rule, BaseProcessor
from aggregator import Aggregator
from studyprocessor import StudyProcessor
from run_selector import RunSelector
from units import U, unit_str, convert_exp
from astrophysix.simdm.experiment import (
ParameterSetting,
ParameterVisibility,
Simulation,
)
from ramses_astrophysix import ramses
filetype_from_ext = {ext: ft for ft in FileType for ext in ft.extension_list}
def not_array_error(err):
epy2 = "object does not support indexing"
epy3 = "object is not subscriptable"
return str(err)[-len(epy2) :] == epy2 or str(err)[-len(epy3) :] == epy3
class PlotRule(Rule):
"""
The rule class, speficic to plot.
Add an extra method, plot, that take the reference to an open hdf5 file (from pytables)
"""
def plot(self, save, arg, **kwargs):
"""
Set the plotter's storage to 'save' and execute the rule
Parameters
----------
save : opended pytables hdf5 file, where to find the data
arg : main argument of the plotting function
kargs : optional keyword arguments to the plotting function
"""
self.postproc.save = save
return self.process_fn(arg, **kwargs)
def datafile(self, name, arg):
if arg is not None:
name = name + "_" + str(arg)
return Datafile(
name=name,
description=self.description + " ({})".format(arg),
)
class Plotter(Aggregator, BaseProcessor):
"""
This class loads derived quantities and plot them
"""
solve_self_dep = False
# Axes information
_ax_nb = {"x": 0, "y": 1, "z": 2} # Number of each axes
_axes_h = {"x": "y", "y": "x", "z": "x"} # Associated horizontal axe
_axes_v = {"x": "z", "y": "z", "z": "y"} # Associated vertical axe
_ax_title = {"x": r"$x$", "y": r"$y$", "z": r"$z$"}
G = 1.0 # Gravitational constant
# Conversion table from namelist keys (from amses config file) into LaTex strings
label_convert = {
"turb_rms": "$f_{rms}$",
"beta": "$\\beta$",
"beta_cool": "$\\beta$",
"dens0": "$n_0$",
"coldens0": "$\\Sigma_0$",
"sfr_avg_window": "window",
"bx_bound": "$B_0$",
"levelmax": "$l_{\\max}$",
"levelmin": "$l_{\\min}$",
"comp_frac": "$1 - \\zeta$",
}
# Conversion table from namelist values (from amses config file) into LaTex strings
value_convert = {
"sfr_avg_window": lambda x: "${:g}$ Myr".format(80 * x),
"bx_bound": lambda x: "${:g}$ $\\mu G$".format(5.267501272979475 * x),
}
def __init__(
self,
path,
in_runs=None,
in_nums=None,
path_out=None,
params=None,
selector=None,
tag=None,
unit_time=U.year,
**kwargs,
):
"""
Create a new Plotter instance. Will select run and outputs via a RunSelector object.
Parameters
----------
path : path to the main folder of the simulations (ex '~/simus/myproject')
in_runs : list of the runs to consider (ex ['run1', 'run2'])
in_nums : list or dict of the outputs numbers to consider (ex [3, 5]
or {'run1' : [3, 5], 'run2' : [4, 6])
path_out : Path where the plot will be saved. By default set to `path`
params : Parameters for postprocessing. See params module.
selector : Existing instance of RunSelector, that selects runs and outputs. If set, in_runs and
in_nums will be ignored
tag : string to add in the output and data files.
kwargs : Keyword arguments for RunSelector.
"""
super(Plotter, self).__init__(path, path_out, params, tag)
# Select runs
if selector is None:
self.selector = RunSelector(
path, in_runs, in_nums, self.params.input.nml_filename, **kwargs
)
else:
self.selector = selector
# Save infos
self.path = path
self.runs = self.selector.runs
self.nums = self.selector.nums
# Get studyprocessor object
self.study = StudyProcessor(
path,
self.runs,
self.nums,
path_out,
self.params,
unit_time=unit_time,
selector=self.selector,
)
# Get postprocesor objets for each run
self.snaps = self.study.snaps
# Define log prefix
self.log_id = "[plot {}] ".format(self.params.out.tag)
# Define rules
self.def_rules()
# generate astrophysix's simulations object
self.gen_simus()
self.save = None
def gen_simus(self):
self.simulations = {}
simu_fmt = self.params.astrophysix.simu_fmt
descr_fmt = self.params.astrophysix.descr_fmt
tag = self.params.out.tag
for run in self.runs:
pp = self.snaps[run][self.nums[run][0]]
nml = self.study.namelist[run]
name = simu_fmt.format(run=run, tag=tag, nml=nml)
exec_time = str(datetime.datetime.fromtimestamp(os.stat(pp.path).st_ctime))
exec_time = exec_time.split(".")[0]
description = descr_fmt.format(run=run, tag=tag, nml=nml)
simu = Simulation(
simu_code=ramses,
name=name,
alias=name.upper(),
description=description,
directory_path=pp.path,
execution_time=exec_time,
)
for param in ramses.input_parameters:
value = None
try:
value = self.study.get_nml(param.key, run)
except KeyError as e:
self._log("key {} not found".format(e), "WARNING")
if value is not None:
try:
param_setting = ParameterSetting(
input_param=param,
value=value,
visibility=ParameterVisibility.BASIC_DISPLAY,
)
simu.parameter_settings.add(param_setting)
except AttributeError:
param_setting = ParameterSetting(
input_param=param,
value=str(value),
visibility=ParameterVisibility.BASIC_DISPLAY,
)
simu.parameter_settings.add(param_setting)
self.simulations[run] = simu
def _not_self_dep(self, name, dep, dep_arg, overwrite, select):
"""
Check if the dependency belongs to the plotter object or to another one (comp, pp, ..)
"""
if dep in self.study.rules:
result = self.study.process(
dep, dep_arg, overwrite, self.overwrite_dep, select
)
if result is not None:
self.just_done.append(result)
else:
super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, select)
def _needs_computation(self, overwrite, plot_filename):
"""
Returns true if the plot needs to be redone
"""
return (
self.params.out.interactive
or overwrite
or not os.path.exists(plot_filename)
)
def _process_rule(
self,
name,
rule,
arg,
overwrite=False,
select=None,
ax=None,
from_cells=False,
**kwargs,
):
"""
Open storage and figure if needed before processing a rule
"""
# Set full name according to argument
if arg is not None:
name_full = (
name
+ "_"
+ str(arg)
.replace(" ", "")
.replace("[", "")
.replace("]", "")
.replace(",", "_")
.replace("'", "")
.replace("/", "")
)
else:
name_full = name
# get filetype of the output
filetype = filetype_from_ext[self.params.out.ext]
# Select runs and nums
if select is not None:
runs, nums = self.selector.select(**select)
else:
runs = self.runs
nums = self.nums
datafiles = []
if rule.kind == "snapshot" or rule.kind == "cells":
run_num = [(run, num) for run in runs for num in nums[run]]
elif rule.kind == "comp":
run_num = [(None, None)]
else:
run_num = [(run, None) for run in runs]
onefigure = False # If axes are provided, only save/close once
if ax is not None:
onefigure = True
plot_filename = self._find_filename(name_full)
for i, (run, num) in enumerate(run_num):
# Find filename
if not onefigure:
plot_filename = self._find_filename(name_full, run, num)
# Find ax
try:
real_ax = ax[i]
except TypeError as e:
if ax is None:
fig, real_ax = plt.subplots(1, 1)
elif not_array_error(e):
real_ax = ax
else:
raise
# Find plot save
if from_cells or rule.kind == "cells":
if not os.exists(self.pp[run][num].cells_filename):
self.pp[run][num].load_cells()
self.pp[run][num].unload_cells()
save = tables.open_file(self.pp[run][num].cells_filename)
elif rule.kind == "snapshot":
save = tables.open_file(self.snaps[run][num].filename)
else:
save = tables.open_file(self.study.filename, "r")
# Call plot routine
try:
close = (not onefigure) or (i == len(run_num) - 1)
plot_info = self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=real_ax,
close=close,
run=run,
**kwargs,
)
finally:
save.close()
# Save in astrophysix format
df = rule.datafile(name, arg)
df[filetype] = plot_filename
if plot_info is not None:
df.plot_info = plot_info
if num is not None:
snap = self.snaps[run][num].snapshot
if overwrite and df.name in snap.datafiles:
del snap.datafiles[df.name]
elif df.name not in snap.datafiles:
snap.datafiles.add(df)
if snap not in self.simulations[run].snapshots:
self.simulations[run].snapshots.add(snap)
datafiles.append(df)
return datafiles
def _plot_rule(
self, rule, save, arg, plot_filename, overwrite, ax, close=True, **kwargs
):
"""
Once all dependencies are met, actually process the rule
"""
plt.sca(ax)
if self._needs_computation(overwrite, plot_filename):
plot_info = rule.plot(save, arg, **kwargs)
if not self.params.out.interactive and close:
plt.tight_layout(pad=1)
if self.params.out.save:
plt.savefig(plot_filename)
self._log("{} plotted".format(plot_filename), "SUCCESS")
else:
self._log(
"{} plotted".format(os.path.basename(plot_filename)), "SUCCESS"
)
if not self.params.out.interactive and close:
plt.close()
return plot_info
else:
self._log("Plot {} is already done, skipping...".format(plot_filename))
def _find_filename(self, name_full, run=None, num=None, fmt=None):
"""
Determine a filename based on rule name, run, output and parameters
"""
tag_name = self.params.out.tag
if fmt is None and self.params.out.fmt == "":
if not self.params.out.tag == "":
tag_name = "_" + tag_name
if run is not None and num is not None:
fmt = "{out}/{run}/{name}{tag}_{run}_{num:05}{ext}"
elif run is not None:
fmt = "{out}/{run}/{name}{tag}_{run}{ext}"
else:
fmt = "{out}/{name}{tag}{ext}"
elif fmt is None:
fmt = self.params.out.fmt
nml = None
if run is not None:
nml = self.study.namelist[run]
return fmt.format(
run=run,
name=name_full,
tag=tag_name,
num=num,
nml=nml,
out=self.path_out,
ext=self.params.out.ext,
)
def get_label_run(self, run, label=None, nml_key=None, time=None):
"""
Set up a label for the run from the namelist and parameters
"""
def get_label_nml(nml_key):
prop_name = os.path.basename(nml_key)
if prop_name in self.label_convert:
prop_label = self.label_convert[prop_name]
else:
prop_label = prop_name
prop_value = self.study.get_nml(nml_key, run)
if prop_name in self.value_convert:
prop_value_str = self.value_convert[prop_name](prop_value)
elif type(prop_value) in [int, float]:
prop_value_str = convert_exp(prop_value, digits=5)
else:
prop_value_str = str(prop_value)
return r"{} = {}".format(prop_label, prop_value_str)
def get_label_file(run):
label_filename = f"{self.path}/{run}/{self.params.input.label_filename}"
if os.path.exists(label_filename):
with open(label_filename, "r") as label_file:
label = label_file.readline()[:-1]
label_file.close()
else:
label = run
return label
if nml_key is None and label is None:
label_run = get_label_file(run)
elif nml_key is not None:
if not type(nml_key) == list:
nml_key = [nml_key]
label_run = ", ".join(map(get_label_nml, nml_key))
if label is not None:
label_run = label + " (" + label_run + ")"
else:
label_run = label
return label_run
def _ax_label_unit(self, node, label, unit, unit_coeff, put_units=True):
"""
Find appropriate labels for axis
"""
if label is None:
if "label" in node._v_attrs:
label = node._v_attrs.label
elif node._v_name in self.label_convert:
label = self.label_convert[node._v_name]
elif not node._v_title == "":
label = node._v_title
else:
label = node._v_name
if "unit" in node._v_attrs:
unit_old = node._v_attrs.unit
else:
unit_old = U.none
if unit is None:
unit = unit_old
if put_units:
if not unit_coeff == 1:
base = unit
unit = unit_coeff * unit
label = label + unit_str(unit, base=base)
else:
label = label + unit_str(unit)
return label, unit_old, unit
def snapshot_title(self, run, title, nml_key, put_time, unit_time=U.Myr):
title = self.get_label_run(run, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.study.info["unit_time"]
u_str = unit_str(unit_time, format="{unit}")
time_str = self.params.plot.time_fmt.format(time.express(unit_time), u_str)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
return title
def _plot_map(
self,
name,
ax_los,
run,
xlabel=None,
ylabel=None,
label=None,
unit=None,
unit_coeff=1.0,
overlays=[],
overlays_kwargs=[],
title=None,
put_title=True,
nml_key=None,
put_time=True,
unit_time=U.Myr,
put_units=True,
unit_space=U.pc,
center_space=False,
cmap="plasma",
norm="log",
put_cbar=True,
transform=None,
vmin=None,
vmax=None,
**kwargs,
):
"""
Plot data on a map
"""
ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los]
im_extent = np.array(self.save.root.maps._v_attrs.im_extent)
unit_length = self.save.root._v_attrs["unit_length"]
if center_space:
center = self.save.root.maps._v_attrs.center
center_h = center[self._ax_nb[ax_h]]
center_v = center[self._ax_nb[ax_v]]
im_extent[:2] = im_extent[:2] - center_h
im_extent[2:] = im_extent[2:] - center_v
im_extent = im_extent * unit_length.express(unit_space)
node = self.save.get_node("/maps/{}_{}".format(name, ax_los))
dmap = node.read()
label, unit_old, unit = self._ax_label_unit(
node, label, unit, unit_coeff, put_units
)
dmap = dmap * unit_old.express(unit)
if transform is not None:
dmap = transform(dmap)
if vmin is None:
vmin = np.min(dmap)
if vmax is None:
vmax = np.max(dmap)
if norm == "log":
norm = mpl.colors.LogNorm(vmin=vmin, vmax=vmax)
elif norm == "linear":
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
im = plt.imshow(
dmap, extent=im_extent, origin="lower", norm=norm, cmap=cmap, **kwargs
)
plt.locator_params(axis="both", nbins=self.params.plot.ntick)
if xlabel is None:
xlabel = self._ax_title[ax_h]
if ylabel is None:
ylabel = self._ax_title[ax_v]
if put_units:
xlabel = xlabel + unit_str(unit_space)
ylabel = ylabel + unit_str(unit_space)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
try:
cbar = plt.colorbar(im, cax=plt.gca().cax)
except AttributeError:
cbar = plt.colorbar()
if put_title:
title = self.snapshot_title(run, title, nml_key, put_time, unit_time)
plt.title(title)
if label is not None:
cbar.set_label(label)
for i, plot_overlay in enumerate(overlays):
if plot_overlay in self.overlays:
if plot_overlay == "particles":
plot_overlay = partial(
self.overlays[plot_overlay],
unit_space=unit_space,
center_space=center_space,
)
else:
plot_overlay = self.overlays[plot_overlay]
try:
plot_overlay(ax_los, im_extent, **overlays_kwargs[i])
except IndexError:
plot_overlay(ax_los, im_extent)
return PlotInfo(
plot_type=PlotType.IMAGE,
xaxis_values=np.linspace(im_extent[0], im_extent[1], dmap.shape[0] + 1),
yaxis_values=np.linspace(im_extent[2], im_extent[3], dmap.shape[1] + 1),
values=dmap,
xaxis_log_scale=False,
yaxis_log_scale=False,
values_log_scale=False,
xaxis_label=xlabel,
yaxis_label=ylabel,
values_label=label,
xaxis_unit=unit_space,
yaxis_unit=unit_space,
values_unit=unit,
plot_title=title,
)
def _overlay_contour(
self,
ax_los,
im_extent,
map_name,
log=False,
lvl_array=None,
lw=None,
lvl_th=None,
lvl_max_lbl=np.inf,
lvl_offset=0,
lbl_fmt="%g",
**kwargs,
):
"""
Add an overlay : contour of other map
"""
map_contour = self.save.get_node("/maps/{}_{}".format(map_name, ax_los)).read()
if log:
map_contour = np.log10(map_contour)
# Computing linewidths
mask_fin = np.isfinite(map_contour)
if lvl_array is None:
lvl_array = np.arange(
np.min(map_contour[mask_fin]), np.max(map_contour[mask_fin]) + 1
)
if lw is None:
lw = np.ones(lvl_array.size) * 2
if lvl_th:
lw[lvl_array >= lvl_th] = lw[lvl_array >= lvl_th] ** (
lvl_th - lvl_array[lvl_array >= lvl_th]
)
lw[lvl_array < lvl_th] = 1.0
cont = plt.contour(
map_contour,
extent=im_extent,
origin="lower",
linewidths=lw,
levels=lvl_array,
**kwargs,
)
# used levels
lvls = np.array(cont.levels) + lvl_offset
cont.levels = lvls
plt.clabel(
cont,
lvls[np.array(lvls) < lvl_max_lbl],
inline=1,
fontsize=8.0,
fmt=lbl_fmt,
)
def _overlay_levels(self, ax_los, im_extent, **kwargs):
"""
Add an overlay : AMR levels
"""
return self._overlay_contour(
ax_los,
im_extent,
"levels",
lbl_fmt="%1d",
lvl_offset=1,
lvl_th=8,
lvl_max_lbl=11,
**kwargs,
)
def _overlay_particles(
self, ax_los, im_extent, unit_space=U.pc, center_space=False, **kwargs
):
"""
Add an overlay with particles data
"""
# Open particle HDF5 file
filename = self.save.get_node("/hdf5/particles").read()[0].decode()
hdf5_parts = tables.open_file(filename, "r")
part_pos = hdf5_parts.get_node("/data/pos").read()
hdf5_parts.close()
unit_length = self.save.root._v_attrs["unit_length"]
# index of the horizontal axis
ih = self._ax_nb[self._axes_h[ax_los]]
# index of the vertical axis
iv = self._ax_nb[self._axes_v[ax_los]]
# horizontal coordinates
part_h = part_pos[:, ih]
part_v = part_pos[:, iv]
if center_space:
ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los]
center = self.save.root.maps._v_attrs.center
center_h = center[self._ax_nb[ax_h]]
center_v = center[self._ax_nb[ax_v]]
part_h -= center_h
part_v -= center_v
part_h *= unit_length.express(unit_space)
part_v *= unit_length.express(unit_space)
# Filter
mask = (
(im_extent[0] <= part_h)
& (part_h <= im_extent[1])
& (im_extent[2] <= part_v)
& (part_v <= im_extent[3])
)
part_h = part_h[mask]
part_v = part_v[mask]
# Scatter plot
plt.scatter(part_h, part_v, **kwargs)
def _overlay_speed(
self, ax_los, im_extent, unit=U.km_s, unit_coeff=1.0, key_v=None, **kwargs
):
"""
Add an overlay : velocity vector field
"""
dmap_vh_node = self.save.get_node("/maps/speed_h_{}".format(ax_los))
dmap_vh = dmap_vh_node.read()
dmap_vv = self.save.get_node("/maps/speed_v_{}".format(ax_los)).read()
label, unit_old, unit = self._ax_label_unit(dmap_vh_node, "", unit, unit_coeff)
vel_red = self.params.plot.vel_red
# take only a subset of velocities
map_vh_red = dmap_vh[::vel_red, ::vel_red] * unit_old.express(unit)
map_vv_red = dmap_vv[::vel_red, ::vel_red] * unit_old.express(unit)
# get norm information
norm_v = np.sqrt(map_vh_red ** 2 + map_vv_red ** 2)
max_v = np.max(norm_v)
min_v = np.min(norm_v)
# Number of selected vectors
nh = map_vh_red.shape[0]
nv = map_vh_red.shape[1]
# Creates vectors position grid
size_h = im_extent[1] - im_extent[0] # size of the map
dh = size_h / dmap_vh.shape[0] # size of cell
seph = size_h / nh # separation between vectors
h = im_extent[0] + dh + np.arange(nh) * seph
size_v = im_extent[3] - im_extent[2]
dv = size_v / dmap_vh.shape[1]
sepv = size_v / nv
v = im_extent[2] + dv + np.arange(nv) * sepv
hh, vv = np.meshgrid(h, v)
# plot vector field
vec_field = plt.quiver(hh, vv, map_vh_red, map_vv_red, units="width", **kwargs)
# add vector key
if key_v is None:
key_v = (max_v + min_v) / 2.0
plt.quiverkey(
vec_field,
0.6,
0.98,
key_v,
r"${:g}$".format(key_v) + label,
labelpos="E",
coordinates="figure",
)
def _overlay_B(self, ax_los, im_extent, **kwargs):
"""
Add an overlay : magnetic streamlines
"""
dmap_Bh_node = self.save.get_node("/maps/B_h_{}".format(ax_los))
dmap_Bh = dmap_Bh_node.read()
dmap_Bv = self.save.get_node("/maps/B_v_{}".format(ax_los)).read()
# TODO : redo this with im_extent
vel_red = self.params.plot.vel_red
radius = self.save.root.maps._v_attrs.radius
center = self.save.root.maps._v_attrs.center
lbox = self.save.root._v_attrs.lbox
map_Bh_red = dmap_Bh[::vel_red, ::vel_red] # take only a subset of velocities
map_Bv_red = dmap_Bv[::vel_red, ::vel_red]
nh = map_Bh_red.shape[0]
nv = map_Bv_red.shape[1]
vec_h = (
np.arange(nh) * 2.0 / nh * radius - radius + center[0] + radius / nh
) * lbox
vec_v = (
np.arange(nv) * 2.0 / nv * radius - radius + center[1] + radius / nv
) * lbox
hh, vv = np.meshgrid(vec_h, vec_v)
plt.streamplot(hh, vv, map_Bh_red, map_Bv_red, **kwargs)
def _plot_hist(
self,
name,
ax_los=None,
run=None,
group="/hist/",
xlabel=None,
unit=None,
unit_coeff=1.0,
ytransform=None,
label=None,
put_title=True,
title=None,
nml_key=None,
put_time=True,
unit_time=U.Myr,
xlog=None,
ylog=False,
kind="bar",
ylabel="$\\mathcal{P}$",
color=None,
colors=None,
nml_color=None,
fit=None,
fitlabel=None,
**kwargs,
):
"""
Plot an histogram (PDF, etc ...)
"""
# Get node
if ax_los is not None:
name = name + "_" + ax_los
node = self.save.get_node(group + name)
if xlog is None:
try:
xlog = node._v_attrs.logbins
except AttributeError:
xlog = False
# get label and units
xlabel, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff)
# Read data
if "mean" in node:
index = node["runs"].read().index(run.encode())
values, centers = node["mean"].read()[index]
else:
values, centers = node.read()
if xlog:
centers = centers + np.log10(unit_old.express(unit))
else:
centers = centers * unit_old.express(unit)
if ytransform is not None:
values = ytransform(values)
width = centers[1] - centers[0]
# Set title
title = self.snapshot_title(run, title, nml_key, put_time, unit_time)
if put_title:
plt.title(title)
if label is None:
label = title
# Set colors
if color is None and colors is not None:
if nml_color is None:
color = colors[run]
else:
nml = self.study.get_nml(nml_color, run)
try:
color = colors[nml]
except TypeError:
color = colors(nml)
# Actual plot
if kind == "bar":
plt.bar(
centers, values, width, log=ylog, color=color, label=label, **kwargs
)
elif kind == "step":
if ylog:
plt.yscale("log")
plt.step(centers, values, where="mid", color=color, label=label, **kwargs)
else:
raise ValueError("kind must be 'bar' or 'step'")
# put labels
if label is not None:
plt.xlabel(xlabel)
if ylabel is not None:
plt.ylabel(ylabel)
# Also diplay fit, previously saved
if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
slope = node.attrs.slope
origin = node.attrs.origin
plt.plot(
centers,
10 ** (slope * centers + origin),
"--",
linewidth=2,
color="orange",
)
# or a new one
if fit is not None:
self._overlay_fit(
centers, values, kind=fit, ls="--", lw=1.5, label=fitlabel
)
# returns PlotInfo (for Galactica)
edges = np.append(centers - width / 2.0, centers[-1] + width / 2.0)
return PlotInfo(
plot_type=PlotType.HISTOGRAM,
xaxis_values=edges,
yaxis_values=values,
xaxis_log_scale=False,
yaxis_log_scale=ylog,
xaxis_label=xlabel,
yaxis_label=ylabel,
xaxis_unit=unit,
yaxis_unit=U.none,
plot_title=title,
)
def _plot(
self,
name_x,
name_y,
node_arg=None,
xlabel=None,
ylabel=None,
label=None,
xunit=None,
yunit=None,
xunit_coeff=1.0,
yunit_coeff=1.0,
xtransform=None,
ytransform=None,
xscale="linear",
yscale="linear",
fit=None,
fitlabel=None,
smooth=0,
nml_key=None,
run=None,
yerr=None,
yerr_kind="std",
sigma_err=2.0,
grid=False,
put_time=False,
unit_time=U.Myr,
colors=None,
nml_color=None,
legend=None,
subname_x=None,
subname_y=None,
**kwargs,
):
"""
Generic plot routine, with name_x and name_y two path in the hdf5 file
"""
# Get proper hdf5 names
if node_arg is not None:
name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg
# Get hdf5 nodes
node_x = self.save.get_node(name_x)
node_y = self.save.get_node(name_y)
# If the actual data is in another file, fetch it
if subname_x:
hdf5_x = tables.open_file(node_x.read())
node_x = hdf5_x.get_node(subname_x)
if subname_y:
hdf5_y = tables.open_file(node_y.read())
node_y = hdf5_y.get_node(subname_y)
# Find proper labels
xlabel, xunit_old, xunit = self._ax_label_unit(
node_x, xlabel, xunit, xunit_coeff
)
ylabel, yunit_old, yunit = self._ax_label_unit(
node_y, ylabel, yunit, yunit_coeff
)
# If relevent, get time
if put_time:
time = self.save.root._v_attrs.time * self.study.info["unit_time"]
time_str = self.params.plot.time_fmt.format(
time.express(unit_time), unit_time.latex.replace("text", "math")
)
time_str = f"${time_str}$"
if label is not None and len(label) > 0:
label = label + " | " + time_str
else:
label = time_str
# Manage the different forms in which the data may be stored :
# Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs)
if node_y._v_attrs.CLASS == "ARRAY":
x = node_x.read() * xunit_old.express(xunit)
y = node_y.read() * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
elif "mean" in node_y:
x = node_x.read() * xunit_old.express(xunit)
y = node_y.mean.read() * yunit_old.express(yunit)
if yerr_kind == "std":
std = node_y.std.read() * yunit_old.express(yunit)
yerr_min = y - sigma_err * std
yerr_max = y + sigma_err * std
elif yerr_kind == "min_max":
yerr_min = node_y.min.read() * yunit_old.express(yunit)
yerr_max = node_y.max.read() * yunit_old.express(yunit)
elif yerr_kind == "95per":
yerr_min = node_y.q025.read() * yunit_old.express(yunit)
yerr_max = node_y.q975.read() * yunit_old.express(yunit)
elif yerr_kind == "68per":
yerr_min = node_y.q16.read() * yunit_old.express(yunit)
yerr_max = node_y.q84.read() * yunit_old.express(yunit)
else:
yerr_min = y
yerr_max = y
yerr = yerr_max - yerr_min
mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(yerr)
x, y, yerr, yerr_min, yerr_max = (
x[mask],
y[mask],
yerr[mask],
yerr_min[mask],
yerr_max[mask],
)
else:
x = node_x[run].read() * xunit_old.express(xunit)
y = node_y[run].read() * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
if isinstance(yerr, str):
yerr = self.save.get_node(yerr).read()
# Apply transformations on x
if xtransform is not None:
x = xtransform(x)
# Apply transformations on y
if ytransform is not None:
y = ytransform(y)
if yerr is not None:
self._log(
"Errorbar may be meaning less when ytransform is used", "WARNING"
)
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
if run is not None:
label = self.get_label_run(run, label, nml_key)
# Look if special colors method is used
if colors is None:
if yerr is None:
(base_line,) = plt.plot(x, y, label=label, **kwargs)
else:
base_line, _, _ = plt.errorbar(x, y, yerr=yerr, label=label, **kwargs)
else:
if nml_color is None:
color = colors[run]
elif nml_color == "time":
time = (
self.save.root._v_attrs.time * self.study.info["unit_time"]
).express(unit_time)
color = colors(time)
else:
nml = self.study.get_nml(nml_color, run)
try:
color = colors[nml]
except TypeError:
color = colors(nml)
if yerr is None:
(base_line,) = plt.plot(x, y, label=label, color=color, **kwargs)
else:
base_line, _, _ = plt.errorbar(
x, y, yerr=yerr, color=color, label=label, **kwargs
)
# Ax decorations
plt.xlabel(xlabel)
plt.ylabel(ylabel)
if grid:
plt.grid()
if legend:
plt.legend()
# Ax scale
plt.xscale(xscale)
plt.yscale(yscale)
if fit is not None:
self._overlay_fit(
x,
y,
yerr,
kind=fit,
ls="--",
lw=1.5,
color=base_line.get_color(),
label=fitlabel,
)
if subname_x:
hdf5_x.close()
if subname_y:
hdf5_y.close()
def _pspec(self, name, **kwargs):
"""
Plot power spectrum (wrapper around pspec_read)
"""
del kwargs["run"]
file_pspec = self.save.get_node("/hdf5/pspec").read()
num = self.save.root._v_attrs.num
getattr(pspec_read, "pspec_" + name)(file_pspec, ".", num, **kwargs)
def _overlay_fit(self, x, y, yerr=None, kind="linear", label=None, **kwargs):
"""
Add an overlay : fit a curve, linear or powerlaw
"""
if kind == "linear":
if yerr is None or np.sum(np.abs(yerr)) == 0:
(a, b, rho, _map_rule, stderr) = linregress(x, y)
self._log(
"Linear fit y = {} x + {} with R^2 = {} and error is {}".format(
a, b, rho, stderr
)
)
if label is None:
label = r"Linear fit with slope ${:.3g}$ and $R^2 = {:.3f}$".format(
a, rho
)
else:
fit = polyfit(x, y, 1, w=[1.0 / ty for ty in yerr], full=True)
c = fit[0]
residual = fit[1][0][0]
b, a = c[0], c[1]
self._log(
"Linear fit y = {} x + {} with residual {}".format(a, b, residual)
)
if label is None:
label = r"Linear fit with slope ${:.3g}$".format(a)
plt.plot(x, a * x + b, label=label, **kwargs)
elif kind == "power_law":
if yerr is None or np.sum(np.abs(yerr)) == 0:
(a, b, rho, _map_rule, stderr) = linregress(np.log10(x), np.log10(y))
self._log(
"Power law fit y = x^({}) * {} with R^2 = {} and error is {}".format(
a, 10 ** b, rho, stderr
)
)
else:
def fitfunc(p, x):
return p[0] + p[1] * x
def errfunc(p, x, y, err):
return (y - fitfunc(p, x)) / err
pinit = [1.0, -1.0]
out = optimize.leastsq(
errfunc,
pinit,
args=(np.log10(x), np.log10(y), yerr / y),
full_output=1,
)
c = out[0]
b, a = c[0], c[1]
residual = errfunc(c, np.log10(x), np.log10(y), yerr / y)
self._log(
"Power law fit y = x^({}) * {} with residual {}".format(
a, 10 ** b, residual
)
)
if label is None:
label = r"Power-law fit with index {:.1f}".format(a)
plt.plot(x, (10 ** b) * x ** a, label=label, **kwargs)
def overlay_kennicutt(self, n0, step):
"""
Add an overlay : Kennicutt mass accretion
"""
plt.grid(False)
ylim = plt.ylim()
(tmin, tmax) = plt.xlim()
tmax = tmax + 20
ymax = plt.ylim()[1]
ssfr_sun = 2.5e-9
ssfr_ken = ssfr_sun * n0 ** 1.4
coeff = ssfr_ken * 1e6 * (self.study.info["unit_length"].express(U.pc)) ** 2
for i in np.arange(tmin, max(tmax, tmin + ymax / coeff), step):
t = np.linspace(0, tmax, 1000)
plt.plot(t + i, t * coeff, ls="--", lw=0.9, color="grey")
plt.plot(t + tmin, (t + i - tmin) * coeff, ls="--", lw=0.9, color="grey")
plt.xlim(tmin, tmax)
plt.ylim(ylim)
def _gen_from_log(self, logrule, name_y, name_x="time", description="Generated"):
if name_x == "time":
name_rule = name_y
else:
name_rule = name_y + "_" + name_x
self.rules[name_rule] = PlotRule(
self,
partial(
self._plot,
"/series/" + logrule + "/" + name_x,
"/series/" + logrule + "/" + name_y,
),
description=description,
kind="run",
dependencies=[logrule],
)
def def_rules(self):
"""
This is where rules are defined
"""
self.rules = {
"plot": PlotRule(
self, lambda arg, **kwargs: self._plot(*arg, **kwargs), kind="comp"
),
"plot_snapshot": PlotRule(
self, lambda arg, **kwargs: self._plot(*arg, **kwargs)
),
"plot_map": PlotRule(
self, lambda mapname, **kwargs: self._plot_map(mapname, **kwargs)
),
"coldens": PlotRule(
self,
partial(
self._plot_map,
"coldens",
label=r"$\Sigma$",
# unit=U.coldens
),
"Column density map",
dependencies=["coldens"],
),
"slice_T": PlotRule(
self,
partial(
self._plot_map,
"T",
label=r"$T$",
),
"Slice of temperature",
dependencies=["T"],
),
"alpha_disk": PlotRule(
self,
partial(self._plot_map, "alpha_disk", label=r"$\alpha$"),
"Map of the Shakura&Sunaev alpha parameter for disks",
dependencies=["alpha_disk"],
),
"alpha_grav": PlotRule(
self,
partial(self._plot_map, "alpha_grav", label=r"$\alpha_g$"),
"Map of the grav Shakura&Sunaev alpha parameter for disks",
dependencies=["alpha_grav"],
),
"coldens_l": PlotRule(
self,
partial(
self._plot_map,
"coldens",
label=r"$\Sigma$",
unit=U.coldens,
overlays=[self._overlay_levels],
),
"Column density with level overlay",
dependencies=["coldens", "levels"],
),
"slice_rho_v": PlotRule(
self,
partial(
self._plot_map,
"slice_rho",
label=r"$\rho$",
unit=U.Msun_pc3,
overlays=[self._overlay_speed],
),
"Density slice with speed overlay",
dependencies=["slice_rho", "speed_h", "speed_v"],
),
"slice_rho_B": PlotRule(
self,
partial(
self._plot_map,
"slice_rho",
label=r"$\rho$",
unit=U.Msun_pc3,
overlays=[self._overlay_B],
),
"Density slice with magnetic field overlay",
dependencies=["slice_rho", "B_h", "B_v"],
),
"slice_rho_B_vel": PlotRule(
self,
partial(
self._plot_map,
"slice_rho",
label=r"$\rho$",
unit=U.Msun_pc3,
overlays=[self._overlay_B, self._overlay_speed],
),
"Density slice with magnetic field and velocity overlay",
dependencies=["slice_rho", "B_h", "B_v", "speed_h", "speed_v"],
),
"jeans_ratio": PlotRule(
self,
partial(
self._plot_map,
"jeans_ratio",
vmin=0.1,
vmax=100,
cmap="RdBu_r",
overlays=[self._overlay_levels],
),
"Jeans' lenght divided by the max resolution",
dependencies=["jeans_ratio", "levels"],
),
"Q": PlotRule(
self,
partial(
self._plot_map,
"Q",
label=r"$Q$",
vmin=0.01,
vmax=100,
cmap="RdBu_r",
),
"Toomre Q parameter for a Keplerian disk",
dependencies=["Q"],
),
"rho_pdf": PlotRule(
self,
partial(self._plot_hist, "rho_pdf"),
"$\rho$-PDF",
dependencies=["rho_pdf"],
),
"rho_pdf_mw": PlotRule(
self,
partial(self._plot_hist, "rho_pdf_mw"),
"Mass weighted $\rho$-PDF",
dependencies=["rho_pdf_mw"],
),
"cos_pdf": PlotRule(
self,
partial(self._plot_hist, "cos_pdf"),
"cos-PDF",
dependencies=["cos_pdf", "mwa_speed"],
),
"avg_coldens_pdf": PlotRule(
self,
partial(
self._plot_hist,
"avg_time_coldens_pdf_z",
group="/comp/",
xlog=True,
put_time=False,
),
"Column density PDF, averaged in time",
kind="runs",
dependencies={"avg_time_coldens_pdf": "z"},
),
"T_pdf": PlotRule(
self,
partial(self._plot_hist, "T_pdf"),
"T-PDF on a 2D slice",
dependencies=["T_pdf"],
),
"P_pdf": PlotRule(
self,
partial(self._plot_hist, "P_pdf"),
"P-PDF on a 2D slice ",
dependencies=["P_pdf"],
),
"B_int": PlotRule(
self,
partial(
self._plot_map, "B_int", label=r"$\mid \mathrm{B} \mid$", unit=U.T
),
"Magnetic intensity map",
dependencies=["B_int"],
),
"Brho": PlotRule(
self,
partial(
self._plot,
"/datasets/Brho/rho",
"/datasets/Brho/B",
label=r"$\mathrm{B} $",
put_time=True,
),
"Brho on a 2D slice ",
dependencies=["Brho"],
),
"Ek_Eb_rho": PlotRule(
self,
partial(
self._plot,
"/datasets/Ek_Eb_rho/rho",
"/datasets/Ek_Eb_rho/Ek_Eb_rho",
label=r"Ek/Eb",
put_time=True,
),
"Ek/Eb on a 2D slice ",
dependencies=["Ek_Eb_rho", "mwa_speed"],
),
"rho_prof": PlotRule(
self,
partial(self._plot, "/profile/axis", "/profile/rho_prof"),
"Density profile",
dependencies=["axis", "rho_prof"],
),
"pspec": PlotRule(self, self._pspec, dependencies={"pspec": None}),
"sbeta": PlotRule(
self,
partial(
self._plot,
"/comp/nml_cloud_params/beta_cool",
"/comp/avg_time_pdf_slope_coldens",
),
"Slope of the Sigma-PDF against cooling beta factor",
kind="comp",
dependencies={
"nml": "cloud_params/beta_cool",
"avg_time_pdf_slope_coldens": None,
},
),
"sbeta_onavg": PlotRule(
self,
partial(
self._plot,
"/comp/sbeta_onavg/beta",
"/comp/sbeta_onavg/slope",
yerr="/comp/sbeta_onavg/stderr",
),
"Slope of the time averaged Sigma-PDF against cooling beta factor",
kind="comp",
dependencies=["sbeta_onavg"],
),
"sink_mass": PlotRule(
self,
partial(
self._plot,
"/series/sinks_from_log/time",
"/series/sinks_from_log/mass_sink",
xunit=U.Myr,
yunit=U.Msun,
),
"Mass of the sinks as a function of time",
kind="run",
dependencies=["sinks_from_log"],
),
"ssm": PlotRule(
self,
partial(
self._plot,
"/series/sinks_from_log/time",
"/series/sinks_from_log/ssm",
xunit=U.Myr,
yunit=U.Msun / U.pc ** 2,
),
"Mass of the sinks as a function of time divided by surface",
kind="run",
dependencies=["ssm"],
),
"assfr": PlotRule(
self,
partial(
self._plot,
"/series/sfr_from_log/time",
"/series/sfr_from_log/sfr",
ylabel="Averaged surfacic SFR",
xunit=U.Myr,
yunit=U.ssfr,
),
kind="run",
dependencies=["sfr_from_log"],
),
"issfr": PlotRule(
self,
partial(
self._plot,
"/series/sinks_from_log/time",
"/series/sinks_from_log/issfr",
ylabel="Surfacic SFR",
xunit=U.Myr,
yunit=U.ssfr,
),
kind="run",
dependencies=["issfr"],
),
"turb_rms": PlotRule(
self,
partial(
self._plot,
"/series/rms_from_log/time",
"/series/rms_from_log/turb_rms",
xunit=U.Myr,
),
"Turbulent RMS",
kind="run",
dependencies=["rms_from_log"],
),
"turb_energy": PlotRule(
self,
partial(
self._plot,
"/series/rms_from_log/time",
"/series/rms_from_log/turb_energy",
xunit=U.Myr,
),
"Turbulent energy",
kind="run",
dependencies=["rms_from_log"],
),
"turb_power": PlotRule(
self,
partial(
self._plot,
"/series/rms_from_log/time",
"/series/rms_from_log/turb_power",
xunit=U.Myr,
),
"Turbulent power",
kind="run",
dependencies=["turb_power"],
),
"sigma": PlotRule(
self,
partial(
self._plot,
"/series/time",
"/series/time_sigma",
ylabel="$\\sigma$",
xunit=U.Myr,
yunit=U.km_s,
),
"Velocity dispersion",
kind="run",
dependencies=["time_sigma"],
),
"mwa_B_int": PlotRule(
self,
partial(
self._plot,
"/series/time",
"/series/time_mwa_B_int",
xunit=U.Myr,
yunit=U.uG,
),
"Magnetic intensity average",
kind="run",
dependencies=["time_mwa_B_int"],
),
"mass": PlotRule(
self,
partial(
self._plot,
"/series/time",
"/series/time_mass",
xunit=U.Myr,
yunit=U.Msun,
),
"Total mass in the box",
kind="run",
dependencies=["time_mass"],
),
"max_fluct_coldens": PlotRule(
self,
partial(
self._plot,
"/series/time",
"/series/time_max_fluct_coldens_z",
ylabel="$\\max(\\Sigma/\\overline{\\Sigma})$",
xunit=U.Myr,
),
"Maximal fluctuation of the column density against time",
kind="run",
dependencies={"time_max_fluct_coldens": "z"},
),
}
averageables = [
"coldens",
"Q",
"T",
"T_mwavg",
"alpha_disk",
"alpha_grav",
]
# Generic rules directly from Ramses fields
for field in self.params.pymses.variables:
def generic_rule(name):
self.rules["slice_" + name] = PlotRule(
self,
partial(self._plot_map, "slice_" + name),
"{} slice".format(name),
dependencies=["slice_" + name],
)
self.rules[name + "_mwavg"] = PlotRule(
self,
partial(self._plot_map, name + "_mwavg"),
"Ax mass-weighted averaged {}".format(name),
dependencies=[name + "_mwavg"],
)
self.rules[name + "_avg"] = PlotRule(
self,
partial(self._plot_map, name + "_avg"),
"Ax averaged {}".format(name),
dependencies=[name + "_avg"],
)
averageables.append("slice_" + name)
averageables.append(name + "_mwavg")
averageables.append(name + "_avg")
# special for vectors
if field in ["g", "vel"]:
# Components
for i, dir in enumerate(["x", "y", "z"]):
generic_rule(field + dir)
# Radial
generic_rule(field + "r")
# Orthoradial
generic_rule(field + "phi")
# Norm
generic_rule(field + "_norm")
else:
generic_rule(field)
for name in averageables:
self.rules["rad_" + name] = PlotRule(
self,
partial(
self._plot,
"/radial/radial_centers",
"/radial/rad_mwavg_" + name,
),
"Azimuthal mass weighted average of {}".format(name),
dependencies=["radial_centers", "rad_mwavg_" + name],
)
self.rules["dispersion_rad_" + name] = PlotRule(
self,
partial(
self._plot,
"/radial/radial_centers",
"/radial/dispersion_rad_" + name,
),
"Radial dispersion of {}".format(name),
dependencies=["radial_centers", "dispersion_rad_" + name],
)
self.rules["avg_map_" + name] = PlotRule(
self,
partial(self._plot_map, "avg_map_" + name),
"Map of the radial average of {}".format(name),
dependencies=["avg_map_" + name],
)
self.rules["mwavg_map_" + name] = PlotRule(
self,
partial(self._plot_map, "mwavg_map_" + name),
"Map of the mass weighted radial average of {}".format(name),
dependencies=["avg_map_" + name],
)
self.rules["fluct_" + name] = PlotRule(
self,
partial(self._plot_map, "fluct_" + name, cmap="RdBu_r"),
"Fluctuation of {}".format(name),
dependencies=["fluct_" + name],
)
self.rules["pdf_" + name] = PlotRule(
self,
partial(self._plot_hist, "pdf_" + name, ylog=True),
"Probability density function of {} fluctuations".format(name),
dependencies=["pdf_" + name],
)
for name_bin in averageables:
if name_bin is not name:
group = "mbb_{}_{}".format(name, name_bin)
self.rules["mbb_" + name + "_" + name_bin] = PlotRule(
self,
partial(self._plot_hist, group, ylabel=r"$\alpha$"),
"Mean of {} by bins of {}".format(name, name_bin),
dependencies=[group],
)
for name in [
"step",
"mcons",
"econs",
"epot",
"ekin",
"eint",
"emag",
"elapsed",
]:
self._gen_from_log("coarse_step_from_log", name)
for name in [
"time",
"mcons",
"econs",
"epot",
"ekin",
"eint",
"emag",
"elapsed",
]:
self._gen_from_log("coarse_step_from_log", name_y=name, name_x="step")
for name in ["fine_step", "dt", "a", "mem_cells", "mem_parts"]:
self._gen_from_log("fine_step_from_log", name)
for name in ["time", "dt", "a", "mem_cells", "mem_parts"]:
self._gen_from_log("fine_step_from_log", name_y=name, name_x="fine_step")
# Dict of overlays
self.overlays = {
"B": self._overlay_B,
"speed": self._overlay_speed,
"levels": self._overlay_levels,
"contour": self._overlay_contour,
"particles": self._overlay_particles,
}
super(Plotter, self).def_rules()