import matplotlib as mpl import numpy as np from functools import partial import matplotlib.patches as patches import matplotlib.pyplot as plt from matplotlib.lines import Line2D from matplotlib.path import Path from matplotlib.widgets import ( Button, CheckButtons, LassoSelector, PolygonSelector, Slider, SpanSelector, ) from scipy.stats import linregress from skimage.draw import line from snapshotprocessor import SnapshotProcessor from params import default_params class DraggablePoint: # http://stackoverflow.com/questions/21654008/matplotlib-drag-overlapping-points-interactively lock = None # only one can be animated at a time def __init__(self, parent, x=0.1, y=0.1, size=0.1): self.parent = parent if self.parent.list_points: color = "r" else: color = "g" self.point = patches.Ellipse( (x, y), size, size * 3, fc=color, alpha=0.5, edgecolor=color ) self.x = x self.y = y self.parent.ax.add_patch(self.point) self.press = None self.background = None self.connect() if self.parent.list_points: line_x = [self.parent.list_points[0].x, self.x] line_y = [self.parent.list_points[0].y, self.y] self.line = Line2D(line_x, line_y, color="k", alpha=0.5) self.parent.ax.add_line(self.line) def connect(self): "connect to all the events we need" self.cidpress = self.point.figure.canvas.mpl_connect( "button_press_event", self.on_press ) self.cidrelease = self.point.figure.canvas.mpl_connect( "button_release_event", self.on_release ) self.cidmotion = self.point.figure.canvas.mpl_connect( "motion_notify_event", self.on_motion ) def on_press(self, event): if event.inaxes != self.point.axes: return if DraggablePoint.lock is not None: return contains, attrd = self.point.contains(event) if not contains: return self.press = (self.point.center), event.xdata, event.ydata DraggablePoint.lock = self # draw everything but the selected rectangle and store the pixel buffer canvas = self.point.figure.canvas axes = self.point.axes self.point.set_animated(True) if self == self.parent.list_points[1]: self.line.set_animated(True) else: self.parent.list_points[1].line.set_animated(True) canvas.draw() self.background = canvas.copy_from_bbox(self.point.axes.bbox) # now redraw just the rectangle axes.draw_artist(self.point) # and blit just the redrawn area canvas.blit(axes.bbox) def on_motion(self, event): if DraggablePoint.lock is not self: return if event.inaxes != self.point.axes: return self.point.center, xpress, ypress = self.press dx = event.xdata - xpress dy = event.ydata - ypress self.point.center = (self.point.center[0] + dx, self.point.center[1] + dy) canvas = self.point.figure.canvas axes = self.point.axes # restore the background region canvas.restore_region(self.background) # redraw just the current rectangle axes.draw_artist(self.point) if self == self.parent.list_points[1]: axes.draw_artist(self.line) else: self.parent.list_points[1].line.set_animated(True) axes.draw_artist(self.parent.list_points[1].line) self.x = self.point.center[0] self.y = self.point.center[1] if self == self.parent.list_points[1]: line_x = [self.parent.list_points[0].x, self.x] line_y = [self.parent.list_points[0].y, self.y] self.line.set_data(line_x, line_y) else: line_x = [self.x, self.parent.list_points[1].x] line_y = [self.y, self.parent.list_points[1].y] self.parent.list_points[1].line.set_data(line_x, line_y) # blit just the redrawn area canvas.blit(axes.bbox) def on_release(self, event): "on release we reset the press data" if DraggablePoint.lock is not self: return self.press = None DraggablePoint.lock = None # turn off the rect animation property and reset the background self.point.set_animated(False) if self == self.parent.list_points[1]: self.line.set_animated(False) else: self.parent.list_points[1].line.set_animated(False) self.background = None # redraw the full figure self.point.figure.canvas.draw() self.x = self.point.center[0] self.y = self.point.center[1] self.parent.draw() def disconnect(self): "disconnect all the stored connection ids" self.point.figure.canvas.mpl_disconnect(self.cidpress) self.point.figure.canvas.mpl_disconnect(self.cidrelease) self.point.figure.canvas.mpl_disconnect(self.cidmotion) class DraggableLine: def __init__(self, parent, ax, xy1, xy2, size): self.parent = parent self.ax = ax self.list_points = [] self.list_points.append(DraggablePoint(self, xy1[0], xy1[1], size)) self.list_points.append(DraggablePoint(self, xy2[0], xy2[1], size)) def draw(self): self.parent.draw(update_map=False) def clear(self): for p in self.list_points: p.point.remove() self.list_points[1].line.remove() class InteractiveGUI: """ This is a matplotlib interactive session to restrain analysis to a specific area """ def update_rmin(self, val): # amp is the current value of the slider self.rmin = self.srmin.val # update curve self.clicked_reset(None) def update_rmax(self, val): # amp is the current value of the slider self.rmax = self.srmax.val # update curve self.clicked_reset(None) def draw(self, first=False, update_map=True): if first or update_map: self.fmap = np.copy(self.datamap) self.frho = np.copy(self.frho_map) self.fcs = np.copy(self.fcs_map) self.fmap[np.logical_not(self.mask)] = np.nan # Map plt.sca(self.ax_fluct) if first: self.im = plt.imshow( self.fmap, origin="lower", cmap="RdBu_r", extent=self.im_extent, norm=mpl.colors.LogNorm(), vmin=1e-2, vmax=1e2, ) plt.title("Fluctuations") # cbar = plt.colorbar() else: self.im.set_data(self.fmap) # Gamma plt.sca(self.ax_gamma) if self.gamma_3d_button.get_status()[0]: rho, cs = self.rho3d[self.mask3d], self.P3d[self.mask3d] plt.xlabel(r"$\log(\rho)$") plt.ylabel(r"$\log(P)$") self.get_gamma = lambda a: a else: rho = self.frho_map[self.mask] cs = self.fcs_map[self.mask] plt.xlabel(r"$\log(\rho / \bar{\rho})$") plt.ylabel(r"$\log(c_s / \bar{c_s})$") self.get_gamma = lambda a: 2 * a + 1 if first: _, _, _, self.gamma_hist = plt.hist2d( rho, cs, bins=100, norm=mpl.colors.LogNorm(), cmap=plt.get_cmap("plasma"), ) # cbar = plt.colorbar() else: self.gamma_hist.remove() _, _, _, self.gamma_hist = plt.hist2d( rho, cs, bins=100, norm=mpl.colors.LogNorm(), cmap=plt.get_cmap("plasma"), ) if update_map and self.gamma_3d_button.get_status()[1]: lr = linregress(rho[rho > 2], cs[rho > 2]) (a, b, r, _, _) = lr self.line_gamma.clear() del self.line_gamma rhomin, rhomax = np.min(rho), np.max(rho) self.line_gamma = DraggableLine( self, self.ax_gamma, [rhomin, a * rhomin + b], [rhomax, a * rhomax + b], 0.05, ) print("Gamma linregress : {}".format(lr)) else: lps = self.line_gamma.list_points xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y a = (yb - ya) / (xb - xa) self.ax_gamma.set_title("$\Gamma$ = {:.3g}".format(self.get_gamma(a))) # PDF if first or update_map: plt.sca(self.ax_pdf) nb_cells = np.sum(self.mask_flat) if first: self.std_nb_cells = nb_cells values, self.edges = np.histogram( np.log10(self.fmap.flatten()[self.mask_flat]), self.pp.params.pdf.nb_bin, weights=np.ones(nb_cells) / self.std_nb_cells, ) edges = self.edges plt.xlabel(r"$\log(\Sigma / \bar{\Sigma})$") plt.ylabel(r"$\mathcal{P}_\Sigma$") else: values, edges = np.histogram( np.log10(self.fmap.flatten()[self.mask_flat]), self.edges, weights=np.ones(nb_cells) / self.std_nb_cells, ) centers = 0.5 * (edges[1:] + edges[:-1]) mask_fit = ( (centers > self.pp.params.pdf.xmin_fit) & (centers < self.pp.params.pdf.xmax_fit) & (values > 0) ) (a, b, r, _, _) = linregress(centers[mask_fit], np.log10(values[mask_fit])) if first: plt.step(centers, values, where="mid", alpha=0.3) (self.step,) = plt.step(centers, values, where="mid") (self.fit,) = plt.plot( centers, 10 ** (a * centers + b), "--", color="navy", label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2), ) plt.yscale("log") plt.title("PDF") else: self.step.set_ydata(values) self.fit.set_ydata(10 ** (a * centers + b)) self.fit.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2)) plt.legend() # PROFILE plt.sca(self.ax_profile) lps = self.line_profile.list_points xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y coor_pix = list( np.asarray( (np.array([xa, ya, xb, yb]) * self.pp.params.pymses.zoom + 0.5) * self.shape[0], dtype=int, ) ) xpp, ypp = line(*coor_pix) rho_prof = self.rho_map[ypp, xpp] xp = ((xpp / float(self.shape[0])) - 0.5) / self.pp.params.pymses.zoom yp = ((ypp / float(self.shape[1])) - 0.5) / self.pp.params.pymses.zoom print(xp, yp, xpp, ypp) x = np.sqrt(np.abs((xp - xa) * (xb - xa) + (yp - ya) * (yb - ya))) x = x - float(x[0]) mask_fit_prof = (x >= self.fit_prof_vmin) & (x <= self.fit_prof_vmax) try: (a, b, r, _, _) = linregress( np.log10(x[mask_fit_prof]), rho_prof[mask_fit_prof] ) except ValueError as e: print("Warning in linregress : {}".format(e)) a, b, r = np.nan, np.nan, np.nan xv, yv = np.array([float(xa), float(ya)]) print(xv, yv) mask_vert = (xv >= self.xy3d[:, 0]) & ( xv < self.xy3d[:, 0] + self.pp.cells["dx"] ) mask_vert = ( mask_vert & (yv >= self.xy3d[:, 1]) & (yv < self.xy3d[:, 1] + self.pp.cells["dx"]) ) rho_vert = np.log10(self.pp.cells["rho"][mask_vert]) z_vert = self.pp.cells["pos"][mask_vert][:, 2] - 0.5 sorter = np.argsort(z_vert) rho_vert = rho_vert[sorter] z_vert = z_vert[sorter] if first: (self.prof,) = plt.semilogx(x, rho_prof) (self.prof_z,) = plt.semilogx(z_vert, rho_vert) plt.title("Profil") (self.fit_prof,) = plt.plot( x[mask_fit_prof], a * np.log10(x[mask_fit_prof]) + b, "--", color="navy", label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2), ) self.ax_profile.set_xlabel(r"$r$") self.ax_profile.set_ylabel(r"$\log(\rho)$") else: self.prof.set_data(x, rho_prof) self.prof_z.set_data(z_vert, rho_vert) self.fit_prof.set_data(x[mask_fit_prof], a * np.log10(x[mask_fit_prof]) + b) self.fit_prof.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2)) plt.legend() plt.draw() self.unselect() def onselect(self, verts, select_data): # update curve path = Path(verts) self.mask_flat = self.mask.flatten() if self.gamma_3d_button.get_status()[2]: self.mask_flat = self.mask_flat | path.contains_points(select_data) else: self.mask_flat = self.mask_flat & path.contains_points(select_data) self.mask = self.mask_flat.reshape(self.shape) if self.gamma_3d_button.get_status()[0]: if self.gamma_3d_button.get_status()[2]: self.mask3d = self.mask3d | path.contains_points(self.xy3d) else: self.mask3d = self.mask3d & path.contains_points(self.xy3d) self.draw() def clicked_select(self, val, selector, button): if not self.lasso: self.lasso = selector( self.ax_fluct, partial(self.onselect, select_data=self.xy) ) if not self.gamma_3d_button.get_status()[0]: self.lasso_gamma = selector( self.ax_gamma, partial(self.onselect, select_data=self.rhocs) ) button.color = "0.55" else: self.unselect() def unselect(self): self.lasso = False self.lasso_gamma = False self.lasso_button.color = "0.85" self.poly_button.color = "0.85" def clicked_reset(self, val): self.mask = (self.rr >= self.rmin) & (self.rr <= self.rmax) self.mask_flat = self.mask.flatten() self.mask3d = np.isfinite(self.rho3d) & np.isfinite(self.P3d) self.mask3d = self.mask3d & (self.rr3d >= self.rmin) & (self.rr3d <= self.rmax) self.lasso = False self.lasso_button.color = "0.85" self.draw() def fit_selector(self, vmin, vmax): self.fit_prof_vmin = vmin self.fit_prof_vmax = vmax self.draw() def __init__(self, num, path="./", params=None, datamap_key="fluct_coldens_z"): """ Interactive plotting """ if params is None: params = default_params() params.input.nml_filename = "disk.nml" params.out.interactive = True params.pymses.map_size = 4096 params.pymses.zoom = 4 params.pymses.variables = ["rho", "vel", "P"] params.disk.enable = True params.disk.nb_bin = 200 params.pdf.nb_bin = 100 self.fig = plt.figure(figsize=(10, 8)) self.pp = SnapshotProcessor(path, num, params=params, tag="interactive") self.pp.pdf_coldens("z") self.pp.pdf_rho("z") self.pp.pdf_T("z") self.pp.load_cells() self.im_extent = np.array(self.pp.get_attribute("/maps", "im_extent")) - 0.5 self.datamap = self.pp.get_value("/maps/" + datamap_key) self.rho_map = np.log10(self.pp.get_value("/maps/rho_z")) self.frho_map = np.log10(self.pp.get_value("/maps/fluct_rho_z")) self.T_map = self.pp.get_value("/maps/T_z") self.fcs_map = np.log10(np.sqrt(self.pp.get_value("/maps/fluct_T_z"))) self.shape = self.datamap.shape x = np.linspace(self.im_extent[0], self.im_extent[1], self.shape[0]) y = np.linspace(self.im_extent[2], self.im_extent[3], self.shape[1]) self.xx, self.yy = np.meshgrid(x, y) self.xy = np.column_stack((self.xx.flatten(), self.yy.flatten())) self.rr = np.sqrt(self.xx ** 2 + self.yy ** 2) self.rhocs = np.column_stack((self.frho_map.flatten(), self.fcs_map.flatten())) self.rmin = self.pp.params.disk.rmin_pdf / 2.0 self.rmax = self.pp.params.disk.rmax_pdf / 2.0 self.mask = (self.rr >= self.rmin) & (self.rr <= self.rmax) self.mask_flat = self.mask.flatten() self.rho3d = np.log10(self.pp.cells["rho"]) self.P3d = np.log10(self.pp.cells["P"]) self.xy3d = self.pp.cells["pos"][:, :2] - 0.5 self.rr3d = np.sqrt(np.sum((self.xy3d) ** 2, axis=1)) self.mask3d = np.isfinite(self.rho3d) & np.isfinite(self.P3d) self.mask3d = self.mask3d & (self.rr3d >= self.rmin) & (self.rr3d <= self.rmax) self.ax_fluct = plt.axes([0.05, 0.6, 0.4, 0.35]) self.ax_gamma = plt.axes([0.05, 0.15, 0.4, 0.35]) self.ax_pdf = plt.axes([0.55, 0.6, 0.4, 0.35]) self.ax_profile = plt.axes([0.55, 0.15, 0.4, 0.35]) self.fit_prof_vmin = 1e-2 self.fit_prof_vmax = 1e-1 self.fit_profile_selector = SpanSelector( self.ax_profile, self.fit_selector, "horizontal", useblit=True, span_stays=True, ) self.line_profile = DraggableLine(self, self.ax_fluct, [0, 0], [0.1, 0], 0.005) self.line_gamma = DraggableLine(self, self.ax_gamma, [0, 0], [1, 0.3], 0.05) ax_rmax = plt.axes([0.05, 0.07, 0.15, 0.02]) self.srmax = Slider( ax_rmax, r"$r_{max}$", 0, 0.7 / params.pymses.zoom, valinit=self.rmax ) ax_rmin = plt.axes([0.05, 0.03, 0.15, 0.02]) self.srmin = Slider( ax_rmin, r"$r_{min}$", 0, 0.7 / params.pymses.zoom, valinit=self.rmin ) ax_lasso = plt.axes([0.6, 0.07, 0.19, 0.02]) ax_poly = plt.axes([0.8, 0.07, 0.19, 0.02]) ax_gamma_3d = plt.axes([0.3, 0.01, 0.19, 0.09]) self.gamma_3d_button = CheckButtons( ax_gamma_3d, ["3D $\Gamma$", "Fit $\Gamma$", "Union"], [True, False, False] ) self.gamma_3d_button.on_clicked( lambda val: self.draw(first=False, update_map=True) ) self.lasso = False self.lasso_gamma = False self.lasso_button = Button(ax_lasso, "Lasso selector") self.poly_button = Button(ax_poly, "Polygon selector") self.lasso_button.on_clicked( partial( self.clicked_select, selector=LassoSelector, button=self.lasso_button ) ) self.poly_button.on_clicked( partial( self.clicked_select, selector=PolygonSelector, button=self.poly_button ) ) ax_reset = plt.axes([0.6, 0.03, 0.19, 0.02]) self.reset_button = Button(ax_reset, "Reset") self.reset_button.on_clicked(self.clicked_reset) self.srmin.on_changed(self.update_rmin) self.srmax.on_changed(self.update_rmax) self.draw(first=True, update_map=True) plt.show()