import json
import logging
import os
import shutil
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import List, Callable, Union, Sequence
import yaml
from csep.core.forecasts import GriddedForecast, CatalogForecast
from floatcsep.infrastructure.environments import EnvironmentFactory
from floatcsep.infrastructure.registries import ModelRegistry
from floatcsep.infrastructure.repositories import ForecastRepository
from floatcsep.utils.accessors import from_zenodo, from_git
from floatcsep.utils.helpers import timewindow2str, str2timewindow, parse_nested_dicts
log = logging.getLogger("floatLogger")
[docs]
class Model(ABC):
"""
The Model class represents a forecast generating system. It can represent a source code, a
collection or a single forecast, etc. A Model can be instantiated from either the filesystem
or host repositories.
Args:
name (str): Name of the model
model_path (str): Relative path of the model (file or code) to the work directory
zenodo_id (int): Zenodo ID or record of the Model
giturl (str): Link to a git repository
repo_hash (str): Specific commit/branch/tag hash.
authors (list[str]): Authors' names metadata
doi: Digital Object Identifier metadata:
"""
def __init__(
self,
name: str,
zenodo_id: int = None,
giturl: str = None,
repo_hash: str = None,
authors: List[str] = None,
doi: str = None,
**kwargs,
):
self.name = name
self.zenodo_id = zenodo_id
self.giturl = giturl
self.repo_hash = repo_hash
self.authors = authors
self.doi = doi
self.registry = None
self.forecasts = {}
self.force_stage = False
self.__dict__.update(**kwargs)
[docs]
@abstractmethod
def stage(self, time_windows=None) -> None:
"""Prepares the stage for a model run."""
pass
[docs]
@abstractmethod
def get_forecast(self, tstring: str, region=None):
"""Retrieves the forecast based on a time window."""
pass
[docs]
@abstractmethod
def create_forecast(self, tstring: str, **kwargs) -> None:
"""Creates a forecast based on the model's logic."""
pass
[docs]
@abstractmethod
def get_source(self):
"""Retrieves the model from a web repository"""
pass
[docs]
def as_dict(self, excluded=("name", "repository", "workdir", "environment")):
"""
Returns:
Dictionary with relevant attributes. Model can be re-instantiated from this dict
"""
list_walk = [
(i, j) for i, j in sorted(self.__dict__.items()) if not i.startswith("_") and j
]
dict_walk = {i: j for i, j in list_walk if i not in excluded}
dict_walk["path"] = self.registry.rel(dict_walk.pop("registry").path).as_posix()
return {self.name: parse_nested_dicts(dict_walk)}
[docs]
@classmethod
def from_dict(cls, record: dict, **kwargs):
"""
Returns a Model instance from a dictionary containing the required attributes. Can be
used to quickly instantiate from a .yml file.
Args:
record (dict): Contains the keywords from the ``__init__`` method.
Note:
Must have either an explicit key `name`, or it must have
exactly one key with the model's name, whose values are
the remaining ``__init__`` keywords.
Returns:
A Model instance
"""
if "name" in record.keys():
return cls(**record)
elif len(record) != 1:
raise IndexError("A single model has not been passed")
name = next(iter(record))
return cls(name=name, **record[name], **kwargs)
[docs]
@classmethod
def factory(cls, model_cfg: dict) -> "Model":
"""Factory method. Instantiate first on any explicit option provided in the model
configuration.
"""
model_path = [*model_cfg.values()][0]["model_path"]
workdir = [*model_cfg.values()][0].get("workdir", "")
model_class = [*model_cfg.values()][0].get("class", "")
if model_class in ("ti", "time_independent"):
return TimeIndependentModel.from_dict(model_cfg)
elif model_class in ("td", "time_dependent"):
return TimeDependentModel.from_dict(model_cfg)
if os.path.isfile(os.path.join(workdir, model_path)):
return TimeIndependentModel.from_dict(model_cfg)
elif "func" in [*model_cfg.values()][0]:
return TimeDependentModel.from_dict(model_cfg)
else:
return TimeIndependentModel.from_dict(model_cfg)
[docs]
class TimeIndependentModel(Model):
"""
A Model whose forecast is invariant in time. A TimeIndependentModel is commonly represented
by a single forecast as static data.
"""
def __init__(self, name: str, model_path: str, forecast_unit=1, store_db=False, **kwargs):
"""
Args:
name (str): The name of the model.
model_path (str): The path to the model data.
forecast_unit (float): The unit of time for the forecast.
store_db (bool): flag to indicate whether to store the model in a database.
"""
super().__init__(name, **kwargs)
self.forecast_unit = forecast_unit
self.registry = ModelRegistry.factory(
model_name=name, workdir=kwargs.get("workdir", os.getcwd()), path=model_path
)
self.repository = ForecastRepository.factory(
self.registry, model_class=self.__class__.__name__, **kwargs
)
[docs]
def stage(self, time_windows: Sequence[Sequence[datetime]] = None, **kwargs) -> None:
"""
Acquire the forecast data if it is not in the file system. Sets the paths internally
(or database pointers) to the forecast data.
Args:
time_windows (list): time_windows that the forecast data represents.
"""
if self.force_stage or not self.registry.file_exists("path"):
os.makedirs(self.registry.dir, exist_ok=True)
self.get_source() # now the TI version above
self.registry.build_tree(time_windows=time_windows, model_class=self.__class__.__name__)
[docs]
def get_source(self) -> None:
"""
Fetch a single-file forecast into the model directory
"""
container = self.registry.dir
expected_file = self.registry.path
os.makedirs(container, exist_ok=True)
if expected_file.exists() and expected_file.is_file() and not self.force_stage:
return
os.makedirs(container, exist_ok=True)
if expected_file.exists() and expected_file.is_file() and not self.force_stage:
return
if self.giturl:
from_git(self.giturl, str(container), branch=self.repo_hash, force=self.force_stage)
elif self.zenodo_id:
from_zenodo(
self.zenodo_id,
str(container),
force=self.force_stage,
keys=[expected_file.name],
)
else:
pass
if not expected_file.exists() or not expected_file.is_file():
raise FileNotFoundError(
f"Expected TI model file at: {expected_file}\n" f"Fetched into: {container}"
)
[docs]
def get_forecast(
self, tstring: Union[str, list] = None, region=None
) -> Union[GriddedForecast, List[GriddedForecast]]:
"""Wrapper that just returns a forecast when requested."""
return self.repository.load_forecast(
tstring, name=self.name, region=region, forecast_unit=self.forecast_unit
)
[docs]
def create_forecast(self, tstring: str, **kwargs) -> None:
"""
Creates a forecast from the model source and a given time window.
Note:
Dummy function for this class, although eventually could also be a source
code (e.g., a Smoothed-Seismicity-Model built from the input-catalog).
"""
return
[docs]
class TimeDependentModel(Model):
"""
Model that creates varying forecasts depending on a time window. Requires either a
collection of Forecasts or a function/source code that returns a Forecast.
"""
def __init__(
self,
name: str,
model_path: str,
func: Union[str, Callable] = None,
func_kwargs: dict = None,
args_file: str = "args.txt",
input_cat: str = "catalog.csv",
fmt: str = "csv",
**kwargs,
) -> None:
"""
Args:
name: The name of the model
model_path: The path to either the source code, or the folder containing static
forecasts.
func: A function/command that runs the model.
func_kwargs: The keyword arguments to run the model. They are usually (over)written
into the file `{model_path}/input/{args_file}`
args_file: Name of the arguments file that will be used to create forecasts
input_cat: Name of the file that will be used as input catalog to create forecasts
**kwargs: Additional keyword parameters, such as a ``prefix`` (str) for the
resulting forecast file paths, ``args_file`` (str) as the path for the model
arguments file or ``input_cat`` that indicates where the input catalog will be
placed for the model.
"""
super().__init__(name, **kwargs)
self.func = func
self.func_kwargs = func_kwargs or {}
self.registry = ModelRegistry.factory(
model_name=name,
workdir=kwargs.get("workdir", os.getcwd()),
path=model_path,
fmt=fmt,
args_file=args_file,
input_cat=input_cat,
)
self.repository = ForecastRepository.factory(
self.registry, model_class=self.__class__.__name__, **kwargs
)
self.build = kwargs.get("build", None)
self.force_build = kwargs.get("force_build", False)
if self.func:
self.environment = EnvironmentFactory.get_env(
self.build, self.name, self.registry.path.as_posix()
)
[docs]
def stage(
self, time_windows=None, run_mode="sequential", stage_dir="results", run_id="run"
) -> None:
"""
Retrieve model artifacts and Set up its interface with the experiment.
1) Get the model from filesystem, Zenodo or Git. Prepares the directory
2) If source code, creates the computational environment (conda, venv or Docker)
3) Prepares the registry tree: filepaths/keys corresponding to existing forecasts
and those to be generated, as well as input catalog and arguments file.
"""
need_source = (
self.force_stage
or not self.registry.path.exists()
or (self.registry.path.is_dir() and not any(self.registry.path.iterdir()))
)
if need_source:
os.makedirs(self.registry.dir, exist_ok=True)
self.get_source(self.zenodo_id, self.giturl, branch=self.repo_hash)
if hasattr(self, "environment"):
self.environment.create_environment(force=self.force_build)
self.registry.build_tree(
time_windows=time_windows,
model_class=self.__class__.__name__,
prefix=self.__dict__.get("prefix", self.name),
run_mode=run_mode,
stage_dir=stage_dir,
run_id=run_id,
)
[docs]
def get_source(self, zenodo_id: int = None, giturl: str = None, **kwargs) -> None:
"""
Search, download or clone the model source in the filesystem from git or zenodo, respectively.
Args:
zenodo_id (int): Zenodo identifier of the repository. Usually as
`https://zenodo.org/record/{zenodo_id}`
giturl (str): git remote repository URL from which to clone the
source
**kwargs: see :func:`~floatcsep.utils.from_zenodo` and
:func:`~floatcsep.utils.from_git`
"""
target_dir = self.registry.path # TD expects a directory here
# If forced, start clean so clone/download won’t fail on non-empty
if self.force_stage and target_dir.exists():
shutil.rmtree(target_dir)
target_dir.mkdir(parents=True, exist_ok=True)
if self.giturl:
from_git(self.giturl, target_dir.as_posix(), branch=self.repo_hash, force=False)
elif self.zenodo_id:
from_zenodo(self.zenodo_id, target_dir.as_posix(), force=self.force_stage)
else:
pass
if not target_dir.exists() or not target_dir.is_dir():
raise FileNotFoundError(f"Expected TD model directory at: {target_dir}")
[docs]
def get_forecast(
self, tstring: Union[str, list] = None, region=None
) -> Union[GriddedForecast, CatalogForecast, List[GriddedForecast], List[CatalogForecast]]:
"""
Wrapper that returns a forecast, by accessing the model's forecast repository.
Note:
The argument ``tstring`` is formatted according to how the Experiment
handles time_windows, specified in the functions
:func:`~floatcsep.utils.helpers.timewindow2str` and
:func:`~floatcsep.utils.helpers.str2timewindow`
Args:
tstring: String representing the start and end of the forecast,
formatted as 'YY1-MM1-DD1_YY2-MM2-DD2'.
region: String representing the region for which to return a forecast.
If None, will return a forecast for all regions.
"""
return self.repository.load_forecast(tstring, name=self.name, region=region)
[docs]
def create_forecast(self, tstring: str, **kwargs) -> None:
"""
Creates a forecast from the model source and a given time window.
Note:
The argument ``tstring`` is formatted according to how the Experiment
handles time_windows, specified in the functions
:func:`~floatcsep.utils.helpers.timewindow2str` and
:func:`~floatcsep.utils.helpers.str2timewindow`
Args:
tstring: String representing the start and end of the forecast,
formatted as 'YY1-MM1-DD1_YY2-MM2-DD2'.
**kwargs:
"""
start_date, end_date = str2timewindow(tstring)
# Model src is a func or binary
if not kwargs.get("force") and self.registry.forecast_exists(tstring):
log.info(f"Forecast for {tstring} of model {self.name} already exists")
return
self.prepare_args(start_date, end_date, **kwargs)
self.prepare_extra_input(start_date, end_date, **kwargs)
log.info(
f"[Model] Running {self.name} using {self.environment.__class__.__name__}:"
f" {timewindow2str([start_date, end_date])}"
)
input_dir = self.registry.get_input_dir(tstring)
forecast_dir = self.registry.get_forecast_dir()
run_label = f"{self.name}_{tstring}"
self.environment.run_command(
command=f"{self.func}",
run_label=run_label,
input_volume=input_dir,
forecast_volume=forecast_dir,
)
[docs]
def prepare_args(self, start: datetime, end: datetime, **kwargs) -> None:
"""
When the model is a source code, the args file is a plain text file with the required
input arguments. At minimum, it consists of the start and end of the forecast
timewindow, but it can also contain other arguments (e.g., minimum magnitude, number of
simulations, cutoff learning magnitude, etc.)
Args:
start: start date of the forecast timewindow
end: end date of the forecast timewindow
**kwargs: represents additional model arguments (name/value pair)
"""
window_str = timewindow2str([start, end])
dest_path = Path(self.registry.get_args_key(window_str))
tpl_path = self.registry.get_args_template_path()
suffix = tpl_path.suffix.lower()
if suffix == ".txt":
def load_kv(fp: Path) -> dict:
data = {}
if fp.exists():
with open(fp, "r") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" in line:
k, v = line.split("=", 1)
data[k.strip()] = v.strip()
return data
def dump_kv(fp: Path, data: dict) -> None:
ordered_keys = []
for k in ("start_date", "end_date"):
if k in data:
ordered_keys.append(k)
ordered_keys += sorted(
k for k in data.keys() if k not in ("start_date", "end_date")
)
with open(fp, "w") as f:
for k in ordered_keys:
f.write(f"{k} = {data[k]}\n")
data = load_kv(tpl_path)
data["start_date"] = start.isoformat()
data["end_date"] = end.isoformat()
for k, v in (kwargs or {}).items():
data[k] = v
for k, v in (self.func_kwargs or {}).items():
data[k] = v
dump_kv(dest_path, data)
elif suffix == ".json":
base = {}
if tpl_path.exists():
with open(tpl_path, "r") as f:
base = json.load(f) or {}
base["start_date"] = start.isoformat()
base["end_date"] = end.isoformat()
base.update(kwargs or {})
base.update(self.func_kwargs or {})
with open(dest_path, "w") as f:
json.dump(base, f, indent=2)
elif suffix in (".yml", ".yaml"):
if tpl_path.exists():
with open(tpl_path, "r") as f:
data = yaml.safe_load(f) or {}
else:
data = {}
data["start_date"] = start.isoformat()
data["end_date"] = end.isoformat()
def nested_update(dest: dict, src: dict, max_depth: int = 3, _lvl: int = 1):
for key, val in (src or {}).items():
if (
_lvl < max_depth
and key in dest
and isinstance(dest[key], dict)
and isinstance(val, dict)
):
nested_update(dest[key], val, max_depth, _lvl + 1)
else:
dest[key] = val
nested_update(data, self.func_kwargs or {})
nested_update(data, kwargs or {})
with open(dest_path, "w") as f:
yaml.safe_dump(data, f, indent=2)
else:
raise ValueError(f"Unsupported args file format: {suffix}")