[plotter] correct the scaling for the particles plot
This commit is contained in:
+34
-5
@@ -604,7 +604,15 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
|
||||
for i, plot_overlay in enumerate(overlays):
|
||||
if plot_overlay in self.overlays:
|
||||
plot_overlay = self.overlays[plot_overlay]
|
||||
|
||||
if plot_overlay == "particles":
|
||||
plot_overlay = partial(
|
||||
self.overlays[plot_overlay],
|
||||
unit_space=unit_space,
|
||||
center_space=center_space,
|
||||
)
|
||||
else:
|
||||
plot_overlay = self.overlays[plot_overlay]
|
||||
|
||||
try:
|
||||
plot_overlay(ax_los, im_extent, **overlays_kwargs[i])
|
||||
@@ -698,7 +706,9 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _overlay_particles(self, ax_los, im_extent, **kwargs):
|
||||
def _overlay_particles(
|
||||
self, ax_los, im_extent, unit_space=U.pc, center_space=False, **kwargs
|
||||
):
|
||||
"""
|
||||
Add an overlay with particles data
|
||||
"""
|
||||
@@ -708,6 +718,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
hdf5_parts = tables.open_file(filename, "r")
|
||||
part_pos = hdf5_parts.get_node("/data/pos").read()
|
||||
hdf5_parts.close()
|
||||
unit_length = self.save.root._v_attrs["unit_length"]
|
||||
|
||||
# index of the horizontal axis
|
||||
ih = self._ax_nb[self._axes_h[ax_los]]
|
||||
@@ -718,9 +729,27 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
part_h = part_pos[:, ih]
|
||||
part_v = part_pos[:, iv]
|
||||
|
||||
# Renormalize
|
||||
part_h = im_extent[0] + (im_extent[1] - im_extent[0]) * part_h
|
||||
part_v = im_extent[2] + (im_extent[3] - im_extent[2]) * part_v
|
||||
if center_space:
|
||||
ax_h = self._axes_h[ax_los]
|
||||
ax_v = self._axes_v[ax_los]
|
||||
center = self.save.root.maps._v_attrs.center
|
||||
center_h = center[self._ax_nb[ax_h]]
|
||||
center_v = center[self._ax_nb[ax_v]]
|
||||
part_h -= center_h
|
||||
part_v -= center_v
|
||||
|
||||
part_h *= unit_length.express(unit_space)
|
||||
part_v *= unit_length.express(unit_space)
|
||||
|
||||
# Filter
|
||||
mask = (
|
||||
(im_extent[0] <= part_h)
|
||||
& (part_h <= im_extent[1])
|
||||
& (im_extent[2] <= part_v)
|
||||
& (part_v <= im_extent[3])
|
||||
)
|
||||
part_h = part_h[mask]
|
||||
part_v = part_v[mask]
|
||||
|
||||
# Scatter plot
|
||||
plt.scatter(part_h, part_v, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user