Files
pipeline/plotter.py
T
2021-10-25 17:37:02 +02:00

1882 lines
61 KiB
Python

# coding: utf-8
"""
This software is a helper to use pysmes tools to read and analyse RAMSES Outputs.
It's a rule based interface.
This is the plotter module.
@author Noé Brucy 2019-2021
"""
import os
from functools import partial
import matplotlib as mpl
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import numpy as np
import tables
from astrophysix.simdm.datafiles import Datafile, PlotInfo, PlotType
from astrophysix.utils.file import FileType
from numpy.polynomial.polynomial import polyfit
from scipy import optimize
from scipy.ndimage.filters import gaussian_filter1d
from scipy.stats import linregress
import pandas as pd
if os.environ.get("DISPLAY", "") == "":
print("No display found. Using non-interactive Agg backend")
mpl.use("Agg")
import datetime
import matplotlib.pyplot as plt
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import pspec_read
from baseprocessor import Rule, BaseProcessor
from aggregator import Aggregator
from studyprocessor import StudyProcessor
from run_selector import RunSelector
from units import U, unit_str, convert_exp
from astrophysix.simdm.experiment import (
ParameterSetting,
ParameterVisibility,
Simulation,
)
from ramses_astrophysix import ramses
filetype_from_ext = {ext: ft for ft in FileType for ext in ft.extension_list}
def not_array_error(err):
epy2 = "object does not support indexing"
epy3 = "object is not subscriptable"
return str(err)[-len(epy2) :] == epy2 or str(err)[-len(epy3) :] == epy3
class PlotRule(Rule):
"""
The rule class, speficic to plot.
Add an extra method, plot, that take the reference to an open hdf5 file (from pytables)
"""
def datafile(self, name, arg):
if arg is not None:
name = name + "_" + str(arg)
return Datafile(
name=name,
description=self.description + " ({})".format(arg),
)
class Plotter(Aggregator, BaseProcessor):
"""
This class loads derived quantities and plot them
"""
solve_self_dep = False
# Axes information
_ax_nb = {"x": 0, "y": 1, "z": 2} # Number of each axes
_axes_h = {"x": "y", "y": "x", "z": "x"} # Associated horizontal axe
_axes_v = {"x": "z", "y": "z", "z": "y"} # Associated vertical axe
_ax_title = {"x": r"$x$", "y": r"$y$", "z": r"$z$"}
G = 1.0 # Gravitational constant
# Conversion table from namelist keys (from amses config file) into LaTex strings
label_convert = {
"turb_rms": "$f_{rms}$",
"beta": "$\\beta$",
"beta_cool": "$\\beta$",
"dens0": "$n_0$",
"coldens0": "$\\Sigma_0$",
"sfr_avg_window": "window",
"bx_bound": "$B_0$",
"levelmax": "$l_{\\max}$",
"levelmin": "$l_{\\min}$",
"comp_frac": "$1 - \\zeta$",
}
# Conversion table from namelist values (from amses config file) into LaTex strings
value_convert = {
"sfr_avg_window": lambda x: "${:g}$ Myr".format(80 * x),
"bx_bound": lambda x: "${:g}$ $\\mu G$".format(5.267501272979475 * x),
}
def __init__(
self,
path,
runs=None,
nums=None,
path_out=".",
params=None,
selector=None,
tag=None,
unit_time=U.year,
**kwargs,
):
"""
Create a new Plotter instance. Will select run and outputs via a RunSelector object.
Parameters
----------
path : path to the main folder of the simulations (ex '~/simus/myproject')
runs : list of the runs to consider (ex ['run1', 'run2'])
nums : list or dict of the outputs numbers to consider (ex [3, 5]
or {'run1' : [3, 5], 'run2' : [4, 6])
path_out : Path where the plot will be saved. By default set to `path`
params : Parameters for postprocessing. See params module.
selector : Existing instance of RunSelector, that selects runs and outputs. If set, runs and
nums will be ignored
tag : string to add in the output and data files.
kwargs : Keyword arguments for RunSelector.
"""
super(Plotter, self).__init__(path, path_out, params, tag)
# Select runs
if selector is None:
self.selector = RunSelector(
path, runs, nums, self.params.input.nml_filename, **kwargs
)
else:
self.selector = selector
# Save infos
self.path = path
self.runs = self.selector.runs
self.nums = self.selector.nums
# Get studyprocessor object
self.study = StudyProcessor(
path,
self.runs,
self.nums,
self.path_out,
self.params,
unit_time=unit_time,
selector=self.selector,
)
# Get postprocesor objets for each run
self.snaps = self.study.snaps
# Define log prefix
self.log_id = "[plot {}] ".format(self.params.out.tag)
# Define rules
self.def_rules()
# Generate astrophysix's simulations object
self.gen_simus()
# Initialize pointers
self.current_processor = None
def gen_simus(self):
self.simulations = {}
simu_fmt = self.params.astrophysix.simu_fmt
descr_fmt = self.params.astrophysix.descr_fmt
tag = self.params.out.tag
for run in self.runs:
pp = self.snaps[run][self.nums[run][0]]
nml = self.study.namelist[run]
name = simu_fmt.format(run=run, tag=tag, nml=nml)
exec_time = str(datetime.datetime.fromtimestamp(os.stat(pp.path).st_ctime))
exec_time = exec_time.split(".")[0]
description = descr_fmt.format(run=run, tag=tag, nml=nml)
simu = Simulation(
simu_code=ramses,
name=name,
alias=name.upper(),
description=description,
directory_path=pp.path,
execution_time=exec_time,
)
for param in ramses.input_parameters:
value = None
try:
value = self.study.get_nml(param.key, run)
except KeyError as e:
self._log("key {} not found".format(e), "WARNING")
if value is not None:
try:
param_setting = ParameterSetting(
input_param=param,
value=value,
visibility=ParameterVisibility.BASIC_DISPLAY,
)
simu.parameter_settings.add(param_setting)
except AttributeError:
param_setting = ParameterSetting(
input_param=param,
value=str(value),
visibility=ParameterVisibility.BASIC_DISPLAY,
)
simu.parameter_settings.add(param_setting)
self.simulations[run] = simu
def _not_self_dep(self, name, dep, dep_arg, overwrite, select):
"""
Check if the dependency belongs to the plotter object or to another one (comp, pp, ..)
"""
if dep in self.study.rules:
result = self.study.process(
dep, dep_arg, overwrite, self.overwrite_dep, select
)
if result is not None:
self.just_done.append(result)
else:
super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, select)
def _needs_computation(self, overwrite, plot_filename):
"""
Returns true if the plot needs to be redone
"""
return (
self.params.out.interactive
or overwrite
or not os.path.exists(plot_filename)
)
def _process_rule(
self,
name,
rule,
arg,
overwrite=False,
select=None,
ax=None,
from_cells=False,
movie=False,
movie_fps=15,
**kwargs,
):
"""
Open storage and figure if needed before processing a rule
"""
with plt.rc_context(self.params.rcParams):
# Set full name according to argument
if arg is not None:
name_full = (
name
+ "_"
+ str(arg)
.replace(" ", "")
.replace("[", "")
.replace("]", "")
.replace(",", "_")
.replace("'", "")
.replace("/", "")
)
else:
name_full = name
# get filetype of the output
filetype = filetype_from_ext[self.params.out.ext]
# Select runs and nums
if select is not None:
runs, nums = self.selector.select(**select)
else:
runs = self.runs
nums = self.nums
datafiles = []
if rule.kind == "snapshot" or rule.kind == "cells":
run_num = [(run, num) for run in runs for num in nums[run]]
if movie:
filenames = {run: [] for run in runs}
elif rule.kind == "comp":
run_num = [(None, None)]
if movie:
self._log(f"No movie possible for rule {name}", "WARNING")
movie = False
else:
run_num = [(run, None) for run in runs]
if movie:
self._log(f"No movie possible for rule {name}", "WARNING")
movie = False
onefigure = False # If axes are provided, only save/close once
if ax is not None:
onefigure = True
plot_filename = self._find_filename(name_full)
for i, (run, num) in enumerate(run_num):
# Find filename
if not onefigure:
plot_filename = self._find_filename(name_full, run, num)
# Find ax
try:
real_ax = ax[i]
except TypeError as e:
if ax is None:
_, real_ax = plt.subplots(1, 1)
elif not_array_error(e):
real_ax = ax
else:
raise
# Find underlying processor
if rule.kind == "snapshot":
self.current_processor = self.snaps[run][num]
else:
self.current_processor = self.study
# Call plot routine
close = (not onefigure) or (i == len(run_num) - 1)
plot_info = self._plot_rule(
rule,
arg,
plot_filename,
overwrite,
ax=real_ax,
close=close,
run=run,
**kwargs,
)
# Save in astrophysix format
df = rule.datafile(name, arg)
df[filetype] = plot_filename
if movie:
filenames[run].append(plot_filename)
if plot_info is not None:
df.plot_info = plot_info
if num is not None:
snap = self.snaps[run][num].snapshot
if overwrite and df.name in snap.datafiles:
del snap.datafiles[df.name]
elif df.name not in snap.datafiles:
snap.datafiles.add(df)
if snap not in self.simulations[run].snapshots:
self.simulations[run].snapshots.add(snap)
datafiles.append(df)
if movie:
for run in runs:
clip = ImageSequenceClip(filenames[run], fps=movie_fps)
movie_filename = self._find_filename(name_full, run=run, ext=".mp4")
os.makedirs(os.path.dirname(movie_filename), exist_ok=True)
clip.write_videofile(movie_filename)
return datafiles
def _plot_rule(self, rule, arg, plot_filename, overwrite, ax, close=True, **kwargs):
"""
Once all dependencies are met, actually process the rule
"""
plt.sca(ax)
if self._needs_computation(overwrite, plot_filename):
plot_info = rule.process(arg, **kwargs)
if self.params.plot.tight_layout and close:
plt.tight_layout(pad=1)
if self.params.out.save:
os.makedirs(os.path.dirname(plot_filename), exist_ok=True)
plt.savefig(plot_filename)
self._log("{} plotted".format(plot_filename), "SUCCESS")
else:
self._log(
"{} plotted".format(os.path.basename(plot_filename)), "SUCCESS"
)
if not self.params.out.interactive and close:
plt.close()
return plot_info
else:
self._log("Plot {} is already done, skipping...".format(plot_filename))
def _find_filename(self, name_full, run=None, num=None, fmt=None, ext=None):
"""
Determine a filename based on rule name, run, output and parameters
"""
tag_name = self.params.out.tag
if ext is None:
ext = self.params.out.ext
if self.params.out.ext_subfolder:
subfolder = f"/{ext[1:]}"
else:
subfolder = ""
if fmt is None and self.params.out.fmt == "":
if not self.params.out.tag == "":
tag_name = "_" + tag_name
if run is not None and num is not None:
fmt = "{out}/{run}{subfolder}/{name}{tag}_{run}_{num:05}{ext}"
elif run is not None:
fmt = "{out}/{run}{subfolder}/{name}{tag}_{run}{ext}"
else:
fmt = "{out}{subfolder}/{name}{tag}{ext}"
elif fmt is None:
fmt = self.params.out.fmt
nml = None
if run is not None:
nml = self.study.namelist[run]
return fmt.format(
run=run,
name=name_full,
tag=tag_name,
num=num,
nml=nml,
out=self.path_out,
ext=ext,
subfolder=subfolder,
)
def get_label_run(self, run, label=None, nml_key=None, time=None):
"""
Set up a label for the run from the namelist and parameters
"""
def get_label_nml(nml_key):
prop_name = os.path.basename(nml_key)
if prop_name in self.label_convert:
prop_label = self.label_convert[prop_name]
else:
prop_label = prop_name
prop_value = self.study.get_nml(nml_key, run)
if prop_name in self.value_convert:
prop_value_str = self.value_convert[prop_name](prop_value)
elif type(prop_value) in [int, float]:
prop_value_str = convert_exp(prop_value, digits=5)
else:
prop_value_str = str(prop_value)
return r"{} = {}".format(prop_label, prop_value_str)
def get_label_file(run):
label_filename = f"{self.path}/{run}/{self.params.input.label_filename}"
if os.path.exists(label_filename):
with open(label_filename, "r") as label_file:
label = label_file.readline()[:-1]
label_file.close()
else:
label = run
return label
if nml_key is None and label is None:
label_run = get_label_file(run)
elif nml_key is not None:
if not type(nml_key) == list:
nml_key = [nml_key]
label_run = ", ".join(map(get_label_nml, nml_key))
if label is not None:
label_run = label + " (" + label_run + ")"
else:
label_run = label
return label_run
def _ax_label_unit(self, node_name, label, unit, unit_coeff, put_units=True):
"""
Find appropriate labels for axis
"""
if label is None:
try:
label = self.current_processor.get_attribute(node_name, "label")
except KeyError:
if os.path.basename(node_name) in self.label_convert:
label = self.label_convert[os.path.basename(node_name)]
else:
label = os.path.basename(node_name)
try:
unit_old = self.current_processor.get_attribute(node_name, "unit")
except KeyError:
unit_old = U.none
if unit is None:
unit = unit_old
if put_units:
if not unit_coeff == 1:
base = unit
unit = unit_coeff * unit
label = label + unit_str(unit, base=base)
else:
label = label + unit_str(unit)
return label, unit_old, unit
def snapshot_title(self, run, title, nml_key, put_time, unit_time=U.Myr):
title = self.get_label_run(run, title, nml_key)
if put_time:
time = self.current_processor.info["time"] * self.study.info["unit_time"]
u_str = unit_str(unit_time, format="{unit}")
time_str = self.params.plot.time_fmt.format(time.express(unit_time), u_str)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
return title
def _plot_map(
self,
name,
ax_los,
run,
xlabel=None,
ylabel=None,
label=None,
unit=None,
unit_coeff=1.0,
overlays=[],
overlays_kwargs=[],
title=None,
put_title=True,
nml_key=None,
put_time=True,
unit_time=U.Myr,
put_units=True,
unit_space=U.pc,
center_space=False,
cmap="plasma",
norm="log",
transform=None,
vmin=None,
vmax=None,
scalebar=None,
scalebar_size=1,
axes=True,
colorbar=True,
embeded=False,
**kwargs,
):
"""
Plot data on a map
"""
ax = plt.gca()
ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los]
im_extent = np.array(self.current_processor.get_attribute("/maps", "im_extent"))
unit_length = self.current_processor.info["unit_length"]
if embeded:
axes = False
# Put a scalebar by default
if scalebar is None:
scalebar = True
if center_space:
center = self.current_processor.get_attribute("/maps", "center")
center_h = center[self._ax_nb[ax_h]]
center_v = center[self._ax_nb[ax_v]]
im_extent[:2] = im_extent[:2] - center_h
im_extent[2:] = im_extent[2:] - center_v
im_extent = im_extent * unit_length.express(unit_space)
node_name = f"/maps/{name}_{ax_los}"
dmap = self.current_processor.get_value(node_name)
label, unit_old, unit = self._ax_label_unit(
node_name, label, unit, unit_coeff, put_units
)
dmap = dmap * unit_old.express(unit)
if transform is not None:
dmap = transform(dmap)
if vmin is None:
vmin = np.min(dmap)
if vmax is None:
vmax = np.max(dmap)
if norm == "log":
norm = mpl.colors.LogNorm(vmin=vmin, vmax=vmax)
elif norm == "linear":
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
im = plt.imshow(
dmap, extent=im_extent, origin="lower", norm=norm, cmap=cmap, **kwargs
)
plt.locator_params(axis="both", nbins=self.params.plot.ntick)
if scalebar:
scalebar = AnchoredSizeBar(
plt.gca().transData,
scalebar_size,
f"{scalebar_size} {unit_str(unit_space)[2:-1]}",
"lower left",
pad=1,
color="white",
frameon=False,
size_vertical=1,
)
plt.gca().add_artist(scalebar)
if axes:
if xlabel is None:
xlabel = self._ax_title[ax_h]
if ylabel is None:
ylabel = self._ax_title[ax_v]
if put_units:
xlabel = xlabel + unit_str(unit_space)
ylabel = ylabel + unit_str(unit_space)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
else:
plt.xticks([])
plt.yticks([])
if colorbar:
if embeded:
cbaxes = inset_axes(
ax, width="10%", height="100%", loc="right", borderpad=0
)
cbar = plt.colorbar(cax=cbaxes, orientation="vertical")
cbaxes.yaxis.set_ticks_position("left")
cbaxes.yaxis.set_label_position("left")
cbaxes.yaxis.set_tick_params(color="white", which="both")
plt.setp(plt.getp(cbaxes.axes, "yticklabels"), color="white")
cbar.outline.set_edgecolor("white")
cbaxes.tick_params(axis="y", direction="in", pad=-25)
plt.sca(ax)
else:
try:
cbar = plt.colorbar(im, cax=plt.gca().cax)
except AttributeError:
cbar = plt.colorbar()
if label is not None:
if embeded:
cbar.set_label(" " + label, color="white", loc="bottom")
else:
cbar.set_label(label)
if put_title:
title = self.snapshot_title(run, title, nml_key, put_time, unit_time)
if embeded:
ax.text(x=0.05, y=0.95, s=title, color="white", transform=ax.transAxes)
else:
plt.title(title)
for i, plot_overlay in enumerate(overlays):
if plot_overlay in self.overlays:
if plot_overlay == "particles":
plot_overlay = partial(
self.overlays[plot_overlay],
unit_space=unit_space,
center_space=center_space,
)
else:
plot_overlay = self.overlays[plot_overlay]
xlim = plt.xlim()
ylim = plt.ylim()
try:
plot_overlay(ax_los, im_extent, **overlays_kwargs[i])
except IndexError:
plot_overlay(ax_los, im_extent)
finally:
# Restore previous limits in case overlays changed it
plt.xlim(xlim)
plt.ylim(ylim)
if embeded:
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
return PlotInfo(
plot_type=PlotType.IMAGE,
xaxis_values=np.linspace(im_extent[0], im_extent[1], dmap.shape[0] + 1),
yaxis_values=np.linspace(im_extent[2], im_extent[3], dmap.shape[1] + 1),
values=dmap,
xaxis_log_scale=False,
yaxis_log_scale=False,
values_log_scale=False,
xaxis_label=xlabel,
yaxis_label=ylabel,
values_label=label,
xaxis_unit=unit_space,
yaxis_unit=unit_space,
values_unit=unit,
plot_title=title,
)
def _overlay_contour(
self,
ax_los,
im_extent,
map_name,
log=False,
lvl_array=None,
lw=None,
lvl_th=None,
lvl_max_lbl=np.inf,
lvl_offset=0,
lbl_fmt="%g",
**kwargs,
):
"""
Add an overlay : contour of other map
"""
map_contour = self.current_processor.get_value(
"/maps/{}_{}".format(map_name, ax_los)
)
if log:
map_contour = np.log10(map_contour)
# Computing linewidths
mask_fin = np.isfinite(map_contour)
if lvl_array is None:
lvl_array = np.arange(
np.min(map_contour[mask_fin]), np.max(map_contour[mask_fin]) + 1
)
if lw is None:
lw = np.ones(lvl_array.size) * 2
if lvl_th:
lw[lvl_array >= lvl_th] = lw[lvl_array >= lvl_th] ** (
lvl_th - lvl_array[lvl_array >= lvl_th]
)
lw[lvl_array < lvl_th] = 1.0
cont = plt.contour(
map_contour,
extent=im_extent,
origin="lower",
linewidths=lw,
levels=lvl_array,
**kwargs,
)
# used levels
lvls = np.array(cont.levels) + lvl_offset
cont.levels = lvls
plt.clabel(
cont,
lvls[np.array(lvls) < lvl_max_lbl],
inline=1,
fontsize=8.0,
fmt=lbl_fmt,
)
def _overlay_levels(self, ax_los, im_extent, **kwargs):
"""
Add an overlay : AMR levels
"""
return self._overlay_contour(
ax_los,
im_extent,
"levels",
lbl_fmt="%1d",
lvl_offset=1,
lvl_th=8,
lvl_max_lbl=11,
**kwargs,
)
def _overlay_particles(
self,
ax_los,
im_extent,
unit_space=U.pc,
center_space=False,
parts=True,
sinks=False,
**kwargs,
):
"""
Add an overlay with particles data
"""
unit_length = self.current_processor.info["unit_length"]
if sinks:
try:
self.current_processor.sinks()
sinks = pd.DataFrame(
self.current_processor.get_value("/datasets/sinks")
)
part_pos = sinks[["x", "y", "z"]].values
mass = sinks.M
unit_length /= self.current_processor.lbox
except KeyError:
self.current_processor._log("No sinks particles", "WARNING")
return
elif parts:
# Open particle HDF5 filetype_from_ext
self.current_processor.load_parts(keys=["pos", "mass"])
part_pos = self.current_processor.parts.pos
mass = self.current_processor.parts.mass
mass *= self.current_processor.info["unit_mass"].express(U.Msun)
self.current_processor.unload_parts()
# index of the horizontal axis
ih = self._ax_nb[self._axes_h[ax_los]]
# index of the vertical axis
iv = self._ax_nb[self._axes_v[ax_los]]
# horizontal coordinates
part_h = part_pos[:, ih]
part_v = part_pos[:, iv]
if center_space:
ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los]
center = self.current_processor.get_attribute("/maps", "center")
center_h = center[self._ax_nb[ax_h]]
center_v = center[self._ax_nb[ax_v]]
part_h -= center_h
part_v -= center_v
part_h *= unit_length.express(unit_space)
part_v *= unit_length.express(unit_space)
# Filter
mask = (
(im_extent[0] <= part_h)
& (part_h <= im_extent[1])
& (im_extent[2] <= part_v)
& (part_v <= im_extent[3])
)
part_h = part_h[mask]
part_v = part_v[mask]
# Scatter plot
plt.scatter(part_h, part_v, s=mass[mask] / 5e3, **kwargs)
def _overlay_speed(
self, ax_los, im_extent, unit=U.km_s, unit_coeff=1.0, key_v=None, **kwargs
):
"""
Add an overlay : velocity vector field
"""
dmap_vh = self.current_processor.get_value("/maps/speed_h_{}".format(ax_los))
dmap_vv = self.current_processor.get_value("/maps/speed_v_{}".format(ax_los))
label, unit_old, unit = self._ax_label_unit(
f"/maps/speed_h_{ax_los}", "", unit, unit_coeff
)
vel_red = self.params.plot.vel_red
# take only a subset of velocities
map_vh_red = dmap_vh[::vel_red, ::vel_red] * unit_old.express(unit)
map_vv_red = dmap_vv[::vel_red, ::vel_red] * unit_old.express(unit)
# get norm information
norm_v = np.sqrt(map_vh_red ** 2 + map_vv_red ** 2)
max_v = np.max(norm_v)
min_v = np.min(norm_v)
# Number of selected vectors
nh = map_vh_red.shape[0]
nv = map_vh_red.shape[1]
# Creates vectors position grid
size_h = im_extent[1] - im_extent[0] # size of the map
dh = size_h / dmap_vh.shape[0] # size of cell
seph = size_h / nh # separation between vectors
h = im_extent[0] + dh + np.arange(nh) * seph
size_v = im_extent[3] - im_extent[2]
dv = size_v / dmap_vh.shape[1]
sepv = size_v / nv
v = im_extent[2] + dv + np.arange(nv) * sepv
hh, vv = np.meshgrid(h, v)
# plot vector field
vec_field = plt.quiver(hh, vv, map_vh_red, map_vv_red, units="width", **kwargs)
# add vector key
if key_v is None:
key_v = (max_v + min_v) / 2.0
plt.quiverkey(
vec_field,
0.6,
0.98,
key_v,
r"${:g}$".format(key_v) + label,
labelpos="E",
coordinates="figure",
)
def _overlay_B(self, ax_los, im_extent, **kwargs):
"""
Add an overlay : magnetic streamlines
"""
dmap_Bh = self.current_processor.get_value(f"/maps/B_h_{ax_los}")
dmap_Bv = self.current_processor.get_value(f"/maps/B_v_{ax_los}")
# TODO : redo this with im_extent
vel_red = self.params.plot.vel_red
radius = self.current_processor.attribute("/maps", "radius")
center = self.current_processor.attribute("/maps", "center")
lbox = self.current_processor.lbox
map_Bh_red = dmap_Bh[::vel_red, ::vel_red] # take only a subset of velocities
map_Bv_red = dmap_Bv[::vel_red, ::vel_red]
nh = map_Bh_red.shape[0]
nv = map_Bv_red.shape[1]
vec_h = (
np.arange(nh) * 2.0 / nh * radius - radius + center[0] + radius / nh
) * lbox
vec_v = (
np.arange(nv) * 2.0 / nv * radius - radius + center[1] + radius / nv
) * lbox
hh, vv = np.meshgrid(vec_h, vec_v)
plt.streamplot(hh, vv, map_Bh_red, map_Bv_red, **kwargs)
def _plot_hist(
self,
name,
ax_los=None,
run=None,
group="/hist/",
xlabel=None,
unit=None,
unit_coeff=1.0,
ytransform=None,
label=None,
put_title=True,
title=None,
nml_key=None,
put_time=True,
unit_time=U.Myr,
xlog=None,
ylog=False,
kind="bar",
ylabel="$\\mathcal{P}$",
color=None,
colors=None,
nml_color=None,
fit=None,
fitlabel=None,
**kwargs,
):
"""
Plot an histogram (PDF, etc ...)
"""
# Get node
if ax_los is not None:
name = name + "_" + ax_los
node_name = group + name
if xlog is None:
try:
xlog = self.current_processor.get_attribute(node_name, "logbins")
except AttributeError:
xlog = False
# get label and units
xlabel, unit_old, unit = self._ax_label_unit(node_name, label, unit, unit_coeff)
# Read data
node = self.current_processor.get_value(node_name)
if "mean" in node:
index = node["runs"].index(run.encode())
values, centers = node["mean"][index]
else:
values, centers = node
if xlog:
centers = centers + np.log10(unit_old.express(unit))
else:
centers = centers * unit_old.express(unit)
if ytransform is not None:
values = ytransform(values)
width = centers[1] - centers[0]
# Set title
title = self.snapshot_title(run, title, nml_key, put_time, unit_time)
if put_title:
plt.title(title)
if label is None:
label = title
# Set colors
if color is None and colors is not None:
if nml_color is None:
color = colors[run]
else:
nml = self.study.get_nml(nml_color, run)
try:
color = colors[nml]
except TypeError:
color = colors(nml)
# Actual plot
if kind == "bar":
plt.bar(
centers, values, width, log=ylog, color=color, label=label, **kwargs
)
elif kind == "step":
if ylog:
plt.yscale("log")
plt.step(centers, values, where="mid", color=color, label=label, **kwargs)
else:
raise ValueError("kind must be 'bar' or 'step'")
# put labels
if label is not None:
plt.xlabel(xlabel)
if ylabel is not None:
plt.ylabel(ylabel)
# Also diplay fit, previously saved
# if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
# slope = node.attrs.slope
# origin = node.attrs.origin
# plt.plot(
# centers,
# 10 ** (slope * centers + origin),
# "--",
# linewidth=2,
# color="orange",
# )
# or a new one
if fit is not None:
self._overlay_fit(
centers, values, kind=fit, ls="--", lw=1.5, label=fitlabel
)
# returns PlotInfo (for Galactica)
edges = np.append(centers - width / 2.0, centers[-1] + width / 2.0)
return PlotInfo(
plot_type=PlotType.HISTOGRAM,
xaxis_values=edges,
yaxis_values=values,
xaxis_log_scale=False,
yaxis_log_scale=ylog,
xaxis_label=xlabel,
yaxis_label=ylabel,
xaxis_unit=unit,
yaxis_unit=U.none,
plot_title=title,
)
def _plot(
self,
name_x,
name_y,
node_arg=None,
xlabel=None,
ylabel=None,
label=None,
xunit=None,
yunit=None,
xunit_coeff=1.0,
yunit_coeff=1.0,
xtransform=None,
ytransform=None,
xscale="linear",
yscale="linear",
fit=None,
fitlabel=None,
smooth=0,
nml_key=None,
run=None,
yerr=None,
yerr_kind="std",
sigma_err=2.0,
grid=False,
put_time=False,
unit_time=U.Myr,
colors=None,
nml_color=None,
legend=None,
subname_x=None,
subname_y=None,
**kwargs,
):
"""
Generic plot routine, with name_x and name_y two path in the hdf5 file
"""
# Get proper hdf5 names
if node_arg is not None:
name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg
# Get hdf5 nodes
node_x = self.current_processor.get_value(name_x)
node_y = self.current_processor.get_value(name_y)
# If the actual data is in another file, fetch it
if subname_x:
hdf5_x = tables.open_file(node_x)
node_x = hdf5_x.get_node(subname_x).read()
if subname_y:
hdf5_y = tables.open_file(node_y)
node_y = hdf5_y.get_node(subname_y).read()
# Find proper labels
xlabel, xunit_old, xunit = self._ax_label_unit(
name_x, xlabel, xunit, xunit_coeff
)
ylabel, yunit_old, yunit = self._ax_label_unit(
name_y, ylabel, yunit, yunit_coeff
)
# If relevent, get time
if put_time:
time = self.current_processor.time * self.study.info["unit_time"]
time_str = self.params.plot.time_fmt.format(
time.express(unit_time), unit_time.latex.replace("text", "math")
)
time_str = f"${time_str}$"
if label is not None and len(label) > 0:
label = label + " | " + time_str
else:
label = time_str
# Manage the different forms in which the data may be stored :
# Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs)
if isinstance(node_y, np.ndarray):
x = node_x * xunit_old.express(xunit)
y = node_y * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
elif "mean" in node_y:
x = node_x * xunit_old.express(xunit)
y = node_y["mean"] * yunit_old.express(yunit)
if yerr_kind == "std":
std = node_y["std"] * yunit_old.express(yunit)
yerr_min = y - sigma_err * std
yerr_max = y + sigma_err * std
elif yerr_kind == "min_max":
yerr_min = node_y["min"] * yunit_old.express(yunit)
yerr_max = node_y["max"] * yunit_old.express(yunit)
elif yerr_kind == "95per":
yerr_min = node_y["q025"] * yunit_old.express(yunit)
yerr_max = node_y["q975"] * yunit_old.express(yunit)
elif yerr_kind == "68per":
yerr_min = node_y["q16"] * yunit_old.express(yunit)
yerr_max = node_y["q84"] * yunit_old.express(yunit)
else:
yerr_min = y
yerr_max = y
yerr = yerr_max - yerr_min
mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(yerr)
x, y, yerr, yerr_min, yerr_max = (
x[mask],
y[mask],
yerr[mask],
yerr_min[mask],
yerr_max[mask],
)
else:
x = node_x[run] * xunit_old.express(xunit)
y = node_y[run] * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
if isinstance(yerr, str):
yerr = self.current_processor.get_value(yerr)
# Apply transformations on x
if xtransform is not None:
x = xtransform(x)
# Apply transformations on y
if ytransform is not None:
y = ytransform(y)
if yerr is not None:
self._log(
"Errorbar may be meaningless when ytransform is used", "WARNING"
)
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
if run is not None:
label = self.get_label_run(run, label, nml_key)
# Look if special colors method is used
if colors is None:
if yerr is None:
(base_line,) = plt.plot(x, y, label=label, **kwargs)
else:
base_line, _, _ = plt.errorbar(x, y, yerr=yerr, label=label, **kwargs)
else:
if nml_color is None:
color = colors[run]
elif nml_color == "time":
time = (
self.current_processor.time
* self.current_processor.info["unit_time"]
).express(unit_time)
color = colors(time)
else:
nml = self.study.get_nml(nml_color, run)
try:
color = colors[nml]
except TypeError:
color = colors(nml)
if yerr is None or yerr_kind is None:
(base_line,) = plt.plot(x, y, label=label, color=color, **kwargs)
else:
base_line, _, _ = plt.errorbar(
x, y, yerr=yerr, color=color, label=label, **kwargs
)
# Ax decorations
plt.xlabel(xlabel)
plt.ylabel(ylabel)
if grid:
plt.grid()
if legend:
plt.legend()
# Ax scale
plt.xscale(xscale)
plt.yscale(yscale)
if fit is not None:
self._overlay_fit(
x,
y,
yerr,
kind=fit,
ls="--",
lw=1.5,
color=base_line.get_color(),
label=fitlabel,
)
if subname_x:
hdf5_x.close()
if subname_y:
hdf5_y.close()
def _pspec(self, name, **kwargs):
"""
Plot power spectrum (wrapper around pspec_read)
"""
del kwargs["run"]
file_pspec = self.current_processor.get_value("/hdf5/pspec")
num = self.current_processor.num
getattr(pspec_read, "pspec_" + name)(file_pspec, ".", num, **kwargs)
def _overlay_fit(self, x, y, yerr=None, kind="linear", label=None, **kwargs):
"""
Add an overlay : fit a curve, linear or powerlaw
"""
if kind == "linear":
if yerr is None or np.sum(np.abs(yerr)) == 0:
(a, b, rho, _map_rule, stderr) = linregress(x, y)
self._log(
"Linear fit y = {} x + {} with R^2 = {} and error is {}".format(
a, b, rho, stderr
)
)
if label is None:
label = r"Linear fit with slope ${:.3g}$ and $R^2 = {:.3f}$".format(
a, rho
)
else:
fit = polyfit(x, y, 1, w=[1.0 / ty for ty in yerr], full=True)
c = fit[0]
residual = fit[1][0][0]
b, a = c[0], c[1]
self._log(
"Linear fit y = {} x + {} with residual {}".format(a, b, residual)
)
if label is None:
label = r"Linear fit with slope ${:.3g}$".format(a)
plt.plot(x, a * x + b, label=label, **kwargs)
elif kind == "power_law":
if yerr is None or np.sum(np.abs(yerr)) == 0:
(a, b, rho, _map_rule, stderr) = linregress(np.log10(x), np.log10(y))
self._log(
"Power law fit y = x^({}) * {} with R^2 = {} and error is {}".format(
a, 10 ** b, rho, stderr
)
)
else:
def fitfunc(p, x):
return p[0] + p[1] * x
def errfunc(p, x, y, err):
return (y - fitfunc(p, x)) / err
pinit = [1.0, -1.0]
out = optimize.leastsq(
errfunc,
pinit,
args=(np.log10(x), np.log10(y), yerr / y),
full_output=1,
)
c = out[0]
b, a = c[0], c[1]
residual = errfunc(c, np.log10(x), np.log10(y), yerr / y)
self._log(
"Power law fit y = x^({}) * {} with residual {}".format(
a, 10 ** b, residual
)
)
if label is None:
label = r"Power-law fit with index {:.1f}".format(a)
plt.plot(x, (10 ** b) * x ** a, label=label, **kwargs)
def _gen_from_log(self, logrule, name_y, name_x="time", description="Generated"):
if name_x == "time":
name_rule = name_y
else:
name_rule = name_y + "_" + name_x
self.rules[name_rule] = PlotRule(
self,
partial(
self._plot,
"/series/" + logrule + "/" + name_x,
"/series/" + logrule + "/" + name_y,
),
description=description,
kind="run",
dependencies=[logrule],
)
def def_rules(self):
"""
This is where rules are defined
"""
self.rules = {
"plot": PlotRule(
self, lambda arg, **kwargs: self._plot(*arg, **kwargs), kind="comp"
),
"plot_run": PlotRule(
self, lambda arg, **kwargs: self._plot(*arg, **kwargs), kind="run"
),
"plot_snapshot": PlotRule(
self, lambda arg, **kwargs: self._plot(*arg, **kwargs)
),
"plot_map": PlotRule(
self, lambda mapname, **kwargs: self._plot_map(mapname, **kwargs)
),
"coldens": PlotRule(
self,
partial(
self._plot_map,
"coldens",
label=r"$\Sigma$",
# unit=U.coldens
),
"Column density map",
dependencies=["coldens"],
),
"slice_T": PlotRule(
self,
partial(
self._plot_map,
"T",
label=r"$T$",
),
"Slice of temperature",
dependencies=["T"],
),
"alpha_disk": PlotRule(
self,
partial(self._plot_map, "alpha_disk", label=r"$\alpha$"),
"Map of the Shakura&Sunaev alpha parameter for disks",
dependencies=["alpha_disk"],
),
"alpha_grav": PlotRule(
self,
partial(self._plot_map, "alpha_grav", label=r"$\alpha_g$"),
"Map of the grav Shakura&Sunaev alpha parameter for disks",
dependencies=["alpha_grav"],
),
"coldens_l": PlotRule(
self,
partial(
self._plot_map,
"coldens",
label=r"$\Sigma$",
unit=U.coldens,
overlays=[self._overlay_levels],
),
"Column density with level overlay",
dependencies=["coldens", "levels"],
),
"slice_rho_v": PlotRule(
self,
partial(
self._plot_map,
"slice_rho",
label=r"$\rho$",
unit=U.Msun_pc3,
overlays=[self._overlay_speed],
),
"Density slice with speed overlay",
dependencies=["slice_rho", "speed_h", "speed_v"],
),
"slice_rho_B": PlotRule(
self,
partial(
self._plot_map,
"slice_rho",
label=r"$\rho$",
unit=U.Msun_pc3,
overlays=[self._overlay_B],
),
"Density slice with magnetic field overlay",
dependencies=["slice_rho", "B_h", "B_v"],
),
"slice_rho_B_vel": PlotRule(
self,
partial(
self._plot_map,
"slice_rho",
label=r"$\rho$",
unit=U.Msun_pc3,
overlays=[self._overlay_B, self._overlay_speed],
),
"Density slice with magnetic field and velocity overlay",
dependencies=["slice_rho", "B_h", "B_v", "speed_h", "speed_v"],
),
"jeans_ratio": PlotRule(
self,
partial(
self._plot_map,
"jeans_ratio",
vmin=0.1,
vmax=100,
cmap="RdBu_r",
overlays=[self._overlay_levels],
),
"Jeans' lenght divided by the max resolution",
dependencies=["jeans_ratio", "levels"],
),
"Q": PlotRule(
self,
partial(
self._plot_map,
"Q",
label=r"$Q$",
vmin=0.01,
vmax=100,
cmap="RdBu_r",
),
"Toomre Q parameter for a Keplerian disk",
dependencies=["Q"],
),
"rho_pdf": PlotRule(
self,
partial(self._plot_hist, "rho_pdf"),
"$\rho$-PDF",
dependencies=["rho_pdf"],
),
"rho_pdf_mw": PlotRule(
self,
partial(self._plot_hist, "rho_pdf_mw"),
"Mass weighted $\rho$-PDF",
dependencies=["rho_pdf_mw"],
),
"cos_pdf": PlotRule(
self,
partial(self._plot_hist, "cos_pdf"),
"cos-PDF",
dependencies=["cos_pdf", "mwa_speed"],
),
"avg_coldens_pdf": PlotRule(
self,
partial(
self._plot_hist,
"avg_time_coldens_pdf_z",
group="/comp/",
xlog=True,
put_time=False,
),
"Column density PDF, averaged in time",
kind="runs",
dependencies={"avg_time_coldens_pdf": "z"},
),
"T_pdf": PlotRule(
self,
partial(self._plot_hist, "T_pdf"),
"T-PDF on a 2D slice",
dependencies=["T_pdf"],
),
"P_pdf": PlotRule(
self,
partial(self._plot_hist, "P_pdf"),
"P-PDF on a 2D slice ",
dependencies=["P_pdf"],
),
"B_int": PlotRule(
self,
partial(
self._plot_map, "B_int", label=r"$\mid \mathrm{B} \mid$", unit=U.T
),
"Magnetic intensity map",
dependencies=["B_int"],
),
"Brho": PlotRule(
self,
partial(
self._plot,
"/datasets/Brho/rho",
"/datasets/Brho/B",
label=r"$\mathrm{B} $",
put_time=True,
),
"Brho on a 2D slice ",
dependencies=["Brho"],
),
"Ek_Eb_rho": PlotRule(
self,
partial(
self._plot,
"/datasets/Ek_Eb_rho/rho",
"/datasets/Ek_Eb_rho/Ek_Eb_rho",
label=r"Ek/Eb",
put_time=True,
),
"Ek/Eb on a 2D slice ",
dependencies=["Ek_Eb_rho", "mwa_speed"],
),
"rho_prof": PlotRule(
self,
partial(self._plot, "/profile/axis", "/profile/rho_prof"),
"Density profile",
dependencies=["axis", "rho_prof"],
),
"pspec": PlotRule(self, self._pspec, dependencies={"pspec": None}),
"sbeta": PlotRule(
self,
partial(
self._plot,
"/comp/nml_cloud_params/beta_cool",
"/comp/avg_time_pdf_slope_coldens",
),
"Slope of the Sigma-PDF against cooling beta factor",
kind="comp",
dependencies={
"nml": "cloud_params/beta_cool",
"avg_time_pdf_slope_coldens": None,
},
),
"sbeta_onavg": PlotRule(
self,
partial(
self._plot,
"/comp/sbeta_onavg/beta",
"/comp/sbeta_onavg/slope",
yerr="/comp/sbeta_onavg/stderr",
),
"Slope of the time averaged Sigma-PDF against cooling beta factor",
kind="comp",
dependencies=["sbeta_onavg"],
),
"sink_mass": PlotRule(
self,
partial(
self._plot,
"/series/sinks_from_log/time",
"/series/sinks_from_log/mass_sink",
xunit=U.Myr,
yunit=U.Msun,
),
"Mass of the sinks as a function of time",
kind="run",
dependencies=["sinks_from_log"],
),
"ssm": PlotRule(
self,
partial(
self._plot,
"/series/sinks_from_log/time",
"/series/sinks_from_log/ssm",
xunit=U.Myr,
yunit=U.Msun / U.pc ** 2,
),
"Mass of the sinks as a function of time divided by surface",
kind="run",
dependencies=["ssm"],
),
"assfr": PlotRule(
self,
partial(
self._plot,
"/series/sfr_from_log/time",
"/series/sfr_from_log/sfr",
ylabel="Averaged surfacic SFR",
xunit=U.Myr,
yunit=U.ssfr,
),
kind="run",
dependencies=["sfr_from_log"],
),
"issfr": PlotRule(
self,
partial(
self._plot,
"/series/sinks_from_log/time",
"/series/sinks_from_log/issfr",
ylabel="Surfacic SFR",
xunit=U.Myr,
yunit=U.ssfr,
),
kind="run",
dependencies=["issfr"],
),
"turb_rms": PlotRule(
self,
partial(
self._plot,
"/series/rms_from_log/time",
"/series/rms_from_log/turb_rms",
xunit=U.Myr,
),
"Turbulent RMS",
kind="run",
dependencies=["rms_from_log"],
),
"turb_energy": PlotRule(
self,
partial(
self._plot,
"/series/rms_from_log/time",
"/series/rms_from_log/turb_energy",
xunit=U.Myr,
),
"Turbulent energy",
kind="run",
dependencies=["rms_from_log"],
),
"turb_power": PlotRule(
self,
partial(
self._plot,
"/series/rms_from_log/time",
"/series/rms_from_log/turb_power",
xunit=U.Myr,
),
"Turbulent power",
kind="run",
dependencies=["turb_power"],
),
"sigma": PlotRule(
self,
partial(
self._plot,
"/series/time",
"/series/time_sigma",
ylabel="$\\sigma$",
xunit=U.Myr,
yunit=U.km_s,
),
"Velocity dispersion",
kind="run",
dependencies=["time_sigma"],
),
"mwa_B_int": PlotRule(
self,
partial(
self._plot,
"/series/time",
"/series/time_mwa_B_int",
xunit=U.Myr,
yunit=U.uG,
),
"Magnetic intensity average",
kind="run",
dependencies=["time_mwa_B_int"],
),
"mass": PlotRule(
self,
partial(
self._plot,
"/series/time",
"/series/time_mass",
xunit=U.Myr,
yunit=U.Msun,
),
"Total mass in the box",
kind="run",
dependencies=["time_mass"],
),
"max_fluct_coldens": PlotRule(
self,
partial(
self._plot,
"/series/time",
"/series/time_max_fluct_coldens_z",
ylabel="$\\max(\\Sigma/\\overline{\\Sigma})$",
xunit=U.Myr,
),
"Maximal fluctuation of the column density against time",
kind="run",
dependencies={"time_max_fluct_coldens": "z"},
),
}
averageables = [
"coldens",
"Q",
"T",
"T_mwavg",
"alpha_disk",
"alpha_grav",
]
# Generic rules directly from Ramses fields
for field in self.params.pymses.variables:
def generic_rule(name):
self.rules["slice_" + name] = PlotRule(
self,
partial(self._plot_map, "slice_" + name),
"{} slice".format(name),
dependencies=["slice_" + name],
)
self.rules[name + "_mwavg"] = PlotRule(
self,
partial(self._plot_map, name + "_mwavg"),
"Ax mass-weighted averaged {}".format(name),
dependencies=[name + "_mwavg"],
)
self.rules[name + "_avg"] = PlotRule(
self,
partial(self._plot_map, name + "_avg"),
"Ax averaged {}".format(name),
dependencies=[name + "_avg"],
)
averageables.append("slice_" + name)
averageables.append(name + "_mwavg")
averageables.append(name + "_avg")
# special for vectors
if field in ["g", "vel"]:
# Components
for i, dir in enumerate(["x", "y", "z"]):
generic_rule(field + dir)
# Radial
generic_rule(field + "r")
# Orthoradial
generic_rule(field + "phi")
# Norm
generic_rule(field + "_norm")
else:
generic_rule(field)
for name in averageables:
self.rules["rad_" + name] = PlotRule(
self,
partial(
self._plot,
"/radial/radial_centers",
"/radial/rad_mwavg_" + name,
),
"Azimuthal mass weighted average of {}".format(name),
dependencies=["radial_centers", "rad_mwavg_" + name],
)
self.rules["dispersion_rad_" + name] = PlotRule(
self,
partial(
self._plot,
"/radial/radial_centers",
"/radial/dispersion_rad_" + name,
),
"Radial dispersion of {}".format(name),
dependencies=["radial_centers", "dispersion_rad_" + name],
)
self.rules["avg_map_" + name] = PlotRule(
self,
partial(self._plot_map, "avg_map_" + name),
"Map of the radial average of {}".format(name),
dependencies=["avg_map_" + name],
)
self.rules["mwavg_map_" + name] = PlotRule(
self,
partial(self._plot_map, "mwavg_map_" + name),
"Map of the mass weighted radial average of {}".format(name),
dependencies=["avg_map_" + name],
)
self.rules["fluct_" + name] = PlotRule(
self,
partial(self._plot_map, "fluct_" + name, cmap="RdBu_r"),
"Fluctuation of {}".format(name),
dependencies=["fluct_" + name],
)
self.rules["pdf_" + name] = PlotRule(
self,
partial(self._plot_hist, "pdf_" + name, ylog=True),
"Probability density function of {} fluctuations".format(name),
dependencies=["pdf_" + name],
)
for name_bin in averageables:
if name_bin is not name:
group = "mbb_{}_{}".format(name, name_bin)
self.rules["mbb_" + name + "_" + name_bin] = PlotRule(
self,
partial(self._plot_hist, group, ylabel=r"$\alpha$"),
"Mean of {} by bins of {}".format(name, name_bin),
dependencies=[group],
)
for name in [
"step",
"mcons",
"econs",
"epot",
"ekin",
"eint",
"emag",
"elapsed",
]:
self._gen_from_log("coarse_step_from_log", name)
for name in [
"time",
"mcons",
"econs",
"epot",
"ekin",
"eint",
"emag",
"elapsed",
]:
self._gen_from_log("coarse_step_from_log", name_y=name, name_x="step")
for name in ["fine_step", "dt", "a", "mem_cells", "mem_parts"]:
self._gen_from_log("fine_step_from_log", name)
for name in ["time", "dt", "a", "mem_cells", "mem_parts"]:
self._gen_from_log("fine_step_from_log", name_y=name, name_x="fine_step")
# Dict of overlays
self.overlays = {
"B": self._overlay_B,
"speed": self._overlay_speed,
"levels": self._overlay_levels,
"contour": self._overlay_contour,
"particles": self._overlay_particles,
}
super(Plotter, self).def_rules()