diff --git a/gui.py b/gui.py index c656332..17fdef0 100644 --- a/gui.py +++ b/gui.py @@ -1,5 +1,173 @@ +import matplotlib as mpl import matplotlib.pyplot as plt +from matplotlib.widgets import Slider, Button, RadioButtons, LassoSelector, SpanSelector +import matplotlib.patches as patches +from matplotlib.lines import Line2D +from matplotlib.path import Path +from scipy.stats import linregress from postprocessor import * +from skimage.draw import line + + +class DraggablePoint: + + # http://stackoverflow.com/questions/21654008/matplotlib-drag-overlapping-points-interactively + + lock = None # only one can be animated at a time + + def __init__(self, parent, x=0.1, y=0.1, size=0.1): + + self.parent = parent + if self.parent.list_points: + color = "r" + else: + color = "g" + self.point = patches.Ellipse( + (x, y), size, size * 3, fc=color, alpha=0.5, edgecolor=color + ) + self.x = x + self.y = y + self.parent.ax.add_patch(self.point) + self.press = None + self.background = None + self.connect() + + if self.parent.list_points: + line_x = [self.parent.list_points[0].x, self.x] + line_y = [self.parent.list_points[0].y, self.y] + + self.line = Line2D(line_x, line_y, color="k", alpha=0.5) + self.parent.ax.add_line(self.line) + + def connect(self): + + "connect to all the events we need" + + self.cidpress = self.point.figure.canvas.mpl_connect( + "button_press_event", self.on_press + ) + self.cidrelease = self.point.figure.canvas.mpl_connect( + "button_release_event", self.on_release + ) + self.cidmotion = self.point.figure.canvas.mpl_connect( + "motion_notify_event", self.on_motion + ) + + def on_press(self, event): + + if event.inaxes != self.point.axes: + return + if DraggablePoint.lock is not None: + return + contains, attrd = self.point.contains(event) + if not contains: + return + self.press = (self.point.center), event.xdata, event.ydata + DraggablePoint.lock = self + + # draw everything but the selected rectangle and store the pixel buffer + canvas = self.point.figure.canvas + axes = self.point.axes + self.point.set_animated(True) + if self == self.parent.list_points[1]: + self.line.set_animated(True) + else: + self.parent.list_points[1].line.set_animated(True) + canvas.draw() + self.background = canvas.copy_from_bbox(self.point.axes.bbox) + + # now redraw just the rectangle + axes.draw_artist(self.point) + + # and blit just the redrawn area + canvas.blit(axes.bbox) + + def on_motion(self, event): + + if DraggablePoint.lock is not self: + return + if event.inaxes != self.point.axes: + return + self.point.center, xpress, ypress = self.press + dx = event.xdata - xpress + dy = event.ydata - ypress + self.point.center = (self.point.center[0] + dx, self.point.center[1] + dy) + + canvas = self.point.figure.canvas + axes = self.point.axes + # restore the background region + canvas.restore_region(self.background) + + # redraw just the current rectangle + axes.draw_artist(self.point) + + if self == self.parent.list_points[1]: + axes.draw_artist(self.line) + else: + self.parent.list_points[1].line.set_animated(True) + axes.draw_artist(self.parent.list_points[1].line) + + self.x = self.point.center[0] + self.y = self.point.center[1] + + if self == self.parent.list_points[1]: + line_x = [self.parent.list_points[0].x, self.x] + line_y = [self.parent.list_points[0].y, self.y] + self.line.set_data(line_x, line_y) + else: + line_x = [self.x, self.parent.list_points[1].x] + line_y = [self.y, self.parent.list_points[1].y] + + self.parent.list_points[1].line.set_data(line_x, line_y) + + # blit just the redrawn area + canvas.blit(axes.bbox) + + def on_release(self, event): + + "on release we reset the press data" + if DraggablePoint.lock is not self: + return + + self.press = None + DraggablePoint.lock = None + + # turn off the rect animation property and reset the background + self.point.set_animated(False) + if self == self.parent.list_points[1]: + self.line.set_animated(False) + else: + self.parent.list_points[1].line.set_animated(False) + + self.background = None + + # redraw the full figure + self.point.figure.canvas.draw() + + self.x = self.point.center[0] + self.y = self.point.center[1] + self.parent.draw() + + def disconnect(self): + + "disconnect all the stored connection ids" + + self.point.figure.canvas.mpl_disconnect(self.cidpress) + self.point.figure.canvas.mpl_disconnect(self.cidrelease) + self.point.figure.canvas.mpl_disconnect(self.cidmotion) + + +class DraggableLine: + def __init__(self, parent, ax, xy1, xy2, size): + + self.parent = parent + self.ax = ax + self.list_points = [] + self.list_points.append(DraggablePoint(self, xy1[0], xy1[1], size)) + self.list_points.append(DraggablePoint(self, xy2[0], xy2[1], size)) + + def draw(self): + self.parent.draw(update_map=False) class InteractiveGUI: @@ -7,103 +175,283 @@ class InteractiveGUI: This is a matplotlib interactive session to restrain analysis to a specific area """ - def onbuttonrelease(self, event): - """Deal with click events""" - button = ["left", "middle", "right"] - toolbar = plt.get_current_fig_manager().toolbar - if toolbar.mode == "zoom rect" and event.inaxes == self.ax_col: - print("zooming ") - xlim = self.ax_col.get_xlim() - ylim = self.ax_col.get_ylim() - self.reset_mask() - elif self.add_mask and event.inaxes == self.ax_col: - self.plot_side() - plt.draw() - - def onbuttonpress(self, event): - """Deal with click events""" - button = ["left", "middle", "right"] - toolbar = plt.get_current_fig_manager().toolbar - if toolbar.mode != "": - print( - "You clicked on something, but toolbar is in mode {:s}.".format( - toolbar.mode - ) - ) - print(self.add_mask) - if self.add_mask and toolbar.mode == "" and event.inaxes == self.ax_col: - ix, iy = event.xdata, event.ydata - print("Add patch {}, {}".format(ix, iy)) - xlim = self.ax_col.get_xlim() - ylim = self.ax_col.get_ylim() - radius = 0.05 * min(abs(xlim[1] - xlim[0]), abs(ylim[1] - ylim[0])) - circle = mpatches.Circle( - [ix, iy], radius, color="black", alpha=0.1, ec="none" - ) - self.circles.append(circle) - self.ax_col.add_artist(circle) - self.ax_col.draw_artist(circle) - self.patch_mask = self.patch_mask | ( - (self.xx - ix) ** 2 + (self.yy - iy) ** 2 < radius ** 2 - ) - # self.plot_side() - - def onkeypress(self, event): - """whenever a key is pressed""" - if not event.inaxes: - return - if event.key == "t": - self.add_mask = not self.add_mask - print("Add mode is {}".format(self.add_mask)) - elif event.key == "r": - self.reset_mask() - - def plot_side(self): - if self.add_mask: - mask = (self.patch_mask & self.mask).flatten() - else: - mask = self.mask.flatten() - self.ax_gamma.clear() - plt.sca(self.ax_gamma) - plot_dcsdrho(self.fluct_maps, mask, tag=self.tag) - - self.ax_pdf.clear() - plt.sca(self.ax_pdf) - sigma_pdf(self.fluct_maps, mask, tag=self.tag, nb_bin_hist=self.args.pdf_nb_bin) - - def reset_mask(self): - xlim = self.ax_col.get_xlim() - ylim = self.ax_col.get_ylim() - self.mask = ( - (self.xx >= xlim[0]) - & (self.xx <= xlim[1]) - & (self.yy >= ylim[0]) - & (self.yy <= ylim[1]) - ) - self.patch_mask = np.full(self.mask.shape, False) - for circle in self.circles: - circle.remove() - self.circles = [] - self.plot_side() + def update_rmin(self, val): + # amp is the current value of the slider + self.rmin = self.srmin.val + # update curve + self.draw() + # redraw canvas while idle plt.draw() - def __init__(self, num, path="./", pp_params=None): + def update_rmax(self, val): + # amp is the current value of the slider + self.rmax = self.srmax.val + # update curve + self.draw() + # redraw canvas while idle + plt.draw() + + def draw(self, first=False, update_map=True): + + if first or update_map: + self.fmap = np.copy(self.datamap) + self.frho = np.copy(self.frho_map) + self.fcs = np.copy(self.fcs_map) + self.fmap[np.logical_not(self.mask)] = np.nan + + ## Map + plt.sca(self.ax_fluct) + if first: + self.im = plt.imshow( + self.fmap, + origin="lower", + cmap="RdBu_r", + norm=mpl.colors.LogNorm(), + vmin=1e-2, + vmax=1e2, + ) + plt.title("Fluctuations") + # cbar = plt.colorbar() + else: + self.im.set_data(self.fmap) + + ## Gamma + plt.sca(self.ax_gamma) + # (a, b, r, _, _) = linregress(self.frho_map[self.mask], self.fcs_map[self.mask]) + + if first: + _, _, _, self.gamma_hist = plt.hist2d( + self.frho_map[self.mask], + self.fcs_map[self.mask], + bins=100, + norm=mpl.colors.LogNorm(), + cmap=plt.get_cmap("plasma"), + ) + plt.xlabel(r"$\log(\rho / \bar{\rho})$") + plt.ylabel(r"$\log(c_s / \bar{c_s})$") + # cbar = plt.colorbar() + else: + self.gamma_hist.remove() + _, _, _, self.gamma_hist = plt.hist2d( + self.frho_map[self.mask], + self.fcs_map[self.mask], + bins=100, + norm=mpl.colors.LogNorm(), + cmap=plt.get_cmap("plasma"), + ) + plt.legend() + + lps = self.line_gamma.list_points + xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y + a = (yb - ya) / (xb - xa) + self.ax_gamma.set_title("$\Gamma$ = {:.3g}".format(2 * a + 1)) + + ## PDF + if first or update_map: + plt.sca(self.ax_pdf) + nb_cells = np.sum(self.mask_flat.flatten()) + if first: + self.std_nb_cells = nb_cells + values, self.edges = np.histogram( + np.log10(self.fmap.flatten()[self.mask_flat]), + self.pp.pp_params.pdf.nb_bin, + weights=np.ones(nb_cells) / self.std_nb_cells, + ) + edges = self.edges + plt.xlabel(r"$\log(\Sigma / \bar{\Sigma})$") + plt.ylabel(r"$\mathcal{P}_\Sigma$") + else: + values, edges = np.histogram( + np.log10(self.fmap.flatten()[self.mask_flat]), + self.edges, + weights=np.ones(nb_cells) / self.std_nb_cells, + ) + + centers = 0.5 * (edges[1:] + edges[:-1]) + mask_fit = ( + (centers > self.pp.pp_params.pdf.xmin_fit) + & (centers < self.pp.pp_params.pdf.xmax_fit) + & (values > 0) + ) + (a, b, r, _, _) = linregress(centers[mask_fit], np.log10(values[mask_fit])) + + if first: + plt.step(centers, values, where="mid", alpha=0.3) + (self.step,) = plt.step(centers, values, where="mid") + (self.fit,) = plt.plot( + centers, + 10 ** (a * centers + b), + "--", + color="navy", + label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2), + ) + plt.yscale("log") + plt.title("PDF") + else: + self.step.set_ydata(values) + self.fit.set_ydata(10 ** (a * centers + b)) + self.fit.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2)) + plt.legend() + + ### PROFILE + plt.sca(self.ax_profile) + lps = self.line_profile.list_points + xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y + xp, yp = line(int(xa), int(ya), int(xb), int(yb)) + rho_prof = self.frho_map[yp, xp] + # Position on the line + + # x = np.linspace(0, 1, len(rho_prof)) + x = np.sqrt(np.abs((xp - xa) * (xb - xa) + (yp - ya) * (yb - ya))) + x = (x - float(x[0])) / self.shape[0] # /(x[-1] - x[0]) + mask_fit_prof = (x >= self.fit_prof_vmin) & (x <= self.fit_prof_vmax) + try: + (a, b, r, _, _) = linregress( + np.log10(x[mask_fit_prof]), rho_prof[mask_fit_prof] + ) + except ValueError as e: + print("Warning in linregress : {}".format(e)) + a, b, r = np.nan, np.nan, np.nan + + if first: + (self.prof,) = plt.semilogx(x, rho_prof) + plt.xlim([None, None]) + plt.title("Profil") + (self.fit_prof,) = plt.plot( + x[mask_fit_prof], + a * np.log10(x[mask_fit_prof]) + b, + "--", + color="navy", + label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2), + ) + self.ax_profile.set_xlabel(r"$r$") + self.ax_profile.set_ylabel(r"$\log(\rho)$") + else: + self.prof.set_data(x, rho_prof) + plt.xlim([None, None]) + self.fit_prof.set_data(x[mask_fit_prof], a * np.log10(x[mask_fit_prof]) + b) + self.fit_prof.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2)) + + plt.legend() + plt.draw() + self.unselect_lasso() + + def onselect(self, verts, select_data): + # update curve + + path = Path(verts) + self.mask_flat = self.mask.flatten() + + self.mask_flat = self.mask_flat & path.contains_points(select_data) + self.mask = self.mask_flat.reshape(self.shape) + self.draw() + + def clicked_select(self, val): + if not self.lasso: + self.lasso = LassoSelector( + self.ax_fluct, partial(self.onselect, select_data=self.points) + ) + self.lasso_gamma = LassoSelector( + self.ax_gamma, partial(self.onselect, select_data=self.rhocs) + ) + self.lasso_button.color = "0.55" + else: + self.unselect_lasso() + + def unselect_lasso(self): + self.lasso = False + self.lasso_gamma = False + self.lasso_button.color = "0.85" + + def clicked_reset(self, val): + self.mask = (self.rr > self.rmin) & (self.rr < self.rmax) + self.mask_flat = self.mask.flatten() + self.lasso = False + self.lasso_button.color = "0.85" + self.draw() + + def fit_selector(self, vmin, vmax): + self.fit_prof_vmin = vmin + self.fit_prof_vmax = vmax + self.draw() + + def __init__(self, num, path="./", pp_params=None, datamap_key="fluct_coldens_z"): """ Interactive plotting """ - pp = PostProcessor(path, num, pp_params=pp_params, tag="interactive") - pp.pdf_coldens("z") - fluct_map = pp.get_value("/maps/avg_map_coldens_z") - rr = pp.get_value("/maps/rr_z") + if pp_params is None: + pp_params = default_params() + pp_params.input.nml_filename = "disk.nml" + pp_params.out.interactive = True + pp_params.pymses.map_size = 4096 + pp_params.pymses.zoom = 4 - fig, ax = plt.subplots(2) - im = ax[0].imshow(fluct_map, origin="lower", cmap="RdBu_r") - cbar = plt.colorbar() + pp_params.pymses.variables = ["rho", "vel", "P"] - fig.canvas.mpl_connect("button_release_event", self.onbuttonrelease) - fig.canvas.mpl_connect("button_press_event", self.onbuttonpress) - fig.canvas.mpl_connect("key_press_event", self.onkeypress) + pp_params.disk.enable = True + pp_params.disk.nb_bin = 200 + pp_params.pdf.nb_bin = 100 + + self.fig = plt.figure(figsize=(10, 8)) + + self.pp = PostProcessor(path, num, pp_params=pp_params, tag="interactive") + self.pp.pdf_coldens("z") + self.pp.pdf_rho("z") + self.pp.pdf_T("z") + self.datamap = self.pp.get_value("/maps/" + datamap_key) + self.frho_map = np.log10(self.pp.get_value("/maps/fluct_rho_z")) + self.T_map = self.pp.get_value("/maps/T_z") + self.fcs_map = np.log10(np.sqrt(self.pp.get_value("/maps/fluct_T_z"))) + self.rr = self.pp.get_value("/maps/rr_z") + + self.shape = self.datamap.shape + x = np.arange(self.shape[0]) + y = np.arange(self.shape[1]) + self.xx, self.yy = np.meshgrid(x, y) + self.points = np.column_stack((self.xx.flatten(), self.yy.flatten())) + self.rhocs = np.column_stack((self.frho_map.flatten(), self.fcs_map.flatten())) + + self.rmin = self.pp.pp_params.disk.rmin_pdf + self.rmax = self.pp.pp_params.disk.rmax_pdf + + self.ax_fluct = plt.axes([0.05, 0.6, 0.4, 0.35]) + self.ax_gamma = plt.axes([0.05, 0.15, 0.4, 0.35]) + self.ax_pdf = plt.axes([0.55, 0.6, 0.4, 0.35]) + self.ax_profile = plt.axes([0.55, 0.15, 0.4, 0.35]) + self.fit_prof_vmin = 1e-2 + self.fit_prof_vmax = 1e-1 + self.fit_profile_selector = SpanSelector( + self.ax_profile, + self.fit_selector, + "horizontal", + useblit=True, + span_stays=True, + ) + self.line_profile = DraggableLine( + self, self.ax_fluct, [100, 2000], [2000, 2000], 50 + ) + self.line_gamma = DraggableLine(self, self.ax_gamma, [0, 0], [1, 0.3], 0.03) + self.mask = (self.rr > self.rmin) & (self.rr < self.rmax) + self.mask_flat = self.mask.flatten() + + ax_rmax = plt.axes([0.1, 0.03, 0.4, 0.02]) + self.srmax = Slider(ax_rmax, r"$r_{max}$", 0, 0.5, valinit=self.rmax) + ax_rmin = plt.axes([0.1, 0.07, 0.4, 0.02]) + self.srmin = Slider(ax_rmin, r"$r_{min}$", 0, 0.5, valinit=self.rmin) + ax_lasso = plt.axes([0.6, 0.07, 0.2, 0.02]) + self.lasso = False + self.lasso_gamma = False + self.lasso_button = Button(ax_lasso, "Select region") + self.lasso_button.on_clicked(self.clicked_select) + + ax_reset = plt.axes([0.6, 0.03, 0.2, 0.02]) + self.reset_button = Button(ax_reset, "Reset") + self.reset_button.on_clicked(self.clicked_reset) + + self.srmin.on_changed(self.update_rmin) + self.srmax.on_changed(self.update_rmax) + + self.draw(True) - plt.tight_layout() plt.show() diff --git a/postprocessor.py b/postprocessor.py index 8323519..ae8ef19 100644 --- a/postprocessor.py +++ b/postprocessor.py @@ -215,7 +215,7 @@ class PostProcessor(HDF5Container): else: op = ScalarOperator(getter, unit) - if pp_params.pymses.fft: + if self.pp_params.pymses.fft: rt = splatting.SplatterProcessor(self._amr, self._ro.info, op) else: rt = raytracing.RayTracer(self._amr, self._ro.info, op) @@ -545,31 +545,6 @@ class PostProcessor(HDF5Container): pdf.attrs.var = np.var return True - def _clumps(self): - name = self.path_out + "/" + self.tag + "_" + str(self.num).zfill(5) - hop_save = name + "_hop" + "_prop_struct.save" - - me.make_clump_hop( - self.path, - self.num, - name + "_hop", - self.pp_params.hop.rho_thres, - self.pp_params.hop.lvl_thres, - [0.5, 0.5, 0.5], - 1, - path_out=path_out + "/", - path_hop="./", - force=True, - gcomp=False, - ) - hop_save = me.clump_properties( - name + "_hop", path, num, path_out=path_out + "/", gcomp=False - ) - f = open(path_out + "/" + hop_save) - hop_data = pickle.load(f) - f.close() - return hop_data - def _sinks(self): csv_name = ( self.path @@ -699,7 +674,6 @@ class PostProcessor(HDF5Container): "Teff": cst.K, }, ), - "clumps": Rule(self, self._clumps, group="/datasets"), # Helpers "radial_bins": Rule(self, self._radial_bins, "Radial bins", "/radial"), "rr": Rule(self, self._rr, "Coordinate map", "/maps"),