[plotter] add filter and scatter options to parts
This commit is contained in:
+90
-29
@@ -29,6 +29,7 @@ if os.environ.get("DISPLAY", "") == "":
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
@@ -67,7 +68,6 @@ def not_array_error(err):
|
|||||||
return str(err)[-len(epy2) :] == epy2 or str(err)[-len(epy3) :] == epy3
|
return str(err)[-len(epy2) :] == epy2 or str(err)[-len(epy3) :] == epy3
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def gethv(map_h, map_v, extent):
|
def gethv(map_h, map_v, extent):
|
||||||
# Number of selected vectors
|
# Number of selected vectors
|
||||||
nh = map_h.shape[0]
|
nh = map_h.shape[0]
|
||||||
@@ -86,6 +86,7 @@ def gethv(map_h, map_v, extent):
|
|||||||
|
|
||||||
return np.meshgrid(h, v)
|
return np.meshgrid(h, v)
|
||||||
|
|
||||||
|
|
||||||
def streamplot(ax, map_h, map_v, extent, **kwargs):
|
def streamplot(ax, map_h, map_v, extent, **kwargs):
|
||||||
"""
|
"""
|
||||||
Add an overlay : streamlines
|
Add an overlay : streamlines
|
||||||
@@ -96,7 +97,6 @@ def streamplot(ax, map_h, map_v, extent, **kwargs):
|
|||||||
|
|
||||||
def quiver(ax, map_h, map_v, extent, key_v=None, label="", **kwargs):
|
def quiver(ax, map_h, map_v, extent, key_v=None, label="", **kwargs):
|
||||||
|
|
||||||
|
|
||||||
hh, vv = gethv(map_h, map_v, extent)
|
hh, vv = gethv(map_h, map_v, extent)
|
||||||
|
|
||||||
# plot vector field
|
# plot vector field
|
||||||
@@ -119,19 +119,20 @@ def quiver(ax, map_h, map_v, extent, key_v=None, label="", **kwargs):
|
|||||||
coordinates="figure",
|
coordinates="figure",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def line_integral_convolution(ax, map_h, map_v, extent, **kwargs):
|
def line_integral_convolution(ax, map_h, map_v, extent, **kwargs):
|
||||||
"""
|
"""
|
||||||
from Adnan Ali Ahmad
|
from Adnan Ali Ahmad
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lic_res = lic.lic(map_v, map_h,length=20) #compute line integral convolution
|
lic_res = lic.lic(map_v, map_h, length=20) # compute line integral convolution
|
||||||
|
|
||||||
# Amplify contrast on lic
|
# Amplify contrast on lic
|
||||||
lim=(.1,.9)
|
lim = (0.1, 0.9)
|
||||||
lic_data_clip = np.clip(lic_res,lim[0],lim[1])
|
lic_data_clip = np.clip(lic_res, lim[0], lim[1])
|
||||||
lic_data_rgba = ScalarMappable(norm=None, cmap="binary").to_rgba(lic_data_clip)
|
lic_data_rgba = ScalarMappable(norm=None, cmap="binary").to_rgba(lic_data_clip)
|
||||||
lic_data_clip_rescale = (lic_data_clip-lim[0])/(lim[1]-lim[0])
|
lic_data_clip_rescale = (lic_data_clip - lim[0]) / (lim[1] - lim[0])
|
||||||
lic_data_rgba[...,3] = lic_data_clip_rescale * 1
|
lic_data_rgba[..., 3] = lic_data_clip_rescale * 1
|
||||||
|
|
||||||
args = [lic_data_rgba]
|
args = [lic_data_rgba]
|
||||||
plot_args = {**kwargs}
|
plot_args = {**kwargs}
|
||||||
@@ -141,7 +142,6 @@ def line_integral_convolution(ax, map_h, map_v, extent, **kwargs):
|
|||||||
ax.imshow(*args, **plot_args)
|
ax.imshow(*args, **plot_args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PlotRule(Rule):
|
class PlotRule(Rule):
|
||||||
"""
|
"""
|
||||||
The rule class, speficic to plot.
|
The rule class, speficic to plot.
|
||||||
@@ -228,7 +228,12 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
# Select runs
|
# Select runs
|
||||||
if selector is None:
|
if selector is None:
|
||||||
self.selector = RunSelector(
|
self.selector = RunSelector(
|
||||||
path, runs, nums, self.params.input.nml_filename, unit_time=unit_time, **kwargs
|
path,
|
||||||
|
runs,
|
||||||
|
nums,
|
||||||
|
self.params.input.nml_filename,
|
||||||
|
unit_time=unit_time,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.selector = selector
|
self.selector = selector
|
||||||
@@ -677,8 +682,6 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
if text_embeded is None:
|
if text_embeded is None:
|
||||||
text_embeded = True
|
text_embeded = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if center_space:
|
if center_space:
|
||||||
center = self.current_processor.get_attribute("/maps", "center")
|
center = self.current_processor.get_attribute("/maps", "center")
|
||||||
center_h = center[self._ax_nb[ax_h]]
|
center_h = center[self._ax_nb[ax_h]]
|
||||||
@@ -725,17 +728,37 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
frameon=False,
|
frameon=False,
|
||||||
)
|
)
|
||||||
plt.gca().add_artist(scalebar)
|
plt.gca().add_artist(scalebar)
|
||||||
|
|
||||||
if axes_indicator:
|
if axes_indicator:
|
||||||
# A liitle drawing saying what are the axes
|
# A liitle drawing saying what are the axes
|
||||||
plt.annotate('', xy=(0.97, 0.1), xycoords='axes fraction', xytext=(0.865, 0.1),
|
plt.annotate(
|
||||||
arrowprops={'arrowstyle': '->', "color" : overtext_color})
|
"",
|
||||||
plt.annotate('', xy=(0.87, 0.2), xycoords='axes fraction', xytext=(0.87, 0.095),
|
xy=(0.97, 0.1),
|
||||||
arrowprops={'arrowstyle': '->', "color" : overtext_color})
|
xycoords="axes fraction",
|
||||||
plt.annotate(self._ax_title[ax_h], xy=(0.87, 0.2), xytext=(0.89, 0.05),
|
xytext=(0.865, 0.1),
|
||||||
color=overtext_color, xycoords='axes fraction')
|
arrowprops={"arrowstyle": "->", "color": overtext_color},
|
||||||
plt.annotate(self._ax_title[ax_v], xy=(0.87, 0.2), xytext=(0.83, 0.12),
|
)
|
||||||
color=overtext_color, xycoords='axes fraction')
|
plt.annotate(
|
||||||
|
"",
|
||||||
|
xy=(0.87, 0.2),
|
||||||
|
xycoords="axes fraction",
|
||||||
|
xytext=(0.87, 0.095),
|
||||||
|
arrowprops={"arrowstyle": "->", "color": overtext_color},
|
||||||
|
)
|
||||||
|
plt.annotate(
|
||||||
|
self._ax_title[ax_h],
|
||||||
|
xy=(0.87, 0.2),
|
||||||
|
xytext=(0.89, 0.05),
|
||||||
|
color=overtext_color,
|
||||||
|
xycoords="axes fraction",
|
||||||
|
)
|
||||||
|
plt.annotate(
|
||||||
|
self._ax_title[ax_v],
|
||||||
|
xy=(0.87, 0.2),
|
||||||
|
xytext=(0.83, 0.12),
|
||||||
|
color=overtext_color,
|
||||||
|
xycoords="axes fraction",
|
||||||
|
)
|
||||||
if axes:
|
if axes:
|
||||||
if xlabel is None:
|
if xlabel is None:
|
||||||
xlabel = self._ax_title[ax_h]
|
xlabel = self._ax_title[ax_h]
|
||||||
@@ -777,7 +800,13 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
if put_title:
|
if put_title:
|
||||||
title = self.snapshot_title(run, title, nml_key, put_time, unit_time)
|
title = self.snapshot_title(run, title, nml_key, put_time, unit_time)
|
||||||
if text_embeded:
|
if text_embeded:
|
||||||
ax.text(x=0.05, y=0.95, s=title, color=overtext_color, transform=ax.transAxes)
|
ax.text(
|
||||||
|
x=0.05,
|
||||||
|
y=0.95,
|
||||||
|
s=title,
|
||||||
|
color=overtext_color,
|
||||||
|
transform=ax.transAxes,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
plt.title(title)
|
plt.title(title)
|
||||||
|
|
||||||
@@ -900,10 +929,15 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
center_space=False,
|
center_space=False,
|
||||||
parts=True,
|
parts=True,
|
||||||
sinks=False,
|
sinks=False,
|
||||||
|
filter_fun=None,
|
||||||
|
s=None,
|
||||||
|
c=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add an overlay with particles data
|
Add an overlay with particles data
|
||||||
|
if both sinks and parts are set to true, only sinks are overlayed
|
||||||
|
filter_fun : function that take an array like value and returns an array of boolean
|
||||||
"""
|
"""
|
||||||
|
|
||||||
unit_length = self.current_processor.info["unit_length"]
|
unit_length = self.current_processor.info["unit_length"]
|
||||||
@@ -915,16 +949,17 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
self.current_processor.get_value("/datasets/sinks")
|
self.current_processor.get_value("/datasets/sinks")
|
||||||
)
|
)
|
||||||
part_pos = sinks[["x", "y", "z"]].values
|
part_pos = sinks[["x", "y", "z"]].values
|
||||||
mass = sinks.msink
|
|
||||||
unit_length /= self.current_processor.lbox
|
unit_length /= self.current_processor.lbox
|
||||||
|
data = sinks
|
||||||
except KeyError:
|
except KeyError:
|
||||||
self.current_processor._log("No sinks particles", "WARNING")
|
self.current_processor._log("No sinks particles", "WARNING")
|
||||||
return
|
return
|
||||||
elif parts:
|
elif parts:
|
||||||
# Open particle HDF5 filetype_from_ext
|
# Open particle HDF5 filetype_from_ext
|
||||||
self.current_processor.load_parts(keys=["pos", "mass"])
|
self.current_processor.load_parts(keys=["pos", "mass"])
|
||||||
part_pos = self.current_processor.parts.pos
|
data = self.current_processor.parts
|
||||||
mass = self.current_processor.parts.mass
|
part_pos = self.current_processor.parts["pos"]
|
||||||
|
mass = self.current_processor.parts["mass"]
|
||||||
mass *= self.current_processor.info["unit_mass"].express(U.Msun)
|
mass *= self.current_processor.info["unit_mass"].express(U.Msun)
|
||||||
self.current_processor.unload_parts()
|
self.current_processor.unload_parts()
|
||||||
|
|
||||||
@@ -956,14 +991,39 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
& (im_extent[2] <= part_v)
|
& (im_extent[2] <= part_v)
|
||||||
& (part_v <= im_extent[3])
|
& (part_v <= im_extent[3])
|
||||||
)
|
)
|
||||||
|
if filter_fun is not None:
|
||||||
|
mask = mask & filter_fun(data)
|
||||||
|
|
||||||
part_h = part_h[mask]
|
part_h = part_h[mask]
|
||||||
part_v = part_v[mask]
|
part_v = part_v[mask]
|
||||||
|
|
||||||
|
# Size and color
|
||||||
|
if s is None and sinks:
|
||||||
|
s = data.msink[mask] / 5e3
|
||||||
|
|
||||||
|
if isinstance(s, str):
|
||||||
|
s = data[s][mask]
|
||||||
|
elif callable(s):
|
||||||
|
s = s(data)[mask]
|
||||||
|
|
||||||
|
if isinstance(c, str):
|
||||||
|
c = data[c][mask]
|
||||||
|
elif callable(c):
|
||||||
|
c = c(data)[mask]
|
||||||
|
|
||||||
# Scatter plot
|
# Scatter plot
|
||||||
plt.scatter(part_h, part_v, s=mass[mask] / 5e3, **kwargs)
|
plt.scatter(part_h, part_v, s=s, c=c, **kwargs)
|
||||||
|
|
||||||
def _overlay_vector(
|
def _overlay_vector(
|
||||||
self, name, ax_los, extent, unit=U.km_s, unit_coeff=1.0, reduce_res=1, kind="quiver", **kwargs
|
self,
|
||||||
|
name,
|
||||||
|
ax_los,
|
||||||
|
extent,
|
||||||
|
unit=U.km_s,
|
||||||
|
unit_coeff=1.0,
|
||||||
|
reduce_res=1,
|
||||||
|
kind="quiver",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add an overlay : vector field
|
Add an overlay : vector field
|
||||||
@@ -977,8 +1037,9 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
|
|
||||||
map_h = self.current_processor.get_value(f"/maps/slice_{name}{ax_h}_{ax_los}")
|
map_h = self.current_processor.get_value(f"/maps/slice_{name}{ax_h}_{ax_los}")
|
||||||
map_v = self.current_processor.get_value(f"/maps/slice_{name}{ax_v}_{ax_los}")
|
map_v = self.current_processor.get_value(f"/maps/slice_{name}{ax_v}_{ax_los}")
|
||||||
label, unit_old, unit = self._ax_label_unit(f"/maps/slice_{name}{ax_h}_{ax_los}", "", unit, unit_coeff)
|
label, unit_old, unit = self._ax_label_unit(
|
||||||
|
f"/maps/slice_{name}{ax_h}_{ax_los}", "", unit, unit_coeff
|
||||||
|
)
|
||||||
|
|
||||||
# take only a subset
|
# take only a subset
|
||||||
map_h = map_h[::reduce_res, ::reduce_res] * unit_old.express(unit)
|
map_h = map_h[::reduce_res, ::reduce_res] * unit_old.express(unit)
|
||||||
@@ -991,7 +1052,7 @@ class Plotter(Aggregator, BaseProcessor):
|
|||||||
elif kind == "lic":
|
elif kind == "lic":
|
||||||
line_integral_convolution(plt.gca(), map_h, map_v, extent=extent, **kwargs)
|
line_integral_convolution(plt.gca(), map_h, map_v, extent=extent, **kwargs)
|
||||||
|
|
||||||
def _overlay_speed(self, ax_los, extent, **kwargs):
|
def _overlay_speed(self, ax_los, extent, **kwargs):
|
||||||
self._overlay_vector("vel", ax_los, extent, **kwargs)
|
self._overlay_vector("vel", ax_los, extent, **kwargs)
|
||||||
|
|
||||||
def _overlay_B(self, ax_los, extent, **kwargs):
|
def _overlay_B(self, ax_los, extent, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user