tf.keras utilities

Utilities to be used with Keras deep learning framework within TensorFlow (>= 2.0).

Model

dcase_util.tfkeras.model. *

create_sequential_model(model_parameter_list)

Create sequential Keras model

model_summary_string(keras_model[, mode, ...])

Model summary in a formatted string, similar to Keras model summary function.

Callbacks

Usage example how to use external metrics with the Callback classes provided by the dcase_util:

 1epochs = 100
 2batch_size = 256
 3loss = 'categorical_crossentropy'
 4metrics =  ['categorical_accuracy']
 5processing_interval = 1
 6manual_update = True
 7external_metric_labels={'ER': 'Error rate'}
 8
 9callback_list = [
10        dcase_util.tfkeras.ProgressLoggerCallback(
11            epochs=epochs,
12            metric=metrics,
13            loss=loss,
14            manual_update=manual_update,
15            manual_update_interval=processing_interval,
16            external_metric_labels=external_metric_labels
17        ),
18        dcase_util.tfkeras.ProgressPlotterCallback(
19            epochs=epochs,
20            metric=metrics,
21            save=False,
22            manual_update=manual_update,
23            manual_update_interval=processing_interval,
24            external_metric_labels=external_metric_labels
25        ),
26        dcase_util.tfkeras.StopperCallback(
27            epochs=epochs,
28            monitor=metric[0],
29            manual_update=manual_update,
30        ),
31        dcase_util.tfkeras.StasherCallback(
32            epochs=epochs,
33            monitor=metric[0],
34            manual_update=manual_update,
35        )
36    ]
37
38for epoch_start in range(0, epochs, processing_interval):
39    epoch_end = epoch_start + processing_interval
40
41    # Make sure we have only specified amount of epochs
42    if epoch_end > epochs:
43        epoch_end = epochs
44
45    # Train model
46    keras_model.fit(
47        x=training_X,
48        y=training_Y,
49        validation_data=(validation_X, validation_Y),
50        callbacks=callback_list,
51        verbose=0,
52        initial_epoch=epoch_start,
53        epochs=epoch_end,
54        batch_size=batch_size,
55        shuffle=True
56    )
57    # Calculate external metrics
58    ER = 0.0
59
60    # Inject external metric values to the callbacks
61    for callback in callback_list:
62        if hasattr(callback, 'set_external_metric_value'):
63            callback.set_external_metric_value(
64                metric_label='ER',
65                metric_value=ER
66            )
67
68    # Manually update callbacks
69    for callback in callback_list:
70        if hasattr(callback, 'update'):
71            callback.update()
72
73    # Check we need to stop training
74    stop_training = False
75    for callback in callback_list:
76        if hasattr(callback, 'stop'):
77            if callback.stop():
78                stop_training = True
79
80    if stop_training:
81        # Stop the training loop
82        break

ProgressLoggerCallback

dcase_util.tfkeras.ProgressLoggerCallback

Keras callback to store metrics with tqdm progress bar or logging interface. Implements Keras Callback API.

This callback is very similar to standard ProgbarLogger Keras callback, however it adds support for logging interface and external metrics (metrics calculated outside Keras training process).

ProgressLoggerCallback([manual_update, ...])

Keras callback to show metrics in logging interface.

ProgressPlotterCallback

dcase_util.tfkeras.ProgressPlotterCallback

Keras callback to plot progress during the training process and save final progress into figure. Implements Keras Callback API.

ProgressPlotterCallback([epochs, ...])

Keras callback to plot progress during the training process and save final progress into figure.

StopperCallback

dcase_util.tfkeras.StopperCallback

Keras callback to stop training when improvement has not seen in specified amount of epochs. Implements Keras Callback API.

This Callback is very similar to standard EarlyStopping Keras callback, however it adds support for external metrics (metrics calculated outside Keras training process).

StopperCallback([epochs, manual_update, ...])

Keras callback to stop training when improvement has not seen in specified amount of epochs.

StasherCallback

dcase_util.tfkeras.StasherCallback

Keras callback to monitor training process and store best model. Implements Keras Callback API.

This callback is very similar to standard ModelCheckpoint Keras callback, however it adds support for external metrics (metrics calculated outside Keras training process).

StasherCallback([epochs, manual_update, ...])

Keras callback to monitor training process and store best model.

BaseCallback

dcase_util.tfkeras.BaseCallback

BaseCallback([epochs, manual_update, ...])

Base class for Callbacks

Data processing

KerasDataSequence

dcase_util.tfkeras.get_keras_data_sequence

KerasDataSequence class should be accessed through getter method to avoid importing Keras when importing dcase_util. This mechanics allows user to decide when importing the Keras, and set random seeds before this.

get_keras_data_sequence_class()

data_collector

dcase_util.tfkeras.data_collector

data_collector([item_list, ...])

Utils

dcase_util.tfkeras.utils. *

setup_keras

Decorator class to allow only one execution

create_optimizer(class_name[, config])

Create tf.keras optimizer