[refactoring] removed star imports, renamed cst in U

and applied flake8 guidelines
This commit is contained in:
Noe Brucy
2020-12-15 14:14:31 +01:00
parent 03e8699b33
commit 53c551364e
9 changed files with 299 additions and 287 deletions
+7 -1
View File
@@ -1,5 +1,11 @@
# coding: utf-8
import numpy as np
from functools import partial
from mypool import MyPool from mypool import MyPool
from postprocessor import *
from postprocessor import PostProcessor
from run_selector import RunSelector
def _map_rule(rule, arg, overwrite, path, path_out, pp_params, run_num): def _map_rule(rule, arg, overwrite, path, path_out, pp_params, run_num):
+22 -59
View File
@@ -1,44 +1,18 @@
# coding: utf-8 # coding: utf-8
import copy import copy
import glob as glob
import os import os
import subprocess
import sys
import time import time
from abc import ABCMeta, abstractmethod from abc import ABCMeta
from functools import partial from functools import partial
import numpy as np import numpy as np
import pymses
import tables
from astrophysix.simdm import SimulationStudy
from astrophysix.simdm.experiment import (
AppliedAlgorithm,
ParameterSetting,
ParameterVisibility,
Simulation,
)
from astrophysix.simdm.results import GenericResult, Snapshot
from numpy.polynomial.polynomial import polyfit
from pymses.analysis import (
Camera,
FractionOperator,
MaxLevelOperator,
ScalarOperator,
raytracing,
slicing,
splatting,
)
from pymses.filters import CellsToPoints
from pymses.sources.hop.file_formats import *
from pymses.sources.ramses import output
from scipy.stats import linregress
from tables import HDF5ExtError
from ramses_astrophysix import ramses import tables
from run_selector import *
from units import * from tables import HDF5ExtError
from pp_params import default_params, load_params
from units import U
class Rule: class Rule:
@@ -51,7 +25,7 @@ class Rule:
dependencies=[], dependencies=[],
is_valid=lambda arg: True, is_valid=lambda arg: True,
kind="classic", kind="classic",
unit=cst.none, unit=U.none,
): ):
self.postproc = postproc self.postproc = postproc
self.process_fn = process self.process_fn = process
@@ -63,24 +37,11 @@ class Rule:
self.kind = kind self.kind = kind
def process(self, arg, **kwargs): def process(self, arg, **kwargs):
if not arg is None: if arg is not None:
return self.process_fn(arg, **kwargs) return self.process_fn(arg, **kwargs)
else: else:
return self.process_fn(**kwargs) return self.process_fn(**kwargs)
def is_valid(self, arg):
# save = self.postproc.save
# valid = True
# for dep in self.dependencies:
# if dep in self.postproc.rules:
# rule_dep = self.postproc.rules[dep]
# if not arg is None:
# valid = valid and rule_dep.group + '/' + dep + '_' + str(arg) in save
# else:
# valid = valid and rule_dep.group + '/' + dep in save
# return valid and self.is_valid_add(arg)
return self.is_valid_add(arg)
class BaseProcessor: class BaseProcessor:
""" """
@@ -134,7 +95,9 @@ class BaseProcessor:
) )
else: else:
self._log( self._log(
"{} is unknown, allowed rules are {}".format(name, self.rules.keys()), "{} is unknown, allowed rules are {}".format(
to_process, self.rules.keys()
),
"ERROR", "ERROR",
) )
@@ -176,13 +139,13 @@ class BaseProcessor:
return overwrite return overwrite
def _process_rule(self, name, rule, arg, overwrite=False, **kwargs): def _process_rule(self, name, rule, arg, overwrite=False, **kwargs):
if not arg is None: if arg is not None:
name_full = rule.group + "/" + name + "_" + str(arg) name_full = rule.group + "/" + name + "_" + str(arg)
else: else:
name_full = rule.group + "/" + name name_full = rule.group + "/" + name
if rule.is_valid(arg): if rule.is_valid(arg):
if not name_full in self.just_done: if name_full not in self.just_done:
if self._needs_computation(overwrite, name_full): if self._needs_computation(overwrite, name_full):
self._log("Processing {}".format(name_full)) self._log("Processing {}".format(name_full))
data = rule.process(arg, **kwargs) data = rule.process(arg, **kwargs)
@@ -255,7 +218,7 @@ class HDF5Container(BaseProcessor):
) )
else: else:
value = node.read() value = node.read()
if not (unit is None or unit_old is None or unit_old == cst.none): if not (unit is None or unit_old is None or unit_old == U.none):
value = value * unit_old.express(unit) value = value * unit_old.express(unit)
finally: finally:
if not open_before: if not open_before:
@@ -266,7 +229,7 @@ class HDF5Container(BaseProcessor):
""" """
Get real units from info files Get real units from info files
unit is either: unit is either:
1. An instance of cst.Unit (pymses unit class) 1. An instance of U.Unit (pymses unit class)
2. A string beginning by "unit_", referring to a code unit, 2. A string beginning by "unit_", referring to a code unit,
available in self.info available in self.info
3. A dict {unit1 : exp1, unit2: exp2, ...} with unitX as 2. 3. A dict {unit1 : exp1, unit2: exp2, ...} with unitX as 2.
@@ -276,10 +239,10 @@ class HDF5Container(BaseProcessor):
and unit the corresponding unit (on one on the above format) and unit the corresponding unit (on one on the above format)
Returns: Returns:
1-3. : a cst.Unit instance 1-3. : a U.Unit instance
4. : a dict {key: unit, ...} with same key as input and unit being cst.Unit instances 4. : a dict {key: unit, ...} with same key as input and unit being U.Unit instances
""" """
if isinstance(unit, cst.Unit): if isinstance(unit, U.Unit):
return unit return unit
if isinstance(unit, str) and unit[:5] == "unit_": if isinstance(unit, str) and unit[:5] == "unit_":
res = self.info[unit] res = self.info[unit]
@@ -287,13 +250,13 @@ class HDF5Container(BaseProcessor):
res = res / self.info["boxlen"] res = res / self.info["boxlen"]
return res return res
if list(unit)[0][:5] == "unit_": if list(unit)[0][:5] == "unit_":
new_unit = cst.none new_unit = U.none
for base_unit_str in unit: for base_unit_str in unit:
expo = unit[base_unit_str] expo = unit[base_unit_str]
base_unit = self._get_units(base_unit_str) base_unit = self._get_units(base_unit_str)
new_unit = new_unit * base_unit ** expo new_unit = new_unit * base_unit ** expo
return new_unit return new_unit
if (not data is None) and isinstance(data, dict) and list(unit)[0] in data: if (data is not None) and isinstance(data, dict) and list(unit)[0] in data:
for key in unit: for key in unit:
unit[key] = self._get_units(unit[key]) unit[key] = self._get_units(unit[key])
return unit return unit
@@ -364,7 +327,7 @@ class HDF5Container(BaseProcessor):
) )
self.save.get_node(name_full).attrs.unit = unit self.save.get_node(name_full).attrs.unit = unit
if not attrs is None: if attrs is not None:
for key in attrs: for key in attrs:
key = str(key) key = str(key)
self.save.get_node(name_full)._v_attrs[key] = attrs[key] self.save.get_node(name_full)._v_attrs[key] = attrs[key]
@@ -448,7 +411,7 @@ class HDF5Container(BaseProcessor):
fn, fn,
name_array_in, name_array_in,
name_array_out, name_array_out,
unit_out=cst.none, unit_out=U.none,
description="", description="",
recursive=True, recursive=True,
): ):
+38 -28
View File
@@ -1,6 +1,17 @@
# coding: utf-8 # coding: utf-8
from aggregator import * import os
import glob
import numpy as np
from functools import partial
from scipy.stats import linregress
from baseprocessor import Rule, HDF5Container
from aggregator import Aggregator
from postprocessor import PostProcessor
from run_selector import RunSelector
from pp_params import default_params
from units import U
class Comparator(Aggregator, HDF5Container): class Comparator(Aggregator, HDF5Container):
@@ -17,7 +28,7 @@ class Comparator(Aggregator, HDF5Container):
pp_params=default_params(), pp_params=default_params(),
selector=None, selector=None,
tag=None, tag=None,
unit_time=cst.year, unit_time=U.year,
**kwargs **kwargs
): ):
""" """
@@ -78,14 +89,14 @@ class Comparator(Aggregator, HDF5Container):
""" """
if overwrite or not (name_full in self.save): if overwrite or not (name_full in self.save):
return True return True
elif not "nums" in self.save.get_node(name_full)._v_attrs: elif "nums" not in self.save.get_node(name_full)._v_attrs:
return True return True
else: else:
saved_nums = self.save.get_node(name_full)._v_attrs.nums saved_nums = self.save.get_node(name_full)._v_attrs.nums
missing_runs = len([run for run in self.nums if not run in saved_nums]) > 0 missing_runs = len([run for run in self.nums if run not in saved_nums]) > 0
missing_nums = missing_runs or all( missing_nums = missing_runs or all(
[ [
len([num for num in self.nums[run] if not num in saved_nums[run]]) len([num for num in self.nums[run] if num not in saved_nums[run]])
> 0 > 0
for run in self.nums for run in self.nums
if run in saved_nums if run in saved_nums
@@ -182,18 +193,18 @@ class Comparator(Aggregator, HDF5Container):
def get_attr(self, attr_name, run, num, node_name="/", arg=None): def get_attr(self, attr_name, run, num, node_name="/", arg=None):
pp = self.pp[run][num] pp = self.pp[run][num]
if not arg is None: if arg is not None:
node_name = node_name + "_" + str(arg) node_name = node_name + "_" + str(arg)
return pp.get_attribute(node_name, attr_name) return pp.get_attribute(node_name, attr_name)
def get_pp_value(self, name, run, num, arg=None): def get_pp_value(self, name, run, num, arg=None):
pp = self.pp[run][num] pp = self.pp[run][num]
if not arg is None: if arg is not None:
name = name + "_" + str(arg) name = name + "_" + str(arg)
return pp.get_value(name) return pp.get_value(name)
def get_global(self, node_name, run, num, arg=None, unload_cells=False): def get_global(self, node_name, run, num, arg=None, unload_cells=False):
if not arg is None: if arg is not None:
node_name = node_name + "_" + str(arg) node_name = node_name + "_" + str(arg)
pp = self.pp[run][num] pp = self.pp[run][num]
if unload_cells: if unload_cells:
@@ -276,7 +287,6 @@ class Comparator(Aggregator, HDF5Container):
return series return series
def _from_log(self, keys, extractor): def _from_log(self, keys, extractor):
nums = self.nums
# Initialize series # Initialize series
series = {} series = {}
@@ -318,15 +328,15 @@ class Comparator(Aggregator, HDF5Container):
for run in self.runs: for run in self.runs:
# Surface of the box in pc^2 # Surface of the box in pc^2
info = self.pp[run][self.nums[run][0]].info info = self.pp[run][self.nums[run][0]].info
surface = (info["unit_length"].express(cst.pc)) ** 2 surface = (info["unit_length"].express(U.pc)) ** 2
# WARNING : We do not multiply by boxlen since already done in 'unit_length' (pymses) # WARNING : We do not multiply by boxlen since already done in 'unit_length' (pymses)
time = self.save.get_node("/series/sinks_from_log/time/" + run).read() time = self.save.get_node("/series/sinks_from_log/time/" + run).read()
time = time * time_unit.express(cst.year) time = time * time_unit.express(U.year)
mass_sink = self.save.get_node( mass_sink = self.save.get_node(
"/series/sinks_from_log/mass_sink/" + run "/series/sinks_from_log/mass_sink/" + run
).read() ).read()
mass_sink = mass_sink * mass_unit.express(cst.Msun) mass_sink = mass_sink * mass_unit.express(U.Msun)
if avg_window is None: if avg_window is None:
shift = 1 shift = 1
@@ -350,11 +360,11 @@ class Comparator(Aggregator, HDF5Container):
for run in self.runs: for run in self.runs:
# Surface of the box in pc^2 # Surface of the box in pc^2
info = self.pp[run][self.nums[run][0]].info info = self.pp[run][self.nums[run][0]].info
surface = (info["unit_length"].express(cst.pc)) ** 2 surface = (info["unit_length"].express(U.pc)) ** 2
mass_sink = self.save.get_node( mass_sink = self.save.get_node(
"/series/sinks_from_log/mass_sink/" + run "/series/sinks_from_log/mass_sink/" + run
).read() ).read()
mass_sink = mass_sink * mass_unit.express(cst.Msun) mass_sink = mass_sink * mass_unit.express(U.Msun)
ssm[run] = mass_sink / surface ssm[run] = mass_sink / surface
@@ -402,7 +412,7 @@ class Comparator(Aggregator, HDF5Container):
glob_group="/globals", glob_group="/globals",
subarray_name=None, subarray_name=None,
unload_cells=True, unload_cells=True,
unit=cst.none, unit=U.none,
description="", description="",
): ):
@@ -472,7 +482,7 @@ class Comparator(Aggregator, HDF5Container):
) )
def def_rules(self): def def_rules(self):
averageables = ["coldens", "rho", "T", "Q"]
self.rules = { self.rules = {
# Read from log # Read from log
"sinks_from_log": Rule( "sinks_from_log": Rule(
@@ -483,7 +493,7 @@ class Comparator(Aggregator, HDF5Container):
self._extract_sinks_from_log, self._extract_sinks_from_log,
), ),
group="/series", group="/series",
unit={"time": "unit_time", "mass_sink": cst.Msun, "nb_sink": cst.none}, unit={"time": "unit_time", "mass_sink": U.Msun, "nb_sink": U.none},
description={ description={
"time": "Time", "time": "Time",
"mass_sink": "Total mass of stars", "mass_sink": "Total mass of stars",
@@ -494,7 +504,7 @@ class Comparator(Aggregator, HDF5Container):
self, self,
self._ssfr_from_mass_sink, self._ssfr_from_mass_sink,
group="/series/sinks_from_log", group="/series/sinks_from_log",
unit=cst.ssfr, unit=U.ssfr,
description="Instantaneous surfacic star formation rate", description="Instantaneous surfacic star formation rate",
dependencies=["sinks_from_log"], dependencies=["sinks_from_log"],
), ),
@@ -502,7 +512,7 @@ class Comparator(Aggregator, HDF5Container):
self, self,
self._surfacic_sink_mass, self._surfacic_sink_mass,
group="/series/sinks_from_log", group="/series/sinks_from_log",
unit=cst.Msun / cst.pc ** 2, unit=U.Msun / U.pc ** 2,
description="Surfacic sink mass", description="Surfacic sink mass",
dependencies=["sinks_from_log"], dependencies=["sinks_from_log"],
), ),
@@ -510,7 +520,7 @@ class Comparator(Aggregator, HDF5Container):
self, self,
partial(self._from_log, ["time", "sfr"], self._extract_sfr_from_log), partial(self._from_log, ["time", "sfr"], self._extract_sfr_from_log),
group="/series", group="/series",
unit={"time": cst.year, "sfr": cst.ssfr}, unit={"time": U.year, "sfr": U.ssfr},
description={ description={
"time": "Time", "time": "Time",
"sfr": "Averaged surfacic star formation rate", "sfr": "Averaged surfacic star formation rate",
@@ -527,7 +537,7 @@ class Comparator(Aggregator, HDF5Container):
unit={ unit={
"time": "unit_time", "time": "unit_time",
"dt": "unit_time", "dt": "unit_time",
"turb_rms": cst.none, "turb_rms": U.none,
"turb_energy": { "turb_energy": {
"unit_length": 3, "unit_length": 3,
"unit_velocity": 2, "unit_velocity": 2,
@@ -551,13 +561,13 @@ class Comparator(Aggregator, HDF5Container):
group="/series", group="/series",
unit={ unit={
"time": "unit_time", "time": "unit_time",
"step": cst.none, "step": U.none,
"mcons": cst.none, "mcons": U.none,
"econs": cst.none, "econs": U.none,
"epot": cst.none, # TODO find unit "epot": U.none, # TODO find unit
"ekin": cst.none, "ekin": U.none,
"eint": cst.none, "eint": U.none,
"emag": cst.none, "emag": U.none,
}, },
), ),
"turb_power": Rule( "turb_power": Rule(
-2
View File
@@ -6,8 +6,6 @@ import multiprocessing
# We must import this explicitly, it is not imported by the top-level # We must import this explicitly, it is not imported by the top-level
# multiprocessing module. # multiprocessing module.
import multiprocessing.pool import multiprocessing.pool
import time
from random import randint
class NoDaemonProcess(multiprocessing.Process): class NoDaemonProcess(multiprocessing.Process):
+142 -126
View File
@@ -9,9 +9,7 @@ This is the plotter module.
""" """
import os import os
import sys
from functools import partial from functools import partial
import matplotlib as mpl import matplotlib as mpl
import numpy as np import numpy as np
import tables import tables
@@ -27,10 +25,24 @@ if os.environ.get("DISPLAY", "") == "":
mpl.use("Agg") mpl.use("Agg")
import datetime import datetime
import pylab as P import matplotlib.pyplot as plt
import pspec_read import pspec_read
from comparator import * from baseprocessor import Rule, BaseProcessor
from aggregator import Aggregator
from comparator import Comparator
from run_selector import RunSelector
from units import U, unit_str, convert_exp
from astrophysix.simdm.results import GenericResult
from astrophysix.simdm.experiment import (
ParameterSetting,
ParameterVisibility,
Simulation,
)
from ramses_astrophysix import ramses
filetype_from_ext = {ext: ft for ft in FileType for ext in ft.extension_list} filetype_from_ext = {ext: ft for ft in FileType for ext in ft.extension_list}
@@ -88,19 +100,18 @@ class Plotter(Aggregator, BaseProcessor):
"beta": "$\\beta$", "beta": "$\\beta$",
"beta_cool": "$\\beta$", "beta_cool": "$\\beta$",
"dens0": "$n_0$", "dens0": "$n_0$",
"coldens0": "$\Sigma_0$", "coldens0": "$\\Sigma_0$",
"sfr_avg_window": "window", "sfr_avg_window": "window",
"bx_bound": "$B_0$", "bx_bound": "$B_0$",
"levelmax": "$l_{\max}$", "levelmax": "$l_{\\max}$",
"levelmin": "$l_{\min}$", "levelmin": "$l_{\\min}$",
"comp_frac": "$1 - \\zeta$", "comp_frac": "$1 - \\zeta$",
} }
# Conversion table from namelist values (from amses config file) into LaTex strings # Conversion table from namelist values (from amses config file) into LaTex strings
value_convert = { value_convert = {
"sfr_avg_window": lambda x: "${:g}$ Myr".format(80 * x), "sfr_avg_window": lambda x: "${:g}$ Myr".format(80 * x),
#'comp_frac' : lambda x: "${:g}$".format(1 - x), "bx_bound": lambda x: "${:g}$ $\\mu G$".format(5.267501272979475 * x),
"bx_bound": lambda x: "${:g}$ $\mu G$".format(5.267501272979475 * x),
} }
def __init__( def __init__(
@@ -112,7 +123,7 @@ class Plotter(Aggregator, BaseProcessor):
pp_params=None, pp_params=None,
selector=None, selector=None,
tag=None, tag=None,
unit_time=cst.year, unit_time=U.year,
**kwargs, **kwargs,
): ):
@@ -124,10 +135,12 @@ class Plotter(Aggregator, BaseProcessor):
path : path to the main folder of the simulations (ex '~/simus/myproject') path : path to the main folder of the simulations (ex '~/simus/myproject')
in_runs : list of the runs to consider (ex ['run1', 'run2']) in_runs : list of the runs to consider (ex ['run1', 'run2'])
in_nums : list or dict of the outputs numbers to consider (ex [3, 5] or {'run1' : [3, 5], 'run2' : [4, 6]) in_nums : list or dict of the outputs numbers to consider (ex [3, 5]
or {'run1' : [3, 5], 'run2' : [4, 6])
path_out : Path where the plot will be saved. By default set to `path` path_out : Path where the plot will be saved. By default set to `path`
pp_params : Parameters for postprocessing. See pp_params module. pp_params : Parameters for postprocessing. See pp_params module.
selector : Existing instance of RunSelector, that selects runs and outputs. If set, in_runs and in_nums will be ignored selector : Existing instance of RunSelector, that selects runs and outputs. If set, in_runs and
in_nums will be ignored
tag : string to add in the output and data files. tag : string to add in the output and data files.
kwargs : Keyword arguments for RunSelector. kwargs : Keyword arguments for RunSelector.
""" """
@@ -216,7 +229,7 @@ class Plotter(Aggregator, BaseProcessor):
dep, dep_arg, overwrite, overwrite_dep=self.overwrite_dep dep, dep_arg, overwrite, overwrite_dep=self.overwrite_dep
) )
if result is not None: if result is not None:
self.just_done.append(done) self.just_done.append(result)
else: else:
super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, **kwargs) super(Plotter, self)._not_self_dep(name, dep, dep_arg, overwrite, **kwargs)
@@ -238,7 +251,7 @@ class Plotter(Aggregator, BaseProcessor):
""" """
# Set full name according to argument # Set full name according to argument
if not arg is None: if arg is not None:
name_full = ( name_full = (
name name
+ "_" + "_"
@@ -286,7 +299,7 @@ class Plotter(Aggregator, BaseProcessor):
real_ax = ax[i] real_ax = ax[i]
except TypeError as e: except TypeError as e:
if ax is None: if ax is None:
fig, real_ax = P.subplots(1, 1) fig, real_ax = plt.subplots(1, 1)
elif not_array_error(e): elif not_array_error(e):
real_ax = ax real_ax = ax
else: else:
@@ -341,15 +354,15 @@ class Plotter(Aggregator, BaseProcessor):
""" """
Once all dependencies are met, actually process the rule Once all dependencies are met, actually process the rule
""" """
P.sca(ax) plt.sca(ax)
if self._needs_computation(overwrite, plot_filename): if self._needs_computation(overwrite, plot_filename):
plot_info = rule.plot(save, arg, **kwargs) plot_info = rule.plot(save, arg, **kwargs)
if not self.pp_params.out.interactive: if not self.pp_params.out.interactive:
P.tight_layout(pad=1) plt.tight_layout(pad=1)
if self.pp_params.out.save: if self.pp_params.out.save:
P.savefig(plot_filename) plt.savefig(plot_filename)
self._log("{} plotted".format(plot_filename), "SUCCESS") self._log("{} plotted".format(plot_filename), "SUCCESS")
else: else:
self._log( self._log(
@@ -357,7 +370,7 @@ class Plotter(Aggregator, BaseProcessor):
) )
if not self.pp_params.out.interactive: if not self.pp_params.out.interactive:
P.close() plt.close()
return plot_info return plot_info
else: else:
self._log("Plot {} is already done, skipping...".format(plot_filename)) self._log("Plot {} is already done, skipping...".format(plot_filename))
@@ -372,9 +385,9 @@ class Plotter(Aggregator, BaseProcessor):
if not self.pp_params.out.tag == "": if not self.pp_params.out.tag == "":
tag_name = "_" + tag_name tag_name = "_" + tag_name
if not run is None and not num is None: if run is not None and num is not None:
fmt = "{out}/{run}/{name}{tag}_{run}_{num:05}{ext}" fmt = "{out}/{run}/{name}{tag}_{run}_{num:05}{ext}"
elif not run is None: elif run is not None:
fmt = "{out}/{run}/{name}{tag}_{run}{ext}" fmt = "{out}/{run}/{name}{tag}_{run}{ext}"
else: else:
fmt = "{out}/{name}{tag}{ext}" fmt = "{out}/{name}{tag}{ext}"
@@ -382,7 +395,7 @@ class Plotter(Aggregator, BaseProcessor):
fmt = self.pp_params.out.fmt fmt = self.pp_params.out.fmt
nml = None nml = None
if not run is None: if run is not None:
nml = self.comp.namelist[run] nml = self.comp.namelist[run]
return fmt.format( return fmt.format(
@@ -423,12 +436,12 @@ class Plotter(Aggregator, BaseProcessor):
label_run = r"{}".format(self.save.root._v_attrs.attrs[run].label) label_run = r"{}".format(self.save.root._v_attrs.attrs[run].label)
else: else:
label_run = run label_run = run
elif not nml_key is None: elif nml_key is not None:
if not type(nml_key) == list: if not type(nml_key) == list:
nml_key = [nml_key] nml_key = [nml_key]
label_run = ", ".join(map(get_label_nml, nml_key)) label_run = ", ".join(map(get_label_nml, nml_key))
if not label is None: if label is not None:
label_run = label + " (" + label_run + ")" label_run = label + " (" + label_run + ")"
else: else:
label_run = label label_run = label
@@ -451,7 +464,7 @@ class Plotter(Aggregator, BaseProcessor):
if "unit" in node._v_attrs: if "unit" in node._v_attrs:
unit_old = node._v_attrs.unit unit_old = node._v_attrs.unit
else: else:
unit_old = cst.none unit_old = U.none
if unit is None: if unit is None:
unit = unit_old unit = unit_old
@@ -466,7 +479,7 @@ class Plotter(Aggregator, BaseProcessor):
return label, unit_old, unit return label, unit_old, unit
def _snapshot_title(self, run, node, title, nml_key, put_time, unit_time=cst.Myr): def _snapshot_title(self, run, node, title, nml_key, put_time, unit_time=U.Myr):
title = self._label_run(run, node, title, nml_key) title = self._label_run(run, node, title, nml_key)
if put_time: if put_time:
@@ -497,9 +510,9 @@ class Plotter(Aggregator, BaseProcessor):
put_title=True, put_title=True,
nml_key=None, nml_key=None,
put_time=True, put_time=True,
unit_time=cst.Myr, unit_time=U.Myr,
put_units=True, put_units=True,
unit_space=cst.pc, unit_space=U.pc,
cmap="plasma", cmap="plasma",
norm="log", norm="log",
put_cbar=True, put_cbar=True,
@@ -535,14 +548,14 @@ class Plotter(Aggregator, BaseProcessor):
elif norm == "linear": elif norm == "linear":
norm = mpl.colors.NoNorm() norm = mpl.colors.NoNorm()
if autoscale and not norm is None: if autoscale and norm is not None:
norm.autoscale(dmap) norm.autoscale(dmap)
im = P.imshow( im = plt.imshow(
dmap, extent=im_extent, origin="lower", norm=norm, cmap=cmap, **kwargs dmap, extent=im_extent, origin="lower", norm=norm, cmap=cmap, **kwargs
) )
P.locator_params(axis="both", nbins=self.pp_params.plot.ntick) plt.locator_params(axis="both", nbins=self.pp_params.plot.ntick)
if xlabel is None: if xlabel is None:
xlabel = self._ax_title[ax_h] xlabel = self._ax_title[ax_h]
@@ -551,19 +564,19 @@ class Plotter(Aggregator, BaseProcessor):
if put_units: if put_units:
xlabel = xlabel + unit_str(unit_space) xlabel = xlabel + unit_str(unit_space)
ylabel = ylabel + unit_str(unit_space) ylabel = ylabel + unit_str(unit_space)
P.xlabel(xlabel) plt.xlabel(xlabel)
P.ylabel(ylabel) plt.ylabel(ylabel)
try: try:
cbar = P.colorbar(im, cax=P.gca().cax) cbar = plt.colorbar(im, cax=plt.gca().cax)
except AttributeError: except AttributeError:
cbar = P.colorbar() cbar = plt.colorbar()
if put_title: if put_title:
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time) title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
P.title(title) plt.title(title)
if not label is None: if label is not None:
cbar.set_label(label) cbar.set_label(label)
for i, plot_overlay in enumerate(overlays): for i, plot_overlay in enumerate(overlays):
@@ -627,7 +640,7 @@ class Plotter(Aggregator, BaseProcessor):
) )
lw[lvl_array < lvl_th] = 1.0 lw[lvl_array < lvl_th] = 1.0
cont = P.contour( cont = plt.contour(
map_contour, map_contour,
extent=im_extent, extent=im_extent,
origin="lower", origin="lower",
@@ -639,7 +652,7 @@ class Plotter(Aggregator, BaseProcessor):
lvls = np.array(cont.levels) + lvl_offset lvls = np.array(cont.levels) + lvl_offset
cont.levels = lvls cont.levels = lvls
P.clabel( plt.clabel(
cont, cont,
lvls[np.array(lvls) < lvl_max_lbl], lvls[np.array(lvls) < lvl_max_lbl],
inline=1, inline=1,
@@ -663,13 +676,11 @@ class Plotter(Aggregator, BaseProcessor):
) )
def _overlay_speed( def _overlay_speed(
self, ax_los, im_extent, unit=cst.km_s, unit_coeff=1.0, key_v=None, **kwargs self, ax_los, im_extent, unit=U.km_s, unit_coeff=1.0, key_v=None, **kwargs
): ):
""" """
Add an overlay : velocity vector field Add an overlay : velocity vector field
""" """
ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los]
dmap_vh_node = self.save.get_node("/maps/speed_h_{}".format(ax_los)) dmap_vh_node = self.save.get_node("/maps/speed_h_{}".format(ax_los))
dmap_vh = dmap_vh_node.read() dmap_vh = dmap_vh_node.read()
dmap_vv = self.save.get_node("/maps/speed_v_{}".format(ax_los)).read() dmap_vv = self.save.get_node("/maps/speed_v_{}".format(ax_los)).read()
@@ -701,11 +712,11 @@ class Plotter(Aggregator, BaseProcessor):
max_v = np.max(norm_v) max_v = np.max(norm_v)
min_v = np.min(norm_v) min_v = np.min(norm_v)
Q = P.quiver(hh, vv, map_vh_red, map_vv_red, units="width", **kwargs) Q = plt.quiver(hh, vv, map_vh_red, map_vv_red, units="width", **kwargs)
if key_v is None: if key_v is None:
key_v = (max_v + min_v) / 2.0 key_v = (max_v + min_v) / 2.0
P.quiverkey( plt.quiverkey(
Q, Q,
0.6, 0.6,
0.98, 0.98,
@@ -719,8 +730,6 @@ class Plotter(Aggregator, BaseProcessor):
""" """
Add an overlay : magnetic streamlines Add an overlay : magnetic streamlines
""" """
ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los]
dmap_Bh_node = self.save.get_node("/maps/B_h_{}".format(ax_los)) dmap_Bh_node = self.save.get_node("/maps/B_h_{}".format(ax_los))
dmap_Bh = dmap_Bh_node.read() dmap_Bh = dmap_Bh_node.read()
dmap_Bv = self.save.get_node("/maps/B_v_{}".format(ax_los)).read() dmap_Bv = self.save.get_node("/maps/B_v_{}".format(ax_los)).read()
@@ -744,7 +753,7 @@ class Plotter(Aggregator, BaseProcessor):
) * lbox ) * lbox
hh, vv = np.meshgrid(vec_h, vec_v) hh, vv = np.meshgrid(vec_h, vec_v)
P.streamplot(hh, vv, map_Bh_red, map_Bv_red, **kwargs) plt.streamplot(hh, vv, map_Bh_red, map_Bv_red, **kwargs)
def _plot_radial( def _plot_radial(
self, self,
@@ -760,7 +769,7 @@ class Plotter(Aggregator, BaseProcessor):
nml_key=None, nml_key=None,
put_title=True, put_title=True,
put_time=True, put_time=True,
unit_time=cst.Myr, unit_time=U.Myr,
**kwargs, **kwargs,
): ):
""" """
@@ -775,23 +784,23 @@ class Plotter(Aggregator, BaseProcessor):
if ytransform is not None: if ytransform is not None:
mean_bin = ytransform(mean_bin) mean_bin = ytransform(mean_bin)
P.xlabel(r"$r$") plt.xlabel(r"$r$")
if xlog: if xlog:
P.xscale("log") plt.xscale("log")
if ylog: if ylog:
P.yscale("log") plt.yscale("log")
if not ylabel is None: if ylabel is not None:
P.ylabel(ylabel) plt.ylabel(ylabel)
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time) title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
if put_title: if put_title:
P.title(title) plt.title(title)
if label == None: if label is None:
label = title label = title
P.plot(bin_centers, mean_bin, label=label, **kwargs) plt.plot(bin_centers, mean_bin, label=label, **kwargs)
def _plot_hist( def _plot_hist(
self, self,
@@ -808,11 +817,11 @@ class Plotter(Aggregator, BaseProcessor):
title=None, title=None,
nml_key=None, nml_key=None,
put_time=True, put_time=True,
unit_time=cst.Myr, unit_time=U.Myr,
xlog=None, xlog=None,
ylog=False, ylog=False,
kind="bar", kind="bar",
ylabel="$\mathcal{P}$", ylabel="$\\mathcal{P}$",
color=None, color=None,
colors=None, colors=None,
nml_color=None, nml_color=None,
@@ -824,7 +833,7 @@ class Plotter(Aggregator, BaseProcessor):
Plot an histogram (PDF, etc ...) Plot an histogram (PDF, etc ...)
""" """
# Get node # Get node
if not ax_los is None: if ax_los is not None:
name = name + "_" + ax_los name = name + "_" + ax_los
node = self.save.get_node(group + name) node = self.save.get_node(group + name)
if xlog is None: if xlog is None:
@@ -853,12 +862,12 @@ class Plotter(Aggregator, BaseProcessor):
# Set title # Set title
title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time) title = self._snapshot_title(run, node, title, nml_key, put_time, unit_time)
if put_title: if put_title:
P.title(title) plt.title(title)
if label == None: if label is None:
label = title label = title
# Set colors # Set colors
if color is None and not colors is None: if color is None and colors is not None:
if nml_color is None: if nml_color is None:
color = colors[run] color = colors[run]
else: else:
@@ -870,25 +879,27 @@ class Plotter(Aggregator, BaseProcessor):
# Actual plot # Actual plot
if kind == "bar": if kind == "bar":
P.bar(centers, values, width, log=ylog, color=color, label=label, **kwargs) plt.bar(
centers, values, width, log=ylog, color=color, label=label, **kwargs
)
elif kind == "step": elif kind == "step":
if ylog: if ylog:
P.yscale("log") plt.yscale("log")
P.step(centers, values, where="mid", color=color, label=label, **kwargs) plt.step(centers, values, where="mid", color=color, label=label, **kwargs)
else: else:
raise ValueError("kind must be 'bar' or 'step'") raise ValueError("kind must be 'bar' or 'step'")
# put labels # put labels
if not label is None: if label is not None:
P.xlabel(xlabel) plt.xlabel(xlabel)
if not ylabel is None: if ylabel is not None:
P.ylabel(ylabel) plt.ylabel(ylabel)
# Also diplay fit, previously saved # Also diplay fit, previously saved
if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save: if ax_los is not None and "/hist/fit_" + name + "_" + ax_los in self.save:
slope = node.attrs.slope slope = node.attrs.slope
origin = node.attrs.origin origin = node.attrs.origin
P.plot( plt.plot(
centers, centers,
10 ** (slope * centers + origin), 10 ** (slope * centers + origin),
"--", "--",
@@ -896,7 +907,7 @@ class Plotter(Aggregator, BaseProcessor):
color="orange", color="orange",
) )
# or a new one # or a new one
if not fit is None: if fit is not None:
self._overlay_fit( self._overlay_fit(
centers, values, kind=fit, ls="--", lw=1.5, label=fitlabel centers, values, kind=fit, ls="--", lw=1.5, label=fitlabel
) )
@@ -912,7 +923,7 @@ class Plotter(Aggregator, BaseProcessor):
xaxis_label=xlabel, xaxis_label=xlabel,
yaxis_label=ylabel, yaxis_label=ylabel,
xaxis_unit=unit, xaxis_unit=unit,
yaxis_unit=cst.none, yaxis_unit=U.none,
plot_title=title, plot_title=title,
) )
@@ -938,7 +949,7 @@ class Plotter(Aggregator, BaseProcessor):
sigma_err=2.0, sigma_err=2.0,
grid=False, grid=False,
put_time=False, put_time=False,
unit_time=cst.Myr, unit_time=U.Myr,
colors=None, colors=None,
nml_color=None, nml_color=None,
legend=None, legend=None,
@@ -951,7 +962,7 @@ class Plotter(Aggregator, BaseProcessor):
""" """
# Get proper hdf5 names # Get proper hdf5 names
if not node_arg is None: if node_arg is not None:
name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg name_x, name_y = name_x + "_" + node_arg, name_y + "_" + node_arg
# Get hdf5 nodes # Get hdf5 nodes
@@ -1031,15 +1042,15 @@ class Plotter(Aggregator, BaseProcessor):
if smooth > 0: if smooth > 0:
y = gaussian_filter1d(y, sigma=smooth) y = gaussian_filter1d(y, sigma=smooth)
if not run is None: if run is not None:
label = self._label_run(run, node_y, label, nml_key) label = self._label_run(run, node_y, label, nml_key)
# Look if special colors method is used # Look if special colors method is used
if colors is None: if colors is None:
if yerr is None: if yerr is None:
(base_line,) = P.plot(x, y, label=label, **kwargs) (base_line,) = plt.plot(x, y, label=label, **kwargs)
else: else:
base_line, _, _ = P.errorbar(x, y, yerr=yerr, label=label, **kwargs) base_line, _, _ = plt.errorbar(x, y, yerr=yerr, label=label, **kwargs)
else: else:
if nml_color is None: if nml_color is None:
color = colors[run] color = colors[run]
@@ -1055,21 +1066,21 @@ class Plotter(Aggregator, BaseProcessor):
except: except:
color = colors(nml) color = colors(nml)
if yerr is None: if yerr is None:
(base_line,) = P.plot(x, y, label=label, color=color, **kwargs) (base_line,) = plt.plot(x, y, label=label, color=color, **kwargs)
else: else:
base_line, _, _ = P.errorbar( base_line, _, _ = plt.errorbar(
x, y, yerr=yerr, color=color, label=label, **kwargs x, y, yerr=yerr, color=color, label=label, **kwargs
) )
# Ax decorations # Ax decorations
P.xlabel(xlabel) plt.xlabel(xlabel)
P.ylabel(ylabel) plt.ylabel(ylabel)
if grid: if grid:
P.grid() plt.grid()
if legend: if legend:
P.legend() plt.legend()
if not fit is None: if fit is not None:
self._overlay_fit( self._overlay_fit(
x, x,
y, y,
@@ -1121,7 +1132,7 @@ class Plotter(Aggregator, BaseProcessor):
) )
if label is None: if label is None:
label = r"Linear fit with slope ${:.3g}$".format(a) label = r"Linear fit with slope ${:.3g}$".format(a)
P.plot(x, a * x + b, label=label, **kwargs) plt.plot(x, a * x + b, label=label, **kwargs)
elif kind == "power_law": elif kind == "power_law":
if yerr is None: if yerr is None:
(a, b, rho, _map_rule, stderr) = linregress(np.log10(x), np.log10(y)) (a, b, rho, _map_rule, stderr) = linregress(np.log10(x), np.log10(y))
@@ -1131,8 +1142,13 @@ class Plotter(Aggregator, BaseProcessor):
) )
) )
else: else:
fitfunc = lambda p, x: p[0] + p[1] * x
errfunc = lambda p, x, y, err: (y - fitfunc(p, x)) / err def fitfunc(p, x):
return p[0] + p[1] * x
def errfunc(p, x, y, err):
return (y - fitfunc(p, x)) / err
pinit = [1.0, -1.0] pinit = [1.0, -1.0]
out = optimize.leastsq( out = optimize.leastsq(
errfunc, errfunc,
@@ -1151,27 +1167,27 @@ class Plotter(Aggregator, BaseProcessor):
) )
if label is None: if label is None:
label = r"Power-law fit with index {:.1f}".format(a) label = r"Power-law fit with index {:.1f}".format(a)
P.plot(x, (10 ** b) * x ** a, label=label, **kwargs) plt.plot(x, (10 ** b) * x ** a, label=label, **kwargs)
def overlay_kennicutt(self, n0, step): def overlay_kennicutt(self, n0, step):
""" """
Add an overlay : Kennicutt mass accretion Add an overlay : Kennicutt mass accretion
""" """
P.grid(False) plt.grid(False)
ylim = P.ylim() ylim = plt.ylim()
(tmin, tmax) = P.xlim() (tmin, tmax) = plt.xlim()
tmax = tmax + 20 tmax = tmax + 20
ymax = P.ylim()[1] ymax = plt.ylim()[1]
ssfr_sun = 2.5e-9 ssfr_sun = 2.5e-9
ssfr_ken = ssfr_sun * n0 ** 1.4 ssfr_ken = ssfr_sun * n0 ** 1.4
coeff = ssfr_ken * 1e6 * (self.comp.info["unit_length"].express(cst.pc)) ** 2 coeff = ssfr_ken * 1e6 * (self.comp.info["unit_length"].express(U.pc)) ** 2
for i in np.arange(tmin, max(tmax, tmin + ymax / coeff), step): for i in np.arange(tmin, max(tmax, tmin + ymax / coeff), step):
t = np.linspace(0, tmax, 1000) t = np.linspace(0, tmax, 1000)
P.plot(t + i, t * coeff, ls="--", lw=0.9, color="grey") plt.plot(t + i, t * coeff, ls="--", lw=0.9, color="grey")
P.plot(t + tmin, (t + i - tmin) * coeff, ls="--", lw=0.9, color="grey") plt.plot(t + tmin, (t + i - tmin) * coeff, ls="--", lw=0.9, color="grey")
P.xlim(tmin, tmax) plt.xlim(tmin, tmax)
P.ylim(ylim) plt.ylim(ylim)
def _gen_from_log(self, logrule, name, description="Generated"): def _gen_from_log(self, logrule, name, description="Generated"):
self.rules[name] = PlotRule( self.rules[name] = PlotRule(
@@ -1180,7 +1196,7 @@ class Plotter(Aggregator, BaseProcessor):
self._plot, self._plot,
"/series/" + logrule + "/time", "/series/" + logrule + "/time",
"/series/" + logrule + "/" + name, "/series/" + logrule + "/" + name,
xunit=cst.Myr, xunit=U.Myr,
), ),
description=description, description=description,
kind="series", kind="series",
@@ -1202,7 +1218,7 @@ class Plotter(Aggregator, BaseProcessor):
self._plot_map, self._plot_map,
"coldens", "coldens",
label=r"$\Sigma$", label=r"$\Sigma$",
# unit=cst.coldens # unit=U.coldens
), ),
"Column density map", "Column density map",
dependencies=["coldens"], dependencies=["coldens"],
@@ -1235,7 +1251,7 @@ class Plotter(Aggregator, BaseProcessor):
self._plot_map, self._plot_map,
"vphi", "vphi",
label=r"$v_\phi$", label=r"$v_\phi$",
# unit=cst.km_s # unit=U.km_s
), ),
"Azimuthal speed", "Azimuthal speed",
dependencies=["vphi"], dependencies=["vphi"],
@@ -1246,7 +1262,7 @@ class Plotter(Aggregator, BaseProcessor):
self._plot_map, self._plot_map,
"vr", "vr",
label=r"$v_r$", label=r"$v_r$",
# unit=cst.km_s # unit=U.km_s
), ),
"Radial speed", "Radial speed",
dependencies=["vr"], dependencies=["vr"],
@@ -1257,7 +1273,7 @@ class Plotter(Aggregator, BaseProcessor):
self._plot_map, self._plot_map,
"rho", "rho",
label=r"$\rho$", label=r"$\rho$",
# unit=cst.Msun_pc3 # unit=U.Msun_pc3
), ),
"Density slice at s = 0, with s = x, y or z.", "Density slice at s = 0, with s = x, y or z.",
dependencies=["rho"], dependencies=["rho"],
@@ -1268,7 +1284,7 @@ class Plotter(Aggregator, BaseProcessor):
self._plot_map, self._plot_map,
"coldens", "coldens",
label=r"$\Sigma$", label=r"$\Sigma$",
unit=cst.coldens, unit=U.coldens,
overlays=[self._overlay_levels], overlays=[self._overlay_levels],
), ),
"Column density with level overlay", "Column density with level overlay",
@@ -1280,7 +1296,7 @@ class Plotter(Aggregator, BaseProcessor):
self._plot_map, self._plot_map,
"rho", "rho",
label=r"$\rho$", label=r"$\rho$",
unit=cst.Msun_pc3, unit=U.Msun_pc3,
overlays=[self._overlay_speed], overlays=[self._overlay_speed],
), ),
"Density slice with speed overlay", "Density slice with speed overlay",
@@ -1292,7 +1308,7 @@ class Plotter(Aggregator, BaseProcessor):
self._plot_map, self._plot_map,
"rho", "rho",
label=r"$\rho$", label=r"$\rho$",
unit=cst.Msun_pc3, unit=U.Msun_pc3,
overlays=[self._overlay_B], overlays=[self._overlay_B],
), ),
"Density slice with magnetic field overlay", "Density slice with magnetic field overlay",
@@ -1304,7 +1320,7 @@ class Plotter(Aggregator, BaseProcessor):
self._plot_map, self._plot_map,
"rho", "rho",
label=r"$\rho$", label=r"$\rho$",
unit=cst.Msun_pc3, unit=U.Msun_pc3,
overlays=[self._overlay_B, self._overlay_speed], overlays=[self._overlay_B, self._overlay_speed],
), ),
"Density slice with magnetic field and velocity overlay", "Density slice with magnetic field and velocity overlay",
@@ -1382,7 +1398,7 @@ class Plotter(Aggregator, BaseProcessor):
"B_int": PlotRule( "B_int": PlotRule(
self, self,
partial( partial(
self._plot_map, "B_int", label=r"$\mid \mathrm{B} \mid$", unit=cst.T self._plot_map, "B_int", label=r"$\mid \mathrm{B} \mid$", unit=U.T
), ),
"Magnetic intensity map", "Magnetic intensity map",
dependencies=["B_int"], dependencies=["B_int"],
@@ -1450,8 +1466,8 @@ class Plotter(Aggregator, BaseProcessor):
self._plot, self._plot,
"/series/sinks_from_log/time", "/series/sinks_from_log/time",
"/series/sinks_from_log/mass_sink", "/series/sinks_from_log/mass_sink",
xunit=cst.Myr, xunit=U.Myr,
yunit=cst.Msun, yunit=U.Msun,
), ),
"Mass of the sinks as a function of time", "Mass of the sinks as a function of time",
kind="series", kind="series",
@@ -1463,8 +1479,8 @@ class Plotter(Aggregator, BaseProcessor):
self._plot, self._plot,
"/series/sinks_from_log/time", "/series/sinks_from_log/time",
"/series/sinks_from_log/ssm", "/series/sinks_from_log/ssm",
xunit=cst.Myr, xunit=U.Myr,
yunit=cst.Msun / cst.pc ** 2, yunit=U.Msun / U.pc ** 2,
), ),
"Mass of the sinks as a function of time divided by surface", "Mass of the sinks as a function of time divided by surface",
kind="series", kind="series",
@@ -1477,8 +1493,8 @@ class Plotter(Aggregator, BaseProcessor):
"/series/sfr_from_log/time", "/series/sfr_from_log/time",
"/series/sfr_from_log/sfr", "/series/sfr_from_log/sfr",
ylabel="Averaged surfacic SFR", ylabel="Averaged surfacic SFR",
xunit=cst.Myr, xunit=U.Myr,
yunit=cst.ssfr, yunit=U.ssfr,
), ),
kind="series", kind="series",
dependencies=["sfr_from_log"], dependencies=["sfr_from_log"],
@@ -1490,8 +1506,8 @@ class Plotter(Aggregator, BaseProcessor):
"/series/sinks_from_log/time", "/series/sinks_from_log/time",
"/series/sinks_from_log/issfr", "/series/sinks_from_log/issfr",
ylabel="Surfacic SFR", ylabel="Surfacic SFR",
xunit=cst.Myr, xunit=U.Myr,
yunit=cst.ssfr, yunit=U.ssfr,
), ),
kind="series", kind="series",
dependencies=["issfr"], dependencies=["issfr"],
@@ -1502,7 +1518,7 @@ class Plotter(Aggregator, BaseProcessor):
self._plot, self._plot,
"/series/rms_from_log/time", "/series/rms_from_log/time",
"/series/rms_from_log/turb_rms", "/series/rms_from_log/turb_rms",
xunit=cst.Myr, xunit=U.Myr,
), ),
"Turbulent RMS", "Turbulent RMS",
kind="series", kind="series",
@@ -1514,7 +1530,7 @@ class Plotter(Aggregator, BaseProcessor):
self._plot, self._plot,
"/series/rms_from_log/time", "/series/rms_from_log/time",
"/series/rms_from_log/turb_energy", "/series/rms_from_log/turb_energy",
xunit=cst.Myr, xunit=U.Myr,
), ),
"Turbulent energy", "Turbulent energy",
kind="series", kind="series",
@@ -1526,7 +1542,7 @@ class Plotter(Aggregator, BaseProcessor):
self._plot, self._plot,
"/series/rms_from_log/time", "/series/rms_from_log/time",
"/series/rms_from_log/turb_power", "/series/rms_from_log/turb_power",
xunit=cst.Myr, xunit=U.Myr,
), ),
"Turbulent power", "Turbulent power",
kind="series", kind="series",
@@ -1539,8 +1555,8 @@ class Plotter(Aggregator, BaseProcessor):
"/series/time", "/series/time",
"/series/time_sigma", "/series/time_sigma",
ylabel="$\\sigma$", ylabel="$\\sigma$",
xunit=cst.Myr, xunit=U.Myr,
yunit=cst.km_s, yunit=U.km_s,
), ),
"Velocity dispersion", "Velocity dispersion",
kind="series", kind="series",
@@ -1552,8 +1568,8 @@ class Plotter(Aggregator, BaseProcessor):
self._plot, self._plot,
"/series/time", "/series/time",
"/series/time_mwa_B_int", "/series/time_mwa_B_int",
xunit=cst.Myr, xunit=U.Myr,
yunit=cst.uG, yunit=U.uG,
), ),
"Magnetic intensity average", "Magnetic intensity average",
kind="series", kind="series",
@@ -1565,8 +1581,8 @@ class Plotter(Aggregator, BaseProcessor):
self._plot, self._plot,
"/series/time", "/series/time",
"/series/time_mass", "/series/time_mass",
xunit=cst.Myr, xunit=U.Myr,
yunit=cst.Msun, yunit=U.Msun,
), ),
"Total mass in the box", "Total mass in the box",
kind="series", kind="series",
@@ -1578,8 +1594,8 @@ class Plotter(Aggregator, BaseProcessor):
self._plot, self._plot,
"/series/time", "/series/time",
"/series/time_max_fluct_coldens_z", "/series/time_max_fluct_coldens_z",
ylabel="$\\max(\Sigma/\overline{\Sigma})$", ylabel="$\\max(\\Sigma/\\overline{\\Sigma})$",
xunit=cst.Myr, xunit=U.Myr,
), ),
"Maximal fluctuation of the column density against time", "Maximal fluctuation of the column density against time",
kind="series", kind="series",
+71 -51
View File
@@ -1,15 +1,39 @@
# coding: utf-8 # coding: utf-8
import pickle
import numpy as np
import tables
import pickle
import astropy.units as u import astropy.units as u
import pandas as pd import pandas as pd
import pymses.utils.regions as reg
from fil_finder import FilFinder2D
from pymses.filters import RegionFilter
from skimage.morphology import medial_axis from skimage.morphology import medial_axis
import os
from functools import partial
from scipy.stats import linregress
from astrophysix.simdm.results import Snapshot
import pymses
import pymses.utils.regions as reg
from pymses.analysis import (
Camera,
FractionOperator,
MaxLevelOperator,
ScalarOperator,
raytracing,
slicing,
splatting,
)
from pymses.filters import CellsToPoints, RegionFilter
from fil_finder import FilFinder2D
import pspec_new import pspec_new
from baseprocessor import *
from units import U
from baseprocessor import HDF5Container, Rule, norm_getter, simple_getter, vect_getter
# Getters # Getters
@@ -167,7 +191,7 @@ class PostProcessor(HDF5Container):
path_out=None, path_out=None,
pp_params=None, pp_params=None,
tag=None, tag=None,
unit_time=cst.year, unit_time=U.year,
): ):
""" """
Creates the basic structures needed for the outputs Creates the basic structures needed for the outputs
@@ -246,21 +270,25 @@ class PostProcessor(HDF5Container):
if self.pp_params.pymses.filter: if self.pp_params.pymses.filter:
center = (self.max_coords + self.min_coords) / 2.0 center = (self.max_coords + self.min_coords) / 2.0
im_extent = [ im_extent = np.array(
self.min_coords[0], [
self.max_coords[0], self.min_coords[0],
self.min_coords[1], self.max_coords[0],
self.max_coords[1], self.min_coords[1],
] self.max_coords[1],
]
)
distance = (self.max_coords[2] - self.min_coords[2]) / 2.0 distance = (self.max_coords[2] - self.min_coords[2]) / 2.0
else: else:
center = self.pp_params.pymses.center center = self.pp_params.pymses.center
im_extent = [ im_extent = np.array(
(-self._radius + center[0]), [
(self._radius + center[0]), (-self._radius + center[0]),
(-self._radius + center[1]), (self._radius + center[0]),
(self._radius + center[1]), (-self._radius + center[1]),
] (self._radius + center[1]),
]
)
distance = self._radius distance = self._radius
# Get time # Get time
@@ -274,7 +302,7 @@ class PostProcessor(HDF5Container):
self.save.root._v_attrs.unit_length = self.info["unit_length"] self.save.root._v_attrs.unit_length = self.info["unit_length"]
self.save.root._v_attrs.time = time self.save.root._v_attrs.time = time
if not "/maps" in self.save: if "/maps" not in self.save:
self.save.create_group("/", "maps", "2D maps") self.save.create_group("/", "maps", "2D maps")
self.save.root.maps._v_attrs.center = center self.save.root.maps._v_attrs.center = center
self.save.root.maps._v_attrs.radius = self._radius self.save.root.maps._v_attrs.radius = self._radius
@@ -283,7 +311,6 @@ class PostProcessor(HDF5Container):
# Initialize cameras # Initialize cameras
self._cam = {} self._cam = {}
for ax_los in self._ax_nb: # los = line of sight for ax_los in self._ax_nb: # los = line of sight
ax_h = self._axes_h[ax_los]
ax_v = self._axes_v[ax_los] ax_v = self._axes_v[ax_los]
self._cam[ax_los] = Camera( self._cam[ax_los] = Camera(
@@ -306,13 +333,13 @@ class PostProcessor(HDF5Container):
else: else:
self.fil = None self.fil = None
time_in_right_unit = self.info["time"] * self.info["unit_time"].express(
unit_time
)
self.snapshot = Snapshot( self.snapshot = Snapshot(
name=str(self.num), name=str(self.num),
description="", description="",
time=( time=(time_in_right_unit, unit_time),
self.info["time"] * self.info["unit_time"].express(unit_time),
unit_time,
),
directory_path=self.path, directory_path=self.path,
data_reference="OUTPUT_{}".format(self.num), data_reference="OUTPUT_{}".format(self.num),
) )
@@ -323,7 +350,7 @@ class PostProcessor(HDF5Container):
""" """
Load all cells from the source file in the memory. Load all cells from the source file in the memory.
Cells will be accessible trough self.cells Cells will be accessible trough self.cells
(/!\ Long and memory heavy) (Long and memory heavy)
""" """
if not self.cells_loaded: if not self.cells_loaded:
if os.path.exists(self.cells_filename): if os.path.exists(self.cells_filename):
@@ -424,7 +451,7 @@ class PostProcessor(HDF5Container):
""" Azimuthal velocity """ """ Azimuthal velocity """
return self.oct_getter_vect_phi(dset, "vel") return self.oct_getter_vect_phi(dset, "vel")
def _slice(self, getter, ax_los="z", z=0.0, unit=cst.none): def _slice(self, getter, ax_los="z", z=0.0, unit=U.none):
""" """
Slice process function. Slice process function.
Return a slice of the source box. Return a slice of the source box.
@@ -440,7 +467,7 @@ class PostProcessor(HDF5Container):
z : float z : float
Coordinate of the slice on the ax_los axis Coordinate of the slice on the ax_los axis
unit : cst.Unit unit : U.Unit
Unit of the resulting dataset Unit of the resulting dataset
Returns Returns
@@ -452,9 +479,7 @@ class PostProcessor(HDF5Container):
datamap = slicing.SliceMap(self._amr, self._cam[ax_los], op, z=z) datamap = slicing.SliceMap(self._amr, self._cam[ax_los], op, z=z)
return datamap.map.T return datamap.map.T
def _ax_avg( def _ax_avg(self, getter, ax_los, unit=U.none, mass_weighted=True, surf_qty=False):
self, getter, ax_los, unit=cst.none, mass_weighted=True, surf_qty=False
):
""" """
Map of the average of a quantity (given by getter) along an axis (ax_los) Map of the average of a quantity (given by getter) along an axis (ax_los)
Returns 2D array if getter returns a scalar quantity Returns 2D array if getter returns a scalar quantity
@@ -505,7 +530,7 @@ class PostProcessor(HDF5Container):
self.load_cells() self.load_cells()
return np.sort(np.unique(self.cells["pos"][:, axis])) return np.sort(np.unique(self.cells["pos"][:, axis]))
def _plane_avg_uniform(self, getter, axis, unit=cst.none, mass_weighted=False): def _plane_avg_uniform(self, getter, axis, unit=U.none, mass_weighted=False):
""" """
Profile of the average of a quantity (given by getter) perpendicular to an axis Profile of the average of a quantity (given by getter) perpendicular to an axis
WARNING : This version only works on an uniform grid, need of a box version for AMR WARNING : This version only works on an uniform grid, need of a box version for AMR
@@ -622,17 +647,17 @@ class PostProcessor(HDF5Container):
""" """
self.load_cells() self.load_cells()
mean_speed = self.save.get_node("/globals/mwa_speed").read() mean_speed = self.save.get_node("/globals/mwa_speed").read()
mean_speed = mean_speed * self.info["unit_velocity"].express(cst.km_s) mean_speed = mean_speed * self.info["unit_velocity"].express(U.km_s)
vel_fluct = (self.cells)["vel"] * self.info["unit_velocity"].express( vel_fluct = (self.cells)["vel"] * self.info["unit_velocity"].express(
cst.km_s U.km_s
) - mean_speed ) - mean_speed
B_norm = getter_B_int(self.cells) B_norm = getter_B_int(self.cells)
B_norm = B_norm * self.info["unit_mag"].express(cst.T) B_norm = B_norm * self.info["unit_mag"].express(U.T)
v_norm = np.sqrt( v_norm = np.sqrt(
np.sum((vel_fluct * 10 ** (3)) ** 2, axis=1) np.sum((vel_fluct * 10 ** (3)) ** 2, axis=1)
) # v_norm [m/s] et vel_fluct [km/s] ) # v_norm [m/s] et vel_fluct [km/s]
rho = getter_rho(self.cells) rho = getter_rho(self.cells)
rho_kg_m3 = rho * self.info["unit_density"].express(cst.kg_m3) rho_kg_m3 = rho * self.info["unit_density"].express(U.kg_m3)
eb = 0.5 * (B_norm) ** 2 / (4 * np.pi * 10 ** (-7)) # mettre le bon mu eb = 0.5 * (B_norm) ** 2 / (4 * np.pi * 10 ** (-7)) # mettre le bon mu
ek = 0.5 * v_norm ** 2 * rho_kg_m3 ek = 0.5 * v_norm ** 2 * rho_kg_m3
rapport = ek / eb rapport = ek / eb
@@ -1024,8 +1049,7 @@ class PostProcessor(HDF5Container):
return alpha return alpha
alpha_f = ( alpha_f = (
self._ax_avg(getter_alpha_num, "z", unit=cst.none, mass_weighted=True) self._ax_avg(getter_alpha_num, "z", unit=U.none, mass_weighted=True) / T_avg
/ T_avg
) )
# alpha # alpha
@@ -1047,7 +1071,7 @@ class PostProcessor(HDF5Container):
gphi = self.oct_getter_vect_phi(dset, "g") gphi = self.oct_getter_vect_phi(dset, "g")
return gr * gphi / (4 * np.pi * self.G) return gr * gphi / (4 * np.pi * self.G)
alpha_g = self._ax_avg(getter_alpha_grav, "z", unit=cst.none, surf_qty=True) / ( alpha_g = self._ax_avg(getter_alpha_grav, "z", unit=U.none, surf_qty=True) / (
coldens * T_avg coldens * T_avg
) )
@@ -1166,10 +1190,6 @@ class PostProcessor(HDF5Container):
GM = self.G * self.pp_params.disk.mass_star # Mass parameter GM = self.G * self.pp_params.disk.mass_star # Mass parameter
# Get mask for filaments
fil = self.fil
mask_fil = np.asarray(fil.mask.copy(), dtype=bool)
# Find center of filaments # Find center of filaments
i_center, j_center = self._filaments_center() i_center, j_center = self._filaments_center()
@@ -1283,7 +1303,7 @@ class PostProcessor(HDF5Container):
self._alpha_disk, self._alpha_disk,
"Map of the Shakura&Sunaev alpha parameter for disks", "Map of the Shakura&Sunaev alpha parameter for disks",
"/maps", "/maps",
unit=cst.none, unit=U.none,
dependencies=[ dependencies=[
"avg_map_rho_avg", "avg_map_rho_avg",
"avg_map_T_mwavg", "avg_map_T_mwavg",
@@ -1297,7 +1317,7 @@ class PostProcessor(HDF5Container):
"Map of the graviational contrib to\ "Map of the graviational contrib to\
Shakura&Sunaev alpha parameter for disks", Shakura&Sunaev alpha parameter for disks",
"/maps", "/maps",
unit=cst.none, unit=U.none,
dependencies=["avg_map_coldens", "avg_map_T_mwavg"], dependencies=["avg_map_coldens", "avg_map_T_mwavg"],
), ),
"rho": Rule( "rho": Rule(
@@ -1373,9 +1393,9 @@ class PostProcessor(HDF5Container):
self._sinks, self._sinks,
group="/datasets", group="/datasets",
unit={ unit={
"Id": cst.none, "Id": U.none,
"M": cst.Msun, "M": U.Msun,
"dmf": cst.Msun, "dmf": U.Msun,
"x": "", "x": "",
"y": "", "y": "",
"z": "", "z": "",
@@ -1388,9 +1408,9 @@ class PostProcessor(HDF5Container):
"lz": "|l|", "lz": "|l|",
"acc_rate": "[Msol/y]", "acc_rate": "[Msol/y]",
"acc_lum": "[Lsol]", "acc_lum": "[Lsol]",
"age": cst.year, "age": U.year,
"int_lum": "[Lsol]", "int_lum": "[Lsol]",
"Teff": cst.K, "Teff": U.K,
}, },
), ),
"pspec": Rule(self, self._pspec, "Power spectrum", "/hdf5"), "pspec": Rule(self, self._pspec, "Power spectrum", "/hdf5"),
@@ -1473,7 +1493,7 @@ class PostProcessor(HDF5Container):
"Global cos fluctuation-PDF", "Global cos fluctuation-PDF",
"/hist", "/hist",
dependencies=["mwa_speed"], dependencies=["mwa_speed"],
unit=cst.none, unit=U.none,
), ),
"Brho": Rule( "Brho": Rule(
self, self,
@@ -1488,7 +1508,7 @@ class PostProcessor(HDF5Container):
"Average of Ek/Eb as a function of rho", "Average of Ek/Eb as a function of rho",
"/datasets", "/datasets",
dependencies=["mwa_speed"], dependencies=["mwa_speed"],
unit={"rho": self.info["unit_density"], "Ek_Eb_rho": cst.none}, unit={"rho": self.info["unit_density"], "Ek_Eb_rho": U.none},
), ),
# Profiles # Profiles
"axis": Rule( "axis": Rule(
-1
View File
@@ -4,7 +4,6 @@ import os
import re import re
import munch import munch
import numpy as np
import yaml import yaml
_dir_path = os.path.dirname(os.path.realpath(__file__)) _dir_path = os.path.dirname(os.path.realpath(__file__))
+4 -2
View File
@@ -5,10 +5,12 @@
import glob import glob
import os import os
from functools import partial from functools import partial
import numpy as np
import yaml
import f90nml import f90nml
from pp_params import * from pp_params import default_params
class NamelistRecursive: class NamelistRecursive:
@@ -247,7 +249,7 @@ class RunSelector:
runs = self.nml_select(runs, filter_nml) runs = self.nml_select(runs, filter_nml)
# Sort by the value in the namelist of sort_run_by # Sort by the value in the namelist of sort_run_by
if not sort_run_by is None: if sort_run_by is not None:
if type(sort_run_by) == str: if type(sort_run_by) == str:
sort_run_by = [sort_run_by] sort_run_by = [sort_run_by]
for nml_key in reversed(sort_run_by): for nml_key in reversed(sort_run_by):
+15 -17
View File
@@ -1,13 +1,13 @@
# coding: utf-8 # coding: utf-8
import astrophysix.units as cst import astrophysix.units as U
create_unit = cst.Unit.create_unit create_unit = U.Unit.create_unit
def parse_exp_unit(u): def parse_exp_unit(u):
splitted = u.split("^") splitted = u.split("^")
name_u = cst.Unit.from_name(splitted[0]).latex.replace("text", "math") name_u = U.Unit.from_name(splitted[0]).latex.replace("text", "math")
exp = "" exp = ""
if len(splitted) > 1: if len(splitted) > 1:
exp = "^{" + str(splitted[1]) + "}" exp = "^{" + str(splitted[1]) + "}"
@@ -39,7 +39,7 @@ def unit_str(unit, base=None, prefix="", format=" [{unit}]"):
prefix : str to put befor the unit prefix : str to put befor the unit
format : str with the {unit} key, to put external decoration format : str with the {unit} key, to put external decoration
""" """
if unit == cst.none: if unit == U.none:
return "" return ""
elif not base is None: elif not base is None:
coeff = unit.express(base) coeff = unit.express(base)
@@ -64,33 +64,31 @@ def unit_str(unit, base=None, prefix="", format=" [{unit}]"):
return format.format(unit=u_str) return format.format(unit=u_str)
cst.coldens = create_unit( U.coldens = create_unit(
"Msun.pc^-2", base_unit=cst.Msun / cst.pc ** 2, descr="Column density" "Msun.pc^-2", base_unit=U.Msun / U.pc ** 2, descr="Column density"
) )
cst.km_s = create_unit("km.s^-1", base_unit=cst.km / cst.s, descr="Speed") U.km_s = create_unit("km.s^-1", base_unit=U.km / U.s, descr="Speed")
cst.Msun_pc3 = create_unit( U.Msun_pc3 = create_unit("Msun.pc^-3", base_unit=U.Msun / U.pc ** 3, descr="Density")
"Msun.pc^-3", base_unit=cst.Msun / cst.pc ** 3, descr="Density"
)
cst.kg_m3 = create_unit("kg.m^-3", base_unit=cst.kg / cst.m ** 3, descr="Density") U.kg_m3 = create_unit("kg.m^-3", base_unit=U.kg / U.m ** 3, descr="Density")
cst.ssfr = create_unit( U.ssfr = create_unit(
"Msun.year^-1.pc^-2", "Msun.year^-1.pc^-2",
base_unit=cst.Msun / cst.year / cst.pc ** 2, base_unit=U.Msun / U.year / U.pc ** 2,
descr="Surfacic SFR", descr="Surfacic SFR",
) )
# latex='M$_{\odot}$.yr$^{-1}$.pc$^{-2}$') # latex='M$_{\odot}$.yr$^{-1}$.pc$^{-2}$')
cst.ssfrG = create_unit( U.ssfrG = create_unit(
"Msun.Gyr^-1.pc^-2", "Msun.Gyr^-1.pc^-2",
base_unit=1e-9 * cst.Msun / cst.year / cst.pc ** 2, base_unit=1e-9 * U.Msun / U.year / U.pc ** 2,
descr="Surfacic SFR", descr="Surfacic SFR",
latex="\mathrm{M}_{\odot}.\mathrm{Gyr}^{-1}.\mathrm{pc}^{-2}", latex="\mathrm{M}_{\odot}.\mathrm{Gyr}^{-1}.\mathrm{pc}^{-2}",
) )
cst.uG = create_unit( U.uG = create_unit(
"μG", base_unit=1e-10 * cst.T, descr="Micro Gauss", latex="\\mu\\mathrm{G}" "μG", base_unit=1e-10 * U.T, descr="Micro Gauss", latex="\\mu\\mathrm{G}"
) )