Objective¶
We will be working with the MNIST dataset, a classic and widely recognized dataset in the machine learning community. This dataset comprises 70,000 grayscale images of handwritten digits ranging from 0 to 9. These digits were collected from a variety of sources, including high school students and employees of the U.S. Census Bureau, ensuring a diverse and realistic representation of human handwriting. Each image is annotated with the correct label (i.e., the digit it represents), providing a supervised learning dataset for classification tasks.
Due to its widespread use in educational tutorials, research papers, and benchmarking machine learning models, MNIST is often referred to as the “Hello World” of machine learning. It serves as a foundational stepping stone for anyone learning how to build, train, and evaluate models for image recognition and classification tasks.
Importing Dataset¶
To make things easier for practitioners, Scikit-Learn, one of Python’s most popular machine learning libraries, offers built-in functions to fetch standard datasets, including MNIST. With a simple call, users can download and load the dataset directly into their environment, bypassing the need for manual downloading or preprocessing.
from sklearn.datasets import fetch_openml
import numpy as np
import warnings
warnings.filterwarnings(action='ignore') # To not show warnings on screen
mnist = fetch_openml('mnist_784', version = 1)
#mnist (uncomment to view the data)
X, y = mnist['data'], mnist['target'] # Training features and Training Labels
X.shape, y.shape
((70000, 784), (70000,))
There are 70,000 images, represented by each row, and 784 features. Each image in MNIST is a 28×28 pixel grid, and when flattened into a one-dimensional array for machine learning purposes, this results in 784 features per image (28 multiplied by 28). Each feature corresponds to the intensity value of a single pixel, ranging from 0 (white) to 255 (black). This format enables vectorized operations and makes the data compatible with many algorithms that expect flat feature inputs.
Before diving into model building, it’s helpful to visually inspect some examples from the dataset. Viewing a few images gives insight into the dataset’s complexity and helps us appreciate the variation in handwriting styles.
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
any_digit = X.iloc[2]
print('Shape before reshaping : ', any_digit.shape)
any_digit_image = any_digit.values.reshape(28,28)
print('Shape after reshaping : ', any_digit_image.shape)
Shape before reshaping : (784,) Shape after reshaping : (28, 28)
plt.imshow(any_digit_image, cmap = mpl.cm.binary, interpolation= "nearest")
plt.axis("off")
plt.show()
y[2]
'4'
Note that the label is a string. We prefer numbers, so let’s cast y to integers:
y = y.astype(np.uint8)
Separating test dataset and shuffling¶
The training set provided is already randomly shuffled, which is beneficial. This ensures that when we perform K-fold cross-validation, each fold contains a representative distribution of the digits. Without this shuffling, certain folds might lack specific digits entirely, leading to biased or uninformative evaluations.
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
Binary Classifier¶
To simplify our exploration, we’ll narrow our focus to a binary classification task—specifically, identifying whether a given digit is a ‘5’ or not a ‘5’. This task, referred to as the “5-detector”, exemplifies a binary classifier, which is a model trained to distinguish between exactly two classes: the positive class (digit 5) and the negative class (all other digits).
# This will return true for all 5 and False for others
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
Model Build¶
A suitable starting algorithm for this kind of task is the Stochastic Gradient Descent (SGD) classifier, provided in Scikit-Learn via the SGDClassifier class. SGD is particularly effective for large-scale datasets because it processes training instances individually, one at a time, which makes it memory-efficient and well-suited for online learning, where data arrives in a stream or batches.
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state= 42)
sgd_clf.fit(X_train, y_train_5)
SGDClassifier(random_state=42)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
SGDClassifier(random_state=42)
sgd_clf.predict([any_digit])
array([False])
any_other_digit = X.iloc[0]
sgd_clf.predict([any_other_digit])
array([ True])
Performance Measure¶
Evaluating a classification model is often more nuanced than evaluating a regression model. While regression performance can be summarized using straightforward metrics like RMSE or MAE, classification offers multiple evaluation metrics, each revealing a different aspect of performance.
Measuring Accuracy Using Cross-Validation¶
To evaluate our SGD classifier’s performance, we can use Scikit-Learn’s cross_val_score() function, which performs K-fold cross-validation—in this case, using 3 folds. This process splits the dataset into three parts, trains on two, and tests on the third, cycling through all combinations.
For scoring, the metric used is accuracy here.
from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")
array([0.95035, 0.96035, 0.9604 ])
Accuracy is generally not the preferred performance measure for classifiers, especially when you are dealing with skewed datasets. Only about 10% of the images are 5s, so if you always guess that an image is not a 5, you will be right about 90% of the time.
Confusion Matrix¶
A better alternative is the confusion matrix, which provides a comprehensive snapshot of prediction outcomes. It tells us how many samples were correctly or incorrectly classified, broken down by actual and predicted labels.
To generate such insights, we use cross_val_predict(), which performs K-fold cross-validation like cross_val_score(), but instead of returning accuracy scores, it returns predictions for each instance made during the validation phase. This ensures that each prediction is made on a fold the model hasn’t seen during training, allowing for an unbiased assessment.
from sklearn.model_selection import cross_val_predict
y_train_predict = cross_val_predict(sgd_clf, X_train, y_train_5, cv = 3)
from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_5, y_train_predict)
array([[53892, 687], [ 1891, 3530]])
A confusion matrix is a square table. Rows represent actual classes, and columns represent predicted classes. For our binary classifier:
- The first row shows how many non-5 digits were predicted as non-5s (true negatives) and how many were incorrectly predicted as 5s (false positives).
- The second row captures how many actual 5s were missed (false negatives) and how many were correctly identified (true positives).
For example:
- True negatives (TN): 53,892 non-5 images correctly predicted as non-5s.
- False positives (FP): 687 non-5 images wrongly predicted as 5s.
- False negatives (FN): 1,891 5-images missed by the model.
- True positives (TP): 3,530 images correctly predicted as 5s.
An ideal classifier would have all correct predictions, resulting in a matrix with non-zero values only on the diagonal from top-left to bottom-right.
y_train_perfect_predictions = y_train_5 # pretend we reached perfection
confusion_matrix(y_train_5, y_train_perfect_predictions)
array([[54579, 0], [ 0, 5421]])
Precision and Recall¶
To delve deeper, we introduce precision—the proportion of positive predictions that were actually correct:
$$ \text{Precision} = \frac{TP}{TP + FP} $$
We also consider recall, also known as sensitivity or true positive rate (TPR)—the proportion of actual positives that were successfully identified:
$$ \text{Recall} = \frac{TP}{TP + FN} $$
from sklearn.metrics import precision_score, recall_score
precision_score(y_train_5, y_train_predict)
np.float64(0.8370879772350012)
recall_score(y_train_5, y_train_predict)
np.float64(0.6511713705958311)
From our example:
- Precision ≈ 83.7%
- Recall ≈ 65.11%
While these values may seem decent, they highlight the limitations of relying on accuracy alone. A more balanced metric is the F1 Score, which is the harmonic mean of precision and recall:
$$ F1 = 2 \times \frac{\text{precision} \times \text{recall}}{\text{precision} + \text{recall}} = \frac{TP}{TP + \frac{(FN + FP)}{2}} $$
from sklearn.metrics import f1_score
f1_score(y_train_5, y_train_predict)
np.float64(0.7325171197343847)
The F1 score is particularly useful when you need a balance between precision and recall, but it also reflects a tradeoff: increasing one typically leads to a decrease in the other.
To control this tradeoff, you can use the model’s decision function, which returns a confidence score (or raw margin) for each prediction rather than a binary output. You can then set a custom threshold to decide what confidence level qualifies as a positive prediction.
y_scores = sgd_clf.decision_function([any_other_digit])
y_scores
array([2164.22030239])
threshold = 8000
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
array([False])
By default, SGDClassifier
uses a threshold of 0. But suppose you increase it to 8000—this makes the classifier more conservative in predicting 5s, which reduces false positives (increasing precision) but also increases false negatives (reducing recall).
To systematically explore the impact of different thresholds, we use cross_val_predict()
again, this time requesting decision scores. With these scores, we can compute precision and recall at every possible threshold using the precision_recall_curve()
function.
y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,
method="decision_function")
y_scores
array([ 1200.93051237, -26883.79202424, -33072.03475406, ..., 13272.12718981, -7258.47203373, -16877.50840447])
from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
plt.legend(loc="center right", fontsize=16)
plt.xlabel("Threshold", fontsize=16)
plt.grid(True)
plt.axis([-50000, 50000, 0, 1])
recall_90_precision = recalls[np.argmax(precisions >= 0.90)]
threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]
plt.figure(figsize=(8, 4))
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.plot([threshold_90_precision, threshold_90_precision], [0., 0.9], "r:")
plt.plot([-50000, threshold_90_precision], [0.9, 0.9], "r:")
plt.plot([-50000, threshold_90_precision], [recall_90_precision, recall_90_precision], "r:")
plt.plot([threshold_90_precision], [0.9], "ro")
plt.plot([threshold_90_precision], [recall_90_precision], "ro")
plt.show()
When you plot this curve, you may notice that precision drops sharply after a certain point. This helps you choose a good balance—for instance, you might aim for 90% precision. You then determine the lowest threshold that achieves this goal. However, remember: high precision with very low recall might render your model ineffective, especially if the cost of missing positives is high.
(y_train_predict == (y_scores > 0)).all()
np.True_
def plot_precision_vs_recall(precisions, recalls):
plt.plot(recalls, precisions, "b-", linewidth=2)
plt.xlabel("Recall", fontsize=16)
plt.ylabel("Precision", fontsize=16)
plt.axis([0, 1, 0, 1])
plt.grid(True)
plt.figure(figsize=(8, 6))
plot_precision_vs_recall(precisions, recalls)
plt.plot([recall_90_precision, recall_90_precision], [0., 0.9], "r:")
plt.plot([0.0, recall_90_precision], [0.9, 0.9], "r:")
plt.plot([recall_90_precision], [0.9], "ro")
plt.show()
You can see that precision really starts to fall sharply around 80% recall. You will probably want to select a precision/recall tradeoff just before that drop—for example, at around 60% recall. But of course the choice depends on your project.
So let’s suppose you decide to aim for 90% precision. You look up the first plot and find that you need to use a threshold of about 8,000. To be more precise you can search for the lowest threshold that gives you at least 90% precision
threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]
y_train_pred_90 = (y_scores >= threshold_90_precision)
precision_score(y_train_5, y_train_pred_90)
np.float64(0.9000345901072293)
recall_score(y_train_5, y_train_pred_90)
np.float64(0.4799852425751706)
threshold_precision = thresholds[np.argmax((precisions >= 0.75) & (precisions <= 0.90))]
y_train_pred = (y_scores >= threshold_precision)
precision_score(y_train_5, y_train_pred)
np.float64(0.75)
recall_score(y_train_5, y_train_pred)
np.float64(0.7659103486441616)
A high-precision classifier is not very useful if its recall is too low!
The ROC Curve¶
Another powerful evaluation tool is the Receiver Operating Characteristic (ROC) curve, which plots recall (TPR) against the false positive rate (FPR). Since FPR = 1 - specificity, the ROC curve effectively shows how well the model distinguishes between classes across thresholds.
Models like RandomForestClassifier
don’t use a decision function but instead offer predict_proba()
, which returns the estimated probability that each instance belongs to each class. These probabilities can be used similarly to scores for plotting an ROC curve.
from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
def plot_roc_curve(fpr, tpr, label=None):
plt.plot(fpr, tpr, linewidth=2, label=label)
plt.plot([0, 1], [0, 1], 'k--') # dashed diagonal
plt.axis([0, 1, 0, 1])
plt.xlabel('False Positive Rate (Fall-Out)', fontsize=16)
plt.ylabel('True Positive Rate (Recall)', fontsize=16)
plt.grid(True)
plt.figure(figsize=(8, 6))
plot_roc_curve(fpr, tpr)
fpr_90 = fpr[np.argmax(tpr >= recall_90_precision)]
plt.plot([fpr_90, fpr_90], [0., recall_90_precision], "r:")
plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], "r:")
plt.plot([fpr_90], [recall_90_precision], "ro")
plt.show()
from sklearn.metrics import roc_auc_score
roc_auc_score(y_train_5, y_scores)
np.float64(0.9604938554008616)
Comparison between Classifiers¶
When comparing ROC curves of different models, the closer the curve approaches the top-left corner, the better the model’s performance. In many cases, Random Forests outperform SGD on ROC curves, offering better classification decisions across a range of thresholds.
from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,
method="predict_proba")
The predict_proba() method returns an array containing a row per instance and a column per class, each containing the probability that the given instance belongs to the given class. But to plot a ROC curve, you need scores, not probabilities.
y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)
recall_for_forest = tpr_forest[np.argmax(fpr_forest >= fpr_90)]
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, "b:", linewidth=2, label="SGD")
plot_roc_curve(fpr_forest, tpr_forest, "Random Forest")
plt.plot([fpr_90, fpr_90], [0., recall_90_precision], "r:")
plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], "r:")
plt.plot([fpr_90], [recall_90_precision], "ro")
plt.plot([fpr_90, fpr_90], [0., recall_for_forest], "r:")
plt.plot([fpr_90], [recall_for_forest], "ro")
plt.grid(True)
plt.legend(loc="lower right", fontsize=16)
plt.show()
the RandomForestClassifier’s ROC curve looks much better than the SGDClassifier’s: it comes much closer to the top-left corner.
roc_auc_score(y_train_5, y_scores_forest)
np.float64(0.9983436731328145)
y_train_pred_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3)
precision_score(y_train_5, y_train_pred_forest)
np.float64(0.9905083315756169)
recall_score(y_train_5, y_train_pred_forest)
np.float64(0.8662608374838591)
Now you know how to train binary classifiers, choose the appropriate metric for your task, evaluate your classifiers using cross-validation, select the precision/recall tradeoff that fits your needs, and compare various models using ROC curves and ROC AUC scores.
Multiclass Classification¶
In contrast to binary classification, where a model predicts between two distinct categories, multiclass classification involves distinguishing among three or more distinct classes. This makes the task significantly more complex, as the classifier must not only separate one class from another but handle multiple inter-class boundaries simultaneously.
Some machine learning algorithms are natively multiclass-capable. For instance, Random Forest classifiers and Naive Bayes classifiers inherently support multiclass classification without any additional modification. They can directly learn from training data with multiple labels and internally handle the complexity of differentiating among all possible classes.
However, other algorithms, such as Support Vector Machines (SVMs) and linear classifiers like Stochastic Gradient Descent (SGD), are inherently designed for binary classification. This means they can only distinguish between two categories by default. To adapt these binary classifiers for multiclass classification, strategy-based approaches are used—primarily One-vs-All (OvA) and One-vs-One (OvO).
- One-vs-All (OvA) strategy involves training a separate binary classifier for each class. Each classifier learns to distinguish its assigned class from all other classes combined. For N classes, you train N classifiers.
- One-vs-One (OvO) strategy involves training a separate binary classifier for every possible pair of classes. For N classes, this results in N × (N – 1) / 2 classifiers. So, with 10 classes (as in MNIST), OvO requires 45 separate classifiers. Each classifier learns to distinguish between two classes only.
Scikit-Learn automates these strategies:
- When using a binary classifier for a multiclass task, it automatically applies the OvA strategy by default.
- For SVM classifiers specifically, it instead uses the OvO strategy, as SVMs scale poorly with large datasets—training several smaller pairwise classifiers is computationally more efficient.
sgd_clf.fit(X_train, y_train)
SGDClassifier(random_state=42)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
SGDClassifier(random_state=42)
sgd_clf.predict([any_digit])
array([4], dtype=uint8)
Internally, when a model such as SGDClassifier is trained on the MNIST dataset (which has 10 digit classes), Scikit-Learn constructs 10 separate binary classifiers in an OvA manner. Each classifier returns a decision score that reflects the confidence in predicting its associated class. These scores are compared, and the class with the highest score is selected as the final prediction. Calling the .decision_function() method confirms this: instead of returning a single value, it yields an array of 10 scores per instance, one for each digit class.
Moreover, the model stores a classes_ attribute, which lists all the unique target classes it has seen during training. In many datasets like MNIST, the index of each class in this array directly corresponds to the class label (e.g., index 5 refers to digit 5), simplifying the interpretation of outputs.
some_digit_scores = sgd_clf.decision_function([any_digit])
some_digit_scores
array([[-34143.40703505, -21942.13780869, -4018.29275037, -2239.19313075, 43.09419826, -15058.88052383, -33653.31059893, -8277.80610963, -7460.52016321, -14180.15338984]])
np.argmax(some_digit_scores)
np.int64(4)
sgd_clf.classes_
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=uint8)
While SGDClassifier uses OvA by default, you can manually configure it to use OvO by wrapping it inside Scikit-Learn’s OneVsOneClassifier.
from sklearn.multiclass import OneVsOneClassifier
ovo_clf = OneVsOneClassifier(SGDClassifier(random_state=42))
ovo_clf.fit(X_train, y_train)
ovo_clf.predict([any_digit])
array([4], dtype=uint8)
To train a binary classifier for every pair of digits: one to distinguish 0s and 1s, another to distinguish 0s and 2s, another for 1s and 2s, and so on. This is called the one-versus-one (OvO) strategy. If there are N classes, you need to train N × (N – 1) / 2 classifiers. For the MNIST problem, this means training 45 binary classifiers! When you want to classify an image, you have to run the image through all 45 classifiers.
len(ovo_clf.estimators_)
45
forest_clf.fit(X_train, y_train)
forest_clf.predict([any_digit])
array([4], dtype=uint8)
Random Forest classifiers, in contrast, natively support multiclass classification. There's no need to convert them using OvA or OvO wrappers. You can call .predict_proba() on such models to get the predicted probabilities for each class directly. This method returns a 2D array where each row corresponds to an instance, and each column corresponds to the model's estimated probability for a particular class.
forest_clf.predict_proba([any_digit])
array([[0. , 0. , 0.02, 0. , 0.95, 0. , 0. , 0.01, 0.01, 0.01]])
cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy")
array([0.87365, 0.85835, 0.8689 ])
cross_val_score(forest_clf, X_train, y_train, cv=3, scoring="accuracy")
array([0.9646 , 0.96255, 0.9666 ])
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))
cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring="accuracy")
array([0.8983, 0.891 , 0.9018])
Error Analysis¶
Once a working model is built and performs reasonably well, the next logical step is error analysis—a systematic approach to understand where and why the model makes mistakes. This helps in identifying weaknesses and guiding the next steps for improvement.
A common starting point is to generate a confusion matrix, which provides a detailed breakdown of actual vs. predicted class labels. Each row in the matrix corresponds to an actual class, while each column represents a predicted class. The diagonal entries show correctly classified instances, while the off-diagonal entries reveal specific misclassifications.
To compute the confusion matrix:
- First, use cross_val_predict() to generate cross-validated predictions across the dataset.
- Then, pass the true and predicted labels to confusion_matrix().
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
conf_mx
array([[5577, 0, 22, 5, 8, 43, 36, 6, 225, 1], [ 0, 6400, 37, 24, 4, 44, 4, 7, 212, 10], [ 27, 27, 5220, 92, 73, 27, 67, 36, 378, 11], [ 22, 17, 117, 5227, 2, 203, 27, 40, 403, 73], [ 12, 14, 41, 9, 5182, 12, 34, 27, 347, 164], [ 27, 15, 30, 168, 53, 4444, 75, 14, 535, 60], [ 30, 15, 42, 3, 44, 97, 5552, 3, 131, 1], [ 21, 10, 51, 30, 49, 12, 3, 5684, 195, 210], [ 17, 63, 48, 86, 3, 126, 25, 10, 5429, 44], [ 25, 18, 30, 64, 118, 36, 1, 179, 371, 5107]])
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()
A confusion matrix where most values lie on the diagonal generally indicates a well-performing model. However, darker (or lighter) areas off the diagonal show where the model struggles.
For example, if the row corresponding to digit ‘5’ appears dimmer than other digits, this may indicate:
- The dataset has fewer samples of the digit '5', making it harder to learn.
- The model finds it inherently more difficult to distinguish '5' from other digits due to visual similarity.
Let’s focus the plot on the errors. To normalize the matrix and better compare class-wise performance, divide each cell in a row by the total number of samples in that actual class. This yields error rates rather than absolute error counts, helping reveal misclassifications relative to class frequency. First, you need to divide each value in the confusion matrix by the number of images in the corresponding class, so you can compare error rates instead of absolute number of errors.
row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_mx = conf_mx / row_sums
np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
plt.show()
Further examination may show, for instance, that many images are misclassified as 8s (a bright column), but real 8s are usually classified correctly (a dark row). This implies the classifier tends to overpredict 8, possibly due to its shape being visually similar to multiple digits. Solutions might include:
- Collecting more training examples of similar digits that are commonly misclassified as 8.
- Feature engineering, such as creating features that count closed loops in digits (e.g., 8 has two, 6 has one, 5 has none).
- Preprocessing images with image processing libraries like Scikit-Image or OpenCV to enhance key patterns.
def plot_digits(instances, images_per_row=10, **options):
size = 28
images_per_row = min(len(instances), images_per_row)
# This is equivalent to n_rows = ceil(len(instances) / images_per_row):
n_rows = (len(instances) - 1) // images_per_row + 1
# Append empty images to fill the end of the grid, if needed:
n_empty = n_rows * images_per_row - len(instances)
padded_instances = np.concatenate([instances, np.zeros((n_empty, size * size))], axis=0)
# Reshape the array so it's organized as a grid containing 28×28 images:
image_grid = padded_instances.reshape((n_rows, images_per_row, size, size))
# Combine axes 0 and 2 (vertical image grid axis, and vertical image axis),
# and axes 1 and 3 (horizontal axes). We first need to move the axes that we
# want to combine next to each other, using transpose(), and only then we
# can reshape:
big_image = image_grid.transpose(0, 2, 1, 3).reshape(n_rows * size,
images_per_row * size)
# Now that we have a big image, we just need to show it:
plt.imshow(big_image, cmap = mpl.cm.binary, **options)
plt.axis("off")
cl_a, cl_b = 3, 5
X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]
X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]
X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]
X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]
plt.figure(figsize=(8,8))
plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)
plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)
plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)
plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)
plt.show()
Multilabel Classification¶
In traditional classification, each instance is assigned to a single class label. But in multilabel classification, an instance may belong to multiple classes simultaneously.
Consider labeling handwritten digits with additional tags—beyond just the digit itself:
- One label might indicate whether the digit is “large” (e.g., 7, 8, 9).
- Another label might specify whether the digit is “odd” (e.g., 1, 3, 5, 7, 9).
Thus, each instance (digit image) receives a tuple of labels, both of which are binary (yes/no). A classifier like KNeighborsClassifier, which supports multilabel classification, can be trained using such a target array where each row has multiple binary labels.
from sklearn.neighbors import KNeighborsClassifier
y_train_large = (y_train >= 7)
y_train_odd = (y_train % 2 == 1)
y_multilabel = np.c_[y_train_large, y_train_odd]
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)
KNeighborsClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier()
knn_clf.predict([any_digit])
array([[False, False]])
This code computes the average F1 score across all labels:
#This will take hours so not running
y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3)
f1_score(y_multilabel, y_train_knn_pred, average="macro")
np.float64(0.9764102655606048)
To assess the model's performance across multiple labels, we often compute the average F1 score—a harmonic mean of precision and recall—across all labels:
- By default, this average treats each label equally.
- But in imbalanced datasets (e.g., many more samples of label A than B), this may skew results. In such cases, setting average="weighted" weighs each label by its support (i.e., how many instances have that label), offering a fairer performance summary.
Multioutput Classification¶
Multioutput classification is a generalization of multilabel classification where each label itself can take on multiple values, not just binary.
A classic example is image denoising using supervised learning. Suppose we want a model to clean noisy images of handwritten digits:
- Each noisy input image is represented as a 784-dimensional array (28×28 pixels).
- The output is a “cleaned” version of the same image—another 784-dimensional array with pixel intensities ranging from 0 to 255.
In this scenario:
- Each pixel is a separate output label.
- But instead of being binary (on/off), each label can take multiple intensity values.
- Therefore, the model’s output for each image is a multi-output, multi-class vector—a complex prediction task requiring both classification and regression intuition.
To build the dataset, you can synthetically add random noise to MNIST digit images using numpy functions like randint() and treat the original images as the target output. A model is then trained to learn the mapping from noisy to clean images, effectively learning a denoising transformation.
Let’s start by creating the training and test sets by taking the MNIST images and adding noise to their pixel intensities using NumPy’s randint() function. The target images will be the original images
noise = np.random.randint(0, 100, (len(X_train), 784))
X_train_mod = X_train + noise
noise = np.random.randint(0, 100, (len(X_test), 784))
X_test_mod = X_test + noise
y_train_mod = X_train
y_test_mod = X_test
def plot_digit(data):
image = data.values.reshape(28, 28)
plt.imshow(image, cmap = mpl.cm.binary,
interpolation="nearest")
plt.axis("off")
some_index = 0
plt.subplot(121); plot_digit(X_test_mod.iloc[some_index])
plt.subplot(122); plot_digit(y_test_mod.iloc[some_index])
plt.show()
knn_clf.fit(X_train_mod, y_train_mod)
KNeighborsClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier()
def plot_digit(data):
image = data.reshape(28, 28)
plt.imshow(image, cmap = mpl.cm.binary,
interpolation="nearest")
plt.axis("off")
clean_digit = knn_clf.predict([X_test_mod.iloc[some_index]])
plot_digit(clean_digit)