add astrophysix support, fix labels

This commit is contained in:
Noe Brucy
2020-12-14 10:08:01 +01:00
parent 653e64d782
commit 7f7216abf6
6 changed files with 519 additions and 355 deletions
+6 -1
View File
@@ -17,6 +17,7 @@ class Comparator(Aggregator, HDF5Container):
pp_params=default_params(), pp_params=default_params(),
selector=None, selector=None,
tag=None, tag=None,
unit_time=cst.year,
**kwargs **kwargs
): ):
""" """
@@ -53,7 +54,11 @@ class Comparator(Aggregator, HDF5Container):
for num in self.nums[run]: for num in self.nums[run]:
self.pp[run][num] = PostProcessor( self.pp[run][num] = PostProcessor(
path_run, num, path_out=path_out_run, pp_params=self.pp_params path_run,
num,
path_out=path_out_run,
pp_params=self.pp_params,
unit_time=unit_time,
) )
run0 = self.runs[0] run0 = self.runs[0]
+310 -303
View File
@@ -17,7 +17,8 @@ from scipy.stats import linregress
from numpy.polynomial.polynomial import polyfit from numpy.polynomial.polynomial import polyfit
from scipy.ndimage.filters import gaussian_filter1d from scipy.ndimage.filters import gaussian_filter1d
from scipy import optimize from scipy import optimize
from astrophysix.simdm.datafiles import Datafile, PlotType, PlotInfo
from astrophysix.utils.file import FileType
import matplotlib as mpl import matplotlib as mpl
if os.environ.get("DISPLAY", "") == "": if os.environ.get("DISPLAY", "") == "":
@@ -26,12 +27,15 @@ if os.environ.get("DISPLAY", "") == "":
import pylab as P import pylab as P
from comparator import * from comparator import *
import pspec_read import pspec_read
import datetime
P.rcParams["image.cmap"] = "plasma" filetype_from_ext = {ext: ft for ft in FileType for ext in ft.extension_list}
P.rcParams["savefig.dpi"] = 400
tex_params = {"text.latex.preamble": r"\usepackage{amsmath}"}
P.rcParams.update(tex_params) def not_array_error(err):
epy2 = "object does not support indexing"
epy3 = "object is not subscriptable"
return str(err)[-len(epy2) :] == epy2 or str(err)[-len(epy3) :] == epy3
class PlotRule(Rule): class PlotRule(Rule):
@@ -42,7 +46,7 @@ class PlotRule(Rule):
def plot(self, save, arg, **kwargs): def plot(self, save, arg, **kwargs):
""" """
Set the plotter's storage to 'save' and exetute the rule Set the plotter's storage to 'save' and execute the rule
Parameters Parameters
---------- ----------
@@ -53,6 +57,12 @@ class PlotRule(Rule):
self.postproc.save = save self.postproc.save = save
return self.process_fn(arg, **kwargs) return self.process_fn(arg, **kwargs)
def datafile(self, arg):
return Datafile(
name=self.description + "_" + arg,
description=self.description + " ({})".format(arg),
)
class Plotter(Aggregator, BaseProcessor): class Plotter(Aggregator, BaseProcessor):
""" """
@@ -99,6 +109,7 @@ class Plotter(Aggregator, BaseProcessor):
pp_params=None, pp_params=None,
selector=None, selector=None,
tag=None, tag=None,
unit_time=cst.year,
**kwargs, **kwargs,
): ):
@@ -135,7 +146,13 @@ class Plotter(Aggregator, BaseProcessor):
# Get comparator object # Get comparator object
self.comp = Comparator( self.comp = Comparator(
path, self.runs, self.nums, path_out, self.pp_params, selector=self.selector path,
self.runs,
self.nums,
path_out,
self.pp_params,
unit_time=unit_time,
selector=self.selector,
) )
# Get postprocesor objets for each run # Get postprocesor objets for each run
@@ -147,17 +164,56 @@ class Plotter(Aggregator, BaseProcessor):
# Define rules # Define rules
self.def_rules() self.def_rules()
# generate astrophysix's simulations object
self.gen_simus()
self.save = None self.save = None
def gen_simus(self):
self.simulations = {}
simu_fmt = self.pp_params.astrophysix.simu_fmt
descr_fmt = self.pp_params.astrophysix.descr_fmt
tag = self.pp_params.out.tag
for run in self.runs:
pp = self.pp[run][self.nums[run][0]]
nml = self.comp.namelist[run]
name = simu_fmt.format(run=run, tag=tag, nml=nml)
exec_time = str(datetime.datetime.fromtimestamp(os.stat(pp.path).st_ctime))
description = descr_fmt.format(run=run, tag=tag, nml=nml)
simu = Simulation(
simu_code=ramses,
name=name,
alias=name.upper(),
description=description,
directory_path=pp.path,
execution_time=exec_time,
)
for param in ramses.input_parameters:
try:
param_setting = ParameterSetting(
input_param=param,
value=self.comp.get_nml(param.key, run),
visibility=ParameterVisibility.BASIC_DISPLAY,
)
simu.parameter_settings.add(param_setting)
except KeyError as e:
self._log("key {} not found".format(e), "WARNING")
except AttributeError as e:
self._log("{}".format(e), "WARNING")
self.simulations[run] = simu
def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs): def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs):
""" """
Check if the dependency belongs to the plotter object or to another one (comp, pp, ..) Check if the dependency belongs to the plotter object or to another one (comp, pp, ..)
""" """
if dep in self.comp.rules: if dep in self.comp.rules:
done = self.comp.process( result = self.comp.process(
dep, dep_arg, overwrite, overwrite_dep=self.overwrite_dep dep, dep_arg, overwrite, overwrite_dep=self.overwrite_dep
) )
self.just_done.extend(done) if result is not None:
self.just_done.append(done)
else: else:
super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, **kwargs) super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, **kwargs)
@@ -172,19 +228,13 @@ class Plotter(Aggregator, BaseProcessor):
) )
def _process_rule( def _process_rule(
self, self, name, rule, arg, overwrite=False, ax=None, from_cells=False, **kwargs
name,
rule,
arg,
overwrite=False,
ax=None,
movie=False,
from_cells=False,
**kwargs,
): ):
""" """
Open storage and figure if needed before processing a rule Open storage and figure if needed before processing a rule
""" """
# Set full name according to argument
if not arg is None: if not arg is None:
name_full = ( name_full = (
name name
@@ -200,144 +250,89 @@ class Plotter(Aggregator, BaseProcessor):
else: else:
name_full = name name_full = name
if rule.is_valid(arg): # Exit if not valid
if rule.kind == "classic" or rule.kind == "cells": if not rule.is_valid(arg):
if "select" in kwargs:
select = kwargs.pop("select")
runs, nums = self.selector.select(**select)
elif "runs" in kwargs:
runs = kwargs.pop("runs")
if isinstance(runs, RunSelector):
nums = runs.nums
runs = runs.runs
else:
nums = self.nums
else:
runs = self.runs
nums = self.nums
i = 0
for run in runs:
files = []
for num in nums[run]:
plot_filename = self._find_filename(name_full, run, num)
if from_cells or rule.kind == "cells":
if not os.exists(self.pp[run][num].cells_filename):
self.pp[run][num].load_cells()
self.pp[run][num].unload_cells()
save = tables.open_file(self.pp[run][num].cells_filename)
elif rule.kind == "classic":
save = tables.open_file(self.pp[run][num].filename)
else:
save = tables.open_file(self.comp.filename, "r")
try:
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax[i],
run=run,
**kwargs,
)
except TypeError as e:
if str(e) in [
"'LocatableAxes' object does not support indexing",
"'AxesSubplot' object does not support indexing",
"'AxesSubplot' object is not subscriptable",
"'Axes' object is not subscriptable",
"'LocatableAxes' object is not subscriptable",
]:
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax,
run=run,
**kwargs,
)
elif ax is None:
fig = P.figure()
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=P.gca(),
run=run,
**kwargs,
)
else:
raise
finally:
save.close()
i = i + 1
files.append(plot_filename)
else:
if "select" in kwargs and not "runs" in kwargs:
select = kwargs.pop("select")
runs, nums = self.selector.select(**select)
if not rule.kind == "runs":
kwargs["runs"] = runs
elif rule.kind == "runs" and "runs" in kwargs:
runs = kwargs.pop("runs")
else:
runs = self.runs
if ax is None:
ax = P.gca()
if rule.kind == "series" and len(runs) == 1:
run = self.runs[0]
plot_filename = self._find_filename(name_full, run)
else:
plot_filename = self._find_filename(name_full)
save = tables.open_file(self.comp.filename, "r")
try:
if rule.kind == "runs":
for i, run in enumerate(runs):
try:
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax[i],
run=run,
**kwargs,
)
except TypeError as e:
if str(e) in [
"'LocatableAxes' object does not support indexing",
"'AxesSubplot' object does not support indexing",
"'AxesSubplot' object is not subscriptable",
"'Axes' object is not subscriptable",
"'LocatableAxes' object is not subscriptable",
]:
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=ax,
run=run,
**kwargs,
)
else:
self._plot_rule(
rule, save, arg, plot_filename, overwrite, ax, **kwargs
)
finally:
save.close()
else:
self._log("{} is not valid in this context".format(name_full), "ERROR") self._log("{} is not valid in this context".format(name_full), "ERROR")
return
# get filetype of the output
filetype = filetype_from_ext[self.pp_params.out.ext]
# Select runs and nums
if "select" in kwargs:
select = kwargs.pop("select")
runs, nums = self.selector.select(**select)
else:
runs = self.runs
nums = self.nums
datafiles = []
# Several plots
if rule.kind == "classic" or rule.kind == "cells":
run_num = [(run, num) for run in runs for num in nums[run]]
else:
run_num = [(run, None) for run in runs]
for i, (run, num) in enumerate(run_num):
# Find filename
plot_filename = self._find_filename(name_full, run, num)
# Find ax
try:
real_ax = ax[i]
except TypeError as e:
if ax is None:
fig, real_ax = P.subplots(1, 1)
elif not_array_error(e):
real_ax = ax
else:
raise
# Find plot save
if from_cells or rule.kind == "cells":
if not os.exists(self.pp[run][num].cells_filename):
self.pp[run][num].load_cells()
self.pp[run][num].unload_cells()
save = tables.open_file(self.pp[run][num].cells_filename)
elif rule.kind == "classic":
save = tables.open_file(self.pp[run][num].filename)
else:
save = tables.open_file(self.comp.filename, "r")
# Call plot routine
try:
plot_info = self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=real_ax,
run=run,
**kwargs,
)
finally:
save.close()
# Save in astrophysix format
df = rule.datafile(arg)
df[filetype] = plot_filename
if plot_info is not None:
df.plot_info = plot_info
if num is not None:
snap = self.pp[run][num].snapshot
if overwrite and df.name in snap.datafiles:
del snap.datafiles[df.name]
elif df.name not in snap.datafiles:
snap.datafiles.add(df)
if snap not in self.simulations[run].snapshots:
self.simulations[run].snapshots.add(snap)
datafiles.append(df)
return datafiles
def _plot_rule(self, rule, save, arg, plot_filename, overwrite, ax, **kwargs): def _plot_rule(self, rule, save, arg, plot_filename, overwrite, ax, **kwargs):
""" """
@@ -345,7 +340,7 @@ class Plotter(Aggregator, BaseProcessor):
""" """
P.sca(ax) P.sca(ax)
if self._needs_computation(overwrite, plot_filename): if self._needs_computation(overwrite, plot_filename):
rule.plot(save, arg, **kwargs) plot_info = rule.plot(save, arg, **kwargs)
if not self.pp_params.out.interactive: if not self.pp_params.out.interactive:
P.tight_layout(pad=1) P.tight_layout(pad=1)
@@ -360,6 +355,7 @@ class Plotter(Aggregator, BaseProcessor):
if not self.pp_params.out.interactive: if not self.pp_params.out.interactive:
P.close() P.close()
return plot_info
else: else:
self._log("Plot {} is already done, skipping...".format(plot_filename)) self._log("Plot {} is already done, skipping...".format(plot_filename))
@@ -396,7 +392,7 @@ class Plotter(Aggregator, BaseProcessor):
ext=self.pp_params.out.ext, ext=self.pp_params.out.ext,
) )
def _label_run(self, run, node, label, nml_key): def _label_run(self, run, node, label, nml_key, time=None):
""" """
Set up a label for the run from the namelist and parameters Set up a label for the run from the namelist and parameters
""" """
@@ -435,7 +431,7 @@ class Plotter(Aggregator, BaseProcessor):
label_run = label label_run = label
return label_run return label_run
def _ax_label_unit(self, node, label, unit, unit_coeff): def _ax_label_unit(self, node, label, unit, unit_coeff, put_units=True):
""" """
Find appropriate labels for axis Find appropriate labels for axis
""" """
@@ -457,20 +453,38 @@ class Plotter(Aggregator, BaseProcessor):
if unit is None: if unit is None:
unit = unit_old unit = unit_old
if not unit_coeff == 1: if put_units:
base = unit if not unit_coeff == 1:
unit = unit_coeff * unit base = unit
label = label + unit_str(unit, base=base) unit = unit_coeff * unit
else: label = label + unit_str(unit, base=base)
label = label + unit_str(unit) else:
label = label + unit_str(unit)
return label, unit_old, unit return label, unit_old, unit
def _snapshot_title(self, run, node, title, nml_key, put_time, unit_time=cst.Myr):
title = self._label_run(run, node, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
u_str = unit_str(unit_time, format="{unit}")
time_str = self.pp_params.plot.time_fmt.format(
time.express(unit_time), u_str
)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
return title
def _plot_map( def _plot_map(
self, self,
name, name,
ax_los, ax_los,
run, run,
xlabel=None,
ylabel=None,
label=None, label=None,
unit=None, unit=None,
unit_coeff=1.0, unit_coeff=1.0,
@@ -480,12 +494,14 @@ class Plotter(Aggregator, BaseProcessor):
put_title=True, put_title=True,
nml_key=None, nml_key=None,
put_time=True, put_time=True,
time_unit=cst.Myr, unit_time=cst.Myr,
put_units=True,
unit_space=cst.pc, unit_space=cst.pc,
cmap="plasma", cmap="plasma",
norm="log", norm="log",
put_cbar=True, put_cbar=True,
autoscale=True, autoscale=True,
transform=None,
**kwargs, **kwargs,
): ):
""" """
@@ -503,9 +519,13 @@ class Plotter(Aggregator, BaseProcessor):
node = self.save.get_node("/maps/{}_{}".format(name, ax_los)) node = self.save.get_node("/maps/{}_{}".format(name, ax_los))
dmap = node.read() dmap = node.read()
label, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff) label, unit_old, unit = self._ax_label_unit(
node, label, unit, unit_coeff, put_units
)
dmap = dmap * unit_old.express(unit) dmap = dmap * unit_old.express(unit)
if transform is not None:
dmap = transform(dmap)
if norm == "log": if norm == "log":
norm = mpl.colors.LogNorm() norm = mpl.colors.LogNorm()
@@ -521,32 +541,28 @@ class Plotter(Aggregator, BaseProcessor):
P.locator_params(axis="both", nbins=self.pp_params.plot.ntick) P.locator_params(axis="both", nbins=self.pp_params.plot.ntick)
P.xlabel(self._ax_title[ax_h] + unit_str(unit_space)) if xlabel is None:
P.ylabel(self._ax_title[ax_v] + unit_str(unit_space)) xlabel = self._ax_title[ax_h]
if ylabel is None:
ylabel = self._ax_title[ax_v]
if put_units:
xlabel = xlabel + unit_str(unit_space)
ylabel = ylabel + unit_str(unit_space)
P.xlabel(xlabel)
P.ylabel(ylabel)
try: try:
cbar = P.colorbar(im, cax=P.gca().cax) cbar = P.colorbar(im, cax=P.gca().cax)
except AttributeError: except AttributeError:
cbar = P.colorbar() cbar = P.colorbar()
if put_title:
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
P.title(title)
if not label is None: if not label is None:
cbar.set_label(label) cbar.set_label(label)
if put_title:
title = self._label_run(run, node, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex
)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
P.title(title)
for i, plot_overlay in enumerate(overlays): for i, plot_overlay in enumerate(overlays):
if plot_overlay in self.overlays: if plot_overlay in self.overlays:
plot_overlay = self.overlays[plot_overlay] plot_overlay = self.overlays[plot_overlay]
@@ -556,6 +572,23 @@ class Plotter(Aggregator, BaseProcessor):
except: except:
plot_overlay(ax_los, im_extent) plot_overlay(ax_los, im_extent)
return PlotInfo(
plot_type=PlotType.IMAGE,
xaxis_values=np.linspace(im_extent[0], im_extent[1], dmap.shape[0] + 1),
yaxis_values=np.linspace(im_extent[2], im_extent[3], dmap.shape[1] + 1),
values=dmap,
xaxis_log_scale=False,
yaxis_log_scale=False,
values_log_scale=(norm == mpl.colors.LogNorm()),
xaxis_label=xlabel,
yaxis_label=ylabel,
values_label=label,
xaxis_unit=unit_space,
yaxis_unit=unit_space,
values_unit=unit,
plot_title=title,
)
def _overlay_contour( def _overlay_contour(
self, self,
ax_los, ax_los,
@@ -724,7 +757,7 @@ class Plotter(Aggregator, BaseProcessor):
nml_key=None, nml_key=None,
put_title=True, put_title=True,
put_time=True, put_time=True,
time_unit=cst.Myr, unit_time=cst.Myr,
**kwargs, **kwargs,
): ):
""" """
@@ -749,26 +782,13 @@ class Plotter(Aggregator, BaseProcessor):
if not ylabel is None: if not ylabel is None:
P.ylabel(ylabel) P.ylabel(ylabel)
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
if put_title: if put_title:
title = self._label_run(run, node, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex
)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
P.title(title) P.title(title)
if label == None:
if label is None:
label = title label = title
P.plot(bin_centers, mean_bin, label=label, **kwargs)
P.plot(bin_centers, mean_bin, label=title, **kwargs) P.plot(bin_centers, mean_bin, label=label, **kwargs)
def _plot_hist( def _plot_hist(
self, self,
@@ -781,10 +801,11 @@ class Plotter(Aggregator, BaseProcessor):
unit_coeff=1.0, unit_coeff=1.0,
ytransform=None, ytransform=None,
label=None, label=None,
put_title=True,
title=None, title=None,
nml_key=None, nml_key=None,
put_time=True, put_time=True,
time_unit=cst.Myr, unit_time=cst.Myr,
xlog=None, xlog=None,
ylog=False, ylog=False,
kind="bar", kind="bar",
@@ -799,47 +820,41 @@ class Plotter(Aggregator, BaseProcessor):
""" """
Plot an histogram (PDF, etc ...) Plot an histogram (PDF, etc ...)
""" """
# Get node
if not ax_los is None: if not ax_los is None:
name = name + "_" + ax_los name = name + "_" + ax_los
node = self.save.get_node(group + name) node = self.save.get_node(group + name)
if xlog is None: if xlog is None:
try: try:
xlog = node._v_attrs_.logbins xlog = node._v_attrs_.logbins
except: except:
xlog = False xlog = False
# get label and units
xlabel, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff) xlabel, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff)
# Read data
if "mean" in node: if "mean" in node:
index = node["runs"].read().index(run.encode()) index = node["runs"].read().index(run.encode())
values, centers = node["mean"].read()[index] values, centers = node["mean"].read()[index]
else: else:
values, centers = node.read() values, centers = node.read()
if xlog: if xlog:
centers = centers + np.log10(unit_old.express(unit)) centers = centers + np.log10(unit_old.express(unit))
else: else:
centers = centers * unit_old.express(unit) centers = centers * unit_old.express(unit)
if ytransform is not None: if ytransform is not None:
values = ytransform(values) values = ytransform(values)
width = centers[1] - centers[0]
title = self._label_run(run, node, title, nml_key) # Set title
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
if put_time: if put_title:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"] P.title(title)
time_str = self.pp_params.plot.time_fmt.format( if label == None:
time.express(time_unit), time_unit.latex label = title
)
if len(title) > 0:
title = title + " | " + time_str
else:
title = time_str
P.title(title)
# Set colors
if color is None and not colors is None: if color is None and not colors is None:
if nml_color is None: if nml_color is None:
color = colors[run] color = colors[run]
@@ -850,11 +865,8 @@ class Plotter(Aggregator, BaseProcessor):
except: except:
color = colors(nml) color = colors(nml)
if label == None: # Actual plot
label = title
if kind == "bar": if kind == "bar":
width = centers[1] - centers[0]
P.bar(centers, values, width, log=ylog, color=color, label=label, **kwargs) P.bar(centers, values, width, log=ylog, color=color, label=label, **kwargs)
elif kind == "step": elif kind == "step":
if ylog: if ylog:
@@ -863,11 +875,13 @@ class Plotter(Aggregator, BaseProcessor):
else: else:
raise ValueError("kind must be 'bar' or 'step'") raise ValueError("kind must be 'bar' or 'step'")
# put labels
if not label is None: if not label is None:
P.xlabel(xlabel) P.xlabel(xlabel)
if not ylabel is None: if not ylabel is None:
P.ylabel(ylabel) P.ylabel(ylabel)
# Also diplay fit, previously saved
if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save: if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
slope = node.attrs.slope slope = node.attrs.slope
origin = node.attrs.origin origin = node.attrs.origin
@@ -878,14 +892,27 @@ class Plotter(Aggregator, BaseProcessor):
linewidth=2, linewidth=2,
color="orange", color="orange",
) )
# or a new one
P.ylim([None, 1.0])
if not fit is None: if not fit is None:
self._overlay_fit( self._overlay_fit(
centers, values, kind=fit, ls="--", lw=1.5, label=fitlabel centers, values, kind=fit, ls="--", lw=1.5, label=fitlabel
) )
# returns PlotInfo (for Galactica)
edges = np.append(centers - width / 2.0, centers[-1] + width / 2.0)
return PlotInfo(
plot_type=PlotType.HISTOGRAM,
xaxis_values=edges,
yaxis_values=values,
xaxis_log_scale=False,
yaxis_log_scale=ylog,
xaxis_label=xlabel,
yaxis_label=ylabel,
xaxis_unit=unit,
yaxis_unit=cst.none,
plot_title=title,
)
def _plot( def _plot(
self, self,
name_x, name_x,
@@ -898,19 +925,17 @@ class Plotter(Aggregator, BaseProcessor):
yunit=None, yunit=None,
xunit_coeff=1.0, xunit_coeff=1.0,
yunit_coeff=1.0, yunit_coeff=1.0,
ylog=False,
fit=None, fit=None,
fitlabel=None, fitlabel=None,
smooth=0, smooth=0,
nml_key=None, nml_key=None,
run=None, run=None,
runs=None,
yerr=None, yerr=None,
yerr_kind="std", yerr_kind="std",
sigma_err=2.0, sigma_err=2.0,
grid=False, grid=False,
put_time=False, put_time=False,
time_unit=cst.Myr, unit_time=cst.Myr,
colors=None, colors=None,
nml_color=None, nml_color=None,
legend=None, legend=None,
@@ -922,12 +947,15 @@ class Plotter(Aggregator, BaseProcessor):
Generic plot routine, with name_x and name_y two path in the hdf5 file Generic plot routine, with name_x and name_y two path in the hdf5 file
""" """
# Get proper hdf5 names
if not node_arg is None: if not node_arg is None:
name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg
# Get hdf5 nodes
node_x = self.save.get_node(name_x) node_x = self.save.get_node(name_x)
node_y = self.save.get_node(name_y) node_y = self.save.get_node(name_y)
# If the actual data is in a,other file, fetch it
if subname_x: if subname_x:
hdf5_x = tables.open_file(node_x.read()) hdf5_x = tables.open_file(node_x.read())
node_x = hdf5_x.get_node(subname_x) node_x = hdf5_x.get_node(subname_x)
@@ -935,6 +963,7 @@ class Plotter(Aggregator, BaseProcessor):
hdf5_y = tables.open_file(node_y.read()) hdf5_y = tables.open_file(node_y.read())
node_y = hdf5_y.get_node(subname_y) node_y = hdf5_y.get_node(subname_y)
# Find proper labels
xlabel, xunit_old, xunit = self._ax_label_unit( xlabel, xunit_old, xunit = self._ax_label_unit(
node_x, xlabel, xunit, xunit_coeff node_x, xlabel, xunit, xunit_coeff
) )
@@ -942,58 +971,24 @@ class Plotter(Aggregator, BaseProcessor):
node_y, ylabel, yunit, yunit_coeff node_y, ylabel, yunit, yunit_coeff
) )
P.xlabel(xlabel) # If relevent, get time
P.ylabel(ylabel)
if grid:
P.grid()
if ylog:
P.yscale("log")
if put_time: if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"] time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format( time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex time.express(unit_time), unit_time.latex
) )
if label is not None and len(label) > 0: if label is not None and len(label) > 0:
label = label + " | " + time_str label = label + " | " + time_str
else: else:
label = time_str label = time_str
# Manage the different forms in which the data may be stored :
# Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs)
if node_y._v_attrs.CLASS == "ARRAY": if node_y._v_attrs.CLASS == "ARRAY":
x = node_x.read() * xunit_old.express(xunit) x = node_x.read() * xunit_old.express(xunit)
y = node_y.read() * yunit_old.express(yunit) y = node_y.read() * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y) mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask] x, y = x[mask], y[mask]
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
if not run is None:
label = self._label_run(run, node_y, label, nml_key)
if colors is None:
(base_line,) = P.plot(x, y, label=label, **kwargs)
else:
if nml_color is None:
color = colors[run]
elif nml_color == "time":
time = (
self.save.root._v_attrs.time * self.comp.info["unit_time"]
).express(time_unit)
color = colors(time)
else:
nml = self.comp.get_nml(nml_color, run)
try:
color = colors[nml]
except:
color = colors(nml)
if yerr is None:
(base_line,) = P.plot(x, y, label=label, color=color, **kwargs)
else:
if isinstance(yerr, str):
yerr = self.save.get_node(yerr).read()
base_line, _, _ = P.errorbar(x, y, yerr=yerr, label=label, **kwargs)
elif "mean" in node_y: elif "mean" in node_y:
x = node_x.read() * xunit_old.express(xunit) x = node_x.read() * xunit_old.express(xunit)
y = node_y.mean.read() * yunit_old.express(yunit) y = node_y.mean.read() * yunit_old.express(yunit)
@@ -1014,7 +1009,6 @@ class Plotter(Aggregator, BaseProcessor):
else: else:
yerr_min = y yerr_min = y
yerr_max = y yerr_max = y
yerr = yerr_max - yerr_min yerr = yerr_max - yerr_min
mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(yerr) mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(yerr)
x, y, yerr, yerr_min, yerr_max = ( x, y, yerr, yerr_min, yerr_max = (
@@ -1024,44 +1018,51 @@ class Plotter(Aggregator, BaseProcessor):
yerr_min[mask], yerr_min[mask],
yerr_max[mask], yerr_max[mask],
) )
if not run is None: else:
label = self._label_run(run, node_y, label, nml_key) x, y = node_x[run], node_y[run]
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
if yerr_kind is None: if isinstance(yerr, str):
yerr = None yerr = self.save.get_node(yerr).read()
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
if not run is None:
label = self._label_run(run, node_y, label, nml_key)
# Look if special colors method is used
if colors is None:
if yerr is None:
(base_line,) = P.plot(x, y, label=label, **kwargs) (base_line,) = P.plot(x, y, label=label, **kwargs)
else:
base_line, _, _ = P.errorbar(x, y, yerr=yerr, label=label, **kwargs)
else:
if nml_color is None:
color = colors[run]
elif nml_color == "time":
time = (
self.save.root._v_attrs.time * self.comp.info["unit_time"]
).express(unit_time)
color = colors(time)
else:
nml = self.comp.get_nml(nml_color, run)
try:
color = colors[nml]
except:
color = colors(nml)
if yerr is None:
(base_line,) = P.plot(x, y, label=label, color=color, **kwargs)
else: else:
base_line, _, _ = P.errorbar( base_line, _, _ = P.errorbar(
x, y, yerr=[y - yerr_min, yerr_max - y], label=label, **kwargs x, y, yerr=yerr, color=color, label=label, **kwargs
) )
else: # Ax decorations
if runs is None: P.xlabel(xlabel)
runs = self.runs P.ylabel(ylabel)
for i, run in enumerate(runs): if grid:
x_run, y_run = node_x[run], node_y[run] P.grid()
x = x_run.read() * xunit_old.express(xunit)
y = y_run.read() * yunit_old.express(yunit)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth)
label_run = self._label_run(run, y_run, label, nml_key)
if colors is None:
(base_line,) = P.plot(x, y, label=label_run, **kwargs)
else:
if nml_color is None:
color = colors[i % len(colors)]
else:
nml = self.comp.get_nml(nml_color, run)
try:
color = colors[nml]
except:
color = colors(nml)
(base_line,) = P.plot(x, y, label=label_run, color=color, **kwargs)
if legend is None:
legend = True
if legend: if legend:
P.legend() P.legend()
@@ -1076,6 +1077,7 @@ class Plotter(Aggregator, BaseProcessor):
color=base_line.get_color(), color=base_line.get_color(),
label=fitlabel, label=fitlabel,
) )
if subname_x: if subname_x:
hdf5_x.close() hdf5_x.close()
if subname_y: if subname_y:
@@ -1603,11 +1605,16 @@ class Plotter(Aggregator, BaseProcessor):
dependencies=["radial_bins", "rad_avg_" + name], dependencies=["radial_bins", "rad_avg_" + name],
) )
self.rules["avg_map_" + name] = PlotRule(
self,
partial(self._plot_map, "avg_map_" + name),
"Map of the radial average of {}".format(name),
dependencies=["avg_map_" + name],
)
self.rules["fluct_" + name] = PlotRule( self.rules["fluct_" + name] = PlotRule(
self, self,
partial( partial(self._plot_map, "fluct_" + name, cmap="RdBu_r"),
self._plot_map, "fluct_" + name, vmin=0.01, vmax=100, cmap="RdBu_r"
),
"Fluctuation of {}".format(name), "Fluctuation of {}".format(name),
dependencies=["fluct_" + name], dependencies=["fluct_" + name],
) )
@@ -1666,7 +1673,7 @@ class Plotter(Aggregator, BaseProcessor):
# Radial # Radial
generic_rule(field + "r") generic_rule(field + "r")
# Othoradial # Orthoradial
generic_rule(field + "phi") generic_rule(field + "phi")
# Norm # Norm
generic_rule(field + "_norm") generic_rule(field + "_norm")
+54 -17
View File
@@ -149,7 +149,15 @@ class PostProcessor(HDF5Container):
cells_loaded = False cells_loaded = False
def __init__(self, path=None, num=None, path_out=None, pp_params=None, tag=None): def __init__(
self,
path=None,
num=None,
path_out=None,
pp_params=None,
tag=None,
unit_time=cst.year,
):
""" """
Creates the basic structures needed for the outputs Creates the basic structures needed for the outputs
""" """
@@ -170,6 +178,14 @@ class PostProcessor(HDF5Container):
self.path_out + "/cells_" + tag_name + format(num, "05") + ".h5" self.path_out + "/cells_" + tag_name + format(num, "05") + ".h5"
) )
self.pspec_filename = (
self.path_out + "/pspec_" + tag_name + format(num, "05") + ".h5"
)
self.filaments_filename = (
self.path_out + "/filaments_" + tag_name + format(num, "05") + ".pickle"
)
if not os.path.exists(self.path_out): if not os.path.exists(self.path_out):
os.makedirs(self.path_out) os.makedirs(self.path_out)
@@ -213,7 +229,7 @@ class PostProcessor(HDF5Container):
if not self.pp_params.pymses.multiprocessing: if not self.pp_params.pymses.multiprocessing:
self._rt.disable_multiprocessing() self._rt.disable_multiprocessing()
# Set the extend of the image # Set the extent of the image
self._radius = 0.5 / self.pp_params.pymses.zoom self._radius = 0.5 / self.pp_params.pymses.zoom
self.lbox = self.info["boxlen"] self.lbox = self.info["boxlen"]
@@ -273,12 +289,23 @@ class PostProcessor(HDF5Container):
self.log_id = "[{}, {}] ".format(self.run, self.num) self.log_id = "[{}, {}] ".format(self.run, self.num)
if os.path.exists(self.path_out + "/filaments.pickle"): if os.path.exists(self.filaments_filename):
with open(self.path_out + "/filaments.pickle", "rb") as f: with open(self.filaments_filename, "rb") as f:
self.fil = pickle.load(f) self.fil = pickle.load(f)
else: else:
self.fil = None self.fil = None
self.snapshot = Snapshot(
name=str(self.num),
description="",
time=(
self.info["time"] * self.info["unit_time"].express(unit_time),
unit_time,
),
directory_path=self.path,
data_reference="OUTPUT_{}".format(self.num),
)
self.def_rules() self.def_rules()
def load_cells(self): def load_cells(self):
@@ -964,12 +991,22 @@ class PostProcessor(HDF5Container):
) )
bins = np.zeros(r.shape, dtype=int) bins = np.zeros(r.shape, dtype=int)
for r0 in radial_bins[1:]: for r0 in radial_bins[1:-1]:
bins = bins + (r >= r0).astype(int) bins = bins + (r >= r0).astype(int)
vr_mean = mean_bin_vr[bins] vr_mean = mean_bin_vr[bins]
vphi_mean = mean_bin_vphi[bins] vphi_mean = mean_bin_vphi[bins]
# use linear interpolation
# v = ((r[i+1] - r)v[i] + (r - r[i])v[i + 1]) / (r[i+1] - r[i])
vr_mean = (radial_bins[bins + 1] - r) * vr_mean
vr_mean = vr_mean + (r - radial_bins[bins]) * mean_bin_vr[bins + 1]
vr_mean = vr_mean / (radial_bins[bins + 1] - radial_bins[bins])
vphi_mean = (radial_bins[bins + 1] - r) * vphi_mean
vphi_mean = vphi_mean + (r - radial_bins[bins]) * mean_bin_vphi[bins + 1]
vphi_mean = vphi_mean / (radial_bins[bins + 1] - radial_bins[bins])
vr = self.oct_getter_vr(dset) vr = self.oct_getter_vr(dset)
vphi = self.oct_getter_vphi(dset) vphi = self.oct_getter_vphi(dset)
alpha = (vphi - vphi_mean) * (vr - vr_mean) alpha = (vphi - vphi_mean) * (vr - vr_mean)
@@ -1055,7 +1092,7 @@ class PostProcessor(HDF5Container):
return sinks_dict return sinks_dict
def _pspec(self): def _pspec(self):
outfile = self.path_out + "/pspec.h5" outfile = self.pspec_filename
pspec_new.pspec(repo=self.path, iouts=[self.num], outfile=outfile) pspec_new.pspec(repo=self.path, iouts=[self.num], outfile=outfile)
return True return True
@@ -1086,9 +1123,10 @@ class PostProcessor(HDF5Container):
self.fil.create_mask( self.fil.create_mask(
verbose=verbose, verbose=verbose,
smooth_size=1 * u.pix, smooth_size=1 * u.pix,
adapt_thresh=2 * u.pix, adapt_thresh=4 * u.pix,
size_thresh=size_thresh * u.pix ** 2, size_thresh=size_thresh * u.pix ** 2,
glob_thresh=glob_thresh, glob_thresh=glob_thresh,
fill_hole_size=0.1 * u.pix ** 2,
) )
self.fil.medskel(verbose=verbose) self.fil.medskel(verbose=verbose)
self.fil.analyze_skeletons( self.fil.analyze_skeletons(
@@ -1099,8 +1137,7 @@ class PostProcessor(HDF5Container):
self.fil.exec_rht() self.fil.exec_rht()
self.fil.find_widths() self.fil.find_widths()
outfile = self.path_out + "/filaments.pickle" with open(self.filaments_filename, "wb") as f:
with open(outfile, "wb") as f:
pickle.dump(self.fil, f, pickle.HIGHEST_PROTOCOL) pickle.dump(self.fil, f, pickle.HIGHEST_PROTOCOL)
return True return True
@@ -1121,7 +1158,7 @@ class PostProcessor(HDF5Container):
def _filaments_forces(self): def _filaments_forces(self):
""" """
Compute forces within a filament (for disks) Compute forces within a filament (for disks), within the slice at z=0
""" """
GM = self.G * self.pp_params.disk.mass_star # Mass parameter GM = self.G * self.pp_params.disk.mass_star # Mass parameter
@@ -1134,11 +1171,11 @@ class PostProcessor(HDF5Container):
i_center, j_center = self._filaments_center() i_center, j_center = self._filaments_center()
# Get slices and projections at z = 0 # Get slices and projections at z = 0
vphi = self.save.get_node("/maps/slice_velphi_z").read() vphi = self.get_value("/maps/slice_velphi_z")
gr = self.save.get_node("/maps/slice_gr_z").read() gr = self.get_value("/maps/slice_gr_z")
Pz = self.save.get_node("/maps/slice_P_z").read() Pz = self.get_value("/maps/slice_P_z")
coldens = self.save.get_node("/maps/coldens_z").read() rho = self.get_value("/maps/slice_rho_z")
vr = self.save.get_node("/maps/slice_velr_z").read() vr = self.get_value("/maps/slice_velr_z")
# Get coordinates # Get coordinates
im_extent = np.array(self.save.root.maps._v_attrs.im_extent) * self.lbox im_extent = np.array(self.save.root.maps._v_attrs.im_extent) * self.lbox
@@ -1167,7 +1204,7 @@ class PostProcessor(HDF5Container):
# Thermal support # Thermal support
GPx, GPy = np.gradient(Pz) GPx, GPy = np.gradient(Pz)
gradPr = (xx * GPx + yy * GPy) / rr gradPr = (xx * GPx + yy * GPy) / rr
fP = gradPr / coldens fP = gradPr / rho
# Gravitational field # Gravitational field
e2 = (1.0 / 512) ** 2 e2 = (1.0 / 512) ** 2
@@ -1371,7 +1408,7 @@ class PostProcessor(HDF5Container):
"slice_velphi": "z", "slice_velphi": "z",
"slice_gr": "z", "slice_gr": "z",
"slice_P": "z", "slice_P": "z",
"coldens": "z", "slice_rho": "z",
"slice_velr": "z", "slice_velr": "z",
}, },
), ),
+17 -8
View File
@@ -22,22 +22,22 @@ disk: # Disk speficic parameters
pdf: # parameters for probability density functions pdf: # parameters for probability density functions
nb_bin : 200 # Number of bins for the PDF nb_bin : 100 # Number of bins for the PDF
range : [-1.5, 2.5] # Range of the PDF (log of fluctuation) range : [-1.5, 2.5] # Range of the PDF (log of fluctuation)
xmin_fit : 0.1 # Lower boundary of the fit (log of fluctuation) xmin_fit : 0.3 # Lower boundary of the fit (log of fluctuation)
xmax_fit : 1.5 # Upper boundary of the fit (log of fluctuation) xmax_fit : 1.5 # Upper boundary of the fit (log of fluctuation)
fit_cut : 1e-4 # Exclude value that are < fit_cut * maximum fit_cut : 1e-4 # Exclude value that are < fit_cut * maximum
filaments: # parameters for FilFinder filaments: # parameters for FilFinder
datamap : "rho_avg" datamap : "fluct_coldens"
verbose : False verbose : False
rmin : 0.15 # In fraction of the box (zoom to be taken into account) rmin : 0.15 # In fraction of the box (zoom to be taken into account)
rmax : 0.45 # In fraction of the box (idem) rmax : 0.45 # In fraction of the box (idem)
size_thresh : 200 # in pixels**2 size_thresh : 400 # in pixels**2
skel_thresh : 100 # in pixels skel_thresh : 50 # in pixels
branch_thresh : 100 # in pixels branch_thresh : 50 # in pixels
glob_thresh : 40 # in map unit glob_thresh : 1.5 # in datamap unit
pymses: # Parameters for Pymses reader pymses: # Parameters for Pymses reader
@@ -80,7 +80,7 @@ out: # Parameters for post processing
# {ext} : Extension defined above # {ext} : Extension defined above
# {name} : Name of the rule # {name} : Name of the rule
# {tag} : Tag defined above # {tag} : Tag defined above
# {nml[nml_key]} : The value of nml_key in the namelist (ex: amr_params/levelmin) # {nml[nml_key]} : The value of nml_key in the namelist (ex: {nml[amr_params/levelmin]})
process: # General setting of the post-processor module process: # General setting of the post-processor module
@@ -91,3 +91,12 @@ process: # General setting of the post-processor module
rules: # Specific rules parameters rules: # Specific rules parameters
turb_energy_threshold : -1 # Remove invalid data (<0 = no threshold) turb_energy_threshold : -1 # Remove invalid data (<0 = no threshold)
astrophysix: # Parameters for astrophysix and galactica
simu_fmt : "{tag}_{run}" # Format of the name of simulation
descr_fmt : "{tag}_{run}" # Format of the default description
# The following keys are accepted
# {run} : Name of the relevant run
# {tag} : Tag defined above
# {nml[nml_key]} : The value of nml_key in the namelist (ex: {nml[amr_params/levelmin]})
+92
View File
@@ -0,0 +1,92 @@
# coding: utf-8
from astrophysix.simdm.protocol import (
SimulationCode,
AlgoType,
Algorithm,
InputParameter,
PhysicalProcess,
Physics,
)
# Simulation code definition #
ramses = SimulationCode(
name="Ramses 3 (MHD)",
code_name="Ramses",
code_version="3.10.1",
alias="RAMSES_3",
url="https://www.ics.uzh.ch/~teyssier/ramses/RAMSES.html",
description="Ramses MHD code",
)
# => Add algorithms : available algorithm types are :
# - AlgoType.AdaptiveMeshRefinement
# - AlgoType.VoronoiMovingMesh
# - AlgoType.SmoothParticleHydrodynamics
# - AlgoType.Godunov
# - AlgoType.PoissonMultigrid
# - AlgoType.PoissonConjugateGradient
# - AlgoType.ParticleMesh
# - AlgoType.FriendOfFriend
# - AlgoType.HLLCRiemann
# - AlgoType.RayTracer
ramses.algorithms.add(
Algorithm(algo_type=AlgoType.AdaptiveMeshRefinement, description="AMR")
)
ramses.algorithms.add(
Algorithm(algo_type=AlgoType.Godunov, description="Godunov scheme")
)
ramses.algorithms.add(
Algorithm(algo_type=AlgoType.HLLCRiemann, description="HLLC Riemann solver")
)
ramses.algorithms.add(
Algorithm(
algo_type=AlgoType.PoissonMultigrid, description="Multigrid Poisson solver"
)
)
# => Add input parameters
ramses.input_parameters.add(
InputParameter(
key="amr_params/levelmin",
name="lmin",
description="min. level of AMR refinement",
)
)
ramses.input_parameters.add(
InputParameter(
key="amr_params/levelmax",
name="lmax",
description="max. level of AMR refinement",
)
)
ramses.input_parameters.add(
InputParameter(
key="amr_params/jeans_refine",
name="jeans_refine",
description="Array, number of cells needed to resolve the Jeans lenght at each level from lmin",
)
)
ramses.input_parameters.add(
InputParameter(
key="cloud_params/beta_cool", name="beta", description="Cooling parameter"
)
)
# => Add physical processes : available physics are :
# - Physics.SelfGravity
# - Physics.Hydrodynamics
# - Physics.MHD
# - Physics.StarFormation
# - Physics.SupernovaeFeedback
# - Physics.AGNFeedback
# - Physics.MolecularCooling
ramses.physical_processes.add(
PhysicalProcess(
physics=Physics.Hydrodynamics, description="Hydrodynamical equations are solved"
)
)
ramses.physical_processes.add(
PhysicalProcess(physics=Physics.SelfGravity, description="Self-Gravity is applied.")
)
+40 -26
View File
@@ -1,14 +1,16 @@
# coding: utf-8 # coding: utf-8
import pymses.utils.constants as cst import astrophysix.units as cst
create_unit = cst.Unit.create_unit
def parse_exp_unit(u): def parse_exp_unit(u):
splitted = u.split("^") splitted = u.split("^")
name_u = cst.Unit.from_name(splitted[0]).latex name_u = cst.Unit.from_name(splitted[0]).latex.replace("text", "math")
exp = "" exp = ""
if len(splitted) > 1: if len(splitted) > 1:
exp = "$^{" + str(splitted[1]) + "}$" exp = "^{" + str(splitted[1]) + "}"
return name_u + exp return name_u + exp
@@ -17,66 +19,78 @@ def convert_exp(number, digits=4):
splitted = "{num:.{digits}g}".format(num=number, digits=digits).split("e") splitted = "{num:.{digits}g}".format(num=number, digits=digits).split("e")
# If no need of scientific notation (low number of digits) # If no need of scientific notation (low number of digits)
if len(splitted) == 1: if len(splitted) == 1:
return "${}$".format(splitted[0]) return "{}".format(splitted[0])
else: else:
coeff = splitted[0] coeff = splitted[0]
exp = splitted[1] exp = splitted[1]
exp_str = "10^{" + str(int(exp)) + "}" exp_str = "10^{" + str(int(exp)) + "}"
if float(coeff) == 1.0: if float(coeff) == 1.0:
return "$" + exp_str + "$" return exp_str
else: else:
return "${}\\times {}$".format(coeff, exp_str) return "{}\\times {}".format(coeff, exp_str)
def unit_str(unit, base=None, prefix=""): def unit_str(unit, base=None, prefix="", format=" [{unit}]"):
"""
Format a unit in matplotlib parsable latex
unit : astrophysics.units.unit.Unit
base : astrophysics.units.unit.Unit to use as base unit (if None `unit` is used)
prefix : str to put befor the unit
format : str with the {unit} key, to put external decoration
"""
if unit == cst.none: if unit == cst.none:
return "" return ""
elif not base is None: elif not base is None:
coeff = unit.express(base) coeff = unit.express(base)
return unit_str(base, prefix=convert_exp(coeff) + " ") u_str = unit_str(base, prefix=convert_exp(coeff) + " ")
elif len(unit.latex) > 0: elif len(unit.latex) > 0:
if ("." in unit.latex or "^" in unit.latex) and not "$" in unit.latex: if "." in unit.latex or "^" in unit.latex:
base_str = ".".join(map(parse_exp_unit, unit.name.split("."))) base_str = ".".join(map(parse_exp_unit, unit.name.split(".")))
return r" [{}{}]".format(prefix, base_str) u_str = r"${}{}$".format(prefix, base_str)
else: else:
return r" [{}{}]".format(prefix, unit.latex) u_str = r"${}{}$".format(prefix, unit.latex.replace("text", "math"))
elif len(unit.name) > 0: elif len(unit.name) > 0:
try: try:
base_str = ".".join(map(parse_exp_unit, unit.name.split("."))) base_str = ".".join(map(parse_exp_unit, unit.name.split(".")))
u_str = r" [{}{}]".format(prefix, base_str) u_str = r"${}{}$".format(prefix, base_str)
except: except:
u_str = r" [{}{}]".format(prefix, unit.name) u_str = r"${}{}$".format(prefix, unit.name)
return u_str
else: else:
base_str = ".".join( base_str = ".".join(
map(parse_exp_unit, unit._decompose_base_units().split(".")) map(parse_exp_unit, unit._decompose_base_units().split("."))
) )
return r" [{}{} {}]".format(prefix, unit.coeff, base_str) u_str = r"${}{} {}$".format(prefix, unit.coeff, base_str)
return format.format(unit=u_str)
cst.Msun_pc3 = cst.create_unit( cst.coldens = create_unit(
"Msun.pc^-2", base_unit=cst.Msun / cst.pc ** 2, descr="Column density"
)
cst.km_s = create_unit("km.s^-1", base_unit=cst.km / cst.s, descr="Speed")
cst.Msun_pc3 = create_unit(
"Msun.pc^-3", base_unit=cst.Msun / cst.pc ** 3, descr="Density" "Msun.pc^-3", base_unit=cst.Msun / cst.pc ** 3, descr="Density"
) )
cst.Msun_pc3 = cst.create_unit( cst.kg_m3 = create_unit("kg.m^-3", base_unit=cst.kg / cst.m ** 3, descr="Density")
"Msun.pc^-3", base_unit=cst.Msun / cst.pc ** 3, descr="Density"
)
cst.ssfr = cst.create_unit( cst.ssfr = create_unit(
"Msun.yr^-1.pc^-2", "Msun.year^-1.pc^-2",
base_unit=cst.Msun / cst.year / cst.pc ** 2, base_unit=cst.Msun / cst.year / cst.pc ** 2,
descr="Surfacic SFR", descr="Surfacic SFR",
latex="M$_{\odot}$.yr$^{-1}$.pc$^{-2}$",
) )
# latex='M$_{\odot}$.yr$^{-1}$.pc$^{-2}$')
cst.ssfrG = cst.create_unit( cst.ssfrG = create_unit(
"Msun.Gyr^-1.pc^-2", "Msun.Gyr^-1.pc^-2",
base_unit=1e-9 * cst.Msun / cst.year / cst.pc ** 2, base_unit=1e-9 * cst.Msun / cst.year / cst.pc ** 2,
descr="Surfacic SFR", descr="Surfacic SFR",
latex="M$_{\odot}$.Gyr$^{-1}$.pc$^{-2}$", latex="\mathrm{M}_{\odot}.\mathrm{Gyr}^{-1}.\mathrm{pc}^{-2}",
) )
cst.uG = cst.create_unit( cst.uG = create_unit(
"μG", base_unit=1e-10 * cst.T, descr="Micro Gauss", latex="$\mu\mathrm{G}$" "μG", base_unit=1e-10 * cst.T, descr="Micro Gauss", latex="\\mu\\mathrm{G}"
) )