diff --git a/galsec.py b/galsec.py index 5e0e538..c3db466 100644 --- a/galsec.py +++ b/galsec.py @@ -7,6 +7,8 @@ import numpy as np from astropy.table import QTable, hstack from astropy import units as u from astropy.units.quantity import Quantity +from scipy.interpolate import griddata +from scipy.fft import fftn def vect_r(position: np.array, vector: np.array) -> np.array: @@ -136,6 +138,66 @@ def aggregate( return binned_data +def get_bouncing_box_mask( + data: QTable, r: Quantity[u.kpc], phi: Quantity[u.rad], size: Quantity[u.kpc] +): + x = r * np.cos(phi) + y = r * np.sin(phi) + norm_inf = np.maximum( + np.abs(data["position"][:, 0] - x), np.abs(data["position"][:, 1] - y) + ) + mask = (norm_inf < size / 2) & (np.abs(data["position"][:, 2]) < size / 2) + return mask + + +def regrid( + position: Quantity, + value: Quantity, + resolution: Quantity[u.pc], +): + min_x, max_x = position[:, 0].min(), position[:, 0].max() + min_y, max_y = position[:, 1].min(), position[:, 1].max() + min_z, max_z = position[:, 2].min(), position[:, 2].max() + size = max([max_x - min_x, max_y - min_y, max_z - min_z]) + + nb_points = int(np.ceil(((size / resolution).to(u.dimensionless_unscaled).value))) + gx, gy, gz = np.mgrid[ + min_x.value : max_x.value : nb_points * 1j, + min_y.value : max_y.value : nb_points * 1j, + min_z.value : max_z.value : nb_points * 1j, + ] + + grid = griddata( + position.value, + value.value, + (gx, gy, gz), + method="nearest", + ) + + import matplotlib.pyplot as plt + + plt.imshow( + grid[:, :, len(grid) // 2].T, + origin="lower", + extent=[min_x.value, max_x.value, min_y.value, max_y.value], + ) + plt.show() + + return ( + grid, + (gx, gy, gz), + ) + +def fft( + position: Quantity, + value: Quantity, + resolution: Quantity[u.pc], + ): + + grid, (gx, gy, gz) = regrid(position, value, resolution) + + fftn(grid, overwrite_x=True) + class Galsec: """ Galactic sector extractor @@ -258,6 +320,8 @@ class Galsec: delta_l: Quantity[u.kpc] = u.kpc, rmin: Quantity[u.kpc] = 1 * u.kpc, rmax: Quantity[u.kpc] = 12 * u.kpc, + zmin: Quantity[u.kpc] = -0.5 * u.kpc, + zmax: Quantity[u.kpc] = 0.5 * u.kpc, ): """Compute the aggration of quantities in sectors bins @@ -274,7 +338,7 @@ class Galsec: """ self.sector_binning(delta_r, delta_l) - grouped_data = {} + self.grouped_data = {} self.sectors = {} for fluid in self.fluids: @@ -283,15 +347,20 @@ class Galsec: else: extensive_fields = ["mass", "ek"] filtered_data = self.data[fluid][ - np.logical_and( - self.data[fluid]["r"] > rmin, self.data[fluid]["r"] < rmax - ) + (self.data[fluid]["r"] > rmin) + & (self.data[fluid]["r"] < rmax) + & (self.data[fluid]["position"][:, 2] > zmin) + & (self.data[fluid]["position"][:, 2] < zmax) ] - grouped_data[fluid] = filtered_data.group_by(["r_bin", "phi_bin"]) + self.grouped_data[fluid] = filtered_data.group_by(["r_bin", "phi_bin"]) self.sectors[fluid] = hstack( [ - grouped_data[fluid]["r_bin", "phi_bin"].groups.aggregate(np.fmin), - aggregate(grouped_data[fluid], extensive_fields=extensive_fields), + self.grouped_data[fluid]["r_bin", "phi_bin"].groups.aggregate( + np.fmin + ), + aggregate( + self.grouped_data[fluid], extensive_fields=extensive_fields + ), ] ) self.sectors[fluid].rename_column("r_bin", "r") @@ -300,7 +369,7 @@ class Galsec: self.sectors["stars"]["sfr"] = ( np.zeros(len(self.sectors["stars"]["mass"])) * u.Msun / u.year ) - for i, group in enumerate(grouped_data["stars"].groups): + for i, group in enumerate(self.grouped_data["stars"].groups): self.sectors["stars"]["sfr"][i] = get_sfr(group, self.time) self.sectors["stars"]["sfr"][i] = get_sfr(group, self.time) @@ -310,6 +379,8 @@ class Galsec: delta_r: Quantity[u.kpc] = u.kpc, rmin: Quantity[u.kpc] = 1 * u.kpc, rmax: Quantity[u.kpc] = 12 * u.kpc, + zmin: Quantity[u.kpc] = -0.5 * u.kpc, + zmax: Quantity[u.kpc] = 0.5 * u.kpc, ): """Compute the aggration of quantities in radial bins @@ -333,9 +404,10 @@ class Galsec: else: extensive_fields = ["mass", "ek"] filtered_data = self.data[fluid][ - np.logical_and( - self.data[fluid]["r"] > rmin, self.data[fluid]["r"] < rmax - ) + (self.data[fluid]["r"] > rmin) + & (self.data[fluid]["r"] < rmax) + & (self.data[fluid]["z"] > zmin) + & (self.data[fluid]["z"] < zmax) ] grouped_data[fluid] = filtered_data.group_by(["r_bin"]) self.rings[fluid] = hstack( diff --git a/galsec_plot.py b/galsec_plot.py index 886c74a..68d12c7 100644 --- a/galsec_plot.py +++ b/galsec_plot.py @@ -50,6 +50,6 @@ def plot_radial(table: QTable): P[0, -1] = P[0, 0] P[1, -1] = P[1, 0] - pc = ax.pcolormesh(P, R, C, cmap="tab20c") #, vmin=6, vmax=9) + pc = ax.pcolormesh(P, R, C, cmap="tab20c") # , vmin=6, vmax=9) print(P, R, C) - #fig.colorbar(pc) + # fig.colorbar(pc)