diff --git a/comparator.py b/comparator.py index 0410110..ee050ef 100644 --- a/comparator.py +++ b/comparator.py @@ -17,6 +17,7 @@ class Comparator(Aggregator, HDF5Container): pp_params=default_params(), selector=None, tag=None, + unit_time=cst.year, **kwargs ): """ @@ -53,7 +54,11 @@ class Comparator(Aggregator, HDF5Container): for num in self.nums[run]: self.pp[run][num] = PostProcessor( - path_run, num, path_out=path_out_run, pp_params=self.pp_params + path_run, + num, + path_out=path_out_run, + pp_params=self.pp_params, + unit_time=unit_time, ) run0 = self.runs[0] diff --git a/plotter.py b/plotter.py index ef7b6db..eb46242 100644 --- a/plotter.py +++ b/plotter.py @@ -17,7 +17,8 @@ from scipy.stats import linregress from numpy.polynomial.polynomial import polyfit from scipy.ndimage.filters import gaussian_filter1d from scipy import optimize - +from astrophysix.simdm.datafiles import Datafile, PlotType, PlotInfo +from astrophysix.utils.file import FileType import matplotlib as mpl if os.environ.get("DISPLAY", "") == "": @@ -26,12 +27,15 @@ if os.environ.get("DISPLAY", "") == "": import pylab as P from comparator import * import pspec_read +import datetime -P.rcParams["image.cmap"] = "plasma" -P.rcParams["savefig.dpi"] = 400 +filetype_from_ext = {ext: ft for ft in FileType for ext in ft.extension_list} -tex_params = {"text.latex.preamble": r"\usepackage{amsmath}"} -P.rcParams.update(tex_params) + +def not_array_error(err): + epy2 = "object does not support indexing" + epy3 = "object is not subscriptable" + return str(err)[-len(epy2) :] == epy2 or str(err)[-len(epy3) :] == epy3 class PlotRule(Rule): @@ -42,7 +46,7 @@ class PlotRule(Rule): def plot(self, save, arg, **kwargs): """ - Set the plotter's storage to 'save' and exetute the rule + Set the plotter's storage to 'save' and execute the rule Parameters ---------- @@ -53,6 +57,12 @@ class PlotRule(Rule): self.postproc.save = save return self.process_fn(arg, **kwargs) + def datafile(self, arg): + return Datafile( + name=self.description + "_" + arg, + description=self.description + " ({})".format(arg), + ) + class Plotter(Aggregator, BaseProcessor): """ @@ -99,6 +109,7 @@ class Plotter(Aggregator, BaseProcessor): pp_params=None, selector=None, tag=None, + unit_time=cst.year, **kwargs, ): @@ -135,7 +146,13 @@ class Plotter(Aggregator, BaseProcessor): # Get comparator object self.comp = Comparator( - path, self.runs, self.nums, path_out, self.pp_params, selector=self.selector + path, + self.runs, + self.nums, + path_out, + self.pp_params, + unit_time=unit_time, + selector=self.selector, ) # Get postprocesor objets for each run @@ -147,17 +164,56 @@ class Plotter(Aggregator, BaseProcessor): # Define rules self.def_rules() + # generate astrophysix's simulations object + self.gen_simus() + self.save = None + def gen_simus(self): + self.simulations = {} + simu_fmt = self.pp_params.astrophysix.simu_fmt + descr_fmt = self.pp_params.astrophysix.descr_fmt + tag = self.pp_params.out.tag + for run in self.runs: + pp = self.pp[run][self.nums[run][0]] + nml = self.comp.namelist[run] + name = simu_fmt.format(run=run, tag=tag, nml=nml) + exec_time = str(datetime.datetime.fromtimestamp(os.stat(pp.path).st_ctime)) + description = descr_fmt.format(run=run, tag=tag, nml=nml) + simu = Simulation( + simu_code=ramses, + name=name, + alias=name.upper(), + description=description, + directory_path=pp.path, + execution_time=exec_time, + ) + + for param in ramses.input_parameters: + try: + param_setting = ParameterSetting( + input_param=param, + value=self.comp.get_nml(param.key, run), + visibility=ParameterVisibility.BASIC_DISPLAY, + ) + simu.parameter_settings.add(param_setting) + except KeyError as e: + self._log("key {} not found".format(e), "WARNING") + except AttributeError as e: + self._log("{}".format(e), "WARNING") + + self.simulations[run] = simu + def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs): """ Check if the dependency belongs to the plotter object or to another one (comp, pp, ..) """ if dep in self.comp.rules: - done = self.comp.process( + result = self.comp.process( dep, dep_arg, overwrite, overwrite_dep=self.overwrite_dep ) - self.just_done.extend(done) + if result is not None: + self.just_done.append(done) else: super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, **kwargs) @@ -172,19 +228,13 @@ class Plotter(Aggregator, BaseProcessor): ) def _process_rule( - self, - name, - rule, - arg, - overwrite=False, - ax=None, - movie=False, - from_cells=False, - **kwargs, + self, name, rule, arg, overwrite=False, ax=None, from_cells=False, **kwargs ): """ Open storage and figure if needed before processing a rule """ + + # Set full name according to argument if not arg is None: name_full = ( name @@ -200,144 +250,89 @@ class Plotter(Aggregator, BaseProcessor): else: name_full = name - if rule.is_valid(arg): - if rule.kind == "classic" or rule.kind == "cells": - if "select" in kwargs: - select = kwargs.pop("select") - runs, nums = self.selector.select(**select) - elif "runs" in kwargs: - runs = kwargs.pop("runs") - if isinstance(runs, RunSelector): - nums = runs.nums - runs = runs.runs - else: - nums = self.nums - else: - runs = self.runs - nums = self.nums - - i = 0 - for run in runs: - files = [] - for num in nums[run]: - plot_filename = self._find_filename(name_full, run, num) - - if from_cells or rule.kind == "cells": - if not os.exists(self.pp[run][num].cells_filename): - self.pp[run][num].load_cells() - self.pp[run][num].unload_cells() - save = tables.open_file(self.pp[run][num].cells_filename) - elif rule.kind == "classic": - save = tables.open_file(self.pp[run][num].filename) - else: - save = tables.open_file(self.comp.filename, "r") - try: - self._plot_rule( - rule, - save, - arg, - plot_filename, - overwrite, - ax=ax[i], - run=run, - **kwargs, - ) - except TypeError as e: - if str(e) in [ - "'LocatableAxes' object does not support indexing", - "'AxesSubplot' object does not support indexing", - "'AxesSubplot' object is not subscriptable", - "'Axes' object is not subscriptable", - "'LocatableAxes' object is not subscriptable", - ]: - self._plot_rule( - rule, - save, - arg, - plot_filename, - overwrite, - ax=ax, - run=run, - **kwargs, - ) - elif ax is None: - fig = P.figure() - self._plot_rule( - rule, - save, - arg, - plot_filename, - overwrite, - ax=P.gca(), - run=run, - **kwargs, - ) - else: - raise - finally: - save.close() - i = i + 1 - files.append(plot_filename) - else: - - if "select" in kwargs and not "runs" in kwargs: - select = kwargs.pop("select") - runs, nums = self.selector.select(**select) - if not rule.kind == "runs": - kwargs["runs"] = runs - elif rule.kind == "runs" and "runs" in kwargs: - runs = kwargs.pop("runs") - else: - runs = self.runs - - if ax is None: - ax = P.gca() - if rule.kind == "series" and len(runs) == 1: - run = self.runs[0] - plot_filename = self._find_filename(name_full, run) - else: - plot_filename = self._find_filename(name_full) - save = tables.open_file(self.comp.filename, "r") - try: - if rule.kind == "runs": - for i, run in enumerate(runs): - try: - self._plot_rule( - rule, - save, - arg, - plot_filename, - overwrite, - ax=ax[i], - run=run, - **kwargs, - ) - except TypeError as e: - if str(e) in [ - "'LocatableAxes' object does not support indexing", - "'AxesSubplot' object does not support indexing", - "'AxesSubplot' object is not subscriptable", - "'Axes' object is not subscriptable", - "'LocatableAxes' object is not subscriptable", - ]: - self._plot_rule( - rule, - save, - arg, - plot_filename, - overwrite, - ax=ax, - run=run, - **kwargs, - ) - else: - self._plot_rule( - rule, save, arg, plot_filename, overwrite, ax, **kwargs - ) - finally: - save.close() - else: + # Exit if not valid + if not rule.is_valid(arg): self._log("{} is not valid in this context".format(name_full), "ERROR") + return + + # get filetype of the output + filetype = filetype_from_ext[self.pp_params.out.ext] + + # Select runs and nums + if "select" in kwargs: + select = kwargs.pop("select") + runs, nums = self.selector.select(**select) + else: + runs = self.runs + nums = self.nums + + datafiles = [] + + # Several plots + if rule.kind == "classic" or rule.kind == "cells": + run_num = [(run, num) for run in runs for num in nums[run]] + else: + run_num = [(run, None) for run in runs] + + for i, (run, num) in enumerate(run_num): + # Find filename + plot_filename = self._find_filename(name_full, run, num) + + # Find ax + try: + real_ax = ax[i] + except TypeError as e: + if ax is None: + fig, real_ax = P.subplots(1, 1) + elif not_array_error(e): + real_ax = ax + else: + raise + + # Find plot save + if from_cells or rule.kind == "cells": + if not os.exists(self.pp[run][num].cells_filename): + self.pp[run][num].load_cells() + self.pp[run][num].unload_cells() + save = tables.open_file(self.pp[run][num].cells_filename) + elif rule.kind == "classic": + save = tables.open_file(self.pp[run][num].filename) + else: + save = tables.open_file(self.comp.filename, "r") + + # Call plot routine + try: + plot_info = self._plot_rule( + rule, + save, + arg, + plot_filename, + overwrite, + ax=real_ax, + run=run, + **kwargs, + ) + finally: + save.close() + + # Save in astrophysix format + df = rule.datafile(arg) + df[filetype] = plot_filename + if plot_info is not None: + df.plot_info = plot_info + if num is not None: + snap = self.pp[run][num].snapshot + + if overwrite and df.name in snap.datafiles: + del snap.datafiles[df.name] + elif df.name not in snap.datafiles: + snap.datafiles.add(df) + + if snap not in self.simulations[run].snapshots: + self.simulations[run].snapshots.add(snap) + + datafiles.append(df) + return datafiles def _plot_rule(self, rule, save, arg, plot_filename, overwrite, ax, **kwargs): """ @@ -345,7 +340,7 @@ class Plotter(Aggregator, BaseProcessor): """ P.sca(ax) if self._needs_computation(overwrite, plot_filename): - rule.plot(save, arg, **kwargs) + plot_info = rule.plot(save, arg, **kwargs) if not self.pp_params.out.interactive: P.tight_layout(pad=1) @@ -360,6 +355,7 @@ class Plotter(Aggregator, BaseProcessor): if not self.pp_params.out.interactive: P.close() + return plot_info else: self._log("Plot {} is already done, skipping...".format(plot_filename)) @@ -396,7 +392,7 @@ class Plotter(Aggregator, BaseProcessor): ext=self.pp_params.out.ext, ) - def _label_run(self, run, node, label, nml_key): + def _label_run(self, run, node, label, nml_key, time=None): """ Set up a label for the run from the namelist and parameters """ @@ -435,7 +431,7 @@ class Plotter(Aggregator, BaseProcessor): label_run = label return label_run - def _ax_label_unit(self, node, label, unit, unit_coeff): + def _ax_label_unit(self, node, label, unit, unit_coeff, put_units=True): """ Find appropriate labels for axis """ @@ -457,20 +453,38 @@ class Plotter(Aggregator, BaseProcessor): if unit is None: unit = unit_old - if not unit_coeff == 1: - base = unit - unit = unit_coeff * unit - label = label + unit_str(unit, base=base) - else: - label = label + unit_str(unit) + if put_units: + if not unit_coeff == 1: + base = unit + unit = unit_coeff * unit + label = label + unit_str(unit, base=base) + else: + label = label + unit_str(unit) return label, unit_old, unit + def _snapshot_title(self, run, node, title, nml_key, put_time, unit_time=cst.Myr): + title = self._label_run(run, node, title, nml_key) + + if put_time: + time = self.save.root._v_attrs.time * self.comp.info["unit_time"] + u_str = unit_str(unit_time, format="{unit}") + time_str = self.pp_params.plot.time_fmt.format( + time.express(unit_time), u_str + ) + if len(title) > 0: + title = title + " | " + time_str + else: + title = time_str + return title + def _plot_map( self, name, ax_los, run, + xlabel=None, + ylabel=None, label=None, unit=None, unit_coeff=1.0, @@ -480,12 +494,14 @@ class Plotter(Aggregator, BaseProcessor): put_title=True, nml_key=None, put_time=True, - time_unit=cst.Myr, + unit_time=cst.Myr, + put_units=True, unit_space=cst.pc, cmap="plasma", norm="log", put_cbar=True, autoscale=True, + transform=None, **kwargs, ): """ @@ -503,9 +519,13 @@ class Plotter(Aggregator, BaseProcessor): node = self.save.get_node("/maps/{}_{}".format(name, ax_los)) dmap = node.read() - label, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff) + label, unit_old, unit = self._ax_label_unit( + node, label, unit, unit_coeff, put_units + ) dmap = dmap * unit_old.express(unit) + if transform is not None: + dmap = transform(dmap) if norm == "log": norm = mpl.colors.LogNorm() @@ -521,32 +541,28 @@ class Plotter(Aggregator, BaseProcessor): P.locator_params(axis="both", nbins=self.pp_params.plot.ntick) - P.xlabel(self._ax_title[ax_h] + unit_str(unit_space)) - P.ylabel(self._ax_title[ax_v] + unit_str(unit_space)) + if xlabel is None: + xlabel = self._ax_title[ax_h] + if ylabel is None: + ylabel = self._ax_title[ax_v] + if put_units: + xlabel = xlabel + unit_str(unit_space) + ylabel = ylabel + unit_str(unit_space) + P.xlabel(xlabel) + P.ylabel(ylabel) try: cbar = P.colorbar(im, cax=P.gca().cax) except AttributeError: cbar = P.colorbar() + if put_title: + title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time) + P.title(title) + if not label is None: cbar.set_label(label) - if put_title: - title = self._label_run(run, node, title, nml_key) - - if put_time: - time = self.save.root._v_attrs.time * self.comp.info["unit_time"] - time_str = self.pp_params.plot.time_fmt.format( - time.express(time_unit), time_unit.latex - ) - if len(title) > 0: - title = title + " | " + time_str - else: - title = time_str - - P.title(title) - for i, plot_overlay in enumerate(overlays): if plot_overlay in self.overlays: plot_overlay = self.overlays[plot_overlay] @@ -556,6 +572,23 @@ class Plotter(Aggregator, BaseProcessor): except: plot_overlay(ax_los, im_extent) + return PlotInfo( + plot_type=PlotType.IMAGE, + xaxis_values=np.linspace(im_extent[0], im_extent[1], dmap.shape[0] + 1), + yaxis_values=np.linspace(im_extent[2], im_extent[3], dmap.shape[1] + 1), + values=dmap, + xaxis_log_scale=False, + yaxis_log_scale=False, + values_log_scale=(norm == mpl.colors.LogNorm()), + xaxis_label=xlabel, + yaxis_label=ylabel, + values_label=label, + xaxis_unit=unit_space, + yaxis_unit=unit_space, + values_unit=unit, + plot_title=title, + ) + def _overlay_contour( self, ax_los, @@ -724,7 +757,7 @@ class Plotter(Aggregator, BaseProcessor): nml_key=None, put_title=True, put_time=True, - time_unit=cst.Myr, + unit_time=cst.Myr, **kwargs, ): """ @@ -749,26 +782,13 @@ class Plotter(Aggregator, BaseProcessor): if not ylabel is None: P.ylabel(ylabel) + title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time) if put_title: - title = self._label_run(run, node, title, nml_key) - - if put_time: - time = self.save.root._v_attrs.time * self.comp.info["unit_time"] - time_str = self.pp_params.plot.time_fmt.format( - time.express(time_unit), time_unit.latex - ) - if len(title) > 0: - title = title + " | " + time_str - else: - title = time_str - P.title(title) - - if label is None: + if label == None: label = title - P.plot(bin_centers, mean_bin, label=label, **kwargs) - P.plot(bin_centers, mean_bin, label=title, **kwargs) + P.plot(bin_centers, mean_bin, label=label, **kwargs) def _plot_hist( self, @@ -781,10 +801,11 @@ class Plotter(Aggregator, BaseProcessor): unit_coeff=1.0, ytransform=None, label=None, + put_title=True, title=None, nml_key=None, put_time=True, - time_unit=cst.Myr, + unit_time=cst.Myr, xlog=None, ylog=False, kind="bar", @@ -799,47 +820,41 @@ class Plotter(Aggregator, BaseProcessor): """ Plot an histogram (PDF, etc ...) """ + # Get node if not ax_los is None: name = name + "_" + ax_los - node = self.save.get_node(group + name) - if xlog is None: try: xlog = node._v_attrs_.logbins except: xlog = False + # get label and units xlabel, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff) + # Read data if "mean" in node: index = node["runs"].read().index(run.encode()) values, centers = node["mean"].read()[index] else: values, centers = node.read() - if xlog: centers = centers + np.log10(unit_old.express(unit)) else: centers = centers * unit_old.express(unit) - if ytransform is not None: values = ytransform(values) + width = centers[1] - centers[0] - title = self._label_run(run, node, title, nml_key) - - if put_time: - time = self.save.root._v_attrs.time * self.comp.info["unit_time"] - time_str = self.pp_params.plot.time_fmt.format( - time.express(time_unit), time_unit.latex - ) - if len(title) > 0: - title = title + " | " + time_str - else: - title = time_str - - P.title(title) + # Set title + title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time) + if put_title: + P.title(title) + if label == None: + label = title + # Set colors if color is None and not colors is None: if nml_color is None: color = colors[run] @@ -850,11 +865,8 @@ class Plotter(Aggregator, BaseProcessor): except: color = colors(nml) - if label == None: - label = title - + # Actual plot if kind == "bar": - width = centers[1] - centers[0] P.bar(centers, values, width, log=ylog, color=color, label=label, **kwargs) elif kind == "step": if ylog: @@ -863,11 +875,13 @@ class Plotter(Aggregator, BaseProcessor): else: raise ValueError("kind must be 'bar' or 'step'") + # put labels if not label is None: P.xlabel(xlabel) if not ylabel is None: P.ylabel(ylabel) + # Also diplay fit, previously saved if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save: slope = node.attrs.slope origin = node.attrs.origin @@ -878,14 +892,27 @@ class Plotter(Aggregator, BaseProcessor): linewidth=2, color="orange", ) - - P.ylim([None, 1.0]) - + # or a new one if not fit is None: self._overlay_fit( centers, values, kind=fit, ls="--", lw=1.5, label=fitlabel ) + # returns PlotInfo (for Galactica) + edges = np.append(centers - width / 2.0, centers[-1] + width / 2.0) + return PlotInfo( + plot_type=PlotType.HISTOGRAM, + xaxis_values=edges, + yaxis_values=values, + xaxis_log_scale=False, + yaxis_log_scale=ylog, + xaxis_label=xlabel, + yaxis_label=ylabel, + xaxis_unit=unit, + yaxis_unit=cst.none, + plot_title=title, + ) + def _plot( self, name_x, @@ -898,19 +925,17 @@ class Plotter(Aggregator, BaseProcessor): yunit=None, xunit_coeff=1.0, yunit_coeff=1.0, - ylog=False, fit=None, fitlabel=None, smooth=0, nml_key=None, run=None, - runs=None, yerr=None, yerr_kind="std", sigma_err=2.0, grid=False, put_time=False, - time_unit=cst.Myr, + unit_time=cst.Myr, colors=None, nml_color=None, legend=None, @@ -922,12 +947,15 @@ class Plotter(Aggregator, BaseProcessor): Generic plot routine, with name_x and name_y two path in the hdf5 file """ + # Get proper hdf5 names if not node_arg is None: name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg + # Get hdf5 nodes node_x = self.save.get_node(name_x) node_y = self.save.get_node(name_y) + # If the actual data is in a,other file, fetch it if subname_x: hdf5_x = tables.open_file(node_x.read()) node_x = hdf5_x.get_node(subname_x) @@ -935,6 +963,7 @@ class Plotter(Aggregator, BaseProcessor): hdf5_y = tables.open_file(node_y.read()) node_y = hdf5_y.get_node(subname_y) + # Find proper labels xlabel, xunit_old, xunit = self._ax_label_unit( node_x, xlabel, xunit, xunit_coeff ) @@ -942,58 +971,24 @@ class Plotter(Aggregator, BaseProcessor): node_y, ylabel, yunit, yunit_coeff ) - P.xlabel(xlabel) - P.ylabel(ylabel) - - if grid: - P.grid() - - if ylog: - P.yscale("log") - + # If relevent, get time if put_time: time = self.save.root._v_attrs.time * self.comp.info["unit_time"] time_str = self.pp_params.plot.time_fmt.format( - time.express(time_unit), time_unit.latex + time.express(unit_time), unit_time.latex ) if label is not None and len(label) > 0: label = label + " | " + time_str else: label = time_str + # Manage the different forms in which the data may be stored : + # Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs) if node_y._v_attrs.CLASS == "ARRAY": x = node_x.read() * xunit_old.express(xunit) y = node_y.read() * yunit_old.express(yunit) mask = np.isfinite(x) & np.isfinite(y) x, y = x[mask], y[mask] - if smooth > 0: - y = gaussian_filter1d(y, sigma=smooth) - if not run is None: - label = self._label_run(run, node_y, label, nml_key) - if colors is None: - (base_line,) = P.plot(x, y, label=label, **kwargs) - else: - if nml_color is None: - color = colors[run] - elif nml_color == "time": - time = ( - self.save.root._v_attrs.time * self.comp.info["unit_time"] - ).express(time_unit) - color = colors(time) - else: - nml = self.comp.get_nml(nml_color, run) - try: - color = colors[nml] - except: - color = colors(nml) - if yerr is None: - (base_line,) = P.plot(x, y, label=label, color=color, **kwargs) - else: - if isinstance(yerr, str): - yerr = self.save.get_node(yerr).read() - - base_line, _, _ = P.errorbar(x, y, yerr=yerr, label=label, **kwargs) - elif "mean" in node_y: x = node_x.read() * xunit_old.express(xunit) y = node_y.mean.read() * yunit_old.express(yunit) @@ -1014,7 +1009,6 @@ class Plotter(Aggregator, BaseProcessor): else: yerr_min = y yerr_max = y - yerr = yerr_max - yerr_min mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(yerr) x, y, yerr, yerr_min, yerr_max = ( @@ -1024,44 +1018,51 @@ class Plotter(Aggregator, BaseProcessor): yerr_min[mask], yerr_max[mask], ) - if not run is None: - label = self._label_run(run, node_y, label, nml_key) + else: + x, y = node_x[run], node_y[run] + mask = np.isfinite(x) & np.isfinite(y) + x, y = x[mask], y[mask] - if yerr_kind is None: - yerr = None + if isinstance(yerr, str): + yerr = self.save.get_node(yerr).read() + + if smooth > 0: + y = gaussian_filter1d(y, sigma=smooth) + if not run is None: + label = self._label_run(run, node_y, label, nml_key) + + # Look if special colors method is used + if colors is None: + if yerr is None: (base_line,) = P.plot(x, y, label=label, **kwargs) + else: + base_line, _, _ = P.errorbar(x, y, yerr=yerr, label=label, **kwargs) + else: + if nml_color is None: + color = colors[run] + elif nml_color == "time": + time = ( + self.save.root._v_attrs.time * self.comp.info["unit_time"] + ).express(unit_time) + color = colors(time) + else: + nml = self.comp.get_nml(nml_color, run) + try: + color = colors[nml] + except: + color = colors(nml) + if yerr is None: + (base_line,) = P.plot(x, y, label=label, color=color, **kwargs) else: base_line, _, _ = P.errorbar( - x, y, yerr=[y - yerr_min, yerr_max - y], label=label, **kwargs + x, y, yerr=yerr, color=color, label=label, **kwargs ) - else: - if runs is None: - runs = self.runs - for i, run in enumerate(runs): - x_run, y_run = node_x[run], node_y[run] - x = x_run.read() * xunit_old.express(xunit) - y = y_run.read() * yunit_old.express(yunit) - mask = np.isfinite(x) & np.isfinite(y) - x, y = x[mask], y[mask] - if smooth > 0: - y = gaussian_filter1d(y, sigma=smooth) - label_run = self._label_run(run, y_run, label, nml_key) - if colors is None: - (base_line,) = P.plot(x, y, label=label_run, **kwargs) - else: - if nml_color is None: - color = colors[i % len(colors)] - else: - nml = self.comp.get_nml(nml_color, run) - try: - color = colors[nml] - except: - color = colors(nml) - (base_line,) = P.plot(x, y, label=label_run, color=color, **kwargs) - if legend is None: - legend = True - + # Ax decorations + P.xlabel(xlabel) + P.ylabel(ylabel) + if grid: + P.grid() if legend: P.legend() @@ -1076,6 +1077,7 @@ class Plotter(Aggregator, BaseProcessor): color=base_line.get_color(), label=fitlabel, ) + if subname_x: hdf5_x.close() if subname_y: @@ -1603,11 +1605,16 @@ class Plotter(Aggregator, BaseProcessor): dependencies=["radial_bins", "rad_avg_" + name], ) + self.rules["avg_map_" + name] = PlotRule( + self, + partial(self._plot_map, "avg_map_" + name), + "Map of the radial average of {}".format(name), + dependencies=["avg_map_" + name], + ) + self.rules["fluct_" + name] = PlotRule( self, - partial( - self._plot_map, "fluct_" + name, vmin=0.01, vmax=100, cmap="RdBu_r" - ), + partial(self._plot_map, "fluct_" + name, cmap="RdBu_r"), "Fluctuation of {}".format(name), dependencies=["fluct_" + name], ) @@ -1666,7 +1673,7 @@ class Plotter(Aggregator, BaseProcessor): # Radial generic_rule(field + "r") - # Othoradial + # Orthoradial generic_rule(field + "phi") # Norm generic_rule(field + "_norm") diff --git a/postprocessor.py b/postprocessor.py index 3f15ad3..19e0104 100644 --- a/postprocessor.py +++ b/postprocessor.py @@ -149,7 +149,15 @@ class PostProcessor(HDF5Container): cells_loaded = False - def __init__(self, path=None, num=None, path_out=None, pp_params=None, tag=None): + def __init__( + self, + path=None, + num=None, + path_out=None, + pp_params=None, + tag=None, + unit_time=cst.year, + ): """ Creates the basic structures needed for the outputs """ @@ -170,6 +178,14 @@ class PostProcessor(HDF5Container): self.path_out + "/cells_" + tag_name + format(num, "05") + ".h5" ) + self.pspec_filename = ( + self.path_out + "/pspec_" + tag_name + format(num, "05") + ".h5" + ) + + self.filaments_filename = ( + self.path_out + "/filaments_" + tag_name + format(num, "05") + ".pickle" + ) + if not os.path.exists(self.path_out): os.makedirs(self.path_out) @@ -213,7 +229,7 @@ class PostProcessor(HDF5Container): if not self.pp_params.pymses.multiprocessing: self._rt.disable_multiprocessing() - # Set the extend of the image + # Set the extent of the image self._radius = 0.5 / self.pp_params.pymses.zoom self.lbox = self.info["boxlen"] @@ -273,12 +289,23 @@ class PostProcessor(HDF5Container): self.log_id = "[{}, {}] ".format(self.run, self.num) - if os.path.exists(self.path_out + "/filaments.pickle"): - with open(self.path_out + "/filaments.pickle", "rb") as f: + if os.path.exists(self.filaments_filename): + with open(self.filaments_filename, "rb") as f: self.fil = pickle.load(f) else: self.fil = None + self.snapshot = Snapshot( + name=str(self.num), + description="", + time=( + self.info["time"] * self.info["unit_time"].express(unit_time), + unit_time, + ), + directory_path=self.path, + data_reference="OUTPUT_{}".format(self.num), + ) + self.def_rules() def load_cells(self): @@ -964,12 +991,22 @@ class PostProcessor(HDF5Container): ) bins = np.zeros(r.shape, dtype=int) - for r0 in radial_bins[1:]: + for r0 in radial_bins[1:-1]: bins = bins + (r >= r0).astype(int) vr_mean = mean_bin_vr[bins] vphi_mean = mean_bin_vphi[bins] + # use linear interpolation + # v = ((r[i+1] - r)v[i] + (r - r[i])v[i + 1]) / (r[i+1] - r[i]) + vr_mean = (radial_bins[bins + 1] - r) * vr_mean + vr_mean = vr_mean + (r - radial_bins[bins]) * mean_bin_vr[bins + 1] + vr_mean = vr_mean / (radial_bins[bins + 1] - radial_bins[bins]) + + vphi_mean = (radial_bins[bins + 1] - r) * vphi_mean + vphi_mean = vphi_mean + (r - radial_bins[bins]) * mean_bin_vphi[bins + 1] + vphi_mean = vphi_mean / (radial_bins[bins + 1] - radial_bins[bins]) + vr = self.oct_getter_vr(dset) vphi = self.oct_getter_vphi(dset) alpha = (vphi - vphi_mean) * (vr - vr_mean) @@ -1055,7 +1092,7 @@ class PostProcessor(HDF5Container): return sinks_dict def _pspec(self): - outfile = self.path_out + "/pspec.h5" + outfile = self.pspec_filename pspec_new.pspec(repo=self.path, iouts=[self.num], outfile=outfile) return True @@ -1086,9 +1123,10 @@ class PostProcessor(HDF5Container): self.fil.create_mask( verbose=verbose, smooth_size=1 * u.pix, - adapt_thresh=2 * u.pix, + adapt_thresh=4 * u.pix, size_thresh=size_thresh * u.pix ** 2, glob_thresh=glob_thresh, + fill_hole_size=0.1 * u.pix ** 2, ) self.fil.medskel(verbose=verbose) self.fil.analyze_skeletons( @@ -1099,8 +1137,7 @@ class PostProcessor(HDF5Container): self.fil.exec_rht() self.fil.find_widths() - outfile = self.path_out + "/filaments.pickle" - with open(outfile, "wb") as f: + with open(self.filaments_filename, "wb") as f: pickle.dump(self.fil, f, pickle.HIGHEST_PROTOCOL) return True @@ -1121,7 +1158,7 @@ class PostProcessor(HDF5Container): def _filaments_forces(self): """ - Compute forces within a filament (for disks) + Compute forces within a filament (for disks), within the slice at z=0 """ GM = self.G * self.pp_params.disk.mass_star # Mass parameter @@ -1134,11 +1171,11 @@ class PostProcessor(HDF5Container): i_center, j_center = self._filaments_center() # Get slices and projections at z = 0 - vphi = self.save.get_node("/maps/slice_velphi_z").read() - gr = self.save.get_node("/maps/slice_gr_z").read() - Pz = self.save.get_node("/maps/slice_P_z").read() - coldens = self.save.get_node("/maps/coldens_z").read() - vr = self.save.get_node("/maps/slice_velr_z").read() + vphi = self.get_value("/maps/slice_velphi_z") + gr = self.get_value("/maps/slice_gr_z") + Pz = self.get_value("/maps/slice_P_z") + rho = self.get_value("/maps/slice_rho_z") + vr = self.get_value("/maps/slice_velr_z") # Get coordinates im_extent = np.array(self.save.root.maps._v_attrs.im_extent) * self.lbox @@ -1167,7 +1204,7 @@ class PostProcessor(HDF5Container): # Thermal support GPx, GPy = np.gradient(Pz) gradPr = (xx * GPx + yy * GPy) / rr - fP = gradPr / coldens + fP = gradPr / rho # Gravitational field e2 = (1.0 / 512) ** 2 @@ -1371,7 +1408,7 @@ class PostProcessor(HDF5Container): "slice_velphi": "z", "slice_gr": "z", "slice_P": "z", - "coldens": "z", + "slice_rho": "z", "slice_velr": "z", }, ), diff --git a/pp_params.yml b/pp_params.yml index b81239c..55876e3 100644 --- a/pp_params.yml +++ b/pp_params.yml @@ -22,22 +22,22 @@ disk: # Disk speficic parameters pdf: # parameters for probability density functions - nb_bin : 200 # Number of bins for the PDF + nb_bin : 100 # Number of bins for the PDF range : [-1.5, 2.5] # Range of the PDF (log of fluctuation) - xmin_fit : 0.1 # Lower boundary of the fit (log of fluctuation) + xmin_fit : 0.3 # Lower boundary of the fit (log of fluctuation) xmax_fit : 1.5 # Upper boundary of the fit (log of fluctuation) fit_cut : 1e-4 # Exclude value that are < fit_cut * maximum filaments: # parameters for FilFinder - datamap : "rho_avg" + datamap : "fluct_coldens" verbose : False rmin : 0.15 # In fraction of the box (zoom to be taken into account) rmax : 0.45 # In fraction of the box (idem) - size_thresh : 200 # in pixels**2 - skel_thresh : 100 # in pixels - branch_thresh : 100 # in pixels - glob_thresh : 40 # in map unit + size_thresh : 400 # in pixels**2 + skel_thresh : 50 # in pixels + branch_thresh : 50 # in pixels + glob_thresh : 1.5 # in datamap unit pymses: # Parameters for Pymses reader @@ -80,7 +80,7 @@ out: # Parameters for post processing # {ext} : Extension defined above # {name} : Name of the rule # {tag} : Tag defined above - # {nml[nml_key]} : The value of nml_key in the namelist (ex: amr_params/levelmin) + # {nml[nml_key]} : The value of nml_key in the namelist (ex: {nml[amr_params/levelmin]}) process: # General setting of the post-processor module @@ -91,3 +91,12 @@ process: # General setting of the post-processor module rules: # Specific rules parameters turb_energy_threshold : -1 # Remove invalid data (<0 = no threshold) + + +astrophysix: # Parameters for astrophysix and galactica + simu_fmt : "{tag}_{run}" # Format of the name of simulation + descr_fmt : "{tag}_{run}" # Format of the default description + # The following keys are accepted + # {run} : Name of the relevant run + # {tag} : Tag defined above + # {nml[nml_key]} : The value of nml_key in the namelist (ex: {nml[amr_params/levelmin]}) diff --git a/ramses_astrophysix.py b/ramses_astrophysix.py new file mode 100644 index 0000000..7fc9e99 --- /dev/null +++ b/ramses_astrophysix.py @@ -0,0 +1,92 @@ +# coding: utf-8 + + +from astrophysix.simdm.protocol import ( + SimulationCode, + AlgoType, + Algorithm, + InputParameter, + PhysicalProcess, + Physics, +) + + +# Simulation code definition # +ramses = SimulationCode( + name="Ramses 3 (MHD)", + code_name="Ramses", + code_version="3.10.1", + alias="RAMSES_3", + url="https://www.ics.uzh.ch/~teyssier/ramses/RAMSES.html", + description="Ramses MHD code", +) +# => Add algorithms : available algorithm types are : +# - AlgoType.AdaptiveMeshRefinement +# - AlgoType.VoronoiMovingMesh +# - AlgoType.SmoothParticleHydrodynamics +# - AlgoType.Godunov +# - AlgoType.PoissonMultigrid +# - AlgoType.PoissonConjugateGradient +# - AlgoType.ParticleMesh +# - AlgoType.FriendOfFriend +# - AlgoType.HLLCRiemann +# - AlgoType.RayTracer +ramses.algorithms.add( + Algorithm(algo_type=AlgoType.AdaptiveMeshRefinement, description="AMR") +) +ramses.algorithms.add( + Algorithm(algo_type=AlgoType.Godunov, description="Godunov scheme") +) +ramses.algorithms.add( + Algorithm(algo_type=AlgoType.HLLCRiemann, description="HLLC Riemann solver") +) +ramses.algorithms.add( + Algorithm( + algo_type=AlgoType.PoissonMultigrid, description="Multigrid Poisson solver" + ) +) +# => Add input parameters +ramses.input_parameters.add( + InputParameter( + key="amr_params/levelmin", + name="lmin", + description="min. level of AMR refinement", + ) +) +ramses.input_parameters.add( + InputParameter( + key="amr_params/levelmax", + name="lmax", + description="max. level of AMR refinement", + ) +) +ramses.input_parameters.add( + InputParameter( + key="amr_params/jeans_refine", + name="jeans_refine", + description="Array, number of cells needed to resolve the Jeans lenght at each level from lmin", + ) +) +ramses.input_parameters.add( + InputParameter( + key="cloud_params/beta_cool", name="beta", description="Cooling parameter" + ) +) + + +# => Add physical processes : available physics are : +# - Physics.SelfGravity +# - Physics.Hydrodynamics +# - Physics.MHD +# - Physics.StarFormation +# - Physics.SupernovaeFeedback +# - Physics.AGNFeedback +# - Physics.MolecularCooling +ramses.physical_processes.add( + PhysicalProcess( + physics=Physics.Hydrodynamics, description="Hydrodynamical equations are solved" + ) +) +ramses.physical_processes.add( + PhysicalProcess(physics=Physics.SelfGravity, description="Self-Gravity is applied.") +) diff --git a/units.py b/units.py index 7f72c2f..cafe8a8 100644 --- a/units.py +++ b/units.py @@ -1,14 +1,16 @@ # coding: utf-8 -import pymses.utils.constants as cst +import astrophysix.units as cst + +create_unit = cst.Unit.create_unit def parse_exp_unit(u): splitted = u.split("^") - name_u = cst.Unit.from_name(splitted[0]).latex + name_u = cst.Unit.from_name(splitted[0]).latex.replace("text", "math") exp = "" if len(splitted) > 1: - exp = "$^{" + str(splitted[1]) + "}$" + exp = "^{" + str(splitted[1]) + "}" return name_u + exp @@ -17,66 +19,78 @@ def convert_exp(number, digits=4): splitted = "{num:.{digits}g}".format(num=number, digits=digits).split("e") # If no need of scientific notation (low number of digits) if len(splitted) == 1: - return "${}$".format(splitted[0]) + return "{}".format(splitted[0]) else: coeff = splitted[0] exp = splitted[1] exp_str = "10^{" + str(int(exp)) + "}" if float(coeff) == 1.0: - return "$" + exp_str + "$" + return exp_str else: - return "${}\\times {}$".format(coeff, exp_str) + return "{}\\times {}".format(coeff, exp_str) -def unit_str(unit, base=None, prefix=""): +def unit_str(unit, base=None, prefix="", format=" [{unit}]"): + """ + Format a unit in matplotlib parsable latex + + unit : astrophysics.units.unit.Unit + base : astrophysics.units.unit.Unit to use as base unit (if None `unit` is used) + prefix : str to put befor the unit + format : str with the {unit} key, to put external decoration + """ if unit == cst.none: return "" elif not base is None: coeff = unit.express(base) - return unit_str(base, prefix=convert_exp(coeff) + " ") + u_str = unit_str(base, prefix=convert_exp(coeff) + " ") elif len(unit.latex) > 0: - if ("." in unit.latex or "^" in unit.latex) and not "$" in unit.latex: + if "." in unit.latex or "^" in unit.latex: base_str = ".".join(map(parse_exp_unit, unit.name.split("."))) - return r" [{}{}]".format(prefix, base_str) + u_str = r"${}{}$".format(prefix, base_str) else: - return r" [{}{}]".format(prefix, unit.latex) + u_str = r"${}{}$".format(prefix, unit.latex.replace("text", "math")) elif len(unit.name) > 0: try: base_str = ".".join(map(parse_exp_unit, unit.name.split("."))) - u_str = r" [{}{}]".format(prefix, base_str) + u_str = r"${}{}$".format(prefix, base_str) except: - u_str = r" [{}{}]".format(prefix, unit.name) - return u_str + u_str = r"${}{}$".format(prefix, unit.name) else: base_str = ".".join( map(parse_exp_unit, unit._decompose_base_units().split(".")) ) - return r" [{}{} {}]".format(prefix, unit.coeff, base_str) + u_str = r"${}{} {}$".format(prefix, unit.coeff, base_str) + return format.format(unit=u_str) -cst.Msun_pc3 = cst.create_unit( +cst.coldens = create_unit( + "Msun.pc^-2", base_unit=cst.Msun / cst.pc ** 2, descr="Column density" +) +cst.km_s = create_unit("km.s^-1", base_unit=cst.km / cst.s, descr="Speed") + + +cst.Msun_pc3 = create_unit( "Msun.pc^-3", base_unit=cst.Msun / cst.pc ** 3, descr="Density" ) -cst.Msun_pc3 = cst.create_unit( - "Msun.pc^-3", base_unit=cst.Msun / cst.pc ** 3, descr="Density" -) +cst.kg_m3 = create_unit("kg.m^-3", base_unit=cst.kg / cst.m ** 3, descr="Density") -cst.ssfr = cst.create_unit( - "Msun.yr^-1.pc^-2", +cst.ssfr = create_unit( + "Msun.year^-1.pc^-2", base_unit=cst.Msun / cst.year / cst.pc ** 2, descr="Surfacic SFR", - latex="M$_{\odot}$.yr$^{-1}$.pc$^{-2}$", ) +# latex='M$_{\odot}$.yr$^{-1}$.pc$^{-2}$') -cst.ssfrG = cst.create_unit( +cst.ssfrG = create_unit( "Msun.Gyr^-1.pc^-2", base_unit=1e-9 * cst.Msun / cst.year / cst.pc ** 2, descr="Surfacic SFR", - latex="M$_{\odot}$.Gyr$^{-1}$.pc$^{-2}$", + latex="\mathrm{M}_{\odot}.\mathrm{Gyr}^{-1}.\mathrm{pc}^{-2}", ) -cst.uG = cst.create_unit( - "μG", base_unit=1e-10 * cst.T, descr="Micro Gauss", latex="$\mu\mathrm{G}$" +cst.uG = create_unit( + "μG", base_unit=1e-10 * cst.T, descr="Micro Gauss", latex="\\mu\\mathrm{G}" )