K-means clustering (previous discussion) is an unsupervised learning algorithm which assigns points to one of K different clusters based on the distance of that point to a centroid. The points may represent physical locations, or embeddings in high-dimensional vector space.

๐ŸŒŸCheck out the demo (in two dimensions) below. Centroids are colored white.๐ŸŒŸ

Note that the points are changing color only, not moving.

General algorithm

The basic K-means algorithm is fairly simple and has two steps, repeated until convergence (i.e. when no points change cluster):

  1. assign points to the cluster corresponding to closest centroid
  2. update the centroid locations to the mean of all points assigned to the associated cluster
import numpy as np

'''
General setup: we will demo 5000 points, each with two 
dimensions. They will be clustered into 5 clusters.
'''

N, k, d = 5000, 5, 2
points = np.random.random((N,d))
centroids = np.random.random((k,d))

while True:

  # assign each point to closest cluster
  clusters = assign_clusters(points, centroids)

  # recalculate centroids
  new_centroids = update_centroids(points, clusters, k)

  # if centroids haven't changed, we're done
  if centroids_equal(new_centroids, centroids):
    break

  # otherwise, update centroids
  centroids = new_centroids

Vectorizing the algorithm

Now that we’ve established the general algorithm, we might consider how it can be vectorized. In essence, this means we’ll avoid writing for-loops in python. In numpy, this means cython will handle for-loops. For environments with a GPU, we can use PyTorch, JAX or Tensorflow to take advantage of parallel processing that GPUs excel at.

How broadcasting works

Before going further it would be good to pause and introduce broadcasting. The numpy guide is much more comprehensive, but the gist is that because numpy needs tensors to have the same shape before performing certain arithmetic operations on them, numpy will automatically try to align their shapes. Importantly, this is done without making copies of the data. The steps for broadcasting are as follows:

  1. A scalar will simply take the shape of the vector it’s being broadcasted to.
  2. If one vector has fewer dimensions, left-pad its shape with 1s.
  3. Starting from the right, check dimension compatibility. If they are equal or one of them is 1, the dimensions are compatible.

For example,

# scalars are the simplest
(np.ones((2,2)) * 3).shape # (2,2)

# (3,) is left padded to (1,3)
(np.ones((3)) * np.ones((2,3))).shape # (2,3)

# if dimensions equal or one of them is 1, they are compatible
(np.ones((2,3,1)) * np.ones((2,1,4))).shape # (2,3,4)

# error! [1,1] * [1,1,1] doesn't work
(np.ones((2)) * np.ones(3)) # incompatible shapes!

For more information, check out the Numpy guide: https://numpy.org/devdocs/user/basics.broadcasting.html.

assign_clusters()

Ok, back to K-means. Let’s start with assign_clusters(). A point is assigned to the cluster corresponding to the centroid which is closest.

The general idea is that for each point we’d like to know its distance to each centroid. We can then take the argmin of those distances to identify the correct cluster and return a result with shape (N,) (N = # of points). Distance will be computed using euclidean distance i.e. subtract, square, and sum.

First we expand the dimensions of the points from (N,d) to (N,1,d). We also expand the dimensions of the centroids to (1,k,d). (This happens automatically via broadcasting.) When we take the difference, we are left with a vector with shape (N,k,d). For each point in N, this yields a vector with k rows and d columns.

Now that we have the differences, we square them, and take the sum across the last dimension. This yields a vector with shape (N,k).

For actual euclidean distance we would then take the square root of the result, but since we just care about the relative distances we can skip this step.

Finally, we just take the argmin of the last dimension, i.e. for each row, the index of the smallest distance. Now we have our (N,) result.

def assign_clusters(points, centroids):
  # (5000,2) -> (5000,1,2)
  points = np.expand_dims(points, 1)

  # (5000,1,2) - (5,2)
  # with broadcasting becomes (5000,5,2)
  diff = points - centroids

  diff = diff ** 2

  # (5000,5,2) -> (5000,5)
  diff = diff.sum(-1)

  # (5000,5) -> (5000,)
  diff = diff.argmin(-1)

  return diff

update_centroids()

To update the centroids, we simply take the mean of the points within that cluster. To do this, we will create a vector with shape (N,k,d) again, and mask out points not in the cluster.

We start by creating an N x k mask, where 1s correspond to the n-th point being in the k-th cluster. We will also temporarily add a dimension at the end to make this N x k x 1.

Adding a dimension to the points vector for proper broadcasting, we can then create the N x k x d vector we planned above. Summing this across the first dimension yields a k x d vector, representing the sums of all of the points within each of the k clusters.

Now we just need to divide by the number of points in each cluster, which we can get by summing the N x k mask we created earlier across the first dimension, yielding k counts.

At this point, we can divide the sum by the counts to get the means.

def update_centroids(points, clusters, k):
    N = clusters.shape[0]

    # create an N x k mask, where each row contains a 1
    # in the position corresponding to its cluster
    mask = np.zeros((N, k))
    mask[np.arange(N), clusters] = 1

    # add a dimension in the middle so that 
    # multiplication will broadcast
    # (5000,2) -> (5000,1,2)
    points = np.expand_dims(points, 1)

    # (5000,5,1) * (5000,1,2) results in a (5000,5,2) vector
    # for each of the 5000 points, we have a 5x2 vector where only
    # one of the five rows will have non-zero values
    masked_points = np.expand_dims(mask, -1) * points

    # sum across points
    # (5000, 5, 2) -> (5, 2)
    sums = masked_points.sum(0)

    # sum across points
    counts = mask.sum(0)

    # avoid divide by zero
    counts[counts == 0] = 1

    # divide sum by count to get means
    return sums / np.expand_dims(counts, -1)

centroids_equal()

As a last implementation, we need to check whether the centroids have changed. For this we can just call np.array_equal() on the old and new centroids vectors.

def centroids_equal(vec1, vec2):
  return np.array_equal(vec1, vec2)

Comparing vectorization vs loops

To compare the vectorized implementation with the non-vectorized implementation, I ran each 100 times and took the mean clock time to converge. Note that because the points are random, K-means may take different numbers of iterations to converge.

The vectorized implementation (green) runs in significantly less time, between 7 and 21 times faster.