Files
pipeline/gui.py
T
2020-12-14 16:46:54 +01:00

458 lines
16 KiB
Python

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons, LassoSelector, SpanSelector
import matplotlib.patches as patches
from matplotlib.lines import Line2D
from matplotlib.path import Path
from scipy.stats import linregress
from postprocessor import *
from skimage.draw import line
class DraggablePoint:
# http://stackoverflow.com/questions/21654008/matplotlib-drag-overlapping-points-interactively
lock = None # only one can be animated at a time
def __init__(self, parent, x=0.1, y=0.1, size=0.1):
self.parent = parent
if self.parent.list_points:
color = "r"
else:
color = "g"
self.point = patches.Ellipse(
(x, y), size, size * 3, fc=color, alpha=0.5, edgecolor=color
)
self.x = x
self.y = y
self.parent.ax.add_patch(self.point)
self.press = None
self.background = None
self.connect()
if self.parent.list_points:
line_x = [self.parent.list_points[0].x, self.x]
line_y = [self.parent.list_points[0].y, self.y]
self.line = Line2D(line_x, line_y, color="k", alpha=0.5)
self.parent.ax.add_line(self.line)
def connect(self):
"connect to all the events we need"
self.cidpress = self.point.figure.canvas.mpl_connect(
"button_press_event", self.on_press
)
self.cidrelease = self.point.figure.canvas.mpl_connect(
"button_release_event", self.on_release
)
self.cidmotion = self.point.figure.canvas.mpl_connect(
"motion_notify_event", self.on_motion
)
def on_press(self, event):
if event.inaxes != self.point.axes:
return
if DraggablePoint.lock is not None:
return
contains, attrd = self.point.contains(event)
if not contains:
return
self.press = (self.point.center), event.xdata, event.ydata
DraggablePoint.lock = self
# draw everything but the selected rectangle and store the pixel buffer
canvas = self.point.figure.canvas
axes = self.point.axes
self.point.set_animated(True)
if self == self.parent.list_points[1]:
self.line.set_animated(True)
else:
self.parent.list_points[1].line.set_animated(True)
canvas.draw()
self.background = canvas.copy_from_bbox(self.point.axes.bbox)
# now redraw just the rectangle
axes.draw_artist(self.point)
# and blit just the redrawn area
canvas.blit(axes.bbox)
def on_motion(self, event):
if DraggablePoint.lock is not self:
return
if event.inaxes != self.point.axes:
return
self.point.center, xpress, ypress = self.press
dx = event.xdata - xpress
dy = event.ydata - ypress
self.point.center = (self.point.center[0] + dx, self.point.center[1] + dy)
canvas = self.point.figure.canvas
axes = self.point.axes
# restore the background region
canvas.restore_region(self.background)
# redraw just the current rectangle
axes.draw_artist(self.point)
if self == self.parent.list_points[1]:
axes.draw_artist(self.line)
else:
self.parent.list_points[1].line.set_animated(True)
axes.draw_artist(self.parent.list_points[1].line)
self.x = self.point.center[0]
self.y = self.point.center[1]
if self == self.parent.list_points[1]:
line_x = [self.parent.list_points[0].x, self.x]
line_y = [self.parent.list_points[0].y, self.y]
self.line.set_data(line_x, line_y)
else:
line_x = [self.x, self.parent.list_points[1].x]
line_y = [self.y, self.parent.list_points[1].y]
self.parent.list_points[1].line.set_data(line_x, line_y)
# blit just the redrawn area
canvas.blit(axes.bbox)
def on_release(self, event):
"on release we reset the press data"
if DraggablePoint.lock is not self:
return
self.press = None
DraggablePoint.lock = None
# turn off the rect animation property and reset the background
self.point.set_animated(False)
if self == self.parent.list_points[1]:
self.line.set_animated(False)
else:
self.parent.list_points[1].line.set_animated(False)
self.background = None
# redraw the full figure
self.point.figure.canvas.draw()
self.x = self.point.center[0]
self.y = self.point.center[1]
self.parent.draw()
def disconnect(self):
"disconnect all the stored connection ids"
self.point.figure.canvas.mpl_disconnect(self.cidpress)
self.point.figure.canvas.mpl_disconnect(self.cidrelease)
self.point.figure.canvas.mpl_disconnect(self.cidmotion)
class DraggableLine:
def __init__(self, parent, ax, xy1, xy2, size):
self.parent = parent
self.ax = ax
self.list_points = []
self.list_points.append(DraggablePoint(self, xy1[0], xy1[1], size))
self.list_points.append(DraggablePoint(self, xy2[0], xy2[1], size))
def draw(self):
self.parent.draw(update_map=False)
class InteractiveGUI:
"""
This is a matplotlib interactive session to restrain analysis to a specific area
"""
def update_rmin(self, val):
# amp is the current value of the slider
self.rmin = self.srmin.val
# update curve
self.draw()
# redraw canvas while idle
plt.draw()
def update_rmax(self, val):
# amp is the current value of the slider
self.rmax = self.srmax.val
# update curve
self.draw()
# redraw canvas while idle
plt.draw()
def draw(self, first=False, update_map=True):
if first or update_map:
self.fmap = np.copy(self.datamap)
self.frho = np.copy(self.frho_map)
self.fcs = np.copy(self.fcs_map)
self.fmap[np.logical_not(self.mask)] = np.nan
## Map
plt.sca(self.ax_fluct)
if first:
self.im = plt.imshow(
self.fmap,
origin="lower",
cmap="RdBu_r",
norm=mpl.colors.LogNorm(),
vmin=1e-2,
vmax=1e2,
)
plt.title("Fluctuations")
# cbar = plt.colorbar()
else:
self.im.set_data(self.fmap)
## Gamma
plt.sca(self.ax_gamma)
# (a, b, r, _, _) = linregress(self.frho_map[self.mask], self.fcs_map[self.mask])
if first:
_, _, _, self.gamma_hist = plt.hist2d(
self.frho_map[self.mask],
self.fcs_map[self.mask],
bins=100,
norm=mpl.colors.LogNorm(),
cmap=plt.get_cmap("plasma"),
)
plt.xlabel(r"$\log(\rho / \bar{\rho})$")
plt.ylabel(r"$\log(c_s / \bar{c_s})$")
# cbar = plt.colorbar()
else:
self.gamma_hist.remove()
_, _, _, self.gamma_hist = plt.hist2d(
self.frho_map[self.mask],
self.fcs_map[self.mask],
bins=100,
norm=mpl.colors.LogNorm(),
cmap=plt.get_cmap("plasma"),
)
plt.legend()
lps = self.line_gamma.list_points
xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y
a = (yb - ya) / (xb - xa)
self.ax_gamma.set_title("$\Gamma$ = {:.3g}".format(2 * a + 1))
## PDF
if first or update_map:
plt.sca(self.ax_pdf)
nb_cells = np.sum(self.mask_flat.flatten())
if first:
self.std_nb_cells = nb_cells
values, self.edges = np.histogram(
np.log10(self.fmap.flatten()[self.mask_flat]),
self.pp.pp_params.pdf.nb_bin,
weights=np.ones(nb_cells) / self.std_nb_cells,
)
edges = self.edges
plt.xlabel(r"$\log(\Sigma / \bar{\Sigma})$")
plt.ylabel(r"$\mathcal{P}_\Sigma$")
else:
values, edges = np.histogram(
np.log10(self.fmap.flatten()[self.mask_flat]),
self.edges,
weights=np.ones(nb_cells) / self.std_nb_cells,
)
centers = 0.5 * (edges[1:] + edges[:-1])
mask_fit = (
(centers > self.pp.pp_params.pdf.xmin_fit)
& (centers < self.pp.pp_params.pdf.xmax_fit)
& (values > 0)
)
(a, b, r, _, _) = linregress(centers[mask_fit], np.log10(values[mask_fit]))
if first:
plt.step(centers, values, where="mid", alpha=0.3)
(self.step,) = plt.step(centers, values, where="mid")
(self.fit,) = plt.plot(
centers,
10 ** (a * centers + b),
"--",
color="navy",
label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2),
)
plt.yscale("log")
plt.title("PDF")
else:
self.step.set_ydata(values)
self.fit.set_ydata(10 ** (a * centers + b))
self.fit.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2))
plt.legend()
### PROFILE
plt.sca(self.ax_profile)
lps = self.line_profile.list_points
xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y
xp, yp = line(int(xa), int(ya), int(xb), int(yb))
rho_prof = self.frho_map[yp, xp]
# Position on the line
# x = np.linspace(0, 1, len(rho_prof))
x = np.sqrt(np.abs((xp - xa) * (xb - xa) + (yp - ya) * (yb - ya)))
x = (x - float(x[0])) / self.shape[0] # /(x[-1] - x[0])
mask_fit_prof = (x >= self.fit_prof_vmin) & (x <= self.fit_prof_vmax)
try:
(a, b, r, _, _) = linregress(
np.log10(x[mask_fit_prof]), rho_prof[mask_fit_prof]
)
except ValueError as e:
print("Warning in linregress : {}".format(e))
a, b, r = np.nan, np.nan, np.nan
if first:
(self.prof,) = plt.semilogx(x, rho_prof)
plt.xlim([None, None])
plt.title("Profil")
(self.fit_prof,) = plt.plot(
x[mask_fit_prof],
a * np.log10(x[mask_fit_prof]) + b,
"--",
color="navy",
label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2),
)
self.ax_profile.set_xlabel(r"$r$")
self.ax_profile.set_ylabel(r"$\log(\rho)$")
else:
self.prof.set_data(x, rho_prof)
plt.xlim([None, None])
self.fit_prof.set_data(x[mask_fit_prof], a * np.log10(x[mask_fit_prof]) + b)
self.fit_prof.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r ** 2))
plt.legend()
plt.draw()
self.unselect_lasso()
def onselect(self, verts, select_data):
# update curve
path = Path(verts)
self.mask_flat = self.mask.flatten()
self.mask_flat = self.mask_flat & path.contains_points(select_data)
self.mask = self.mask_flat.reshape(self.shape)
self.draw()
def clicked_select(self, val):
if not self.lasso:
self.lasso = LassoSelector(
self.ax_fluct, partial(self.onselect, select_data=self.points)
)
self.lasso_gamma = LassoSelector(
self.ax_gamma, partial(self.onselect, select_data=self.rhocs)
)
self.lasso_button.color = "0.55"
else:
self.unselect_lasso()
def unselect_lasso(self):
self.lasso = False
self.lasso_gamma = False
self.lasso_button.color = "0.85"
def clicked_reset(self, val):
self.mask = (self.rr > self.rmin) & (self.rr < self.rmax)
self.mask_flat = self.mask.flatten()
self.lasso = False
self.lasso_button.color = "0.85"
self.draw()
def fit_selector(self, vmin, vmax):
self.fit_prof_vmin = vmin
self.fit_prof_vmax = vmax
self.draw()
def __init__(self, num, path="./", pp_params=None, datamap_key="fluct_coldens_z"):
"""
Interactive plotting
"""
if pp_params is None:
pp_params = default_params()
pp_params.input.nml_filename = "disk.nml"
pp_params.out.interactive = True
pp_params.pymses.map_size = 4096
pp_params.pymses.zoom = 4
pp_params.pymses.variables = ["rho", "vel", "P"]
pp_params.disk.enable = True
pp_params.disk.nb_bin = 200
pp_params.pdf.nb_bin = 100
self.fig = plt.figure(figsize=(10, 8))
self.pp = PostProcessor(path, num, pp_params=pp_params, tag="interactive")
self.pp.pdf_coldens("z")
self.pp.pdf_rho("z")
self.pp.pdf_T("z")
self.datamap = self.pp.get_value("/maps/" + datamap_key)
self.frho_map = np.log10(self.pp.get_value("/maps/fluct_rho_z"))
self.T_map = self.pp.get_value("/maps/T_z")
self.fcs_map = np.log10(np.sqrt(self.pp.get_value("/maps/fluct_T_z")))
self.rr = self.pp.get_value("/maps/rr_z")
self.shape = self.datamap.shape
x = np.arange(self.shape[0])
y = np.arange(self.shape[1])
self.xx, self.yy = np.meshgrid(x, y)
self.points = np.column_stack((self.xx.flatten(), self.yy.flatten()))
self.rhocs = np.column_stack((self.frho_map.flatten(), self.fcs_map.flatten()))
self.rmin = self.pp.pp_params.disk.rmin_pdf
self.rmax = self.pp.pp_params.disk.rmax_pdf
self.ax_fluct = plt.axes([0.05, 0.6, 0.4, 0.35])
self.ax_gamma = plt.axes([0.05, 0.15, 0.4, 0.35])
self.ax_pdf = plt.axes([0.55, 0.6, 0.4, 0.35])
self.ax_profile = plt.axes([0.55, 0.15, 0.4, 0.35])
self.fit_prof_vmin = 1e-2
self.fit_prof_vmax = 1e-1
self.fit_profile_selector = SpanSelector(
self.ax_profile,
self.fit_selector,
"horizontal",
useblit=True,
span_stays=True,
)
self.line_profile = DraggableLine(
self, self.ax_fluct, [100, 2000], [2000, 2000], 50
)
self.line_gamma = DraggableLine(self, self.ax_gamma, [0, 0], [1, 0.3], 0.03)
self.mask = (self.rr > self.rmin) & (self.rr < self.rmax)
self.mask_flat = self.mask.flatten()
ax_rmax = plt.axes([0.1, 0.03, 0.4, 0.02])
self.srmax = Slider(ax_rmax, r"$r_{max}$", 0, 0.5, valinit=self.rmax)
ax_rmin = plt.axes([0.1, 0.07, 0.4, 0.02])
self.srmin = Slider(ax_rmin, r"$r_{min}$", 0, 0.5, valinit=self.rmin)
ax_lasso = plt.axes([0.6, 0.07, 0.2, 0.02])
self.lasso = False
self.lasso_gamma = False
self.lasso_button = Button(ax_lasso, "Select region")
self.lasso_button.on_clicked(self.clicked_select)
ax_reset = plt.axes([0.6, 0.03, 0.2, 0.02])
self.reset_button = Button(ax_reset, "Reset")
self.reset_button.on_clicked(self.clicked_reset)
self.srmin.on_changed(self.update_rmin)
self.srmax.on_changed(self.update_rmax)
self.draw(True)
plt.show()