"""
Library functions and command-line access to the simulation driver.
"""
import os, pickle, pathlib
from typing import NamedTuple, Dict
from logging import getLogger
from sailfish.event import Recurrence, RecurringEvent, ParseRecurrenceError
from sailfish.setup_base import SetupBase, SetupError
from sailfish.solver_base import SolverBase
from sailfish.solvers import (
SolverInitializationError,
register_solver_extension,
make_solver,
)
logger = getLogger(__name__)
user_build_config = dict()
[docs]class ConfigurationError(Exception):
"""An invalid runtime configuration"""
[docs]class ExtensionError(Exception):
"""An invalid extension was specified"""
[docs]def keyed_event(item):
"""
Return a key, val pair where the value string describes a recurrence rule.
"""
key, val = item.split("=")
return key, Recurrence.from_str(val)
[docs]def keyed_value(item):
"""
Return a key, val pair from a "key=val" string.
The value string is python-eval'd so it must be a valid Python expression.
"""
try:
key, val = item.split("=")
return key, eval(val)
except NameError:
return key, val
except SyntaxError:
raise ConfigurationError(f"badly formed model parameter value {val} in {item}")
except ValueError:
raise ConfigurationError(f"badly formed model parameter {item}")
[docs]def first_not_none(*args):
for arg in args:
if arg is not None:
return arg
[docs]def update_dict_where_none(new_dict, old_dict, frozen=[]):
"""
Like `dict.update`, except `key=value` pairs in `old_dict` are only used
to add / overwrite values in `new_dict` if they are `None` or missing.
"""
for key in old_dict:
old_val = old_dict.get(key)
new_val = new_dict.get(key)
if type(new_val) is dict and type(old_val) is dict:
update_dict_where_none(new_val, old_val)
elif old_val is not None:
if new_val is None:
new_dict[key] = old_val
elif key in frozen and new_val != old_val:
raise ConfigurationError(f"{key} cannot be changed")
[docs]def update_where_none(new, old, frozen=[]):
"""
Same as `update_dict_where_none`, except operates on (immutable) named tuple
instances and returns a new named tuple.
"""
new_dict = new._asdict()
old_dict = old._asdict()
update_dict_where_none(new_dict, old_dict, frozen)
return type(new)(**new_dict)
# The functions below were written to allow state to be written in terms of
# builtin Python objects (no sailfish application classes). That would be good
# practice because then pickle files can be opened on systems that don't have
# sailfish installed, so these functions should possibly be restored at some
# point. However in practice it's more convenient for post-processing to have
# immediate access to the sailfish objects after unpickling. It's also tedious
# to ensure that all sailfish objects have been removed in `asdict`, and are
# properly restored in `fromdict`.
# def asdict(t):
# """
# Convert named tuple instances to dictionaries.
# This function operates recursively on the data members of a dictionary or
# named tuple. Each object that is a named tuple is mapped to its dictionary
# representation, with an additional `_type` key to indicate the named tuple
# subclass. This mapping is applied to the simulation state before pickling,
# so that `sailfish` module is not required to unpickle the checkpoint
# files.
# """
# if type(t) is dict:
# return {k: asdict(v) for k, v in t.items()}
# if isinstance(t, tuple):
# d = {k: asdict(v) for k, v in t._asdict().items()}
# d["_type"] = ".".join([type(t).__module__, type(t).__name__])
# return d
# return t
# def fromdict(d):
# """
# Convert from dictionaries to named tuples.
# This function performs the inverse of the `asdict` method, and is applied
# to pickled simulation states.
# """
# import sailfish
# if type(d) is dict:
# if "_type" in d:
# cls = eval(d["_type"])
# del d["_type"]
# return cls(**{k: fromdict(v) for k, v in d.items()})
# else:
# return {k: fromdict(v) for k, v in d.items()}
# else:
# return d
[docs]def write_checkpoint(number, outdir, state):
"""
Write the simulation state to a file, as a pickle.
"""
if type(number) is int:
filename = f"chkpt.{number:04d}.pk"
elif type(number) is str:
filename = f"chkpt.{number}.pk"
else:
raise ValueError("number arg must be int or str")
if outdir is not None:
pathlib.Path(outdir).mkdir(parents=True, exist_ok=True)
filename = os.path.join(outdir, filename)
state_checkpoint_dict = dict(
iteration=state.iteration,
time=state.solver.time,
timestep_dt=state.timestep_dt,
cfl_number=state.cfl_number,
solution=state.solver.solution,
primitive=state.solver.primitive,
timeseries=state.timeseries,
solver=state.setup.solver,
solver_options=state.solver.options,
event_states=state.event_states,
driver=state.driver,
model_parameters=state.setup.model_parameter_dict(),
setup_name=state.setup.dash_case_class_name(),
mesh=state.mesh,
**state.setup.checkpoint_diagnostics(state.solver.time),
)
with open(filename, "wb") as chkpt:
logger.info(f"write checkpoint {chkpt.name}")
pickle.dump(state_checkpoint_dict, chkpt)
[docs]def load_checkpoint(chkpt_file):
"""
Load the simulation state from a pickle file.
"""
try:
with open(chkpt_file, "rb") as file:
return pickle.load(file)
except FileNotFoundError:
raise ConfigurationError(f"could not open checkpoint file {chkpt_file}")
[docs]def newest_chkpt_in_directory(directory_name):
import re
expr = re.compile("chkpt\.([0-9]+)\.pk")
list_of_matches = list(
filter(None, (expr.search(f) for f in os.listdir(directory_name)))
)
list_of_matches.sort(key=lambda l: int(l.groups()[0]))
for match in reversed(list_of_matches):
try:
path = os.path.join(directory_name, match.group())
load_checkpoint(path) # exception if checkpoint is corrupted
return path
except:
logger.warning(f"skipping corrupt checkpoint file {path}")
raise ConfigurationError("the specified directory did not have a usable checkpoint")
[docs]def append_timeseries(state):
"""
Append to the driver state timeseries for post-processing.
"""
reductions = state.solver.reductions()
if reductions:
state.timeseries.append(reductions)
logger.info(f"record timeseries event {len(state.timeseries)}")
else:
logger.warning(
"timeseries event ignored because solver does not provide reductions"
)
[docs]class DriverArgs(NamedTuple):
"""
Contains data used by the driver.
"""
setup_name: str = None
chkpt_file: str = None
model_parameters: dict = None
solver_options: dict = None
cfl_number: float = None
end_time: float = None
execution_mode: str = None
fold: int = None
resolution: int = None
num_patches: int = None
events: Dict[str, Recurrence] = dict()
new_timestep_cadence: int = None
verbose_output: str = ""
[docs] def from_namespace(args):
"""
Construct an instance from an argparse-type namespace object.
"""
driver = DriverArgs(
**{k: w for k, w in vars(args).items() if k in DriverArgs._fields}
)
parts = args.command.split(":")
if args.restart_dir:
setup_name = None
chkpt_file = newest_chkpt_in_directory(parts[0])
elif parts[0].endswith(".pk"):
setup_name = None
chkpt_file = parts[0]
else:
setup_name = parts[0]
chkpt_file = None
try:
model_parameters = dict(keyed_value(a) for a in parts[1:])
except IndexError:
model_parameters = dict()
model_parameters.update(args.model_parameters)
return driver._replace(
setup_name=setup_name,
chkpt_file=chkpt_file,
model_parameters=model_parameters,
)
[docs]class DriverState(NamedTuple):
"""
Contains the stateful variables in use by the `simulate` function.
An instance of this class is yielded by `simulate` each time an event takes
place.
"""
iteration: int
driver: DriverArgs
mesh: object
timeseries: list
event_states: list
solver: SolverBase
setup: SetupBase
cfl_number: float
timestep_dt: float
[docs]def simulate(driver):
"""
Main generator for running simulations.
If invoked with a `DriverArgs` instance in `driver`, the other arguments
are ignored. Otherwise, the driver is created from the setup name, model
paramters, and keyword arguments.
This function is a generator: it yields its state at a sequence of
pause points, defined by the `events` dictionary.
"""
from time import perf_counter
from sailfish import __version__ as version
from sailfish.kernel.system import configure_build, log_system_info, measure_time
from sailfish.event import Recurrence
from sailfish import solvers
main_logger = getLogger("main_logger")
main_logger.info(f"\nsailfish {version}\n")
if driver.setup_name:
"""
Generate an initial driver state from command line arguments, model
parametrs, and a setup instance.
"""
logger.info(f"start new simulation with setup {driver.setup_name}")
setup = SetupBase.find_setup_class(driver.setup_name)(
**driver.model_parameters or dict()
)
driver = driver._replace(
resolution=driver.resolution or setup.default_resolution,
)
iteration = 0
time = setup.start_time
event_states = {name: RecurringEvent() for name in driver.events}
solution = None
timeseries = list()
dt = None
elif driver.chkpt_file:
"""
Load driver state from a checkpoint file. The setup model parameters
are updated with any items given on the command line after the setup
name. All command line arguments are also restorted from the previous
session, but are updated with the command line argument given for this
session, except for "frozen" arguments.
"""
logger.info(f"load checkpoint {driver.chkpt_file}")
chkpt = load_checkpoint(driver.chkpt_file)
setup_class = SetupBase.find_setup_class(chkpt["setup_name"])
driver = update_where_none(driver, chkpt["driver"], frozen=["resolution"])
update_dict_where_none(
driver.model_parameters,
chkpt["model_parameters"],
frozen=list(setup_class.immutable_parameter_keys()),
)
update_dict_where_none(
driver.solver_options,
chkpt["solver_options"],
)
setup = setup_class(**driver.model_parameters)
iteration = chkpt["iteration"]
time = chkpt["time"]
event_states = chkpt["event_states"]
solution = chkpt["solution"]
try:
dt = chkpt["timestep_dt"]
except KeyError:
# Forgive missing timestep_dt in the checkpoint, this key was
# added recently (JZ 4-25-22). Prior to this change, timestep_dt
# was not stored in the checkpoint file, and a restarted
# simulation could end up different from a continuous one, when:
# (1) new_timestep_cadence > 1, and (2) a new dt was not computed
# just before the checkpoint was written. The differences would be
# due to a slightly different timestep used, after it's recomputed
# following the restart, and they would be minor. Still, restarted
# runs are supposed to be bitwise identical to continuous ones.
# Older checkpoints will still work, but they will not have this
# guarantee.
logger.warning(
"timestep_dt not in checkpoint, will recompute it on first iteration"
)
dt = None
try:
timeseries = chkpt["timeseries"]
except KeyError:
logger.warning("older checkpoint version: no timeseries")
for event in driver.events:
if event not in event_states:
event_states[event] = RecurringEvent()
else:
raise ConfigurationError("driver args must specify setup_name or chkpt_file")
"""
This line ensures that if a checkpoint event is present, then it is
emitted last, ensuring that any modifications to the driver state (e.g.
time series sample) happening in response to other events triggered in the
same iteration, are reflected in the checkpoint file that is written.
Note: The Python 3.7+ specifications guarantee to that dictionary
iteration order reflects the insertion order. This behavior is also
present in the CPython implementation of Python 3.6
"""
if "checkpoint" in event_states:
event_states["checkpoint"] = event_states.pop("checkpoint")
"""
Initialize and log state in the system module. The build system influences
JIT-compiled module code. Currently the build parameters are inferred from
the platform (Linux or MacOS), but in the future these should also be
extensible by a system-specific rc-style configuration file.
"""
configure_build(**user_build_config, execution_mode=driver.execution_mode)
log_system_info(driver.execution_mode or "cpu")
mode = driver.execution_mode or "cpu"
fold = driver.fold or 10
mesh = setup.mesh(driver.resolution)
end_time = first_not_none(driver.end_time, setup.default_end_time, float("inf"))
reference_time = setup.reference_time_scale
new_timestep_cadence = driver.new_timestep_cadence or 1
dt = None
if "physics" in driver.verbose_output:
logger.info(f"physics struct (setup -> solver) {setup.physics}")
if (
"options" in driver.verbose_output
or "solver" in driver.verbose_output
or "solver-options" in driver.verbose_output
):
logger.info(f"options struct (cmdline -> solver) {driver.solver_options}")
solver = make_solver(
setup.solver,
setup.physics,
driver.solver_options,
setup=setup,
mesh=mesh,
time=time,
solution=solution,
num_patches=driver.num_patches or 1,
mode=mode,
)
if driver.cfl_number is not None and driver.cfl_number > solver.maximum_cfl:
raise ConfigurationError(
f"cfl number {driver.cfl_number} "
f"is greater than {solver.maximum_cfl}, "
f"max allowed by solver {setup.solver}"
)
cfl_number = driver.cfl_number or solver.recommended_cfl
for name, event in driver.events.items():
logger.info(f"recurrence for {name} event is {event}")
logger.info(f"run until t={end_time}")
logger.info(f"CFL number is {cfl_number}")
logger.info(f"simulation time / user time is {reference_time:0.4f}")
logger.info(f"recompute dt every {new_timestep_cadence} iterations")
setup.print_model_parameters(newlines=True, logger=main_logger)
def grab_state():
"""
Collect items from the driver and solver state, as well as run
details, sufficient for restarts and post processing.
"""
return DriverState(
iteration=iteration,
driver=driver,
mesh=mesh,
timeseries=timeseries,
event_states=event_states,
solver=solver,
setup=setup,
cfl_number=cfl_number,
timestep_dt=dt,
)
while True:
siml_time = solver.time
user_time = siml_time / reference_time
"""
Run the main simulation loop. Iterations are grouped according the
the fold parameter. Side effects including the iteration message are
performed between fold boundaries.
"""
for name in event_states:
event = driver.events[name]
state = event_states[name]
if event_states[name].is_due(user_time, event):
event_states[name] = state.next(user_time, event)
yield name, state.number, grab_state()
if end_time is not None and user_time >= end_time:
break
with measure_time(mode) as fold_time:
for _ in range(fold):
if dt is None or (iteration % new_timestep_cadence == 0):
dx = mesh.min_spacing(siml_time)
dt = dx / solver.maximum_wavespeed() * cfl_number
solver.advance(dt)
iteration += 1
Mzps = mesh.num_total_zones / fold_time() * 1e-6 * fold
main_logger.info(
f"[{iteration:04d}] t={user_time:0.3f} dt={dt:.3e} Mzps={Mzps:.3f}"
)
yield "end", None, grab_state()
[docs]def run(setup_name, quiet=True, **kwargs):
"""
Run a simulation with no side-effects, and return the final state.
This function is intended for use by scripts that run a simulation and
inspect the output in-memory, or otherwise handle archiving the final
result themselves. Event monitoring is not supported. If `quiet=True`
(default) then logging is suppressed.
"""
import sailfish.setups
if "events" in kwargs:
raise ValueError("events are not supported")
driver = DriverArgs(setup_name=setup_name, **kwargs)
if not quiet:
init_logging()
load_user_config()
return next(simulate(driver))[2]
[docs]def init_logging():
"""
Convenience method to enable logging to standard output.
This function is called from the `main` entry point (i.e. when sailfish is
used as a command line tool). However when sailfish is used as a library,
logging is not enabled by default (Python's `logging` module recommends
that libraries should not install any event handlers on the root logger).
This function enables a sensible logging configuration, so if the calling
application or script is not particular about how logging should take
place, but it doesn't want the driver to be silent, then invoking this
function will do it for you. Note this function is also invoked by the
`run` function if :code:`quiet=False` is passed to it.
"""
from sys import stdout
from logging import StreamHandler, Formatter, getLogger, INFO
class RunFormatter(Formatter):
def format(self, record):
name = record.name.replace("sailfish.", "")
if name == "main_logger":
return f"{record.msg}"
if record.levelno <= 20:
return f"[{name}] {record.msg}"
else:
return f"[{name}:{record.levelname.lower()}] {record.msg}"
handler = StreamHandler(stdout)
handler.setFormatter(RunFormatter())
root_logger = getLogger()
root_logger.addHandler(handler)
root_logger.setLevel(INFO)
[docs]def load_user_config():
"""
Initialize user extensions: setups and solvers outside the main codebase.
This function is called by the `main` entry point and the `run` API function
to load custom setups provided by the user. Extensions are defined in the
`extensions` section of the .sailfish file. The .sailfish file is loaded
from the current working directory.
"""
from configparser import ConfigParser, ParsingError
from importlib import import_module
try:
config = ConfigParser()
config.read(".sailfish")
try:
for setup_extension in config["extensions"]["setups"].split():
import_module(setup_extension)
except KeyError:
pass
try:
for solver_extension in config["extensions"]["solvers"].split():
register_solver_extension(solver_extension)
except KeyError:
pass
try:
for key, val in config["build"].items():
user_build_config[key] = val
except KeyError:
pass
except ModuleNotFoundError as e:
raise ExtensionError(e)
except ParsingError as e:
raise ConfigurationError(e)
[docs]def main():
"""
General-purpose command line interface.
"""
import argparse
import sailfish
import sailfish.setups
class MakeDict(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, dict(values))
def add_dict_entry(key):
class AddDictEntry(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
getattr(namespace, self.dest)[key] = values
return AddDictEntry
parser = argparse.ArgumentParser(
prog="sailfish",
usage=argparse.SUPPRESS,
description="sailfish is a GPU-accelerated astrophysical gasdynamics code",
)
parser.add_argument(
"--version",
action="version",
version=f"%(prog)s {sailfish.__version__}",
)
parser.add_argument(
"command",
nargs="?",
help="setup name or restart file (if directory, then load newest checkpoint)",
)
parser.add_argument(
"--describe",
action="store_true",
help="print a description of the setup and exit",
)
parser.add_argument(
"--resolution",
"-n",
metavar="N",
type=int,
help="grid resolution",
)
parser.add_argument(
"--patches",
metavar="N",
type=int,
dest="num_patches",
help="number of patches for domain decomposition",
)
parser.add_argument(
"--cfl",
dest="cfl_number",
metavar="C",
type=float,
help="CFL parameter",
)
parser.add_argument(
"--fold",
"-f",
metavar="F",
type=int,
help="iterations between messages and side effects",
)
parser.add_argument(
"--new-timestep-cadence",
metavar="C",
type=int,
help="iterations between recomputing the timestep dt",
)
parser.add_argument(
"--events",
nargs="*",
metavar="E=V",
type=keyed_event,
action=MakeDict,
default=dict(),
help="a sequence of events and recurrence rules to be emitted",
)
parser.add_argument(
"--restart-dir",
action="store_true",
help="the command argument is a directory; restart from newest checkpoint therein",
)
parser.add_argument(
"--final-chkpt",
action="store_true",
help="write chkpt.final.pk on exit",
)
parser.add_argument(
"--checkpoint",
"-c",
metavar="C",
type=Recurrence.from_str,
action=add_dict_entry("checkpoint"),
dest="events",
help="checkpoint recurrence [<delta>|<log:mul>]",
)
parser.add_argument(
"--timeseries",
"-t",
metavar="T",
type=Recurrence.from_str,
action=add_dict_entry("timeseries"),
dest="events",
help="timeseries recurrence [<delta>|<log:mul>]",
)
parser.add_argument(
"--model",
nargs="*",
metavar="K=V",
type=keyed_value,
action=MakeDict,
default=dict(),
dest="model_parameters",
help="key-value pairs given as models parameters to the setup",
)
parser.add_argument(
"--solver",
nargs="*",
metavar="K=V",
type=keyed_value,
action=MakeDict,
default=dict(),
dest="solver_options",
help="key-value pairs passed as options to the solver",
)
parser.add_argument(
"--outdir",
"-o",
metavar="D",
type=str,
dest="output_directory",
help="directory where checkpoints are written",
)
parser.add_argument(
"--end-time",
"-e",
metavar="T",
type=float,
help="when to end the simulation",
)
parser.add_argument(
"--event-handlers-file",
metavar="F",
type=str,
help="path to a module defining a get_event_handlers function",
)
parser.add_argument(
"--verbose-output",
metavar="P",
type=str,
default="",
help="detailed print solver structs [physics,options]",
)
exec_group = parser.add_mutually_exclusive_group()
exec_group.add_argument(
"--mode",
dest="execution_mode",
choices=["cpu", "omp", "gpu"],
help="execution mode",
)
exec_group.add_argument(
"--use-omp",
"-p",
dest="execution_mode",
action="store_const",
const="omp",
help="multi-core with OpenMP",
)
exec_group.add_argument(
"--use-gpu",
"-g",
dest="execution_mode",
action="store_const",
const="gpu",
help="gpu acceleration",
)
try:
init_logging()
load_user_config()
args = parser.parse_args()
if args.describe and args.command is not None:
setup_name = args.command.split(":")[0]
SetupBase.find_setup_class(setup_name).describe_class()
elif args.command is None:
print("specify setup:")
for setup in SetupBase.__subclasses__():
print(f" {setup.dash_case_class_name()}")
else:
driver = DriverArgs.from_namespace(args)
outdir = (
args.output_directory
or (driver.chkpt_file and os.path.dirname(driver.chkpt_file))
or "."
)
if args.event_handlers_file is not None:
import importlib.util
spec = importlib.util.spec_from_file_location(
"events_handler_module", args.event_handlers_file
)
events_handler_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(events_handler_module)
events_dict = events_handler_module.get_event_handlers()
else:
events_dict = dict()
for name, number, state in simulate(driver):
if name == "timeseries":
append_timeseries(state)
elif name == "checkpoint":
write_checkpoint(number, outdir, state)
elif name == "end":
if args.final_chkpt:
write_checkpoint("final", outdir, state)
elif name in events_dict:
events_dict[name](number, outdir, state, logger)
else:
logger.warning(f"unrecognized event {name}")
except ConfigurationError as e:
print(f"bad configuration: {e}")
except ExtensionError as e:
print(f"bad extension: {e}")
except SetupError as e:
print(f"setup error: {e}")
except ParseRecurrenceError as e:
print(f"parse error: {e}")
except SolverInitializationError as e:
print(f"solver initialization error: {e}")
except OSError as e:
print(f"file system error: {e}")
except ModuleNotFoundError as e:
print(f"unsatisfied dependency: {e}")
except KeyboardInterrupt:
print("")