Soft K-Means Clustering

Introduction / Background

Soft K-Means Clustering is an extension of K-Means Clustering

softmax(xi)=exij=1nexj

Implementation

Algorithm Inputs

def soft_k_means(X, k=3, max_iterations=3)

Here is the function definition, whose inputs are:

[x1,1x1,dx2,1x2,dxN,1xN,d]

Setup

This algorithm follows the same setup as that of K-Means Clustering

Main Loop

for i in range(max_iterations):
    # 1. calculate the distance between each point and the k means
    # 2. calculate the probability distribution of each point belonging to each mean
    # (optionally calculate the loss)
    # 3. if no updates are made then stop early
    # 4. update each k mean to be the mean of all points nearest to it

Loop Step 1

dists = euclidean(X, mus)
[dist1,1dist1,kdist2,1dist2,kdistN,1distN,k]

Loop Step 2

exps = np.exp(dists) # Nxk
r = exps/np.sum(exps, axis=1, keepdims=True) # Nxk
[ex1,1ex1,dex2,1ex2,dexN,1exN,d]
[ex1,1/m=1dex1,mex1,d/m=1dex1,mex2,1/m=1dex2,mex2,d/m=1dex2,mexN,1/m=1dexN,mexN,d/m=1dexN,m]

Loop Step 4

labels = np.argmin(dists, axis=1)
if len(ret) > 0 and (ret[-1] == labels).all():
    print(f"Early stop at index {i}")
    break
ret.append(labels)
[label1label2labelN]
where:  labeli[0,k),Z

Loop Step 5

for j in range(k):
    mus[j] = r[:,j].dot(X)/np.sum(r[:,j]) 
[ex1,j/m=1dex1,mex2,j/m=1dex2,mexN,j/m=1dexN,m]

Overall code

def soft_k_means(X, k=3, max_iterations=3, beta=1.0):
    N, dim = X.shape
    ret = []
    mus = X[np.random.choice(N, size=(k,), replace=False)]

    for i in range(max_iterations):
        # Step 1
        dists = euclidean(X, mus)
        # Step 2
        exps = np.exp(-beta*dists)
        r = exps/np.sum(exps, axis=1, keepdims=True)
        # Step 3
        labels = np.argmin(dists, axis=1)
        # the line below is the optional loss calculation
        loss = sum([np.sum(dists[np.where(labels==j), j]) for j in range(k)])
        if len(ret) > 0 and (ret[-1][0] == labels).all():
            print(f"EARLY STOP AT {i}, max_iterations={max_iterations}")
            break
        ret.append((labels, loss))

        # Step 4
        for j in range(k):
            mus[j] = r[:,j].dot(X)/np.sum(r[:,j]) 

    return ret

Previous:K-Means Clustering
Next:Support Vector Machines