Kmeans++ 对图像聚类
kmeans算法是较为常见的聚类算法,不仅可以对二维的坐标点进行聚类,还可以对高维的图像信息进行聚类。Kmeans算法对初始质心的选择比较敏感,Kmeans++算法针对初始质心的选择做了改进,使得几个初始质心尽可能的远。
在使用kmeans算法对二维坐标进行聚类时,聚类的依据是坐标点与质心之间的距离;同样,对于高维度的图像信息,可以将像素点之间的差异看作距离,这样得到的每个簇,都是像素点差异较小的图,简单来说,每个簇内是图像相似度较高的图像。
这里使用kmeans++算法,对CIFAR10数据集进行聚类。CIFAR10是一个用于图像分类的数据集,共有10个类别,每张图像的大小为32*32*3。程序在CIFAR10数据集内挑选出了200张图像,并对这200张图像进行聚类,k值设为10。
聚类结果:

簇1

簇2

簇3
簇4

簇5

簇6
簇7
簇8
簇9
簇10
其中,avg.jpg是质心,是簇内所有图像的平均值,是对簇内信息的抽象反映。
可以看出,聚类效果还是不错的,把一些较为相似的图像放在了一个簇内。
代码实现:
import torchvision
import torch
import numpy as np
from torch.utils.data import DataLoader
import os
import cv2
import randomdef load_data():transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])train = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)train_loader = torch.utils.data.DataLoader(train, batch_size=1)for i, (img, _) in enumerate(train_loader):images.append(np.transpose(img[0].numpy(), (1, 2, 0)))if i > 200:breakdef distance(img, centroids):return np.array([np.sum(centroid - img) ** 2 for centroid in centroids])def kmeans_plus(k):centroids = []idx = random.randint(0, len(images) - 1)centroids.append(images.pop(idx))for _ in range(k - 1):sum = 0dx = np.zeros((len(images),))for i, img in enumerate(images):dx[i] = np.min([(img - centroid) ** 2 for centroid in centroids])sum += dx.sum()p = np.array(dx) / summax_idx = np.argmax(p)centroids.append(images.pop(max_idx))print("finish")return centroidsdef kmeans(k=10):centroids = kmeans_plus(k)clu = dict()for epoch in range(100):for i in range(k):clu[i] = []for img in images:index = distance(img, centroids).argmin()clu[index].append(img)for i in range(k):sum = np.zeros_like(img)for img in clu[i]:sum += imgmean = sum / len(clu[i])centroids[i] = meanprint(epoch)return clu, centroidsif __name__ == '__main__':images = []load_data()clu, centroids = kmeans()for i in range(len(centroids)):os.mkdir(f"./{i}")ims = clu[i]sum = np.zeros_like(images[0].shape)for idx, im in enumerate(ims):cv2.imwrite(f"./{i}/{idx}.jpg", cv2.resize(im * 255, (320, 320)))sum = sum + imsum = sum / (idx + 1)cv2.imwrite(f"./{i}/avg.jpg", cv2.resize((sum * 255), (320, 320)))
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
