前言

大家好,之前写多了自动化办公的内容,现在换个机器学习的专题跟大家交流学习,作为一个眼科研究生后面也希望后面多通过一些眼科案例顺带普及下眼科知识!在眼科中AI的一项应用就是利用卷积神经网络实现图像识别。今天先从一个虚构的冠心病数据集说说python如何实现简单的有监督学习

数据说明

因文章以分享技术为目的,疾病数据集不含有现实意义,且出于保护目的将四个特征指标以S1-S4替代


400+多位病人的数据,包含年龄、性别(1为男性,2为女性),S1-S4为4个冠心病检测指标,Results是冠心病高相关性的定量指标,也是我们本次设计模型需要预测的指标

有监督学习是指有目标变量或预测目标的机器学习方法,包括分类和回归

本例中需要预测的是连续的定量指标,属于回归问题。作为入门介绍就简单利用scikit-learn库中的LinearRegression()实现

初步代码实现

首先导入需要的库并设置好文件路径

import pandas as pd
# 分隔训练集和测试集,本例用7:3
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import os

# 利用函数定位到桌面文件夹,个人习惯。可指定绝对路径
def GetDesktopPath():
    return os.path.join(os.path.expanduser("~"), 'Desktop')

dat_path = f'{GetDesktopPath()}\\data\\冠心病.csv'

指定特征列(需要纳入预测模型的指标)

features = ['Age', 'Sex', 'S1', 'S2', 'S3', 'S4']

读取数据集并分隔

CAD_data = pd.read_csv(dat_path)
X = CAD_data[features].values
y = CAD_data['Results'].values

# 分割数据集,本例训练集和测试集分割比例为7:3
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=3/10, random_state=10)

random_state参数设置的好处:

  • 读入数据集如果是有序整理好的,如果不随机则模型构建效果大打折

  • 类似于R语言的set.seed()函数,设定生成随机数的种子,让结果能够重现

现在建立线性回归模型,并进行训练及验证

linear_model = LinearRegression()
linear_model.fit(X_train, y_train)
R2 = linear_model.score(X_test, y_test)
print('基础线性模型的R2值为:{:.4f}'.format(R2))

# 基础线性模型的R2值为:0.4100

模型优化

上述基础线性回归模型存在几个问题:

  • 不同的数值变量所处的范围不同,可以考虑归一化,消除量纲或者其他因素可能引入的偏差,影响模型精度。有一种常用方法是将数值线性缩放到 [-1, 1] 或 [0, 1]

  • 性别是分类变量,男女彼此没有高低之分。抽象来说就是离散特征的取值之间没有大小的意义,但用 1 和 0 代替分类变量进入模型中会引入数值大小的区别。数据预处理针对这类变量可以考虑使用独热编码 (One-Hot Encode),又称一位有效编码,其方法是使用N位状态寄存器来对N个状态进行编码。独热编码在各类算法中运用广泛,这里只是非常简单的运用。简单理解就是男性 -> [0, 1],女性 -> [1, 0]

综上,对特征进行分类后分别进行相应的处理,有时会使模型性能提升。这属于特征工程的范畴,仅类别特征可再细分为:

有兴趣的读者可以自行了解。下面是优化模型的代码。


首先引入所需库

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import os
# 引入numpy库做多维数据合并
import numpy as np 
# 数据预处理用到的独热编码和最大最小归一化
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler

def GetDesktopPath():
    return os.path.join(os.path.expanduser("~"), 'Desktop')

dat_path = f'{GetDesktopPath()}\\data\\冠心病.csv'

特征分类

numeric_features = ['Age''S1''S2''S3''S4']
category_features = ['Sex']numeric_features = ['Age''S1''S2''S3''S4']
category_features = ['Sex']

读取数据

CAD_data = pd.read_csv(dat_path)

X = CAD_data[numeric_features + category_features]
y = CAD_data['Results']

数据预处理,注意训练集和测试集的特征都需要预处理,因此可以考虑封装成函数方便调用

def preprocessing(train, test):
    # 独热编码处理分类变量
    encoder = OneHotEncoder(sparse=False)
    encoded_train = encoder.fit_transform(train[category_features])
    encoded_test = encoder.transform(test[category_features])
    # 归一化处理数值变量
    scaler = MinMaxScaler()
    scaled_train = scaler.fit_transform(train[numeric_features])
    scaled_test = scaler.transform(test[numeric_features])
    # 横向合并
    train_new = np.hstack((encoded_train, scaled_train))
    test_new = np.hstack((encoded_test, scaled_test))
    # 返回数据
    return train_new, test_new

分隔数据集并对特征预处理

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=3/10, random_state=10)
X_train_new, X_test_new = preprocessing(X_train, X_test)

最后构建新模型,训练并验证

linear_model_new = LinearRegression()
linear_model_new.fit(X_train_new, y_train)
R2_new = linear_model_new.score(X_test_new, y_test)
print('优化线性模型的R2值为:{:.4f}'.format(R2_new))

我们看一下结果,并跟旧模型进行比对

结束语

其实这个结果并不出乎意料,样本量小也是特征预处理不能发挥出显著优化作用的一个原因。另外,针对模型优化可以再指出的一点是,如果特征较多时往往也不会全部纳入模型中拟合,也要考虑相关性做适当舍弃剪裁。例如本例中实际上去掉年龄Age特征后模型的R值上升会比直接预处理更明显!


当然,本实例的目的不是为了将模型优化的多好,而是希望通过这个简单的案例能够吸引更多的人学习Python,学习人工智能,并用于现实世界,产生新的思想并创造价值!


©著作权归作者所有:来自51CTO博客作者mb5fe18e32e4691的原创作品,如需转载,请注明出处,否则将追究法律责任

更多相关文章

  1. js中基础数据结构数组去重问题
  2. 你真的懂网络分层模型吗?
  3. mysql group_concat 获取一对多的数据
  4. 【工具】历史文章分类汇总-V4 | Python数据之道
  5. 【工具】历史文章分类汇总-V5 | Python数据之道
  6. 福布斯系列之数据清洗(3) | Python数据分析项目实战
  7. 福布斯系列之数据清洗(5) | Python数据分析项目实战
  8. 福布斯系列之数据清洗(2) | Python数据分析项目实战
  9. 福布斯系列之补充数据收集 | Python数据分析项目实战

随机推荐

  1. Android Content Providers(二)——Contact
  2. WebView的使用之Android与JS通过WebView
  3. 【Android】背景知识
  4. Android中的shape中的属性大全
  5. gif in android
  6. 【Android】开源项目汇总
  7. Android ui基础——gravity 与 layout_gr
  8. Android SDCard Mount 流程分析(一)
  9. Android inputType ,软键盘输入类型
  10. android EditText中的inputType