[scikit-learn]普通最小二乘法

返回上级目录

1.1.1普通最小二乘法

对于一般直线拟合,可以使用:

通过计算

$\sum_{i=1}^{n}{\left(y_i-a-bx_i\right)=0}$

$\sum_{i=1}^{n}{\left(y_i-a-bx_i\right)×x_i=0}$

​求得a,b,得到结果。

对于一般多项式最小二乘法的矩阵形式为

$Ax=b$

其中$A$为$x\times k$的矩阵,$x$为$k\times 1$的列向量

所以通过计算​法方程

$\left[\begin{matrix}m&\sum x_i&\sum x_i^2\\sum x_i&\sum x_i^2&\sum x_i^3\\sum x_i^2&\sum x_i^3&\sum x_i^4\\end{matrix}\right]\left[\begin{matrix}a_0\a_1\a_2\\end{matrix}\right]=\left[\begin{matrix}\sum y_i\\sum{x_iy}_i\\sum{x_i^2y}_i\\end{matrix}\right]$

解的系数a

对于sklearn中,实现方式只需要使用fit即可。最终得到的结果系数集合$A$(在官方文档中为$w$)被保存在coef_中,实现代码如下:

1.1.1.1示例

from sklearn import linear_model
reg = linear_model.LinearRegression()
reg.fit([[0, 0], [1, 1], [2, 2]], [0, 1, 2])
print(reg.coef_)

针对上面的简单例子,我们可以知道调用了LinearRegression()用于线性回归。取以下几点$(0,0,0),(1,1,1),(2,2,2)$,很容易看出,该公式一共有两个变量:

$y=w_0+w_1x_1+w_2x_2$

同样,我们很容易看出$w$为[0,0.5,0.5],而实际运算结果

PS E:\文本资料\sklearn> & 'C:\ProgramData\Anaconda3\python.exe' 'c:\Users\wren1\.vscode\extensions\ms-python.python-2019.9.34911\pythonFiles\ptvsd_launcher.py' '--default' '--client' '--host' 'localhost' '--port' '52819' 'e:\文本资料\sklearn\代码\code1.1.1.1.py'
[0.5 0.5]

1.1.1.2 示例

示例使用的是diabetes数据集第一个特征,尝试回归成一条直线

我将代码的注释进行了翻译并对一些内容进行了注解。如下:


import matplotlib.pyplot as plt import numpy as np from sklearn import datasets, linear_model from sklearn.metrics import mean_squared_error, r2_score # 加载diabetes数据集 diabetes = datasets.load_diabetes() # 只使用其中一列 diabetes_X = diabetes.data[:, np.newaxis, 2] # 将数据划分成训练集和验证集两个部分,其中验证集为后20个 diabetes_X_train = diabetes_X[:-20] diabetes_X_test = diabetes_X[-20:] # 相对应的结果也同样划分 diabetes_y_train = diabetes.target[:-20] diabetes_y_test = diabetes.target[-20:] # 实例化线性回归 regr = linear_model.LinearRegression() # 使用fit对训练集进行训练 regr.fit(diabetes_X_train, diabetes_y_train) # 再使用验证集对结果进行预测 diabetes_y_pred = regr.predict(diabetes_X_test) # 系数 print('系数: \n', regr.coef_) # 均方差 print("均方差: %.2f" % mean_squared_error(diabetes_y_test, diabetes_y_pred)) # 模型得分:1分最优 print('得分: %.2f / 1' % r2_score(diabetes_y_test, diabetes_y_pred)) # 使用PLT对结果输出 plt.scatter(diabetes_X_test, diabetes_y_test, color='black') plt.plot(diabetes_X_test, diabetes_y_pred, color='blue', linewidth=3) plt.xticks(()) plt.yticks(()) plt.show()

最终结果如图:

avatar

系数:
 [938.23786125]
均方差: 2548.07
得分: 0.47 / 1

发表评论

电子邮件地址不会被公开。