Files
2024-03-22 15:06:07 +01:00

1991 lines
64 KiB
Python

# coding: utf-8
"""
This software is a helper to use pysmes tools to read and analyse RAMSES Outputs.
It's a rule based interface.
This is the plotter module.
@author Noé Brucy 2019-2021
"""
import os
from functools import partial
import matplotlib as mpl
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import numpy as np
import tables
from astrophysix.simdm.datafiles import Datafile, PlotInfo, PlotType
from astrophysix.utils.file import FileType
from numpy.polynomial.polynomial import polyfit
from scipy import optimize
from scipy.ndimage.filters import gaussian_filter1d
from scipy.stats import linregress
import pandas as pd
if os.environ.get("DISPLAY", "") == "":
print("No display found. Using non-interactive Agg backend")
mpl.use("Agg")
import datetime
import matplotlib.pyplot as plt
try:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
except ModuleNotFoundError:
print("WARNING: no movie support (missing module moviepy)")
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from baseprocessor import Rule, BaseProcessor
from aggregator import Aggregator
from studyprocessor import StudyProcessor
from utils.runselector import RunSelector
from utils.units import U, unit_str, convert_exp
try:
import lic
except ModuleNotFoundError:
print("WARNING: no LIC support (missing module lic)")
from matplotlib.cm import ScalarMappable
from astrophysix.simdm.experiment import (
ParameterSetting,
ParameterVisibility,
Simulation,
)
from galactica.ramses_astrophysix import ramses
filetype_from_ext = {ext: ft for ft in FileType for ext in ft.extension_list}
def not_array_error(err):
epy2 = "object does not support indexing"
epy3 = "object is not subscriptable"
return str(err)[-len(epy2) :] == epy2 or str(err)[-len(epy3) :] == epy3
def gethv(map_h, map_v, extent):
# Number of selected vectors
nh = map_h.shape[0]
nv = map_h.shape[1]
# Creates vectors position grid
size_h = extent[1] - extent[0] # size of the map
dh = size_h / map_h.shape[0] # size of cell
seph = size_h / nh # separation between vectors
h = extent[0] + dh + np.arange(nh) * seph
size_v = extent[3] - extent[2]
dv = size_v / map_h.shape[1]
sepv = size_v / nv
v = extent[2] + dv + np.arange(nv) * sepv
return np.meshgrid(h, v)
def streamplot(ax, map_h, map_v, extent, **kwargs):
"""
Add an overlay : streamlines
"""
hh, vv = gethv(map_h, map_v, extent)
ax.streamplot(hh, vv, map_h, map_v, **kwargs)
def quiver(ax, map_h, map_v, extent, key_v=None, lognorm=False, label="", **kwargs):
hh, vv = gethv(map_h, map_v, extent)
# get norm information
norm_v = np.sqrt(map_h**2 + map_v**2)
max_v = np.max(norm_v)
min_v = np.min(norm_v)
if key_v is None:
key_v = (max_v + min_v) / 2.0
key = f"${key_v:g}$ {label}"
if lognorm:
lognorm_v = np.log10(norm_v)
map_h *= lognorm_v / norm_v
map_v *= lognorm_v / norm_v
key_v = np.log10(key_v)
# plot vector field
vec_field = ax.quiver(hh, vv, map_h, map_v, units="width", **kwargs)
# add vector key
ax.quiverkey(
vec_field,
0.6,
0.98,
key_v,
key,
labelpos="E",
coordinates="figure",
)
def line_integral_convolution(ax, map_h, map_v, extent, **kwargs):
"""
from Adnan Ali Ahmad
"""
lic_res = lic.lic(map_v, map_h, length=20) # compute line integral convolution
# Amplify contrast on lic
lim = (0.1, 0.9)
lic_data_clip = np.clip(lic_res, lim[0], lim[1])
lic_data_rgba = ScalarMappable(norm=None, cmap="binary").to_rgba(lic_data_clip)
lic_data_clip_rescale = (lic_data_clip - lim[0]) / (lim[1] - lim[0])
lic_data_rgba[..., 3] = lic_data_clip_rescale * 1
args = [lic_data_rgba]
plot_args = {**kwargs}
plot_args["cmap"] = "binary"
plot_args["extent"] = extent
plot_args["origin"] = "lower"
ax.imshow(*args, **plot_args)
class PlotRule(Rule):
"""
The rule class, specific to plot.
"""
def datafile(self, name, arg):
if arg is not None:
name = name + "_" + str(arg)
return Datafile(
name=name,
description=self.description + " ({})".format(arg),
)
class Plotter(Aggregator, BaseProcessor):
"""
This class loads derived quantities and plot them
"""
solve_self_dep = False
# Axes information
_ax_nb = {"x": 0, "y": 1, "z": 2} # Number of each axes
_axes_h = {"x": "y", "y": "x", "z": "x"} # Associated horizontal axe
_axes_v = {"x": "z", "y": "z", "z": "y"} # Associated vertical axe
_ax_title = {"x": r"$x$", "y": r"$y$", "z": r"$z$"}
G = 1.0 # Gravitational constant
# Conversion table from namelist keys (from amses config file) into LaTex strings
label_convert = {
"turb_rms": "$f_{rms}$",
"beta": "$\\beta$",
"beta_cool": "$\\beta$",
"dens0": "$n_0$",
"coldens0": "$\\Sigma_0$",
"sfr_avg_window": "window",
"bx_bound": "$B_0$",
"levelmax": "$l_{\\max}$",
"levelmin": "$l_{\\min}$",
"comp_frac": "$\chi$",
}
# Conversion table from namelist values (from ramses config file) into LaTex strings
value_str = {
"sfr_avg_window": lambda x: "${:g}$ Myr".format(80 * x),
"Bx": lambda x: "${:.1f}$ $\\mu G$".format(7.6189439 * x),
}
# Conversion table from namelist values (from ramses config file) into suitanle units
value_convert = {
"sfr_avg_window": lambda x: 80 * x, # Myr
"Bx": lambda x: x * 7.6189439,
}
def __init__(
self,
path,
runs=None,
nums=None,
path_out=".",
params=None,
selector=None,
tag=None,
unit_time=U.year,
**kwargs,
):
"""
Create a new Plotter instance. Will select run and outputs via a RunSelector object.
Parameters
----------
path : path to the main folder of the simulations (ex '~/simus/myproject')
runs : list of the runs to consider (ex ['run1', 'run2'])
nums : list or dict of the outputs numbers to consider (ex [3, 5]
or {'run1' : [3, 5], 'run2' : [4, 6])
path_out : Path where the plot will be saved. By default set to `path`
params : Parameters for postprocessing. See params module.
selector : Existing instance of RunSelector, that selects runs and outputs. If set, runs and
nums will be ignored
tag : string to add in the output and data files.
kwargs : Keyword arguments for RunSelector.
"""
# log info
self.log_id = "plotter({})".format(tag)
super(Plotter, self).__init__(path, path_out, params, tag)
# Select runs
if selector is None:
self.selector = RunSelector(
path,
runs,
nums,
self.params.input.nml_filename,
unit_time=unit_time,
**kwargs,
)
else:
self.selector = selector
# Save infos
self.path = path
self.runs = self.selector.runs
self.nums = self.selector.nums
# Get studyprocessor object
self.study = StudyProcessor(
path,
self.runs,
self.nums,
self.path_out,
self.params,
tag=tag,
unit_time=unit_time,
selector=self.selector,
)
# Get postprocesor objets for each run
self.snaps = self.study.snaps
# Define rules
self.def_rules()
# Generate astrophysix's simulations object
if self.params.astrophysix.generate:
self.gen_simus()
# Initialize pointers
self.current_processor = None
def gen_simus(self):
self.simulations = {}
simu_fmt = self.params.astrophysix.simu_fmt
descr_fmt = self.params.astrophysix.descr_fmt
tag = self.params.out.tag
for run in self.runs:
pp = self.snaps[run][self.nums[run][0]]
nml = self.study.namelist[run]
name = simu_fmt.format(run=run, tag=tag, nml=nml)
exec_time = str(datetime.datetime.fromtimestamp(os.stat(pp.path).st_ctime))
exec_time = exec_time.split(".")[0]
description = descr_fmt.format(run=run, tag=tag, nml=nml)
simu = Simulation(
simu_code=ramses,
name=name,
alias=name.upper(),
description=description,
directory_path=pp.path,
execution_time=exec_time,
)
for param in ramses.input_parameters:
value = None
try:
value = self.study.get_nml(param.key, run)
except KeyError as e:
self.logger.warning("key {} not found".format(e))
if value is not None:
try:
param_setting = ParameterSetting(
input_param=param,
value=value,
visibility=ParameterVisibility.BASIC_DISPLAY,
)
simu.parameter_settings.add(param_setting)
except AttributeError:
param_setting = ParameterSetting(
input_param=param,
value=str(value),
visibility=ParameterVisibility.BASIC_DISPLAY,
)
simu.parameter_settings.add(param_setting)
self.simulations[run] = simu
def _not_self_dep(self, name, dep, dep_arg, overwrite, select):
"""
Check if the dependency belongs to the plotter object or to another one (comp, pp, ..)
"""
if dep in self.study.rules:
result = self.study.process(
dep, dep_arg, overwrite, self.overwrite_dep, select
)
if result is not None:
self.just_done.append(result)
else:
super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, select)
def _needs_computation(self, overwrite, plot_filename):
"""
Returns true if the plot needs to be redone
"""
return (
self.params.out.interactive
or overwrite
or not os.path.exists(plot_filename)
)
def _process_rule(
self,
name,
rule,
arg,
overwrite=False,
select=None,
ax=None,
from_cells=False,
movie=False,
movie_fps=15,
**kwargs,
):
"""
Open storage and figure if needed before processing a rule
"""
with plt.rc_context(self.params.rcParams):
# Set full name according to argument
if arg is not None:
name_full = (
name
+ "_"
+ str(arg)
.replace(" ", "")
.replace("[", "")
.replace("]", "")
.replace(",", "_")
.replace("'", "")
.replace("/", "")
)
else:
name_full = name
# get filetype of the output
filetype = filetype_from_ext[self.params.out.ext]
# Select runs and nums
if select is not None:
runs, nums = self.selector.select(**select)
else:
runs = self.runs
nums = self.nums
datafiles = []
if rule.kind == "snapshot" or rule.kind == "cells":
run_num = [(run, num) for run in runs for num in nums[run]]
if movie:
filenames = {run: [] for run in runs}
elif rule.kind == "comp":
run_num = [(None, None)]
if movie:
self.logger.warning(f"No movie possible for rule {name}")
movie = False
else:
run_num = [(run, None) for run in runs]
if movie:
self.logger.warning(f"No movie possible for rule {name}")
movie = False
onefigure = False # If axes are provided, only save/close once
if ax is not None:
onefigure = True
if not movie:
plot_filename = self._find_filename(name_full)
for i, (run, num) in enumerate(run_num):
# Find filename
if movie or not onefigure:
plot_filename = self._find_filename(name_full, run, num)
# Find ax
try:
real_ax = ax[i]
except TypeError as e:
if ax is None:
_, real_ax = plt.subplots(1, 1)
elif not_array_error(e):
real_ax = ax
else:
raise
# Find underlying processor
if rule.kind == "snapshot":
self.current_processor = self.snaps[run][num]
else:
self.current_processor = self.study
# Call plot routine
close = (not onefigure) or (i == len(run_num) - 1)
plot_info = self._plot_rule(
rule,
arg,
plot_filename,
overwrite,
ax=real_ax,
close=close,
run=run,
**kwargs,
)
if movie:
filenames[run].append(plot_filename)
# Save in astrophysix format
if self.params.astrophysix.generate:
df = rule.datafile(name, arg)
df[filetype] = plot_filename
if plot_info is not None:
df.plot_info = plot_info
if num is not None:
snap = self.snaps[run][num].snapshot
if overwrite and df.name in snap.datafiles:
del snap.datafiles[df.name]
elif df.name not in snap.datafiles:
snap.datafiles.add(df)
if snap not in self.simulations[run].snapshots:
self.simulations[run].snapshots.add(snap)
datafiles.append(df)
if movie:
for run in runs:
clip = ImageSequenceClip(filenames[run], fps=movie_fps)
movie_filename = self._find_filename(name_full, run=run, ext=".mp4")
os.makedirs(os.path.dirname(movie_filename), exist_ok=True)
clip.write_videofile(movie_filename)
return datafiles
def _plot_rule(self, rule, arg, plot_filename, overwrite, ax, close=True, **kwargs):
"""
Once all dependencies are met, actually process the rule
"""
plt.sca(ax)
if self._needs_computation(overwrite, plot_filename):
plot_info = rule.process(arg, **kwargs)
if self.params.plot.tight_layout and close:
plt.tight_layout(pad=1)
if self.params.out.save:
os.makedirs(os.path.dirname(plot_filename), exist_ok=True)
plt.savefig(plot_filename)
self.logger.info(f"{plot_filename} plotted")
else:
self.logger.info(f"{os.path.basename(plot_filename)} plotted")
if not self.params.out.interactive and close:
plt.close()
return plot_info
else:
self.logger.info(f"Plot {plot_filename} is already done.")
def _find_filename(self, name_full, run=None, num=None, fmt=None, ext=None):
"""
Determine a filename based on rule name, run, output and parameters
"""
tag_name = self.params.out.tag
if ext is None:
ext = self.params.out.ext
if self.params.out.ext_subfolder:
subfolder = f"/{ext[1:]}"
else:
subfolder = ""
if fmt is None and self.params.out.fmt == "":
if not self.params.out.tag == "":
tag_name = "_" + tag_name
if run is not None and num is not None:
fmt = "{out}/{run}{subfolder}/{name}{tag}_{run}_{num:05}{ext}"
elif run is not None:
fmt = "{out}/{run}{subfolder}/{name}{tag}_{run}{ext}"
else:
fmt = "{out}{subfolder}/{name}{tag}{ext}"
elif fmt is None:
fmt = self.params.out.fmt
nml = None
if run is not None:
nml = self.study.namelist[run]
return fmt.format(
run=run,
name=name_full,
tag=tag_name,
num=num,
nml=nml,
out=self.path_out,
ext=ext,
subfolder=subfolder,
)
def get_label_run(self, run, label=None, nml_key=None, time=None):
"""
Set up a label for the run from the namelist and parameters
"""
def get_label_nml(nml_key):
prop_name = os.path.basename(nml_key)
if prop_name in self.label_convert:
prop_label = self.label_convert[prop_name]
else:
prop_label = prop_name
try:
prop_value = self.study.get_nml(nml_key, run)
except KeyError:
return ""
if prop_name in self.value_str:
prop_value_str = self.value_str[prop_name](prop_value)
elif prop_name in self.value_convert:
prop_value_str = self.value_convert[prop_name](prop_value)
elif type(prop_value) in [int, float]:
prop_value_str = convert_exp(prop_value, digits=4)
else:
prop_value_str = str(prop_value)
return r"{} = {}".format(prop_label, prop_value_str)
def get_label_file(run):
label_filename = f"{self.path}/{run}/{self.params.input.label_filename}"
if os.path.exists(label_filename):
with open(label_filename, "r") as label_file:
label = label_file.readline()[:-1]
label_file.close()
else:
label = run
return label
if nml_key is None and (label is None or len(label) == 0):
label_run = get_label_file(run)
elif nml_key is not None:
if not type(nml_key) == list:
nml_key = [nml_key]
lbl_list = map(get_label_nml, nml_key) # get namelist value
lbl_list = filter(lambda x: len(x) > 0, lbl_list) # Remove void labels
label_run = ", ".join(lbl_list)
if label is not None and len(label) > 0:
label_run = label + " (" + label_run + ")"
else:
label_run = label
return label_run
def _ax_label_unit(self, node_name, label, unit, unit_coeff, put_units=True):
"""
Find appropriate labels for axis
"""
if label is None:
try:
label = self.current_processor.get_attribute(node_name, "label")
except KeyError:
if os.path.basename(node_name) in self.label_convert:
label = self.label_convert[os.path.basename(node_name)]
else:
label = os.path.basename(node_name)
try:
unit_old = self.current_processor.get_attribute(node_name, "unit")
except KeyError:
unit_old = U.none
if unit is None:
unit = unit_old
if put_units:
if not unit_coeff == 1:
base = unit
unit = unit_coeff * unit
label = label + unit_str(unit, base=base)
else:
label = label + unit_str(unit)
return label, unit_old, unit
def snapshot_title(self, run, title, nml_key, put_time, unit_time=U.Myr):
title = self.get_label_run(run, title, nml_key)
if put_time:
time = self.current_processor.info["time"] * self.study.info["unit_time"]
u_str = unit_str(unit_time, format="{unit}")
time_str = self.params.plot.time_fmt.format(time.express(unit_time), u_str)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
return title
def _plot_map(
self,
name,
ax_los,
run,
xlabel=None,
ylabel=None,
label=None,
unit=None,
unit_coeff=1.0,
overlays=[],
overlays_kwargs=[],
title=None,
put_title=True,
nml_key=None,
put_time=True,
unit_time=U.Myr,
put_units=True,
unit_space=U.pc,
center_space=False,
cmap="plasma",
norm="log",
transform=None,
vmin=None,
vmax=None,
scalebar=None,
scalebar_size=1,
axes=True,
colorbar=True,
embeded=False,
text_embeded=None,
text_kwargs={},
colorbar_embeded=None,
axes_indicator=None,
overtext_color="w",
**kwargs,
):
"""
Plot data on a map
"""
ax = plt.gca()
ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los]
im_extent = np.array(self.current_processor.get_attribute("/maps", "im_extent"))
unit_length = self.current_processor.info["unit_length"]
if embeded:
axes = False
# Put a scalebar by default
if scalebar is None:
scalebar = True
if axes_indicator is None:
axes_indicator = True
if colorbar_embeded is None:
colorbar_embeded = True
if text_embeded is None:
text_embeded = True
if center_space:
center = self.current_processor.get_attribute("/maps", "center")
center_h = center[self._ax_nb[ax_h]]
center_v = center[self._ax_nb[ax_v]]
im_extent[:2] = im_extent[:2] - center_h
im_extent[2:] = im_extent[2:] - center_v
im_extent = im_extent * unit_length.express(unit_space)
node_name = f"/maps/{name}_{ax_los}"
dmap = self.current_processor.get_value(node_name)
label, unit_old, unit = self._ax_label_unit(
node_name, label, unit, unit_coeff, put_units
)
dmap = dmap * unit_old.express(unit)
if transform is not None:
dmap = transform(dmap)
if vmin is None:
vmin = np.min(dmap)
if vmax is None:
vmax = np.max(dmap)
if norm == "log":
norm = mpl.colors.LogNorm(vmin=vmin, vmax=vmax)
elif norm == "linear":
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
im = plt.imshow(
dmap, extent=im_extent, origin="lower", norm=norm, cmap=cmap, **kwargs
)
plt.locator_params(axis="both", nbins=self.params.plot.ntick)
if scalebar:
scalebar = AnchoredSizeBar(
plt.gca().transData,
scalebar_size,
f"{scalebar_size} {unit_str(unit_space)[2:-1]}",
"lower left",
pad=1,
color=overtext_color,
frameon=False,
)
plt.gca().add_artist(scalebar)
if axes_indicator:
# A little drawing saying what are the axes
plt.annotate(
"",
xy=(0.97, 0.1),
xycoords="axes fraction",
xytext=(0.865, 0.1),
arrowprops={"arrowstyle": "->", "color": overtext_color},
)
plt.annotate(
"",
xy=(0.87, 0.2),
xycoords="axes fraction",
xytext=(0.87, 0.095),
arrowprops={"arrowstyle": "->", "color": overtext_color},
)
plt.annotate(
self._ax_title[ax_h],
xy=(0.87, 0.2),
xytext=(0.89, 0.05),
color=overtext_color,
xycoords="axes fraction",
)
plt.annotate(
self._ax_title[ax_v],
xy=(0.87, 0.2),
xytext=(0.83, 0.12),
color=overtext_color,
xycoords="axes fraction",
)
if axes:
if xlabel is None:
xlabel = self._ax_title[ax_h]
if ylabel is None:
ylabel = self._ax_title[ax_v]
if put_units:
xlabel = xlabel + unit_str(unit_space)
ylabel = ylabel + unit_str(unit_space)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
else:
plt.xticks([])
plt.yticks([])
if colorbar:
if colorbar_embeded:
cbaxes = inset_axes(
ax, width="10%", height="100%", loc="right", borderpad=0
)
cbar = plt.colorbar(cax=cbaxes, orientation="vertical")
cbaxes.yaxis.set_ticks_position("left")
cbaxes.yaxis.set_label_position("left")
cbaxes.yaxis.set_tick_params(color=overtext_color, which="both")
plt.setp(plt.getp(cbaxes.axes, "yticklabels"), color=overtext_color)
cbar.outline.set_edgecolor(overtext_color)
cbaxes.tick_params(axis="y", direction="in", pad=-25)
plt.sca(ax)
else:
try:
cbar = plt.colorbar(im, cax=plt.gca().cax)
except AttributeError:
cbar = plt.colorbar()
if label is not None:
if colorbar_embeded:
cbar.set_label(" " + label, color=overtext_color, loc="bottom")
else:
cbar.set_label(label)
if put_title:
title = self.snapshot_title(run, title, nml_key, put_time, unit_time)
if text_embeded:
ax.text(
x=0.05,
y=0.91,
s=title,
color=overtext_color,
transform=ax.transAxes,
**text_kwargs,
)
else:
plt.title(title)
for i, plot_overlay in enumerate(overlays):
if plot_overlay in self.overlays:
if plot_overlay == "particles":
plot_overlay = partial(
self.overlays[plot_overlay],
unit_space=unit_space,
center_space=center_space,
)
else:
plot_overlay = self.overlays[plot_overlay]
xlim = plt.xlim()
ylim = plt.ylim()
if len(overlays_kwargs) > i:
plot_overlay(ax_los, im_extent, **overlays_kwargs[i])
else:
plot_overlay(ax_los, im_extent)
plt.xlim(xlim)
plt.ylim(ylim)
if self.params.astrophysix.generate:
return PlotInfo(
plot_type=PlotType.IMAGE,
xaxis_values=np.linspace(im_extent[0], im_extent[1], dmap.shape[0] + 1),
yaxis_values=np.linspace(im_extent[2], im_extent[3], dmap.shape[1] + 1),
values=dmap,
xaxis_log_scale=False,
yaxis_log_scale=False,
values_log_scale=False,
xaxis_label=xlabel,
yaxis_label=ylabel,
values_label=label,
xaxis_unit=unit_space,
yaxis_unit=unit_space,
values_unit=unit,
plot_title=title,
)
def _overlay_contour(
self,
ax_los,
im_extent,
map_name,
log=False,
lvl_array=None,
lw=None,
lvl_th=None,
lvl_max_lbl=np.inf,
lvl_offset=0,
lbl_fmt="%g",
**kwargs,
):
"""
Add an overlay : contour of other map
"""
map_contour = self.current_processor.get_value(
"/maps/{}_{}".format(map_name, ax_los)
)
if log:
map_contour = np.log10(map_contour)
# Computing linewidths
mask_fin = np.isfinite(map_contour)
if lvl_array is None:
lvl_array = np.arange(
np.min(map_contour[mask_fin]), np.max(map_contour[mask_fin]) + 1
)
if lw is None:
lw = np.ones(lvl_array.size) * 2
if lvl_th:
lw[lvl_array >= lvl_th] = lw[lvl_array >= lvl_th] ** (
lvl_th - lvl_array[lvl_array >= lvl_th]
)
lw[lvl_array < lvl_th] = 1.0
cont = plt.contour(
map_contour,
extent=im_extent,
origin="lower",
linewidths=lw,
levels=lvl_array,
**kwargs,
)
# used levels
lvls = np.array(cont.levels) + lvl_offset
cont.levels = lvls
plt.clabel(
cont,
lvls[np.array(lvls) < lvl_max_lbl],
inline=1,
fontsize=8.0,
fmt=lbl_fmt,
)
def _overlay_levels(self, ax_los, im_extent, **kwargs):
"""
Add an overlay : AMR levels
"""
return self._overlay_contour(
ax_los,
im_extent,
"levels",
lbl_fmt="%1d",
lvl_offset=1,
lvl_th=8,
lvl_max_lbl=11,
**kwargs,
)
def _overlay_particles(
self,
ax_los,
im_extent,
unit_space=U.pc,
center_space=False,
parts=True,
sinks=False,
filter_fun=None,
s=None,
c=None,
**kwargs,
):
"""
Add an overlay with particles data
if both sinks and parts are set to true, only sinks are overlayed
filter_fun : function that take an array like value and returns an array of boolean
"""
unit_length = self.current_processor.info["unit_length"]
if sinks:
try:
self.current_processor.load_sinks_rule()
data = pd.DataFrame(
self.current_processor.get_value("/datasets/load_sinks_rule")
)
part_pos = data[["x", "y", "z"]].values
unit_length /= self.current_processor.lbox
except KeyError:
self.current_processor.logger.warning("No sinks particles")
return
elif parts:
# Open particle HDF5 filetype_from_ext
self.current_processor.load_parts()
data = self.current_processor.parts
part_pos = self.current_processor.parts["pos"]
mass = self.current_processor.parts["mass"]
mass *= self.current_processor.info["unit_mass"].express(U.Msun)
self.current_processor.unload_parts()
# index of the horizontal axis
ih = self._ax_nb[self._axes_h[ax_los]]
# index of the vertical axis
iv = self._ax_nb[self._axes_v[ax_los]]
# horizontal coordinates
part_h = part_pos[:, ih]
part_v = part_pos[:, iv]
if center_space:
ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los]
center = self.current_processor.get_attribute("/maps", "center")
center_h = center[self._ax_nb[ax_h]]
center_v = center[self._ax_nb[ax_v]]
part_h -= center_h
part_v -= center_v
part_h *= unit_length.express(unit_space)
part_v *= unit_length.express(unit_space)
# Filter
mask = (
(im_extent[0] <= part_h)
& (part_h <= im_extent[1])
& (im_extent[2] <= part_v)
& (part_v <= im_extent[3])
)
if filter_fun is not None:
mask = mask & filter_fun(data)
part_h = part_h[mask]
part_v = part_v[mask]
# Size and color
if s is None and sinks:
s = data.msink[mask] / 5e3
elif s in data.keys():
s = data[s][mask]
elif callable(s):
s = s(data)[mask]
if c in data.keys():
c = data[c][mask]
elif callable(c):
c = c(data)[mask]
# Scatter plot
scatter = plt.scatter(part_h, part_v, s=s, c=c, **kwargs)
return scatter
def _overlay_vector(
self,
name,
ax_los,
extent,
unit=U.km_s,
unit_coeff=1.0,
reduce_res=1,
kind="quiver",
**kwargs,
):
"""
Add an overlay : vector field
"""
ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los]
self.current_processor.process(f"slice_{name}{ax_h}", ax_los)
self.current_processor.process(f"slice_{name}{ax_v}", ax_los)
map_h = self.current_processor.get_value(f"/maps/slice_{name}{ax_h}_{ax_los}")
map_v = self.current_processor.get_value(f"/maps/slice_{name}{ax_v}_{ax_los}")
label, unit_old, unit = self._ax_label_unit(
f"/maps/slice_{name}{ax_h}_{ax_los}", "", unit, unit_coeff
)
# take only a subset
map_h = map_h[::reduce_res, ::reduce_res] * unit_old.express(unit)
map_v = map_v[::reduce_res, ::reduce_res] * unit_old.express(unit)
if kind == "quiver":
quiver(plt.gca(), map_h, map_v, extent=extent, label=label, **kwargs)
elif kind == "streamplot":
streamplot(plt.gca(), map_h, map_v, extent=extent, **kwargs)
elif kind == "lic":
line_integral_convolution(plt.gca(), map_h, map_v, extent=extent, **kwargs)
def _overlay_speed(self, ax_los, extent, **kwargs):
self._overlay_vector("vel", ax_los, extent, **kwargs)
def _overlay_B(self, ax_los, extent, **kwargs):
self._overlay_vector("Bl", ax_los, extent, **kwargs)
def _plot_hist(
self,
name,
ax_los=None,
run=None,
group="/hist/",
xlabel=None,
unit=None,
unit_coeff=1.0,
ytransform=None,
label=None,
put_title=True,
title=None,
nml_key=None,
put_time=True,
unit_time=U.Myr,
xlog=None,
ylog=False,
kind="bar",
ylabel="$\\mathcal{P}$",
color=None,
colors=None,
nml_color=None,
fit=None,
fitlabel=None,
**kwargs,
):
"""
Plot an histogram (PDF, etc ...)
"""
# Get node
if ax_los is not None:
name = name + "_" + ax_los
node_name = group + name
if xlog is None:
try:
xlog = self.current_processor.get_attribute(node_name, "logbins")
except AttributeError:
xlog = False
# get label and units
xlabel, unit_old, unit = self._ax_label_unit(node_name, label, unit, unit_coeff)
# Read data
node = self.current_processor.get_value(node_name)
if "mean" in node:
index = node["runs"].index(run.encode())
values, centers = node["mean"][index]
else:
values, centers = node
if xlog:
centers = centers + np.log10(unit_old.express(unit))
else:
centers = centers * unit_old.express(unit)
if ytransform is not None:
values = ytransform(values)
width = centers[1] - centers[0]
# Set title
title = self.snapshot_title(run, title, nml_key, put_time, unit_time)
if put_title:
plt.title(title)
if label is None:
label = title
# Set colors
if color is None and colors is not None:
if nml_color is None:
color = colors[run]
elif nml_color == "time":
time = (
self.current_processor.time
* self.current_processor.info["unit_time"]
).express(unit_time)
color = colors(time)
else:
nml_value = self.study.get_nml(nml_color, run)
if os.path.basename(nml_color) in self.value_convert:
nml_value = self.value_convert[os.path.basename(nml_color)](
nml_value
)
try:
color = colors[nml_value]
except TypeError:
color = colors(nml_value)
# Actual plot
if kind == "bar":
plt.bar(
centers, values, width, log=ylog, color=color, label=label, **kwargs
)
elif kind == "step":
if ylog:
plt.yscale("log")
plt.step(centers, values, where="mid", color=color, label=label, **kwargs)
else:
raise ValueError("kind must be 'bar' or 'step'")
# put labels
if label is not None:
plt.xlabel(xlabel)
if ylabel is not None:
plt.ylabel(ylabel)
# Also diplay fit, previously saved
# if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
# slope = node.attrs.slope
# origin = node.attrs.origin
# plt.plot(
# centers,
# 10 ** (slope * centers + origin),
# "--",
# linewidth=2,
# color="orange",
# )
# or a new one
if fit is not None:
self._overlay_fit(
centers, values, kind=fit, ls="--", lw=1.5, label=fitlabel
)
# returns PlotInfo (for Galactica)
if self.params.astrophysix.generate:
edges = np.append(centers - width / 2.0, centers[-1] + width / 2.0)
return PlotInfo(
plot_type=PlotType.HISTOGRAM,
xaxis_values=edges,
yaxis_values=values,
xaxis_log_scale=False,
yaxis_log_scale=ylog,
xaxis_label=xlabel,
yaxis_label=ylabel,
xaxis_unit=unit,
yaxis_unit=U.none,
plot_title=title,
)
def plot(
self,
x: np.array,
y: np.array,
xlabel: str = "",
ylabel: str = "",
label: str = "",
xscale: str = "linear",
yscale: str = "linear",
fit: str = None,
fitlabel: str = None,
smooth: float = 0,
nml_key=None,
run: str = None,
yerr: np.array = None,
grid: bool = False,
put_time: bool = False,
unit_time=U.Myr,
colors=None,
nml_color=None,
legend: bool = False,
**kwargs,
):
"""
Generic plot routine, with x, y two numpy arrauys
"""
# Option to smooth data for readability (beware)
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
# Special label if the plot apply to a given run
if run is not None:
label = self.get_label_run(run, label, nml_key)
# If relevant, get time
if put_time:
time = self.current_processor.time * self.study.info["unit_time"]
time_str = self.params.plot.time_fmt.format(
time.express(unit_time), unit_time.latex.replace("text", "math")
)
time_str = f"${time_str}$"
if len(label) > 0:
label = label + " | " + time_str
else:
label = time_str
# Look if special colors method is used
if colors is None:
if yerr is None:
(base_line,) = plt.plot(x, y, label=label, **kwargs)
else:
base_line, _, _ = plt.errorbar(x, y, yerr=yerr, label=label, **kwargs)
else:
if nml_color is None:
color = colors[run]
elif nml_color == "time":
time = (
self.current_processor.time
* self.current_processor.info["unit_time"]
).express(unit_time)
color = colors(time)
else:
nml_value = self.study.get_nml(nml_color, run)
if os.path.basename(nml_color) in self.value_convert:
nml_value = self.value_convert[os.path.basename(nml_color)](
nml_value
)
try:
color = colors[nml_value]
except TypeError:
color = colors(nml_value)
if yerr is None:
(base_line,) = plt.plot(x, y, label=label, color=color, **kwargs)
else:
base_line, _, _ = plt.errorbar(
x, y, yerr=yerr, color=color, label=label, **kwargs
)
# Ax decorations
plt.xlabel(xlabel)
plt.ylabel(ylabel)
if grid:
plt.grid()
if legend:
plt.legend()
# Ax scale
plt.xscale(xscale)
plt.yscale(yscale)
if fit is not None:
self._overlay_fit(
x,
y,
yerr,
kind=fit,
ls="--",
lw=1.5,
color=base_line.get_color(),
label=fitlabel,
)
def _plot(
self,
name_x: str,
name_y: str,
node_arg=None,
xlabel=None,
ylabel=None,
xunit=None,
yunit=None,
put_units=True,
xunit_coeff=1.0,
yunit_coeff=1.0,
xtransform=None,
ytransform=None,
run=None,
yerr=None,
yerr_kind="std",
sigma_err=2.0,
subname_x=None,
subname_y=None,
wait_until_over=None,
**kwargs,
):
"""
Generic plot routine, with name_x and name_y two path in the hdf5 file
"""
# Get proper hdf5 names
if node_arg is not None:
name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg
# Get hdf5 nodes
node_x = self.current_processor.get_value(name_x)
node_y = self.current_processor.get_value(name_y)
# If the actual data is in another file, fetch it
if subname_x:
hdf5_x = tables.open_file(node_x)
node_x = hdf5_x.get_node(subname_x).read()
if subname_y:
hdf5_y = tables.open_file(node_y)
node_y = hdf5_y.get_node(subname_y).read()
# Find proper labels
xlabel, xunit_old, xunit = self._ax_label_unit(
name_x,
xlabel,
xunit,
xunit_coeff,
put_units=put_units,
)
ylabel, yunit_old, yunit = self._ax_label_unit(
name_y,
ylabel,
yunit,
yunit_coeff,
put_units=put_units,
)
# Manage the different forms in which the data may be stored :
# Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs)
if isinstance(node_y, np.ndarray):
x = node_x * xunit_old.express(xunit)
y = node_y * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
elif "mean" in node_y:
x = node_x * xunit_old.express(xunit)
y = node_y["mean"] * yunit_old.express(yunit)
if yerr_kind == "std":
std = node_y["std"] * yunit_old.express(yunit)
yerr_min = y - sigma_err * std
yerr_max = y + sigma_err * std
elif yerr_kind == "min_max":
yerr_min = node_y["min"] * yunit_old.express(yunit)
yerr_max = node_y["max"] * yunit_old.express(yunit)
elif yerr_kind == "95per":
yerr_min = node_y["q025"] * yunit_old.express(yunit)
yerr_max = node_y["q975"] * yunit_old.express(yunit)
elif yerr_kind == "68per":
yerr_min = node_y["q16"] * yunit_old.express(yunit)
yerr_max = node_y["q84"] * yunit_old.express(yunit)
else:
yerr_min = y
yerr_max = y
yerr = yerr_max - yerr_min
mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(yerr)
x, y, yerr, yerr_min, yerr_max = (
x[mask],
y[mask],
yerr[mask],
yerr_min[mask],
yerr_max[mask],
)
else:
x = node_x[run] * xunit_old.express(xunit)
y = node_y[run] * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
if isinstance(yerr, str):
yerr = self.current_processor.get_value(yerr)
# Apply transformations on x
if xtransform is not None:
x = xtransform(x)
# Apply transformations on y
if ytransform is not None:
y = ytransform(y)
if yerr is not None:
self.logger.warning(
"Errorbar may be meaningless when ytransform is used"
)
# Offset to start x when y in over a given value
if wait_until_over is not None:
offset = np.argmax(y > wait_until_over)
x = x - x[offset]
self.plot(x, y, yerr=yerr, xlabel=xlabel, ylabel=ylabel, run=run, **kwargs)
if subname_x:
hdf5_x.close()
if subname_y:
hdf5_y.close()
def _overlay_fit(self, x, y, yerr=None, kind="linear", label=None, **kwargs):
"""
Add an overlay : fit a curve, linear or powerlaw
"""
if kind == "linear":
if yerr is None or np.sum(np.abs(yerr)) == 0:
(a, b, rho, _map_rule, stderr) = linregress(x, y)
self.logger.info(
"Linear fit y = {} x + {} with R^2 = {} and error is {}".format(
a, b, rho, stderr
)
)
if label is None:
label = r"Linear fit with slope ${:.3g}$ and $R^2 = {:.3f}$".format(
a, rho
)
else:
fit = polyfit(x, y, 1, w=[1.0 / ty for ty in yerr], full=True)
c = fit[0]
residual = fit[1][0][0]
b, a = c[0], c[1]
self.logger.info(
"Linear fit y = {} x + {} with residual {}".format(a, b, residual)
)
if label is None:
label = r"Linear fit with slope ${:.3g}$".format(a)
plt.plot(x, a * x + b, label=label, **kwargs)
elif kind == "power_law":
if yerr is None or np.sum(np.abs(yerr)) == 0:
(a, b, rho, _map_rule, stderr) = linregress(np.log10(x), np.log10(y))
self.logger.info(
"Power law fit y = x^({}) * {} with R^2 = {} and error is {}".format(
a, 10**b, rho, stderr
)
)
else:
def fitfunc(p, x):
return p[0] + p[1] * x
def errfunc(p, x, y, err):
return (y - fitfunc(p, x)) / err
pinit = [1.0, -1.0]
out = optimize.leastsq(
errfunc,
pinit,
args=(np.log10(x), np.log10(y), yerr / y),
full_output=1,
)
c = out[0]
b, a = c[0], c[1]
residual = errfunc(c, np.log10(x), np.log10(y), yerr / y)
self.logger.info(
"Power law fit y = x^({}) * {} with residual {}".format(
a, 10**b, residual
)
)
if label is None:
label = r"Power-law fit with index {:.1f}".format(a)
plt.plot(x, (10**b) * x**a, label=label, **kwargs)
def _gen_from_log(self, logrule, name_y, name_x="time", description="Generated"):
if name_x == "time":
name_rule = name_y
else:
name_rule = name_y + "_" + name_x
self.rules[name_rule] = PlotRule(
partial(
self._plot,
"/series/" + logrule + "/" + name_x,
"/series/" + logrule + "/" + name_y,
),
description=description,
kind="run",
dependencies=[logrule],
)
def def_rules(self):
"""
This is where rules are defined
"""
self.rules = {
"plot_comp": PlotRule(
lambda arg, **kwargs: self._plot(*arg, **kwargs), kind="comp"
),
"plot_run": PlotRule(
lambda arg, **kwargs: self._plot(*arg, **kwargs), kind="run"
),
"plot_snapshot": PlotRule(lambda arg, **kwargs: self._plot(*arg, **kwargs)),
"plot_map": PlotRule(
lambda mapname, **kwargs: self._plot_map(mapname, **kwargs)
),
"coldens": PlotRule(
partial(
self._plot_map,
"coldens",
label=r"$\Sigma$",
# unit=U.coldens
),
"Column density map",
dependencies=["coldens"],
),
"slice_T": PlotRule(
partial(
self._plot_map,
"T",
label=r"$T$",
),
"Slice of temperature",
dependencies=["T"],
),
"alpha_disk": PlotRule(
partial(self._plot_map, "alpha_disk", label=r"$\alpha$"),
"Map of the Shakura&Sunaev alpha parameter for disks",
dependencies=["alpha_disk"],
),
"alpha_grav": PlotRule(
partial(self._plot_map, "alpha_grav", label=r"$\alpha_g$"),
"Map of the grav Shakura&Sunaev alpha parameter for disks",
dependencies=["alpha_grav"],
),
"coldens_l": PlotRule(
partial(
self._plot_map,
"coldens",
label=r"$\Sigma$",
unit=U.coldens,
overlays=[self._overlay_levels],
),
"Column density with level overlay",
dependencies=["coldens", "levels"],
),
"slice_rho_v": PlotRule(
partial(
self._plot_map,
"slice_rho",
label=r"$\rho$",
unit=U.Msun_pc3,
overlays=[self._overlay_speed],
),
"Density slice with speed overlay",
dependencies=["slice_rho"],
),
"jeans_ratio": PlotRule(
partial(
self._plot_map,
"jeans_ratio",
vmin=0.1,
vmax=100,
cmap="RdBu_r",
overlays=[self._overlay_levels],
),
"Jeans' lenght divided by the max resolution",
dependencies=["jeans_ratio", "levels"],
),
"Q": PlotRule(
partial(
self._plot_map,
"Q",
label=r"$Q$",
vmin=0.01,
vmax=100,
cmap="RdBu_r",
),
"Toomre Q parameter for a Keplerian disk",
dependencies=["Q"],
),
"rho_pdf": PlotRule(
partial(self._plot_hist, "rho_pdf"),
"$\rho$-PDF",
dependencies=["rho_pdf"],
),
"rho_pdf_mw": PlotRule(
partial(self._plot_hist, "rho_pdf_mw"),
"Mass weighted $\rho$-PDF",
dependencies=["rho_pdf_mw"],
),
"cos_pdf": PlotRule(
partial(self._plot_hist, "cos_pdf"),
"cos-PDF",
dependencies=["cos_pdf", "mwa_speed"],
),
"avg_coldens_pdf": PlotRule(
partial(
self._plot_hist,
"avg_time_coldens_pdf_z",
group="/comp/",
xlog=True,
put_time=False,
),
"Column density PDF, averaged in time",
kind="runs",
dependencies={"avg_time_coldens_pdf": "z"},
),
"T_pdf": PlotRule(
partial(self._plot_hist, "T_pdf"),
"T-PDF on a 2D slice",
dependencies=["T_pdf"],
),
"P_pdf": PlotRule(
partial(self._plot_hist, "P_pdf"),
"P-PDF on a 2D slice ",
dependencies=["P_pdf"],
),
"Brho": PlotRule(
partial(
self._plot,
"/datasets/Brho/rho",
"/datasets/Brho/B",
label=r"$\mathrm{B} $",
put_time=True,
),
"Brho on a 2D slice ",
dependencies=["Brho"],
),
"Ek_Eb_rho": PlotRule(
partial(
self._plot,
"/datasets/Ek_Eb_rho/rho",
"/datasets/Ek_Eb_rho/Ek_Eb_rho",
label=r"Ek/Eb",
put_time=True,
),
"Ek/Eb on a 2D slice ",
dependencies=["Ek_Eb_rho", "mwa_speed"],
),
"rho_prof": PlotRule(
partial(self._plot, "/profile/axis", "/profile/rho_prof"),
"Density profile",
dependencies=["axis", "rho_prof"],
),
"sbeta": PlotRule(
partial(
self._plot,
"/comp/nml_cloud_params/beta_cool",
"/comp/avg_time_pdf_slope_coldens",
),
"Slope of the Sigma-PDF against cooling beta factor",
kind="comp",
dependencies={
"nml": "cloud_params/beta_cool",
"avg_time_pdf_slope_coldens": None,
},
),
"sbeta_onavg": PlotRule(
partial(
self._plot,
"/comp/sbeta_onavg/beta",
"/comp/sbeta_onavg/slope",
yerr="/comp/sbeta_onavg/stderr",
),
"Slope of the time averaged Sigma-PDF against cooling beta factor",
kind="comp",
dependencies=["sbeta_onavg"],
),
"sink_mass": PlotRule(
partial(
self._plot,
"/series/sinks_from_log/time",
"/series/sinks_from_log/mass_sink",
xunit=U.Myr,
yunit=U.Msun,
),
"Mass of the sinks as a function of time",
kind="run",
dependencies=["sinks_from_log"],
),
"ssm": PlotRule(
partial(
self._plot,
"/series/sinks_from_log/time",
"/series/sinks_from_log/ssm",
xunit=U.Myr,
yunit=U.Msun / U.pc**2,
),
"Mass of the sinks as a function of time divided by surface",
kind="run",
dependencies=["ssm"],
),
"assfr": PlotRule(
partial(
self._plot,
"/series/sfr_from_log/time",
"/series/sfr_from_log/sfr",
ylabel="Averaged surfacic SFR",
xunit=U.Myr,
yunit=U.ssfr,
),
kind="run",
dependencies=["sfr_from_log"],
),
"issfr": PlotRule(
partial(
self._plot,
"/series/sinks_from_log/time",
"/series/sinks_from_log/issfr",
ylabel="Surfacic SFR",
xunit=U.Myr,
yunit=U.ssfr,
),
kind="run",
dependencies=["issfr"],
),
"turb_rms": PlotRule(
partial(
self._plot,
"/series/rms_from_log/time",
"/series/rms_from_log/turb_rms",
xunit=U.Myr,
),
"Turbulent RMS",
kind="run",
dependencies=["rms_from_log"],
),
"turb_energy": PlotRule(
partial(
self._plot,
"/series/rms_from_log/time",
"/series/rms_from_log/turb_energy",
xunit=U.Myr,
),
"Turbulent energy",
kind="run",
dependencies=["rms_from_log"],
),
"turb_power": PlotRule(
partial(
self._plot,
"/series/rms_from_log/time",
"/series/rms_from_log/turb_power",
xunit=U.Myr,
),
"Turbulent power",
kind="run",
dependencies=["turb_power"],
),
"sigma": PlotRule(
partial(
self._plot,
"/series/time",
"/series/time_sigma",
ylabel="$\\sigma$",
xunit=U.Myr,
yunit=U.km_s,
),
"Velocity dispersion",
kind="run",
dependencies=["time_sigma"],
),
"mwa_B_int": PlotRule(
partial(
self._plot,
"/series/time",
"/series/time_mwa_B_int",
xunit=U.Myr,
yunit=U.uG,
),
"Magnetic intensity average",
kind="run",
dependencies=["time_mwa_B_int"],
),
"mass": PlotRule(
partial(
self._plot,
"/series/time",
"/series/time_mass",
xunit=U.Myr,
yunit=U.Msun,
),
"Total mass in the box",
kind="run",
dependencies=["time_mass"],
),
"max_fluct_coldens": PlotRule(
partial(
self._plot,
"/series/time",
"/series/time_max_fluct_coldens_z",
ylabel="$\\max(\\Sigma/\\overline{\\Sigma})$",
xunit=U.Myr,
),
"Maximal fluctuation of the column density against time",
kind="run",
dependencies={"time_max_fluct_coldens": "z"},
),
}
averageables = [
"coldens",
"Q",
"T",
"T_mwavg",
"alpha_disk",
"alpha_grav",
]
# Generic rules directly from Ramses fields
for field in self.params.pymses.variables:
def generic_rule(name):
self.rules["slice_" + name] = PlotRule(
partial(self._plot_map, "slice_" + name),
"{} slice".format(name),
dependencies=["slice_" + name],
)
self.rules[name + "_mwavg"] = PlotRule(
partial(self._plot_map, name + "_mwavg"),
"Ax mass-weighted averaged {}".format(name),
dependencies=[name + "_mwavg"],
)
self.rules[name + "_avg"] = PlotRule(
partial(self._plot_map, name + "_avg"),
"Ax averaged {}".format(name),
dependencies=[name + "_avg"],
)
averageables.append("slice_" + name)
averageables.append(name + "_mwavg")
averageables.append(name + "_avg")
# special for vectors
if field in ["g", "vel", "Bl", "Br"]:
# Components
for i, dir in enumerate(["x", "y", "z"]):
generic_rule(field + dir)
# Radial
generic_rule(field + "r")
# Orthoradial
generic_rule(field + "phi")
# Norm
generic_rule(field + "_norm")
else:
generic_rule(field)
for name in averageables:
self.rules["rad_" + name] = PlotRule(
partial(
self._plot,
"/radial/radial_centers",
"/radial/rad_mwavg_" + name,
),
"Azimuthal mass weighted average of {}".format(name),
dependencies=["radial_centers", "rad_mwavg_" + name],
)
self.rules["dispersion_rad_" + name] = PlotRule(
partial(
self._plot,
"/radial/radial_centers",
"/radial/dispersion_rad_" + name,
),
"Radial dispersion of {}".format(name),
dependencies=["radial_centers", "dispersion_rad_" + name],
)
self.rules["avg_map_" + name] = PlotRule(
partial(self._plot_map, "avg_map_" + name),
"Map of the radial average of {}".format(name),
dependencies=["avg_map_" + name],
)
self.rules["mwavg_map_" + name] = PlotRule(
partial(self._plot_map, "mwavg_map_" + name),
"Map of the mass weighted radial average of {}".format(name),
dependencies=["avg_map_" + name],
)
self.rules["fluct_" + name] = PlotRule(
partial(self._plot_map, "fluct_" + name, cmap="RdBu_r"),
"Fluctuation of {}".format(name),
dependencies=["fluct_" + name],
)
self.rules["pdf_" + name] = PlotRule(
partial(self._plot_hist, "pdf_" + name, ylog=True),
"Probability density function of {} fluctuations".format(name),
dependencies=["pdf_" + name],
)
for name_bin in averageables:
if name_bin is not name:
group = "mbb_{}_{}".format(name, name_bin)
self.rules["mbb_" + name + "_" + name_bin] = PlotRule(
self,
partial(self._plot_hist, group, ylabel=r"$\alpha$"),
"Mean of {} by bins of {}".format(name, name_bin),
dependencies=[group],
)
for name in [
"step",
"mcons",
"econs",
"epot",
"ekin",
"eint",
"emag",
"elapsed",
]:
self._gen_from_log("coarse_step_from_log", name)
for name in [
"time",
"mcons",
"econs",
"epot",
"ekin",
"eint",
"emag",
"elapsed",
]:
self._gen_from_log("coarse_step_from_log", name_y=name, name_x="step")
for name in ["fine_step", "dt", "a", "mem_cells", "mem_parts"]:
self._gen_from_log("fine_step_from_log", name)
for name in ["time", "dt", "a", "mem_cells", "mem_parts"]:
self._gen_from_log("fine_step_from_log", name_y=name, name_x="fine_step")
self._gen_from_log("SN_momentum_from_log", name_x="time", name_y="SN_momentum")
# Dict of overlays
self.overlays = {
"g": partial(self._overlay_vector, "g"),
"B": self._overlay_B,
"vel": self._overlay_speed,
"speed": self._overlay_speed,
"levels": self._overlay_levels,
"contour": self._overlay_contour,
"particles": self._overlay_particles,
}
super(Plotter, self).def_rules()