Files
pipeline/fragdisk/gui.py
T
2023-01-30 12:12:14 +01:00

562 lines
19 KiB
Python

import matplotlib as mpl
import numpy as np
from functools import partial
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.path import Path
from matplotlib.widgets import (
Button,
CheckButtons,
LassoSelector,
PolygonSelector,
Slider,
SpanSelector,
)
from scipy.stats import linregress
from skimage.draw import line
from snapshotprocessor import SnapshotProcessor
from utils.params import default_params
class DraggablePoint:
# http://stackoverflow.com/questions/21654008/matplotlib-drag-overlapping-points-interactively
lock = None # only one can be animated at a time
def __init__(self, parent, x=0.1, y=0.1, size=0.1):
self.parent = parent
if self.parent.list_points:
color = "r"
else:
color = "g"
self.point = patches.Ellipse(
(x, y), size, size * 3, fc=color, alpha=0.5, edgecolor=color
)
self.x = x
self.y = y
self.parent.ax.add_patch(self.point)
self.press = None
self.background = None
self.connect()
if self.parent.list_points:
line_x = [self.parent.list_points[0].x, self.x]
line_y = [self.parent.list_points[0].y, self.y]
self.line = Line2D(line_x, line_y, color="k", alpha=0.5)
self.parent.ax.add_line(self.line)
def connect(self):
"connect to all the events we need"
self.cidpress = self.point.figure.canvas.mpl_connect(
"button_press_event", self.on_press
)
self.cidrelease = self.point.figure.canvas.mpl_connect(
"button_release_event", self.on_release
)
self.cidmotion = self.point.figure.canvas.mpl_connect(
"motion_notify_event", self.on_motion
)
def on_press(self, event):
if event.inaxes != self.point.axes:
return
if DraggablePoint.lock is not None:
return
contains, attrd = self.point.contains(event)
if not contains:
return
self.press = (self.point.center), event.xdata, event.ydata
DraggablePoint.lock = self
# draw everything but the selected rectangle and store the pixel buffer
canvas = self.point.figure.canvas
axes = self.point.axes
self.point.set_animated(True)
if self == self.parent.list_points[1]:
self.line.set_animated(True)
else:
self.parent.list_points[1].line.set_animated(True)
canvas.draw()
self.background = canvas.copy_from_bbox(self.point.axes.bbox)
# now redraw just the rectangle
axes.draw_artist(self.point)
# and blit just the redrawn area
canvas.blit(axes.bbox)
def on_motion(self, event):
if DraggablePoint.lock is not self:
return
if event.inaxes != self.point.axes:
return
self.point.center, xpress, ypress = self.press
dx = event.xdata - xpress
dy = event.ydata - ypress
self.point.center = (self.point.center[0] + dx, self.point.center[1] + dy)
canvas = self.point.figure.canvas
axes = self.point.axes
# restore the background region
canvas.restore_region(self.background)
# redraw just the current rectangle
axes.draw_artist(self.point)
if self == self.parent.list_points[1]:
axes.draw_artist(self.line)
else:
self.parent.list_points[1].line.set_animated(True)
axes.draw_artist(self.parent.list_points[1].line)
self.x = self.point.center[0]
self.y = self.point.center[1]
if self == self.parent.list_points[1]:
line_x = [self.parent.list_points[0].x, self.x]
line_y = [self.parent.list_points[0].y, self.y]
self.line.set_data(line_x, line_y)
else:
line_x = [self.x, self.parent.list_points[1].x]
line_y = [self.y, self.parent.list_points[1].y]
self.parent.list_points[1].line.set_data(line_x, line_y)
# blit just the redrawn area
canvas.blit(axes.bbox)
def on_release(self, event):
"on release we reset the press data"
if DraggablePoint.lock is not self:
return
self.press = None
DraggablePoint.lock = None
# turn off the rect animation property and reset the background
self.point.set_animated(False)
if self == self.parent.list_points[1]:
self.line.set_animated(False)
else:
self.parent.list_points[1].line.set_animated(False)
self.background = None
# redraw the full figure
self.point.figure.canvas.draw()
self.x = self.point.center[0]
self.y = self.point.center[1]
self.parent.draw()
def disconnect(self):
"disconnect all the stored connection ids"
self.point.figure.canvas.mpl_disconnect(self.cidpress)
self.point.figure.canvas.mpl_disconnect(self.cidrelease)
self.point.figure.canvas.mpl_disconnect(self.cidmotion)
class DraggableLine:
def __init__(self, parent, ax, xy1, xy2, size):
self.parent = parent
self.ax = ax
self.list_points = []
self.list_points.append(DraggablePoint(self, xy1[0], xy1[1], size))
self.list_points.append(DraggablePoint(self, xy2[0], xy2[1], size))
def draw(self):
self.parent.draw(update_map=False)
def clear(self):
for p in self.list_points:
p.point.remove()
self.list_points[1].line.remove()
class InteractiveGUI:
"""
This is a matplotlib interactive session to restrain analysis to a specific area
"""
def update_rmin(self, val):
# amp is the current value of the slider
self.rmin = self.srmin.val
# update curve
self.clicked_reset(None)
def update_rmax(self, val):
# amp is the current value of the slider
self.rmax = self.srmax.val
# update curve
self.clicked_reset(None)
def draw(self, first=False, update_map=True):
if first or update_map:
self.fmap = np.copy(self.datamap)
self.frho = np.copy(self.frho_map)
self.fcs = np.copy(self.fcs_map)
self.fmap[np.logical_not(self.mask)] = np.nan
# Map
plt.sca(self.ax_fluct)
if first:
self.im = plt.imshow(
self.fmap,
origin="lower",
cmap="RdBu_r",
extent=self.im_extent,
norm=mpl.colors.LogNorm(),
vmin=1e-2,
vmax=1e2,
)
plt.title("Fluctuations")
# cbar = plt.colorbar()
else:
self.im.set_data(self.fmap)
# Gamma
plt.sca(self.ax_gamma)
if self.gamma_3d_button.get_status()[0]:
rho, cs = self.rho3d[self.mask3d], self.P3d[self.mask3d]
plt.xlabel(r"$\log(\rho)$")
plt.ylabel(r"$\log(P)$")
self.get_gamma = lambda a: a
else:
rho = self.frho_map[self.mask]
cs = self.fcs_map[self.mask]
plt.xlabel(r"$\log(\rho / \bar{\rho})$")
plt.ylabel(r"$\log(c_s / \bar{c_s})$")
self.get_gamma = lambda a: 2 * a + 1
if first:
_, _, _, self.gamma_hist = plt.hist2d(
rho,
cs,
bins=100,
norm=mpl.colors.LogNorm(),
cmap=plt.get_cmap("plasma"),
)
# cbar = plt.colorbar()
else:
self.gamma_hist.remove()
_, _, _, self.gamma_hist = plt.hist2d(
rho,
cs,
bins=100,
norm=mpl.colors.LogNorm(),
cmap=plt.get_cmap("plasma"),
)
if update_map and self.gamma_3d_button.get_status()[1]:
lr = linregress(rho[rho > 2], cs[rho > 2])
(a, b, r, _, _) = lr
self.line_gamma.clear()
del self.line_gamma
rhomin, rhomax = np.min(rho), np.max(rho)
self.line_gamma = DraggableLine(
self,
self.ax_gamma,
[rhomin, a * rhomin + b],
[rhomax, a * rhomax + b],
0.05,
)
print("Gamma linregress : {}".format(lr))
else:
lps = self.line_gamma.list_points
xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y
a = (yb - ya) / (xb - xa)
self.ax_gamma.set_title("$\Gamma$ = {:.3g}".format(self.get_gamma(a)))
# PDF
if first or update_map:
plt.sca(self.ax_pdf)
nb_cells = np.sum(self.mask_flat)
if first:
self.std_nb_cells = nb_cells
values, self.edges = np.histogram(
np.log10(self.fmap.flatten()[self.mask_flat]),
self.pp.params.pdf.nb_bin,
weights=np.ones(nb_cells) / self.std_nb_cells,
)
edges = self.edges
plt.xlabel(r"$\log(\Sigma / \bar{\Sigma})$")
plt.ylabel(r"$\mathcal{P}_\Sigma$")
else:
values, edges = np.histogram(
np.log10(self.fmap.flatten()[self.mask_flat]),
self.edges,
weights=np.ones(nb_cells) / self.std_nb_cells,
)
centers = 0.5 * (edges[1:] + edges[:-1])
mask_fit = (
(centers > self.pp.params.pdf.xmin_fit)
& (centers < self.pp.params.pdf.xmax_fit)
& (values > 0)
)
(a, b, r, _, _) = linregress(centers[mask_fit], np.log10(values[mask_fit]))
if first:
plt.step(centers, values, where="mid", alpha=0.3)
(self.step,) = plt.step(centers, values, where="mid")
(self.fit,) = plt.plot(
centers,
10 ** (a * centers + b),
"--",
color="navy",
label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r**2),
)
plt.yscale("log")
plt.title("PDF")
else:
self.step.set_ydata(values)
self.fit.set_ydata(10 ** (a * centers + b))
self.fit.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r**2))
plt.legend()
# PROFILE
plt.sca(self.ax_profile)
lps = self.line_profile.list_points
xa, ya, xb, yb = lps[0].x, lps[0].y, lps[1].x, lps[1].y
coor_pix = list(
np.asarray(
(np.array([xa, ya, xb, yb]) * self.pp.params.pymses.zoom + 0.5)
* self.shape[0],
dtype=int,
)
)
xpp, ypp = line(*coor_pix)
rho_prof = self.rho_map[ypp, xpp]
xp = ((xpp / float(self.shape[0])) - 0.5) / self.pp.params.pymses.zoom
yp = ((ypp / float(self.shape[1])) - 0.5) / self.pp.params.pymses.zoom
print(xp, yp, xpp, ypp)
x = np.sqrt(np.abs((xp - xa) * (xb - xa) + (yp - ya) * (yb - ya)))
x = x - float(x[0])
mask_fit_prof = (x >= self.fit_prof_vmin) & (x <= self.fit_prof_vmax)
try:
(a, b, r, _, _) = linregress(
np.log10(x[mask_fit_prof]), rho_prof[mask_fit_prof]
)
except ValueError as e:
print("Warning in linregress : {}".format(e))
a, b, r = np.nan, np.nan, np.nan
xv, yv = np.array([float(xa), float(ya)])
print(xv, yv)
mask_vert = (xv >= self.xy3d[:, 0]) & (
xv < self.xy3d[:, 0] + self.pp.cells["dx"]
)
mask_vert = (
mask_vert
& (yv >= self.xy3d[:, 1])
& (yv < self.xy3d[:, 1] + self.pp.cells["dx"])
)
rho_vert = np.log10(self.pp.cells["rho"][mask_vert])
z_vert = self.pp.cells["pos"][mask_vert][:, 2] - 0.5
sorter = np.argsort(z_vert)
rho_vert = rho_vert[sorter]
z_vert = z_vert[sorter]
if first:
(self.prof,) = plt.semilogx(x, rho_prof)
(self.prof_z,) = plt.semilogx(z_vert, rho_vert)
plt.title("Profil")
(self.fit_prof,) = plt.plot(
x[mask_fit_prof],
a * np.log10(x[mask_fit_prof]) + b,
"--",
color="navy",
label=r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r**2),
)
self.ax_profile.set_xlabel(r"$r$")
self.ax_profile.set_ylabel(r"$\log(\rho)$")
else:
self.prof.set_data(x, rho_prof)
self.prof_z.set_data(z_vert, rho_vert)
self.fit_prof.set_data(x[mask_fit_prof], a * np.log10(x[mask_fit_prof]) + b)
self.fit_prof.set_label(r"a = {:.3g}, $R^2$ = {:.3g}".format(a, r**2))
plt.legend()
plt.draw()
self.unselect()
def onselect(self, verts, select_data):
# update curve
path = Path(verts)
self.mask_flat = self.mask.flatten()
if self.gamma_3d_button.get_status()[2]:
self.mask_flat = self.mask_flat | path.contains_points(select_data)
else:
self.mask_flat = self.mask_flat & path.contains_points(select_data)
self.mask = self.mask_flat.reshape(self.shape)
if self.gamma_3d_button.get_status()[0]:
if self.gamma_3d_button.get_status()[2]:
self.mask3d = self.mask3d | path.contains_points(self.xy3d)
else:
self.mask3d = self.mask3d & path.contains_points(self.xy3d)
self.draw()
def clicked_select(self, val, selector, button):
if not self.lasso:
self.lasso = selector(
self.ax_fluct, partial(self.onselect, select_data=self.xy)
)
if not self.gamma_3d_button.get_status()[0]:
self.lasso_gamma = selector(
self.ax_gamma, partial(self.onselect, select_data=self.rhocs)
)
button.color = "0.55"
else:
self.unselect()
def unselect(self):
self.lasso = False
self.lasso_gamma = False
self.lasso_button.color = "0.85"
self.poly_button.color = "0.85"
def clicked_reset(self, val):
self.mask = (self.rr >= self.rmin) & (self.rr <= self.rmax)
self.mask_flat = self.mask.flatten()
self.mask3d = np.isfinite(self.rho3d) & np.isfinite(self.P3d)
self.mask3d = self.mask3d & (self.rr3d >= self.rmin) & (self.rr3d <= self.rmax)
self.lasso = False
self.lasso_button.color = "0.85"
self.draw()
def fit_selector(self, vmin, vmax):
self.fit_prof_vmin = vmin
self.fit_prof_vmax = vmax
self.draw()
def __init__(self, num, path="./", params=None, datamap_key="fluct_coldens_z"):
"""
Interactive plotting
"""
if params is None:
params = default_params()
params.input.nml_filename = "disk.nml"
params.out.interactive = True
params.pymses.map_size = 4096
params.pymses.zoom = 4
params.pymses.variables = ["rho", "vel", "P"]
params.disk.enable = True
params.disk.nb_bin = 200
params.pdf.nb_bin = 100
self.fig = plt.figure(figsize=(10, 8))
self.pp = SnapshotProcessor(path, num, params=params, tag="interactive")
self.pp.pdf_coldens("z")
self.pp.pdf_rho("z")
self.pp.pdf_T("z")
self.pp.load_cells()
self.im_extent = np.array(self.pp.get_attribute("/maps", "im_extent")) - 0.5
self.datamap = self.pp.get_value("/maps/" + datamap_key)
self.rho_map = np.log10(self.pp.get_value("/maps/rho_z"))
self.frho_map = np.log10(self.pp.get_value("/maps/fluct_rho_z"))
self.T_map = self.pp.get_value("/maps/T_z")
self.fcs_map = np.log10(np.sqrt(self.pp.get_value("/maps/fluct_T_z")))
self.shape = self.datamap.shape
x = np.linspace(self.im_extent[0], self.im_extent[1], self.shape[0])
y = np.linspace(self.im_extent[2], self.im_extent[3], self.shape[1])
self.xx, self.yy = np.meshgrid(x, y)
self.xy = np.column_stack((self.xx.flatten(), self.yy.flatten()))
self.rr = np.sqrt(self.xx**2 + self.yy**2)
self.rhocs = np.column_stack((self.frho_map.flatten(), self.fcs_map.flatten()))
self.rmin = self.pp.params.disk.rmin_pdf / 2.0
self.rmax = self.pp.params.disk.rmax_pdf / 2.0
self.mask = (self.rr >= self.rmin) & (self.rr <= self.rmax)
self.mask_flat = self.mask.flatten()
self.rho3d = np.log10(self.pp.cells["rho"])
self.P3d = np.log10(self.pp.cells["P"])
self.xy3d = self.pp.cells["pos"][:, :2] - 0.5
self.rr3d = np.sqrt(np.sum((self.xy3d) ** 2, axis=1))
self.mask3d = np.isfinite(self.rho3d) & np.isfinite(self.P3d)
self.mask3d = self.mask3d & (self.rr3d >= self.rmin) & (self.rr3d <= self.rmax)
self.ax_fluct = plt.axes([0.05, 0.6, 0.4, 0.35])
self.ax_gamma = plt.axes([0.05, 0.15, 0.4, 0.35])
self.ax_pdf = plt.axes([0.55, 0.6, 0.4, 0.35])
self.ax_profile = plt.axes([0.55, 0.15, 0.4, 0.35])
self.fit_prof_vmin = 1e-2
self.fit_prof_vmax = 1e-1
self.fit_profile_selector = SpanSelector(
self.ax_profile,
self.fit_selector,
"horizontal",
useblit=True,
span_stays=True,
)
self.line_profile = DraggableLine(self, self.ax_fluct, [0, 0], [0.1, 0], 0.005)
self.line_gamma = DraggableLine(self, self.ax_gamma, [0, 0], [1, 0.3], 0.05)
ax_rmax = plt.axes([0.05, 0.07, 0.15, 0.02])
self.srmax = Slider(
ax_rmax, r"$r_{max}$", 0, 0.7 / params.pymses.zoom, valinit=self.rmax
)
ax_rmin = plt.axes([0.05, 0.03, 0.15, 0.02])
self.srmin = Slider(
ax_rmin, r"$r_{min}$", 0, 0.7 / params.pymses.zoom, valinit=self.rmin
)
ax_lasso = plt.axes([0.6, 0.07, 0.19, 0.02])
ax_poly = plt.axes([0.8, 0.07, 0.19, 0.02])
ax_gamma_3d = plt.axes([0.3, 0.01, 0.19, 0.09])
self.gamma_3d_button = CheckButtons(
ax_gamma_3d, ["3D $\Gamma$", "Fit $\Gamma$", "Union"], [True, False, False]
)
self.gamma_3d_button.on_clicked(
lambda val: self.draw(first=False, update_map=True)
)
self.lasso = False
self.lasso_gamma = False
self.lasso_button = Button(ax_lasso, "Lasso selector")
self.poly_button = Button(ax_poly, "Polygon selector")
self.lasso_button.on_clicked(
partial(
self.clicked_select, selector=LassoSelector, button=self.lasso_button
)
)
self.poly_button.on_clicked(
partial(
self.clicked_select, selector=PolygonSelector, button=self.poly_button
)
)
ax_reset = plt.axes([0.6, 0.03, 0.19, 0.02])
self.reset_button = Button(ax_reset, "Reset")
self.reset_button.on_clicked(self.clicked_reset)
self.srmin.on_changed(self.update_rmin)
self.srmax.on_changed(self.update_rmax)
self.draw(first=True, update_map=True)
plt.show()