Huge refactoring

This commit is contained in:
Noe Brucy
2021-06-22 12:23:03 +02:00
parent a286c08b72
commit 0dc9e8fc7b
12 changed files with 343 additions and 361 deletions
+45 -47
View File
@@ -30,7 +30,7 @@ import matplotlib.pyplot as plt
import pspec_read
from baseprocessor import Rule, BaseProcessor
from aggregator import Aggregator
from comparator import Comparator
from studyprocessor import StudyProcessor
from run_selector import RunSelector
from units import U, unit_str, convert_exp
@@ -121,7 +121,7 @@ class Plotter(Aggregator, BaseProcessor):
in_runs=None,
in_nums=None,
path_out=None,
pp_params=None,
params=None,
selector=None,
tag=None,
unit_time=U.year,
@@ -139,19 +139,19 @@ class Plotter(Aggregator, BaseProcessor):
in_nums : list or dict of the outputs numbers to consider (ex [3, 5]
or {'run1' : [3, 5], 'run2' : [4, 6])
path_out : Path where the plot will be saved. By default set to `path`
pp_params : Parameters for postprocessing. See pp_params module.
params : Parameters for postprocessing. See params module.
selector : Existing instance of RunSelector, that selects runs and outputs. If set, in_runs and
in_nums will be ignored
tag : string to add in the output and data files.
kwargs : Keyword arguments for RunSelector.
"""
super(Plotter, self).__init__(path, path_out, pp_params, tag)
super(Plotter, self).__init__(path, path_out, params, tag)
# Select runs
if selector is None:
self.selector = RunSelector(
path, in_runs, in_nums, self.pp_params.input.nml_filename, **kwargs
path, in_runs, in_nums, self.params.input.nml_filename, **kwargs
)
else:
self.selector = selector
@@ -161,22 +161,22 @@ class Plotter(Aggregator, BaseProcessor):
self.runs = self.selector.runs
self.nums = self.selector.nums
# Get comparator object
self.comp = Comparator(
# Get studyprocessor object
self.study = StudyProcessor(
path,
self.runs,
self.nums,
path_out,
self.pp_params,
self.params,
unit_time=unit_time,
selector=self.selector,
)
# Get postprocesor objets for each run
self.pp = self.comp.pp
self.snaps = self.study.snaps
# Define log prefix
self.log_id = "[plot {}] ".format(self.pp_params.out.tag)
self.log_id = "[plot {}] ".format(self.params.out.tag)
# Define rules
self.def_rules()
@@ -188,12 +188,12 @@ class Plotter(Aggregator, BaseProcessor):
def gen_simus(self):
self.simulations = {}
simu_fmt = self.pp_params.astrophysix.simu_fmt
descr_fmt = self.pp_params.astrophysix.descr_fmt
tag = self.pp_params.out.tag
simu_fmt = self.params.astrophysix.simu_fmt
descr_fmt = self.params.astrophysix.descr_fmt
tag = self.params.out.tag
for run in self.runs:
pp = self.pp[run][self.nums[run][0]]
nml = self.comp.namelist[run]
pp = self.snaps[run][self.nums[run][0]]
nml = self.study.namelist[run]
name = simu_fmt.format(run=run, tag=tag, nml=nml)
exec_time = str(datetime.datetime.fromtimestamp(os.stat(pp.path).st_ctime))
exec_time = exec_time.split(".")[0]
@@ -210,7 +210,7 @@ class Plotter(Aggregator, BaseProcessor):
for param in ramses.input_parameters:
value = None
try:
value = self.comp.get_nml(param.key, run)
value = self.study.get_nml(param.key, run)
except KeyError as e:
self._log("key {} not found".format(e), "WARNING")
@@ -237,8 +237,8 @@ class Plotter(Aggregator, BaseProcessor):
"""
Check if the dependency belongs to the plotter object or to another one (comp, pp, ..)
"""
if dep in self.comp.rules:
result = self.comp.process(
if dep in self.study.rules:
result = self.study.process(
dep, dep_arg, overwrite, self.overwrite_dep, select
)
if result is not None:
@@ -251,7 +251,7 @@ class Plotter(Aggregator, BaseProcessor):
Returns true if the plot needs to be redone
"""
return (
self.pp_params.out.interactive
self.params.out.interactive
or overwrite
or not os.path.exists(plot_filename)
)
@@ -288,7 +288,7 @@ class Plotter(Aggregator, BaseProcessor):
name_full = name
# get filetype of the output
filetype = filetype_from_ext[self.pp_params.out.ext]
filetype = filetype_from_ext[self.params.out.ext]
# Select runs and nums
if select is not None:
@@ -330,14 +330,14 @@ class Plotter(Aggregator, BaseProcessor):
# Find plot save
if from_cells or rule.kind == "cells":
if not os.path.exists(self.pp[run][num].cells_filename):
if not os.exists(self.pp[run][num].cells_filename):
self.pp[run][num].load_cells()
self.pp[run][num].unload_cells()
save = tables.open_file(self.pp[run][num].cells_filename)
elif rule.kind == "snapshot":
save = tables.open_file(self.pp[run][num].filename)
save = tables.open_file(self.snaps[run][num].filename)
else:
save = tables.open_file(self.comp.filename, "r")
save = tables.open_file(self.study.filename, "r")
# Call plot routine
try:
@@ -362,7 +362,7 @@ class Plotter(Aggregator, BaseProcessor):
if plot_info is not None:
df.plot_info = plot_info
if num is not None:
snap = self.pp[run][num].snapshot
snap = self.snaps[run][num].snapshot
if overwrite and df.name in snap.datafiles:
del snap.datafiles[df.name]
@@ -385,10 +385,10 @@ class Plotter(Aggregator, BaseProcessor):
if self._needs_computation(overwrite, plot_filename):
plot_info = rule.plot(save, arg, **kwargs)
if not self.pp_params.out.interactive and close:
if not self.params.out.interactive and close:
plt.tight_layout(pad=1)
if self.pp_params.out.save:
if self.params.out.save:
plt.savefig(plot_filename)
self._log("{} plotted".format(plot_filename), "SUCCESS")
else:
@@ -396,7 +396,7 @@ class Plotter(Aggregator, BaseProcessor):
"{} plotted".format(os.path.basename(plot_filename)), "SUCCESS"
)
if not self.pp_params.out.interactive and close:
if not self.params.out.interactive and close:
plt.close()
return plot_info
else:
@@ -406,10 +406,10 @@ class Plotter(Aggregator, BaseProcessor):
"""
Determine a filename based on rule name, run, output and parameters
"""
tag_name = self.pp_params.out.tag
tag_name = self.params.out.tag
if fmt is None and self.pp_params.out.fmt == "":
if not self.pp_params.out.tag == "":
if fmt is None and self.params.out.fmt == "":
if not self.params.out.tag == "":
tag_name = "_" + tag_name
if run is not None and num is not None:
@@ -419,11 +419,11 @@ class Plotter(Aggregator, BaseProcessor):
else:
fmt = "{out}/{name}{tag}{ext}"
elif fmt is None:
fmt = self.pp_params.out.fmt
fmt = self.params.out.fmt
nml = None
if run is not None:
nml = self.comp.namelist[run]
nml = self.study.namelist[run]
return fmt.format(
run=run,
@@ -432,7 +432,7 @@ class Plotter(Aggregator, BaseProcessor):
num=num,
nml=nml,
out=self.path_out,
ext=self.pp_params.out.ext,
ext=self.params.out.ext,
)
def get_label_run(self, run, label=None, nml_key=None, time=None):
@@ -514,11 +514,9 @@ class Plotter(Aggregator, BaseProcessor):
title = self.get_label_run(run, title, nml_key)
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time = self.save.root._v_attrs.time * self.study.info["unit_time"]
u_str = unit_str(unit_time, format="{unit}")
time_str = self.pp_params.plot.time_fmt.format(
time.express(unit_time), u_str
)
time_str = self.params.plot.time_fmt.format(time.express(unit_time), u_str)
if len(title) > 0:
title = title + " | " + time_str
else:
@@ -593,7 +591,7 @@ class Plotter(Aggregator, BaseProcessor):
dmap, extent=im_extent, origin="lower", norm=norm, cmap=cmap, **kwargs
)
plt.locator_params(axis="both", nbins=self.pp_params.plot.ntick)
plt.locator_params(axis="both", nbins=self.params.plot.ntick)
if xlabel is None:
xlabel = self._ax_title[ax_h]
@@ -781,7 +779,7 @@ class Plotter(Aggregator, BaseProcessor):
label, unit_old, unit = self._ax_label_unit(dmap_vh_node, "", unit, unit_coeff)
vel_red = self.pp_params.plot.vel_red
vel_red = self.params.plot.vel_red
# take only a subset of velocities
map_vh_red = dmap_vh[::vel_red, ::vel_red] * unit_old.express(unit)
@@ -835,7 +833,7 @@ class Plotter(Aggregator, BaseProcessor):
# TODO : redo this with im_extent
vel_red = self.pp_params.plot.vel_red
vel_red = self.params.plot.vel_red
radius = self.save.root.maps._v_attrs.radius
center = self.save.root.maps._v_attrs.center
lbox = self.save.root._v_attrs.lbox
@@ -923,7 +921,7 @@ class Plotter(Aggregator, BaseProcessor):
if nml_color is None:
color = colors[run]
else:
nml = self.comp.get_nml(nml_color, run)
nml = self.study.get_nml(nml_color, run)
try:
color = colors[nml]
except TypeError:
@@ -1043,8 +1041,8 @@ class Plotter(Aggregator, BaseProcessor):
# If relevent, get time
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format(
time = self.save.root._v_attrs.time * self.study.info["unit_time"]
time_str = self.params.plot.time_fmt.format(
time.express(unit_time), unit_time.latex.replace("text", "math")
)
time_str = f"${time_str}$"
@@ -1125,11 +1123,11 @@ class Plotter(Aggregator, BaseProcessor):
color = colors[run]
elif nml_color == "time":
time = (
self.save.root._v_attrs.time * self.comp.info["unit_time"]
self.save.root._v_attrs.time * self.study.info["unit_time"]
).express(unit_time)
color = colors(time)
else:
nml = self.comp.get_nml(nml_color, run)
nml = self.study.get_nml(nml_color, run)
try:
color = colors[nml]
except TypeError:
@@ -1254,7 +1252,7 @@ class Plotter(Aggregator, BaseProcessor):
ssfr_sun = 2.5e-9
ssfr_ken = ssfr_sun * n0 ** 1.4
coeff = ssfr_ken * 1e6 * (self.comp.info["unit_length"].express(U.pc)) ** 2
coeff = ssfr_ken * 1e6 * (self.study.info["unit_length"].express(U.pc)) ** 2
for i in np.arange(tmin, max(tmax, tmin + ymax / coeff), step):
t = np.linspace(0, tmax, 1000)
plt.plot(t + i, t * coeff, ls="--", lw=0.9, color="grey")
@@ -1661,7 +1659,7 @@ class Plotter(Aggregator, BaseProcessor):
]
# Generic rules directly from Ramses fields
for field in self.pp_params.pymses.variables:
for field in self.params.pymses.variables:
def generic_rule(name):