机器人与人工智能爱好者论坛

 找回密码
 立即注册
查看: 15055|回复: 0
打印 上一主题 下一主题

K近邻算法

[复制链接]

115

主题

116

帖子

392

积分

中级会员

Rank: 3Rank: 3

积分
392
跳转到指定楼层
楼主
发表于 2019-1-16 14:46:35 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
算法原理
k近邻(k-Nearest Neighbor,kNN),应该是最简单的传统机器学习模型,给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的k个实例,这k个实例中的大多数属于哪个类别,就把该输入实例划分到这个类别。

k近邻算法没有显示的训练过程,在“训练阶段”仅仅是把样本保存起来,训练时间开销为零,待收到测试样本后在进行计算处理。
这个k实际上是一个超参数,k值的选择会对k近邻法的结果产生重大影响。如果选择较小的k值,意味着只有与输入实例较近的(相似的)训练实例才会对预测结果起作用,预测结果会对近邻的实例点非常敏感,如果近邻的实例点恰巧是噪声点,预测就会出错;如果选择较大的k值,就意味着与输入实例较远的(不相似的)训练实例也会对预测起作用,这样预测也会出错。在实际应用中,k值一般取一个比较小的数值,并且通常采用交叉验证法来选取最优的k值。如上图的k=5。
模型训练代码地址:https://github.com/qianshuang/ml-exp
def train():
   print("start training...")
   # 处理训练数据
   train_feature, train_target = process_file(train_dir, word_to_id, cat_to_id)
   # 模型训练
   model.fit(train_feature, train_target)
def test():
   print("start testing...")
   # 处理测试数据
   test_feature, test_target = process_file(test_dir, word_to_id, cat_to_id)
   # test_predict = model.predict(test_feature)  # 返回预测类别
   test_predict_proba = model.predict_proba(test_feature)    # 返回属于各个类别的概率
   test_predict = np.argmax(test_predict_proba, 1)  # 返回概率最大的类别标签
   # accuracy
   true_false = (test_predict == test_target)
   accuracy = np.count_nonzero(true_false) / float(len(test_target))
   print()
   print("accuracy is %f" % accuracy)
   # precision    recall  f1-score
   print()
   print(metrics.classification_report(test_target, test_predict, target_names=categories))
   # 混淆矩阵
   print("Confusion Matrix...")
   print(metrics.confusion_matrix(test_target, test_predict))
if not os.path.exists(vocab_dir):
   # 构建词典表
   build_vocab(train_dir, vocab_dir)
categories, cat_to_id = read_category()
words, word_to_id = read_vocab(vocab_dir)
# kNN
model = neighbors.KNeighborsClassifier()
train()
test()运行结果:
read_category...
read_vocab...
start training...
start testing...
accuracy is 0.820000
            precision    recall  f1-score   support
        时政       0.65      0.85      0.74        94
        财经       0.81      0.94      0.87       115
        科技       0.96      0.97      0.96        94
        游戏       0.99      0.74      0.85       104
        娱乐       0.99      0.75      0.85        89
        时尚       0.88      0.67      0.76        91
        家居       0.44      0.78      0.56        89
        房产       0.93      0.82      0.87       104
        体育       1.00      0.98      0.99       116
        教育       0.96      0.65      0.78       104
avg / total       0.87      0.82      0.83      1000
Confusion Matrix...
[[ 80   4   0   0   0   0   6   3   0   1]
[  1 108   0   0   0   0   6   0   0   0]
[  0   0  91   0   0   0   3   0   0   0]
[  4   0   1  77   0   3  18   0   0   1]
[  4   3   0   1  67   4  10   0   0   0]
[  0   0   0   0   1  61  29   0   0   0]
[  9   5   2   0   0   0  69   3   0   1]
[  9   3   0   0   0   0   7  85   0   0]
[  2   0   0   0   0   0   0   0 114   0]
[ 14  10   1   0   0   1  10   0   0  68]]
社群:

公众号:

了解更多干货关注小程序八斗问答
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

关闭

站长推荐上一条 /1 下一条

QQ|Archiver|手机版|小黑屋|陕ICP备15012670号-1    

GMT+8, 2024-11-25 15:29 , Processed in 0.058358 second(s), 24 queries .

Powered by Discuz! X3.2

© 2001-2013 Comsenz Inc.

快速回复 返回顶部 返回列表