habib's blog

Advanced Keras Techniques

In this blog, we’ll explore some advanced Keras techniques that can help you improve your model’s performance, flexibility, and efficiency during training. These techniques go beyond the basics of model.fit() and allow you to take full control of the training process. We’ll cover custom training loops, accuracy metrics, and custom callbacks, with easy-to-understand explanations and code snippets.

Custom Training Loops:

Why Use Custom Training Loops?

By default, Keras provides the model.fit() method, which handles the entire training process for you. However, sometimes you need more control over how your model trains. This is where custom training loops come in. They allow you to:

  1. Implement custom logic (e.g., custom loss functions, multi-task learning).

  2. Debug and inspect intermediate values (e.g., gradients, predictions).

  3. Handle advanced use cases like GANs or reinforcement learning.

Let’s start by implementing a basic neural network and then build a custom training loop.

First let us implement a basic neural network:

model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dense(10)
])
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 
optimizer = tf.keras.optimizers.Adam()

Implementing a custom training loop:

epochs = 2
train_dataset = train_dataset.repeat(epochs)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
for epoch in range(epochs):
    print(f'Start of epoch {epoch + 1}')

    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)  # Forward pass and logits stores the model's prediction for the particular batch
            loss_value = loss_fn(y_batch_train, logits)  # Computes the loss

        # Compute gradients and update weights
        grads = tape.gradient(loss_value, model.trainable_weights)
#tape.gradient(loss, model.trainable_weights) computes the gradients of the loss with respect to the model's trainable weights
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Logging the loss every 200 steps
        if step % 200 == 0:
            print(f'Epoch {epoch + 1} Step {step}: Loss = {loss_value.numpy()}')

What's happening here?

  1. Forward Pass: The model makes predictions (logits) for the current batch.
  2. Loss Calculation: The loss is computed between the predictions and true labels.
  3. Gradient Computation: Gradients of the loss with respect to the model’s weights are calculated using tape.gradient().
  4. Weight Update: The optimizer applies the gradients to update the model’s weights.
  5. Logging: The loss is printed every 200 steps to monitor progress.

Enhancing the above loop by adding accuracy metric:

While the above loop works, it only tracks the loss. Let’s enhance it by adding an accuracy metric to monitor model performance.

Define the model:

model = Sequential([ 
    Flatten(input_shape=(28, 28)),  # Flatten the input to a 1D vector
    Dense(128, activation='relu'),  # First hidden layer with 128 neurons and ReLU activation
    Dense(10)  # Output layer with 10 neurons for the 10 classes (digits 0-9)
])

Define loss function, optimizer and metric:

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam() 
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy()  # Metric to track accuracy during training

Implementation of custom training loop with accuracy:

epochs = 5  # Number of epochs for training

for epoch in range(epochs):
    print(f'Start of epoch {epoch + 1}')
    
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            # Forward pass: Compute predictions
            logits = model(x_batch_train, training=True)
            # Compute loss
            loss_value = loss_fn(y_batch_train, logits)
         
        # Compute gradients
        grads = tape.gradient(loss_value, model.trainable_weights)
        # Apply gradients to update model weights
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        
        # Update the accuracy metric
        accuracy_metric.update_state(y_batch_train, logits)

        # Log the loss and accuracy every 200 steps
        if step % 200 == 0:
            print(f'Epoch {epoch + 1} Step {step}: Loss = {loss_value.numpy()} Accuracy = {accuracy_metric.result().numpy()}')
    
    # Reset the metric at the end of each epoch
    accuracy_metric.reset_state()

What's new here?

  1. Accuracy Tracking: The accuracy_metric.update_state() method updates the accuracy for each batch.
  2. Logging: Both loss and accuracy are printed every 200 steps.
  3. Reset Metric: The accuracy metric is reset at the end of each epoch to avoid mixing results across epochs.

Custom Callbacks:

Create a custom callback to log additional metrics at the end of each epoch.

# Step 4: Implement the Custom Callback 
class CustomCallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        print(f'End of epoch {epoch + 1}, loss: {logs.get("loss")}, accuracy: {logs.get("accuracy")}')

Implementation of custom training loop with Custom Callback

# Step 5: Implement the Custom Training Loop with Custom Callback

epochs = 2
custom_callback = CustomCallback()  # Initialize the custom callback

for epoch in range(epochs):
    print(f'Start of epoch {epoch + 1}')
    
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            # Forward pass: Compute predictions
            logits = model(x_batch_train, training=True)
            # Compute loss
            loss_value = loss_fn(y_batch_train, logits)
        
        # Compute gradients
        grads = tape.gradient(loss_value, model.trainable_weights)
        # Apply gradients to update model weights
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        
        # Update the accuracy metric
        accuracy_metric.update_state(y_batch_train, logits)

        # Log the loss and accuracy every 200 steps
        if step % 200 == 0:
            print(f'Epoch {epoch + 1} Step {step}: Loss = {loss_value.numpy()} Accuracy = {accuracy_metric.result().numpy()}')
    
    # Call the custom callback at the end of each epoch
    custom_callback.on_epoch_end(epoch, logs={'loss': loss_value.numpy(), 'accuracy': accuracy_metric.result().numpy()})
    
    # Reset the metric at the end of each epoch
    accuracy_metric.reset_state()  # Use reset_state() instead of reset_states()

What's new here?

  1. Custom Callback: The CustomCallback class logs the loss and accuracy at the end of each epoch.

  2. Integration: The callback is invoked at the end of each epoch using custom_callback.on_epoch_end().