diff --git a/plotter.py b/plotter.py index 83af97e..3ab4592 100644 --- a/plotter.py +++ b/plotter.py @@ -604,7 +604,15 @@ class Plotter(Aggregator, BaseProcessor): for i, plot_overlay in enumerate(overlays): if plot_overlay in self.overlays: - plot_overlay = self.overlays[plot_overlay] + + if plot_overlay == "particles": + plot_overlay = partial( + self.overlays[plot_overlay], + unit_space=unit_space, + center_space=center_space, + ) + else: + plot_overlay = self.overlays[plot_overlay] try: plot_overlay(ax_los, im_extent, **overlays_kwargs[i]) @@ -698,7 +706,9 @@ class Plotter(Aggregator, BaseProcessor): **kwargs, ) - def _overlay_particles(self, ax_los, im_extent, **kwargs): + def _overlay_particles( + self, ax_los, im_extent, unit_space=U.pc, center_space=False, **kwargs + ): """ Add an overlay with particles data """ @@ -708,6 +718,7 @@ class Plotter(Aggregator, BaseProcessor): hdf5_parts = tables.open_file(filename, "r") part_pos = hdf5_parts.get_node("/data/pos").read() hdf5_parts.close() + unit_length = self.save.root._v_attrs["unit_length"] # index of the horizontal axis ih = self._ax_nb[self._axes_h[ax_los]] @@ -718,9 +729,27 @@ class Plotter(Aggregator, BaseProcessor): part_h = part_pos[:, ih] part_v = part_pos[:, iv] - # Renormalize - part_h = im_extent[0] + (im_extent[1] - im_extent[0]) * part_h - part_v = im_extent[2] + (im_extent[3] - im_extent[2]) * part_v + if center_space: + ax_h = self._axes_h[ax_los] + ax_v = self._axes_v[ax_los] + center = self.save.root.maps._v_attrs.center + center_h = center[self._ax_nb[ax_h]] + center_v = center[self._ax_nb[ax_v]] + part_h -= center_h + part_v -= center_v + + part_h *= unit_length.express(unit_space) + part_v *= unit_length.express(unit_space) + + # Filter + mask = ( + (im_extent[0] <= part_h) + & (part_h <= im_extent[1]) + & (im_extent[2] <= part_v) + & (part_v <= im_extent[3]) + ) + part_h = part_h[mask] + part_v = part_v[mask] # Scatter plot plt.scatter(part_h, part_v, **kwargs)