In [1]:
import os
import pickle
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
In [9]:
BASE_DIR = 'cats_and_dogs_filtered'
train_dir = os.path.join(BASE_DIR, 'train')
# Directory with training cat/dog pictures
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')
train_cat_fnames = os.listdir(train_cats_dir)
train_dog_fnames = os.listdir(train_dogs_dir)
In [10]:
# Get the filenames for cats and dogs images
cats_filenames = [os.path.join(train_cats_dir, filename) for filename in os.listdir(train_cats_dir)]
dogs_filenames = [os.path.join(train_dogs_dir, filename) for filename in os.listdir(train_dogs_dir)]
fig, axes = plt.subplots(2, 4, figsize=(14, 7))
fig.suptitle('Cat and Dog Images', fontsize=16)
# Plot the first 4 images of each class
for i, cat_image in enumerate(cats_filenames[:4]):
img = tf.keras.utils.load_img(cat_image)
axes[0, i].imshow(img)
axes[0, i].set_title(f'Example Cat {i}')
for i, dog_image in enumerate(dogs_filenames[:4]):
img = tf.keras.utils.load_img(dog_image)
axes[1, i].imshow(img)
axes[1, i].set_title(f'Example Dog {i}')
plt.show()
In [13]:
def train_val_datasets():
"""Creates datasets for training and validation.
Returns:
(tf.data.Dataset, tf.data.Dataset): Training and validation datasets.
"""
training_dataset, validation_dataset = tf.keras.utils.image_dataset_from_directory(
directory=BASE_DIR,
image_size=(150,150),
batch_size=128,
label_mode='binary',
validation_split=0.10,
subset='both',
seed=42
)
return training_dataset, validation_dataset
In [14]:
# Create the datasets
training_dataset, validation_dataset = train_val_datasets()
Found 4000 files belonging to 2 classes. Using 3600 files for training. Using 400 files for validation.
In [15]:
for images, labels in training_dataset.take(1):
example_batch_images = images
example_batch_labels = labels
print(f"Maximum pixel value of images: {np.max(example_batch_images)}\n")
print(f"Shape of batch of images: {example_batch_images.shape}")
print(f"Shape of batch of labels: {example_batch_labels.shape}")
Maximum pixel value of images: 255.0 Shape of batch of images: (128, 150, 150, 3) Shape of batch of labels: (128, 1)
In [32]:
def create_model():
"""Creates the untrained model for classifying cats and dogs.
Returns:
tf.keras.Model: The model that will be trained to classify cats and dogs.
"""
model = tf.keras.models.Sequential([
tf.keras.Input(shape=(150,150,3)),
tf.keras.layers.Rescaling(1./255),
tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
loss='binary_crossentropy',
metrics=['accuracy']
)
return model
In [33]:
model = create_model()
In [34]:
model.evaluate(example_batch_images, example_batch_labels, verbose=False)
Out[34]:
[0.6515199542045593, 0.7578125]
In [35]:
predictions = model.predict(example_batch_images, verbose=False)
print(f"predictions have shape: {predictions.shape}")
predictions have shape: (128, 1)
In [36]:
class EarlyStoppingCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
# Check if the accuracy is greater or equal to 0.95 and validation accuracy is greater or equal to 0.8
if logs['accuracy'] >= 0.95 and logs['val_accuracy'] >= 0.80:
self.model.stop_training = True
print("\nReached 95% train accuracy and 80% validation accuracy, so cancelling training!")
In [37]:
# Train the model and save the training history (this may take some time)
history = model.fit(
training_dataset,
epochs=20,
validation_data=validation_dataset,
callbacks = [EarlyStoppingCallback()]
)
Epoch 1/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 12s 358ms/step - accuracy: 0.7112 - loss: 1.1279 - val_accuracy: 0.7300 - val_loss: 0.5887 Epoch 2/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 351ms/step - accuracy: 0.7514 - loss: 0.5808 - val_accuracy: 0.7300 - val_loss: 0.5896 Epoch 3/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 349ms/step - accuracy: 0.7622 - loss: 0.5644 - val_accuracy: 0.7300 - val_loss: 0.6849 Epoch 4/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 348ms/step - accuracy: 0.7522 - loss: 0.5802 - val_accuracy: 0.7300 - val_loss: 0.6062 Epoch 5/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 349ms/step - accuracy: 0.7588 - loss: 0.5616 - val_accuracy: 0.7300 - val_loss: 0.5814 Epoch 6/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 346ms/step - accuracy: 0.7684 - loss: 0.5511 - val_accuracy: 0.7300 - val_loss: 0.5809 Epoch 7/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 346ms/step - accuracy: 0.7594 - loss: 0.5545 - val_accuracy: 0.7300 - val_loss: 0.5945 Epoch 8/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 345ms/step - accuracy: 0.7568 - loss: 0.5513 - val_accuracy: 0.7300 - val_loss: 0.5829 Epoch 9/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 348ms/step - accuracy: 0.7581 - loss: 0.5495 - val_accuracy: 0.7300 - val_loss: 0.5929 Epoch 10/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 347ms/step - accuracy: 0.7554 - loss: 0.5520 - val_accuracy: 0.7300 - val_loss: 0.5942 Epoch 11/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 345ms/step - accuracy: 0.7609 - loss: 0.5459 - val_accuracy: 0.7300 - val_loss: 0.6042 Epoch 12/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 348ms/step - accuracy: 0.7603 - loss: 0.5263 - val_accuracy: 0.7300 - val_loss: 0.6361 Epoch 13/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 346ms/step - accuracy: 0.7633 - loss: 0.5300 - val_accuracy: 0.7300 - val_loss: 0.6357 Epoch 14/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 348ms/step - accuracy: 0.7516 - loss: 0.5155 - val_accuracy: 0.7300 - val_loss: 0.6882 Epoch 15/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 347ms/step - accuracy: 0.7668 - loss: 0.5020 - val_accuracy: 0.7300 - val_loss: 0.6034 Epoch 16/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 344ms/step - accuracy: 0.7839 - loss: 0.4754 - val_accuracy: 0.7300 - val_loss: 0.7933 Epoch 17/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 347ms/step - accuracy: 0.7831 - loss: 0.4708 - val_accuracy: 0.7275 - val_loss: 0.6275 Epoch 18/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 349ms/step - accuracy: 0.7981 - loss: 0.4390 - val_accuracy: 0.7225 - val_loss: 0.6495 Epoch 19/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 348ms/step - accuracy: 0.8236 - loss: 0.3973 - val_accuracy: 0.7275 - val_loss: 0.8244 Epoch 20/20 29/29 ━━━━━━━━━━━━━━━━━━━━ 10s 354ms/step - accuracy: 0.8338 - loss: 0.3671 - val_accuracy: 0.7075 - val_loss: 0.9238
In [38]:
# Get training and validation accuracies
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
# Get number of epochs
epochs = range(len(acc))
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
fig.suptitle('Training and validation accuracy')
for i, (data, label) in enumerate(zip([(acc, val_acc), (loss, val_loss)], ["Accuracy", "Loss"])):
ax[i].plot(epochs, data[0], 'r', label="Training " + label)
ax[i].plot(epochs, data[1], 'b', label="Validation " + label)
ax[i].legend()
ax[i].set_xlabel('epochs')
plt.show()
In [39]:
with open('history.pkl', 'wb') as f:
pickle.dump(history.history, f)
In [ ]: