refactored the gui; improved it a lot

This commit is contained in:
Noe Brucy
2020-05-06 17:22:11 +02:00
committed by Noe Brucy
parent 8dfdf918a0
commit 0818372e0c
2 changed files with 439 additions and 117 deletions
+438 -90
View File
@@ -1,5 +1,173 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons, LassoSelector, SpanSelector
import matplotlib.patches as patches
from matplotlib.lines import Line2D
from matplotlib.path import Path
from scipy.stats import linregress
from postprocessor import *
from skimage.draw import line
class DraggablePoint:
# http://stackoverflow.com/questions/21654008/matplotlib-drag-overlapping-points-interactively
lock = None # only one can be animated at a time
def __init__(self, parent, x=0.1, y=0.1, size=0.1):
self.parent = parent
if self.parent.list_points:
color = "r"
else:
color = "g"
self.point = patches.Ellipse(
(x, y), size, size * 3, fc=color, alpha=0.5, edgecolor=color
)
self.x = x
self.y = y
self.parent.ax.add_patch(self.point)
self.press = None
self.background = None
self.connect()
if self.parent.list_points:
line_x = [self.parent.list_points[0].x, self.x]
line_y = [self.parent.list_points[0].y, self.y]
self.line = Line2D(line_x, line_y, color="k", alpha=0.5)
self.parent.ax.add_line(self.line)
def connect(self):
"connect to all the events we need"
self.cidpress = self.point.figure.canvas.mpl_connect(
"button_press_event", self.on_press
)
self.cidrelease = self.point.figure.canvas.mpl_connect(
"button_release_event", self.on_release
)
self.cidmotion = self.point.figure.canvas.mpl_connect(
"motion_notify_event", self.on_motion
)
def on_press(self, event):
if event.inaxes != self.point.axes:
return
if DraggablePoint.lock is not None:
return
contains, attrd = self.point.contains(event)
if not contains:
return
self.press = (self.point.center), event.xdata, event.ydata
DraggablePoint.lock = self
# draw everything but the selected rectangle and store the pixel buffer
canvas = self.point.figure.canvas
axes = self.point.axes
self.point.set_animated(True)
if self == self.parent.list_points[1]:
self.line.set_animated(True)
else:
self.parent.list_points[1].line.set_animated(True)
canvas.draw()
self.background = canvas.copy_from_bbox(self.point.axes.bbox)
# now redraw just the rectangle
axes.draw_artist(self.point)
# and blit just the redrawn area
canvas.blit(axes.bbox)
def on_motion(self, event):
if DraggablePoint.lock is not self:
return
if event.inaxes != self.point.axes:
return
self.point.center, xpress, ypress = self.press
dx = event.xdata - xpress
dy = event.ydata - ypress
self.point.center = (self.point.center[0] + dx, self.point.center[1] + dy)
canvas = self.point.figure.canvas
axes = self.point.axes
# restore the background region
canvas.restore_region(self.background)
# redraw just the current rectangle
axes.draw_artist(self.point)
if self == self.parent.list_points[1]:
axes.draw_artist(self.line)
else:
self.parent.list_points[1].line.set_animated(True)
axes.draw_artist(self.parent.list_points[1].line)
self.x = self.point.center[0]
self.y = self.point.center[1]
if self == self.parent.list_points[1]:
line_x = [self.parent.list_points[0].x, self.x]
line_y = [self.parent.list_points[0].y, self.y]
self.line.set_data(line_x, line_y)
else:
line_x = [self.x, self.parent.list_points[1].x]
line_y = [self.y, self.parent.list_points[1].y]
self.parent.list_points[1].line.set_data(line_x, line_y)
# blit just the redrawn area
canvas.blit(axes.bbox)
def on_release(self, event):
"on release we reset the press data"
if DraggablePoint.lock is not self:
return
self.press = None
DraggablePoint.lock = None
# turn off the rect animation property and reset the background
self.point.set_animated(False)
if self == self.parent.list_points[1]:
self.line.set_animated(False)
else:
self.parent.list_points[1].line.set_animated(False)
self.background = None
# redraw the full figure
self.point.figure.canvas.draw()
self.x = self.point.center[0]
self.y = self.point.center[1]
self.parent.draw()
def disconnect(self):
"disconnect all the stored connection ids"
self.point.figure.canvas.mpl_disconnect(self.cidpress)
self.point.figure.canvas.mpl_disconnect(self.cidrelease)
self.point.figure.canvas.mpl_disconnect(self.cidmotion)
class DraggableLine:
def __init__(self, parent, ax, xy1, xy2, size):
self.parent = parent
self.ax = ax
self.list_points = []
self.list_points.append(DraggablePoint(self, xy1[0], xy1[1], size))
self.list_points.append(DraggablePoint(self, xy2[0], xy2[1], size))
def draw(self):
self.parent.draw(update_map=False)
class InteractiveGUI:
@@ -7,103 +175,283 @@ class InteractiveGUI:
This is a matplotlib interactive session to restrain analysis to a specific area
"""
def onbuttonrelease(self, event):
"""Deal with click events"""
button = ["left", "middle", "right"]
toolbar = plt.get_current_fig_manager().toolbar
if toolbar.mode == "zoom rect" and event.inaxes == self.ax_col:
print("zooming ")
xlim = self.ax_col.get_xlim()
ylim = self.ax_col.get_ylim()
self.reset_mask()
elif self.add_mask and event.inaxes == self.ax_col:
self.plot_side()
plt.draw()
def onbuttonpress(self, event):
"""Deal with click events"""
button = ["left", "middle", "right"]
toolbar = plt.get_current_fig_manager().toolbar
if toolbar.mode != "":
print(
"You clicked on something, but toolbar is in mode {:s}.".format(
toolbar.mode
)
)
print(self.add_mask)
if self.add_mask and toolbar.mode == "" and event.inaxes == self.ax_col:
ix, iy = event.xdata, event.ydata
print("Add patch {}, {}".format(ix, iy))
xlim = self.ax_col.get_xlim()
ylim = self.ax_col.get_ylim()
radius = 0.05 * min(abs(xlim[1] - xlim[0]), abs(ylim[1] - ylim[0]))
circle = mpatches.Circle(
[ix, iy], radius, color="black", alpha=0.1, ec="none"
)
self.circles.append(circle)
self.ax_col.add_artist(circle)
self.ax_col.draw_artist(circle)
self.patch_mask = self.patch_mask | (
(self.xx - ix) ** 2 + (self.yy - iy) ** 2 < radius ** 2
)
# self.plot_side()
def onkeypress(self, event):
"""whenever a key is pressed"""
if not event.inaxes:
return
if event.key == "t":
self.add_mask = not self.add_mask
print("Add mode is {}".format(self.add_mask))
elif event.key == "r":
self.reset_mask()
def plot_side(self):
if self.add_mask:
mask = (self.patch_mask & self.mask).flatten()
else:
mask = self.mask.flatten()
self.ax_gamma.clear()
plt.sca(self.ax_gamma)
plot_dcsdrho(self.fluct_maps, mask, tag=self.tag)
self.ax_pdf.clear()
plt.sca(self.ax_pdf)
sigma_pdf(self.fluct_maps, mask, tag=self.tag, nb_bin_hist=self.args.pdf_nb_bin)
def reset_mask(self):
xlim = self.ax_col.get_xlim()
ylim = self.ax_col.get_ylim()
self.mask = (
(self.xx >= xlim[0])
& (self.xx <= xlim[1])
& (self.yy >= ylim[0])
& (self.yy <= ylim[1])
)
self.patch_mask = np.full(self.mask.shape, False)
for circle in self.circles:
circle.remove()
self.circles = []
self.plot_side()
def update_rmin(self, val):
# amp is the current value of the slider
self.rmin = self.srmin.val
# update curve
self.draw()
# redraw canvas while idle
plt.draw()
def __init__(self, num, path="./", pp_params=None):
def update_rmax(self, val):
# amp is the current value of the slider
self.rmax = self.srmax.val
# update curve
self.draw()
# redraw canvas while idle
plt.draw()
def draw(self, first=False, update_map=True):
if first or update_map:
self.fmap = np.copy(self.datamap)
self.frho = np.copy(self.frho_map)
self.fcs = np.copy(self.fcs_map)
self.fmap[np.logical_not(self.mask)] = np.nan
## Map
plt.sca(self.ax_fluct)
if first:
self.im = plt.imshow(
self.fmap,
origin="lower",
cmap="RdBu_r",
norm=mpl.colors.LogNorm(),
vmin=1e-2,
vmax=1e2,
)
plt.title("Fluctuations")
# cbar = plt.colorbar()
else:
self.im.set_data(self.fmap)
## Gamma
plt.sca(self.ax_gamma)
# (a, b, r, _, _) = linregress(self.frho_map[self.mask], self.fcs_map[self.mask])
if first:
_, _, _, self.gamma_hist = plt.hist2d(
self.frho_map[self.mask],
self.fcs_map[self.mask],
bins=100,
norm=mpl.colors.LogNorm(),
cmap=plt.get_cmap("plasma"),
)
plt.xlabel(r"$\log(\rho / \bar{\rho})$")
plt.ylabel(r"$\log(c_s / \bar{c_s})$")
# cbar = plt.colorbar()
else:
self.gamma_hist.remove()
_, _, _, self.gamma_hist = plt.hist2d(
self.frho_map[self.mask],
self.fcs_map[self.mask],
bins=100,
norm=mpl.colors.LogNorm(),
cmap=plt.get_cmap("plasma"),
)
plt.legend()
lps = self.line_gamma.list_points
xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y
a = (yb - ya) / (xb - xa)
self.ax_gamma.set_title("$\Gamma$ = {:.3g}".format(2 * a + 1))
## PDF
if first or update_map:
plt.sca(self.ax_pdf)
nb_cells = np.sum(self.mask_flat.flatten())
if first:
self.std_nb_cells = nb_cells
values, self.edges = np.histogram(
np.log10(self.fmap.flatten()[self.mask_flat]),
self.pp.pp_params.pdf.nb_bin,
weights=np.ones(nb_cells) / self.std_nb_cells,
)
edges = self.edges
plt.xlabel(r"$\log(\Sigma / \bar{\Sigma})$")
plt.ylabel(r"$\mathcal{P}_\Sigma$")
else:
values, edges = np.histogram(
np.log10(self.fmap.flatten()[self.mask_flat]),
self.edges,
weights=np.ones(nb_cells) / self.std_nb_cells,
)
centers = 0.5 * (edges[1:] + edges[:-1])
mask_fit = (
(centers > self.pp.pp_params.pdf.xmin_fit)
& (centers < self.pp.pp_params.pdf.xmax_fit)
& (values > 0)
)
(a, b, r, _, _) = linregress(centers[mask_fit], np.log10(values[mask_fit]))
if first:
plt.step(centers, values, where="mid", alpha=0.3)
(self.step,) = plt.step(centers, values, where="mid")
(self.fit,) = plt.plot(
centers,
10 ** (a * centers + b),
"--",
color="navy",
label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2),
)
plt.yscale("log")
plt.title("PDF")
else:
self.step.set_ydata(values)
self.fit.set_ydata(10 ** (a * centers + b))
self.fit.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2))
plt.legend()
### PROFILE
plt.sca(self.ax_profile)
lps = self.line_profile.list_points
xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y
xp, yp = line(int(xa), int(ya), int(xb), int(yb))
rho_prof = self.frho_map[yp, xp]
# Position on the line
# x = np.linspace(0, 1, len(rho_prof))
x = np.sqrt(np.abs((xp - xa) * (xb - xa) + (yp - ya) * (yb - ya)))
x = (x - float(x[0])) / self.shape[0] # /(x[-1] - x[0])
mask_fit_prof = (x >= self.fit_prof_vmin) & (x <= self.fit_prof_vmax)
try:
(a, b, r, _, _) = linregress(
np.log10(x[mask_fit_prof]), rho_prof[mask_fit_prof]
)
except ValueError as e:
print("Warning in linregress : {}".format(e))
a, b, r = np.nan, np.nan, np.nan
if first:
(self.prof,) = plt.semilogx(x, rho_prof)
plt.xlim([None, None])
plt.title("Profil")
(self.fit_prof,) = plt.plot(
x[mask_fit_prof],
a * np.log10(x[mask_fit_prof]) + b,
"--",
color="navy",
label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2),
)
self.ax_profile.set_xlabel(r"$r$")
self.ax_profile.set_ylabel(r"$\log(\rho)$")
else:
self.prof.set_data(x, rho_prof)
plt.xlim([None, None])
self.fit_prof.set_data(x[mask_fit_prof], a * np.log10(x[mask_fit_prof]) + b)
self.fit_prof.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2))
plt.legend()
plt.draw()
self.unselect_lasso()
def onselect(self, verts, select_data):
# update curve
path = Path(verts)
self.mask_flat = self.mask.flatten()
self.mask_flat = self.mask_flat & path.contains_points(select_data)
self.mask = self.mask_flat.reshape(self.shape)
self.draw()
def clicked_select(self, val):
if not self.lasso:
self.lasso = LassoSelector(
self.ax_fluct, partial(self.onselect, select_data=self.points)
)
self.lasso_gamma = LassoSelector(
self.ax_gamma, partial(self.onselect, select_data=self.rhocs)
)
self.lasso_button.color = "0.55"
else:
self.unselect_lasso()
def unselect_lasso(self):
self.lasso = False
self.lasso_gamma = False
self.lasso_button.color = "0.85"
def clicked_reset(self, val):
self.mask = (self.rr > self.rmin) & (self.rr < self.rmax)
self.mask_flat = self.mask.flatten()
self.lasso = False
self.lasso_button.color = "0.85"
self.draw()
def fit_selector(self, vmin, vmax):
self.fit_prof_vmin = vmin
self.fit_prof_vmax = vmax
self.draw()
def __init__(self, num, path="./", pp_params=None, datamap_key="fluct_coldens_z"):
"""
Interactive plotting
"""
pp = PostProcessor(path, num, pp_params=pp_params, tag="interactive")
pp.pdf_coldens("z")
fluct_map = pp.get_value("/maps/avg_map_coldens_z")
rr = pp.get_value("/maps/rr_z")
if pp_params is None:
pp_params = default_params()
pp_params.input.nml_filename = "disk.nml"
pp_params.out.interactive = True
pp_params.pymses.map_size = 4096
pp_params.pymses.zoom = 4
fig, ax = plt.subplots(2)
im = ax[0].imshow(fluct_map, origin="lower", cmap="RdBu_r")
cbar = plt.colorbar()
pp_params.pymses.variables = ["rho", "vel", "P"]
fig.canvas.mpl_connect("button_release_event", self.onbuttonrelease)
fig.canvas.mpl_connect("button_press_event", self.onbuttonpress)
fig.canvas.mpl_connect("key_press_event", self.onkeypress)
pp_params.disk.enable = True
pp_params.disk.nb_bin = 200
pp_params.pdf.nb_bin = 100
self.fig = plt.figure(figsize=(10, 8))
self.pp = PostProcessor(path, num, pp_params=pp_params, tag="interactive")
self.pp.pdf_coldens("z")
self.pp.pdf_rho("z")
self.pp.pdf_T("z")
self.datamap = self.pp.get_value("/maps/" + datamap_key)
self.frho_map = np.log10(self.pp.get_value("/maps/fluct_rho_z"))
self.T_map = self.pp.get_value("/maps/T_z")
self.fcs_map = np.log10(np.sqrt(self.pp.get_value("/maps/fluct_T_z")))
self.rr = self.pp.get_value("/maps/rr_z")
self.shape = self.datamap.shape
x = np.arange(self.shape[0])
y = np.arange(self.shape[1])
self.xx, self.yy = np.meshgrid(x, y)
self.points = np.column_stack((self.xx.flatten(), self.yy.flatten()))
self.rhocs = np.column_stack((self.frho_map.flatten(), self.fcs_map.flatten()))
self.rmin = self.pp.pp_params.disk.rmin_pdf
self.rmax = self.pp.pp_params.disk.rmax_pdf
self.ax_fluct = plt.axes([0.05, 0.6, 0.4, 0.35])
self.ax_gamma = plt.axes([0.05, 0.15, 0.4, 0.35])
self.ax_pdf = plt.axes([0.55, 0.6, 0.4, 0.35])
self.ax_profile = plt.axes([0.55, 0.15, 0.4, 0.35])
self.fit_prof_vmin = 1e-2
self.fit_prof_vmax = 1e-1
self.fit_profile_selector = SpanSelector(
self.ax_profile,
self.fit_selector,
"horizontal",
useblit=True,
span_stays=True,
)
self.line_profile = DraggableLine(
self, self.ax_fluct, [100, 2000], [2000, 2000], 50
)
self.line_gamma = DraggableLine(self, self.ax_gamma, [0, 0], [1, 0.3], 0.03)
self.mask = (self.rr > self.rmin) & (self.rr < self.rmax)
self.mask_flat = self.mask.flatten()
ax_rmax = plt.axes([0.1, 0.03, 0.4, 0.02])
self.srmax = Slider(ax_rmax, r"$r_{max}$", 0, 0.5, valinit=self.rmax)
ax_rmin = plt.axes([0.1, 0.07, 0.4, 0.02])
self.srmin = Slider(ax_rmin, r"$r_{min}$", 0, 0.5, valinit=self.rmin)
ax_lasso = plt.axes([0.6, 0.07, 0.2, 0.02])
self.lasso = False
self.lasso_gamma = False
self.lasso_button = Button(ax_lasso, "Select region")
self.lasso_button.on_clicked(self.clicked_select)
ax_reset = plt.axes([0.6, 0.03, 0.2, 0.02])
self.reset_button = Button(ax_reset, "Reset")
self.reset_button.on_clicked(self.clicked_reset)
self.srmin.on_changed(self.update_rmin)
self.srmax.on_changed(self.update_rmax)
self.draw(True)
plt.tight_layout()
plt.show()
+1 -27
View File
@@ -215,7 +215,7 @@ class PostProcessor(HDF5Container):
else:
op = ScalarOperator(getter, unit)
if pp_params.pymses.fft:
if self.pp_params.pymses.fft:
rt = splatting.SplatterProcessor(self._amr, self._ro.info, op)
else:
rt = raytracing.RayTracer(self._amr, self._ro.info, op)
@@ -545,31 +545,6 @@ class PostProcessor(HDF5Container):
pdf.attrs.var = np.var
return True
def _clumps(self):
name = self.path_out + "/" + self.tag + "_" + str(self.num).zfill(5)
hop_save = name + "_hop" + "_prop_struct.save"
me.make_clump_hop(
self.path,
self.num,
name + "_hop",
self.pp_params.hop.rho_thres,
self.pp_params.hop.lvl_thres,
[0.5, 0.5, 0.5],
1,
path_out=path_out + "/",
path_hop="./",
force=True,
gcomp=False,
)
hop_save = me.clump_properties(
name + "_hop", path, num, path_out=path_out + "/", gcomp=False
)
f = open(path_out + "/" + hop_save)
hop_data = pickle.load(f)
f.close()
return hop_data
def _sinks(self):
csv_name = (
self.path
@@ -699,7 +674,6 @@ class PostProcessor(HDF5Container):
"Teff": cst.K,
},
),
"clumps": Rule(self, self._clumps, group="/datasets"),
# Helpers
"radial_bins": Rule(self, self._radial_bins, "Radial bins", "/radial"),
"rr": Rule(self, self._rr, "Coordinate map", "/maps"),