Introduction to Transfer Learning¶
In this lab, you will explore how to utilize a pre-trained model to achieve high-performance results even when working with a relatively small dataset. This technique, known as transfer learning, allows you to leverage the knowledge embedded in the trained layers of an existing model and adapt it to your specific application. Rather than training a deep neural network from scratch—a process that requires extensive computational resources and a large amount of data—you can fine-tune a pre-trained model to your needs.
The fundamental approach to transfer learning involves:
- Extracting the convolutional layers of a pre-trained model.
- Appending fully connected (dense) layers tailored to your specific task.
- Training only the newly added dense layers while keeping the convolutional layers frozen.
- Evaluating the results and fine-tuning further if needed.
By adopting transfer learning, you significantly reduce training time and improve model performance, as the pre-trained model has already learned useful feature representations from a large dataset. Instead of learning everything from scratch, you simply reuse and adapt the learned features for your dataset.
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import logging
tf.get_logger().setLevel(logging.ERROR)
Preparing the Pretrained Model: InceptionV3¶
For this exercise, we will use InceptionV3, a powerful convolutional neural network architecture originally trained on ImageNet, as the base model. InceptionV3 is particularly effective at extracting features from images due to its deep hierarchical structure and efficient use of computational resources.
Steps to Configure the Model:
- Define the input shape: The model must accept an input shape suitable for your application. For this lab, we use 150x150x3 as the input dimensions.
- Freeze the convolutional base: Since the convolutional layers of InceptionV3 have already learned valuable features, we freeze them to retain their pre-trained weights.
- Attach a new classifier: We append a set of dense layers on top of the frozen base to tailor the model to our specific task.
local_weights_file = 'inception_v3_model/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'
# Set the input shape and remove the dense layers
pretrained_model = tf.keras.applications.inception_v3.InceptionV3( # Loads the InceptionV3 model architecture
input_shape=(150,150,3), # Sets the input size for images as 150x150 pixels with 3 color channels (RGB)
include_top=False, # Removes the fully connected (dense) layers from the original model
weights=None # The model is not loaded with pre-trained weights.If we wanted pre-trained ImageNet weights, we would use weights='imagenet'
)
# This ensures that the convolutional base has pre-learned features
pretrained_model.load_weights(local_weights_file) # Loads the weights from a previously downloaded trained model
# Loops through all layers in the pretrained model
for layer in pretrained_model.layers:
layer.trainable = False # Freezes the layers, meaning they will not be updated during training. This retains the knowledge already present in the pre-trained model
pretrained_model.summary()
Model: "inception_v3" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 150, 150, 3 0 [] )] conv2d (Conv2D) (None, 74, 74, 32) 864 ['input_1[0][0]'] batch_normalization (BatchNorm (None, 74, 74, 32) 96 ['conv2d[0][0]'] alization) activation (Activation) (None, 74, 74, 32) 0 ['batch_normalization[0][0]'] conv2d_1 (Conv2D) (None, 72, 72, 32) 9216 ['activation[0][0]'] batch_normalization_1 (BatchNo (None, 72, 72, 32) 96 ['conv2d_1[0][0]'] rmalization) activation_1 (Activation) (None, 72, 72, 32) 0 ['batch_normalization_1[0][0]'] conv2d_2 (Conv2D) (None, 72, 72, 64) 18432 ['activation_1[0][0]'] batch_normalization_2 (BatchNo (None, 72, 72, 64) 192 ['conv2d_2[0][0]'] rmalization) activation_2 (Activation) (None, 72, 72, 64) 0 ['batch_normalization_2[0][0]'] max_pooling2d (MaxPooling2D) (None, 35, 35, 64) 0 ['activation_2[0][0]'] conv2d_3 (Conv2D) (None, 35, 35, 80) 5120 ['max_pooling2d[0][0]'] batch_normalization_3 (BatchNo (None, 35, 35, 80) 240 ['conv2d_3[0][0]'] rmalization) activation_3 (Activation) (None, 35, 35, 80) 0 ['batch_normalization_3[0][0]'] conv2d_4 (Conv2D) (None, 33, 33, 192) 138240 ['activation_3[0][0]'] batch_normalization_4 (BatchNo (None, 33, 33, 192) 576 ['conv2d_4[0][0]'] rmalization) activation_4 (Activation) (None, 33, 33, 192) 0 ['batch_normalization_4[0][0]'] max_pooling2d_1 (MaxPooling2D) (None, 16, 16, 192) 0 ['activation_4[0][0]'] conv2d_8 (Conv2D) (None, 16, 16, 64) 12288 ['max_pooling2d_1[0][0]'] batch_normalization_8 (BatchNo (None, 16, 16, 64) 192 ['conv2d_8[0][0]'] rmalization) activation_8 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_8[0][0]'] conv2d_6 (Conv2D) (None, 16, 16, 48) 9216 ['max_pooling2d_1[0][0]'] conv2d_9 (Conv2D) (None, 16, 16, 96) 55296 ['activation_8[0][0]'] batch_normalization_6 (BatchNo (None, 16, 16, 48) 144 ['conv2d_6[0][0]'] rmalization) batch_normalization_9 (BatchNo (None, 16, 16, 96) 288 ['conv2d_9[0][0]'] rmalization) activation_6 (Activation) (None, 16, 16, 48) 0 ['batch_normalization_6[0][0]'] activation_9 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_9[0][0]'] average_pooling2d (AveragePool (None, 16, 16, 192) 0 ['max_pooling2d_1[0][0]'] ing2D) conv2d_5 (Conv2D) (None, 16, 16, 64) 12288 ['max_pooling2d_1[0][0]'] conv2d_7 (Conv2D) (None, 16, 16, 64) 76800 ['activation_6[0][0]'] conv2d_10 (Conv2D) (None, 16, 16, 96) 82944 ['activation_9[0][0]'] conv2d_11 (Conv2D) (None, 16, 16, 32) 6144 ['average_pooling2d[0][0]'] batch_normalization_5 (BatchNo (None, 16, 16, 64) 192 ['conv2d_5[0][0]'] rmalization) batch_normalization_7 (BatchNo (None, 16, 16, 64) 192 ['conv2d_7[0][0]'] rmalization)
batch_normalization_10 (BatchN (None, 16, 16, 96) 288 ['conv2d_10[0][0]'] ormalization) batch_normalization_11 (BatchN (None, 16, 16, 32) 96 ['conv2d_11[0][0]'] ormalization) activation_5 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_5[0][0]'] activation_7 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_7[0][0]'] activation_10 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_10[0][0]'] activation_11 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_11[0][0]'] mixed0 (Concatenate) (None, 16, 16, 256) 0 ['activation_5[0][0]', 'activation_7[0][0]', 'activation_10[0][0]', 'activation_11[0][0]'] conv2d_15 (Conv2D) (None, 16, 16, 64) 16384 ['mixed0[0][0]'] batch_normalization_15 (BatchN (None, 16, 16, 64) 192 ['conv2d_15[0][0]'] ormalization) activation_15 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_15[0][0]'] conv2d_13 (Conv2D) (None, 16, 16, 48) 12288 ['mixed0[0][0]'] conv2d_16 (Conv2D) (None, 16, 16, 96) 55296 ['activation_15[0][0]'] batch_normalization_13 (BatchN (None, 16, 16, 48) 144 ['conv2d_13[0][0]'] ormalization) batch_normalization_16 (BatchN (None, 16, 16, 96) 288 ['conv2d_16[0][0]'] ormalization) activation_13 (Activation) (None, 16, 16, 48) 0 ['batch_normalization_13[0][0]'] activation_16 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_16[0][0]'] average_pooling2d_1 (AveragePo (None, 16, 16, 256) 0 ['mixed0[0][0]'] oling2D) conv2d_12 (Conv2D) (None, 16, 16, 64) 16384 ['mixed0[0][0]'] conv2d_14 (Conv2D) (None, 16, 16, 64) 76800 ['activation_13[0][0]'] conv2d_17 (Conv2D) (None, 16, 16, 96) 82944 ['activation_16[0][0]'] conv2d_18 (Conv2D) (None, 16, 16, 64) 16384 ['average_pooling2d_1[0][0]'] batch_normalization_12 (BatchN (None, 16, 16, 64) 192 ['conv2d_12[0][0]'] ormalization) batch_normalization_14 (BatchN (None, 16, 16, 64) 192 ['conv2d_14[0][0]'] ormalization) batch_normalization_17 (BatchN (None, 16, 16, 96) 288 ['conv2d_17[0][0]'] ormalization) batch_normalization_18 (BatchN (None, 16, 16, 64) 192 ['conv2d_18[0][0]'] ormalization) activation_12 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_12[0][0]'] activation_14 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_14[0][0]'] activation_17 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_17[0][0]'] activation_18 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_18[0][0]'] mixed1 (Concatenate) (None, 16, 16, 288) 0 ['activation_12[0][0]', 'activation_14[0][0]', 'activation_17[0][0]', 'activation_18[0][0]'] conv2d_22 (Conv2D) (None, 16, 16, 64) 18432 ['mixed1[0][0]'] batch_normalization_22 (BatchN (None, 16, 16, 64) 192 ['conv2d_22[0][0]'] ormalization) activation_22 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_22[0][0]']
conv2d_20 (Conv2D) (None, 16, 16, 48) 13824 ['mixed1[0][0]'] conv2d_23 (Conv2D) (None, 16, 16, 96) 55296 ['activation_22[0][0]'] batch_normalization_20 (BatchN (None, 16, 16, 48) 144 ['conv2d_20[0][0]'] ormalization) batch_normalization_23 (BatchN (None, 16, 16, 96) 288 ['conv2d_23[0][0]'] ormalization) activation_20 (Activation) (None, 16, 16, 48) 0 ['batch_normalization_20[0][0]'] activation_23 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_23[0][0]'] average_pooling2d_2 (AveragePo (None, 16, 16, 288) 0 ['mixed1[0][0]'] oling2D) conv2d_19 (Conv2D) (None, 16, 16, 64) 18432 ['mixed1[0][0]'] conv2d_21 (Conv2D) (None, 16, 16, 64) 76800 ['activation_20[0][0]'] conv2d_24 (Conv2D) (None, 16, 16, 96) 82944 ['activation_23[0][0]'] conv2d_25 (Conv2D) (None, 16, 16, 64) 18432 ['average_pooling2d_2[0][0]'] batch_normalization_19 (BatchN (None, 16, 16, 64) 192 ['conv2d_19[0][0]'] ormalization) batch_normalization_21 (BatchN (None, 16, 16, 64) 192 ['conv2d_21[0][0]'] ormalization) batch_normalization_24 (BatchN (None, 16, 16, 96) 288 ['conv2d_24[0][0]'] ormalization) batch_normalization_25 (BatchN (None, 16, 16, 64) 192 ['conv2d_25[0][0]'] ormalization) activation_19 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_19[0][0]'] activation_21 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_21[0][0]'] activation_24 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_24[0][0]'] activation_25 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_25[0][0]'] mixed2 (Concatenate) (None, 16, 16, 288) 0 ['activation_19[0][0]', 'activation_21[0][0]', 'activation_24[0][0]', 'activation_25[0][0]'] conv2d_27 (Conv2D) (None, 16, 16, 64) 18432 ['mixed2[0][0]'] batch_normalization_27 (BatchN (None, 16, 16, 64) 192 ['conv2d_27[0][0]'] ormalization) activation_27 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_27[0][0]'] conv2d_28 (Conv2D) (None, 16, 16, 96) 55296 ['activation_27[0][0]'] batch_normalization_28 (BatchN (None, 16, 16, 96) 288 ['conv2d_28[0][0]'] ormalization) activation_28 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_28[0][0]'] conv2d_26 (Conv2D) (None, 7, 7, 384) 995328 ['mixed2[0][0]'] conv2d_29 (Conv2D) (None, 7, 7, 96) 82944 ['activation_28[0][0]'] batch_normalization_26 (BatchN (None, 7, 7, 384) 1152 ['conv2d_26[0][0]'] ormalization) batch_normalization_29 (BatchN (None, 7, 7, 96) 288 ['conv2d_29[0][0]'] ormalization) activation_26 (Activation) (None, 7, 7, 384) 0 ['batch_normalization_26[0][0]'] activation_29 (Activation) (None, 7, 7, 96) 0 ['batch_normalization_29[0][0]'] max_pooling2d_2 (MaxPooling2D) (None, 7, 7, 288) 0 ['mixed2[0][0]'] mixed3 (Concatenate) (None, 7, 7, 768) 0 ['activation_26[0][0]', 'activation_29[0][0]', 'max_pooling2d_2[0][0]']
conv2d_34 (Conv2D) (None, 7, 7, 128) 98304 ['mixed3[0][0]'] batch_normalization_34 (BatchN (None, 7, 7, 128) 384 ['conv2d_34[0][0]'] ormalization) activation_34 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_34[0][0]'] conv2d_35 (Conv2D) (None, 7, 7, 128) 114688 ['activation_34[0][0]'] batch_normalization_35 (BatchN (None, 7, 7, 128) 384 ['conv2d_35[0][0]'] ormalization) activation_35 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_35[0][0]'] conv2d_31 (Conv2D) (None, 7, 7, 128) 98304 ['mixed3[0][0]'] conv2d_36 (Conv2D) (None, 7, 7, 128) 114688 ['activation_35[0][0]'] batch_normalization_31 (BatchN (None, 7, 7, 128) 384 ['conv2d_31[0][0]'] ormalization) batch_normalization_36 (BatchN (None, 7, 7, 128) 384 ['conv2d_36[0][0]'] ormalization) activation_31 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_31[0][0]'] activation_36 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_36[0][0]'] conv2d_32 (Conv2D) (None, 7, 7, 128) 114688 ['activation_31[0][0]'] conv2d_37 (Conv2D) (None, 7, 7, 128) 114688 ['activation_36[0][0]'] batch_normalization_32 (BatchN (None, 7, 7, 128) 384 ['conv2d_32[0][0]'] ormalization) batch_normalization_37 (BatchN (None, 7, 7, 128) 384 ['conv2d_37[0][0]'] ormalization) activation_32 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_32[0][0]'] activation_37 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_37[0][0]'] average_pooling2d_3 (AveragePo (None, 7, 7, 768) 0 ['mixed3[0][0]'] oling2D) conv2d_30 (Conv2D) (None, 7, 7, 192) 147456 ['mixed3[0][0]'] conv2d_33 (Conv2D) (None, 7, 7, 192) 172032 ['activation_32[0][0]'] conv2d_38 (Conv2D) (None, 7, 7, 192) 172032 ['activation_37[0][0]'] conv2d_39 (Conv2D) (None, 7, 7, 192) 147456 ['average_pooling2d_3[0][0]'] batch_normalization_30 (BatchN (None, 7, 7, 192) 576 ['conv2d_30[0][0]'] ormalization) batch_normalization_33 (BatchN (None, 7, 7, 192) 576 ['conv2d_33[0][0]'] ormalization) batch_normalization_38 (BatchN (None, 7, 7, 192) 576 ['conv2d_38[0][0]'] ormalization) batch_normalization_39 (BatchN (None, 7, 7, 192) 576 ['conv2d_39[0][0]'] ormalization) activation_30 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_30[0][0]'] activation_33 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_33[0][0]'] activation_38 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_38[0][0]'] activation_39 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_39[0][0]'] mixed4 (Concatenate) (None, 7, 7, 768) 0 ['activation_30[0][0]', 'activation_33[0][0]', 'activation_38[0][0]', 'activation_39[0][0]'] conv2d_44 (Conv2D) (None, 7, 7, 160) 122880 ['mixed4[0][0]'] batch_normalization_44 (BatchN (None, 7, 7, 160) 480 ['conv2d_44[0][0]'] ormalization)
activation_44 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_44[0][0]'] conv2d_45 (Conv2D) (None, 7, 7, 160) 179200 ['activation_44[0][0]'] batch_normalization_45 (BatchN (None, 7, 7, 160) 480 ['conv2d_45[0][0]'] ormalization) activation_45 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_45[0][0]'] conv2d_41 (Conv2D) (None, 7, 7, 160) 122880 ['mixed4[0][0]'] conv2d_46 (Conv2D) (None, 7, 7, 160) 179200 ['activation_45[0][0]'] batch_normalization_41 (BatchN (None, 7, 7, 160) 480 ['conv2d_41[0][0]'] ormalization) batch_normalization_46 (BatchN (None, 7, 7, 160) 480 ['conv2d_46[0][0]'] ormalization) activation_41 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_41[0][0]'] activation_46 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_46[0][0]'] conv2d_42 (Conv2D) (None, 7, 7, 160) 179200 ['activation_41[0][0]'] conv2d_47 (Conv2D) (None, 7, 7, 160) 179200 ['activation_46[0][0]'] batch_normalization_42 (BatchN (None, 7, 7, 160) 480 ['conv2d_42[0][0]'] ormalization) batch_normalization_47 (BatchN (None, 7, 7, 160) 480 ['conv2d_47[0][0]'] ormalization) activation_42 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_42[0][0]'] activation_47 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_47[0][0]'] average_pooling2d_4 (AveragePo (None, 7, 7, 768) 0 ['mixed4[0][0]'] oling2D) conv2d_40 (Conv2D) (None, 7, 7, 192) 147456 ['mixed4[0][0]'] conv2d_43 (Conv2D) (None, 7, 7, 192) 215040 ['activation_42[0][0]'] conv2d_48 (Conv2D) (None, 7, 7, 192) 215040 ['activation_47[0][0]'] conv2d_49 (Conv2D) (None, 7, 7, 192) 147456 ['average_pooling2d_4[0][0]'] batch_normalization_40 (BatchN (None, 7, 7, 192) 576 ['conv2d_40[0][0]'] ormalization) batch_normalization_43 (BatchN (None, 7, 7, 192) 576 ['conv2d_43[0][0]'] ormalization) batch_normalization_48 (BatchN (None, 7, 7, 192) 576 ['conv2d_48[0][0]'] ormalization) batch_normalization_49 (BatchN (None, 7, 7, 192) 576 ['conv2d_49[0][0]'] ormalization) activation_40 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_40[0][0]'] activation_43 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_43[0][0]'] activation_48 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_48[0][0]'] activation_49 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_49[0][0]'] mixed5 (Concatenate) (None, 7, 7, 768) 0 ['activation_40[0][0]', 'activation_43[0][0]', 'activation_48[0][0]', 'activation_49[0][0]'] conv2d_54 (Conv2D) (None, 7, 7, 160) 122880 ['mixed5[0][0]'] batch_normalization_54 (BatchN (None, 7, 7, 160) 480 ['conv2d_54[0][0]'] ormalization) activation_54 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_54[0][0]'] conv2d_55 (Conv2D) (None, 7, 7, 160) 179200 ['activation_54[0][0]']
batch_normalization_55 (BatchN (None, 7, 7, 160) 480 ['conv2d_55[0][0]'] ormalization) activation_55 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_55[0][0]'] conv2d_51 (Conv2D) (None, 7, 7, 160) 122880 ['mixed5[0][0]'] conv2d_56 (Conv2D) (None, 7, 7, 160) 179200 ['activation_55[0][0]'] batch_normalization_51 (BatchN (None, 7, 7, 160) 480 ['conv2d_51[0][0]'] ormalization) batch_normalization_56 (BatchN (None, 7, 7, 160) 480 ['conv2d_56[0][0]'] ormalization) activation_51 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_51[0][0]'] activation_56 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_56[0][0]'] conv2d_52 (Conv2D) (None, 7, 7, 160) 179200 ['activation_51[0][0]'] conv2d_57 (Conv2D) (None, 7, 7, 160) 179200 ['activation_56[0][0]'] batch_normalization_52 (BatchN (None, 7, 7, 160) 480 ['conv2d_52[0][0]'] ormalization) batch_normalization_57 (BatchN (None, 7, 7, 160) 480 ['conv2d_57[0][0]'] ormalization) activation_52 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_52[0][0]'] activation_57 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_57[0][0]'] average_pooling2d_5 (AveragePo (None, 7, 7, 768) 0 ['mixed5[0][0]'] oling2D) conv2d_50 (Conv2D) (None, 7, 7, 192) 147456 ['mixed5[0][0]'] conv2d_53 (Conv2D) (None, 7, 7, 192) 215040 ['activation_52[0][0]'] conv2d_58 (Conv2D) (None, 7, 7, 192) 215040 ['activation_57[0][0]'] conv2d_59 (Conv2D) (None, 7, 7, 192) 147456 ['average_pooling2d_5[0][0]'] batch_normalization_50 (BatchN (None, 7, 7, 192) 576 ['conv2d_50[0][0]'] ormalization) batch_normalization_53 (BatchN (None, 7, 7, 192) 576 ['conv2d_53[0][0]'] ormalization) batch_normalization_58 (BatchN (None, 7, 7, 192) 576 ['conv2d_58[0][0]'] ormalization) batch_normalization_59 (BatchN (None, 7, 7, 192) 576 ['conv2d_59[0][0]'] ormalization) activation_50 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_50[0][0]'] activation_53 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_53[0][0]'] activation_58 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_58[0][0]'] activation_59 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_59[0][0]'] mixed6 (Concatenate) (None, 7, 7, 768) 0 ['activation_50[0][0]', 'activation_53[0][0]', 'activation_58[0][0]', 'activation_59[0][0]'] conv2d_64 (Conv2D) (None, 7, 7, 192) 147456 ['mixed6[0][0]'] batch_normalization_64 (BatchN (None, 7, 7, 192) 576 ['conv2d_64[0][0]'] ormalization) activation_64 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_64[0][0]'] conv2d_65 (Conv2D) (None, 7, 7, 192) 258048 ['activation_64[0][0]'] batch_normalization_65 (BatchN (None, 7, 7, 192) 576 ['conv2d_65[0][0]'] ormalization) activation_65 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_65[0][0]']
conv2d_61 (Conv2D) (None, 7, 7, 192) 147456 ['mixed6[0][0]'] conv2d_66 (Conv2D) (None, 7, 7, 192) 258048 ['activation_65[0][0]'] batch_normalization_61 (BatchN (None, 7, 7, 192) 576 ['conv2d_61[0][0]'] ormalization) batch_normalization_66 (BatchN (None, 7, 7, 192) 576 ['conv2d_66[0][0]'] ormalization) activation_61 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_61[0][0]'] activation_66 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_66[0][0]'] conv2d_62 (Conv2D) (None, 7, 7, 192) 258048 ['activation_61[0][0]'] conv2d_67 (Conv2D) (None, 7, 7, 192) 258048 ['activation_66[0][0]'] batch_normalization_62 (BatchN (None, 7, 7, 192) 576 ['conv2d_62[0][0]'] ormalization) batch_normalization_67 (BatchN (None, 7, 7, 192) 576 ['conv2d_67[0][0]'] ormalization) activation_62 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_62[0][0]'] activation_67 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_67[0][0]'] average_pooling2d_6 (AveragePo (None, 7, 7, 768) 0 ['mixed6[0][0]'] oling2D) conv2d_60 (Conv2D) (None, 7, 7, 192) 147456 ['mixed6[0][0]'] conv2d_63 (Conv2D) (None, 7, 7, 192) 258048 ['activation_62[0][0]'] conv2d_68 (Conv2D) (None, 7, 7, 192) 258048 ['activation_67[0][0]'] conv2d_69 (Conv2D) (None, 7, 7, 192) 147456 ['average_pooling2d_6[0][0]'] batch_normalization_60 (BatchN (None, 7, 7, 192) 576 ['conv2d_60[0][0]'] ormalization) batch_normalization_63 (BatchN (None, 7, 7, 192) 576 ['conv2d_63[0][0]'] ormalization) batch_normalization_68 (BatchN (None, 7, 7, 192) 576 ['conv2d_68[0][0]'] ormalization) batch_normalization_69 (BatchN (None, 7, 7, 192) 576 ['conv2d_69[0][0]'] ormalization) activation_60 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_60[0][0]'] activation_63 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_63[0][0]'] activation_68 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_68[0][0]'] activation_69 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_69[0][0]'] mixed7 (Concatenate) (None, 7, 7, 768) 0 ['activation_60[0][0]', 'activation_63[0][0]', 'activation_68[0][0]', 'activation_69[0][0]'] conv2d_72 (Conv2D) (None, 7, 7, 192) 147456 ['mixed7[0][0]'] batch_normalization_72 (BatchN (None, 7, 7, 192) 576 ['conv2d_72[0][0]'] ormalization) activation_72 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_72[0][0]'] conv2d_73 (Conv2D) (None, 7, 7, 192) 258048 ['activation_72[0][0]'] batch_normalization_73 (BatchN (None, 7, 7, 192) 576 ['conv2d_73[0][0]'] ormalization) activation_73 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_73[0][0]'] conv2d_70 (Conv2D) (None, 7, 7, 192) 147456 ['mixed7[0][0]'] conv2d_74 (Conv2D) (None, 7, 7, 192) 258048 ['activation_73[0][0]'] batch_normalization_70 (BatchN (None, 7, 7, 192) 576 ['conv2d_70[0][0]']
ormalization) batch_normalization_74 (BatchN (None, 7, 7, 192) 576 ['conv2d_74[0][0]'] ormalization) activation_70 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_70[0][0]'] activation_74 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_74[0][0]'] conv2d_71 (Conv2D) (None, 3, 3, 320) 552960 ['activation_70[0][0]'] conv2d_75 (Conv2D) (None, 3, 3, 192) 331776 ['activation_74[0][0]'] batch_normalization_71 (BatchN (None, 3, 3, 320) 960 ['conv2d_71[0][0]'] ormalization) batch_normalization_75 (BatchN (None, 3, 3, 192) 576 ['conv2d_75[0][0]'] ormalization) activation_71 (Activation) (None, 3, 3, 320) 0 ['batch_normalization_71[0][0]'] activation_75 (Activation) (None, 3, 3, 192) 0 ['batch_normalization_75[0][0]'] max_pooling2d_3 (MaxPooling2D) (None, 3, 3, 768) 0 ['mixed7[0][0]'] mixed8 (Concatenate) (None, 3, 3, 1280) 0 ['activation_71[0][0]', 'activation_75[0][0]', 'max_pooling2d_3[0][0]'] conv2d_80 (Conv2D) (None, 3, 3, 448) 573440 ['mixed8[0][0]'] batch_normalization_80 (BatchN (None, 3, 3, 448) 1344 ['conv2d_80[0][0]'] ormalization) activation_80 (Activation) (None, 3, 3, 448) 0 ['batch_normalization_80[0][0]'] conv2d_77 (Conv2D) (None, 3, 3, 384) 491520 ['mixed8[0][0]'] conv2d_81 (Conv2D) (None, 3, 3, 384) 1548288 ['activation_80[0][0]'] batch_normalization_77 (BatchN (None, 3, 3, 384) 1152 ['conv2d_77[0][0]'] ormalization) batch_normalization_81 (BatchN (None, 3, 3, 384) 1152 ['conv2d_81[0][0]'] ormalization) activation_77 (Activation) (None, 3, 3, 384) 0 ['batch_normalization_77[0][0]'] activation_81 (Activation) (None, 3, 3, 384) 0 ['batch_normalization_81[0][0]'] conv2d_78 (Conv2D) (None, 3, 3, 384) 442368 ['activation_77[0][0]'] conv2d_79 (Conv2D) (None, 3, 3, 384) 442368 ['activation_77[0][0]'] conv2d_82 (Conv2D) (None, 3, 3, 384) 442368 ['activation_81[0][0]'] conv2d_83 (Conv2D) (None, 3, 3, 384) 442368 ['activation_81[0][0]'] average_pooling2d_7 (AveragePo (None, 3, 3, 1280) 0 ['mixed8[0][0]'] oling2D) conv2d_76 (Conv2D) (None, 3, 3, 320) 409600 ['mixed8[0][0]'] batch_normalization_78 (BatchN (None, 3, 3, 384) 1152 ['conv2d_78[0][0]'] ormalization) batch_normalization_79 (BatchN (None, 3, 3, 384) 1152 ['conv2d_79[0][0]'] ormalization) batch_normalization_82 (BatchN (None, 3, 3, 384) 1152 ['conv2d_82[0][0]'] ormalization) batch_normalization_83 (BatchN (None, 3, 3, 384) 1152 ['conv2d_83[0][0]'] ormalization) conv2d_84 (Conv2D) (None, 3, 3, 192) 245760 ['average_pooling2d_7[0][0]'] batch_normalization_76 (BatchN (None, 3, 3, 320) 960 ['conv2d_76[0][0]'] ormalization) activation_78 (Activation) (None, 3, 3, 384) 0 ['batch_normalization_78[0][0]'] activation_79 (Activation) (None, 3, 3, 384) 0 ['batch_normalization_79[0][0]']
activation_82 (Activation) (None, 3, 3, 384) 0 ['batch_normalization_82[0][0]'] activation_83 (Activation) (None, 3, 3, 384) 0 ['batch_normalization_83[0][0]'] batch_normalization_84 (BatchN (None, 3, 3, 192) 576 ['conv2d_84[0][0]'] ormalization) activation_76 (Activation) (None, 3, 3, 320) 0 ['batch_normalization_76[0][0]'] mixed9_0 (Concatenate) (None, 3, 3, 768) 0 ['activation_78[0][0]', 'activation_79[0][0]'] concatenate (Concatenate) (None, 3, 3, 768) 0 ['activation_82[0][0]', 'activation_83[0][0]'] activation_84 (Activation) (None, 3, 3, 192) 0 ['batch_normalization_84[0][0]'] mixed9 (Concatenate) (None, 3, 3, 2048) 0 ['activation_76[0][0]', 'mixed9_0[0][0]', 'concatenate[0][0]', 'activation_84[0][0]'] conv2d_89 (Conv2D) (None, 3, 3, 448) 917504 ['mixed9[0][0]'] batch_normalization_89 (BatchN (None, 3, 3, 448) 1344 ['conv2d_89[0][0]'] ormalization) activation_89 (Activation) (None, 3, 3, 448) 0 ['batch_normalization_89[0][0]'] conv2d_86 (Conv2D) (None, 3, 3, 384) 786432 ['mixed9[0][0]'] conv2d_90 (Conv2D) (None, 3, 3, 384) 1548288 ['activation_89[0][0]'] batch_normalization_86 (BatchN (None, 3, 3, 384) 1152 ['conv2d_86[0][0]'] ormalization) batch_normalization_90 (BatchN (None, 3, 3, 384) 1152 ['conv2d_90[0][0]'] ormalization) activation_86 (Activation) (None, 3, 3, 384) 0 ['batch_normalization_86[0][0]'] activation_90 (Activation) (None, 3, 3, 384) 0 ['batch_normalization_90[0][0]'] conv2d_87 (Conv2D) (None, 3, 3, 384) 442368 ['activation_86[0][0]'] conv2d_88 (Conv2D) (None, 3, 3, 384) 442368 ['activation_86[0][0]'] conv2d_91 (Conv2D) (None, 3, 3, 384) 442368 ['activation_90[0][0]'] conv2d_92 (Conv2D) (None, 3, 3, 384) 442368 ['activation_90[0][0]'] average_pooling2d_8 (AveragePo (None, 3, 3, 2048) 0 ['mixed9[0][0]'] oling2D) conv2d_85 (Conv2D) (None, 3, 3, 320) 655360 ['mixed9[0][0]'] batch_normalization_87 (BatchN (None, 3, 3, 384) 1152 ['conv2d_87[0][0]'] ormalization) batch_normalization_88 (BatchN (None, 3, 3, 384) 1152 ['conv2d_88[0][0]'] ormalization) batch_normalization_91 (BatchN (None, 3, 3, 384) 1152 ['conv2d_91[0][0]'] ormalization) batch_normalization_92 (BatchN (None, 3, 3, 384) 1152 ['conv2d_92[0][0]'] ormalization) conv2d_93 (Conv2D) (None, 3, 3, 192) 393216 ['average_pooling2d_8[0][0]'] batch_normalization_85 (BatchN (None, 3, 3, 320) 960 ['conv2d_85[0][0]'] ormalization) activation_87 (Activation) (None, 3, 3, 384) 0 ['batch_normalization_87[0][0]'] activation_88 (Activation) (None, 3, 3, 384) 0 ['batch_normalization_88[0][0]'] activation_91 (Activation) (None, 3, 3, 384) 0 ['batch_normalization_91[0][0]'] activation_92 (Activation) (None, 3, 3, 384) 0 ['batch_normalization_92[0][0]'] batch_normalization_93 (BatchN (None, 3, 3, 192) 576 ['conv2d_93[0][0]']
ormalization) activation_85 (Activation) (None, 3, 3, 320) 0 ['batch_normalization_85[0][0]'] mixed9_1 (Concatenate) (None, 3, 3, 768) 0 ['activation_87[0][0]', 'activation_88[0][0]'] concatenate_1 (Concatenate) (None, 3, 3, 768) 0 ['activation_91[0][0]', 'activation_92[0][0]'] activation_93 (Activation) (None, 3, 3, 192) 0 ['batch_normalization_93[0][0]'] mixed10 (Concatenate) (None, 3, 3, 2048) 0 ['activation_85[0][0]', 'mixed9_1[0][0]', 'concatenate_1[0][0]', 'activation_93[0][0]'] ================================================================================================== Total params: 21,802,784 Trainable params: 0 Non-trainable params: 21,802,784 __________________________________________________________________________________________________
Removing the Fully Connected Layer¶
To adapt InceptionV3 to our problem, we remove the fully connected (dense) output layer of the original model, which was designed for classifying 1,000 ImageNet categories. Instead, we replace it with our own set of layers that are trained on our dataset.
The base model retains its feature extraction capabilities, and we customize only the classification head. We also select an intermediate layer (e.g., mixed7) as the cutoff point. This ensures that we retain generalized feature representations while discarding layers that might be too specialized for ImageNet categories.
last_layer = pretrained_model.get_layer('mixed7')
last_output = last_layer.output
print('Output shape of last layer : ', last_layer.output.shape)
Output shape of last layer : (None, 7, 7, 768)
Adding Custom Dense Layers¶
Once the convolutional layers are in place, we introduce additional fully connected layers to classify the images. These layers will be responsible for learning new patterns specific to our dataset, which consists of recognizing cats and dogs.
Key components of the classifier:
- Flatten layer: Converts the 3D feature maps into a 1D vector.
- Dense layers: Fully connected layers to learn higher-level features.
- Dropout layer: Prevents overfitting by randomly deactivating a portion of neurons during training.
- Output layer: Uses a sigmoid activation function for binary classification (cats vs. dogs).
# last_output represents the output of the last convolutional layer
x = tf.keras.layers.Flatten()(last_output) # Flatten() converts the feature maps from the convolutional layers into a 1D vector
x = tf.keras.layers.Dense(1024, activation='relu')(x) # Adds a fully connected (dense) layer with 1024 neurons
x = tf.keras.layers.Dropout(0.2)(x) # Adds a Dropout layer to randomly deactivate 20% of neurons during training for reducing overfitting
x = tf.keras.layers.Dense(1, activation='sigmoid')(x) # 1 neuron: Because it's a binary classification
# Creates a new model by specifying:
# pretrained_model.input: Uses InceptionV3’s input layer
# x: The final custom dense layers
# This model will use InceptionV3 for feature extraction and train only the dense layers
model = tf.keras.Model(pretrained_model.input, x)
model.summary()
Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 150, 150, 3 0 [] )] conv2d (Conv2D) (None, 74, 74, 32) 864 ['input_1[0][0]'] batch_normalization (BatchNorm (None, 74, 74, 32) 96 ['conv2d[0][0]'] alization) activation (Activation) (None, 74, 74, 32) 0 ['batch_normalization[0][0]'] conv2d_1 (Conv2D) (None, 72, 72, 32) 9216 ['activation[0][0]'] batch_normalization_1 (BatchNo (None, 72, 72, 32) 96 ['conv2d_1[0][0]'] rmalization) activation_1 (Activation) (None, 72, 72, 32) 0 ['batch_normalization_1[0][0]'] conv2d_2 (Conv2D) (None, 72, 72, 64) 18432 ['activation_1[0][0]'] batch_normalization_2 (BatchNo (None, 72, 72, 64) 192 ['conv2d_2[0][0]'] rmalization) activation_2 (Activation) (None, 72, 72, 64) 0 ['batch_normalization_2[0][0]'] max_pooling2d (MaxPooling2D) (None, 35, 35, 64) 0 ['activation_2[0][0]'] conv2d_3 (Conv2D) (None, 35, 35, 80) 5120 ['max_pooling2d[0][0]'] batch_normalization_3 (BatchNo (None, 35, 35, 80) 240 ['conv2d_3[0][0]'] rmalization) activation_3 (Activation) (None, 35, 35, 80) 0 ['batch_normalization_3[0][0]'] conv2d_4 (Conv2D) (None, 33, 33, 192) 138240 ['activation_3[0][0]'] batch_normalization_4 (BatchNo (None, 33, 33, 192) 576 ['conv2d_4[0][0]'] rmalization) activation_4 (Activation) (None, 33, 33, 192) 0 ['batch_normalization_4[0][0]'] max_pooling2d_1 (MaxPooling2D) (None, 16, 16, 192) 0 ['activation_4[0][0]'] conv2d_8 (Conv2D) (None, 16, 16, 64) 12288 ['max_pooling2d_1[0][0]'] batch_normalization_8 (BatchNo (None, 16, 16, 64) 192 ['conv2d_8[0][0]'] rmalization) activation_8 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_8[0][0]'] conv2d_6 (Conv2D) (None, 16, 16, 48) 9216 ['max_pooling2d_1[0][0]'] conv2d_9 (Conv2D) (None, 16, 16, 96) 55296 ['activation_8[0][0]'] batch_normalization_6 (BatchNo (None, 16, 16, 48) 144 ['conv2d_6[0][0]'] rmalization) batch_normalization_9 (BatchNo (None, 16, 16, 96) 288 ['conv2d_9[0][0]'] rmalization) activation_6 (Activation) (None, 16, 16, 48) 0 ['batch_normalization_6[0][0]'] activation_9 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_9[0][0]'] average_pooling2d (AveragePool (None, 16, 16, 192) 0 ['max_pooling2d_1[0][0]'] ing2D) conv2d_5 (Conv2D) (None, 16, 16, 64) 12288 ['max_pooling2d_1[0][0]'] conv2d_7 (Conv2D) (None, 16, 16, 64) 76800 ['activation_6[0][0]'] conv2d_10 (Conv2D) (None, 16, 16, 96) 82944 ['activation_9[0][0]'] conv2d_11 (Conv2D) (None, 16, 16, 32) 6144 ['average_pooling2d[0][0]'] batch_normalization_5 (BatchNo (None, 16, 16, 64) 192 ['conv2d_5[0][0]'] rmalization) batch_normalization_7 (BatchNo (None, 16, 16, 64) 192 ['conv2d_7[0][0]'] rmalization)
batch_normalization_10 (BatchN (None, 16, 16, 96) 288 ['conv2d_10[0][0]'] ormalization) batch_normalization_11 (BatchN (None, 16, 16, 32) 96 ['conv2d_11[0][0]'] ormalization) activation_5 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_5[0][0]'] activation_7 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_7[0][0]'] activation_10 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_10[0][0]'] activation_11 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_11[0][0]'] mixed0 (Concatenate) (None, 16, 16, 256) 0 ['activation_5[0][0]', 'activation_7[0][0]', 'activation_10[0][0]', 'activation_11[0][0]'] conv2d_15 (Conv2D) (None, 16, 16, 64) 16384 ['mixed0[0][0]'] batch_normalization_15 (BatchN (None, 16, 16, 64) 192 ['conv2d_15[0][0]'] ormalization) activation_15 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_15[0][0]'] conv2d_13 (Conv2D) (None, 16, 16, 48) 12288 ['mixed0[0][0]'] conv2d_16 (Conv2D) (None, 16, 16, 96) 55296 ['activation_15[0][0]'] batch_normalization_13 (BatchN (None, 16, 16, 48) 144 ['conv2d_13[0][0]'] ormalization) batch_normalization_16 (BatchN (None, 16, 16, 96) 288 ['conv2d_16[0][0]'] ormalization) activation_13 (Activation) (None, 16, 16, 48) 0 ['batch_normalization_13[0][0]'] activation_16 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_16[0][0]'] average_pooling2d_1 (AveragePo (None, 16, 16, 256) 0 ['mixed0[0][0]'] oling2D) conv2d_12 (Conv2D) (None, 16, 16, 64) 16384 ['mixed0[0][0]'] conv2d_14 (Conv2D) (None, 16, 16, 64) 76800 ['activation_13[0][0]'] conv2d_17 (Conv2D) (None, 16, 16, 96) 82944 ['activation_16[0][0]'] conv2d_18 (Conv2D) (None, 16, 16, 64) 16384 ['average_pooling2d_1[0][0]'] batch_normalization_12 (BatchN (None, 16, 16, 64) 192 ['conv2d_12[0][0]'] ormalization) batch_normalization_14 (BatchN (None, 16, 16, 64) 192 ['conv2d_14[0][0]'] ormalization) batch_normalization_17 (BatchN (None, 16, 16, 96) 288 ['conv2d_17[0][0]'] ormalization) batch_normalization_18 (BatchN (None, 16, 16, 64) 192 ['conv2d_18[0][0]'] ormalization) activation_12 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_12[0][0]'] activation_14 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_14[0][0]'] activation_17 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_17[0][0]'] activation_18 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_18[0][0]'] mixed1 (Concatenate) (None, 16, 16, 288) 0 ['activation_12[0][0]', 'activation_14[0][0]', 'activation_17[0][0]', 'activation_18[0][0]'] conv2d_22 (Conv2D) (None, 16, 16, 64) 18432 ['mixed1[0][0]'] batch_normalization_22 (BatchN (None, 16, 16, 64) 192 ['conv2d_22[0][0]'] ormalization) activation_22 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_22[0][0]']
conv2d_20 (Conv2D) (None, 16, 16, 48) 13824 ['mixed1[0][0]'] conv2d_23 (Conv2D) (None, 16, 16, 96) 55296 ['activation_22[0][0]'] batch_normalization_20 (BatchN (None, 16, 16, 48) 144 ['conv2d_20[0][0]'] ormalization) batch_normalization_23 (BatchN (None, 16, 16, 96) 288 ['conv2d_23[0][0]'] ormalization) activation_20 (Activation) (None, 16, 16, 48) 0 ['batch_normalization_20[0][0]'] activation_23 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_23[0][0]'] average_pooling2d_2 (AveragePo (None, 16, 16, 288) 0 ['mixed1[0][0]'] oling2D) conv2d_19 (Conv2D) (None, 16, 16, 64) 18432 ['mixed1[0][0]'] conv2d_21 (Conv2D) (None, 16, 16, 64) 76800 ['activation_20[0][0]'] conv2d_24 (Conv2D) (None, 16, 16, 96) 82944 ['activation_23[0][0]'] conv2d_25 (Conv2D) (None, 16, 16, 64) 18432 ['average_pooling2d_2[0][0]'] batch_normalization_19 (BatchN (None, 16, 16, 64) 192 ['conv2d_19[0][0]'] ormalization) batch_normalization_21 (BatchN (None, 16, 16, 64) 192 ['conv2d_21[0][0]'] ormalization) batch_normalization_24 (BatchN (None, 16, 16, 96) 288 ['conv2d_24[0][0]'] ormalization) batch_normalization_25 (BatchN (None, 16, 16, 64) 192 ['conv2d_25[0][0]'] ormalization) activation_19 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_19[0][0]'] activation_21 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_21[0][0]'] activation_24 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_24[0][0]'] activation_25 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_25[0][0]'] mixed2 (Concatenate) (None, 16, 16, 288) 0 ['activation_19[0][0]', 'activation_21[0][0]', 'activation_24[0][0]', 'activation_25[0][0]'] conv2d_27 (Conv2D) (None, 16, 16, 64) 18432 ['mixed2[0][0]'] batch_normalization_27 (BatchN (None, 16, 16, 64) 192 ['conv2d_27[0][0]'] ormalization) activation_27 (Activation) (None, 16, 16, 64) 0 ['batch_normalization_27[0][0]'] conv2d_28 (Conv2D) (None, 16, 16, 96) 55296 ['activation_27[0][0]'] batch_normalization_28 (BatchN (None, 16, 16, 96) 288 ['conv2d_28[0][0]'] ormalization) activation_28 (Activation) (None, 16, 16, 96) 0 ['batch_normalization_28[0][0]'] conv2d_26 (Conv2D) (None, 7, 7, 384) 995328 ['mixed2[0][0]'] conv2d_29 (Conv2D) (None, 7, 7, 96) 82944 ['activation_28[0][0]'] batch_normalization_26 (BatchN (None, 7, 7, 384) 1152 ['conv2d_26[0][0]'] ormalization) batch_normalization_29 (BatchN (None, 7, 7, 96) 288 ['conv2d_29[0][0]'] ormalization) activation_26 (Activation) (None, 7, 7, 384) 0 ['batch_normalization_26[0][0]'] activation_29 (Activation) (None, 7, 7, 96) 0 ['batch_normalization_29[0][0]'] max_pooling2d_2 (MaxPooling2D) (None, 7, 7, 288) 0 ['mixed2[0][0]'] mixed3 (Concatenate) (None, 7, 7, 768) 0 ['activation_26[0][0]', 'activation_29[0][0]', 'max_pooling2d_2[0][0]']
conv2d_34 (Conv2D) (None, 7, 7, 128) 98304 ['mixed3[0][0]'] batch_normalization_34 (BatchN (None, 7, 7, 128) 384 ['conv2d_34[0][0]'] ormalization) activation_34 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_34[0][0]'] conv2d_35 (Conv2D) (None, 7, 7, 128) 114688 ['activation_34[0][0]'] batch_normalization_35 (BatchN (None, 7, 7, 128) 384 ['conv2d_35[0][0]'] ormalization) activation_35 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_35[0][0]'] conv2d_31 (Conv2D) (None, 7, 7, 128) 98304 ['mixed3[0][0]'] conv2d_36 (Conv2D) (None, 7, 7, 128) 114688 ['activation_35[0][0]'] batch_normalization_31 (BatchN (None, 7, 7, 128) 384 ['conv2d_31[0][0]'] ormalization) batch_normalization_36 (BatchN (None, 7, 7, 128) 384 ['conv2d_36[0][0]'] ormalization) activation_31 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_31[0][0]'] activation_36 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_36[0][0]'] conv2d_32 (Conv2D) (None, 7, 7, 128) 114688 ['activation_31[0][0]'] conv2d_37 (Conv2D) (None, 7, 7, 128) 114688 ['activation_36[0][0]'] batch_normalization_32 (BatchN (None, 7, 7, 128) 384 ['conv2d_32[0][0]'] ormalization) batch_normalization_37 (BatchN (None, 7, 7, 128) 384 ['conv2d_37[0][0]'] ormalization) activation_32 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_32[0][0]'] activation_37 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_37[0][0]'] average_pooling2d_3 (AveragePo (None, 7, 7, 768) 0 ['mixed3[0][0]'] oling2D) conv2d_30 (Conv2D) (None, 7, 7, 192) 147456 ['mixed3[0][0]'] conv2d_33 (Conv2D) (None, 7, 7, 192) 172032 ['activation_32[0][0]'] conv2d_38 (Conv2D) (None, 7, 7, 192) 172032 ['activation_37[0][0]'] conv2d_39 (Conv2D) (None, 7, 7, 192) 147456 ['average_pooling2d_3[0][0]'] batch_normalization_30 (BatchN (None, 7, 7, 192) 576 ['conv2d_30[0][0]'] ormalization) batch_normalization_33 (BatchN (None, 7, 7, 192) 576 ['conv2d_33[0][0]'] ormalization) batch_normalization_38 (BatchN (None, 7, 7, 192) 576 ['conv2d_38[0][0]'] ormalization) batch_normalization_39 (BatchN (None, 7, 7, 192) 576 ['conv2d_39[0][0]'] ormalization) activation_30 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_30[0][0]'] activation_33 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_33[0][0]'] activation_38 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_38[0][0]'] activation_39 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_39[0][0]'] mixed4 (Concatenate) (None, 7, 7, 768) 0 ['activation_30[0][0]', 'activation_33[0][0]', 'activation_38[0][0]', 'activation_39[0][0]'] conv2d_44 (Conv2D) (None, 7, 7, 160) 122880 ['mixed4[0][0]'] batch_normalization_44 (BatchN (None, 7, 7, 160) 480 ['conv2d_44[0][0]'] ormalization)
activation_44 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_44[0][0]'] conv2d_45 (Conv2D) (None, 7, 7, 160) 179200 ['activation_44[0][0]'] batch_normalization_45 (BatchN (None, 7, 7, 160) 480 ['conv2d_45[0][0]'] ormalization) activation_45 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_45[0][0]'] conv2d_41 (Conv2D) (None, 7, 7, 160) 122880 ['mixed4[0][0]'] conv2d_46 (Conv2D) (None, 7, 7, 160) 179200 ['activation_45[0][0]'] batch_normalization_41 (BatchN (None, 7, 7, 160) 480 ['conv2d_41[0][0]'] ormalization) batch_normalization_46 (BatchN (None, 7, 7, 160) 480 ['conv2d_46[0][0]'] ormalization) activation_41 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_41[0][0]'] activation_46 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_46[0][0]'] conv2d_42 (Conv2D) (None, 7, 7, 160) 179200 ['activation_41[0][0]'] conv2d_47 (Conv2D) (None, 7, 7, 160) 179200 ['activation_46[0][0]'] batch_normalization_42 (BatchN (None, 7, 7, 160) 480 ['conv2d_42[0][0]'] ormalization) batch_normalization_47 (BatchN (None, 7, 7, 160) 480 ['conv2d_47[0][0]'] ormalization) activation_42 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_42[0][0]'] activation_47 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_47[0][0]'] average_pooling2d_4 (AveragePo (None, 7, 7, 768) 0 ['mixed4[0][0]'] oling2D) conv2d_40 (Conv2D) (None, 7, 7, 192) 147456 ['mixed4[0][0]'] conv2d_43 (Conv2D) (None, 7, 7, 192) 215040 ['activation_42[0][0]'] conv2d_48 (Conv2D) (None, 7, 7, 192) 215040 ['activation_47[0][0]'] conv2d_49 (Conv2D) (None, 7, 7, 192) 147456 ['average_pooling2d_4[0][0]'] batch_normalization_40 (BatchN (None, 7, 7, 192) 576 ['conv2d_40[0][0]'] ormalization) batch_normalization_43 (BatchN (None, 7, 7, 192) 576 ['conv2d_43[0][0]'] ormalization) batch_normalization_48 (BatchN (None, 7, 7, 192) 576 ['conv2d_48[0][0]'] ormalization) batch_normalization_49 (BatchN (None, 7, 7, 192) 576 ['conv2d_49[0][0]'] ormalization) activation_40 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_40[0][0]'] activation_43 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_43[0][0]'] activation_48 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_48[0][0]'] activation_49 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_49[0][0]'] mixed5 (Concatenate) (None, 7, 7, 768) 0 ['activation_40[0][0]', 'activation_43[0][0]', 'activation_48[0][0]', 'activation_49[0][0]'] conv2d_54 (Conv2D) (None, 7, 7, 160) 122880 ['mixed5[0][0]'] batch_normalization_54 (BatchN (None, 7, 7, 160) 480 ['conv2d_54[0][0]'] ormalization) activation_54 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_54[0][0]'] conv2d_55 (Conv2D) (None, 7, 7, 160) 179200 ['activation_54[0][0]']
batch_normalization_55 (BatchN (None, 7, 7, 160) 480 ['conv2d_55[0][0]'] ormalization) activation_55 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_55[0][0]'] conv2d_51 (Conv2D) (None, 7, 7, 160) 122880 ['mixed5[0][0]'] conv2d_56 (Conv2D) (None, 7, 7, 160) 179200 ['activation_55[0][0]'] batch_normalization_51 (BatchN (None, 7, 7, 160) 480 ['conv2d_51[0][0]'] ormalization) batch_normalization_56 (BatchN (None, 7, 7, 160) 480 ['conv2d_56[0][0]'] ormalization) activation_51 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_51[0][0]'] activation_56 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_56[0][0]'] conv2d_52 (Conv2D) (None, 7, 7, 160) 179200 ['activation_51[0][0]'] conv2d_57 (Conv2D) (None, 7, 7, 160) 179200 ['activation_56[0][0]'] batch_normalization_52 (BatchN (None, 7, 7, 160) 480 ['conv2d_52[0][0]'] ormalization) batch_normalization_57 (BatchN (None, 7, 7, 160) 480 ['conv2d_57[0][0]'] ormalization) activation_52 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_52[0][0]'] activation_57 (Activation) (None, 7, 7, 160) 0 ['batch_normalization_57[0][0]'] average_pooling2d_5 (AveragePo (None, 7, 7, 768) 0 ['mixed5[0][0]'] oling2D) conv2d_50 (Conv2D) (None, 7, 7, 192) 147456 ['mixed5[0][0]'] conv2d_53 (Conv2D) (None, 7, 7, 192) 215040 ['activation_52[0][0]'] conv2d_58 (Conv2D) (None, 7, 7, 192) 215040 ['activation_57[0][0]'] conv2d_59 (Conv2D) (None, 7, 7, 192) 147456 ['average_pooling2d_5[0][0]'] batch_normalization_50 (BatchN (None, 7, 7, 192) 576 ['conv2d_50[0][0]'] ormalization) batch_normalization_53 (BatchN (None, 7, 7, 192) 576 ['conv2d_53[0][0]'] ormalization) batch_normalization_58 (BatchN (None, 7, 7, 192) 576 ['conv2d_58[0][0]'] ormalization) batch_normalization_59 (BatchN (None, 7, 7, 192) 576 ['conv2d_59[0][0]'] ormalization) activation_50 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_50[0][0]'] activation_53 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_53[0][0]'] activation_58 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_58[0][0]'] activation_59 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_59[0][0]'] mixed6 (Concatenate) (None, 7, 7, 768) 0 ['activation_50[0][0]', 'activation_53[0][0]', 'activation_58[0][0]', 'activation_59[0][0]'] conv2d_64 (Conv2D) (None, 7, 7, 192) 147456 ['mixed6[0][0]'] batch_normalization_64 (BatchN (None, 7, 7, 192) 576 ['conv2d_64[0][0]'] ormalization) activation_64 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_64[0][0]'] conv2d_65 (Conv2D) (None, 7, 7, 192) 258048 ['activation_64[0][0]'] batch_normalization_65 (BatchN (None, 7, 7, 192) 576 ['conv2d_65[0][0]'] ormalization) activation_65 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_65[0][0]']
conv2d_61 (Conv2D) (None, 7, 7, 192) 147456 ['mixed6[0][0]'] conv2d_66 (Conv2D) (None, 7, 7, 192) 258048 ['activation_65[0][0]'] batch_normalization_61 (BatchN (None, 7, 7, 192) 576 ['conv2d_61[0][0]'] ormalization) batch_normalization_66 (BatchN (None, 7, 7, 192) 576 ['conv2d_66[0][0]'] ormalization) activation_61 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_61[0][0]'] activation_66 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_66[0][0]'] conv2d_62 (Conv2D) (None, 7, 7, 192) 258048 ['activation_61[0][0]'] conv2d_67 (Conv2D) (None, 7, 7, 192) 258048 ['activation_66[0][0]'] batch_normalization_62 (BatchN (None, 7, 7, 192) 576 ['conv2d_62[0][0]'] ormalization) batch_normalization_67 (BatchN (None, 7, 7, 192) 576 ['conv2d_67[0][0]'] ormalization) activation_62 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_62[0][0]'] activation_67 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_67[0][0]'] average_pooling2d_6 (AveragePo (None, 7, 7, 768) 0 ['mixed6[0][0]'] oling2D) conv2d_60 (Conv2D) (None, 7, 7, 192) 147456 ['mixed6[0][0]'] conv2d_63 (Conv2D) (None, 7, 7, 192) 258048 ['activation_62[0][0]'] conv2d_68 (Conv2D) (None, 7, 7, 192) 258048 ['activation_67[0][0]'] conv2d_69 (Conv2D) (None, 7, 7, 192) 147456 ['average_pooling2d_6[0][0]'] batch_normalization_60 (BatchN (None, 7, 7, 192) 576 ['conv2d_60[0][0]'] ormalization) batch_normalization_63 (BatchN (None, 7, 7, 192) 576 ['conv2d_63[0][0]'] ormalization) batch_normalization_68 (BatchN (None, 7, 7, 192) 576 ['conv2d_68[0][0]'] ormalization) batch_normalization_69 (BatchN (None, 7, 7, 192) 576 ['conv2d_69[0][0]'] ormalization) activation_60 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_60[0][0]'] activation_63 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_63[0][0]'] activation_68 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_68[0][0]'] activation_69 (Activation) (None, 7, 7, 192) 0 ['batch_normalization_69[0][0]'] mixed7 (Concatenate) (None, 7, 7, 768) 0 ['activation_60[0][0]', 'activation_63[0][0]', 'activation_68[0][0]', 'activation_69[0][0]'] flatten (Flatten) (None, 37632) 0 ['mixed7[0][0]'] dense (Dense) (None, 1024) 38536192 ['flatten[0][0]'] dropout (Dropout) (None, 1024) 0 ['dense[0][0]'] dense_1 (Dense) (None, 1) 1025 ['dropout[0][0]'] ================================================================================================== Total params: 47,512,481 Trainable params: 38,537,217 Non-trainable params: 8,975,264 __________________________________________________________________________________________________
Preprocessing Input Data¶
Different pre-trained models require different preprocessing techniques. InceptionV3, for instance, expects input values to be scaled between -1 and 1. The TensorFlow Keras library provides a built-in method to handle this:
from tensorflow.keras.applications.inception_v3 import preprocess_input
You should apply this preprocessing function to your dataset to ensure compatibility with the model’s training regime. This is crucial for maintaining consistency and maximizing performance.
Preparing the dataset¶
BASE_DIR = 'cats_and_dogs_filtered'
train_dir = os.path.join(BASE_DIR, 'train')
validation_dir = os.path.join(BASE_DIR, 'validation')
# Directory with training cat/dog pictures
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')
# Directory with validation cat/dog pictures
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
# Instantiate the training dataset
train_dataset = tf.keras.utils.image_dataset_from_directory(
train_dir,
image_size=(150, 150),
batch_size=20,
label_mode='binary'
)
# Instantiate the validation dataset
validation_dataset = tf.keras.utils.image_dataset_from_directory(
validation_dir,
image_size=(150, 150),
batch_size=20,
label_mode='binary'
)
Found 3000 files belonging to 2 classes. Found 1000 files belonging to 2 classes.
def preprocess(image, label):
image = tf.keras.applications.inception_v3.preprocess_input(image)
return image, label
train_dataset_scaled = train_dataset.map(preprocess)
validation_dataset_scaled = validation_dataset.map(preprocess)
# Optimize the datasets for training
SHUFFLE_BUFFER_SIZE = 1000
PREFETCH_BUFFER_SIZE = tf.data.AUTOTUNE
train_dataset_final = (train_dataset_scaled
.cache()
.shuffle(SHUFFLE_BUFFER_SIZE)
.prefetch(PREFETCH_BUFFER_SIZE)
)
validation_dataset_final = (validation_dataset_scaled
.cache()
.prefetch(PREFETCH_BUFFER_SIZE)
)
Augmentation model and final model building¶
data_augmentation = tf.keras.models.Sequential([
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomRotation(0.4),
tf.keras.layers.RandomTranslation(0.2, 0.2),
tf.keras.layers.RandomContrast(0.4),
tf.keras.layers.RandomZoom(0.2)
])
inputs = tf.keras.Input(shape=(150,150,3)) # Defines a new input layer for the augmented model
x = data_augmentation(inputs) # Applies data augmentation to inputs
x = model(x) # Passes the augmented images through the previously defined model
# Creates a new model (model_with_aug) that includes:
# Data augmentation as the first step
# Pretrained InceptionV3 for feature extraction
# Custom dense layers for classification
model_with_aug = tf.keras.Model(inputs, x)
model_with_aug.compile(
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.0001),
loss='binary_crossentropy', metrics=['accuracy']
)
Model training and accuracy check¶
history = model_with_aug.fit(train_dataset_final, verbose=2, epochs=20, validation_data=validation_dataset_final)
Epoch 1/20 150/150 - 132s - loss: 0.6522 - accuracy: 0.6693 - val_loss: 0.1692 - val_accuracy: 0.9320 - 132s/epoch - 877ms/step Epoch 2/20 150/150 - 66s - loss: 0.5357 - accuracy: 0.7437 - val_loss: 0.1431 - val_accuracy: 0.9430 - 66s/epoch - 443ms/step Epoch 3/20 150/150 - 68s - loss: 0.5302 - accuracy: 0.7407 - val_loss: 0.2039 - val_accuracy: 0.9280 - 68s/epoch - 452ms/step Epoch 4/20 150/150 - 67s - loss: 0.4891 - accuracy: 0.7653 - val_loss: 0.1636 - val_accuracy: 0.9490 - 67s/epoch - 450ms/step Epoch 5/20 150/150 - 70s - loss: 0.5020 - accuracy: 0.7623 - val_loss: 0.1336 - val_accuracy: 0.9530 - 70s/epoch - 470ms/step Epoch 6/20 150/150 - 67s - loss: 0.4692 - accuracy: 0.7743 - val_loss: 0.1659 - val_accuracy: 0.9530 - 67s/epoch - 448ms/step Epoch 7/20 150/150 - 67s - loss: 0.4642 - accuracy: 0.7893 - val_loss: 0.1984 - val_accuracy: 0.9500 - 67s/epoch - 450ms/step Epoch 8/20 150/150 - 69s - loss: 0.4725 - accuracy: 0.7887 - val_loss: 0.1421 - val_accuracy: 0.9660 - 69s/epoch - 460ms/step Epoch 9/20 150/150 - 68s - loss: 0.4557 - accuracy: 0.7830 - val_loss: 0.1761 - val_accuracy: 0.9600 - 68s/epoch - 450ms/step Epoch 10/20 150/150 - 67s - loss: 0.4736 - accuracy: 0.7830 - val_loss: 0.1713 - val_accuracy: 0.9560 - 67s/epoch - 447ms/step Epoch 11/20 150/150 - 67s - loss: 0.4498 - accuracy: 0.7893 - val_loss: 0.1866 - val_accuracy: 0.9540 - 67s/epoch - 449ms/step Epoch 12/20 150/150 - 67s - loss: 0.4634 - accuracy: 0.7913 - val_loss: 0.1624 - val_accuracy: 0.9590 - 67s/epoch - 448ms/step Epoch 13/20 150/150 - 65s - loss: 0.4456 - accuracy: 0.7960 - val_loss: 0.1975 - val_accuracy: 0.9560 - 65s/epoch - 433ms/step Epoch 14/20 150/150 - 65s - loss: 0.4559 - accuracy: 0.8010 - val_loss: 0.2033 - val_accuracy: 0.9510 - 65s/epoch - 431ms/step Epoch 15/20 150/150 - 66s - loss: 0.4378 - accuracy: 0.7963 - val_loss: 0.1605 - val_accuracy: 0.9560 - 66s/epoch - 442ms/step Epoch 16/20 150/150 - 69s - loss: 0.4404 - accuracy: 0.8057 - val_loss: 0.2752 - val_accuracy: 0.9450 - 69s/epoch - 461ms/step Epoch 17/20 150/150 - 68s - loss: 0.4303 - accuracy: 0.8010 - val_loss: 0.1953 - val_accuracy: 0.9570 - 68s/epoch - 451ms/step Epoch 18/20 150/150 - 66s - loss: 0.4403 - accuracy: 0.8017 - val_loss: 0.2010 - val_accuracy: 0.9580 - 66s/epoch - 438ms/step Epoch 19/20 150/150 - 67s - loss: 0.4236 - accuracy: 0.8053 - val_loss: 0.2239 - val_accuracy: 0.9580 - 67s/epoch - 444ms/step Epoch 20/20 150/150 - 70s - loss: 0.4431 - accuracy: 0.7950 - val_loss: 0.2148 - val_accuracy: 0.9600 - 70s/epoch - 467ms/step
def plot_loss_acc(history):
'''Plots the training and validation loss and accuracy from a history object'''
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
fig, ax = plt.subplots(1,2, figsize=(12, 6))
ax[0].plot(epochs, acc, 'bo', label='Training accuracy')
ax[0].plot(epochs, val_acc, 'b', label='Validation accuracy')
ax[0].set_title('Training and validation accuracy')
ax[0].set_xlabel('epochs')
ax[0].set_ylabel('accuracy')
ax[0].legend()
ax[1].plot(epochs, loss, 'bo', label='Training Loss')
ax[1].plot(epochs, val_loss, 'b', label='Validation Loss')
ax[1].set_title('Training and validation loss')
ax[1].set_xlabel('epochs')
ax[1].set_ylabel('loss')
ax[1].legend()
plt.show()
# Plot training results
plot_loss_acc(history)
This approach leverages transfer learning while reducing computation time and improving accuracy on small datasets.