From 5a8fd2ba93db7278e0f05371b9bc853f847993a4 Mon Sep 17 00:00:00 2001 From: Noe Brucy Date: Mon, 20 Dec 2021 14:45:49 +0100 Subject: [PATCH] [plotter] add filter and scatter options to parts --- plotter.py | 119 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 90 insertions(+), 29 deletions(-) diff --git a/plotter.py b/plotter.py index 4ddd01f..9e483a7 100644 --- a/plotter.py +++ b/plotter.py @@ -29,6 +29,7 @@ if os.environ.get("DISPLAY", "") == "": import datetime import matplotlib.pyplot as plt + try: from moviepy.video.io.ImageSequenceClip import ImageSequenceClip except ModuleNotFoundError: @@ -67,7 +68,6 @@ def not_array_error(err): return str(err)[-len(epy2) :] == epy2 or str(err)[-len(epy3) :] == epy3 - def gethv(map_h, map_v, extent): # Number of selected vectors nh = map_h.shape[0] @@ -86,6 +86,7 @@ def gethv(map_h, map_v, extent): return np.meshgrid(h, v) + def streamplot(ax, map_h, map_v, extent, **kwargs): """ Add an overlay : streamlines @@ -96,7 +97,6 @@ def streamplot(ax, map_h, map_v, extent, **kwargs): def quiver(ax, map_h, map_v, extent, key_v=None, label="", **kwargs): - hh, vv = gethv(map_h, map_v, extent) # plot vector field @@ -119,19 +119,20 @@ def quiver(ax, map_h, map_v, extent, key_v=None, label="", **kwargs): coordinates="figure", ) + def line_integral_convolution(ax, map_h, map_v, extent, **kwargs): """ from Adnan Ali Ahmad """ - lic_res = lic.lic(map_v, map_h,length=20) #compute line integral convolution + lic_res = lic.lic(map_v, map_h, length=20) # compute line integral convolution # Amplify contrast on lic - lim=(.1,.9) - lic_data_clip = np.clip(lic_res,lim[0],lim[1]) + lim = (0.1, 0.9) + lic_data_clip = np.clip(lic_res, lim[0], lim[1]) lic_data_rgba = ScalarMappable(norm=None, cmap="binary").to_rgba(lic_data_clip) - lic_data_clip_rescale = (lic_data_clip-lim[0])/(lim[1]-lim[0]) - lic_data_rgba[...,3] = lic_data_clip_rescale * 1 + lic_data_clip_rescale = (lic_data_clip - lim[0]) / (lim[1] - lim[0]) + lic_data_rgba[..., 3] = lic_data_clip_rescale * 1 args = [lic_data_rgba] plot_args = {**kwargs} @@ -141,7 +142,6 @@ def line_integral_convolution(ax, map_h, map_v, extent, **kwargs): ax.imshow(*args, **plot_args) - class PlotRule(Rule): """ The rule class, speficic to plot. @@ -228,7 +228,12 @@ class Plotter(Aggregator, BaseProcessor): # Select runs if selector is None: self.selector = RunSelector( - path, runs, nums, self.params.input.nml_filename, unit_time=unit_time, **kwargs + path, + runs, + nums, + self.params.input.nml_filename, + unit_time=unit_time, + **kwargs, ) else: self.selector = selector @@ -677,8 +682,6 @@ class Plotter(Aggregator, BaseProcessor): if text_embeded is None: text_embeded = True - - if center_space: center = self.current_processor.get_attribute("/maps", "center") center_h = center[self._ax_nb[ax_h]] @@ -725,17 +728,37 @@ class Plotter(Aggregator, BaseProcessor): frameon=False, ) plt.gca().add_artist(scalebar) - + if axes_indicator: # A liitle drawing saying what are the axes - plt.annotate('', xy=(0.97, 0.1), xycoords='axes fraction', xytext=(0.865, 0.1), - arrowprops={'arrowstyle': '->', "color" : overtext_color}) - plt.annotate('', xy=(0.87, 0.2), xycoords='axes fraction', xytext=(0.87, 0.095), - arrowprops={'arrowstyle': '->', "color" : overtext_color}) - plt.annotate(self._ax_title[ax_h], xy=(0.87, 0.2), xytext=(0.89, 0.05), - color=overtext_color, xycoords='axes fraction') - plt.annotate(self._ax_title[ax_v], xy=(0.87, 0.2), xytext=(0.83, 0.12), - color=overtext_color, xycoords='axes fraction') + plt.annotate( + "", + xy=(0.97, 0.1), + xycoords="axes fraction", + xytext=(0.865, 0.1), + arrowprops={"arrowstyle": "->", "color": overtext_color}, + ) + plt.annotate( + "", + xy=(0.87, 0.2), + xycoords="axes fraction", + xytext=(0.87, 0.095), + arrowprops={"arrowstyle": "->", "color": overtext_color}, + ) + plt.annotate( + self._ax_title[ax_h], + xy=(0.87, 0.2), + xytext=(0.89, 0.05), + color=overtext_color, + xycoords="axes fraction", + ) + plt.annotate( + self._ax_title[ax_v], + xy=(0.87, 0.2), + xytext=(0.83, 0.12), + color=overtext_color, + xycoords="axes fraction", + ) if axes: if xlabel is None: xlabel = self._ax_title[ax_h] @@ -777,7 +800,13 @@ class Plotter(Aggregator, BaseProcessor): if put_title: title = self.snapshot_title(run, title, nml_key, put_time, unit_time) if text_embeded: - ax.text(x=0.05, y=0.95, s=title, color=overtext_color, transform=ax.transAxes) + ax.text( + x=0.05, + y=0.95, + s=title, + color=overtext_color, + transform=ax.transAxes, + ) else: plt.title(title) @@ -900,10 +929,15 @@ class Plotter(Aggregator, BaseProcessor): center_space=False, parts=True, sinks=False, + filter_fun=None, + s=None, + c=None, **kwargs, ): """ Add an overlay with particles data + if both sinks and parts are set to true, only sinks are overlayed + filter_fun : function that take an array like value and returns an array of boolean """ unit_length = self.current_processor.info["unit_length"] @@ -915,16 +949,17 @@ class Plotter(Aggregator, BaseProcessor): self.current_processor.get_value("/datasets/sinks") ) part_pos = sinks[["x", "y", "z"]].values - mass = sinks.msink unit_length /= self.current_processor.lbox + data = sinks except KeyError: self.current_processor._log("No sinks particles", "WARNING") return elif parts: # Open particle HDF5 filetype_from_ext self.current_processor.load_parts(keys=["pos", "mass"]) - part_pos = self.current_processor.parts.pos - mass = self.current_processor.parts.mass + data = self.current_processor.parts + part_pos = self.current_processor.parts["pos"] + mass = self.current_processor.parts["mass"] mass *= self.current_processor.info["unit_mass"].express(U.Msun) self.current_processor.unload_parts() @@ -956,14 +991,39 @@ class Plotter(Aggregator, BaseProcessor): & (im_extent[2] <= part_v) & (part_v <= im_extent[3]) ) + if filter_fun is not None: + mask = mask & filter_fun(data) + part_h = part_h[mask] part_v = part_v[mask] + # Size and color + if s is None and sinks: + s = data.msink[mask] / 5e3 + + if isinstance(s, str): + s = data[s][mask] + elif callable(s): + s = s(data)[mask] + + if isinstance(c, str): + c = data[c][mask] + elif callable(c): + c = c(data)[mask] + # Scatter plot - plt.scatter(part_h, part_v, s=mass[mask] / 5e3, **kwargs) + plt.scatter(part_h, part_v, s=s, c=c, **kwargs) def _overlay_vector( - self, name, ax_los, extent, unit=U.km_s, unit_coeff=1.0, reduce_res=1, kind="quiver", **kwargs + self, + name, + ax_los, + extent, + unit=U.km_s, + unit_coeff=1.0, + reduce_res=1, + kind="quiver", + **kwargs, ): """ Add an overlay : vector field @@ -977,8 +1037,9 @@ class Plotter(Aggregator, BaseProcessor): map_h = self.current_processor.get_value(f"/maps/slice_{name}{ax_h}_{ax_los}") map_v = self.current_processor.get_value(f"/maps/slice_{name}{ax_v}_{ax_los}") - label, unit_old, unit = self._ax_label_unit(f"/maps/slice_{name}{ax_h}_{ax_los}", "", unit, unit_coeff) - + label, unit_old, unit = self._ax_label_unit( + f"/maps/slice_{name}{ax_h}_{ax_los}", "", unit, unit_coeff + ) # take only a subset map_h = map_h[::reduce_res, ::reduce_res] * unit_old.express(unit) @@ -991,7 +1052,7 @@ class Plotter(Aggregator, BaseProcessor): elif kind == "lic": line_integral_convolution(plt.gca(), map_h, map_v, extent=extent, **kwargs) - def _overlay_speed(self, ax_los, extent, **kwargs): + def _overlay_speed(self, ax_los, extent, **kwargs): self._overlay_vector("vel", ax_los, extent, **kwargs) def _overlay_B(self, ax_los, extent, **kwargs):