From c2f866b37f7e320b6a6c553925fd4ddf513af763 Mon Sep 17 00:00:00 2001 From: Noe Brucy Date: Sat, 14 Mar 2020 18:09:04 +0100 Subject: [PATCH] improvements --- comparator.py | 21 ++++++----- plotter.py | 91 +++++++++++++++++++++++++++++++++++++----------- postprocessor.py | 4 +-- run_selector.py | 11 ++++-- units.py | 7 ++++ 5 files changed, 97 insertions(+), 37 deletions(-) diff --git a/comparator.py b/comparator.py index 44d3780..e5610c5 100644 --- a/comparator.py +++ b/comparator.py @@ -92,21 +92,19 @@ class Comparator(Aggregator, HDF5Container): "Get real units from info files" if isinstance(unit, cst.Unit): return unit - elif isinstance(unit, str): - # assert(not run is None) - return self.info[unit] # [run][unit] - # elif unit.keys()[0] in self.runs: - # for run in unit: - # unit[run] = self._get_units(unit[run], run=run) - # return unit - elif unit.keys()[0] in self.info: + if isinstance(unit, str): + res = self.info[unit] + if unit == "unit_length": + res = res / self.info["boxlen"] + return res + if unit.keys()[0] in self.info: new_unit = cst.none for base_unit_str in unit: expo = unit[base_unit_str] base_unit = self._get_units(base_unit_str) new_unit = new_unit * base_unit ** expo return new_unit - elif (not data is None) and isinstance(data, dict) and unit.keys()[0] in data: + if (not data is None) and isinstance(data, dict) and unit.keys()[0] in data: for key in unit: unit[key] = self._get_units(unit[key]) return unit @@ -124,8 +122,9 @@ class Comparator(Aggregator, HDF5Container): for run in self.runs: series[run] = [] for i, num in enumerate(self.nums[run]): - series[run].apend(getter(run, num, arg=arg)) - return np.array(series) + series[run].append(getter(run, num, arg=arg)) + series[run] = np.array(series[run]) + return series def _comp(self, getter, use_num=True): prop = np.zeros(len(self.runs)) diff --git a/plotter.py b/plotter.py index 2e745d7..ba6db0b 100644 --- a/plotter.py +++ b/plotter.py @@ -119,7 +119,9 @@ class Plotter(Aggregator, BaseProcessor): or not os.path.exists(plot_filename) ) - def _process_rule(self, name, rule, arg, overwrite=False, ax=None, **kwargs): + def _process_rule( + self, name, rule, arg, overwrite=False, ax=None, movie=False, **kwargs + ): if not arg is None: name_full = name + "_" + str(arg) else: @@ -133,11 +135,9 @@ class Plotter(Aggregator, BaseProcessor): runs = runs.runs except KeyError: runs = self.runs - - if ax is None: - ax = [P.subplots(1, 1)[1] for run in runs for num in self.nums[run]] i = 0 for run in runs: + files = [] for num in self.nums[run]: plot_filename = self._find_filename(name_full, run, num) save = tables.open_file(self.pp[run][num].filename) @@ -153,24 +153,38 @@ class Plotter(Aggregator, BaseProcessor): **kwargs ) except TypeError as e: - if ( - str(e) - != "'LocatableAxes' object does not support indexing" - ): + if str(e) in [ + "'LocatableAxes' object does not support indexing", + "'AxesSubplot' object does not support indexing", + ]: + self._plot_rule( + rule, + save, + arg, + plot_filename, + overwrite, + ax=ax, + run=run, + **kwargs + ) + elif ax is None: + fig = P.figure() + self._plot_rule( + rule, + save, + arg, + plot_filename, + overwrite, + ax=P.gca(), + run=run, + **kwargs + ) + else: raise - self._plot_rule( - rule, - save, - arg, - plot_filename, - overwrite, - ax=ax, - run=run, - **kwargs - ) finally: save.close() i = i + 1 + files.append(plot_filename) else: if ax is None: ax = P.gca() @@ -479,6 +493,7 @@ class Plotter(Aggregator, BaseProcessor): xlog=None, ylog=False, kind="bar", + ylabel="$\mathcal{P}$", color=None, colors=None, nml_color=None, @@ -505,7 +520,7 @@ class Plotter(Aggregator, BaseProcessor): if put_time: time = self.save.root._v_attrs.time * self.comp.info["unit_time"] - time_str = self.pp_params.out.time_fmt.format( + time_str = self.pp_params.plot.time_fmt.format( time.express(time_unit), time_unit.latex ) if len(title) > 0: @@ -538,6 +553,8 @@ class Plotter(Aggregator, BaseProcessor): if not label is None: P.xlabel(label) + if not ylabel is None: + P.ylabel(ylabel) if not ax_los is None and "/hist/fit_" + name + "_" + ax_los in self.save: slope = node.attrs.slope @@ -573,8 +590,11 @@ class Plotter(Aggregator, BaseProcessor): yerr_kind="std", sigma_err=2.0, grid=True, + put_time=False, + time_unit=cst.Myr, colors=None, nml_color=None, + legend=None, **kwargs ): @@ -597,6 +617,16 @@ class Plotter(Aggregator, BaseProcessor): if grid: P.grid() + if put_time: + time = self.save.root._v_attrs.time * self.comp.info["unit_time"] + time_str = self.pp_params.plot.time_fmt.format( + time.express(time_unit), time_unit.latex + ) + if len(label) > 0: + label = label + " | " + time_str + else: + label = time_str + yerr = None if node_y._v_attrs.CLASS == "ARRAY": x = node_x.read() * xunit_old.express(xunit) @@ -607,7 +637,23 @@ class Plotter(Aggregator, BaseProcessor): y = gaussian_filter1d(y, sigma=smooth) if not run is None: label = self._label_run(run, node_y, label, nml_key) - (base_line,) = P.plot(x, y, label=label, **kwargs) + if colors is None: + (base_line,) = P.plot(x, y, label=label, **kwargs) + else: + if nml_color is None: + color = colors[run] + elif nml_color == "time": + time = ( + self.save.root._v_attrs.time * self.comp.info["unit_time"] + ).express(time_unit) + color = colors(time) + else: + nml = self.comp.get_nml(nml_color, run) + try: + color = colors[nml] + except: + color = colors(nml) + (base_line,) = P.plot(x, y, label=label, color=color, **kwargs) elif "mean" in node_y: x = node_x.read() * xunit_old.express(xunit) y = node_y.mean.read() * yunit_old.express(yunit) @@ -664,8 +710,11 @@ class Plotter(Aggregator, BaseProcessor): except: color = colors(nml) (base_line,) = P.plot(x, y, label=label_run, color=color, **kwargs) + if legend is None: + legend = True - P.legend() + if legend: + P.legend() if not fit is None: self._overlay_fit( diff --git a/postprocessor.py b/postprocessor.py index d80fe31..2aa5fce 100644 --- a/postprocessor.py +++ b/postprocessor.py @@ -226,9 +226,7 @@ class PostProcessor(HDF5Container): self.load_cells() return np.sort(np.unique(self.cells["pos"][:, axis])) - def _plane_avg_uniform( - self, getter, axis, unit=cst.none, mass_weighted=True, surf_qty=False - ): + def _plane_avg_uniform(self, getter, axis, unit=cst.none, mass_weighted=False): """ Profile of the average of a quantity (given by getter) perpendicular to an axis WARNING : This version only works on an uniform grid, need of a box version for AMR diff --git a/run_selector.py b/run_selector.py index ab79e89..8df4a03 100644 --- a/run_selector.py +++ b/run_selector.py @@ -44,6 +44,7 @@ class RunSelector: sort_run_by=None, time_min=None, time_max=None, + time=None, ): self.path_in = path_in self.pp_params = pp_params @@ -66,7 +67,7 @@ class RunSelector: in_nums[run] = nums_temp for i, run in enumerate(self.runs): - self.nums[run] = self.get_nums(run, in_nums[run], time_min, time_max) + self.nums[run] = self.get_nums(run, in_nums[run], time_min, time_max, time) def load_namelist(self, run): path_run = self.path_in + "/" + run @@ -147,7 +148,7 @@ class RunSelector: info_file.close() return info - def get_nums(self, run, in_nums=None, time_min=None, time_max=None): + def get_nums(self, run, in_nums=None, time_min=None, time_max=None, time=None): def try_load_info(num): try: self.info[run][num] = self.load_info(run, num) @@ -191,4 +192,10 @@ class RunSelector: nums = filter(lambda n: self.info[run][n]["time"] >= time_min, nums) if not time_max is None: nums = filter(lambda n: self.info[run][n]["time"] <= time_max, nums) + + if not time is None: + times = np.asarray([[self.info[run][n]["time"], n] for n in nums]) + idx = (np.abs(times[:, 0] - time)).argmin() + nums = [int(times[idx, 1])] + return nums diff --git a/units.py b/units.py index 196914a..4ed863c 100644 --- a/units.py +++ b/units.py @@ -68,3 +68,10 @@ cst.ssfr = cst.create_unit( descr="Surfacic SFR", latex="M$_{\odot}$.yr$^{-1}$.pc$^{-2}$", ) + +cst.ssfrG = cst.create_unit( + "Msun.Gyr^-1.pc^-2", + base_unit=1e-9 * cst.Msun / cst.year / cst.pc ** 2, + descr="Surfacic SFR", + latex="M$_{\odot}$.Gyr$^{-1}$.pc$^{-2}$", +)