Huge refactoring
This commit is contained in:
+45
-47
@@ -30,7 +30,7 @@ import matplotlib.pyplot as plt
|
||||
import pspec_read
|
||||
from baseprocessor import Rule, BaseProcessor
|
||||
from aggregator import Aggregator
|
||||
from comparator import Comparator
|
||||
from studyprocessor import StudyProcessor
|
||||
from run_selector import RunSelector
|
||||
from units import U, unit_str, convert_exp
|
||||
|
||||
@@ -121,7 +121,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
in_runs=None,
|
||||
in_nums=None,
|
||||
path_out=None,
|
||||
pp_params=None,
|
||||
params=None,
|
||||
selector=None,
|
||||
tag=None,
|
||||
unit_time=U.year,
|
||||
@@ -139,19 +139,19 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
in_nums : list or dict of the outputs numbers to consider (ex [3, 5]
|
||||
or {'run1' : [3, 5], 'run2' : [4, 6])
|
||||
path_out : Path where the plot will be saved. By default set to `path`
|
||||
pp_params : Parameters for postprocessing. See pp_params module.
|
||||
params : Parameters for postprocessing. See params module.
|
||||
selector : Existing instance of RunSelector, that selects runs and outputs. If set, in_runs and
|
||||
in_nums will be ignored
|
||||
tag : string to add in the output and data files.
|
||||
kwargs : Keyword arguments for RunSelector.
|
||||
"""
|
||||
|
||||
super(Plotter, self).__init__(path, path_out, pp_params, tag)
|
||||
super(Plotter, self).__init__(path, path_out, params, tag)
|
||||
|
||||
# Select runs
|
||||
if selector is None:
|
||||
self.selector = RunSelector(
|
||||
path, in_runs, in_nums, self.pp_params.input.nml_filename, **kwargs
|
||||
path, in_runs, in_nums, self.params.input.nml_filename, **kwargs
|
||||
)
|
||||
else:
|
||||
self.selector = selector
|
||||
@@ -161,22 +161,22 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
self.runs = self.selector.runs
|
||||
self.nums = self.selector.nums
|
||||
|
||||
# Get comparator object
|
||||
self.comp = Comparator(
|
||||
# Get studyprocessor object
|
||||
self.study = StudyProcessor(
|
||||
path,
|
||||
self.runs,
|
||||
self.nums,
|
||||
path_out,
|
||||
self.pp_params,
|
||||
self.params,
|
||||
unit_time=unit_time,
|
||||
selector=self.selector,
|
||||
)
|
||||
|
||||
# Get postprocesor objets for each run
|
||||
self.pp = self.comp.pp
|
||||
self.snaps = self.study.snaps
|
||||
|
||||
# Define log prefix
|
||||
self.log_id = "[plot {}] ".format(self.pp_params.out.tag)
|
||||
self.log_id = "[plot {}] ".format(self.params.out.tag)
|
||||
|
||||
# Define rules
|
||||
self.def_rules()
|
||||
@@ -188,12 +188,12 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
|
||||
def gen_simus(self):
|
||||
self.simulations = {}
|
||||
simu_fmt = self.pp_params.astrophysix.simu_fmt
|
||||
descr_fmt = self.pp_params.astrophysix.descr_fmt
|
||||
tag = self.pp_params.out.tag
|
||||
simu_fmt = self.params.astrophysix.simu_fmt
|
||||
descr_fmt = self.params.astrophysix.descr_fmt
|
||||
tag = self.params.out.tag
|
||||
for run in self.runs:
|
||||
pp = self.pp[run][self.nums[run][0]]
|
||||
nml = self.comp.namelist[run]
|
||||
pp = self.snaps[run][self.nums[run][0]]
|
||||
nml = self.study.namelist[run]
|
||||
name = simu_fmt.format(run=run, tag=tag, nml=nml)
|
||||
exec_time = str(datetime.datetime.fromtimestamp(os.stat(pp.path).st_ctime))
|
||||
exec_time = exec_time.split(".")[0]
|
||||
@@ -210,7 +210,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
for param in ramses.input_parameters:
|
||||
value = None
|
||||
try:
|
||||
value = self.comp.get_nml(param.key, run)
|
||||
value = self.study.get_nml(param.key, run)
|
||||
except KeyError as e:
|
||||
self._log("key {} not found".format(e), "WARNING")
|
||||
|
||||
@@ -237,8 +237,8 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
"""
|
||||
Check if the dependency belongs to the plotter object or to another one (comp, pp, ..)
|
||||
"""
|
||||
if dep in self.comp.rules:
|
||||
result = self.comp.process(
|
||||
if dep in self.study.rules:
|
||||
result = self.study.process(
|
||||
dep, dep_arg, overwrite, self.overwrite_dep, select
|
||||
)
|
||||
if result is not None:
|
||||
@@ -251,7 +251,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
Returns true if the plot needs to be redone
|
||||
"""
|
||||
return (
|
||||
self.pp_params.out.interactive
|
||||
self.params.out.interactive
|
||||
or overwrite
|
||||
or not os.path.exists(plot_filename)
|
||||
)
|
||||
@@ -288,7 +288,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
name_full = name
|
||||
|
||||
# get filetype of the output
|
||||
filetype = filetype_from_ext[self.pp_params.out.ext]
|
||||
filetype = filetype_from_ext[self.params.out.ext]
|
||||
|
||||
# Select runs and nums
|
||||
if select is not None:
|
||||
@@ -330,14 +330,14 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
|
||||
# Find plot save
|
||||
if from_cells or rule.kind == "cells":
|
||||
if not os.path.exists(self.pp[run][num].cells_filename):
|
||||
if not os.exists(self.pp[run][num].cells_filename):
|
||||
self.pp[run][num].load_cells()
|
||||
self.pp[run][num].unload_cells()
|
||||
save = tables.open_file(self.pp[run][num].cells_filename)
|
||||
elif rule.kind == "snapshot":
|
||||
save = tables.open_file(self.pp[run][num].filename)
|
||||
save = tables.open_file(self.snaps[run][num].filename)
|
||||
else:
|
||||
save = tables.open_file(self.comp.filename, "r")
|
||||
save = tables.open_file(self.study.filename, "r")
|
||||
|
||||
# Call plot routine
|
||||
try:
|
||||
@@ -362,7 +362,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
if plot_info is not None:
|
||||
df.plot_info = plot_info
|
||||
if num is not None:
|
||||
snap = self.pp[run][num].snapshot
|
||||
snap = self.snaps[run][num].snapshot
|
||||
|
||||
if overwrite and df.name in snap.datafiles:
|
||||
del snap.datafiles[df.name]
|
||||
@@ -385,10 +385,10 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
if self._needs_computation(overwrite, plot_filename):
|
||||
plot_info = rule.plot(save, arg, **kwargs)
|
||||
|
||||
if not self.pp_params.out.interactive and close:
|
||||
if not self.params.out.interactive and close:
|
||||
plt.tight_layout(pad=1)
|
||||
|
||||
if self.pp_params.out.save:
|
||||
if self.params.out.save:
|
||||
plt.savefig(plot_filename)
|
||||
self._log("{} plotted".format(plot_filename), "SUCCESS")
|
||||
else:
|
||||
@@ -396,7 +396,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
"{} plotted".format(os.path.basename(plot_filename)), "SUCCESS"
|
||||
)
|
||||
|
||||
if not self.pp_params.out.interactive and close:
|
||||
if not self.params.out.interactive and close:
|
||||
plt.close()
|
||||
return plot_info
|
||||
else:
|
||||
@@ -406,10 +406,10 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
"""
|
||||
Determine a filename based on rule name, run, output and parameters
|
||||
"""
|
||||
tag_name = self.pp_params.out.tag
|
||||
tag_name = self.params.out.tag
|
||||
|
||||
if fmt is None and self.pp_params.out.fmt == "":
|
||||
if not self.pp_params.out.tag == "":
|
||||
if fmt is None and self.params.out.fmt == "":
|
||||
if not self.params.out.tag == "":
|
||||
tag_name = "_" + tag_name
|
||||
|
||||
if run is not None and num is not None:
|
||||
@@ -419,11 +419,11 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
else:
|
||||
fmt = "{out}/{name}{tag}{ext}"
|
||||
elif fmt is None:
|
||||
fmt = self.pp_params.out.fmt
|
||||
fmt = self.params.out.fmt
|
||||
|
||||
nml = None
|
||||
if run is not None:
|
||||
nml = self.comp.namelist[run]
|
||||
nml = self.study.namelist[run]
|
||||
|
||||
return fmt.format(
|
||||
run=run,
|
||||
@@ -432,7 +432,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
num=num,
|
||||
nml=nml,
|
||||
out=self.path_out,
|
||||
ext=self.pp_params.out.ext,
|
||||
ext=self.params.out.ext,
|
||||
)
|
||||
|
||||
def get_label_run(self, run, label=None, nml_key=None, time=None):
|
||||
@@ -514,11 +514,9 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
title = self.get_label_run(run, title, nml_key)
|
||||
|
||||
if put_time:
|
||||
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
|
||||
time = self.save.root._v_attrs.time * self.study.info["unit_time"]
|
||||
u_str = unit_str(unit_time, format="{unit}")
|
||||
time_str = self.pp_params.plot.time_fmt.format(
|
||||
time.express(unit_time), u_str
|
||||
)
|
||||
time_str = self.params.plot.time_fmt.format(time.express(unit_time), u_str)
|
||||
if len(title) > 0:
|
||||
title = title + " | " + time_str
|
||||
else:
|
||||
@@ -593,7 +591,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
dmap, extent=im_extent, origin="lower", norm=norm, cmap=cmap, **kwargs
|
||||
)
|
||||
|
||||
plt.locator_params(axis="both", nbins=self.pp_params.plot.ntick)
|
||||
plt.locator_params(axis="both", nbins=self.params.plot.ntick)
|
||||
|
||||
if xlabel is None:
|
||||
xlabel = self._ax_title[ax_h]
|
||||
@@ -781,7 +779,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
|
||||
label, unit_old, unit = self._ax_label_unit(dmap_vh_node, "", unit, unit_coeff)
|
||||
|
||||
vel_red = self.pp_params.plot.vel_red
|
||||
vel_red = self.params.plot.vel_red
|
||||
|
||||
# take only a subset of velocities
|
||||
map_vh_red = dmap_vh[::vel_red, ::vel_red] * unit_old.express(unit)
|
||||
@@ -835,7 +833,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
|
||||
# TODO : redo this with im_extent
|
||||
|
||||
vel_red = self.pp_params.plot.vel_red
|
||||
vel_red = self.params.plot.vel_red
|
||||
radius = self.save.root.maps._v_attrs.radius
|
||||
center = self.save.root.maps._v_attrs.center
|
||||
lbox = self.save.root._v_attrs.lbox
|
||||
@@ -923,7 +921,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
if nml_color is None:
|
||||
color = colors[run]
|
||||
else:
|
||||
nml = self.comp.get_nml(nml_color, run)
|
||||
nml = self.study.get_nml(nml_color, run)
|
||||
try:
|
||||
color = colors[nml]
|
||||
except TypeError:
|
||||
@@ -1043,8 +1041,8 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
|
||||
# If relevent, get time
|
||||
if put_time:
|
||||
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
|
||||
time_str = self.pp_params.plot.time_fmt.format(
|
||||
time = self.save.root._v_attrs.time * self.study.info["unit_time"]
|
||||
time_str = self.params.plot.time_fmt.format(
|
||||
time.express(unit_time), unit_time.latex.replace("text", "math")
|
||||
)
|
||||
time_str = f"${time_str}$"
|
||||
@@ -1125,11 +1123,11 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
color = colors[run]
|
||||
elif nml_color == "time":
|
||||
time = (
|
||||
self.save.root._v_attrs.time * self.comp.info["unit_time"]
|
||||
self.save.root._v_attrs.time * self.study.info["unit_time"]
|
||||
).express(unit_time)
|
||||
color = colors(time)
|
||||
else:
|
||||
nml = self.comp.get_nml(nml_color, run)
|
||||
nml = self.study.get_nml(nml_color, run)
|
||||
try:
|
||||
color = colors[nml]
|
||||
except TypeError:
|
||||
@@ -1254,7 +1252,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
ssfr_sun = 2.5e-9
|
||||
ssfr_ken = ssfr_sun * n0 ** 1.4
|
||||
|
||||
coeff = ssfr_ken * 1e6 * (self.comp.info["unit_length"].express(U.pc)) ** 2
|
||||
coeff = ssfr_ken * 1e6 * (self.study.info["unit_length"].express(U.pc)) ** 2
|
||||
for i in np.arange(tmin, max(tmax, tmin + ymax / coeff), step):
|
||||
t = np.linspace(0, tmax, 1000)
|
||||
plt.plot(t + i, t * coeff, ls="--", lw=0.9, color="grey")
|
||||
@@ -1661,7 +1659,7 @@ class Plotter(Aggregator, BaseProcessor):
|
||||
]
|
||||
|
||||
# Generic rules directly from Ramses fields
|
||||
for field in self.pp_params.pymses.variables:
|
||||
for field in self.params.pymses.variables:
|
||||
|
||||
def generic_rule(name):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user