import matplotlib as mpl import matplotlib.patches as patches import matplotlib.pyplot as plt from matplotlib.lines import Line2D from matplotlib.path import Path from matplotlib.widgets import ( Button, CheckButtons, LassoSelector, PolygonSelector, RadioButtons, Slider, SpanSelector, ) from scipy.stats import linregress from skimage.draw import line from postprocessor import * class DraggablePoint: # http://stackoverflow.com/questions/21654008/matplotlib-drag-overlapping-points-interactively lock = None # only one can be animated at a time def __init__(self, parent, x=0.1, y=0.1, size=0.1): self.parent = parent if self.parent.list_points: color = "r" else: color = "g" self.point = patches.Ellipse( (x, y), size, size * 3, fc=color, alpha=0.5, edgecolor=color ) self.x = x self.y = y self.parent.ax.add_patch(self.point) self.press = None self.background = None self.connect() if self.parent.list_points: line_x = [self.parent.list_points[0].x, self.x] line_y = [self.parent.list_points[0].y, self.y] self.line = Line2D(line_x, line_y, color="k", alpha=0.5) self.parent.ax.add_line(self.line) def connect(self): "connect to all the events we need" self.cidpress = self.point.figure.canvas.mpl_connect( "button_press_event", self.on_press ) self.cidrelease = self.point.figure.canvas.mpl_connect( "button_release_event", self.on_release ) self.cidmotion = self.point.figure.canvas.mpl_connect( "motion_notify_event", self.on_motion ) def on_press(self, event): if event.inaxes != self.point.axes: return if DraggablePoint.lock is not None: return contains, attrd = self.point.contains(event) if not contains: return self.press = (self.point.center), event.xdata, event.ydata DraggablePoint.lock = self # draw everything but the selected rectangle and store the pixel buffer canvas = self.point.figure.canvas axes = self.point.axes self.point.set_animated(True) if self == self.parent.list_points[1]: self.line.set_animated(True) else: self.parent.list_points[1].line.set_animated(True) canvas.draw() self.background = canvas.copy_from_bbox(self.point.axes.bbox) # now redraw just the rectangle axes.draw_artist(self.point) # and blit just the redrawn area canvas.blit(axes.bbox) def on_motion(self, event): if DraggablePoint.lock is not self: return if event.inaxes != self.point.axes: return self.point.center, xpress, ypress = self.press dx = event.xdata - xpress dy = event.ydata - ypress self.point.center = (self.point.center[0] + dx, self.point.center[1] + dy) canvas = self.point.figure.canvas axes = self.point.axes # restore the background region canvas.restore_region(self.background) # redraw just the current rectangle axes.draw_artist(self.point) if self == self.parent.list_points[1]: axes.draw_artist(self.line) else: self.parent.list_points[1].line.set_animated(True) axes.draw_artist(self.parent.list_points[1].line) self.x = self.point.center[0] self.y = self.point.center[1] if self == self.parent.list_points[1]: line_x = [self.parent.list_points[0].x, self.x] line_y = [self.parent.list_points[0].y, self.y] self.line.set_data(line_x, line_y) else: line_x = [self.x, self.parent.list_points[1].x] line_y = [self.y, self.parent.list_points[1].y] self.parent.list_points[1].line.set_data(line_x, line_y) # blit just the redrawn area canvas.blit(axes.bbox) def on_release(self, event): "on release we reset the press data" if DraggablePoint.lock is not self: return self.press = None DraggablePoint.lock = None # turn off the rect animation property and reset the background self.point.set_animated(False) if self == self.parent.list_points[1]: self.line.set_animated(False) else: self.parent.list_points[1].line.set_animated(False) self.background = None # redraw the full figure self.point.figure.canvas.draw() self.x = self.point.center[0] self.y = self.point.center[1] self.parent.draw() def disconnect(self): "disconnect all the stored connection ids" self.point.figure.canvas.mpl_disconnect(self.cidpress) self.point.figure.canvas.mpl_disconnect(self.cidrelease) self.point.figure.canvas.mpl_disconnect(self.cidmotion) class DraggableLine: def __init__(self, parent, ax, xy1, xy2, size): self.parent = parent self.ax = ax self.list_points = [] self.list_points.append(DraggablePoint(self, xy1[0], xy1[1], size)) self.list_points.append(DraggablePoint(self, xy2[0], xy2[1], size)) def draw(self): self.parent.draw(update_map=False) def clear(self): for p in self.list_points: p.point.remove() self.list_points[1].line.remove() class FitterTool: """ Flexible fitter tool """ def __init__(self, ax, bounds=None): self.ax = ax self.bounds = bounds self.bounds_selector = SpanSelector( self.ax, self.onselect, "horizontal", useblit=True, span_stays=True ) self.fitline = DraggableLine(self, self.ax_gamma, [0, 0], [1, 0.3], 0.05) if update_map and self.gamma_3d_button.get_status()[1]: lr = linregress(rho[rho > -4], cs[rho > -4]) (a, b, r, _, _) = lr self.line_gamma.clear() del self.line_gamma rhomin, rhomax = np.min(rho), np.max(rho) self.line_gamma = DraggableLine( self, self.ax_gamma, [rhomin, a * rhomin + b], [rhomax, a * rhomax + b], 0.05, ) print("Gamma linregress : {}".format(lr)) else: lps = self.line_gamma.list_points xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y a = (yb - ya) / (xb - xa) def onselect(self, vmin, vmax): self.bounds = (vmin, vmax) class InteractiveGUI: """ This is a matplotlib interactive session to restrain analysis to a specific area """ def update_rmin(self, val): # amp is the current value of the slider self.rmin = self.srmin.val # update curve self.clicked_reset(None) def update_rmax(self, val): # amp is the current value of the slider self.rmax = self.srmax.val # update curve self.clicked_reset(None) def draw(self, first=False, update_map=True): if first or update_map: self.fmap = np.copy(self.datamap) self.frho = np.copy(self.frho_map) self.fcs = np.copy(self.fcs_map) self.fmap[np.logical_not(self.mask)] = np.nan ## Map plt.sca(self.ax_fluct) if first: self.im = plt.imshow( self.fmap, origin="lower", cmap="RdBu_r", extent=self.im_extent, norm=mpl.colors.LogNorm(), vmin=1e-2, vmax=1e2, ) plt.title("Fluctuations") # cbar = plt.colorbar() else: self.im.set_data(self.fmap) ## Gamma plt.sca(self.ax_gamma) if self.gamma_3d_button.get_status()[0]: rho, cs = self.rho3d[self.mask3d], self.P3d[self.mask3d] plt.xlabel(r"$\log(\rho)$") plt.ylabel(r"$\log(P)$") self.get_gamma = lambda a: a else: rho = self.frho_map[self.mask] cs = self.fcs_map[self.mask] plt.xlabel(r"$\log(\rho / \bar{\rho})$") plt.ylabel(r"$\log(c_s / \bar{c_s})$") self.get_gamma = lambda a: 2 * a + 1 if first: _, _, _, self.gamma_hist = plt.hist2d( rho, cs, bins=100, norm=mpl.colors.LogNorm(), cmap=plt.get_cmap("plasma"), ) # cbar = plt.colorbar() else: self.gamma_hist.remove() _, _, _, self.gamma_hist = plt.hist2d( rho, cs, bins=100, norm=mpl.colors.LogNorm(), cmap=plt.get_cmap("plasma"), ) if update_map and self.gamma_3d_button.get_status()[1]: lr = linregress(rho[rho > 2], cs[rho > 2]) (a, b, r, _, _) = lr self.line_gamma.clear() del self.line_gamma rhomin, rhomax = np.min(rho), np.max(rho) self.line_gamma = DraggableLine( self, self.ax_gamma, [rhomin, a * rhomin + b], [rhomax, a * rhomax + b], 0.05, ) print("Gamma linregress : {}".format(lr)) else: lps = self.line_gamma.list_points xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y a = (yb - ya) / (xb - xa) self.ax_gamma.set_title("$\Gamma$ = {:.3g}".format(self.get_gamma(a))) ## PDF if first or update_map: plt.sca(self.ax_pdf) nb_cells = np.sum(self.mask_flat) if first: self.std_nb_cells = nb_cells values, self.edges = np.histogram( np.log10(self.fmap.flatten()[self.mask_flat]), self.pp.pp_params.pdf.nb_bin, weights=np.ones(nb_cells) / self.std_nb_cells, ) edges = self.edges plt.xlabel(r"$\log(\Sigma / \bar{\Sigma})$") plt.ylabel(r"$\mathcal{P}_\Sigma$") else: values, edges = np.histogram( np.log10(self.fmap.flatten()[self.mask_flat]), self.edges, weights=np.ones(nb_cells) / self.std_nb_cells, ) centers = 0.5 * (edges[1:] + edges[:-1]) mask_fit = ( (centers > self.pp.pp_params.pdf.xmin_fit) & (centers < self.pp.pp_params.pdf.xmax_fit) & (values > 0) ) (a, b, r, _, _) = linregress(centers[mask_fit], np.log10(values[mask_fit])) if first: plt.step(centers, values, where="mid", alpha=0.3) (self.step,) = plt.step(centers, values, where="mid") (self.fit,) = plt.plot( centers, 10 ** (a * centers + b), "--", color="navy", label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2), ) plt.yscale("log") plt.title("PDF") else: self.step.set_ydata(values) self.fit.set_ydata(10 ** (a * centers + b)) self.fit.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2)) plt.legend() ### PROFILE plt.sca(self.ax_profile) lps = self.line_profile.list_points xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y coor_pix = list( np.asarray( (np.array([xa, ya, xb, yb]) * self.pp.pp_params.pymses.zoom + 0.5) * self.shape[0], dtype=int, ) ) xpp, ypp = line(*coor_pix) rho_prof = self.rho_map[ypp, xpp] xp = ((xpp / float(self.shape[0])) - 0.5) / self.pp.pp_params.pymses.zoom yp = ((ypp / float(self.shape[1])) - 0.5) / self.pp.pp_params.pymses.zoom print(xp, yp, xpp, ypp) x = np.sqrt(np.abs((xp - xa) * (xb - xa) + (yp - ya) * (yb - ya))) x = x - float(x[0]) mask_fit_prof = (x >= self.fit_prof_vmin) & (x <= self.fit_prof_vmax) try: (a, b, r, _, _) = linregress( np.log10(x[mask_fit_prof]), rho_prof[mask_fit_prof] ) except ValueError as e: print("Warning in linregress : {}".format(e)) a, b, r = np.nan, np.nan, np.nan xv, yv = np.array([float(xa), float(ya)]) print(xv, yv) mask_vert = (xv >= self.xy3d[:, 0]) & ( xv < self.xy3d[:, 0] + self.pp.cells["dx"] ) mask_vert = ( mask_vert & (yv >= self.xy3d[:, 1]) & (yv < self.xy3d[:, 1] + self.pp.cells["dx"]) ) rho_vert = np.log10(self.pp.cells["rho"][mask_vert]) z_vert = self.pp.cells["pos"][mask_vert][:, 2] - 0.5 sorter = np.argsort(z_vert) rho_vert = rho_vert[sorter] z_vert = z_vert[sorter] if first: (self.prof,) = plt.semilogx(x, rho_prof) (self.prof_z,) = plt.semilogx(z_vert, rho_vert) plt.title("Profil") (self.fit_prof,) = plt.plot( x[mask_fit_prof], a * np.log10(x[mask_fit_prof]) + b, "--", color="navy", label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2), ) self.ax_profile.set_xlabel(r"$r$") self.ax_profile.set_ylabel(r"$\log(\rho)$") else: self.prof.set_data(x, rho_prof) self.prof_z.set_data(z_vert, rho_vert) self.fit_prof.set_data(x[mask_fit_prof], a * np.log10(x[mask_fit_prof]) + b) self.fit_prof.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2)) plt.legend() plt.draw() self.unselect() def onselect(self, verts, select_data): # update curve path = Path(verts) self.mask_flat = self.mask.flatten() if self.gamma_3d_button.get_status()[2]: self.mask_flat = self.mask_flat | path.contains_points(select_data) else: self.mask_flat = self.mask_flat & path.contains_points(select_data) self.mask = self.mask_flat.reshape(self.shape) if self.gamma_3d_button.get_status()[0]: if self.gamma_3d_button.get_status()[2]: self.mask3d = self.mask3d | path.contains_points(self.xy3d) else: self.mask3d = self.mask3d & path.contains_points(self.xy3d) self.draw() def clicked_select(self, val, selector, button): if not self.lasso: self.lasso = selector( self.ax_fluct, partial(self.onselect, select_data=self.xy) ) if not self.gamma_3d_button.get_status()[0]: self.lasso_gamma = selector( self.ax_gamma, partial(self.onselect, select_data=self.rhocs) ) button.color = "0.55" else: self.unselect() def unselect(self): self.lasso = False self.lasso_gamma = False self.lasso_button.color = "0.85" self.poly_button.color = "0.85" def clicked_reset(self, val): self.mask = (self.rr >= self.rmin) & (self.rr <= self.rmax) self.mask_flat = self.mask.flatten() self.mask3d = np.isfinite(self.rho3d) & np.isfinite(self.P3d) self.mask3d = self.mask3d & (self.rr3d >= self.rmin) & (self.rr3d <= self.rmax) self.lasso = False self.lasso_button.color = "0.85" self.draw() def fit_selector(self, vmin, vmax): self.fit_prof_vmin = vmin self.fit_prof_vmax = vmax self.draw() def __init__(self, num, path="./", pp_params=None, datamap_key="fluct_coldens_z"): """ Interactive plotting """ if pp_params is None: pp_params = default_params() pp_params.input.nml_filename = "disk.nml" pp_params.out.interactive = True pp_params.pymses.map_size = 4096 pp_params.pymses.zoom = 4 pp_params.pymses.variables = ["rho", "vel", "P"] pp_params.disk.enable = True pp_params.disk.nb_bin = 200 pp_params.pdf.nb_bin = 100 self.fig = plt.figure(figsize=(10, 8)) self.pp = PostProcessor(path, num, pp_params=pp_params, tag="interactive") self.pp.pdf_coldens("z") self.pp.pdf_rho("z") self.pp.pdf_T("z") self.pp.load_cells() self.im_extent = np.array(self.pp.get_attribute("/maps", "im_extent")) - 0.5 self.datamap = self.pp.get_value("/maps/" + datamap_key) self.rho_map = np.log10(self.pp.get_value("/maps/rho_z")) self.frho_map = np.log10(self.pp.get_value("/maps/fluct_rho_z")) self.T_map = self.pp.get_value("/maps/T_z") self.fcs_map = np.log10(np.sqrt(self.pp.get_value("/maps/fluct_T_z"))) self.shape = self.datamap.shape x = np.linspace(self.im_extent[0], self.im_extent[1], self.shape[0]) y = np.linspace(self.im_extent[2], self.im_extent[3], self.shape[1]) self.xx, self.yy = np.meshgrid(x, y) self.xy = np.column_stack((self.xx.flatten(), self.yy.flatten())) self.rr = np.sqrt(self.xx ** 2 + self.yy ** 2) self.rhocs = np.column_stack((self.frho_map.flatten(), self.fcs_map.flatten())) self.rmin = self.pp.pp_params.disk.rmin_pdf / 2.0 self.rmax = self.pp.pp_params.disk.rmax_pdf / 2.0 self.mask = (self.rr >= self.rmin) & (self.rr <= self.rmax) self.mask_flat = self.mask.flatten() self.rho3d = np.log10(self.pp.cells["rho"]) self.P3d = np.log10(self.pp.cells["P"]) self.xy3d = self.pp.cells["pos"][:, :2] - 0.5 self.rr3d = np.sqrt(np.sum((self.xy3d) ** 2, axis=1)) self.mask3d = np.isfinite(self.rho3d) & np.isfinite(self.P3d) self.mask3d = self.mask3d & (self.rr3d >= self.rmin) & (self.rr3d <= self.rmax) self.ax_fluct = plt.axes([0.05, 0.6, 0.4, 0.35]) self.ax_gamma = plt.axes([0.05, 0.15, 0.4, 0.35]) self.ax_pdf = plt.axes([0.55, 0.6, 0.4, 0.35]) self.ax_profile = plt.axes([0.55, 0.15, 0.4, 0.35]) self.fit_prof_vmin = 1e-2 self.fit_prof_vmax = 1e-1 self.fit_profile_selector = SpanSelector( self.ax_profile, self.fit_selector, "horizontal", useblit=True, span_stays=True, ) self.line_profile = DraggableLine(self, self.ax_fluct, [0, 0], [0.1, 0], 0.005) self.line_gamma = DraggableLine(self, self.ax_gamma, [0, 0], [1, 0.3], 0.05) ax_rmax = plt.axes([0.05, 0.07, 0.15, 0.02]) self.srmax = Slider( ax_rmax, r"$r_{max}$", 0, 0.7 / pp_params.pymses.zoom, valinit=self.rmax ) ax_rmin = plt.axes([0.05, 0.03, 0.15, 0.02]) self.srmin = Slider( ax_rmin, r"$r_{min}$", 0, 0.7 / pp_params.pymses.zoom, valinit=self.rmin ) ax_lasso = plt.axes([0.6, 0.07, 0.19, 0.02]) ax_poly = plt.axes([0.8, 0.07, 0.19, 0.02]) ax_gamma_3d = plt.axes([0.3, 0.01, 0.19, 0.09]) self.gamma_3d_button = CheckButtons( ax_gamma_3d, ["3D $\Gamma$", "Fit $\Gamma$", "Union"], [True, False, False] ) self.gamma_3d_button.on_clicked( lambda val: self.draw(first=False, update_map=True) ) self.lasso = False self.lasso_gamma = False self.lasso_button = Button(ax_lasso, "Lasso selector") self.poly_button = Button(ax_poly, "Polygon selector") self.lasso_button.on_clicked( partial( self.clicked_select, selector=LassoSelector, button=self.lasso_button ) ) self.poly_button.on_clicked( partial( self.clicked_select, selector=PolygonSelector, button=self.poly_button ) ) ax_reset = plt.axes([0.6, 0.03, 0.19, 0.02]) self.reset_button = Button(ax_reset, "Reset") self.reset_button.on_clicked(self.clicked_reset) self.srmin.on_changed(self.update_rmin) self.srmax.on_changed(self.update_rmax) self.draw(first=True, update_map=True) plt.show()