[plotter] add filter and scatter options to parts

This commit is contained in:
Noe Brucy
2021-12-20 14:45:49 +01:00
parent 1eaa7318a3
commit 5a8fd2ba93
+90 -29
View File
@@ -29,6 +29,7 @@ if os.environ.get("DISPLAY", "") == "":
import datetime import datetime
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
try: try:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
except ModuleNotFoundError: except ModuleNotFoundError:
@@ -67,7 +68,6 @@ def not_array_error(err):
return str(err)[-len(epy2) :] == epy2 or str(err)[-len(epy3) :] == epy3 return str(err)[-len(epy2) :] == epy2 or str(err)[-len(epy3) :] == epy3
def gethv(map_h, map_v, extent): def gethv(map_h, map_v, extent):
# Number of selected vectors # Number of selected vectors
nh = map_h.shape[0] nh = map_h.shape[0]
@@ -86,6 +86,7 @@ def gethv(map_h, map_v, extent):
return np.meshgrid(h, v) return np.meshgrid(h, v)
def streamplot(ax, map_h, map_v, extent, **kwargs): def streamplot(ax, map_h, map_v, extent, **kwargs):
""" """
Add an overlay : streamlines Add an overlay : streamlines
@@ -96,7 +97,6 @@ def streamplot(ax, map_h, map_v, extent, **kwargs):
def quiver(ax, map_h, map_v, extent, key_v=None, label="", **kwargs): def quiver(ax, map_h, map_v, extent, key_v=None, label="", **kwargs):
hh, vv = gethv(map_h, map_v, extent) hh, vv = gethv(map_h, map_v, extent)
# plot vector field # plot vector field
@@ -119,19 +119,20 @@ def quiver(ax, map_h, map_v, extent, key_v=None, label="", **kwargs):
coordinates="figure", coordinates="figure",
) )
def line_integral_convolution(ax, map_h, map_v, extent, **kwargs): def line_integral_convolution(ax, map_h, map_v, extent, **kwargs):
""" """
from Adnan Ali Ahmad from Adnan Ali Ahmad
""" """
lic_res = lic.lic(map_v, map_h,length=20) #compute line integral convolution lic_res = lic.lic(map_v, map_h, length=20) # compute line integral convolution
# Amplify contrast on lic # Amplify contrast on lic
lim=(.1,.9) lim = (0.1, 0.9)
lic_data_clip = np.clip(lic_res,lim[0],lim[1]) lic_data_clip = np.clip(lic_res, lim[0], lim[1])
lic_data_rgba = ScalarMappable(norm=None, cmap="binary").to_rgba(lic_data_clip) lic_data_rgba = ScalarMappable(norm=None, cmap="binary").to_rgba(lic_data_clip)
lic_data_clip_rescale = (lic_data_clip-lim[0])/(lim[1]-lim[0]) lic_data_clip_rescale = (lic_data_clip - lim[0]) / (lim[1] - lim[0])
lic_data_rgba[...,3] = lic_data_clip_rescale * 1 lic_data_rgba[..., 3] = lic_data_clip_rescale * 1
args = [lic_data_rgba] args = [lic_data_rgba]
plot_args = {**kwargs} plot_args = {**kwargs}
@@ -141,7 +142,6 @@ def line_integral_convolution(ax, map_h, map_v, extent, **kwargs):
ax.imshow(*args, **plot_args) ax.imshow(*args, **plot_args)
class PlotRule(Rule): class PlotRule(Rule):
""" """
The rule class, speficic to plot. The rule class, speficic to plot.
@@ -228,7 +228,12 @@ class Plotter(Aggregator, BaseProcessor):
# Select runs # Select runs
if selector is None: if selector is None:
self.selector = RunSelector( self.selector = RunSelector(
path, runs, nums, self.params.input.nml_filename, unit_time=unit_time, **kwargs path,
runs,
nums,
self.params.input.nml_filename,
unit_time=unit_time,
**kwargs,
) )
else: else:
self.selector = selector self.selector = selector
@@ -677,8 +682,6 @@ class Plotter(Aggregator, BaseProcessor):
if text_embeded is None: if text_embeded is None:
text_embeded = True text_embeded = True
if center_space: if center_space:
center = self.current_processor.get_attribute("/maps", "center") center = self.current_processor.get_attribute("/maps", "center")
center_h = center[self._ax_nb[ax_h]] center_h = center[self._ax_nb[ax_h]]
@@ -725,17 +728,37 @@ class Plotter(Aggregator, BaseProcessor):
frameon=False, frameon=False,
) )
plt.gca().add_artist(scalebar) plt.gca().add_artist(scalebar)
if axes_indicator: if axes_indicator:
# A liitle drawing saying what are the axes # A liitle drawing saying what are the axes
plt.annotate('', xy=(0.97, 0.1), xycoords='axes fraction', xytext=(0.865, 0.1), plt.annotate(
arrowprops={'arrowstyle': '->', "color" : overtext_color}) "",
plt.annotate('', xy=(0.87, 0.2), xycoords='axes fraction', xytext=(0.87, 0.095), xy=(0.97, 0.1),
arrowprops={'arrowstyle': '->', "color" : overtext_color}) xycoords="axes fraction",
plt.annotate(self._ax_title[ax_h], xy=(0.87, 0.2), xytext=(0.89, 0.05), xytext=(0.865, 0.1),
color=overtext_color, xycoords='axes fraction') arrowprops={"arrowstyle": "->", "color": overtext_color},
plt.annotate(self._ax_title[ax_v], xy=(0.87, 0.2), xytext=(0.83, 0.12), )
color=overtext_color, xycoords='axes fraction') plt.annotate(
"",
xy=(0.87, 0.2),
xycoords="axes fraction",
xytext=(0.87, 0.095),
arrowprops={"arrowstyle": "->", "color": overtext_color},
)
plt.annotate(
self._ax_title[ax_h],
xy=(0.87, 0.2),
xytext=(0.89, 0.05),
color=overtext_color,
xycoords="axes fraction",
)
plt.annotate(
self._ax_title[ax_v],
xy=(0.87, 0.2),
xytext=(0.83, 0.12),
color=overtext_color,
xycoords="axes fraction",
)
if axes: if axes:
if xlabel is None: if xlabel is None:
xlabel = self._ax_title[ax_h] xlabel = self._ax_title[ax_h]
@@ -777,7 +800,13 @@ class Plotter(Aggregator, BaseProcessor):
if put_title: if put_title:
title = self.snapshot_title(run, title, nml_key, put_time, unit_time) title = self.snapshot_title(run, title, nml_key, put_time, unit_time)
if text_embeded: if text_embeded:
ax.text(x=0.05, y=0.95, s=title, color=overtext_color, transform=ax.transAxes) ax.text(
x=0.05,
y=0.95,
s=title,
color=overtext_color,
transform=ax.transAxes,
)
else: else:
plt.title(title) plt.title(title)
@@ -900,10 +929,15 @@ class Plotter(Aggregator, BaseProcessor):
center_space=False, center_space=False,
parts=True, parts=True,
sinks=False, sinks=False,
filter_fun=None,
s=None,
c=None,
**kwargs, **kwargs,
): ):
""" """
Add an overlay with particles data Add an overlay with particles data
if both sinks and parts are set to true, only sinks are overlayed
filter_fun : function that take an array like value and returns an array of boolean
""" """
unit_length = self.current_processor.info["unit_length"] unit_length = self.current_processor.info["unit_length"]
@@ -915,16 +949,17 @@ class Plotter(Aggregator, BaseProcessor):
self.current_processor.get_value("/datasets/sinks") self.current_processor.get_value("/datasets/sinks")
) )
part_pos = sinks[["x", "y", "z"]].values part_pos = sinks[["x", "y", "z"]].values
mass = sinks.msink
unit_length /= self.current_processor.lbox unit_length /= self.current_processor.lbox
data = sinks
except KeyError: except KeyError:
self.current_processor._log("No sinks particles", "WARNING") self.current_processor._log("No sinks particles", "WARNING")
return return
elif parts: elif parts:
# Open particle HDF5 filetype_from_ext # Open particle HDF5 filetype_from_ext
self.current_processor.load_parts(keys=["pos", "mass"]) self.current_processor.load_parts(keys=["pos", "mass"])
part_pos = self.current_processor.parts.pos data = self.current_processor.parts
mass = self.current_processor.parts.mass part_pos = self.current_processor.parts["pos"]
mass = self.current_processor.parts["mass"]
mass *= self.current_processor.info["unit_mass"].express(U.Msun) mass *= self.current_processor.info["unit_mass"].express(U.Msun)
self.current_processor.unload_parts() self.current_processor.unload_parts()
@@ -956,14 +991,39 @@ class Plotter(Aggregator, BaseProcessor):
& (im_extent[2] <= part_v) & (im_extent[2] <= part_v)
& (part_v <= im_extent[3]) & (part_v <= im_extent[3])
) )
if filter_fun is not None:
mask = mask & filter_fun(data)
part_h = part_h[mask] part_h = part_h[mask]
part_v = part_v[mask] part_v = part_v[mask]
# Size and color
if s is None and sinks:
s = data.msink[mask] / 5e3
if isinstance(s, str):
s = data[s][mask]
elif callable(s):
s = s(data)[mask]
if isinstance(c, str):
c = data[c][mask]
elif callable(c):
c = c(data)[mask]
# Scatter plot # Scatter plot
plt.scatter(part_h, part_v, s=mass[mask] / 5e3, **kwargs) plt.scatter(part_h, part_v, s=s, c=c, **kwargs)
def _overlay_vector( def _overlay_vector(
self, name, ax_los, extent, unit=U.km_s, unit_coeff=1.0, reduce_res=1, kind="quiver", **kwargs self,
name,
ax_los,
extent,
unit=U.km_s,
unit_coeff=1.0,
reduce_res=1,
kind="quiver",
**kwargs,
): ):
""" """
Add an overlay : vector field Add an overlay : vector field
@@ -977,8 +1037,9 @@ class Plotter(Aggregator, BaseProcessor):
map_h = self.current_processor.get_value(f"/maps/slice_{name}{ax_h}_{ax_los}") map_h = self.current_processor.get_value(f"/maps/slice_{name}{ax_h}_{ax_los}")
map_v = self.current_processor.get_value(f"/maps/slice_{name}{ax_v}_{ax_los}") map_v = self.current_processor.get_value(f"/maps/slice_{name}{ax_v}_{ax_los}")
label, unit_old, unit = self._ax_label_unit(f"/maps/slice_{name}{ax_h}_{ax_los}", "", unit, unit_coeff) label, unit_old, unit = self._ax_label_unit(
f"/maps/slice_{name}{ax_h}_{ax_los}", "", unit, unit_coeff
)
# take only a subset # take only a subset
map_h = map_h[::reduce_res, ::reduce_res] * unit_old.express(unit) map_h = map_h[::reduce_res, ::reduce_res] * unit_old.express(unit)
@@ -991,7 +1052,7 @@ class Plotter(Aggregator, BaseProcessor):
elif kind == "lic": elif kind == "lic":
line_integral_convolution(plt.gca(), map_h, map_v, extent=extent, **kwargs) line_integral_convolution(plt.gca(), map_h, map_v, extent=extent, **kwargs)
def _overlay_speed(self, ax_los, extent, **kwargs): def _overlay_speed(self, ax_los, extent, **kwargs):
self._overlay_vector("vel", ax_los, extent, **kwargs) self._overlay_vector("vel", ax_los, extent, **kwargs)
def _overlay_B(self, ax_los, extent, **kwargs): def _overlay_B(self, ax_los, extent, **kwargs):