import matplotlib as mpl import matplotlib.pyplot as plt from matplotlib.widgets import Slider, Button, RadioButtons, LassoSelector from matplotlib.widgets import SpanSelector, PolygonSelector, CheckButtons import matplotlib.patches as patches from matplotlib.lines import Line2D from matplotlib.path import Path from scipy.stats import linregress from postprocessor import * from skimage.draw import line class DraggablePoint: # http://stackoverflow.com/questions/21654008/matplotlib-drag-overlapping-points-interactively lock = None # only one can be animated at a time def __init__(self, parent, x=0.1, y=0.1, size=0.1): self.parent = parent if self.parent.list_points: color = "r" else: color = "g" self.point = patches.Ellipse( (x, y), size, size * 3, fc=color, alpha=0.5, edgecolor=color ) self.x = x self.y = y self.parent.ax.add_patch(self.point) self.press = None self.background = None self.connect() if self.parent.list_points: line_x = [self.parent.list_points[0].x, self.x] line_y = [self.parent.list_points[0].y, self.y] self.line = Line2D(line_x, line_y, color="k", alpha=0.5) self.parent.ax.add_line(self.line) def connect(self): "connect to all the events we need" self.cidpress = self.point.figure.canvas.mpl_connect( "button_press_event", self.on_press ) self.cidrelease = self.point.figure.canvas.mpl_connect( "button_release_event", self.on_release ) self.cidmotion = self.point.figure.canvas.mpl_connect( "motion_notify_event", self.on_motion ) def on_press(self, event): if event.inaxes != self.point.axes: return if DraggablePoint.lock is not None: return contains, attrd = self.point.contains(event) if not contains: return self.press = (self.point.center), event.xdata, event.ydata DraggablePoint.lock = self # draw everything but the selected rectangle and store the pixel buffer canvas = self.point.figure.canvas axes = self.point.axes self.point.set_animated(True) if self == self.parent.list_points[1]: self.line.set_animated(True) else: self.parent.list_points[1].line.set_animated(True) canvas.draw() self.background = canvas.copy_from_bbox(self.point.axes.bbox) # now redraw just the rectangle axes.draw_artist(self.point) # and blit just the redrawn area canvas.blit(axes.bbox) def on_motion(self, event): if DraggablePoint.lock is not self: return if event.inaxes != self.point.axes: return self.point.center, xpress, ypress = self.press dx = event.xdata - xpress dy = event.ydata - ypress self.point.center = (self.point.center[0] + dx, self.point.center[1] + dy) canvas = self.point.figure.canvas axes = self.point.axes # restore the background region canvas.restore_region(self.background) # redraw just the current rectangle axes.draw_artist(self.point) if self == self.parent.list_points[1]: axes.draw_artist(self.line) else: self.parent.list_points[1].line.set_animated(True) axes.draw_artist(self.parent.list_points[1].line) self.x = self.point.center[0] self.y = self.point.center[1] if self == self.parent.list_points[1]: line_x = [self.parent.list_points[0].x, self.x] line_y = [self.parent.list_points[0].y, self.y] self.line.set_data(line_x, line_y) else: line_x = [self.x, self.parent.list_points[1].x] line_y = [self.y, self.parent.list_points[1].y] self.parent.list_points[1].line.set_data(line_x, line_y) # blit just the redrawn area canvas.blit(axes.bbox) def on_release(self, event): "on release we reset the press data" if DraggablePoint.lock is not self: return self.press = None DraggablePoint.lock = None # turn off the rect animation property and reset the background self.point.set_animated(False) if self == self.parent.list_points[1]: self.line.set_animated(False) else: self.parent.list_points[1].line.set_animated(False) self.background = None # redraw the full figure self.point.figure.canvas.draw() self.x = self.point.center[0] self.y = self.point.center[1] self.parent.draw() def disconnect(self): "disconnect all the stored connection ids" self.point.figure.canvas.mpl_disconnect(self.cidpress) self.point.figure.canvas.mpl_disconnect(self.cidrelease) self.point.figure.canvas.mpl_disconnect(self.cidmotion) class DraggableLine: def __init__(self, parent, ax, xy1, xy2, size): self.parent = parent self.ax = ax self.list_points = [] self.list_points.append(DraggablePoint(self, xy1[0], xy1[1], size)) self.list_points.append(DraggablePoint(self, xy2[0], xy2[1], size)) def draw(self): self.parent.draw(update_map=False) def clear(self): for p in self.list_points: p.point.remove() self.list_points[1].line.remove() class FitterTool: """ Flexible fitter tool """ def __init__(self, ax, bounds=None): self.ax = ax self.bounds = bounds self.bounds_selector = SpanSelector( self.ax, self.onselect, "horizontal", useblit=True, span_stays=True ) self.fitline = DraggableLine(self, self.ax_gamma, [0, 0], [1, 0.3], 0.05) if update_map and self.gamma_3d_button.get_status()[1]: lr = linregress(rho[rho > -4], cs[rho > -4]) (a, b, r, _, _) = lr self.line_gamma.clear() del self.line_gamma rhomin, rhomax = np.min(rho), np.max(rho) self.line_gamma = DraggableLine( self, self.ax_gamma, [rhomin, a * rhomin + b], [rhomax, a * rhomax + b], 0.05, ) print("Gamma linregress : {}".format(lr)) else: lps = self.line_gamma.list_points xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y a = (yb - ya) / (xb - xa) def onselect(self, vmin, vmax): self.bounds = (vmin, vmax) class InteractiveGUI: """ This is a matplotlib interactive session to restrain analysis to a specific area """ def update_rmin(self, val): # amp is the current value of the slider self.rmin = self.srmin.val # update curve self.clicked_reset(None) def update_rmax(self, val): # amp is the current value of the slider self.rmax = self.srmax.val # update curve self.clicked_reset(None) def draw(self, first=False, update_map=True): if first or update_map: self.fmap = np.copy(self.datamap) self.frho = np.copy(self.frho_map) self.fcs = np.copy(self.fcs_map) self.fmap[np.logical_not(self.mask)] = np.nan ## Map plt.sca(self.ax_fluct) if first: self.im = plt.imshow( self.fmap, origin="lower", cmap="RdBu_r", extent=self.im_extent, norm=mpl.colors.LogNorm(), vmin=1e-2, vmax=1e2, ) plt.title("Fluctuations") # cbar = plt.colorbar() else: self.im.set_data(self.fmap) ## Gamma plt.sca(self.ax_gamma) if self.gamma_3d_button.get_status()[0]: rho, cs = self.rho3d[self.mask3d], self.P3d[self.mask3d] plt.xlabel(r"$\log(\rho)$") plt.ylabel(r"$\log(P)$") self.get_gamma = lambda a: a else: rho = self.frho_map[self.mask] cs = self.fcs_map[self.mask] plt.xlabel(r"$\log(\rho / \bar{\rho})$") plt.ylabel(r"$\log(c_s / \bar{c_s})$") self.get_gamma = lambda a: 2 * a + 1 if first: _, _, _, self.gamma_hist = plt.hist2d( rho, cs, bins=100, norm=mpl.colors.LogNorm(), cmap=plt.get_cmap("plasma"), ) # cbar = plt.colorbar() else: self.gamma_hist.remove() _, _, _, self.gamma_hist = plt.hist2d( rho, cs, bins=100, norm=mpl.colors.LogNorm(), cmap=plt.get_cmap("plasma"), ) if update_map and self.gamma_3d_button.get_status()[1]: lr = linregress(rho[rho > 2], cs[rho > 2]) (a, b, r, _, _) = lr self.line_gamma.clear() del self.line_gamma rhomin, rhomax = np.min(rho), np.max(rho) self.line_gamma = DraggableLine( self, self.ax_gamma, [rhomin, a * rhomin + b], [rhomax, a * rhomax + b], 0.05, ) print("Gamma linregress : {}".format(lr)) else: lps = self.line_gamma.list_points xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y a = (yb - ya) / (xb - xa) self.ax_gamma.set_title("$\Gamma$ = {:.3g}".format(self.get_gamma(a))) ## PDF if first or update_map: plt.sca(self.ax_pdf) nb_cells = np.sum(self.mask_flat) if first: self.std_nb_cells = nb_cells values, self.edges = np.histogram( np.log10(self.fmap.flatten()[self.mask_flat]), self.pp.pp_params.pdf.nb_bin, weights=np.ones(nb_cells) / self.std_nb_cells, ) edges = self.edges plt.xlabel(r"$\log(\Sigma / \bar{\Sigma})$") plt.ylabel(r"$\mathcal{P}_\Sigma$") else: values, edges = np.histogram( np.log10(self.fmap.flatten()[self.mask_flat]), self.edges, weights=np.ones(nb_cells) / self.std_nb_cells, ) centers = 0.5 * (edges[1:] + edges[:-1]) mask_fit = ( (centers > self.pp.pp_params.pdf.xmin_fit) & (centers < self.pp.pp_params.pdf.xmax_fit) & (values > 0) ) (a, b, r, _, _) = linregress(centers[mask_fit], np.log10(values[mask_fit])) if first: plt.step(centers, values, where="mid", alpha=0.3) (self.step,) = plt.step(centers, values, where="mid") (self.fit,) = plt.plot( centers, 10 ** (a * centers + b), "--", color="navy", label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2), ) plt.yscale("log") plt.title("PDF") else: self.step.set_ydata(values) self.fit.set_ydata(10 ** (a * centers + b)) self.fit.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2)) plt.legend() ### PROFILE plt.sca(self.ax_profile) lps = self.line_profile.list_points xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y coor_pix = list( np.asarray( (np.array([xa, ya, xb, yb]) * self.pp.pp_params.pymses.zoom + 0.5) * self.shape[0], dtype=int, ) ) xpp, ypp = line(*coor_pix) rho_prof = self.rho_map[ypp, xpp] xp = ((xpp / float(self.shape[0])) - 0.5) / self.pp.pp_params.pymses.zoom yp = ((ypp / float(self.shape[1])) - 0.5) / self.pp.pp_params.pymses.zoom print(xp, yp, xpp, ypp) x = np.sqrt(np.abs((xp - xa) * (xb - xa) + (yp - ya) * (yb - ya))) x = x - float(x[0]) mask_fit_prof = (x >= self.fit_prof_vmin) & (x <= self.fit_prof_vmax) try: (a, b, r, _, _) = linregress( np.log10(x[mask_fit_prof]), rho_prof[mask_fit_prof] ) except ValueError as e: print("Warning in linregress : {}".format(e)) a, b, r = np.nan, np.nan, np.nan xv, yv = np.array([float(xa), float(ya)]) print(xv, yv) mask_vert = (xv >= self.xy3d[:, 0]) & ( xv < self.xy3d[:, 0] + self.pp.cells["dx"] ) mask_vert = ( mask_vert & (yv >= self.xy3d[:, 1]) & (yv < self.xy3d[:, 1] + self.pp.cells["dx"]) ) rho_vert = np.log10(self.pp.cells["rho"][mask_vert]) z_vert = self.pp.cells["pos"][mask_vert][:, 2] - 0.5 sorter = np.argsort(z_vert) rho_vert = rho_vert[sorter] z_vert = z_vert[sorter] if first: (self.prof,) = plt.semilogx(x, rho_prof) (self.prof_z,) = plt.semilogx(z_vert, rho_vert) plt.title("Profil") (self.fit_prof,) = plt.plot( x[mask_fit_prof], a * np.log10(x[mask_fit_prof]) + b, "--", color="navy", label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2), ) self.ax_profile.set_xlabel(r"$r$") self.ax_profile.set_ylabel(r"$\log(\rho)$") else: self.prof.set_data(x, rho_prof) self.prof_z.set_data(z_vert, rho_vert) self.fit_prof.set_data(x[mask_fit_prof], a * np.log10(x[mask_fit_prof]) + b) self.fit_prof.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2)) plt.legend() plt.draw() self.unselect() def onselect(self, verts, select_data): # update curve path = Path(verts) self.mask_flat = self.mask.flatten() if self.gamma_3d_button.get_status()[2]: self.mask_flat = self.mask_flat | path.contains_points(select_data) else: self.mask_flat = self.mask_flat & path.contains_points(select_data) self.mask = self.mask_flat.reshape(self.shape) if self.gamma_3d_button.get_status()[0]: if self.gamma_3d_button.get_status()[2]: self.mask3d = self.mask3d | path.contains_points(self.xy3d) else: self.mask3d = self.mask3d & path.contains_points(self.xy3d) self.draw() def clicked_select(self, val, selector, button): if not self.lasso: self.lasso = selector( self.ax_fluct, partial(self.onselect, select_data=self.xy) ) if not self.gamma_3d_button.get_status()[0]: self.lasso_gamma = selector( self.ax_gamma, partial(self.onselect, select_data=self.rhocs) ) button.color = "0.55" else: self.unselect() def unselect(self): self.lasso = False self.lasso_gamma = False self.lasso_button.color = "0.85" self.poly_button.color = "0.85" def clicked_reset(self, val): self.mask = (self.rr >= self.rmin) & (self.rr <= self.rmax) self.mask_flat = self.mask.flatten() self.mask3d = np.isfinite(self.rho3d) & np.isfinite(self.P3d) self.mask3d = self.mask3d & (self.rr3d >= self.rmin) & (self.rr3d <= self.rmax) self.lasso = False self.lasso_button.color = "0.85" self.draw() def fit_selector(self, vmin, vmax): self.fit_prof_vmin = vmin self.fit_prof_vmax = vmax self.draw() def __init__(self, num, path="./", pp_params=None, datamap_key="fluct_coldens_z"): """ Interactive plotting """ if pp_params is None: pp_params = default_params() pp_params.input.nml_filename = "disk.nml" pp_params.out.interactive = True pp_params.pymses.map_size = 4096 pp_params.pymses.zoom = 4 pp_params.pymses.variables = ["rho", "vel", "P"] pp_params.disk.enable = True pp_params.disk.nb_bin = 200 pp_params.pdf.nb_bin = 100 self.fig = plt.figure(figsize=(10, 8)) self.pp = PostProcessor(path, num, pp_params=pp_params, tag="interactive") self.pp.pdf_coldens("z") self.pp.pdf_rho("z") self.pp.pdf_T("z") self.pp.load_cells() self.im_extent = np.array(self.pp.get_attribute("/maps", "im_extent")) - 0.5 self.datamap = self.pp.get_value("/maps/" + datamap_key) self.rho_map = np.log10(self.pp.get_value("/maps/rho_z")) self.frho_map = np.log10(self.pp.get_value("/maps/fluct_rho_z")) self.T_map = self.pp.get_value("/maps/T_z") self.fcs_map = np.log10(np.sqrt(self.pp.get_value("/maps/fluct_T_z"))) self.shape = self.datamap.shape x = np.linspace(self.im_extent[0], self.im_extent[1], self.shape[0]) y = np.linspace(self.im_extent[2], self.im_extent[3], self.shape[1]) self.xx, self.yy = np.meshgrid(x, y) self.xy = np.column_stack((self.xx.flatten(), self.yy.flatten())) self.rr = np.sqrt(self.xx ** 2 + self.yy ** 2) self.rhocs = np.column_stack((self.frho_map.flatten(), self.fcs_map.flatten())) self.rmin = self.pp.pp_params.disk.rmin_pdf / 2.0 self.rmax = self.pp.pp_params.disk.rmax_pdf / 2.0 self.mask = (self.rr >= self.rmin) & (self.rr <= self.rmax) self.mask_flat = self.mask.flatten() self.rho3d = np.log10(self.pp.cells["rho"]) self.P3d = np.log10(self.pp.cells["P"]) self.xy3d = self.pp.cells["pos"][:, :2] - 0.5 self.rr3d = np.sqrt(np.sum((self.xy3d) ** 2, axis=1)) self.mask3d = np.isfinite(self.rho3d) & np.isfinite(self.P3d) self.mask3d = self.mask3d & (self.rr3d >= self.rmin) & (self.rr3d <= self.rmax) self.ax_fluct = plt.axes([0.05, 0.6, 0.4, 0.35]) self.ax_gamma = plt.axes([0.05, 0.15, 0.4, 0.35]) self.ax_pdf = plt.axes([0.55, 0.6, 0.4, 0.35]) self.ax_profile = plt.axes([0.55, 0.15, 0.4, 0.35]) self.fit_prof_vmin = 1e-2 self.fit_prof_vmax = 1e-1 self.fit_profile_selector = SpanSelector( self.ax_profile, self.fit_selector, "horizontal", useblit=True, span_stays=True, ) self.line_profile = DraggableLine(self, self.ax_fluct, [0, 0], [0.1, 0], 0.005) self.line_gamma = DraggableLine(self, self.ax_gamma, [0, 0], [1, 0.3], 0.05) ax_rmax = plt.axes([0.05, 0.07, 0.15, 0.02]) self.srmax = Slider( ax_rmax, r"$r_{max}$", 0, 0.7 / pp_params.pymses.zoom, valinit=self.rmax ) ax_rmin = plt.axes([0.05, 0.03, 0.15, 0.02]) self.srmin = Slider( ax_rmin, r"$r_{min}$", 0, 0.7 / pp_params.pymses.zoom, valinit=self.rmin ) ax_lasso = plt.axes([0.6, 0.07, 0.19, 0.02]) ax_poly = plt.axes([0.8, 0.07, 0.19, 0.02]) ax_gamma_3d = plt.axes([0.3, 0.01, 0.19, 0.09]) self.gamma_3d_button = CheckButtons( ax_gamma_3d, ["3D $\Gamma$", "Fit $\Gamma$", "Union"], [True, False, False] ) self.gamma_3d_button.on_clicked( lambda val: self.draw(first=False, update_map=True) ) self.lasso = False self.lasso_gamma = False self.lasso_button = Button(ax_lasso, "Lasso selector") self.poly_button = Button(ax_poly, "Polygon selector") self.lasso_button.on_clicked( partial( self.clicked_select, selector=LassoSelector, button=self.lasso_button ) ) self.poly_button.on_clicked( partial( self.clicked_select, selector=PolygonSelector, button=self.poly_button ) ) ax_reset = plt.axes([0.6, 0.03, 0.19, 0.02]) self.reset_button = Button(ax_reset, "Reset") self.reset_button.on_clicked(self.clicked_reset) self.srmin.on_changed(self.update_rmin) self.srmax.on_changed(self.update_rmax) self.draw(first=True, update_map=True) plt.show()