diff --git a/plotter.py b/plotter.py index 460ea3f..41eda37 100644 --- a/plotter.py +++ b/plotter.py @@ -240,7 +240,7 @@ class Plotter(Aggregator, BaseProcessor): """ # log info - self.log_id = "plot {}".format(tag) + self.log_id = "plotter({})".format(tag) super(Plotter, self).__init__(path, path_out, params, tag) @@ -409,12 +409,12 @@ class Plotter(Aggregator, BaseProcessor): elif rule.kind == "comp": run_num = [(None, None)] if movie: - self._log(f"No movie possible for rule {name}", "WARNING") + self.logger.warning(f"No movie possible for rule {name}") movie = False else: run_num = [(run, None) for run in runs] if movie: - self._log(f"No movie possible for rule {name}", "WARNING") + self.logger.warning(f"No movie possible for rule {name}") movie = False onefigure = False # If axes are provided, only save/close once @@ -499,20 +499,18 @@ class Plotter(Aggregator, BaseProcessor): if self.params.plot.tight_layout and close: plt.tight_layout(pad=1) - if self.params.out.save: + if self.params.out.save: os.makedirs(os.path.dirname(plot_filename), exist_ok=True) plt.savefig(plot_filename) - self._log("{} plotted".format(plot_filename), "SUCCESS") + self.logger.info(f"{plot_filename} plotted") else: - self._log( - "{} plotted".format(os.path.basename(plot_filename)), "SUCCESS" - ) + self.logger.info(f"{os.path.basename(plot_filename)} plotted") if not self.params.out.interactive and close: plt.close() return plot_info else: - self._log("Plot {} is already done, skipping...".format(plot_filename)) + self.logger.info(f"Plot {plot_filename} is already done.") def _find_filename(self, name_full, run=None, num=None, fmt=None, ext=None): """ @@ -591,7 +589,7 @@ class Plotter(Aggregator, BaseProcessor): label = run return label - if nml_key is None and label is None: + if nml_key is None and (label is None or len(label) == 0): label_run = get_label_file(run) elif nml_key is not None: if not type(nml_key) == list: @@ -600,7 +598,7 @@ class Plotter(Aggregator, BaseProcessor): lbl_list = filter(lambda x: len(x) > 0, lbl_list) # Remove void labels label_run = ", ".join(lbl_list) - if label is not None: + if label is not None and len(label) > 0: label_run = label + " (" + label_run + ")" else: label_run = label @@ -757,7 +755,7 @@ class Plotter(Aggregator, BaseProcessor): plt.gca().add_artist(scalebar) if axes_indicator: - # A liitle drawing saying what are the axes + # A little drawing saying what are the axes plt.annotate( "", xy=(0.97, 0.1), @@ -979,7 +977,7 @@ class Plotter(Aggregator, BaseProcessor): part_pos = data[["x", "y", "z"]].values unit_length /= self.current_processor.lbox except KeyError: - self.current_processor._log("No sinks particles", "WARNING") + self.current_processor.logger.warning("No sinks particles") return elif parts: # Open particle HDF5 filetype_from_ext @@ -1219,14 +1217,115 @@ class Plotter(Aggregator, BaseProcessor): plot_title=title, ) + def plot( + self, + x:np.array, + y:np.array, + xlabel:str="", + ylabel:str="", + label:str="", + xscale:str="linear", + yscale:str="linear", + fit:str=None, + fitlabel:str=None, + smooth:float=0, + nml_key=None, + run:str=None, + yerr:np.array=None, + grid:bool=False, + put_time:bool=False, + unit_time=U.Myr, + colors=None, + nml_color=None, + legend:bool=False, + **kwargs, + ): + """ + Generic plot routine, with x, y two numpy arrauys + """ + + + # Option to smooth data for readability (beware) + if smooth > 0: + y = gaussian_filter1d(y, sigma=smooth) + + # Special label if the plot apply to a given run + if run is not None: + label = self.get_label_run(run, label, nml_key) + + # If relevant, get time + if put_time: + time = self.current_processor.time * self.study.info["unit_time"] + time_str = self.params.plot.time_fmt.format( + time.express(unit_time), unit_time.latex.replace("text", "math") + ) + time_str = f"${time_str}$" + if len(label) > 0: + label = label + " | " + time_str + else: + label = time_str + + # Look if special colors method is used + if colors is None: + if yerr is None: + (base_line,) = plt.plot(x, y, label=label, **kwargs) + else: + base_line, _, _ = plt.errorbar(x, y, yerr=yerr, label=label, **kwargs) + else: + if nml_color is None: + color = colors[run] + elif nml_color == "time": + time = ( + self.current_processor.time + * self.current_processor.info["unit_time"] + ).express(unit_time) + color = colors(time) + else: + nml_value = self.study.get_nml(nml_color, run) + if os.path.basename(nml_color) in self.value_convert: + nml_value = self.value_convert[os.path.basename(nml_color)](nml_value) + try: + color = colors[nml_value] + except TypeError: + color = colors(nml_value) + if yerr is None: + (base_line,) = plt.plot(x, y, label=label, color=color, **kwargs) + else: + base_line, _, _ = plt.errorbar( + x, y, yerr=yerr, color=color, label=label, **kwargs + ) + + # Ax decorations + plt.xlabel(xlabel) + plt.ylabel(ylabel) + if grid: + plt.grid() + if legend: + plt.legend() + + # Ax scale + plt.xscale(xscale) + plt.yscale(yscale) + + if fit is not None: + self._overlay_fit( + x, + y, + yerr, + kind=fit, + ls="--", + lw=1.5, + color=base_line.get_color(), + label=fitlabel, + ) + def _plot( self, - name_x, - name_y, + name_x:str, + name_y:str, node_arg=None, xlabel=None, ylabel=None, - label=None, xunit=None, yunit=None, put_units=True, @@ -1234,22 +1333,10 @@ class Plotter(Aggregator, BaseProcessor): yunit_coeff=1.0, xtransform=None, ytransform=None, - xscale="linear", - yscale="linear", - fit=None, - fitlabel=None, - smooth=0, - nml_key=None, run=None, yerr=None, yerr_kind="std", sigma_err=2.0, - grid=False, - put_time=False, - unit_time=U.Myr, - colors=None, - nml_color=None, - legend=None, subname_x=None, subname_y=None, **kwargs, @@ -1282,18 +1369,6 @@ class Plotter(Aggregator, BaseProcessor): name_y, ylabel, yunit, yunit_coeff, put_units=put_units, ) - # If relevent, get time - if put_time: - time = self.current_processor.time * self.study.info["unit_time"] - time_str = self.params.plot.time_fmt.format( - time.express(unit_time), unit_time.latex.replace("text", "math") - ) - time_str = f"${time_str}$" - if label is not None and len(label) > 0: - label = label + " | " + time_str - else: - label = time_str - # Manage the different forms in which the data may be stored : # Possibilities are : plain array, dict of arrays (mean, std, ..) or dict of array (runs) if isinstance(node_y, np.ndarray): @@ -1347,67 +1422,12 @@ class Plotter(Aggregator, BaseProcessor): if ytransform is not None: y = ytransform(y) if yerr is not None: - self._log( - "Errorbar may be meaningless when ytransform is used", "WARNING" - ) - if smooth > 0: - y = gaussian_filter1d(y, sigma=smooth) - if run is not None: - label = self.get_label_run(run, label, nml_key) - - # Look if special colors method is used - if colors is None: - if yerr is None: - (base_line,) = plt.plot(x, y, label=label, **kwargs) - else: - base_line, _, _ = plt.errorbar(x, y, yerr=yerr, label=label, **kwargs) - else: - if nml_color is None: - color = colors[run] - elif nml_color == "time": - time = ( - self.current_processor.time - * self.current_processor.info["unit_time"] - ).express(unit_time) - color = colors(time) - else: - nml_value = self.study.get_nml(nml_color, run) - if os.path.basename(nml_color) in self.value_convert: - nml_value = self.value_convert[os.path.basename(nml_color)](nml_value) - try: - color = colors[nml_value] - except TypeError: - color = colors(nml_value) - if yerr is None or yerr_kind is None: - (base_line,) = plt.plot(x, y, label=label, color=color, **kwargs) - else: - base_line, _, _ = plt.errorbar( - x, y, yerr=yerr, color=color, label=label, **kwargs + self.logger.warning( + "Errorbar may be meaningless when ytransform is used" ) - # Ax decorations - plt.xlabel(xlabel) - plt.ylabel(ylabel) - if grid: - plt.grid() - if legend: - plt.legend() - - # Ax scale - plt.xscale(xscale) - plt.yscale(yscale) - - if fit is not None: - self._overlay_fit( - x, - y, - yerr, - kind=fit, - ls="--", - lw=1.5, - color=base_line.get_color(), - label=fitlabel, - ) + self.plot(x, y, yerr=yerr, xlabel=xlabel, + ylabel=ylabel, run=run, **kwargs) if subname_x: hdf5_x.close() @@ -1507,7 +1527,7 @@ class Plotter(Aggregator, BaseProcessor): This is where rules are defined """ self.rules = { - "plot": PlotRule(lambda arg, **kwargs: self._plot(*arg, **kwargs), kind="comp" + "plot_arrays": PlotRule(lambda arg, **kwargs: self._plot(*arg, **kwargs), kind="comp" ), "plot_run": PlotRule(lambda arg, **kwargs: self._plot(*arg, **kwargs), kind="run" ),