Try / scikit-learn in Y minutes

Scikit-learn is an open source machine learning library for supervised and unsupervised learning. It provides various tools for model fitting, data preprocessing, model selection, model evaluation, and many other utilities.

This guide introduces the machine learning vocabulary that is used throughout scikit-learn and provides a simple learning example.

About machine learning · Loading a dataset · Learning and predicting · Evaluating results · Further reading

✨ This is an open source guide. Feel free to improve it!

About machine learning

In general, a learning problem considers a set of n samples of data and then tries to predict properties of the unknown data. If each sample is more than a single number (e.g., an array), it is said to have multiple attributes or features.

One of the most common tasks in machine learning is classification. In a classification problem, the samples belong to two or more classes, and we want to learn from already labeled data how to predict the class of unlabeled data.

An example of a classification problem is handwritten digit recognition, where the goal is to assign each input vector (a digit image) to one of a finite number of discrete categories (an actual digit).

→ 9

Machine learning is about learning some properties of a dataset and then testing those properties against another dataset. A common practice in machine learning is to evaluate an algorithm by splitting a data set into two:

  • the training set, on which we learn some properties;
  • the testing set, on which we test the learned properties.

Scikit-learn provides dozens of built-in machine learning algorithms and models, called estimators. Each estimator can be fitted to some data using its fit method.

Loading a dataset

Scikit-learn comes with a few standard datasets. In the following example, we load the digits dataset:

from sklearn import datasets
digits = datasets.load_digits()

The digits dataset consists of 8x8 pixel images of digits. The images attribute of the dataset stores 8x8 arrays of grayscale values for each image. We will use these arrays to visualize the first 4 images. The target attribute of the dataset stores the digit that each image represents (see the titles of the 4 plots below):

_, axes = plt.subplots(nrows=1, ncols=4, figsize=(6, 2))
for ax, image, label in zip(axes, digits.images,
    ax.imshow(image,, interpolation="nearest")
    ax.set_title("Target: %i" % label)

Learning and predicting

To apply a classifier on this data, we need to flatten the images by transforming each 2D array of grayscale values from the shape (8, 8) to the shape (64,). After that, the whole dataset will be of the shape (n_samples, n_features), where n_samples is the number of images and n_features is the number of pixels in each image (64).

We can then split the data into training and test subsets and fit a support vector classifier on the training samples. We can then use the fitted classifier to predict the value of the digit for the samples in the test subset:

from sklearn import svm, metrics
from sklearn.model_selection import train_test_split

# Flatten the images
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))

# Create a classifier: a support vector classifier
clf = svm.SVC(gamma=0.001)

# Split data into 50% train and 50% test subsets
X_train, X_test, y_train, y_test = train_test_split(
    data,, test_size=0.5, shuffle=False

# Learn the digits on the train subset, y_train)

# Predict the value of the digit on the test subset
predicted = clf.predict(X_test)

Let's visualize the first 4 test samples and show their predicted digit value in the title:

_, axes = plt.subplots(nrows=1, ncols=4, figsize=(6, 2))
for ax, image, prediction in zip(axes, X_test, predicted):
    image = image.reshape(8, 8)
    ax.imshow(image,, interpolation="nearest")
    ax.set_title(f"Prediction: {prediction}")

Do you agree with the classifier?

Evaluating results

Classification metrics show how well our classifier does its job of predicting digits from images.

classification_report builds a text report showing the main classification metrics:

print(f"Classification report for classifier {clf}:")
print(metrics.classification_report(y_test, predicted))
Classification report for classifier SVC(gamma=0.001):
              precision    recall  f1-score   support

           0       1.00      0.99      0.99        88
           1       0.99      0.97      0.98        91
           2       0.99      0.99      0.99        86
           3       0.98      0.87      0.92        91
           4       0.99      0.96      0.97        92
           5       0.95      0.97      0.96        91
           6       0.99      0.99      0.99        91
           7       0.96      0.99      0.97        89
           8       0.94      1.00      0.97        88
           9       0.93      0.98      0.95        92

    accuracy                           0.97       899
   macro avg       0.97      0.97      0.97       899
weighted avg       0.97      0.97      0.97       899
  • Precision is the ability of the classifier to not label a sample as positive when it is actually negative (e.g., never label a "3" image as "8"). It is calculated as the ratio tp / (tp + fp), where tp is the number of true positives and fp is the number of false positives. Ranges from 0 (worst) to 1 (best).

  • Recall is the ability of the classifier to find all the positive samples (e.g., to correctly label all "8" images as "8"). It is calculated as the ratio tp / (tp + fn), where tp is the number of true positives and fn is the number of false negatives. Ranges from 0 (worst) to 1 (best).

  • F1 score combines precision and recall with equal relative contribution. It is the harmonic mean of the precision and recall, calculated as 2 * (precision * recall) / (precision + recall). Ranges from 0 (worst) to 1 (best).

  • Support refers to the number of actual occurrences of the class in the dataset. In our case, it's the total number of images that belong to each digit class.

  • Accuracy is the ratio of correct predictions (both true positives and true negatives) among all samples. For a multi-class classification problem (like ours), accuracy gives an overall indication of how often the classifier is correct across all classes.

We can also plot a confusion matrix of the true digit values and the predicted digit values:

disp = metrics.ConfusionMatrixDisplay.from_predictions(y_test, predicted)
disp.figure_.suptitle("Confusion Matrix")

Confusion matrix is a table used to evaluate the performance of an algorithm. It shows the actual vs. predicted classifications in a grid format, where each row represents the instances of the actual class, and each column represents the instances of the predicted class.

In our case, the matrix helps in understanding how well the model is identifying digits (0 through 9), by comparing the true labels of the images against the labels predicted by the model. It highlights the true positives, true negatives, false positives, and false negatives for each digit class, aiding in assessing the model's performance across all digit categories.

For example, from the above matrix we can see that:

  • Almost all "6" images were correctly labeled as "6", except for one image mislabeled as "1" (high recall).
  • Almost all "6" labels are true positives, except for one "5" image incorrectly labeled as "6" (high precision).

Further reading

See the Tutorials for additional learning resources and the User Guide for details on all the features scikit-learn provides.

Scikit-learn developers + 1 others · original · CC-BY-SA-4.0 · 2024-02-29