improvements

This commit is contained in:
Noe Brucy
2020-03-14 18:09:04 +01:00
parent 1c2750a7bd
commit c2f866b37f
5 changed files with 97 additions and 37 deletions
+10 -11
View File
@@ -92,21 +92,19 @@ class Comparator(Aggregator, HDF5Container):
"Get real units from info files" "Get real units from info files"
if isinstance(unit, cst.Unit): if isinstance(unit, cst.Unit):
return unit return unit
elif isinstance(unit, str): if isinstance(unit, str):
# assert(not run is None) res = self.info[unit]
return self.info[unit] # [run][unit] if unit == "unit_length":
# elif unit.keys()[0] in self.runs: res = res / self.info["boxlen"]
# for run in unit: return res
# unit[run] = self._get_units(unit[run], run=run) if unit.keys()[0] in self.info:
# return unit
elif unit.keys()[0] in self.info:
new_unit = cst.none new_unit = cst.none
for base_unit_str in unit: for base_unit_str in unit:
expo = unit[base_unit_str] expo = unit[base_unit_str]
base_unit = self._get_units(base_unit_str) base_unit = self._get_units(base_unit_str)
new_unit = new_unit * base_unit ** expo new_unit = new_unit * base_unit ** expo
return new_unit return new_unit
elif (not data is None) and isinstance(data, dict) and unit.keys()[0] in data: if (not data is None) and isinstance(data, dict) and unit.keys()[0] in data:
for key in unit: for key in unit:
unit[key] = self._get_units(unit[key]) unit[key] = self._get_units(unit[key])
return unit return unit
@@ -124,8 +122,9 @@ class Comparator(Aggregator, HDF5Container):
for run in self.runs: for run in self.runs:
series[run] = [] series[run] = []
for i, num in enumerate(self.nums[run]): for i, num in enumerate(self.nums[run]):
series[run].apend(getter(run, num, arg=arg)) series[run].append(getter(run, num, arg=arg))
return np.array(series) series[run] = np.array(series[run])
return series
def _comp(self, getter, use_num=True): def _comp(self, getter, use_num=True):
prop = np.zeros(len(self.runs)) prop = np.zeros(len(self.runs))
+59 -10
View File
@@ -119,7 +119,9 @@ class Plotter(Aggregator, BaseProcessor):
or not os.path.exists(plot_filename) or not os.path.exists(plot_filename)
) )
def _process_rule(self, name, rule, arg, overwrite=False, ax=None, **kwargs): def _process_rule(
self, name, rule, arg, overwrite=False, ax=None, movie=False, **kwargs
):
if not arg is None: if not arg is None:
name_full = name + "_" + str(arg) name_full = name + "_" + str(arg)
else: else:
@@ -133,11 +135,9 @@ class Plotter(Aggregator, BaseProcessor):
runs = runs.runs runs = runs.runs
except KeyError: except KeyError:
runs = self.runs runs = self.runs
if ax is None:
ax = [P.subplots(1, 1)[1] for run in runs for num in self.nums[run]]
i = 0 i = 0
for run in runs: for run in runs:
files = []
for num in self.nums[run]: for num in self.nums[run]:
plot_filename = self._find_filename(name_full, run, num) plot_filename = self._find_filename(name_full, run, num)
save = tables.open_file(self.pp[run][num].filename) save = tables.open_file(self.pp[run][num].filename)
@@ -153,11 +153,10 @@ class Plotter(Aggregator, BaseProcessor):
**kwargs **kwargs
) )
except TypeError as e: except TypeError as e:
if ( if str(e) in [
str(e) "'LocatableAxes' object does not support indexing",
!= "'LocatableAxes' object does not support indexing" "'AxesSubplot' object does not support indexing",
): ]:
raise
self._plot_rule( self._plot_rule(
rule, rule,
save, save,
@@ -168,9 +167,24 @@ class Plotter(Aggregator, BaseProcessor):
run=run, run=run,
**kwargs **kwargs
) )
elif ax is None:
fig = P.figure()
self._plot_rule(
rule,
save,
arg,
plot_filename,
overwrite,
ax=P.gca(),
run=run,
**kwargs
)
else:
raise
finally: finally:
save.close() save.close()
i = i + 1 i = i + 1
files.append(plot_filename)
else: else:
if ax is None: if ax is None:
ax = P.gca() ax = P.gca()
@@ -479,6 +493,7 @@ class Plotter(Aggregator, BaseProcessor):
xlog=None, xlog=None,
ylog=False, ylog=False,
kind="bar", kind="bar",
ylabel="$\mathcal{P}$",
color=None, color=None,
colors=None, colors=None,
nml_color=None, nml_color=None,
@@ -505,7 +520,7 @@ class Plotter(Aggregator, BaseProcessor):
if put_time: if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"] time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.out.time_fmt.format( time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex time.express(time_unit), time_unit.latex
) )
if len(title) > 0: if len(title) > 0:
@@ -538,6 +553,8 @@ class Plotter(Aggregator, BaseProcessor):
if not label is None: if not label is None:
P.xlabel(label) P.xlabel(label)
if not ylabel is None:
P.ylabel(ylabel)
if not ax_los is None and "/hist/fit_" + name + "_" + ax_los in self.save: if not ax_los is None and "/hist/fit_" + name + "_" + ax_los in self.save:
slope = node.attrs.slope slope = node.attrs.slope
@@ -573,8 +590,11 @@ class Plotter(Aggregator, BaseProcessor):
yerr_kind="std", yerr_kind="std",
sigma_err=2.0, sigma_err=2.0,
grid=True, grid=True,
put_time=False,
time_unit=cst.Myr,
colors=None, colors=None,
nml_color=None, nml_color=None,
legend=None,
**kwargs **kwargs
): ):
@@ -597,6 +617,16 @@ class Plotter(Aggregator, BaseProcessor):
if grid: if grid:
P.grid() P.grid()
if put_time:
time = self.save.root._v_attrs.time * self.comp.info["unit_time"]
time_str = self.pp_params.plot.time_fmt.format(
time.express(time_unit), time_unit.latex
)
if len(label) > 0:
label = label + " | " + time_str
else:
label = time_str
yerr = None yerr = None
if node_y._v_attrs.CLASS == "ARRAY": if node_y._v_attrs.CLASS == "ARRAY":
x = node_x.read() * xunit_old.express(xunit) x = node_x.read() * xunit_old.express(xunit)
@@ -607,7 +637,23 @@ class Plotter(Aggregator, BaseProcessor):
y = gaussian_filter1d(y, sigma=smooth) y = gaussian_filter1d(y, sigma=smooth)
if not run is None: if not run is None:
label = self._label_run(run, node_y, label, nml_key) label = self._label_run(run, node_y, label, nml_key)
if colors is None:
(base_line,) = P.plot(x, y, label=label, **kwargs) (base_line,) = P.plot(x, y, label=label, **kwargs)
else:
if nml_color is None:
color = colors[run]
elif nml_color == "time":
time = (
self.save.root._v_attrs.time * self.comp.info["unit_time"]
).express(time_unit)
color = colors(time)
else:
nml = self.comp.get_nml(nml_color, run)
try:
color = colors[nml]
except:
color = colors(nml)
(base_line,) = P.plot(x, y, label=label, color=color, **kwargs)
elif "mean" in node_y: elif "mean" in node_y:
x = node_x.read() * xunit_old.express(xunit) x = node_x.read() * xunit_old.express(xunit)
y = node_y.mean.read() * yunit_old.express(yunit) y = node_y.mean.read() * yunit_old.express(yunit)
@@ -664,7 +710,10 @@ class Plotter(Aggregator, BaseProcessor):
except: except:
color = colors(nml) color = colors(nml)
(base_line,) = P.plot(x, y, label=label_run, color=color, **kwargs) (base_line,) = P.plot(x, y, label=label_run, color=color, **kwargs)
if legend is None:
legend = True
if legend:
P.legend() P.legend()
if not fit is None: if not fit is None:
+1 -3
View File
@@ -226,9 +226,7 @@ class PostProcessor(HDF5Container):
self.load_cells() self.load_cells()
return np.sort(np.unique(self.cells["pos"][:, axis])) return np.sort(np.unique(self.cells["pos"][:, axis]))
def _plane_avg_uniform( def _plane_avg_uniform(self, getter, axis, unit=cst.none, mass_weighted=False):
self, getter, axis, unit=cst.none, mass_weighted=True, surf_qty=False
):
""" """
Profile of the average of a quantity (given by getter) perpendicular to an axis Profile of the average of a quantity (given by getter) perpendicular to an axis
WARNING : This version only works on an uniform grid, need of a box version for AMR WARNING : This version only works on an uniform grid, need of a box version for AMR
+9 -2
View File
@@ -44,6 +44,7 @@ class RunSelector:
sort_run_by=None, sort_run_by=None,
time_min=None, time_min=None,
time_max=None, time_max=None,
time=None,
): ):
self.path_in = path_in self.path_in = path_in
self.pp_params = pp_params self.pp_params = pp_params
@@ -66,7 +67,7 @@ class RunSelector:
in_nums[run] = nums_temp in_nums[run] = nums_temp
for i, run in enumerate(self.runs): for i, run in enumerate(self.runs):
self.nums[run] = self.get_nums(run, in_nums[run], time_min, time_max) self.nums[run] = self.get_nums(run, in_nums[run], time_min, time_max, time)
def load_namelist(self, run): def load_namelist(self, run):
path_run = self.path_in + "/" + run path_run = self.path_in + "/" + run
@@ -147,7 +148,7 @@ class RunSelector:
info_file.close() info_file.close()
return info return info
def get_nums(self, run, in_nums=None, time_min=None, time_max=None): def get_nums(self, run, in_nums=None, time_min=None, time_max=None, time=None):
def try_load_info(num): def try_load_info(num):
try: try:
self.info[run][num] = self.load_info(run, num) self.info[run][num] = self.load_info(run, num)
@@ -191,4 +192,10 @@ class RunSelector:
nums = filter(lambda n: self.info[run][n]["time"] >= time_min, nums) nums = filter(lambda n: self.info[run][n]["time"] >= time_min, nums)
if not time_max is None: if not time_max is None:
nums = filter(lambda n: self.info[run][n]["time"] <= time_max, nums) nums = filter(lambda n: self.info[run][n]["time"] <= time_max, nums)
if not time is None:
times = np.asarray([[self.info[run][n]["time"], n] for n in nums])
idx = (np.abs(times[:, 0] - time)).argmin()
nums = [int(times[idx, 1])]
return nums return nums
+7
View File
@@ -68,3 +68,10 @@ cst.ssfr = cst.create_unit(
descr="Surfacic SFR", descr="Surfacic SFR",
latex="M$_{\odot}$.yr$^{-1}$.pc$^{-2}$", latex="M$_{\odot}$.yr$^{-1}$.pc$^{-2}$",
) )
cst.ssfrG = cst.create_unit(
"Msun.Gyr^-1.pc^-2",
base_unit=1e-9 * cst.Msun / cst.year / cst.pc ** 2,
descr="Surfacic SFR",
latex="M$_{\odot}$.Gyr$^{-1}$.pc$^{-2}$",
)