Source code for dcase_util.tfkeras.callbacks

# !/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function, absolute_import
from six import iteritems
import numpy
import os
import collections
import logging
import datetime

from dcase_util.ui import FancyStringifier, FancyHTMLStringifier, FancyLogger, FancyPrinter, FancyHTMLPrinter
from dcase_util.utils import Timer, setup_logging, is_jupyter


[docs]class BaseCallback(object): """Base class for Callbacks """
[docs] def __init__(self, epochs=None, manual_update=False, external_metric_labels=None, **kwargs): self.params = None self.model = None self.manual_update = manual_update self.epochs = epochs self.epoch = 0 if external_metric_labels is None: external_metric_labels = collections.OrderedDict() self.external_metric_labels = external_metric_labels self.external_metric = collections.OrderedDict() self.keras_metrics = [ 'binary_accuracy', 'categorical_accuracy', 'sparse_categorical_accuracy', 'top_k_categorical_accuracy' ]
@property def logger(self): logger = logging.getLogger(__name__) if not logger.handlers: setup_logging() return logger def set_model(self, model): self.model = model def set_params(self, params): self.params = params def on_train_begin(self, logs=None): pass def on_train_end(self, logs=None): pass def on_epoch_begin(self, epoch, logs=None): pass def on_epoch_end(self, epoch, logs=None): pass def on_test_begin(self, logs=None): pass def on_test_end(self, logs=None): pass def on_predict_begin(self, logs=None): pass def on_predict_end(self, logs=None): pass def on_train_batch_begin(self, batch, logs=None): pass def on_train_batch_end(self, batch, logs=None): pass def on_test_batch_begin(self, batch, logs=None): pass def on_test_batch_end(self, batch, logs=None): pass def on_predict_batch_begin(self, batch, logs=None): pass def on_predict_batch_end(self, batch, logs=None): pass def update(self): pass def add_external_metric(self, metric_label): pass def set_external_metric_value(self, metric_label, metric_value): pass def get_operator(self, metric): metric = metric.lower() if metric.endswith('error_rate') or metric.endswith('er'): return numpy.less elif (metric.endswith('f_measure') or metric.endswith('fmeasure') or metric.endswith('fscore') or metric.endswith('f-score')): return numpy.greater elif metric.endswith('accuracy') or metric.endswith('acc'): return numpy.greater else: return numpy.less def _implements_train_batch_hooks(self): return False def _implements_test_batch_hooks(self): return False def _implements_predict_batch_hooks(self): return False
[docs]class ProgressLoggerCallback(BaseCallback): """Keras callback to show metrics in logging interface. Implements Keras Callback API. This callback is very similar to standard ``ProgbarLogger`` Keras callback, however it adds support for logging interface, and external metrics (metrics calculated outside Keras training process). """
[docs] def __init__(self, manual_update=False, epochs=None, external_metric_labels=None, metric=None, loss=None, manual_update_interval=1, output_type='logging', show_timing=True, **kwargs): """Constructor Parameters ---------- epochs : int Total amount of epochs Default value None metric : str Metric name Default value None manual_update : bool Manually update callback, use this to when injecting external metrics Default value False manual_update_interval : int Epoch interval for manual update, used anticipate updates Default value 1 output_type : str Output type, either 'logging', 'console', or 'notebook' Default value 'logging' show_timing : bool Show per epoch time and estimated time remaining Default value True external_metric_labels : dict or OrderedDict Dictionary with {'metric_label': 'metric_name'} Default value None """ kwargs.update({ 'manual_update': manual_update, 'epochs': epochs, 'external_metric_labels': external_metric_labels, }) super(ProgressLoggerCallback, self).__init__(**kwargs) if isinstance(metric, str): self.metric = metric elif callable(metric): self.metric = metric.__name__ self.loss = loss self.manual_update_interval = manual_update_interval self.output_type = output_type self.show_timing = show_timing self.timer = Timer() self.ui = FancyStringifier() if self.output_type == 'logging': self.output_target = FancyLogger() elif self.output_type == 'console': self.output_target = FancyPrinter() elif self.output_type == 'notebook': self.output_target = FancyHTMLPrinter() self.ui = FancyHTMLStringifier() self.seen = 0 self.log_values = [] self.most_recent_values = collections.OrderedDict() self.most_recent_values['l_tra'] = None self.most_recent_values['l_val'] = None self.most_recent_values['m_tra'] = None self.most_recent_values['m_val'] = None self.data = { 'l_tra': numpy.empty((self.epochs,)), 'l_val': numpy.empty((self.epochs,)), 'm_tra': numpy.empty((self.epochs,)), 'm_val': numpy.empty((self.epochs,)), } self.data['l_tra'][:] = numpy.nan self.data['l_val'][:] = numpy.nan self.data['m_tra'][:] = numpy.nan self.data['m_val'][:] = numpy.nan for metric_label in self.external_metric_labels: self.data[metric_label] = numpy.empty((self.epochs,)) self.data[metric_label][:] = numpy.nan self.header_shown = False self.last_update_epoch = 0 self.target = None self.first_epoch = None self.total_time = 0
def on_train_begin(self, logs=None): if self.epochs is None: self.epochs = self.params['epochs'] if not self.header_shown: output = '' output += self.ui.line('Training') + '\n' if self.external_metric_labels: output += self.ui.row( '', 'Loss', 'Metric', 'Ext. metrics', '', widths=[10, 26, 26, len(self.external_metric_labels)*15, 17] ) + '\n' header2 = ['', self.loss, self.metric] header3 = ['Epoch', 'Train', 'Val', 'Train', 'Val'] widths = [10, 13, 13, 13, 13] for metric_label, metric_name in iteritems(self.external_metric_labels): header2.append('') header3.append(metric_name) widths.append(15) if self.show_timing: header2.append('') header3.append('Step time') widths.append(20) header2.append('') header3.append('Remaining time') widths.append(20) output += self.ui.row(*header2) + '\n' output += self.ui.row(*header3, widths=widths) + '\n' output += self.ui.row_sep() else: if self.show_timing: output += self.ui.row('', 'Loss', 'Metric', '', '', widths=[10, 36, 36, 16, 15]) + '\n' output += self.ui.row('', self.loss, self.metric, '') + '\n' output += self.ui.row( 'Epoch', 'Train', 'Val', 'Train', 'Val', 'Step', 'Remaining', widths=[10, 18, 18, 18, 18, 16, 15] ) + '\n' else: output += self.ui.row('', 'Loss', 'Metric', widths=[10, 36, 36]) + '\n' output += self.ui.row('', self.loss, self.metric) + '\n' output += self.ui.row( 'Epoch', 'Train', 'Val', 'Train', 'Val', widths=[10, 18, 18, 18, 18] ) + '\n' output += self.ui.row_sep() self.output_target.line(output) # Show header only once self.header_shown = True def on_epoch_begin(self, epoch, logs=None): self.epoch = epoch + 1 if 'steps' in self.params: self.target = self.params['steps'] elif 'samples' in self.params: self.target = self.params['samples'] self.seen = 0 self.timer.start() def on_train_batch_begin(self, batch, logs=None): if self.target and self.seen < self.target: self.log_values = [] def on_train_batch_end(self, batch, logs=None): logs = logs or {} batch_size = logs.get('size', 0) self.seen += batch_size if self.metric and self.metric in logs: self.log_values.append((self.metric, logs[self.metric])) metrics_list = list(logs.keys()) if 'metrics' in self.params: metrics_list = self.params['metrics'] for k in metrics_list: if k in logs: self.log_values.append((k, logs[k])) def on_epoch_end(self, epoch, logs=None): self.timer.stop() self.total_time += self.timer.elapsed() self.epoch = epoch if self.first_epoch is None: # Store first epoch number self.first_epoch = self.epoch logs = logs or {} # Reset values self.most_recent_values['l_tra'] = None self.most_recent_values['l_val'] = None self.most_recent_values['m_tra'] = None self.most_recent_values['m_val'] = None # Collect values metrics_list = list(logs.keys()) if 'metrics' in self.params: metrics_list = self.params['metrics'] for k in metrics_list: if k in logs: self.log_values.append((k, logs[k])) if k == 'loss': self.data['l_tra'][self.epoch] = logs[k] self.most_recent_values['l_tra'] = '{:4.3f}'.format(logs[k]) elif k == 'val_loss': self.data['l_val'][self.epoch] = logs[k] self.most_recent_values['l_val'] = '{:4.3f}'.format(logs[k]) elif self.metric and k.endswith(self.metric): if k.startswith('val_'): self.data['m_val'][self.epoch] = logs[k] self.most_recent_values['m_val'] = '{:4.3f}'.format(logs[k]) else: self.data['m_tra'][self.epoch] = logs[k] self.most_recent_values['m_tra'] = '{:4.3f}'.format(logs[k]) for metric_label in self.external_metric_labels: if metric_label in self.external_metric: metric_name = self.external_metric_labels[metric_label] value = self.external_metric[metric_label] if metric_name.endswith('f_measure') or metric_name.endswith('f_score'): self.most_recent_values[metric_label] = '{:3.1f}'.format(value * 100) else: self.most_recent_values[metric_label] = '{:4.3f}'.format(value) if (not self.manual_update or (self.epoch - self.last_update_epoch > 0 and (self.epoch+1) % self.manual_update_interval)): # Update logged progress self.update_progress_log() def update(self): """Update """ self.update_progress_log() self.last_update_epoch = self.epoch def update_progress_log(self): """Update progress to logging interface """ if self.epoch - self.last_update_epoch: data = [ self.epoch, self.data['l_tra'][self.epoch], self.data['l_val'][self.epoch] if 'l_val' in self.most_recent_values else '-', self.data['m_tra'][self.epoch], self.data['m_val'][self.epoch] if self.most_recent_values['m_val'] else '-' ] types = [ 'int', 'float4', 'float4', 'float4', 'float4' ] for metric_label in self.external_metric_labels: if metric_label in self.external_metric: value = self.data[metric_label][self.epoch] if numpy.isnan(value): value = ' ' * 10 else: if self.external_metric_labels[metric_label].endswith('f_measure') or self.external_metric_labels[metric_label].endswith('f_score'): value = float(value) * 100 types.append('float2') else: value = float(value) types.append('float4') data.append(value) else: data.append('') if self.show_timing: # Add step time step_time = datetime.timedelta(seconds=self.timer.elapsed()) data.append(str(step_time)[:-3]) types.append('str') # Add remaining time average_time_per_epoch = self.total_time / float(self.epoch-(self.first_epoch-1)) remaining_time_seconds = datetime.timedelta( seconds=(self.epochs - 1 - self.epoch) * average_time_per_epoch ) data.append('-'+str(remaining_time_seconds).split('.', 2)[0]) types.append('str') output = self.ui.row(*data, types=types) self.output_target.line(output) def add_external_metric(self, metric_id): """Add external metric to be monitored Parameters ---------- metric_id : str Metric name """ if metric_id not in self.external_metric_labels: self.external_metric_labels[metric_id] = metric_id if metric_id not in self.data: self.data[metric_id] = numpy.empty((self.epochs,)) self.data[metric_id][:] = numpy.nan def set_external_metric_value(self, metric_label, metric_value): """Add external metric value Parameters ---------- metric_label : str Metric label metric_value : numeric Metric value """ self.external_metric[metric_label] = metric_value self.data[metric_label][self.epoch] = metric_value def _implements_train_batch_hooks(self): return True
[docs]class ProgressPlotterCallback(ProgressLoggerCallback): """Keras callback to plot progress during the training process and save final progress into figure. Implements Keras Callback API. """
[docs] def __init__(self, epochs=None, manual_update=False, external_metric_labels=None, metric=None, loss=None, filename=None, plotting_rate=10, interactive=True, save=False, focus_span=10, **kwargs): """Constructor Parameters ---------- epochs : int Total amount of epochs Default value None metric : str Metric name Default value None manual_update : bool Manually update callback, use this to when injecting external metrics Default value False interactive : bool Show plot during the training and update with plotting rate Default value True plotting_rate : int Plot update rate in seconds Default value 10 save : bool Save plot on disk, plotting rate applies Default value False filename : str Filename of figure Default value None focus_span : int Epoch amount to highlight, and show separately in the plot. Default value 10 """ kwargs.update({ 'manual_update': manual_update, 'epochs': epochs, 'external_metric_labels': external_metric_labels, 'metric': metric, 'loss': loss, }) super(ProgressPlotterCallback, self).__init__(**kwargs) self.filename = filename if filename is not None: # Get file format for the output plot file_extension = os.path.splitext(self.filename)[1] if file_extension == '.eps': self.format = 'eps' elif file_extension == '.svg': self.format = 'svg' elif file_extension == '.pdf': self.format = 'pdf' elif file_extension == '.png': self.format = 'png' self.plotting_rate = plotting_rate self.interactive = interactive self.save = save self.focus_span = focus_span if self.focus_span > self.epochs: self.focus_span = self.epochs self.timer.start() self.data = { 'l_tra': numpy.empty((self.epochs,)), 'l_val': numpy.empty((self.epochs,)), 'm_tra': numpy.empty((self.epochs,)), 'm_val': numpy.empty((self.epochs,)), } self.data['l_tra'][:] = numpy.nan self.data['l_val'][:] = numpy.nan self.data['m_tra'][:] = numpy.nan self.data['m_val'][:] = numpy.nan for metric_label in self.external_metric_labels: self.data[metric_label] = numpy.empty((self.epochs,)) self.data[metric_label][:] = numpy.nan self.ax1_1 = None self.ax1_2 = None self.ax2_1 = None self.ax2_2 = None self.extra_main = {} self.extra_highlight = {} import matplotlib.pyplot as plt import warnings import matplotlib.cbook warnings.filterwarnings("ignore", category=matplotlib.cbook.mplDeprecation) figure_height = 8 if len(self.external_metric_labels) > 2: figure_height = 8 + len(self.external_metric_labels) self.figure = plt.figure(num=None, figsize=(18, figure_height), dpi=80, facecolor='w', edgecolor='k') self.draw() if self.interactive: plt.show(block=False) plt.pause(0.1)
def draw(self): """Draw plot """ import matplotlib.patches as patches import matplotlib.pyplot as plt plt.figure(self.figure.number) row_count = 2+len(self.external_metric_labels) self.ax1_1 = plt.subplot2grid((row_count, 4), (0, 0), rowspan=1, colspan=3) self.ax1_2 = plt.subplot2grid((row_count, 4), (0, 3), rowspan=1, colspan=1) self.ax2_1 = plt.subplot2grid((row_count, 4), (1, 0), rowspan=1, colspan=3) self.ax2_2 = plt.subplot2grid((row_count, 4), (1, 3), rowspan=1, colspan=1) self.extra_main = {} self.extra_highlight = {} row_id = 2 for metric_label in self.external_metric_labels: self.extra_main[metric_label] = plt.subplot2grid((row_count, 4), (row_id, 0), rowspan=1, colspan=3) self.extra_highlight[metric_label] = plt.subplot2grid((row_count, 4), (row_id, 3), rowspan=1, colspan=1) row_id += 1 span = [self.epoch - self.focus_span, self.epoch] if span[0] < 0: span[0] = 0 # PLOT 1 / Main self.ax1_1.cla() self.ax1_1.set_title('Loss') self.ax1_1.set_ylabel('Model Loss') self.ax1_1.plot( numpy.arange(self.epochs), self.data['l_tra'], lw=3, color='red', ) self.ax1_1.plot( numpy.arange(self.epochs), self.data['l_val'], lw=3, color='green', ) self.ax1_1.add_patch( patches.Rectangle( (span[0], self.ax1_1.get_ylim()[0]), # (x,y) width=span[1]-span[0], height=self.ax1_1.get_ylim()[1], facecolor="#000000", alpha=0.05 ) ) # Horizontal lines if not numpy.all(numpy.isnan(self.data['l_tra'])): self.ax1_1.axhline(y=numpy.nanmin(self.data['l_tra']), lw=1, color='red', linestyle='--') self.ax1_1.axhline(y=numpy.nanmin(self.data['l_val']), lw=1, color='green', linestyle='--') self.ax1_1.legend(['Train', 'Validation'], loc='upper right') self.ax1_1.set_xlim([0, self.epochs - 1]) self.ax1_1.set_xticklabels([]) self.ax1_1.grid(True) # PLOT 1 / Highlighted area self.ax1_2.cla() self.ax1_2.set_title('Loss / Highlighted area') self.ax1_2.set_ylabel('Model Loss') self.ax1_2.plot( numpy.arange(span[0], span[1]), self.data['l_tra'][span[0]:span[1]], lw=3, color='red', ) self.ax1_2.plot( numpy.arange(span[0], span[1]), self.data['l_val'][span[0]:span[1]], lw=3, color='green', ) self.ax1_2.set_xticklabels([]) self.ax1_2.grid(True) self.ax1_2.yaxis.tick_right() self.ax1_2.yaxis.set_label_position("right") # PLOT 2 / Main self.ax2_1.cla() self.ax2_1.set_title('Metric') self.ax2_1.set_ylabel(self.metric) # Plots self.ax2_1.plot( numpy.arange(self.epochs), self.data['m_tra'], lw=3, color='red', ) self.ax2_1.plot( numpy.arange(self.epochs), self.data['m_val'], lw=3, color='green', ) # Horizontal lines if not numpy.all(numpy.isnan(self.data['m_tra'])): if self.get_operator(metric=self.metric) == numpy.greater: h_tra_line_y = numpy.nanmax(self.data['m_tra']) h_val_line_y = numpy.nanmax(self.data['m_val']) else: h_tra_line_y = numpy.nanmin(self.data['m_tra']) h_val_line_y = numpy.nanmin(self.data['m_tra']) self.ax2_1.axhline(y=h_tra_line_y, lw=1, color='red', linestyle='--') self.ax2_1.axhline(y=h_val_line_y, lw=1, color='green', linestyle='--') self.ax2_1.add_patch( patches.Rectangle( (span[0], self.ax2_1.get_ylim()[0]), # (x,y) width=span[1]-span[0], height=self.ax2_1.get_ylim()[1], facecolor="#000000", alpha=0.05 ) ) if self.get_operator(metric=self.metric) == numpy.greater: legend_location = 'lower right' else: legend_location = 'upper right' self.ax2_1.legend(['Train', 'Validation'], loc=legend_location) self.ax2_1.set_xlim([0, self.epochs - 1]) if self.external_metric_labels: self.ax2_1.set_xticklabels([]) self.ax2_1.grid(True) # PLOT 2 / Highlighted area self.ax2_2.cla() self.ax2_2.set_title('Metric / Highlighted area') self.ax2_2.set_ylabel(self.metric) self.ax2_2.plot( numpy.arange(span[0], span[1]), self.data['m_tra'][span[0]:span[1]], lw=3, color='red', ) self.ax2_2.plot( numpy.arange(span[0], span[1]), self.data['m_val'][span[0]:span[1]], lw=3, color='green', ) self.ax2_2.set_xticklabels([]) self.ax2_2.grid(True) self.ax2_2.yaxis.tick_right() self.ax2_2.yaxis.set_label_position("right") for mid, metric_label in enumerate(self.external_metric_labels): metric_name = self.external_metric_labels[metric_label] if metric_name.endswith('f_measure') or metric_name.endswith('f_score'): factor = 100 else: factor = 1 # PLOT 3 / Main self.extra_main[metric_label].cla() self.extra_main[metric_label].set_title('External metric') self.extra_main[metric_label].set_ylabel(str(metric_label)) mask = numpy.isfinite(self.data[metric_label]) self.extra_main[metric_label].plot( numpy.arange(self.epochs)[mask], self.data[metric_label][mask]*factor, lw=3, color='green', marker='o', ) self.extra_main[metric_label].add_patch( patches.Rectangle( (span[0], self.extra_main[metric_label].get_ylim()[0]), # (x,y) width=span[1]-span[0], height=self.extra_main[metric_label].get_ylim()[1], facecolor="#000000", alpha=0.05 ) ) # Horizontal lines if not numpy.all(numpy.isnan(self.data[metric_label][mask])): if self.get_operator(metric=str(metric_label)) == numpy.greater: h_extra_line_y = numpy.nanmax((self.data[metric_label][mask]*factor)) else: h_extra_line_y = numpy.nanmin((self.data[metric_label][mask]*factor)) self.extra_main[metric_label].axhline(y=h_extra_line_y, lw=1, color='blue', linestyle='--') if self.get_operator(metric=self.metric) == numpy.greater: legend_location = 'lower right' else: legend_location = 'upper right' self.extra_main[metric_label].legend(['Validation'], loc=legend_location) self.extra_main[metric_label].set_xlim([0, self.epochs - 1]) if (mid + 1) < len(self.external_metric_labels): self.extra_main[metric_label].set_xticklabels([]) else: self.extra_main[metric_label].set_xlabel('Epochs') self.extra_main[metric_label].grid(True) # PLOT 3 / Highlighted area self.extra_highlight[metric_label].cla() self.extra_highlight[metric_label].set_title('External metric / Highlighted area') self.extra_highlight[metric_label].set_ylabel(str(metric_label)) highlight_data = self.data[metric_label][span[0]:span[1]]*factor mask = numpy.isfinite(highlight_data) self.extra_highlight[metric_label].plot( numpy.arange(span[0], span[1])[mask], highlight_data[mask], lw=3, color='green', marker='o', ) if (mid + 1) < len(self.external_metric_labels): self.extra_highlight[metric_label].set_xticklabels([]) else: self.extra_highlight[metric_label].set_xlabel('Epochs') self.extra_highlight[metric_label].yaxis.tick_right() self.extra_highlight[metric_label].yaxis.set_label_position("right") self.extra_highlight[metric_label].grid(True) plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1, wspace=0.02, hspace=0.2) def on_train_begin(self, logs=None): if self.epochs is None: self.epochs = self.params['epochs'] def on_epoch_begin(self, epoch, logs=None): self.epoch = epoch + 1 self.seen = 0 def on_epoch_end(self, epoch, logs=None): self.epoch = epoch logs = logs or {} # Collect values metrics_list = list(logs.keys()) if 'metrics' in self.params: metrics_list = self.params['metrics'] for k in metrics_list: if k in logs: self.log_values.append((k, logs[k])) if k == 'loss': self.data['l_tra'][self.epoch] = logs[k] elif k == 'val_loss': self.data['l_val'][self.epoch] = logs[k] elif self.metric and k.endswith(self.metric): if k.startswith('val_'): self.data['m_val'][self.epoch] = logs[k] else: self.data['m_tra'][self.epoch] = logs[k] if not self.manual_update: # Update logged progress self.update() def update(self): """Update """ import matplotlib.pyplot as plt if self.timer.elapsed() > self.plotting_rate: self.draw() self.figure.canvas.flush_events() if self.interactive: plt.pause(0.01) if self.save: plt.savefig(self.filename, bbox_inches='tight', format=self.format, dpi=1000) self.timer.start() def add_external_metric(self, metric_label): """Add external metric to be monitored Parameters ---------- metric_label : str Metric label """ if metric_label not in self.external_metric_labels: self.external_metric_labels[metric_label] = metric_label if metric_label not in self.data: self.data[metric_label] = numpy.empty((self.epochs,)) self.data[metric_label][:] = numpy.nan def set_external_metric_value(self, metric_label, metric_value): """Add external metric value Parameters ---------- metric_label : str Metric label metric_value : numeric Metric value """ self.external_metric[metric_label] = metric_value self.data[metric_label][self.epoch] = metric_value def close(self): """Manually close progress logging """ import matplotlib.pyplot as plt if self.save: self.draw() plt.savefig(self.filename, bbox_inches='tight', format=self.format, dpi=1000) plt.close(self.figure)
[docs]class StopperCallback(BaseCallback): """Keras callback to stop training when improvement has not seen in specified amount of epochs. Implements Keras Callback API. Callback is very similar to standard ``EarlyStopping`` Keras callback, however it adds support for external metrics (calculated outside Keras training process). """
[docs] def __init__(self, epochs=None, manual_update=False, monitor='val_loss', patience=0, min_delta=0, initial_delay=10, **kwargs): """Constructor Parameters ---------- epochs : int Total amount of epochs manual_update : bool Manually update callback, use this to when injecting external metrics monitor : str Metric value to be monitored patience : int Number of epochs with no improvement after which training will be stopped. min_delta : float Minimum change in the monitored quantity to qualify as an improvement. initial_delay : int Amount of epochs to wait at the beginning before quantity is monitored. """ kwargs.update({ 'epochs': epochs, 'manual_update': manual_update, }) super(StopperCallback, self).__init__(**kwargs) self.monitor = monitor self.patience = patience self.min_delta = min_delta self.initial_delay = initial_delay self.wait = None self.stopped_epoch = None self.model = None self.params = None self.last_update_epoch = 0 self.stopped = False self.metric_data = { self.monitor: numpy.empty((self.epochs,)) } self.metric_data[self.monitor][:] = numpy.nan mode = kwargs.get('mode', 'auto') if mode not in ['min', 'max', 'auto']: mode = 'auto' if mode == 'min': self.monitor_op = numpy.less elif mode == 'max': self.monitor_op = numpy.greater else: self.monitor_op = self.get_operator(metric=self.monitor) self.best = numpy.Inf if self.monitor_op == numpy.less else -numpy.Inf if self.monitor_op == numpy.greater: self.min_delta *= 1 else: self.min_delta *= -1
def on_train_begin(self, logs=None): if self.epochs is None: self.epochs = self.params['epochs'] if self.wait is None: self.wait = 0 if self.stopped_epoch is None: self.stopped_epoch = 0 def on_epoch_begin(self, epoch, logs=None): self.epoch = epoch + 1 def on_epoch_end(self, epoch, logs=None): self.epoch = epoch if self.monitor in logs: self.metric_data[self.monitor][self.epoch] = logs.get(self.monitor) if not self.manual_update: self.update() def set_external_metric_value(self, metric_label, metric_value): """Add external metric value Parameters ---------- metric_label : str Metric label metric_value : numeric Metric value """ if metric_label not in self.metric_data: self.metric_data[metric_label] = numpy.empty((self.epochs,)) self.metric_data[metric_label][:] = numpy.nan self.metric_data[metric_label][self.epoch] = metric_value def stop(self): return self.stopped def update(self): if self.epoch > self.initial_delay: # get current metric value current = self.metric_data[self.monitor][self.epoch] if numpy.isnan(current): message = '{name}: Metric to monitor is Nan, metric:[{metric}]'.format( name=self.__class__.__name__, metric=self.monitor ) self.logger.exception(message) raise ValueError(message) if self.monitor_op(current - self.min_delta, self.best): # New best value found self.best = current self.wait = 0 else: if self.wait >= self.patience: # Stopping criteria met => return false self.stopped_epoch = self.epoch self.model.stop_training = True self.logger.info(' Stopping criteria met at epoch[{epoch:d}]'.format( epoch=self.epoch, )) self.logger.info(' metric[{metric}], patience[{patience:d}]'.format( metric=self.monitor, current='{:4.4f}'.format(current), patience=self.patience )) self.logger.info(' ') self.stopped = True return self.stopped # Increase waiting counter self.wait += self.epoch - self.last_update_epoch self.last_update_epoch = self.epoch return self.stopped
[docs]class StasherCallback(BaseCallback): """Keras callback to monitor training process and store best model. Implements Keras Callback API. This callback is very similar to standard ``ModelCheckpoint`` Keras callback, however it adds support for external metrics (metrics calculated outside Keras training process). """
[docs] def __init__(self, epochs=None, manual_update=False, monitor='val_loss', mode='auto', period=1, initial_delay=10, save_weights=False, file_path=None, **kwargs): """Constructor Parameters ---------- epochs : int Total amount of epochs Default value None manual_update : bool Manually update callback, use this to when injecting external metrics Default value False monitor : str Metric to monitor Default value 'val_loss' mode : str Which way metric is interpreted, values {auto, min, max} Default value 'auto' period : int Save only after every Nth epoch Default value 1 initial_delay : int Amount of epochs to wait at the beginning before quantity is monitored. Default value 10 save_weights : bool Save weight to the disk Default value False file_path : str File name for model weight Default value None """ kwargs.update({ 'epochs': epochs, 'manual_update': manual_update, }) super(StasherCallback, self).__init__(**kwargs) self.monitor = monitor self.period = period self.initial_delay = initial_delay self.save_weights = save_weights self.file_path = file_path self.epochs_since_last_save = 0 if mode not in ['auto', 'min', 'max']: mode = 'auto' if mode == 'min': self.monitor_op = numpy.less elif mode == 'max': self.monitor_op = numpy.greater else: self.monitor_op = self.get_operator(metric=self.monitor) self.best = numpy.Inf if self.monitor_op == numpy.less else -numpy.Inf self.metric_data = { self.monitor: numpy.empty((self.epochs,)) } self.metric_data[self.monitor][:] = numpy.nan self.best_model_weights = None self.best_model_epoch = 0 self.last_logs = None
def on_epoch_begin(self, epoch, logs=None): self.epoch = epoch + 1 def on_epoch_end(self, epoch, logs=None): self.epoch = epoch if self.monitor in logs: self.metric_data[self.monitor][self.epoch] = logs.get(self.monitor) self.last_logs = logs if not self.manual_update: self.update() def update(self): if self.epoch > self.initial_delay: self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 current = self.metric_data[self.monitor][self.epoch] if numpy.isnan(current): message = '{name}: Metric to monitor is Nan, metric:[{metric}]'.format( name=self.__class__.__name__, metric=self.monitor ) self.logger.exception(message) raise ValueError(message) else: if self.monitor_op(current, self.best): # Store the best self.best = current self.best_model_weights = self.model.get_weights() self.best_model_epoch = self.epoch if self.save_weights and self.file_path: # Save weight on disk logs = self.last_logs if self.monitor not in logs: logs[self.monitor] = current file_path = self.file_path.format(epoch=self.epoch, **self.last_logs) self.model.save_weights(file_path, overwrite=True) def set_external_metric_value(self, metric_label, metric_value): """Add external metric value Parameters ---------- metric_label : str Metric label metric_value : numeric Metric value """ if metric_label not in self.metric_data: self.metric_data[metric_label] = numpy.empty((self.epochs,)) self.metric_data[metric_label][:] = numpy.nan self.metric_data[metric_label][self.epoch] = metric_value def get_best(self): """Return best model seen Returns ------- dict Dictionary with keys 'weights', 'epoch', 'metric_value' """ return { 'epoch': self.best_model_epoch, 'weights': self.best_model_weights, 'metric_value': self.best, } def to_string(self, ui=None, indent=0): """Get information in a string Parameters ---------- ui : FancyStringifier or FancyHTMLStringifier Stringifier class Default value FancyStringifier indent : int Amount of indent Default value 0 Returns ------- str """ if ui is None: ui = FancyStringifier() output = '' output += ui.class_name(self.__class__.__name__, indent=indent) + '\n' output += ui.data(field='Best model weights at epoch', value=self.best_model_epoch, indent=indent) + '\n' output += ui.data(field='Metric type', value=self.monitor, indent=indent+2) + '\n' output += ui.data(field='Metric value', value=self.best, indent=indent+2) + '\n' output += ui.line() return output def to_html(self, indent=0): """Get information in a HTML formatted string Parameters ---------- indent : int Amount of indent Default value 0 Returns ------- str """ return self.to_string(ui=FancyHTMLStringifier(), indent=indent) def log(self): """Print information about the best model into logging interface """ lines = self.to_string(indent=2).split('\n') for line in lines: self.logger.info(line) self.logger.info(' ') def show(self, mode='auto', indent=0): """Print information about the best model If called inside Jupyter notebook HTML formatted version is shown. Parameters ---------- mode : str Output type, possible values ['auto', 'print', 'html']. 'html' will work only in Jupyter notebook Default value 'auto' indent : int Amount of indent Default value 0 Returns ------- Nothing """ if mode == 'auto': if is_jupyter(): mode = 'html' else: mode = 'print' if mode not in ['html', 'print']: # Unknown mode given message = '{name}: Unknown mode [{mode}]'.format(name=self.__class__.__name__, mode=mode) self.logger.exception(message) raise ValueError(message) if mode == 'html': from IPython.core.display import display, HTML display( HTML( self.to_html(indent=indent) ) ) elif mode == 'print': print(self.to_string(indent=indent))