import os
import json
import numpy
from matplotlib import pyplot
from typing import Dict, Callable, Union, Sequence, List
from csep.core.catalogs import CSEPCatalog
from csep.core.forecasts import GriddedForecast
from csep.models import EvaluationResult
from floatcsep.model import Model
from floatcsep.utils import parse_csep_func, timewindow2str
from floatcsep.registry import PathTree
[docs]
class Evaluation:
"""
Class representing a Scoring Test, which wraps the evaluation function,
its arguments, parameters and hyper-parameters.
Args:
name (str): Name of the Test
func (str, ~typing.Callable): Test function/callable
func_kwargs (dict): Keyword arguments of the test function
ref_model (str): String of the reference model, if any
plot_func (str, ~typing.Callable): Test's plotting function
plot_args (list,dict): Positional arguments of the plotting function
plot_kwargs (list,dict): Keyword arguments of the plotting function
"""
_TYPES = {
'number_test': 'consistency',
'spatial_test': 'consistency',
'magnitude_test': 'consistency',
'likelihood_test': 'consistency',
'conditional_likelihood_test': 'consistency',
'negative_binomial_number_test': 'consistency',
'binary_spatial_test': 'consistency',
'binomial_spatial_test': 'consistency',
'brier_score': 'consistency',
'binary_conditional_likelihood_test': 'consistency',
'paired_t_test': 'comparative',
'paired_ttest_point_process': 'comparative',
'w_test': 'comparative',
'binary_paired_t_test': 'comparative',
'vector_poisson_t_w_test': 'batch',
'sequential_likelihood': 'sequential',
'sequential_information_gain': 'sequential_comparative'
}
[docs]
def __init__(self, name: str, func: Union[str, Callable],
func_kwargs: Dict = None,
ref_model: (str, Model) = None,
plot_func: Callable = None,
plot_args: Sequence = None,
plot_kwargs: Dict = None,
markdown: str = '') -> None:
self.name = name
self.func = parse_csep_func(func)
self.func_kwargs = func_kwargs or {} # todo set default args from exp?
self.ref_model = ref_model
self.plot_func = None
self.plot_args = None
self.plot_kwargs = None
self.parse_plots(plot_func, plot_args, plot_kwargs)
self.markdown = markdown
self.type = Evaluation._TYPES.get(self.func.__name__)
@property
def type(self):
"""
Returns the type of the test, mapped from the class attribute
Evaluation._TYPES
"""
return self._type
@type.setter
def type(self, type_list: Union[str, Sequence[str]]):
if isinstance(type_list, Sequence):
if ('Comparative' in type_list) and (self.ref_model is None):
raise TypeError('A comparative-type test should have a'
' reference model assigned')
self._type = type_list
def parse_plots(self, plot_func, plot_args, plot_kwargs):
if isinstance(plot_func, str):
self.plot_func = [parse_csep_func(plot_func)]
self.plot_args = [plot_args] if plot_args else [{}]
self.plot_kwargs = [plot_kwargs] if plot_kwargs else [{}]
elif isinstance(plot_func, (list, dict)):
if isinstance(plot_func, dict):
plot_func = [{i: j} for i, j in plot_func.items()]
if plot_args is not None or plot_kwargs is not None:
raise ValueError('If multiple plot functions are passed,'
'each func should be a dictionary with '
'plot_args and plot_kwargs passed as '
'dictionaries beneath each func.')
func_names = [list(i.keys())[0] for i in plot_func]
self.plot_func = [parse_csep_func(func) for func in func_names]
self.plot_args = [i[j].get('plot_args', {})
for i, j in zip(plot_func, func_names)]
self.plot_kwargs = [i[j].get('plot_kwargs', {})
for i, j in zip(plot_func, func_names)]
[docs]
def prepare_args(self,
timewindow: Union[str, list],
catpath: Union[str, list],
model: Union[Model, Sequence[Model]],
ref_model: Union[Model, Sequence] = None,
region = None) -> tuple:
"""
Prepares the positional argument for the Evaluation function.
Args:
timewindow (str, list): Time window string (or list of str)
formatted from :meth:`floatcsep.utils.timewindow2str`
catpath (str,list): Path(s) pointing to the filtered catalog(s)
model (:class:`floatcsep:model.Model`): Model to be evaluated
ref_model (:class:`floatcsep:model.Model`, list): Reference model (or
models) reference for the evaluation.
Returns:
A tuple of the positional arguments required by the evaluation
function :meth:`Evaluation.func`.
"""
# Subtasks
# ========
# Get forecast from model
# Read Catalog
# Share forecast region with catalog
# Check if ref_model is None, Model or List[Model]
# Prepare argument tuple
forecast = model.get_forecast(timewindow, region)
catalog = self.get_catalog(catpath, forecast)
if isinstance(ref_model, Model):
# Args: (Fc, RFc, Cat)
ref_forecast = ref_model.get_forecast(timewindow, region)
test_args = (forecast, ref_forecast, catalog)
elif isinstance(ref_model, list):
# Args: (Fc, [RFc], Cat)
ref_forecasts = [i.get_forecast(timewindow, region)
for i in ref_model]
test_args = (forecast, ref_forecasts, catalog)
else:
# Args: (Fc, Cat)
test_args = (forecast, catalog)
return test_args
[docs]
@staticmethod
def get_catalog(
catalog_path: Union[str, Sequence[str]],
forecast: Union[GriddedForecast, Sequence[GriddedForecast]]
) -> Union[CSEPCatalog, List[CSEPCatalog]]:
"""
Reads the catalog(s) from the given path(s). References the catalog
region to the forecast region.
Args:
catalog_path (str, list(str)): Path to the existing catalog
forecast (:class:`~csep.core.forecasts.GriddedForecast`): Forecast
object, onto which the catalog will be confronted for testing.
Returns:
"""
if isinstance(catalog_path, str):
eval_cat = CSEPCatalog.load_json(catalog_path)
eval_cat.region = getattr(forecast, 'region')
else:
eval_cat = [CSEPCatalog.load_json(i) for i in catalog_path]
if (len(forecast) != len(eval_cat)) or (not isinstance(forecast,
Sequence)):
raise IndexError('Amount of passed catalogs and forecats must '
'be the same')
for cat, fc in zip(eval_cat, forecast):
cat.region = getattr(fc, 'region', None)
return eval_cat
[docs]
def compute(self,
timewindow: Union[str, list],
catalog: str,
model: Model,
path: str,
ref_model: Union[Model, Sequence[Model]] = None,
region=None) -> None:
"""
Runs the test, structuring the arguments according to the
test-typology/function-signature
Args:
timewindow (list[~datetime.datetime, ~datetime.datetime]): Pair of
datetime objects representing the testing time span
catalog (str): Path to the filtered catalog
model (Model, list[Model]): Model(s) to be evaluated
ref_model: Model to be used as reference
path: Path to store the Evaluation result
region: region to filter a catalog forecast.
Returns:
"""
test_args = self.prepare_args(timewindow,
catpath=catalog,
model=model,
ref_model=ref_model,
region=region)
evaluation_result = self.func(*test_args, **self.func_kwargs)
self.write_result(evaluation_result, path)
[docs]
@staticmethod
def write_result(result: EvaluationResult,
path: str) -> None:
"""
Dumps a test result into a json file.
"""
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, numpy.integer):
return int(obj)
if isinstance(obj, numpy.floating):
return float(obj)
if isinstance(obj, numpy.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
with open(path, 'w') as _file:
json.dump(result.to_dict(), _file, indent=4, cls=NumpyEncoder)
def read_results(self, window: str, models: List[Model],
tree: PathTree) -> List:
"""
Reads an Evaluation result for a given time window and returns a list
of the results for all tested models.
"""
test_results = []
if not isinstance(window, str):
wstr_ = timewindow2str(window)
else:
wstr_ = window
for i in models:
eval_path = tree(wstr_, 'evaluations', self, i.name)
with open(eval_path, 'r') as file_:
model_eval = EvaluationResult.from_dict(json.load(file_))
test_results.append(model_eval)
return test_results
def plot_results(self,
timewindow: Union[str, List],
models: List[Model],
tree: PathTree,
dpi: int = 300,
show: bool = False) -> None:
"""
Plots all evaluation results
Args:
dpi: Figure resolution with which to save
show: show in runtime
"""
if isinstance(timewindow, str):
timewindow = [timewindow]
for func, fargs, fkwargs in zip(self.plot_func, self.plot_args,
self.plot_kwargs):
if self.type in ['consistency', 'comparative']:
try:
for time_str in timewindow:
fig_path = tree(time_str, 'figures', self.name)
results = self.read_results(time_str, models, tree)
ax = func(results, plot_args=fargs, **fkwargs)
if 'code' in fargs:
exec(fargs['code'])
pyplot.savefig(fig_path, dpi=dpi)
if show:
pyplot.show()
except AttributeError as msg:
if self.type in ['consistency', 'comparative']:
for time_str in timewindow:
results = self.read_results(time_str, models, tree)
for result, model in zip(results, models):
fig_name = f'{self.name}_{model.name}'
tree.paths[time_str]['figures'][fig_name] =\
os.path.join(time_str, 'figures', fig_name)
fig_path = tree(time_str, 'figures', fig_name)
ax = func(result, plot_args=fargs, **fkwargs,
show=False)
if 'code' in fargs:
exec(fargs['code'])
pyplot.savefig(fig_path, dpi=dpi)
if show:
pyplot.show()
elif self.type in ['sequential', 'sequential_comparative', 'batch']:
fig_path = tree(timewindow[-1], 'figures', self.name)
results = self.read_results(timewindow[-1], models, tree)
ax = func(results, plot_args=fargs, **fkwargs)
if 'code' in fargs:
exec(fargs['code'])
pyplot.savefig(fig_path, dpi=dpi)
if show:
pyplot.show()
def as_dict(self) -> dict:
"""
Represents an Evaluation instance as a dictionary, which can be
serialized and then parsed
"""
out = {}
included = ['model', 'ref_model', 'func_kwargs']
for k, v in self.__dict__.items():
if k in included and v:
out[k] = v
func_str = f'{self.func.__module__}.{self.func.__name__}'
plot_func_str = []
for i, j, k in zip(self.plot_func, self.plot_args, self.plot_kwargs):
pfunc = {f'{i.__module__}.{i.__name__}': {'plot_args': j,
'plot_kwargs': k}}
plot_func_str.append(pfunc)
return {self.name: {**out,
'func': func_str,
'plot_func': plot_func_str}}
def __str__(self):
return (
f"name: {self.name}\n"
f"function: {self.func.__name__}\n"
f"reference model: {self.ref_model}\n"
f"kwargs: {self.func_kwargs}\n"
)
[docs]
@classmethod
def from_dict(cls, record):
"""
Parses a dictionary and re-instantiate an Evaluation object
"""
if len(record) != 1:
raise IndexError('A single test has not been passed')
name = next(iter(record))
return cls(name=name, **record[name])