Files
pipeline/gui.py
T
2020-12-14 16:59:35 +01:00

596 lines
20 KiB
Python

import matplotlib as mpl
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.path import Path
from matplotlib.widgets import (
Button,
CheckButtons,
LassoSelector,
PolygonSelector,
RadioButtons,
Slider,
SpanSelector,
)
from scipy.stats import linregress
from skimage.draw import line
from postprocessor import *
class DraggablePoint:
# http://stackoverflow.com/questions/21654008/matplotlib-drag-overlapping-points-interactively
lock = None # only one can be animated at a time
def __init__(self, parent, x=0.1, y=0.1, size=0.1):
self.parent = parent
if self.parent.list_points:
color = "r"
else:
color = "g"
self.point = patches.Ellipse(
(x, y), size, size * 3, fc=color, alpha=0.5, edgecolor=color
)
self.x = x
self.y = y
self.parent.ax.add_patch(self.point)
self.press = None
self.background = None
self.connect()
if self.parent.list_points:
line_x = [self.parent.list_points[0].x, self.x]
line_y = [self.parent.list_points[0].y, self.y]
self.line = Line2D(line_x, line_y, color="k", alpha=0.5)
self.parent.ax.add_line(self.line)
def connect(self):
"connect to all the events we need"
self.cidpress = self.point.figure.canvas.mpl_connect(
"button_press_event", self.on_press
)
self.cidrelease = self.point.figure.canvas.mpl_connect(
"button_release_event", self.on_release
)
self.cidmotion = self.point.figure.canvas.mpl_connect(
"motion_notify_event", self.on_motion
)
def on_press(self, event):
if event.inaxes != self.point.axes:
return
if DraggablePoint.lock is not None:
return
contains, attrd = self.point.contains(event)
if not contains:
return
self.press = (self.point.center), event.xdata, event.ydata
DraggablePoint.lock = self
# draw everything but the selected rectangle and store the pixel buffer
canvas = self.point.figure.canvas
axes = self.point.axes
self.point.set_animated(True)
if self == self.parent.list_points[1]:
self.line.set_animated(True)
else:
self.parent.list_points[1].line.set_animated(True)
canvas.draw()
self.background = canvas.copy_from_bbox(self.point.axes.bbox)
# now redraw just the rectangle
axes.draw_artist(self.point)
# and blit just the redrawn area
canvas.blit(axes.bbox)
def on_motion(self, event):
if DraggablePoint.lock is not self:
return
if event.inaxes != self.point.axes:
return
self.point.center, xpress, ypress = self.press
dx = event.xdata - xpress
dy = event.ydata - ypress
self.point.center = (self.point.center[0] + dx, self.point.center[1] + dy)
canvas = self.point.figure.canvas
axes = self.point.axes
# restore the background region
canvas.restore_region(self.background)
# redraw just the current rectangle
axes.draw_artist(self.point)
if self == self.parent.list_points[1]:
axes.draw_artist(self.line)
else:
self.parent.list_points[1].line.set_animated(True)
axes.draw_artist(self.parent.list_points[1].line)
self.x = self.point.center[0]
self.y = self.point.center[1]
if self == self.parent.list_points[1]:
line_x = [self.parent.list_points[0].x, self.x]
line_y = [self.parent.list_points[0].y, self.y]
self.line.set_data(line_x, line_y)
else:
line_x = [self.x, self.parent.list_points[1].x]
line_y = [self.y, self.parent.list_points[1].y]
self.parent.list_points[1].line.set_data(line_x, line_y)
# blit just the redrawn area
canvas.blit(axes.bbox)
def on_release(self, event):
"on release we reset the press data"
if DraggablePoint.lock is not self:
return
self.press = None
DraggablePoint.lock = None
# turn off the rect animation property and reset the background
self.point.set_animated(False)
if self == self.parent.list_points[1]:
self.line.set_animated(False)
else:
self.parent.list_points[1].line.set_animated(False)
self.background = None
# redraw the full figure
self.point.figure.canvas.draw()
self.x = self.point.center[0]
self.y = self.point.center[1]
self.parent.draw()
def disconnect(self):
"disconnect all the stored connection ids"
self.point.figure.canvas.mpl_disconnect(self.cidpress)
self.point.figure.canvas.mpl_disconnect(self.cidrelease)
self.point.figure.canvas.mpl_disconnect(self.cidmotion)
class DraggableLine:
def __init__(self, parent, ax, xy1, xy2, size):
self.parent = parent
self.ax = ax
self.list_points = []
self.list_points.append(DraggablePoint(self, xy1[0], xy1[1], size))
self.list_points.append(DraggablePoint(self, xy2[0], xy2[1], size))
def draw(self):
self.parent.draw(update_map=False)
def clear(self):
for p in self.list_points:
p.point.remove()
self.list_points[1].line.remove()
class FitterTool:
"""
Flexible fitter tool
"""
def __init__(self, ax, bounds=None):
self.ax = ax
self.bounds = bounds
self.bounds_selector = SpanSelector(
self.ax, self.onselect, "horizontal", useblit=True, span_stays=True
)
self.fitline = DraggableLine(self, self.ax_gamma, [0, 0], [1, 0.3], 0.05)
if update_map and self.gamma_3d_button.get_status()[1]:
lr = linregress(rho[rho > -4], cs[rho > -4])
(a, b, r, _, _) = lr
self.line_gamma.clear()
del self.line_gamma
rhomin, rhomax = np.min(rho), np.max(rho)
self.line_gamma = DraggableLine(
self,
self.ax_gamma,
[rhomin, a * rhomin + b],
[rhomax, a * rhomax + b],
0.05,
)
print("Gamma linregress : {}".format(lr))
else:
lps = self.line_gamma.list_points
xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y
a = (yb - ya) / (xb - xa)
def onselect(self, vmin, vmax):
self.bounds = (vmin, vmax)
class InteractiveGUI:
"""
This is a matplotlib interactive session to restrain analysis to a specific area
"""
def update_rmin(self, val):
# amp is the current value of the slider
self.rmin = self.srmin.val
# update curve
self.clicked_reset(None)
def update_rmax(self, val):
# amp is the current value of the slider
self.rmax = self.srmax.val
# update curve
self.clicked_reset(None)
def draw(self, first=False, update_map=True):
if first or update_map:
self.fmap = np.copy(self.datamap)
self.frho = np.copy(self.frho_map)
self.fcs = np.copy(self.fcs_map)
self.fmap[np.logical_not(self.mask)] = np.nan
## Map
plt.sca(self.ax_fluct)
if first:
self.im = plt.imshow(
self.fmap,
origin="lower",
cmap="RdBu_r",
extent=self.im_extent,
norm=mpl.colors.LogNorm(),
vmin=1e-2,
vmax=1e2,
)
plt.title("Fluctuations")
# cbar = plt.colorbar()
else:
self.im.set_data(self.fmap)
## Gamma
plt.sca(self.ax_gamma)
if self.gamma_3d_button.get_status()[0]:
rho, cs = self.rho3d[self.mask3d], self.P3d[self.mask3d]
plt.xlabel(r"$\log(\rho)$")
plt.ylabel(r"$\log(P)$")
self.get_gamma = lambda a: a
else:
rho = self.frho_map[self.mask]
cs = self.fcs_map[self.mask]
plt.xlabel(r"$\log(\rho / \bar{\rho})$")
plt.ylabel(r"$\log(c_s / \bar{c_s})$")
self.get_gamma = lambda a: 2 * a + 1
if first:
_, _, _, self.gamma_hist = plt.hist2d(
rho,
cs,
bins=100,
norm=mpl.colors.LogNorm(),
cmap=plt.get_cmap("plasma"),
)
# cbar = plt.colorbar()
else:
self.gamma_hist.remove()
_, _, _, self.gamma_hist = plt.hist2d(
rho,
cs,
bins=100,
norm=mpl.colors.LogNorm(),
cmap=plt.get_cmap("plasma"),
)
if update_map and self.gamma_3d_button.get_status()[1]:
lr = linregress(rho[rho > 2], cs[rho > 2])
(a, b, r, _, _) = lr
self.line_gamma.clear()
del self.line_gamma
rhomin, rhomax = np.min(rho), np.max(rho)
self.line_gamma = DraggableLine(
self,
self.ax_gamma,
[rhomin, a * rhomin + b],
[rhomax, a * rhomax + b],
0.05,
)
print("Gamma linregress : {}".format(lr))
else:
lps = self.line_gamma.list_points
xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y
a = (yb - ya) / (xb - xa)
self.ax_gamma.set_title("$\Gamma$ = {:.3g}".format(self.get_gamma(a)))
## PDF
if first or update_map:
plt.sca(self.ax_pdf)
nb_cells = np.sum(self.mask_flat)
if first:
self.std_nb_cells = nb_cells
values, self.edges = np.histogram(
np.log10(self.fmap.flatten()[self.mask_flat]),
self.pp.pp_params.pdf.nb_bin,
weights=np.ones(nb_cells) / self.std_nb_cells,
)
edges = self.edges
plt.xlabel(r"$\log(\Sigma / \bar{\Sigma})$")
plt.ylabel(r"$\mathcal{P}_\Sigma$")
else:
values, edges = np.histogram(
np.log10(self.fmap.flatten()[self.mask_flat]),
self.edges,
weights=np.ones(nb_cells) / self.std_nb_cells,
)
centers = 0.5 * (edges[1:] + edges[:-1])
mask_fit = (
(centers > self.pp.pp_params.pdf.xmin_fit)
& (centers < self.pp.pp_params.pdf.xmax_fit)
& (values > 0)
)
(a, b, r, _, _) = linregress(centers[mask_fit], np.log10(values[mask_fit]))
if first:
plt.step(centers, values, where="mid", alpha=0.3)
(self.step,) = plt.step(centers, values, where="mid")
(self.fit,) = plt.plot(
centers,
10 ** (a * centers + b),
"--",
color="navy",
label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2),
)
plt.yscale("log")
plt.title("PDF")
else:
self.step.set_ydata(values)
self.fit.set_ydata(10 ** (a * centers + b))
self.fit.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2))
plt.legend()
### PROFILE
plt.sca(self.ax_profile)
lps = self.line_profile.list_points
xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y
coor_pix = list(
np.asarray(
(np.array([xa, ya, xb, yb]) * self.pp.pp_params.pymses.zoom + 0.5)
* self.shape[0],
dtype=int,
)
)
xpp, ypp = line(*coor_pix)
rho_prof = self.rho_map[ypp, xpp]
xp = ((xpp / float(self.shape[0])) - 0.5) / self.pp.pp_params.pymses.zoom
yp = ((ypp / float(self.shape[1])) - 0.5) / self.pp.pp_params.pymses.zoom
print(xp, yp, xpp, ypp)
x = np.sqrt(np.abs((xp - xa) * (xb - xa) + (yp - ya) * (yb - ya)))
x = x - float(x[0])
mask_fit_prof = (x >= self.fit_prof_vmin) & (x <= self.fit_prof_vmax)
try:
(a, b, r, _, _) = linregress(
np.log10(x[mask_fit_prof]), rho_prof[mask_fit_prof]
)
except ValueError as e:
print("Warning in linregress : {}".format(e))
a, b, r = np.nan, np.nan, np.nan
xv, yv = np.array([float(xa), float(ya)])
print(xv, yv)
mask_vert = (xv >= self.xy3d[:, 0]) & (
xv < self.xy3d[:, 0] + self.pp.cells["dx"]
)
mask_vert = (
mask_vert
& (yv >= self.xy3d[:, 1])
& (yv < self.xy3d[:, 1] + self.pp.cells["dx"])
)
rho_vert = np.log10(self.pp.cells["rho"][mask_vert])
z_vert = self.pp.cells["pos"][mask_vert][:, 2] - 0.5
sorter = np.argsort(z_vert)
rho_vert = rho_vert[sorter]
z_vert = z_vert[sorter]
if first:
(self.prof,) = plt.semilogx(x, rho_prof)
(self.prof_z,) = plt.semilogx(z_vert, rho_vert)
plt.title("Profil")
(self.fit_prof,) = plt.plot(
x[mask_fit_prof],
a * np.log10(x[mask_fit_prof]) + b,
"--",
color="navy",
label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2),
)
self.ax_profile.set_xlabel(r"$r$")
self.ax_profile.set_ylabel(r"$\log(\rho)$")
else:
self.prof.set_data(x, rho_prof)
self.prof_z.set_data(z_vert, rho_vert)
self.fit_prof.set_data(x[mask_fit_prof], a * np.log10(x[mask_fit_prof]) + b)
self.fit_prof.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2))
plt.legend()
plt.draw()
self.unselect()
def onselect(self, verts, select_data):
# update curve
path = Path(verts)
self.mask_flat = self.mask.flatten()
if self.gamma_3d_button.get_status()[2]:
self.mask_flat = self.mask_flat | path.contains_points(select_data)
else:
self.mask_flat = self.mask_flat & path.contains_points(select_data)
self.mask = self.mask_flat.reshape(self.shape)
if self.gamma_3d_button.get_status()[0]:
if self.gamma_3d_button.get_status()[2]:
self.mask3d = self.mask3d | path.contains_points(self.xy3d)
else:
self.mask3d = self.mask3d & path.contains_points(self.xy3d)
self.draw()
def clicked_select(self, val, selector, button):
if not self.lasso:
self.lasso = selector(
self.ax_fluct, partial(self.onselect, select_data=self.xy)
)
if not self.gamma_3d_button.get_status()[0]:
self.lasso_gamma = selector(
self.ax_gamma, partial(self.onselect, select_data=self.rhocs)
)
button.color = "0.55"
else:
self.unselect()
def unselect(self):
self.lasso = False
self.lasso_gamma = False
self.lasso_button.color = "0.85"
self.poly_button.color = "0.85"
def clicked_reset(self, val):
self.mask = (self.rr >= self.rmin) & (self.rr <= self.rmax)
self.mask_flat = self.mask.flatten()
self.mask3d = np.isfinite(self.rho3d) & np.isfinite(self.P3d)
self.mask3d = self.mask3d & (self.rr3d >= self.rmin) & (self.rr3d <= self.rmax)
self.lasso = False
self.lasso_button.color = "0.85"
self.draw()
def fit_selector(self, vmin, vmax):
self.fit_prof_vmin = vmin
self.fit_prof_vmax = vmax
self.draw()
def __init__(self, num, path="./", pp_params=None, datamap_key="fluct_coldens_z"):
"""
Interactive plotting
"""
if pp_params is None:
pp_params = default_params()
pp_params.input.nml_filename = "disk.nml"
pp_params.out.interactive = True
pp_params.pymses.map_size = 4096
pp_params.pymses.zoom = 4
pp_params.pymses.variables = ["rho", "vel", "P"]
pp_params.disk.enable = True
pp_params.disk.nb_bin = 200
pp_params.pdf.nb_bin = 100
self.fig = plt.figure(figsize=(10, 8))
self.pp = PostProcessor(path, num, pp_params=pp_params, tag="interactive")
self.pp.pdf_coldens("z")
self.pp.pdf_rho("z")
self.pp.pdf_T("z")
self.pp.load_cells()
self.im_extent = np.array(self.pp.get_attribute("/maps", "im_extent")) - 0.5
self.datamap = self.pp.get_value("/maps/" + datamap_key)
self.rho_map = np.log10(self.pp.get_value("/maps/rho_z"))
self.frho_map = np.log10(self.pp.get_value("/maps/fluct_rho_z"))
self.T_map = self.pp.get_value("/maps/T_z")
self.fcs_map = np.log10(np.sqrt(self.pp.get_value("/maps/fluct_T_z")))
self.shape = self.datamap.shape
x = np.linspace(self.im_extent[0], self.im_extent[1], self.shape[0])
y = np.linspace(self.im_extent[2], self.im_extent[3], self.shape[1])
self.xx, self.yy = np.meshgrid(x, y)
self.xy = np.column_stack((self.xx.flatten(), self.yy.flatten()))
self.rr = np.sqrt(self.xx ** 2 + self.yy ** 2)
self.rhocs = np.column_stack((self.frho_map.flatten(), self.fcs_map.flatten()))
self.rmin = self.pp.pp_params.disk.rmin_pdf / 2.0
self.rmax = self.pp.pp_params.disk.rmax_pdf / 2.0
self.mask = (self.rr >= self.rmin) & (self.rr <= self.rmax)
self.mask_flat = self.mask.flatten()
self.rho3d = np.log10(self.pp.cells["rho"])
self.P3d = np.log10(self.pp.cells["P"])
self.xy3d = self.pp.cells["pos"][:, :2] - 0.5
self.rr3d = np.sqrt(np.sum((self.xy3d) ** 2, axis=1))
self.mask3d = np.isfinite(self.rho3d) & np.isfinite(self.P3d)
self.mask3d = self.mask3d & (self.rr3d >= self.rmin) & (self.rr3d <= self.rmax)
self.ax_fluct = plt.axes([0.05, 0.6, 0.4, 0.35])
self.ax_gamma = plt.axes([0.05, 0.15, 0.4, 0.35])
self.ax_pdf = plt.axes([0.55, 0.6, 0.4, 0.35])
self.ax_profile = plt.axes([0.55, 0.15, 0.4, 0.35])
self.fit_prof_vmin = 1e-2
self.fit_prof_vmax = 1e-1
self.fit_profile_selector = SpanSelector(
self.ax_profile,
self.fit_selector,
"horizontal",
useblit=True,
span_stays=True,
)
self.line_profile = DraggableLine(self, self.ax_fluct, [0, 0], [0.1, 0], 0.005)
self.line_gamma = DraggableLine(self, self.ax_gamma, [0, 0], [1, 0.3], 0.05)
ax_rmax = plt.axes([0.05, 0.07, 0.15, 0.02])
self.srmax = Slider(
ax_rmax, r"$r_{max}$", 0, 0.7 / pp_params.pymses.zoom, valinit=self.rmax
)
ax_rmin = plt.axes([0.05, 0.03, 0.15, 0.02])
self.srmin = Slider(
ax_rmin, r"$r_{min}$", 0, 0.7 / pp_params.pymses.zoom, valinit=self.rmin
)
ax_lasso = plt.axes([0.6, 0.07, 0.19, 0.02])
ax_poly = plt.axes([0.8, 0.07, 0.19, 0.02])
ax_gamma_3d = plt.axes([0.3, 0.01, 0.19, 0.09])
self.gamma_3d_button = CheckButtons(
ax_gamma_3d, ["3D $\Gamma$", "Fit $\Gamma$", "Union"], [True, False, False]
)
self.gamma_3d_button.on_clicked(
lambda val: self.draw(first=False, update_map=True)
)
self.lasso = False
self.lasso_gamma = False
self.lasso_button = Button(ax_lasso, "Lasso selector")
self.poly_button = Button(ax_poly, "Polygon selector")
self.lasso_button.on_clicked(
partial(
self.clicked_select, selector=LassoSelector, button=self.lasso_button
)
)
self.poly_button.on_clicked(
partial(
self.clicked_select, selector=PolygonSelector, button=self.poly_button
)
)
ax_reset = plt.axes([0.6, 0.03, 0.19, 0.02])
self.reset_button = Button(ax_reset, "Reset")
self.reset_button.on_clicked(self.clicked_reset)
self.srmin.on_changed(self.update_rmin)
self.srmax.on_changed(self.update_rmax)
self.draw(first=True, update_map=True)
plt.show()