KMeans(n_clusters=8, *, init='k-means++', n_init=10, max_iter=300, tol=0.0001, precompute_distances='auto', verbose=0, random_state=None, copy_x=True, n_jobs='deprecated', algorithm='auto')
Here's a brief explanation of the main parameters:
precompute_distances: Whether to precompute distances (faster but requires more memory) or to calculate distances on the fly (slower but less memory-intensive).
There are also several other parameters you can specify to further customize the KMeans object. Once you have created the KMeans object with the desired parameters, you can fit it to your data using the fit method and use the predict method to assign new data points to their closest clusters based on the learned centroids.
Visual inspection of the clusters can be a useful method to evaluate clustering, especially when the number of clusters is small. It can provide insights into the quality of the clustering and identify any issues or anomalies.
Besides, scikit-learn provides several cluster quality metrics that can be used to evaluate the performance of K-means clustering. Here are some commonly used metrics:
Silhouette Score: Computes the mean silhouette coefficient of all samples. This metric measures the similarity of a sample to its own cluster compared to other clusters. It ranges from -1 to 1, where higher values indicate better clustering.
Inertia: Measures the sum of squared distances of all samples to their closest cluster center. This metric is used to evaluate how well the clusters are separated from each other. Lower values indicate better clustering results.
Calinski-Harabasz Index: Computes the ratio of the between-cluster variance to the within-cluster variance. This metric measures how well the clusters are separated from each other. Higher values indicate better clustering results.
Davies-Bouldin Index: Computes the average similarity between each cluster and its most similar cluster. This metric measures how well the clusters are separated from each other. Lower values indicate better clustering results.
It is important to note that no single method can determine the best clustering for all datasets, and a combination of methods may be necessary to obtain a comprehensive evaluation.
You probably meet the following userwarning message when implementation of KMean
clustering in Section 5 and Section 6.
C:\Users\shouk\anaconda3\lib\site-packages\sklearn\cluster\_kmeans.py:1382: UserWarning: KMeans is known to have a memory leak on Windows with MKL, when there are less chunks than available threads. You can avoid it by setting the environment variable OMP_NUM_THREADS=1.
warnings.warn(
To avoid this userwarning message, the easiest way is to set the environment variable by adding the following command before importing all required packages. If you are interested in studying where and how the UserWarning message generates, you maybe comment #
the following code first and run all the rest codes.
import os
os.environ["OMP_NUM_THREADS"] = '1'
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn import preprocessing
from sklearn.cluster import KMeans
We use a very popular and easy to understand dataset on Mall Customers with different names online. You can download from many places, such as Mall_Customers.csv
from this link in Kaggle, or customers.csv
from this link in GitHub. Then you put it into a folder, say the data
folder in your current working directory. I downloaded the one with name 'customers.csv' from GitHub, so I can read it in the following way.
data = pd.read_csv('./data/customers.csv')
data.head()
CustomerID | Gender | Age | Annual Income (k$) | Spending Score (1-100) | |
---|---|---|---|---|---|
0 | 1 | Male | 19 | 15 | 39 |
1 | 2 | Male | 21 | 15 | 81 |
2 | 3 | Female | 20 | 16 | 6 |
3 | 4 | Female | 23 | 16 | 77 |
4 | 5 | Female | 31 | 17 | 40 |
If you do not want a copy in your local computer, you can read it directly from the GitHub.
url = 'https://raw.githubusercontent.com/jeffprosise/Applied-Machine-Learning/main/Chapter%201/Data/customers.csv'
data = pd.read_csv(url)
data.head()
CustomerID | Gender | Age | Annual Income (k$) | Spending Score (1-100) | |
---|---|---|---|---|---|
0 | 1 | Male | 19 | 15 | 39 |
1 | 2 | Male | 21 | 15 | 81 |
2 | 3 | Female | 20 | 16 | 6 |
3 | 4 | Female | 23 | 16 | 77 |
4 | 5 | Female | 31 | 17 | 40 |
data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 200 entries, 0 to 199 Data columns (total 5 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 CustomerID 200 non-null int64 1 Gender 200 non-null object 2 Age 200 non-null int64 3 Annual Income (k$) 200 non-null int64 4 Spending Score (1-100) 200 non-null int64 dtypes: int64(4), object(1) memory usage: 7.9+ KB
data.describe()
CustomerID | Age | Annual Income (k$) | Spending Score (1-100) | |
---|---|---|---|---|
count | 200.000000 | 200.000000 | 200.000000 | 200.000000 |
mean | 100.500000 | 38.850000 | 60.560000 | 50.200000 |
std | 57.879185 | 13.969007 | 26.264721 | 25.823522 |
min | 1.000000 | 18.000000 | 15.000000 | 1.000000 |
25% | 50.750000 | 28.750000 | 41.500000 | 34.750000 |
50% | 100.500000 | 36.000000 | 61.500000 | 50.000000 |
75% | 150.250000 | 49.000000 | 78.000000 | 73.000000 |
max | 200.000000 | 70.000000 | 137.000000 | 99.000000 |
We encode the categorical variable 'Gender', where "Male" and "Female" in are encoded with value 1 and 0, respectively. There are several methods to encode the categorical or string variables, which have been discussed in this previous article. You can easily use these methods, but we will use another method LabelEncoder()
in Scikit learn.
data_encode = data.copy()
le = preprocessing.LabelEncoder()
data_encode['Gender'] = le.fit_transform(data_encode['Gender'])
data_encode
CustomerID | Gender | Age | Annual Income (k$) | Spending Score (1-100) | |
---|---|---|---|---|---|
0 | 1 | 1 | 19 | 15 | 39 |
1 | 2 | 1 | 21 | 15 | 81 |
2 | 3 | 0 | 20 | 16 | 6 |
3 | 4 | 0 | 23 | 16 | 77 |
4 | 5 | 0 | 31 | 17 | 40 |
... | ... | ... | ... | ... | ... |
195 | 196 | 0 | 35 | 120 | 79 |
196 | 197 | 0 | 45 | 126 | 28 |
197 | 198 | 1 | 32 | 126 | 74 |
198 | 199 | 1 | 32 | 137 | 18 |
199 | 200 | 1 | 30 | 137 | 83 |
200 rows × 5 columns
First, we just consider a simple 2 dimension case on the customers in terms of 'Annual Income (k$)' and 'Spending Score (1-100)'. We make a copy of the DataFrame to keep the original DataFrame unchanged.
df = data_encode.copy()
x = df['Annual Income (k$)']
y = df['Spending Score (1-100)']
# set the plotting style to 'ggplot'
plt.style.use('ggplot')
plt.scatter(x, y, s=35, alpha=0.9)
plt.xlabel('Annual Income (k$)')
plt.ylabel('Spending Score')
Text(0, 0.5, 'Spending Score')
The above visualization result displays that the data points fall into roughly five clusters.
The easy way to create the input X by the column names, and then convert it as a NumPy array using '.values' to avoid an UserWarning in the Section 5.8.
X = df[['Annual Income (k$)','Spending Score (1-100)']].values
kmeans = KMeans(n_clusters=5, init='k-means++', n_init='auto', random_state=0).fit(X)
To predict the labels of K-means clustering in scikit-learn, you can use the predict
method of the KMeans
object.
cluster_labels = kmeans.predict(X)
print(cluster_labels)
[4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 1 4 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 0 2 1 2 0 2 0 2 1 2 0 2 0 2 0 2 0 2 1 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2]
Or use labels_
method.
cluster_labels = kmeans.labels_
print(cluster_labels)
[4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 1 4 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 0 2 1 2 0 2 0 2 1 2 0 2 0 2 0 2 0 2 1 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2]
To obtain the center points of the clusters, we can use the cluster_centers_
method of the KMeans object.
centers = kmeans.cluster_centers_
print(centers)
[[88.2 17.11428571] [55.2962963 49.51851852] [86.53846154 82.12820513] [25.72727273 79.36363636] [26.30434783 20.91304348]]
For a two dimensional clustering, it is easy to visualize the results with a scatter plot.
plt.scatter(x, y, c=cluster_labels, s=35, alpha=0.9, cmap='jet')
plt.xlabel('Annual Income (k$)')
plt.ylabel('Spending Score')
plt.scatter(centers[:, 0], centers[:, 1], c='red', s=70)
<matplotlib.collections.PathCollection at 0x15add15a890>
We can add the cluster labels to the DataFrame, which will be very convenient for us to inquire the customers' information.
df['Labels'] = cluster_labels
df
CustomerID | Gender | Age | Annual Income (k$) | Spending Score (1-100) | Labels | |
---|---|---|---|---|---|---|
0 | 1 | 1 | 19 | 15 | 39 | 4 |
1 | 2 | 1 | 21 | 15 | 81 | 3 |
2 | 3 | 0 | 20 | 16 | 6 | 4 |
3 | 4 | 0 | 23 | 16 | 77 | 3 |
4 | 5 | 0 | 31 | 17 | 40 | 4 |
... | ... | ... | ... | ... | ... | ... |
195 | 196 | 0 | 35 | 120 | 79 | 2 |
196 | 197 | 0 | 45 | 126 | 28 | 0 |
197 | 198 | 1 | 32 | 126 | 74 | 2 |
198 | 199 | 1 | 32 | 137 | 18 | 0 |
199 | 200 | 1 | 30 | 137 | 83 | 2 |
200 rows × 6 columns
kmeans.predict([[120,79]])[0]
2
For new dataset, We can easily predict the cluster label(s).
kmeans.predict([[150,30]])[0]
0
df[df['Labels']==2]
CustomerID | Gender | Age | Annual Income (k$) | Spending Score (1-100) | Labels | |
---|---|---|---|---|---|---|
123 | 124 | 1 | 39 | 69 | 91 | 2 |
125 | 126 | 0 | 31 | 70 | 77 | 2 |
127 | 128 | 1 | 40 | 71 | 95 | 2 |
129 | 130 | 1 | 38 | 71 | 75 | 2 |
131 | 132 | 1 | 39 | 71 | 75 | 2 |
133 | 134 | 0 | 31 | 72 | 71 | 2 |
135 | 136 | 0 | 29 | 73 | 88 | 2 |
137 | 138 | 1 | 32 | 73 | 73 | 2 |
139 | 140 | 0 | 35 | 74 | 72 | 2 |
141 | 142 | 1 | 32 | 75 | 93 | 2 |
143 | 144 | 0 | 32 | 76 | 87 | 2 |
145 | 146 | 1 | 28 | 77 | 97 | 2 |
147 | 148 | 0 | 32 | 77 | 74 | 2 |
149 | 150 | 1 | 34 | 78 | 90 | 2 |
151 | 152 | 1 | 39 | 78 | 88 | 2 |
153 | 154 | 0 | 38 | 78 | 76 | 2 |
155 | 156 | 0 | 27 | 78 | 89 | 2 |
157 | 158 | 0 | 30 | 78 | 78 | 2 |
159 | 160 | 0 | 30 | 78 | 73 | 2 |
161 | 162 | 0 | 29 | 79 | 83 | 2 |
163 | 164 | 0 | 31 | 81 | 93 | 2 |
165 | 166 | 0 | 36 | 85 | 75 | 2 |
167 | 168 | 0 | 33 | 86 | 95 | 2 |
169 | 170 | 1 | 32 | 87 | 63 | 2 |
171 | 172 | 1 | 28 | 87 | 75 | 2 |
173 | 174 | 1 | 36 | 87 | 92 | 2 |
175 | 176 | 0 | 30 | 88 | 86 | 2 |
177 | 178 | 1 | 27 | 88 | 69 | 2 |
179 | 180 | 1 | 35 | 93 | 90 | 2 |
181 | 182 | 0 | 32 | 97 | 86 | 2 |
183 | 184 | 0 | 29 | 98 | 88 | 2 |
185 | 186 | 1 | 30 | 99 | 97 | 2 |
187 | 188 | 1 | 28 | 101 | 68 | 2 |
189 | 190 | 0 | 36 | 103 | 85 | 2 |
191 | 192 | 0 | 32 | 103 | 69 | 2 |
193 | 194 | 0 | 38 | 113 | 91 | 2 |
195 | 196 | 0 | 35 | 120 | 79 | 2 |
197 | 198 | 1 | 32 | 126 | 74 | 2 |
199 | 200 | 1 | 30 | 137 | 83 | 2 |
df[df['CustomerID']==125].values
array([[125, 0, 23, 70, 29, 0]], dtype=int64)
In this example, we segment the customers using all the variables except for the 'CustomerID' column.
df = data_encode.copy()
X = df.drop(['CustomerID'],axis=1).values
X
array([[ 1, 19, 15, 39], [ 1, 21, 15, 81], [ 0, 20, 16, 6], [ 0, 23, 16, 77], [ 0, 31, 17, 40], [ 0, 22, 17, 76], [ 0, 35, 18, 6], [ 0, 23, 18, 94], [ 1, 64, 19, 3], [ 0, 30, 19, 72], [ 1, 67, 19, 14], [ 0, 35, 19, 99], [ 0, 58, 20, 15], [ 0, 24, 20, 77], [ 1, 37, 20, 13], [ 1, 22, 20, 79], [ 0, 35, 21, 35], [ 1, 20, 21, 66], [ 1, 52, 23, 29], [ 0, 35, 23, 98], [ 1, 35, 24, 35], [ 1, 25, 24, 73], [ 0, 46, 25, 5], [ 1, 31, 25, 73], [ 0, 54, 28, 14], [ 1, 29, 28, 82], [ 0, 45, 28, 32], [ 1, 35, 28, 61], [ 0, 40, 29, 31], [ 0, 23, 29, 87], [ 1, 60, 30, 4], [ 0, 21, 30, 73], [ 1, 53, 33, 4], [ 1, 18, 33, 92], [ 0, 49, 33, 14], [ 0, 21, 33, 81], [ 0, 42, 34, 17], [ 0, 30, 34, 73], [ 0, 36, 37, 26], [ 0, 20, 37, 75], [ 0, 65, 38, 35], [ 1, 24, 38, 92], [ 1, 48, 39, 36], [ 0, 31, 39, 61], [ 0, 49, 39, 28], [ 0, 24, 39, 65], [ 0, 50, 40, 55], [ 0, 27, 40, 47], [ 0, 29, 40, 42], [ 0, 31, 40, 42], [ 0, 49, 42, 52], [ 1, 33, 42, 60], [ 0, 31, 43, 54], [ 1, 59, 43, 60], [ 0, 50, 43, 45], [ 1, 47, 43, 41], [ 0, 51, 44, 50], [ 1, 69, 44, 46], [ 0, 27, 46, 51], [ 1, 53, 46, 46], [ 1, 70, 46, 56], [ 1, 19, 46, 55], [ 0, 67, 47, 52], [ 0, 54, 47, 59], [ 1, 63, 48, 51], [ 1, 18, 48, 59], [ 0, 43, 48, 50], [ 0, 68, 48, 48], [ 1, 19, 48, 59], [ 0, 32, 48, 47], [ 1, 70, 49, 55], [ 0, 47, 49, 42], [ 0, 60, 50, 49], [ 0, 60, 50, 56], [ 1, 59, 54, 47], [ 1, 26, 54, 54], [ 0, 45, 54, 53], [ 1, 40, 54, 48], [ 0, 23, 54, 52], [ 0, 49, 54, 42], [ 1, 57, 54, 51], [ 1, 38, 54, 55], [ 1, 67, 54, 41], [ 0, 46, 54, 44], [ 0, 21, 54, 57], [ 1, 48, 54, 46], [ 0, 55, 57, 58], [ 0, 22, 57, 55], [ 0, 34, 58, 60], [ 0, 50, 58, 46], [ 0, 68, 59, 55], [ 1, 18, 59, 41], [ 1, 48, 60, 49], [ 0, 40, 60, 40], [ 0, 32, 60, 42], [ 1, 24, 60, 52], [ 0, 47, 60, 47], [ 0, 27, 60, 50], [ 1, 48, 61, 42], [ 1, 20, 61, 49], [ 0, 23, 62, 41], [ 0, 49, 62, 48], [ 1, 67, 62, 59], [ 1, 26, 62, 55], [ 1, 49, 62, 56], [ 0, 21, 62, 42], [ 0, 66, 63, 50], [ 1, 54, 63, 46], [ 1, 68, 63, 43], [ 1, 66, 63, 48], [ 1, 65, 63, 52], [ 0, 19, 63, 54], [ 0, 38, 64, 42], [ 1, 19, 64, 46], [ 0, 18, 65, 48], [ 0, 19, 65, 50], [ 0, 63, 65, 43], [ 0, 49, 65, 59], [ 0, 51, 67, 43], [ 0, 50, 67, 57], [ 1, 27, 67, 56], [ 0, 38, 67, 40], [ 0, 40, 69, 58], [ 1, 39, 69, 91], [ 0, 23, 70, 29], [ 0, 31, 70, 77], [ 1, 43, 71, 35], [ 1, 40, 71, 95], [ 1, 59, 71, 11], [ 1, 38, 71, 75], [ 1, 47, 71, 9], [ 1, 39, 71, 75], [ 0, 25, 72, 34], [ 0, 31, 72, 71], [ 1, 20, 73, 5], [ 0, 29, 73, 88], [ 0, 44, 73, 7], [ 1, 32, 73, 73], [ 1, 19, 74, 10], [ 0, 35, 74, 72], [ 0, 57, 75, 5], [ 1, 32, 75, 93], [ 0, 28, 76, 40], [ 0, 32, 76, 87], [ 1, 25, 77, 12], [ 1, 28, 77, 97], [ 1, 48, 77, 36], [ 0, 32, 77, 74], [ 0, 34, 78, 22], [ 1, 34, 78, 90], [ 1, 43, 78, 17], [ 1, 39, 78, 88], [ 0, 44, 78, 20], [ 0, 38, 78, 76], [ 0, 47, 78, 16], [ 0, 27, 78, 89], [ 1, 37, 78, 1], [ 0, 30, 78, 78], [ 1, 34, 78, 1], [ 0, 30, 78, 73], [ 0, 56, 79, 35], [ 0, 29, 79, 83], [ 1, 19, 81, 5], [ 0, 31, 81, 93], [ 1, 50, 85, 26], [ 0, 36, 85, 75], [ 1, 42, 86, 20], [ 0, 33, 86, 95], [ 0, 36, 87, 27], [ 1, 32, 87, 63], [ 1, 40, 87, 13], [ 1, 28, 87, 75], [ 1, 36, 87, 10], [ 1, 36, 87, 92], [ 0, 52, 88, 13], [ 0, 30, 88, 86], [ 1, 58, 88, 15], [ 1, 27, 88, 69], [ 1, 59, 93, 14], [ 1, 35, 93, 90], [ 0, 37, 97, 32], [ 0, 32, 97, 86], [ 1, 46, 98, 15], [ 0, 29, 98, 88], [ 0, 41, 99, 39], [ 1, 30, 99, 97], [ 0, 54, 101, 24], [ 1, 28, 101, 68], [ 0, 41, 103, 17], [ 0, 36, 103, 85], [ 0, 34, 103, 23], [ 0, 32, 103, 69], [ 1, 33, 113, 8], [ 0, 38, 113, 91], [ 0, 47, 120, 16], [ 0, 35, 120, 79], [ 0, 45, 126, 28], [ 1, 32, 126, 74], [ 1, 32, 137, 18], [ 1, 30, 137, 83]], dtype=int64)
If the data has more than three dimensions, it is challenging for us to use plotting method to find the optimal cluster number. Here, let's use some methods introduced in Section 1.2.
The elbow method is a heuristic method used to determine the optimal number of clusters in K-means clustering. The basic idea behind the elbow method is to plot the sum of squared distances between the data points and their assigned cluster centroids, as a function of the number of clusters, K.
The idea of the elbow method is to identify this "elbow point" on the plot, which corresponds to the optimal number of clusters. This point is usually determined by visually inspecting the plot and selecting the value of K where the rate of decrease in the sum of squared distances starts to level off. It is easily obtained from KMeans.inertia_
method in Scikit-learn.
inertias = []
for i in range(1, 10):
kmeans = KMeans(n_clusters=i,init='k-means++', max_iter = 300, n_init='auto', random_state=0)
kmeans.fit(X)
inertias.append(kmeans.inertia_)
plt.plot(range(1, 10), inertias)
plt.xlabel('Number of clusters')
plt.ylabel('Inertia')
Text(0, 0.5, 'Inertia')
The resulting plot will have the number of clusters on the x-axis and the sum of squared distances on the y-axis. The "elbow point" corresponds to the value of K where the rate of decrease in the sum of squared distances starts to level off. In this example, the elbow point appears to be at K=6, suggesting that 6 clusters might be the optimal choice for this dataset. However, the optimal number of clusters ultimately depends on the specific dataset and problem at hand, so the elbow method should be used as a heuristic rather than a definitive solution.
where:
p: is the mean distance to the points in the nearest cluster that the data point is not a part of
q: is the mean intra-cluster distance to all the points in its own cluster.
The value of the silhouette score range lies between -1 to 1, where a higher value indicates better clustering results. A value of 0 indicates that the sample is on or very close to the decision boundary between two neighboring clusters.
from sklearn.metrics import silhoDuette_score
# Silhouette score analysis to find the ideal number of clusters for K-means clustering
score=[]
range_n_clusters = range(2, 10)
for num_clusters in range_n_clusters:
# intialise kmeans
kmeans = KMeans(n_clusters=num_clusters, init='k-means++', random_state=0, max_iter = 300, n_init='auto')
kmeans.fit(X)
cluster_labels = kmeans.labels_
# silhouette score
silhouette_avg = silhouette_score(X, cluster_labels)
score.append(silhouette_avg)
print("For n_clusters={0}, the silhouette score is {1}".format(num_clusters, silhouette_avg))
For n_clusters=2, the silhouette score is 0.32323687252392846 For n_clusters=3, the silhouette score is 0.383798873822341 For n_clusters=4, the silhouette score is 0.4052954330641215 For n_clusters=5, the silhouette score is 0.37688936241822546 For n_clusters=6, the silhouette score is 0.4506609653808789 For n_clusters=7, the silhouette score is 0.403956517241377 For n_clusters=8, the silhouette score is 0.37726715689435 For n_clusters=9, the silhouette score is 0.3787881296338692
We can also plot the silhouette score and easily observe its maximum value.
plt.plot(range_n_clusters,score,'r*-')
plt.xlabel('Number of clusters')
plt.ylabel('silhouette Scores')
Text(0, 0.5, 'silhouette Scores')
The above result reveal that silhouette score reaches its maximum of 0.4506609653808789 when the cluster number is 6, and this result is identical with that by elbow method.
Based on the above result, we specify n_cluster=6
to class the customers into 6 clusters.
kmeans = KMeans(n_clusters=6, init='k-means++', random_state=0, max_iter = 300, n_init='auto')
kmeans.fit(X)
df['Cluster'] = kmeans.predict(X)
df.head()
CustomerID | Gender | Age | Annual Income (k$) | Spending Score (1-100) | Cluster | |
---|---|---|---|---|---|---|
0 | 1 | 1 | 19 | 15 | 39 | 2 |
1 | 2 | 1 | 21 | 15 | 81 | 3 |
2 | 3 | 0 | 20 | 16 | 6 | 2 |
3 | 4 | 0 | 23 | 16 | 77 | 3 |
4 | 5 | 0 | 31 | 17 | 40 | 2 |
Visualizing multidimensional clustering results can be challenging since we cannot directly visualize more than three dimensions in a 2D or 3D plot. However, there are some techniques that can help us to visualize the clusters and understand the relationships between the features. Here are some common approaches:
We use pairplot()
of seaborn library to create the scatter matrix plot. We exclude the 'CustomerID' columns.
results = df.drop(['CustomerID'], axis=1)
sns.pairplot(results, hue="Cluster",palette="rainbow")
<seaborn.axisgrid.PairGrid at 0x15adc90bf10>
The above plot allows us to visualize the relationships between all pairs of features in the dataset and how they relate to the clustering results.
Pandas provides an easy method to create parallel coordinate plots, which can be used to visualize multiple dimensional data.
plt.figure(figsize=(15,8))
pd.plotting.parallel_coordinates(results,'Cluster',alpha=0.90)
plt.xticks(rotation=45)
plt.show()
This plot can easily interpret the results in terms of variables or clusters. Just take two examples, from the results of the Age, we can see that the customers are comparatively older in the cluster 5 than the rest clusters. In terms of the clusters, cluster 0 has the lowest spending score.
To create a heatmap to show the relations between each cluster and each variable, we need to shape the data using pandas melt as follows.
results_melt = pd.melt(results, id_vars=['Cluster'],
value_vars=['Gender', 'Age', 'Annual Income (k$)', 'Spending Score (1-100)'],
var_name='Variables',
value_name='Values')
results_melt.head()
Cluster | Variables | Values | |
---|---|---|---|
0 | 2 | Gender | 1 |
1 | 3 | Gender | 1 |
2 | 2 | Gender | 0 |
3 | 3 | Gender | 0 |
4 | 2 | Gender | 0 |
Then we pivot the results DataFrame, and then use seaborn to create the heatmap.
results_pivot = results_melt.pivot_table(index="Cluster",columns="Variables", values="Values")
results_pivot.head()
Variables | Age | Annual Income (k$) | Gender | Spending Score (1-100) |
---|---|---|---|---|
Cluster | ||||
0 | 41.647059 | 88.735294 | 0.558824 | 16.764706 |
1 | 27.315789 | 57.500000 | 0.368421 | 48.447368 |
2 | 44.318182 | 25.772727 | 0.409091 | 20.272727 |
3 | 25.521739 | 26.304348 | 0.391304 | 78.565217 |
4 | 32.692308 | 86.538462 | 0.461538 | 82.128205 |
We can easily use heatmap
method provided in seaborn data visualization library to create the heatmap.
sns.heatmap(results_pivot)
<Axes: xlabel='Variables', ylabel='Cluster'>
The above heatmap cannot display the gender well due to its smaller values, thus it is good to display the cell values with text using annot=True
.
sns.heatmap(results_pivot,annot=True)
<Axes: xlabel='Variables', ylabel='Cluster'>
The total numbers of males and females, and average values of the rest variable would be very helpful to give an insight on the cluster results.
col_names = ['Cluster', 'Average Age', 'Average Income','Average Spending Index', 'Number of Females',
'Number of Males']
mean_results = pd.DataFrame(columns=col_names)
for i, center in enumerate(kmeans.cluster_centers_):
# Averages/mean of age, income and spending score
mean_age = center[1]
mean_income = center[2]
mean_spend = center[3]
# numbers of females and males
clusters_df = df[df['Cluster'] == i]
n_females = clusters_df[clusters_df['Gender'] == 0].shape[0]
n_males = clusters_df[clusters_df['Gender'] == 1].shape[0]
mean_results.loc[i] = ([i, mean_age, mean_income, mean_spend, n_females, n_males])
mean_results.head()
Cluster | Average Age | Average Income | Average Spending Index | Number of Females | Number of Males | |
---|---|---|---|---|---|---|
0 | 0.0 | 41.647059 | 88.735294 | 16.764706 | 15.0 | 19.0 |
1 | 1.0 | 27.315789 | 57.500000 | 48.447368 | 24.0 | 14.0 |
2 | 2.0 | 44.318182 | 25.772727 | 20.272727 | 13.0 | 9.0 |
3 | 3.0 | 25.521739 | 26.304348 | 78.565217 | 14.0 | 9.0 |
4 | 4.0 | 32.692308 | 86.538462 | 82.128205 | 21.0 | 18.0 |
It is also helpful to create a parallel coordinate plot and visualize the above result table.
pd.plotting.parallel_coordinates(mean_results,'Cluster',sort_labels=True)
plt.xticks(rotation=45)
plt.show()
This article demonstrates how to implement K-means clustering using scikit-learn library and visualize the results using pandas, Matplotlib and seaborn. It covers the following main topics: (1) essential concepts on K-means clustering and its applications; two dimensional clustering example; (4) a multidimensional clustering example; and (5) how to display and visualize multidimensional clustering results.