tf.keras utilities
Utilities to be used with Keras deep learning framework within TensorFlow (>= 2.0).
Model
dcase_util.tfkeras.model. *
|
Create sequential Keras model |
|
Model summary in a formatted string, similar to Keras model summary function. |
Callbacks
Usage example how to use external metrics with the Callback classes provided by the dcase_util:
1epochs = 100
2batch_size = 256
3loss = 'categorical_crossentropy'
4metrics = ['categorical_accuracy']
5processing_interval = 1
6manual_update = True
7external_metric_labels={'ER': 'Error rate'}
8
9callback_list = [
10 dcase_util.tfkeras.ProgressLoggerCallback(
11 epochs=epochs,
12 metric=metrics,
13 loss=loss,
14 manual_update=manual_update,
15 manual_update_interval=processing_interval,
16 external_metric_labels=external_metric_labels
17 ),
18 dcase_util.tfkeras.ProgressPlotterCallback(
19 epochs=epochs,
20 metric=metrics,
21 save=False,
22 manual_update=manual_update,
23 manual_update_interval=processing_interval,
24 external_metric_labels=external_metric_labels
25 ),
26 dcase_util.tfkeras.StopperCallback(
27 epochs=epochs,
28 monitor=metric[0],
29 manual_update=manual_update,
30 ),
31 dcase_util.tfkeras.StasherCallback(
32 epochs=epochs,
33 monitor=metric[0],
34 manual_update=manual_update,
35 )
36 ]
37
38for epoch_start in range(0, epochs, processing_interval):
39 epoch_end = epoch_start + processing_interval
40
41 # Make sure we have only specified amount of epochs
42 if epoch_end > epochs:
43 epoch_end = epochs
44
45 # Train model
46 keras_model.fit(
47 x=training_X,
48 y=training_Y,
49 validation_data=(validation_X, validation_Y),
50 callbacks=callback_list,
51 verbose=0,
52 initial_epoch=epoch_start,
53 epochs=epoch_end,
54 batch_size=batch_size,
55 shuffle=True
56 )
57 # Calculate external metrics
58 ER = 0.0
59
60 # Inject external metric values to the callbacks
61 for callback in callback_list:
62 if hasattr(callback, 'set_external_metric_value'):
63 callback.set_external_metric_value(
64 metric_label='ER',
65 metric_value=ER
66 )
67
68 # Manually update callbacks
69 for callback in callback_list:
70 if hasattr(callback, 'update'):
71 callback.update()
72
73 # Check we need to stop training
74 stop_training = False
75 for callback in callback_list:
76 if hasattr(callback, 'stop'):
77 if callback.stop():
78 stop_training = True
79
80 if stop_training:
81 # Stop the training loop
82 break
ProgressLoggerCallback
dcase_util.tfkeras.ProgressLoggerCallback
Keras callback to store metrics with tqdm progress bar or logging interface. Implements Keras Callback API.
This callback is very similar to standard ProgbarLogger
Keras callback, however it adds support for
logging interface and external metrics (metrics calculated outside Keras training process).
|
Keras callback to show metrics in logging interface. |
ProgressPlotterCallback
dcase_util.tfkeras.ProgressPlotterCallback
Keras callback to plot progress during the training process and save final progress into figure. Implements Keras Callback API.
|
Keras callback to plot progress during the training process and save final progress into figure. |
StopperCallback
dcase_util.tfkeras.StopperCallback
Keras callback to stop training when improvement has not seen in specified amount of epochs. Implements Keras Callback API.
This Callback is very similar to standard EarlyStopping
Keras callback, however it adds support for
external metrics (metrics calculated outside Keras training process).
|
Keras callback to stop training when improvement has not seen in specified amount of epochs. |
StasherCallback
dcase_util.tfkeras.StasherCallback
Keras callback to monitor training process and store best model. Implements Keras Callback API.
This callback is very similar to standard ModelCheckpoint
Keras callback, however it adds support for
external metrics (metrics calculated outside Keras training process).
|
Keras callback to monitor training process and store best model. |
BaseCallback
dcase_util.tfkeras.BaseCallback
|
Base class for Callbacks |
Data processing
KerasDataSequence
dcase_util.tfkeras.get_keras_data_sequence
KerasDataSequence class should be accessed through getter method to avoid importing Keras when importing dcase_util. This mechanics allows user to decide when importing the Keras, and set random seeds before this.
data_collector
dcase_util.tfkeras.data_collector
|
Utils
dcase_util.tfkeras.utils. *
Decorator class to allow only one execution |
|
|
Create tf.keras optimizer |