diff --git a/aggregator.py b/aggregator.py index 15e51e5..30ff8c4 100644 --- a/aggregator.py +++ b/aggregator.py @@ -8,17 +8,27 @@ def _map_rule(rule, arg, overwrite, path, path_out, pp_params, run_num): ) except Exception as e: print(e) + raise return pp.process(rule, arg, overwrite, overwrite) class Aggregator: def _not_self_dep(self, name, dep, dep_arg, overwrite, **kwargs): - if "runs" in kwargs: - dep_runs = [run for run in self.runs if run in kwargs["runs"]] + if "select" in kwargs: + select = kwargs["select"] + runs, nums = self.selector.select(**select) + elif "runs" in kwargs: + runs = kwargs["runs"] + if isinstance(runs, RunSelector): + nums = runs.nums + runs = runs.runs + else: + nums = self.nums else: - dep_runs = self.runs + runs = self.runs + nums = self.nums - run_num = [(run, num) for run in dep_runs for num in self.nums[run]] + run_num = [(run, num) for run in runs for num in nums[run]] map_fn = partial( _map_rule, dep, dep_arg, overwrite, self.path, self.path_out, self.pp_params ) diff --git a/baseprocessor.py b/baseprocessor.py index 34bf4f7..9e835a6 100644 --- a/baseprocessor.py +++ b/baseprocessor.py @@ -4,8 +4,10 @@ import sys import os import glob as glob import copy - +import time import tables +from tables import HDF5ExtError + import pymses import numpy as np from numpy.polynomial.polynomial import polyfit @@ -211,7 +213,12 @@ class HDF5Container(BaseProcessor): def open(self): if not self.opened: - self.save = tables.open_file(self.filename, mode="a") + try: + self.save = tables.open_file(self.filename, mode="a") + except HDF5ExtError: + # Wait a bit if the lock was not still released + time.sleep(3) + self.save = tables.open_file(self.filename, mode="a") self.opened = True def close(self): @@ -253,10 +260,52 @@ class HDF5Container(BaseProcessor): self.close() return value + def _get_units(self, unit, data=None): + """ + Get real units from info files + unit is either: + 1. An instance of cst.Unit (pymses unit class) + 2. A string beginning by "unit_", referring to a code unit, + available in self.info + 3. A dict {unit1 : exp1, unit2: exp2, ...} with unitX as 2. + and expX a float, referring to the compound unit + unit1**exp1 * unit2**exp2 + 4. A dict {key: unit, ...} where key is a field name (eg. 'time', or 'mass') + and unit the corresponding unit (on one on the above format) + + Returns: + 1-3. : a cst.Unit instance + 4. : a dict {key: unit, ...} with same key as input and unit being cst.Unit instances + """ + if isinstance(unit, cst.Unit): + return unit + if isinstance(unit, str) and unit[:5] == "unit_": + res = self.info[unit] + if unit == "unit_length": + res = res / self.info["boxlen"] + return res + if list(unit)[0][:5] == "unit_": + new_unit = cst.none + for base_unit_str in unit: + expo = unit[base_unit_str] + base_unit = self._get_units(base_unit_str) + new_unit = new_unit * base_unit ** expo + return new_unit + if (not data is None) and isinstance(data, dict) and list(unit)[0] in data: + for key in unit: + unit[key] = self._get_units(unit[key]) + return unit + + else: + raise ValueError("Invalid unit") + def _save_data(self, name_full, data, description, unit): """ Save data in the HDF5 structure, overwrite if necessary """ + + unit = self._get_units(unit, data=data) + if name_full in self.save: self.save.remove_node(name_full, recursive=True) @@ -285,6 +334,7 @@ class HDF5Container(BaseProcessor): self.save.get_node(name_full)._v_attrs.unit = unit for key in data: + key = str(key) if isinstance(description, dict): if isinstance(unit, dict): self._save_data( @@ -314,6 +364,7 @@ class HDF5Container(BaseProcessor): if not attrs is None: for key in attrs: + key = str(key) self.save.get_node(name_full)._v_attrs[key] = attrs[key] def set_value(self, node_name, data, description, unit): @@ -412,3 +463,11 @@ class HDF5Container(BaseProcessor): def simple_getter(name, dset): return dset[name] + + +def vect_getter(name, i, dset): + return dset[name][:, i] + + +def norm_getter(name, dset): + return np.sqrt(np.sum(dset[name] ** 2, axis=1)) diff --git a/comparator.py b/comparator.py index 36a0865..e74faf8 100644 --- a/comparator.py +++ b/comparator.py @@ -88,47 +88,7 @@ class Comparator(Aggregator, HDF5Container): ) return missing_nums - def _get_units(self, unit, data=None): - """ - Get real units from info files - unit is either: - 1. An instance of cst.Unit (pymses unit class) - 2. A string beginning by "unit_", referring to a code unit, - available in self.info - 3. A dict {unit1 : exp1, unit2: exp2, ...} with unitX as 2. - and expX a float, referring to the compound unit - unit1**exp1 * unit2**exp2 - 4. A dict {key: unit, ...} where key is a field name (eg. 'time', or 'mass') - and unit the corresponding unit (on one on the above format) - - Returns: - 1-3. : a cst.Unit instance - 4. : a dict {key: unit, ...} with same key as input and unit being cst.Unit instances - """ - if isinstance(unit, cst.Unit): - return unit - if isinstance(unit, str) and unit[:5] == "unit_": - res = self.info[unit] - if unit == "unit_length": - res = res / self.info["boxlen"] - return res - if list(unit)[0][:5] == "unit_": - new_unit = cst.none - for base_unit_str in unit: - expo = unit[base_unit_str] - base_unit = self._get_units(base_unit_str) - new_unit = new_unit * base_unit ** expo - return new_unit - if (not data is None) and isinstance(data, dict) and list(unit)[0] in data: - for key in unit: - unit[key] = self._get_units(unit[key]) - return unit - - else: - raise ValueError("Invalid unit") - def _save_data(self, name_full, data, description, unit): - unit = self._get_units(unit, data=data) super(Comparator, self)._save_data(name_full, data, description, unit) self.save.get_node(name_full)._v_attrs.nums = self.nums @@ -276,6 +236,24 @@ class Comparator(Aggregator, HDF5Container): series["sfr"][run].append(sfr) return series + def _extract_cons_from_log(self, series, log_filename, run): + cmd_grep = "grep 'Main step' {} -A 2".format(log_filename) + content = os.popen(cmd_grep).readlines() + for i in range(0, len(content), 4): + series["time"][run].append( + np.float(content[i + 2].split("=")[2].split()[0]) + ) + series["step"][run].append(np.int(content[i].split("=")[1].split()[0])) + series["mcons"][run].append(np.float(content[i].split("=")[2].split()[0])) + series["econs"][run].append(np.float(content[i].split("=")[3].split()[0])) + series["epot"][run].append(np.float(content[i].split("=")[4].split()[0])) + series["ekin"][run].append(np.float(content[i].split("=")[5].split()[0])) + series["eint"][run].append(np.float(content[i].split("=")[6].split()[0])) + series["emag"][run].append( + np.float(content[i + 1].split("=")[1].split()[0]) + ) + return series + def _extract_rms_from_log(self, series, log_filename, run): cmd_grep = "grep 'turbulent rms' {} -C 1".format(log_filename) content = os.popen(cmd_grep).readlines() @@ -334,7 +312,8 @@ class Comparator(Aggregator, HDF5Container): ssfr = {} for run in self.runs: # Surface of the box in pc^2 - surface = (self.info["unit_length"].express(cst.pc)) ** 2 + info = self.pp[run][self.nums[run][0]].info + surface = (info["unit_length"].express(cst.pc)) ** 2 # WARNING : We do not multiply by boxlen since already done in 'unit_length' (pymses) time = self.save.get_node("/series/sinks_from_log/time/" + run).read() @@ -360,6 +339,22 @@ class Comparator(Aggregator, HDF5Container): return ssfr, {"avg_window": avg_window} + def _surfacic_sink_mass(self): + mass_unit = self.save.get_node("/series/sinks_from_log/mass_sink")._v_attrs.unit + ssm = {} + for run in self.runs: + # Surface of the box in pc^2 + info = self.pp[run][self.nums[run][0]].info + surface = (info["unit_length"].express(cst.pc)) ** 2 + mass_sink = self.save.get_node( + "/series/sinks_from_log/mass_sink/" + run + ).read() + mass_sink = mass_sink * mass_unit.express(cst.Msun) + + ssm[run] = mass_sink / surface + + return ssm + def _turb_power(self): turb_power = {} for run in self.runs: @@ -475,6 +470,14 @@ class Comparator(Aggregator, HDF5Container): description="Instantaneous surfacic star formation rate", dependencies=["sinks_from_log"], ), + "ssm": Rule( + self, + self._surfacic_sink_mass, + group="/series/sinks_from_log", + unit=cst.Msun / cst.pc ** 2, + description="Surfacic sink mass", + dependencies=["sinks_from_log"], + ), "sfr_from_log": Rule( self, partial(self._from_log, ["time", "sfr"], self._extract_sfr_from_log), @@ -510,6 +513,25 @@ class Comparator(Aggregator, HDF5Container): "turb_energy": "Injected turbulent energy", }, ), + "cons_from_log": Rule( + self, + partial( + self._from_log, + ["time", "step", "mcons", "econs", "epot", "ekin", "eint", "emag"], + self._extract_cons_from_log, + ), + group="/series", + unit={ + "time": "unit_time", + "step": cst.none, + "mcons": cst.none, + "econs": cst.none, + "epot": cst.none, # TODO find unit + "ekin": cst.none, + "eint": cst.none, + "emag": cst.none, + }, + ), "turb_power": Rule( self, self._turb_power, @@ -574,6 +596,9 @@ class Comparator(Aggregator, HDF5Container): self._gen_rule_time_global("mwa_sigma", "time_sigma", unit="unit_velocity") self._gen_rule_time_global("max_fluct_coldens") + self._gen_rule_time_global( + "mass", unit=self.info["unit_density"] * self.info["unit_length"] ** 3 + ) self._gen_rule_time_global("mwa_B_int", unit="unit_mag") for name in [ diff --git a/plotter.py b/plotter.py index 867aa11..2b2a5e5 100644 --- a/plotter.py +++ b/plotter.py @@ -30,7 +30,7 @@ import pspec_read P.rcParams["image.cmap"] = "plasma" P.rcParams["savefig.dpi"] = 400 -tex_params = {"text.latex.preamble": [r"\usepackage{amsmath}"]} +tex_params = {"text.latex.preamble": r"\usepackage{amsmath}"} P.rcParams.update(tex_params) @@ -73,7 +73,7 @@ class Plotter(Aggregator, BaseProcessor): label_convert = { "turb_rms": "$f_{rms}$", "beta": "$\\beta$", - "beta_cool": "$\\beta_{c}$", + "beta_cool": "$\\beta$", "dens0": "$n_0$", "coldens0": "$\Sigma_0$", "sfr_avg_window": "window", @@ -122,16 +122,20 @@ class Plotter(Aggregator, BaseProcessor): # Select runs if selector is None: - selector = RunSelector(path, in_runs, in_nums, self.pp_params, **kwargs) + self.selector = RunSelector( + path, in_runs, in_nums, self.pp_params, **kwargs + ) + else: + self.selector = selector # Save infos self.path = path - self.runs = selector.runs - self.nums = selector.nums + self.runs = self.selector.runs + self.nums = self.selector.nums # Get comparator object self.comp = Comparator( - path, self.runs, self.nums, path_out, self.pp_params, selector=selector + path, self.runs, self.nums, path_out, self.pp_params, selector=self.selector ) # Get postprocesor objets for each run @@ -182,26 +186,38 @@ class Plotter(Aggregator, BaseProcessor): Open storage and figure if needed before processing a rule """ if not arg is None: - name_full = name + "_" + str(arg) + name_full = ( + name + + "_" + + str(arg) + .replace(" ", "") + .replace("[", "") + .replace("]", "") + .replace(",", "_") + ) else: name_full = name if rule.is_valid(arg): - if rule.kind == "classic" or rule.kind == "runs": - try: + if rule.kind == "classic" or rule.kind == "cells": + if "select" in kwargs: + select = kwargs.pop("select") + runs, nums = self.selector.select(**select) + elif "runs" in kwargs: runs = kwargs.pop("runs") if isinstance(runs, RunSelector): + nums = runs.nums runs = runs.runs - except KeyError: + else: + nums = self.nums + else: runs = self.runs + nums = self.nums + i = 0 for run in runs: files = [] - if rule.kind == "classic": - nums = self.nums[run] - else: - nums = [None] - for num in nums: + for num in nums[run]: plot_filename = self._find_filename(name_full, run, num) if from_cells or rule.kind == "cells": @@ -229,6 +245,7 @@ class Plotter(Aggregator, BaseProcessor): "'LocatableAxes' object does not support indexing", "'AxesSubplot' object does not support indexing", "'AxesSubplot' object is not subscriptable", + "'Axes' object is not subscriptable", "'LocatableAxes' object is not subscriptable", ]: self._plot_rule( @@ -260,18 +277,61 @@ class Plotter(Aggregator, BaseProcessor): i = i + 1 files.append(plot_filename) else: + + if "select" in kwargs and not "runs" in kwargs: + select = kwargs.pop("select") + runs, nums = self.selector.select(**select) + if not rule.kind == "runs": + kwargs["runs"] = runs + elif rule.kind == "runs" and "runs" in kwargs: + runs = kwargs.pop("runs") + else: + runs = self.runs + if ax is None: ax = P.gca() - if rule.kind == "series" and len(self.runs) == 1: + if rule.kind == "series" and len(runs) == 1: run = self.runs[0] plot_filename = self._find_filename(name_full, run) else: plot_filename = self._find_filename(name_full) save = tables.open_file(self.comp.filename, "r") try: - self._plot_rule( - rule, save, arg, plot_filename, overwrite, ax, **kwargs - ) + if rule.kind == "runs": + for i, run in enumerate(runs): + try: + self._plot_rule( + rule, + save, + arg, + plot_filename, + overwrite, + ax=ax[i], + run=run, + **kwargs, + ) + except TypeError as e: + if str(e) in [ + "'LocatableAxes' object does not support indexing", + "'AxesSubplot' object does not support indexing", + "'AxesSubplot' object is not subscriptable", + "'Axes' object is not subscriptable", + "'LocatableAxes' object is not subscriptable", + ]: + self._plot_rule( + rule, + save, + arg, + plot_filename, + overwrite, + ax=ax, + run=run, + **kwargs, + ) + else: + self._plot_rule( + rule, save, arg, plot_filename, overwrite, ax, **kwargs + ) finally: save.close() else: @@ -284,15 +344,20 @@ class Plotter(Aggregator, BaseProcessor): P.sca(ax) if self._needs_computation(overwrite, plot_filename): rule.plot(save, arg, **kwargs) - P.tight_layout(pad=1) + if not self.pp_params.out.interactive: + P.tight_layout(pad=1) + + if self.pp_params.out.save: P.savefig(plot_filename) - P.close() self._log("{} plotted".format(plot_filename), "SUCCESS") else: self._log( "{} plotted".format(os.path.basename(plot_filename)), "SUCCESS" ) + + if not self.pp_params.out.interactive: + P.close() else: self._log("Plot {} is already done, skipping...".format(plot_filename)) @@ -452,8 +517,7 @@ class Plotter(Aggregator, BaseProcessor): dmap, extent=im_extent, origin="lower", norm=norm, cmap=cmap, **kwargs ) - P.locator_params(axis=ax_h, nbins=self.pp_params.plot.ntick) - P.locator_params(axis=ax_v, nbins=self.pp_params.plot.ntick) + P.locator_params(axis="both", nbins=self.pp_params.plot.ntick) P.xlabel(self._ax_title[ax_h] + unit_str(unit_space)) P.ylabel(self._ax_title[ax_v] + unit_str(unit_space)) @@ -680,7 +744,6 @@ class Plotter(Aggregator, BaseProcessor): P.xscale("log") if ylog: P.yscale("log") - P.plot(bin_centers, mean_bin, **kwargs) if not ylabel is None: P.ylabel(ylabel) @@ -700,6 +763,8 @@ class Plotter(Aggregator, BaseProcessor): P.title(title) + P.plot(bin_centers, mean_bin, label=title, **kwargs) + def _plot_hist( self, name, @@ -743,7 +808,7 @@ class Plotter(Aggregator, BaseProcessor): xlabel, unit_old, unit = self._ax_label_unit(node, label, unit, unit_coeff) if "mean" in node: - index = node["runs"].read().index(run) + index = node["runs"].read().index(run.encode()) values, centers = node["mean"].read()[index] else: values, centers = node.read() @@ -949,9 +1014,15 @@ class Plotter(Aggregator, BaseProcessor): ) if not run is None: label = self._label_run(run, node_y, label, nml_key) - base_line, _, _ = P.errorbar( - x, y, yerr=[y - yerr_min, yerr_max - y], label=label, **kwargs - ) + + if yerr_kind is None: + yerr = None + (base_line,) = P.plot(x, y, label=label, **kwargs) + else: + base_line, _, _ = P.errorbar( + x, y, yerr=[y - yerr_min, yerr_max - y], label=label, **kwargs + ) + else: if runs is None: runs = self.runs @@ -1067,7 +1138,7 @@ class Plotter(Aggregator, BaseProcessor): def overlay_kennicutt(self, n0, step): """ - Add an overlay : kennicutt mass accretion + Add an overlay : Kennicutt mass accretion """ P.grid(False) ylim = P.ylim() @@ -1085,6 +1156,20 @@ class Plotter(Aggregator, BaseProcessor): P.xlim(tmin, tmax) P.ylim(ylim) + def _gen_from_log(self, logrule, name, description="Generated"): + self.rules[name] = PlotRule( + self, + partial( + self._plot, + "/series/" + logrule + "/time", + "/series/" + logrule + "/" + name, + xunit=cst.Myr, + ), + description=description, + kind="series", + dependencies=[logrule], + ) + def def_rules(self): """ This is where rules are defined @@ -1105,6 +1190,16 @@ class Plotter(Aggregator, BaseProcessor): "Column density map", dependencies=["coldens"], ), + "T": PlotRule( + self, + partial( + self._plot_map, + "T", + label=r"$T$", + ), + "Temperature map", + dependencies=["T"], + ), "alpha_disk": PlotRule( self, partial(self._plot_map, "alpha_disk", label=r"$\alpha$"), @@ -1139,18 +1234,6 @@ class Plotter(Aggregator, BaseProcessor): "Radial speed", dependencies=["vr"], ), - "P_avg": PlotRule( - self, - partial(self._plot_map, "P_avg", label=r"$P$"), - "Pressure (average)", - dependencies=["P_avg"], - ), - "rho_avg": PlotRule( - self, - partial(self._plot_map, "rho_avg", label=r"$\rho$"), - "Density (average)", - dependencies=["rho_avg"], - ), "rho": PlotRule( self, partial( @@ -1242,6 +1325,12 @@ class Plotter(Aggregator, BaseProcessor): "$\rho$-PDF", dependencies=["rho_pdf"], ), + "rho_pdf_mw": PlotRule( + self, + partial(self._plot_hist, "rho_pdf_mw"), + "Mass weighted $\rho$-PDF", + dependencies=["rho_pdf_mw"], + ), "cos_pdf": PlotRule( self, partial(self._plot_hist, "cos_pdf"), @@ -1335,10 +1424,23 @@ class Plotter(Aggregator, BaseProcessor): xunit=cst.Myr, yunit=cst.Msun, ), - "Mass of the sinks against time", + "Mass of the sinks as a function of time", kind="series", dependencies=["sinks_from_log"], ), + "ssm": PlotRule( + self, + partial( + self._plot, + "/series/sinks_from_log/time", + "/series/sinks_from_log/ssm", + xunit=cst.Myr, + yunit=cst.Msun / cst.pc ** 2, + ), + "Mass of the sinks as a function of time divided by surface", + kind="series", + dependencies=["ssm"], + ), "assfr": PlotRule( self, partial( @@ -1422,12 +1524,25 @@ class Plotter(Aggregator, BaseProcessor): "/series/time", "/series/time_mwa_B_int", xunit=cst.Myr, - yunit=cst.T, + yunit=cst.uG, ), "Magnetic intensity average", kind="series", dependencies=["time_mwa_B_int"], ), + "mass": PlotRule( + self, + partial( + self._plot, + "/series/time", + "/series/time_mass", + xunit=cst.Myr, + yunit=cst.Msun, + ), + "Total mass in the box", + kind="series", + dependencies=["time_mass"], + ), "max_fluct_coldens": PlotRule( self, partial( @@ -1459,13 +1574,7 @@ class Plotter(Aggregator, BaseProcessor): for name in averageables: self.rules["rad_" + name] = PlotRule( self, - partial( - self._plot_radial, - "rad_avg_" + name, - label=name, - xlog=True, - ylog=True, - ), + partial(self._plot_radial, "rad_avg_" + name, xlog=True, ylog=True), "Azimuthal average of {}".format(name), dependencies=["radial_bins", "rad_avg_" + name], ) @@ -1473,12 +1582,7 @@ class Plotter(Aggregator, BaseProcessor): self.rules["fluct_" + name] = PlotRule( self, partial( - self._plot_map, - "fluct_" + name, - vmin=0.01, - vmax=100, - cmap="RdBu_r", - label="{}/avg({})".format(name, name), + self._plot_map, "fluct_" + name, vmin=0.01, vmax=100, cmap="RdBu_r" ), "Fluctuation of {}".format(name), dependencies=["fluct_" + name], @@ -1486,12 +1590,7 @@ class Plotter(Aggregator, BaseProcessor): self.rules["pdf_" + name] = PlotRule( self, - partial( - self._plot_hist, - "pdf_" + name, - ylog=True, - label="{}/avg({})".format(name, name), - ), + partial(self._plot_hist, "pdf_" + name, ylog=True), "Probability density function of {} fluctuations".format(name), dependencies=["fit_pdf_" + name], ) @@ -1506,6 +1605,50 @@ class Plotter(Aggregator, BaseProcessor): dependencies=[group], ) + for name in ["step", "mcons", "econs", "epot", "ekin", "eint", "emag"]: + self._gen_from_log("cons_from_log", name) + + # Generic rules directly from Ramses fields + for field in self.pp_params.pymses.variables: + + def generic_rule(name): + + self.rules["slice_" + name] = PlotRule( + self, + partial(self._plot_map, "slice_" + name), + "{} slice".format(name), + dependencies=["slice_" + name], + ) + + self.rules[name + "_mwavg"] = PlotRule( + self, + partial(self._plot_map, name + "_mwavg"), + "Ax mass-weighted averaged {}".format(name), + dependencies=[name + "_mwavg"], + ) + + self.rules[name + "_avg"] = PlotRule( + self, + partial(self._plot_map, name + "_avg"), + "Ax averaged {}".format(name), + dependencies=[name + "_avg"], + ) + + # special for vectors + if field in ["g", "vel"]: + # Components + for i, dir in enumerate(["x", "y", "z"]): + generic_rule(field + "x") + + # Radial + generic_rule(field + "r") + # Othoradial + generic_rule(field + "phi") + # Norm + generic_rule(field + "_norm") + else: + generic_rule(field) + # Dict of overlays self.overlays = { "B": self._overlay_B, diff --git a/postprocessor.py b/postprocessor.py index 38f407a..2d82c93 100644 --- a/postprocessor.py +++ b/postprocessor.py @@ -4,16 +4,16 @@ import pspec_new from baseprocessor import * import pymses.utils.regions as reg from pymses.filters import RegionFilter - +import astropy.units as u +from fil_finder import FilFinder2D +import pickle +from skimage.morphology import medial_axis # Getters def mass_func(dset): - try: - dx = dset["dx"] - except: - dx = dset.get_sizes() + dx = dset["dx"] return dset["rho"] * dx ** 3 # Mass function @@ -47,7 +47,7 @@ def getter_rho(dset): def getter_v_norm(dset): - v_norm = np.sqrt(np.sum(dset["Br"] ** 2, axis=1)) + v_norm = np.sqrt(np.sum(dset["vel"] ** 2, axis=1)) return v_norm @@ -82,15 +82,6 @@ def mean_by_bins( # For each cell, bin_number contains the number of the bins it belongs to bin_number = np.zeros(len(y)) - # Go through the min value of x of each bin - for x_min in x_bins[1:-1]: - bin_number = bin_number + (x > x_min).astype(int) - - # Compute the mean in each bin - y_mean = np.zeros(len(x_bins) - 1) - for i in range(len(y_mean)): - y_mean[i] = np.mean(y[bin_number == i]) - # Get the center of each bin if logbins: centers = 10 ** (0.5 * (np.log10(x_bins[1:]) + np.log10(x_bins[:-1]))) @@ -100,6 +91,39 @@ def mean_by_bins( return centers, y_mean +# Filament helpers + + +def find_center(distance, skeleton, i_center, j_center, i, j): + """ + Given a distance array, find the cells at a center of a filament at a given postion + """ + if skeleton[i, j]: + i_center[i, j], j_center[i, j] = i, j + return i, j + elif i_center[i, j] or j_center[i, j]: + return i_center[i, j], j_center[i, j] + else: + i_neigh = np.array([i - 1, i, i + 1]) + i_neigh = i_neigh[(i_neigh > 0) & (i_neigh < distance.shape[0])] + j_neigh = np.array([j - 1, j, j + 1]) + j_neigh = j_neigh[(j_neigh > 0) & (j_neigh < distance.shape[1])] + ii_neigh, jj_neigh = np.meshgrid(i_neigh, j_neigh) + d_neigh = distance[ii_neigh, jj_neigh] + ind_max = np.unravel_index(np.argmax(d_neigh), d_neigh.shape) + i_max, j_max = ii_neigh[ind_max], jj_neigh[ind_max] + if i_max == i and j_max == j: + i_center[i, j], j_center[i, j] = i, j + else: + i_center[i, j], j_center[i, j] = find_center( + distance, skeleton, i_center, j_center, i_max, j_max + ) + return i_center[i, j], j_center[i, j] + + +# PostProcessor class + + class PostProcessor(HDF5Container): """ This class enable to compute and save derived quantities from the raw output @@ -110,6 +134,17 @@ class PostProcessor(HDF5Container): _axes_h = {"x": "y", "y": "x", "z": "x"} # Associated horizontal axe _axes_v = {"x": "z", "y": "z", "z": "y"} # Associated vertical axe + # Pymses unit key of amr fiels + unit_key = { + "rho": "unit_density", + "vel": "unit_velocity", + "Br": "unit_mag", + "Bl": "unit_mag", + "P": "unit_pressure", + "g": {"unit_gravpot": 1, "unit_length": -1}, + "phi": "unit_gravpot", + } + G = 1.0 # Gravitational constant cells_loaded = False @@ -238,6 +273,12 @@ class PostProcessor(HDF5Container): self.log_id = "[{}, {}] ".format(self.run, self.num) + if os.path.exists(self.path_out + "/filaments.pickle"): + with open(self.path_out + "/filaments.pickle", "rb") as f: + self.fil = pickle.load(f) + else: + self.fil = None + self.def_rules() def load_cells(self): @@ -290,22 +331,47 @@ class PostProcessor(HDF5Container): """ Returns the position in normalized units centered on the position of the star """ - pos = dset.get_cell_centers() + pos = dset.points pos = pos - (np.array(self.pp_params.disk.pos_star) / self.lbox) return pos def getter_vect_r(self, dset, name_vect): """ Radial component of a vector """ - r = self.getter_pos_disk(dset)[:, :, :2] + r = self.getter_pos_disk(dset)[:, :2] + ur = np.transpose((np.transpose(r) / np.sqrt(np.sum(r ** 2, axis=1)))) + return np.einsum("ij, ij -> i", dset[name_vect][:, :2], ur) + + def getter_vect_phi(self, dset, name_vect): + """ Azimuthal component of a vector """ + + r = self.getter_pos_disk(dset)[:, :2] + r_norm = np.sqrt(np.sum(r ** 2, axis=1)) + rot = np.array([[0, -1], [1, 0]]) + uphi = np.transpose(np.einsum("ij, kj -> ik", rot, r) / r_norm) + vect = dset[name_vect][:, :2] + + return np.einsum("ij,ij -> i", vect, uphi) + + def oct_getter_pos_disk(self, dset): + """ + Returns the position in normalized units centered on the position of the star + """ + pos = dset.get_cell_centers() + pos = pos - (np.array(self.pp_params.disk.pos_star) / self.lbox) + return pos + + def oct_getter_vect_r(self, dset, name_vect): + """ Radial component of a vector """ + r = self.oct_getter_pos_disk(dset)[:, :, :2] ur = np.transpose( (np.transpose(r, (2, 0, 1)) / np.sqrt(np.sum(r ** 2, axis=2))), (1, 2, 0) ) return np.einsum("ikj, ikj -> ik", dset[name_vect][:, :, :2], ur) - def getter_vect_phi(self, dset, name_vect): + def oct_getter_vect_phi(self, dset, name_vect): """ Azimuthal component of a vector """ - r = self.getter_pos_disk(dset)[:, :, :2] + r = self.oct_getter_pos_disk(dset)[:, :, :2] r_norm = np.sqrt(np.sum(r ** 2, axis=2)) rot = np.array([[0, -1], [1, 0]]) uphi = np.transpose(np.einsum("ij, klj -> ikl", rot, r) / r_norm, (1, 2, 0)) @@ -313,14 +379,14 @@ class PostProcessor(HDF5Container): return np.einsum("ikj,ikj -> ik", vect, uphi) - def getter_vr(self, dset): - return self.getter_vect_r(dset, "vel") + def oct_getter_vr(self, dset): + return self.oct_getter_vect_r(dset, "vel") - def getter_vphi(self, dset): + def oct_getter_vphi(self, dset): """ Azimuthal velocity """ - return self.getter_vect_phi(dset, "vel") + return self.oct_getter_vect_phi(dset, "vel") - def _slice(self, getter, ax_los="z", z=0, unit=cst.none): + def _slice(self, getter, ax_los="z", z=0.0, unit=cst.none): """ Slice process function. Return a slice of the source box. @@ -343,6 +409,7 @@ class PostProcessor(HDF5Container): ------- A numpy array containing the slice """ + unit = self._get_units(unit) op = ScalarOperator(getter, unit) datamap = slicing.SliceMap(self._amr, self._cam[ax_los], op, z=z) return datamap.map.T @@ -356,6 +423,7 @@ class PostProcessor(HDF5Container): If surf_qty is set (projection mode), mass_weighted is ignored """ + unit = self._get_units(unit) if surf_qty: op = ScalarOperator(getter, unit) else: @@ -405,6 +473,7 @@ class PostProcessor(HDF5Container): WARNING : This version only works on an uniform grid, need of a box version for AMR Returns 1D array if getter returns a scalar quantity """ + unit = self._get_units(unit) self.load_cells() if isinstance(axis, str): axis = self._ax_nb[axis] @@ -426,10 +495,10 @@ class PostProcessor(HDF5Container): return df.groupby("axis").mean().values[:, 0] - def _vol_avg(self, getter, mass_weighted=True): + def _sum(self, getter, mass_weighted=True): """ - Global volumic (or mass_weighted) average of the quantity returned by getter - Returns a scalar (or a vctor if the quantity returned by getter is a getter, eg. speed) + Global sum of the quantity returned by getter (variable must be extensive) + Returns a scalar (or a vector if the quantity returned by getter is a getter, eg. speed) """ self.load_cells() value = getter(self.cells) @@ -444,6 +513,24 @@ class PostProcessor(HDF5Container): self.unload_cells() return data + def _vol_avg(self, getter, mass_weighted=True): + """ + Global volumic (or mass_weighted) average of the quantity returned by getter + Returns a scalar (or a vector if the quantity returned by getter is a getter, eg. speed) + """ + self.load_cells() + value = getter(self.cells) + if mass_weighted: + weight = mass_func(self.cells) + else: + weight = vol_func(self.cells) + # Transpose (.T) is for vectorial values + data = np.sum((weight * value.T).T, axis=0) / np.sum(weight) + + if self.pp_params.process.unload_cells: + self.unload_cells() + return data + def _vol_pdf(self, getter, bins=100, logbins=False, weight_func=vol_func): self.load_cells() data = getter(self.cells) @@ -656,7 +743,7 @@ class PostProcessor(HDF5Container): # Operator to compute the angular speed times rho def omega_rho_func(dset): - pos = self.getter_pos_disk(dset) + pos = self.oct_getter_pos_disk(dset) xx = pos[:, :, 0] yy = pos[:, :, 1] rc = np.sqrt(xx ** 2 + yy ** 2) # cylindrical radius @@ -743,9 +830,17 @@ class PostProcessor(HDF5Container): map_size = self.pp_params.pymses.map_size pos_star = self.pp_params.disk.pos_star - x = np.linspace(im_extent[0], im_extent[1], map_size) - y = np.linspace(im_extent[2], im_extent[3], map_size) + # Physical size of cells + dx = (im_extent[1] - im_extent[0]) / map_size + dy = (im_extent[3] - im_extent[2]) / map_size + + # Physical coordinates of the center of the cells + x = np.linspace(im_extent[0], im_extent[1], map_size) + 0.5 * dx + y = np.linspace(im_extent[2], im_extent[3], map_size) + 0.5 * dy + xx, yy = np.meshgrid(x, y) + + # Physical radius rr = np.sqrt((xx - pos_star[0]) ** 2 + (yy - pos_star[1]) ** 2) return rr @@ -810,14 +905,17 @@ class PostProcessor(HDF5Container): fluct_map = self.save.get_node("/maps/fluct_" + name + "_" + ax_los).read() rr = self.save.get_node("/maps/rr_" + ax_los).read() - mask_pdf = (rr > self.pp_params.disk.rmin_pdf) & ( - rr < self.pp_params.disk.rmax_pdf + mask_pdf = ( + (rr > self.pp_params.disk.rmin_pdf) + & (rr < self.pp_params.disk.rmax_pdf) + & (fluct_map > 0) ) nb_cells = np.sum(mask_pdf.flatten()) values, edges = np.histogram( np.log10(fluct_map[mask_pdf].flatten()), self.pp_params.pdf.nb_bin, + range=self.pp_params.pdf.range, weights=np.ones(nb_cells) / nb_cells, ) centers = 0.5 * (edges[1:] + edges[:-1]) @@ -848,7 +946,7 @@ class PostProcessor(HDF5Container): # Mean part - T_avg = self.save.get_node("/maps/avg_map_T_avg_z").read() + T_avg = self.save.get_node("/maps/avg_map_T_mwavg_z").read() radial_bins = self.save.get_node("/radial/radial_bins_" + ax_los).read() mean_bin_vr = self.save.get_node( @@ -862,7 +960,9 @@ class PostProcessor(HDF5Container): # Fluct part def getter_alpha_num(dset): - r = np.sqrt(np.sum((self.lbox * self.getter_pos_disk(dset)) ** 2, axis=2)) + r = np.sqrt( + np.sum((self.lbox * self.oct_getter_pos_disk(dset)) ** 2, axis=2) + ) bins = np.zeros(r.shape, dtype=int) for r0 in radial_bins[1:]: @@ -871,8 +971,8 @@ class PostProcessor(HDF5Container): vr_mean = mean_bin_vr[bins] vphi_mean = mean_bin_vphi[bins] - vr = self.getter_vr(dset) - vphi = self.getter_vphi(dset) + vr = self.oct_getter_vr(dset) + vphi = self.oct_getter_vphi(dset) alpha = (vphi - vphi_mean) * (vr - vr_mean) return alpha @@ -889,15 +989,15 @@ class PostProcessor(HDF5Container): "Map of the gravitational contribution to the Shakura&Sunaev alpha parameter for disks" assert ax_los == "z" - T_avg = self.save.get_node("/maps/avg_map_T_avg_z").read() + T_avg = self.save.get_node("/maps/avg_map_T_mwavg_z").read() coldens = self.save.get_node("/maps/avg_map_coldens_z").read() def getter_alpha_grav(dset): - r2 = np.sum((self.lbox * self.getter_pos_disk(dset)) ** 2, axis=2) + r2 = np.sum((self.lbox * self.oct_getter_pos_disk(dset)) ** 2, axis=2) e2 = (1.0 / 256.0) ** 2 gstar = -self.G * self.pp_params.disk.mass_star / (e2 + r2) - gr = self.getter_vect_r(dset, "g") - gstar - gphi = self.getter_vect_phi(dset, "g") + gr = self.oct_getter_vect_r(dset, "g") - gstar + gphi = self.oct_getter_vect_phi(dset, "g") return gr * gphi / (4 * np.pi * self.G) alpha_g = self._ax_avg(getter_alpha_grav, "z", unit=cst.none, surf_qty=True) / ( @@ -908,6 +1008,14 @@ class PostProcessor(HDF5Container): alpha_g = (2.0 / 3) * alpha_g return alpha_g + alpha_g = self._ax_avg(getter_alpha_grav, "z", unit=cst.none, surf_qty=True) / ( + coldens * T_avg + ) + + # alpha + alpha_g = (2.0 / 3) * alpha_g + return alpha_g + def _sinks(self): csv_name = ( self.path @@ -950,7 +1058,139 @@ class PostProcessor(HDF5Container): def _pspec(self): outfile = self.path_out + "/pspec.h5" pspec_new.pspec(repo=self.path, iouts=[self.num], outfile=outfile) - return outfile + return True + + def _filaments(self): + + datamap_name = self.pp_params.filaments.datamap + verbose = self.pp_params.filaments.verbose + rmin_frac = self.pp_params.filaments.rmin + rmax_frac = self.pp_params.filaments.rmax + size_thresh = self.pp_params.filaments.size_thresh + skel_thresh = self.pp_params.filaments.skel_thresh + branch_thresh = self.pp_params.filaments.branch_thresh + glob_thresh = self.pp_params.filaments.glob_thresh + + datamap = self.save.get_node("/maps/" + datamap_name + "_z").read() + shape = datamap.shape + x = np.arange(shape[0]) - shape[0] / 2 + y = np.arange(shape[1]) - shape[1] / 2 + xx, yy = np.meshgrid(x, y) + rr = np.sqrt(xx ** 2 + yy ** 2) + rmin = int(rmin_frac * shape[0]) + rmax = int(rmax_frac * shape[0]) + mask = (rr >= rmin) & (rr <= rmax) + + datamap[np.logical_not(mask)] = np.nan + self.fil = FilFinder2D(datamap, distance=1 * u.cm, beamwidth=1 * u.pix) + self.fil.preprocess_image(flatten_percent=95) + self.fil.create_mask( + verbose=verbose, + smooth_size=1 * u.pix, + adapt_thresh=2 * u.pix, + size_thresh=size_thresh * u.pix ** 2, + glob_thresh=glob_thresh, + ) + self.fil.medskel(verbose=verbose) + self.fil.analyze_skeletons( + skel_thresh=skel_thresh * u.pix, + branch_thresh=branch_thresh * u.pix, + relintens_thresh=0.1, + ) + self.fil.exec_rht() + self.fil.find_widths() + + outfile = self.path_out + "/filaments.pickle" + with open(outfile, "wb") as f: + pickle.dump(self.fil, f, pickle.HIGHEST_PROTOCOL) + return True + + def _filaments_center(self): + """ + Fill an array with center postion for each cell in a filament + """ + fil = self.fil + mask = fil.mask.copy() + _, distance = medial_axis(mask, return_distance=True) + skel = fil.skeleton + i_center = np.zeros(distance.shape, dtype=int) + j_center = np.zeros(distance.shape, dtype=int) + x_mask, y_mask = np.where(mask) + for k in range(len(x_mask)): + find_center(distance, skel, i_center, j_center, x_mask[k], y_mask[k]) + return np.stack([i_center, j_center]) + + def _filaments_forces(self): + """ + Compute forces within a filament (for disks) + """ + + GM = self.G * self.pp_params.disk.mass_star # Mass parameter + + # Get mask for filaments + fil = self.fil + mask_fil = np.asarray(fil.mask.copy(), dtype=bool) + + # Find center of filaments + i_center, j_center = self._filaments_center() + + # Get slices and projections at z = 0 + vphi = self.save.get_node("/maps/slice_velphi_z").read() + gr = self.save.get_node("/maps/slice_gr_z").read() + Pz = self.save.get_node("/maps/slice_P_z").read() + coldens = self.save.get_node("/maps/coldens_z").read() + vr = self.save.get_node("/maps/slice_velr_z").read() + + # Get coordinates + im_extent = np.array(self.save.root.maps._v_attrs.im_extent) * self.lbox + map_size = self.pp_params.pymses.map_size + pos_star = self.pp_params.disk.pos_star + + # Physical size of cells + dx = (im_extent[1] - im_extent[0]) / map_size + dy = (im_extent[3] - im_extent[2]) / map_size + + # Physical coordinates of the center of the cells + x = np.linspace(im_extent[0], im_extent[1], map_size) + 0.5 * dx + y = np.linspace(im_extent[2], im_extent[3], map_size) + 0.5 * dy + + xx, yy = np.meshgrid(x, y) + rr = np.sqrt((xx - pos_star[0]) ** 2 + (yy - pos_star[1]) ** 2) + + # Rotational support + R = vphi ** 2 / rr + + # Equilibrium + Gvrx, Gvry = np.gradient(vr) + gradvr = (xx * Gvrx + yy * Gvry) / rr + dvr = gradvr + vr * gradvr # Complete derivative + + # Thermal support + GPx, GPy = np.gradient(Pz) + gradPr = (xx * GPx + yy * GPy) / rr + fP = gradPr / coldens + + # Gravitational field + e2 = (1.0 / 512) ** 2 + gstar = -GM / (rr ** 2 + e2) + + # Substract gravitational field from the star + Rdisk = R + gstar + gdisk = gr - gstar + + # Forces at the center of filaments + Rdisk_center = Rdisk[i_center, j_center] + gr_center = gdisk[i_center, j_center] + fP_center = fP[i_center, j_center] + dvr_center = dvr[i_center, j_center] + + # Forces for the filaments equilibrium + Rfil = Rdisk - Rdisk_center + gfil = gdisk - gr_center + fPfil = fP - fP_center + dvr_fil = dvr - dvr_center + + return {"gfil": gfil, "Rfil": Rfil, "fPfil": fPfil, "dvr": dvr_fil} def def_rules(self): @@ -967,7 +1207,7 @@ class PostProcessor(HDF5Container): self, partial( self._ax_avg, - self.getter_vr, + self.oct_getter_vr, mass_weighted=True, unit=self.info["unit_velocity"], ), @@ -979,7 +1219,7 @@ class PostProcessor(HDF5Container): self, partial( self._ax_avg, - self.getter_vphi, + self.oct_getter_vphi, mass_weighted=True, unit=self.info["unit_velocity"], ), @@ -987,31 +1227,7 @@ class PostProcessor(HDF5Container): "/maps", unit=self.info["unit_velocity"], ), - "rho_avg": Rule( - self, - partial( - self._ax_avg, - getter_rho, - mass_weighted=False, - unit=self.info["unit_density"], - ), - "Ax mass-weighted averaged azimuthal density", - "/maps", - unit=self.info["unit_density"], - ), - "P_avg": Rule( - self, - partial( - self._ax_avg, - getter_P, - mass_weighted=True, - unit=self.info["unit_pressure"], - ), - "Ax mass-weighted averaged azimuthal pressure", - "/maps", - unit=self.info["unit_pressure"], - ), - "T_avg": Rule( + "T_mwavg": Rule( self, partial( self._ax_avg, @@ -1031,7 +1247,7 @@ class PostProcessor(HDF5Container): unit=cst.none, dependencies=[ "avg_map_rho_avg", - "avg_map_T_avg", + "avg_map_T_mwavg", "avg_map_vr", "avg_map_vphi", ], @@ -1043,7 +1259,7 @@ class PostProcessor(HDF5Container): Shakura&Sunaev alpha parameter for disks", "/maps", unit=cst.none, - dependencies=["avg_map_coldens", "avg_map_T_avg"], + dependencies=["avg_map_coldens", "avg_map_T_mwavg"], ), "rho": Rule( self, @@ -1139,6 +1355,27 @@ class PostProcessor(HDF5Container): }, ), "pspec": Rule(self, self._pspec, "Power spectrum", "/hdf5"), + "filaments": Rule( + self, + self._filaments, + "Filaments", + "/datasets", + dependencies={self.pp_params.filaments.datamap: "z"}, + ), + "filaments_forces": Rule( + self, + self._filaments_forces, + "Filaments", + "/datasets", + dependencies={ + "filaments": None, + "slice_velphi": "z", + "slice_gr": "z", + "slice_P": "z", + "coldens": "z", + "slice_velr": "z", + }, + ), # Helpers "radial_bins": Rule(self, self._radial_bins, "Radial bins", "/radial"), "rr": Rule(self, self._rr, "Coordinate map", "/maps"), @@ -1165,6 +1402,18 @@ class PostProcessor(HDF5Container): "/hist", unit=self.info["unit_density"], ), + "rho_pdf_mw": Rule( + self, + partial( + self._vol_pdf, + partial(simple_getter, "rho"), + weight_func=mass_func, + logbins=True, + ), + "Global rho-PDF", + "/hist", + unit=self.info["unit_density"], + ), "T_pdf": Rule( self, partial(self._vol_pdf, getter_T, logbins=True), @@ -1226,6 +1475,13 @@ class PostProcessor(HDF5Container): "/globals", unit=self.info["unit_time"], ), + "mass": Rule( + self, + partial(self._sum, mass_func), + "Total mass", + "/globals", + unit=self.info["unit_density"] * self.info["unit_length"] ** 3, + ), "mwa_speed": Rule( self, partial(self._vol_avg, partial(simple_getter, "vel")), @@ -1261,7 +1517,8 @@ class PostProcessor(HDF5Container): "rho_avg", "P_avg", "T_avg", - "alpha_disk", + "P_mwavg", + "T_mwavg" "alpha_disk", "alpha_grav", ] for name in averageables: @@ -1312,7 +1569,72 @@ class PostProcessor(HDF5Container): dependencies=[name, name_bin], ) - self._gen_rule_transform("fluct_coldens", np.max, "max", group="/globals") + self._gen_rule_transform("fluct_coldens", np.nanmax, "max", group="/globals") + + # Generic rules directly from Ramses fields + for field in self.pp_params.pymses.variables: + + def generic_rule(name, getter, unit, oct_getter=None): + if oct_getter is None: + oct_getter = getter + + self.rules["slice_" + name] = Rule( + self, + partial(self._slice, getter, z=0.0, unit=unit), + "{} slice".format(name), + "/maps", + unit=unit, + ) + + self.rules[name + "_mwavg"] = Rule( + self, + partial(self._ax_avg, oct_getter, mass_weighted=True, unit=unit), + "Ax mass-weighted averaged {}".format(name), + "/maps", + unit=unit, + ) + + self.rules[name + "_avg"] = Rule( + self, + partial(self._ax_avg, oct_getter, mass_weighted=False, unit=unit), + "Ax averaged {}".format(name), + "/maps", + unit=unit, + ) + + # special for vectors + if field in ["g", "vel"]: + # Components + for i, dir in enumerate(["x", "y", "z"]): + generic_rule( + field + dir, + partial(vect_getter, field, i), + self.unit_key[field], + ) + + # Radial + generic_rule( + field + "r", + partial(self.getter_vect_r, name_vect=field), + self.unit_key[field], + oct_getter=self.oct_getter_vect_r, + ) + + # Othoradial + generic_rule( + field + "phi", + partial(self.getter_vect_phi, name_vect=field), + self.unit_key[field], + oct_getter=self.oct_getter_vect_phi, + ) + + # Norm + generic_rule( + field + "_norm", partial(norm_getter, field), self.unit_key[field] + ) + + else: + generic_rule(field, partial(simple_getter, field), self.unit_key[field]) super(PostProcessor, self).def_rules() diff --git a/pp_params.yml b/pp_params.yml index 40de283..bf95340 100644 --- a/pp_params.yml +++ b/pp_params.yml @@ -22,9 +22,21 @@ disk: # Disk speficic parameters pdf: # parameters for probability density functions - nb_bin : 50 # Number of bins for the PDF - xmin_fit : 0. # Lower boundary of the fit - xmax_fit : 1.25 # Upper boundary of the fit + nb_bin : 100 # Number of bins for the PDF + range : [-1.5, 2.5] # Range of the PDF (log of fluctuation) + xmin_fit : 0. # Lower boundary of the fit (log of fluctuation) + xmax_fit : 1.25 # Upper boundary of the fit (log of fluctuation) + + +filaments: # parameters for FilFinder + datamap : "rho_avg" + verbose : False + rmin : 0.15 # In fraction of the box (zoom to be taken into account) + rmax : 0.45 # In fraction of the box (idem) + size_thresh : 200 # in pixels**2 + skel_thresh : 100 # in pixels + branch_thresh : 100 # in pixels + glob_thresh : 40 # in map unit pymses: # Parameters for Pymses reader @@ -56,7 +68,8 @@ input: # Parameters on how to look for input files (= output from Ramses) out: # Parameters for post processing tag : "" # Tag for the image - interactive : False # Interactive mode (do not save the plots on the disk) + interactive : False # Interactive mode (keep figures open) + save : True # Save the plots on the disk ext : '.jpeg' # extension for plots fmt : "" # Format of the output filename for plots # The following keys are accepted @@ -70,7 +83,7 @@ out: # Parameters for post processing process: # General setting of the post-processor module - verbose : True # Give more infos on what is going on + verbose : False # Give more infos on what is going on num_process : 1 # Number of forks save_cells : True # Save cells structure on disk unload_cells : True # Save memory usage diff --git a/run_selector.py b/run_selector.py index a415cc9..c3ae12a 100644 --- a/run_selector.py +++ b/run_selector.py @@ -106,10 +106,70 @@ class RunSelector: for run in self.runs: in_nums[run] = nums_temp - for i, run in enumerate(self.runs): - self.nums[run] = self.get_nums( - run, in_nums[run], time_min, time_max, time - ) + for i, run in enumerate(self.runs): + self.nums[run] = self.get_nums(run, in_nums[run], time_min, time_max, time) + + def select( + self, + runs=None, + nums="all", + filter_nml={}, + sort_run_by=None, + time_min=None, + time_max=None, + time=None, + ): + """ + Sub-select runs and outputs from already selected runs and outputs + + Parameters + --------- + runs : str or list of str. The name runs to consider. Default: all. + nums : int or list of int or str. + The output numbers to consider. + "last" select only the last output. + "all" preselect all outputs (default) + + filter_nml : tuple or list of tupple. + Filter runs by namelist. + tuples are in the following form: + (nml_key, operator, nml_value) + with nml_key a key from the namelist (eg. "cloud_params/dens0") + operator within ("=", "!=", "<", ">", "in") + and nml_value a string, float or int + time_min : float, select output where time >= time_min (in code units) + time_max : float, select output where time <= time_min (in code units) + time : float or list of float. For each value, select the output closer to it. + + sort_run_by : str, a key from the namelist used to sort the runs (by ascending order) + + Returns + ------- + + (selected_runs, selected_nums) + """ + + selected_runs = self.get_runs( + runs, "*", filter_nml, sort_run_by, do_tests=False + ) + + if len(selected_runs) == 0: + raise ValueError("No runs found") + + if not type(nums) == dict: + nums_temp = nums + nums = {} + for run in selected_runs: + nums[run] = nums_temp + + selected_nums = {} + + for i, run in enumerate(selected_runs): + selected_nums[run] = self.get_nums( + run, nums[run], time_min, time_max, time, do_tests=False + ) + + return selected_runs, selected_nums def load_namelist(self, run): path_run = self.path_in + "/" + run @@ -139,7 +199,14 @@ class RunSelector: runs = list(filter(lambda r: value[r] in operand, runs)) return runs - def get_runs(self, in_runs=None, filter_name="*", filter_nml={}, sort_run_by=None): + def get_runs( + self, + in_runs=None, + filter_name="*", + filter_nml={}, + sort_run_by=None, + do_tests=True, + ): def try_load_nml(run): try: self.namelist[run] = self.load_namelist(run) @@ -148,17 +215,25 @@ class RunSelector: success = False return success - runs = list( - map( - os.path.basename, - list( - filter(os.path.isdir, glob.glob(self.path_in + "/" + filter_name)) - ), + if do_tests: + runs = list( + map( + os.path.basename, + list( + filter( + os.path.isdir, glob.glob(self.path_in + "/" + filter_name) + ) + ), + ) ) - ) + else: + runs = self.runs + if in_runs is not None: runs = list(filter(lambda n: n in runs, in_runs)) - runs = list(filter(try_load_nml, runs)) + + if do_tests: + runs = list(filter(try_load_nml, runs)) # Select runs that match namelist conditions runs = self.nml_select(runs, filter_nml) @@ -194,23 +269,32 @@ class RunSelector: info_file.close() return info - def get_nums(self, run, in_nums=None, time_min=None, time_max=None, time=None): + def get_nums( + self, run, in_nums=None, time_min=None, time_max=None, time=None, do_tests=True + ): def try_load_info(num): - try: - self.info[run][num] = self.load_info(run, num) + if do_tests: + try: + self.info[run][num] = self.load_info(run, num) + success = True + except IOError: + success = False + else: success = True - except IOError: - success = False return success - names = glob.glob( - self.path_in + "/" + run + "/output_[0-9][0-9][0-9][0-9][0-9]" - ) - nums = list(map(lambda n: int(n.split("/")[-1].split("_")[1]), names)) - - if type(in_nums) == int: + if isinstance(in_nums, int): in_nums = [in_nums] - if type(in_nums) == list: + + if do_tests: + names = glob.glob( + self.path_in + "/" + run + "/output_[0-9][0-9][0-9][0-9][0-9]" + ) + nums = list(map(lambda n: int(n.split("/")[-1].split("_")[1]), names)) + else: + nums = self.nums[run] + + if isinstance(in_nums, list): nums = list(filter(lambda n: n in nums, in_nums)) nums = np.sort(nums)