Check
Step 4: Check¶
Now that we are done with the first three steps, let’s begin from the top. I am going to define some of the work from prior as functions to make this easier.
def assign_labels(X, colors):
#Get the distance
distance = [((X - c) ** 2).sum(axis=1) ** .5 for c in colors]
#Stack the distances
distance = np.vstack(distance)
#Find the labels
labels = distance.argmin(axis=0)
return labels
def compute_centroids(X, colors, labels):
#Find the centers
colors = [X[labels == l].mean(axis=0) for l in list(range(len(colors)))]
#Stack
colors = np.vstack(colors)
return colors
def plot_clusters(colors, labels, img_shape):
#Plot the clusters
Y = colors[labels]
Y = Y.reshape(img_shape)
plt.imshow(Y)
plt.show()
We will choose a different seed this time.
#Set the seed to make it easy to replicate
np.random.seed(1)
#Randomly choose 4 colors
colors = np.random.uniform(0,1,(4,3))
Let's do a first iteration.
print("Image after labeling:")
labels = assign_labels(X, colors)
plot_clusters(colors, labels, img_shape)
print()
colors = compute_centroids(X, colors, labels)
print("Image after re-computing centroids:")
plot_clusters(colors, labels, img_shape)
And a second iteration.
print("Image after labeling:")
labels = assign_labels(X, colors)
plot_clusters(colors, labels, img_shape)
print()
colors = compute_centroids(X, colors, labels)
print("Image after re-computing centroids:")
plot_clusters(colors, labels, img_shape)
One more time before we introduce the loop.
print("Image after labeling:")
labels = assign_labels(X, colors)
plot_clusters(colors, labels, img_shape)
print()
colors = compute_centroids(X, colors, labels)
print("Image after re-computing centroids:")
plot_clusters(colors, labels, img_shape)
Now for our loop, we are going to set a maximum number of iterations and also check if the centroids have stopped moving. If either condition is true we stop!
#Set the seed to make it easy to replicate
np.random.seed(1)
#Randomly choose 4 colors
colors = np.random.uniform(0,1,(4,3))
max_iter = 5
#Start with labels as -1 meaning null
labels = np.ones(len(X)) * -1
for _ in range(max_iter):
#Hold onto the old labels
old_labels = labels.copy()
print("Image after labeling:")
labels = assign_labels(X, colors)
plot_clusters(colors, labels, img_shape)
print()
colors = compute_centroids(X, colors, labels)
print("Image after re-computing centroids:")
plot_clusters(colors, labels, img_shape)
#If all labels are the same, end the iteration
if (labels == old_labels).all():
break
else:
print("{} Labels Changed".format((labels != old_labels).sum()))
With a cutoff of 5 iterations, we notice that there is still some labels which are being changed. We might want to increase the number of labels in the future because of that. Let's also try this with 10 colors.
#Set the seed to make it easy to replicate
np.random.seed(0)
#Randomly choose 4 colors
colors = np.random.uniform(0,1,(10,3))
max_iter = 5
#Start with labels as -1 meaning null
labels = np.ones(len(X)) * -1
for _ in range(max_iter):
#Hold onto the old labels
old_labels = labels.copy()
print("Image after labeling:")
labels = assign_labels(X, colors)
plot_clusters(colors, labels, img_shape)
print()
colors = compute_centroids(X, colors, labels)
print("Image after re-computing centroids:")
plot_clusters(colors, labels, img_shape)
#If all labels are the same, end the iteration
if (labels == old_labels).all():
break
else:
print("{} Labels Changed".format((labels != old_labels).sum()))
If we get rid of all the plotting, and run this for up to 1000 times for 5 colors, we can see how long it takes to converge as well as the final picture.
#Set the seed to make it easy to replicate
np.random.seed(0)
#Randomly choose 10 colors
colors = np.random.uniform(0,1,(5,3))
max_iter = 1000
#Start with labels as -1 meaning null
labels = np.ones(len(X)) * -1
num_iter = 0
for _ in range(max_iter):
num_iter += 1
#Hold onto the old labels
old_labels = labels.copy()
labels = assign_labels(X, colors)
colors = compute_centroids(X, colors, labels)
#If all labels are the same, end the iteration
if (labels == old_labels).all():
break
print("Converged after {} iterations.".format(num_iter))
plot_clusters(colors, labels, img_shape)