refactored the gui; improved it a lot

This commit is contained in:
Noe Brucy
2020-05-06 17:22:11 +02:00
committed by Noe Brucy
parent 8dfdf918a0
commit 0818372e0c
2 changed files with 439 additions and 117 deletions
+431 -83
View File
@@ -1,5 +1,173 @@
import matplotlib as mpl
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons, LassoSelector, SpanSelector
import matplotlib.patches as patches
from matplotlib.lines import Line2D
from matplotlib.path import Path
from scipy.stats import linregress
from postprocessor import * from postprocessor import *
from skimage.draw import line
class DraggablePoint:
# http://stackoverflow.com/questions/21654008/matplotlib-drag-overlapping-points-interactively
lock = None # only one can be animated at a time
def __init__(self, parent, x=0.1, y=0.1, size=0.1):
self.parent = parent
if self.parent.list_points:
color = "r"
else:
color = "g"
self.point = patches.Ellipse(
(x, y), size, size * 3, fc=color, alpha=0.5, edgecolor=color
)
self.x = x
self.y = y
self.parent.ax.add_patch(self.point)
self.press = None
self.background = None
self.connect()
if self.parent.list_points:
line_x = [self.parent.list_points[0].x, self.x]
line_y = [self.parent.list_points[0].y, self.y]
self.line = Line2D(line_x, line_y, color="k", alpha=0.5)
self.parent.ax.add_line(self.line)
def connect(self):
"connect to all the events we need"
self.cidpress = self.point.figure.canvas.mpl_connect(
"button_press_event", self.on_press
)
self.cidrelease = self.point.figure.canvas.mpl_connect(
"button_release_event", self.on_release
)
self.cidmotion = self.point.figure.canvas.mpl_connect(
"motion_notify_event", self.on_motion
)
def on_press(self, event):
if event.inaxes != self.point.axes:
return
if DraggablePoint.lock is not None:
return
contains, attrd = self.point.contains(event)
if not contains:
return
self.press = (self.point.center), event.xdata, event.ydata
DraggablePoint.lock = self
# draw everything but the selected rectangle and store the pixel buffer
canvas = self.point.figure.canvas
axes = self.point.axes
self.point.set_animated(True)
if self == self.parent.list_points[1]:
self.line.set_animated(True)
else:
self.parent.list_points[1].line.set_animated(True)
canvas.draw()
self.background = canvas.copy_from_bbox(self.point.axes.bbox)
# now redraw just the rectangle
axes.draw_artist(self.point)
# and blit just the redrawn area
canvas.blit(axes.bbox)
def on_motion(self, event):
if DraggablePoint.lock is not self:
return
if event.inaxes != self.point.axes:
return
self.point.center, xpress, ypress = self.press
dx = event.xdata - xpress
dy = event.ydata - ypress
self.point.center = (self.point.center[0] + dx, self.point.center[1] + dy)
canvas = self.point.figure.canvas
axes = self.point.axes
# restore the background region
canvas.restore_region(self.background)
# redraw just the current rectangle
axes.draw_artist(self.point)
if self == self.parent.list_points[1]:
axes.draw_artist(self.line)
else:
self.parent.list_points[1].line.set_animated(True)
axes.draw_artist(self.parent.list_points[1].line)
self.x = self.point.center[0]
self.y = self.point.center[1]
if self == self.parent.list_points[1]:
line_x = [self.parent.list_points[0].x, self.x]
line_y = [self.parent.list_points[0].y, self.y]
self.line.set_data(line_x, line_y)
else:
line_x = [self.x, self.parent.list_points[1].x]
line_y = [self.y, self.parent.list_points[1].y]
self.parent.list_points[1].line.set_data(line_x, line_y)
# blit just the redrawn area
canvas.blit(axes.bbox)
def on_release(self, event):
"on release we reset the press data"
if DraggablePoint.lock is not self:
return
self.press = None
DraggablePoint.lock = None
# turn off the rect animation property and reset the background
self.point.set_animated(False)
if self == self.parent.list_points[1]:
self.line.set_animated(False)
else:
self.parent.list_points[1].line.set_animated(False)
self.background = None
# redraw the full figure
self.point.figure.canvas.draw()
self.x = self.point.center[0]
self.y = self.point.center[1]
self.parent.draw()
def disconnect(self):
"disconnect all the stored connection ids"
self.point.figure.canvas.mpl_disconnect(self.cidpress)
self.point.figure.canvas.mpl_disconnect(self.cidrelease)
self.point.figure.canvas.mpl_disconnect(self.cidmotion)
class DraggableLine:
def __init__(self, parent, ax, xy1, xy2, size):
self.parent = parent
self.ax = ax
self.list_points = []
self.list_points.append(DraggablePoint(self, xy1[0], xy1[1], size))
self.list_points.append(DraggablePoint(self, xy2[0], xy2[1], size))
def draw(self):
self.parent.draw(update_map=False)
class InteractiveGUI: class InteractiveGUI:
@@ -7,103 +175,283 @@ class InteractiveGUI:
This is a matplotlib interactive session to restrain analysis to a specific area This is a matplotlib interactive session to restrain analysis to a specific area
""" """
def onbuttonrelease(self, event): def update_rmin(self, val):
"""Deal with click events""" # amp is the current value of the slider
button = ["left", "middle", "right"] self.rmin = self.srmin.val
toolbar = plt.get_current_fig_manager().toolbar # update curve
if toolbar.mode == "zoom rect" and event.inaxes == self.ax_col: self.draw()
print("zooming ") # redraw canvas while idle
xlim = self.ax_col.get_xlim()
ylim = self.ax_col.get_ylim()
self.reset_mask()
elif self.add_mask and event.inaxes == self.ax_col:
self.plot_side()
plt.draw() plt.draw()
def onbuttonpress(self, event): def update_rmax(self, val):
"""Deal with click events""" # amp is the current value of the slider
button = ["left", "middle", "right"] self.rmax = self.srmax.val
toolbar = plt.get_current_fig_manager().toolbar # update curve
if toolbar.mode != "": self.draw()
print( # redraw canvas while idle
"You clicked on something, but toolbar is in mode {:s}.".format( plt.draw()
toolbar.mode
)
)
print(self.add_mask)
if self.add_mask and toolbar.mode == "" and event.inaxes == self.ax_col:
ix, iy = event.xdata, event.ydata
print("Add patch {}, {}".format(ix, iy))
xlim = self.ax_col.get_xlim()
ylim = self.ax_col.get_ylim()
radius = 0.05 * min(abs(xlim[1] - xlim[0]), abs(ylim[1] - ylim[0]))
circle = mpatches.Circle(
[ix, iy], radius, color="black", alpha=0.1, ec="none"
)
self.circles.append(circle)
self.ax_col.add_artist(circle)
self.ax_col.draw_artist(circle)
self.patch_mask = self.patch_mask | (
(self.xx - ix) ** 2 + (self.yy - iy) ** 2 < radius ** 2
)
# self.plot_side()
def onkeypress(self, event): def draw(self, first=False, update_map=True):
"""whenever a key is pressed"""
if not event.inaxes:
return
if event.key == "t":
self.add_mask = not self.add_mask
print("Add mode is {}".format(self.add_mask))
elif event.key == "r":
self.reset_mask()
def plot_side(self): if first or update_map:
if self.add_mask: self.fmap = np.copy(self.datamap)
mask = (self.patch_mask & self.mask).flatten() self.frho = np.copy(self.frho_map)
self.fcs = np.copy(self.fcs_map)
self.fmap[np.logical_not(self.mask)] = np.nan
## Map
plt.sca(self.ax_fluct)
if first:
self.im = plt.imshow(
self.fmap,
origin="lower",
cmap="RdBu_r",
norm=mpl.colors.LogNorm(),
vmin=1e-2,
vmax=1e2,
)
plt.title("Fluctuations")
# cbar = plt.colorbar()
else: else:
mask = self.mask.flatten() self.im.set_data(self.fmap)
self.ax_gamma.clear()
## Gamma
plt.sca(self.ax_gamma) plt.sca(self.ax_gamma)
plot_dcsdrho(self.fluct_maps, mask, tag=self.tag) # (a, b, r, _, _) = linregress(self.frho_map[self.mask], self.fcs_map[self.mask])
self.ax_pdf.clear() if first:
plt.sca(self.ax_pdf) _, _, _, self.gamma_hist = plt.hist2d(
sigma_pdf(self.fluct_maps, mask, tag=self.tag, nb_bin_hist=self.args.pdf_nb_bin) self.frho_map[self.mask],
self.fcs_map[self.mask],
def reset_mask(self): bins=100,
xlim = self.ax_col.get_xlim() norm=mpl.colors.LogNorm(),
ylim = self.ax_col.get_ylim() cmap=plt.get_cmap("plasma"),
self.mask = (
(self.xx >= xlim[0])
& (self.xx <= xlim[1])
& (self.yy >= ylim[0])
& (self.yy <= ylim[1])
) )
self.patch_mask = np.full(self.mask.shape, False) plt.xlabel(r"$\log(\rho / \bar{\rho})$")
for circle in self.circles: plt.ylabel(r"$\log(c_s / \bar{c_s})$")
circle.remove() # cbar = plt.colorbar()
self.circles = [] else:
self.plot_side() self.gamma_hist.remove()
plt.draw() _, _, _, self.gamma_hist = plt.hist2d(
self.frho_map[self.mask],
self.fcs_map[self.mask],
bins=100,
norm=mpl.colors.LogNorm(),
cmap=plt.get_cmap("plasma"),
)
plt.legend()
def __init__(self, num, path="./", pp_params=None): lps = self.line_gamma.list_points
xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y
a = (yb - ya) / (xb - xa)
self.ax_gamma.set_title("$\Gamma$ = {:.3g}".format(2 * a + 1))
## PDF
if first or update_map:
plt.sca(self.ax_pdf)
nb_cells = np.sum(self.mask_flat.flatten())
if first:
self.std_nb_cells = nb_cells
values, self.edges = np.histogram(
np.log10(self.fmap.flatten()[self.mask_flat]),
self.pp.pp_params.pdf.nb_bin,
weights=np.ones(nb_cells) / self.std_nb_cells,
)
edges = self.edges
plt.xlabel(r"$\log(\Sigma / \bar{\Sigma})$")
plt.ylabel(r"$\mathcal{P}_\Sigma$")
else:
values, edges = np.histogram(
np.log10(self.fmap.flatten()[self.mask_flat]),
self.edges,
weights=np.ones(nb_cells) / self.std_nb_cells,
)
centers = 0.5 * (edges[1:] + edges[:-1])
mask_fit = (
(centers > self.pp.pp_params.pdf.xmin_fit)
& (centers < self.pp.pp_params.pdf.xmax_fit)
& (values > 0)
)
(a, b, r, _, _) = linregress(centers[mask_fit], np.log10(values[mask_fit]))
if first:
plt.step(centers, values, where="mid", alpha=0.3)
(self.step,) = plt.step(centers, values, where="mid")
(self.fit,) = plt.plot(
centers,
10 ** (a * centers + b),
"--",
color="navy",
label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2),
)
plt.yscale("log")
plt.title("PDF")
else:
self.step.set_ydata(values)
self.fit.set_ydata(10 ** (a * centers + b))
self.fit.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2))
plt.legend()
### PROFILE
plt.sca(self.ax_profile)
lps = self.line_profile.list_points
xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y
xp, yp = line(int(xa), int(ya), int(xb), int(yb))
rho_prof = self.frho_map[yp, xp]
# Position on the line
# x = np.linspace(0, 1, len(rho_prof))
x = np.sqrt(np.abs((xp - xa) * (xb - xa) + (yp - ya) * (yb - ya)))
x = (x - float(x[0])) / self.shape[0] # /(x[-1] - x[0])
mask_fit_prof = (x >= self.fit_prof_vmin) & (x <= self.fit_prof_vmax)
try:
(a, b, r, _, _) = linregress(
np.log10(x[mask_fit_prof]), rho_prof[mask_fit_prof]
)
except ValueError as e:
print("Warning in linregress : {}".format(e))
a, b, r = np.nan, np.nan, np.nan
if first:
(self.prof,) = plt.semilogx(x, rho_prof)
plt.xlim([None, None])
plt.title("Profil")
(self.fit_prof,) = plt.plot(
x[mask_fit_prof],
a * np.log10(x[mask_fit_prof]) + b,
"--",
color="navy",
label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2),
)
self.ax_profile.set_xlabel(r"$r$")
self.ax_profile.set_ylabel(r"$\log(\rho)$")
else:
self.prof.set_data(x, rho_prof)
plt.xlim([None, None])
self.fit_prof.set_data(x[mask_fit_prof], a * np.log10(x[mask_fit_prof]) + b)
self.fit_prof.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2))
plt.legend()
plt.draw()
self.unselect_lasso()
def onselect(self, verts, select_data):
# update curve
path = Path(verts)
self.mask_flat = self.mask.flatten()
self.mask_flat = self.mask_flat & path.contains_points(select_data)
self.mask = self.mask_flat.reshape(self.shape)
self.draw()
def clicked_select(self, val):
if not self.lasso:
self.lasso = LassoSelector(
self.ax_fluct, partial(self.onselect, select_data=self.points)
)
self.lasso_gamma = LassoSelector(
self.ax_gamma, partial(self.onselect, select_data=self.rhocs)
)
self.lasso_button.color = "0.55"
else:
self.unselect_lasso()
def unselect_lasso(self):
self.lasso = False
self.lasso_gamma = False
self.lasso_button.color = "0.85"
def clicked_reset(self, val):
self.mask = (self.rr > self.rmin) & (self.rr < self.rmax)
self.mask_flat = self.mask.flatten()
self.lasso = False
self.lasso_button.color = "0.85"
self.draw()
def fit_selector(self, vmin, vmax):
self.fit_prof_vmin = vmin
self.fit_prof_vmax = vmax
self.draw()
def __init__(self, num, path="./", pp_params=None, datamap_key="fluct_coldens_z"):
""" """
Interactive plotting Interactive plotting
""" """
pp = PostProcessor(path, num, pp_params=pp_params, tag="interactive") if pp_params is None:
pp.pdf_coldens("z") pp_params = default_params()
fluct_map = pp.get_value("/maps/avg_map_coldens_z") pp_params.input.nml_filename = "disk.nml"
rr = pp.get_value("/maps/rr_z") pp_params.out.interactive = True
pp_params.pymses.map_size = 4096
pp_params.pymses.zoom = 4
fig, ax = plt.subplots(2) pp_params.pymses.variables = ["rho", "vel", "P"]
im = ax[0].imshow(fluct_map, origin="lower", cmap="RdBu_r")
cbar = plt.colorbar()
fig.canvas.mpl_connect("button_release_event", self.onbuttonrelease) pp_params.disk.enable = True
fig.canvas.mpl_connect("button_press_event", self.onbuttonpress) pp_params.disk.nb_bin = 200
fig.canvas.mpl_connect("key_press_event", self.onkeypress) pp_params.pdf.nb_bin = 100
self.fig = plt.figure(figsize=(10, 8))
self.pp = PostProcessor(path, num, pp_params=pp_params, tag="interactive")
self.pp.pdf_coldens("z")
self.pp.pdf_rho("z")
self.pp.pdf_T("z")
self.datamap = self.pp.get_value("/maps/" + datamap_key)
self.frho_map = np.log10(self.pp.get_value("/maps/fluct_rho_z"))
self.T_map = self.pp.get_value("/maps/T_z")
self.fcs_map = np.log10(np.sqrt(self.pp.get_value("/maps/fluct_T_z")))
self.rr = self.pp.get_value("/maps/rr_z")
self.shape = self.datamap.shape
x = np.arange(self.shape[0])
y = np.arange(self.shape[1])
self.xx, self.yy = np.meshgrid(x, y)
self.points = np.column_stack((self.xx.flatten(), self.yy.flatten()))
self.rhocs = np.column_stack((self.frho_map.flatten(), self.fcs_map.flatten()))
self.rmin = self.pp.pp_params.disk.rmin_pdf
self.rmax = self.pp.pp_params.disk.rmax_pdf
self.ax_fluct = plt.axes([0.05, 0.6, 0.4, 0.35])
self.ax_gamma = plt.axes([0.05, 0.15, 0.4, 0.35])
self.ax_pdf = plt.axes([0.55, 0.6, 0.4, 0.35])
self.ax_profile = plt.axes([0.55, 0.15, 0.4, 0.35])
self.fit_prof_vmin = 1e-2
self.fit_prof_vmax = 1e-1
self.fit_profile_selector = SpanSelector(
self.ax_profile,
self.fit_selector,
"horizontal",
useblit=True,
span_stays=True,
)
self.line_profile = DraggableLine(
self, self.ax_fluct, [100, 2000], [2000, 2000], 50
)
self.line_gamma = DraggableLine(self, self.ax_gamma, [0, 0], [1, 0.3], 0.03)
self.mask = (self.rr > self.rmin) & (self.rr < self.rmax)
self.mask_flat = self.mask.flatten()
ax_rmax = plt.axes([0.1, 0.03, 0.4, 0.02])
self.srmax = Slider(ax_rmax, r"$r_{max}$", 0, 0.5, valinit=self.rmax)
ax_rmin = plt.axes([0.1, 0.07, 0.4, 0.02])
self.srmin = Slider(ax_rmin, r"$r_{min}$", 0, 0.5, valinit=self.rmin)
ax_lasso = plt.axes([0.6, 0.07, 0.2, 0.02])
self.lasso = False
self.lasso_gamma = False
self.lasso_button = Button(ax_lasso, "Select region")
self.lasso_button.on_clicked(self.clicked_select)
ax_reset = plt.axes([0.6, 0.03, 0.2, 0.02])
self.reset_button = Button(ax_reset, "Reset")
self.reset_button.on_clicked(self.clicked_reset)
self.srmin.on_changed(self.update_rmin)
self.srmax.on_changed(self.update_rmax)
self.draw(True)
plt.tight_layout()
plt.show() plt.show()
+1 -27
View File
@@ -215,7 +215,7 @@ class PostProcessor(HDF5Container):
else: else:
op = ScalarOperator(getter, unit) op = ScalarOperator(getter, unit)
if pp_params.pymses.fft: if self.pp_params.pymses.fft:
rt = splatting.SplatterProcessor(self._amr, self._ro.info, op) rt = splatting.SplatterProcessor(self._amr, self._ro.info, op)
else: else:
rt = raytracing.RayTracer(self._amr, self._ro.info, op) rt = raytracing.RayTracer(self._amr, self._ro.info, op)
@@ -545,31 +545,6 @@ class PostProcessor(HDF5Container):
pdf.attrs.var = np.var pdf.attrs.var = np.var
return True return True
def _clumps(self):
name = self.path_out + "/" + self.tag + "_" + str(self.num).zfill(5)
hop_save = name + "_hop" + "_prop_struct.save"
me.make_clump_hop(
self.path,
self.num,
name + "_hop",
self.pp_params.hop.rho_thres,
self.pp_params.hop.lvl_thres,
[0.5, 0.5, 0.5],
1,
path_out=path_out + "/",
path_hop="./",
force=True,
gcomp=False,
)
hop_save = me.clump_properties(
name + "_hop", path, num, path_out=path_out + "/", gcomp=False
)
f = open(path_out + "/" + hop_save)
hop_data = pickle.load(f)
f.close()
return hop_data
def _sinks(self): def _sinks(self):
csv_name = ( csv_name = (
self.path self.path
@@ -699,7 +674,6 @@ class PostProcessor(HDF5Container):
"Teff": cst.K, "Teff": cst.K,
}, },
), ),
"clumps": Rule(self, self._clumps, group="/datasets"),
# Helpers # Helpers
"radial_bins": Rule(self, self._radial_bins, "Radial bins", "/radial"), "radial_bins": Rule(self, self._radial_bins, "Radial bins", "/radial"),
"rr": Rule(self, self._rr, "Coordinate map", "/maps"), "rr": Rule(self, self._rr, "Coordinate map", "/maps"),