单层感知机在Mnist上的实现

单层感知机

原理博客:统计学习方法|K近邻原理剖析及实现 | Dodo (pkudodo.com)

数据集:Statistical-Learning-Method_Code/Mnist at master · Dod-o/Statistical-Learning-Method_Code (github.com)

读取数据

数据格式:

mnist数据集,第一列为label,后面为data

image-20221015185327165

完整代码

import numpy as np
import matplotlib.pyplot as pltdef read_data(addr):# 打开文件data_csv = np.loadtxt(open(addr, 'r'), delimiter=',')# 便签数据labels = data_csv[:, 0]flag = labels > 5labels[flag] = 1labels[~flag] = -1# 数据data = data_csv[:, 1:] / 255labels = labels.reshape(labels.shape[0], 1)return data, labelsdef perceptron(data, labels, epochs=50, lr=0.01):# 初始化参数 w bw_num = data.shape[-1]w = np.zeros([1, w_num])w = np.random.normal(size=[1, w_num])b = 0.loss = 0.data_num = data.shape[0]for epoch in range(epochs):for image, label in zip(data, labels):loss = -1 * label * (np.matmul(image, w.T) + b)# 计算距离超平面的距离if loss >= 0:w = w + label * image * lrb = b + label * lrprint("epoch:{},loss:{}".format(epoch, loss))return w, bdef test(w, b, data, labels):err_count = 0sum_count = data.shape[0]for image, label in zip(data, labels):output = -1 * label * (np.matmul(image, w.T) + b)err_count = err_count + 1 if output >= 0 else err_countprint("准确率为:{}".format((sum_count - err_count) / sum_count))if __name__ == '__main__':addr = '../Mnist/mnist_train.csv'test_addr = '../Mnist/mnist_test.csv'data, lables = read_data(addr)w, b = perceptron(data, lables, epochs=10)test_data, test_labels = read_data(test_addr)test(w, b, test_data, test_labels)# image = data[5].reshape(28, 28)# plt.imshow(image)# plt.show()


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部