![深入浅出Python机器学习](https://wfqqreader-1252317822.image.myqcloud.com/cover/94/44510094/b_44510094.jpg)
4.1.2 线性模型的图形表示
大家肯定还记得,我们在初中数学(也可能是小学数学)中学过,两个点可以确定一条直线。假设有两个点,它们的坐标是(1,3)和(4,5),那么我们可以画一条直线来穿过这两个点,并且计算出这条直线的方程。下面我们在Jupyter Notebook中输入代码如下:
![](https://epubservercos.yuewen.com/89A715/23721687209561106/epubprivate/OEBPS/Images/Figure-P62_22491.jpg?sign=1738953983-tzqfM6x1qiUSKuNw64fFl8APOM5vXvFa-0-49c326704f5abc94ee9211314fd2aa36)
运行代码,将会得到如图4-2所示的结果。
![](https://epubservercos.yuewen.com/89A715/23721687209561106/epubprivate/OEBPS/Images/Figure-P63_4683.jpg?sign=1738953983-o3TxBtgikfyppvupQirUdutDYgzXGw5g-0-d1e3486e2e662a0756dbbad2017049c3)
图4-2 穿过点(1,3)和(4,5)的直线
【结果分析】图4-2表示的就是穿过上述两个数据点的直线,现在我们可以确定这条直线的方程。
在Jupyter Notebook中输入代码如下:
print('\n\n\n直线方程为:') print('==========\n') #打印直线方程 print('y = {:.3f}'.format(lr.coef_[0]),'x','+ {:.3f}'.format(lr.intercept_)) print('\n==========') print('\n\n\n')
运行代码,将会得到如图4-3所示的结果。
![](https://epubservercos.yuewen.com/89A715/23721687209561106/epubprivate/OEBPS/Images/Figure-P63_4736.jpg?sign=1738953983-HVKyqakJAptxrjDQF800ZzfxMffctY94-0-45f2f9bf11f96491fc333e2204b763b4)
图4-3 程序计算出的直线方程
【结果分析】通过程序的计算,我们很容易就可以得到这条直线的方程为
y = 0.667 x + 2.333
这是数据中只有2个点的情况,那如果是3个点会是怎样的情况呢?我们来实验一下,假设现在有第3个点,坐标是(3,3),我们把这个点添加进去,看会得到怎样的结果。输入代码如下:
![](https://epubservercos.yuewen.com/89A715/23721687209561106/epubprivate/OEBPS/Images/Figure-P64_22561.jpg?sign=1738953983-q1h8E0RV5D5e7gVQjEUKIfIMjBLSjqZj-0-5ae1e7861ac70955b28ea7ad4d2d9a10)
运行代码,会得到如图4-4所示的结果。
![](https://epubservercos.yuewen.com/89A715/23721687209561106/epubprivate/OEBPS/Images/Figure-P64_22568.jpg?sign=1738953983-5qLIC9988eQctEni2BdfTMBBlNqEbXUJ-0-59b7d165f96c10ee1c3fca76109ac2ce)
图4-4 对3个点进行拟合的线性模型
【结果分析】从图4-4中我们可以看到,这次直线没有穿过任何一个点,而是位于一个和3个点的距离相加最小的位置。
下面我们可以在计算出这条直线的方程,输入代码如下:
print('\n\n\n新的直线方程为:') print('==========\n') #打印直线方程 print('y = {:.3f}'.format(lr.coef_[0]),'x','+ {:.3f}'.format(lr.intercept_)) print('\n==========') print('\n\n\n')
![](https://epubservercos.yuewen.com/89A715/23721687209561106/epubprivate/OEBPS/Images/Figure-P65_22761.jpg?sign=1738953983-FPQzQaYzy3iQ1I5Z0JsIn3ff7bvW0kdH-0-9cbd5b3e2b84465e8dac0b665fd63617)
图4-5 对3个点进行拟合的线性模型方程
运行代码,将会得到结果如图4-5所示。
【结果分析】从图4-5中我们可以看到,新的直线方程和只有2个数据点的直线方程已经发生了变化。线性模型让自己距离每个数据点的加和为最小值。这也就是线性回归模型的原理。
当然,在实际应用中,数据量要远远大于2个或是3个,下面我们就用数量更多的数据来进行实验。
现在我们以scikit-klearn生成的make_regression数据集为例,用Python语句绘制一条线性模型的预测线,更加清晰地反映出线性模型的原理。在jupyter notebook中输入代码如下:
![](https://epubservercos.yuewen.com/89A715/23721687209561106/epubprivate/OEBPS/Images/Figure-P65_39583.jpg?sign=1738953983-cRyrxaE9YOtX3xv8IJK4GHeX1jXnqXwW-0-94120657c984dbda57d681ae6aebca56)
按下shift+回车键后,会得到结果如图4-6所示的结果。
![](https://epubservercos.yuewen.com/89A715/23721687209561106/epubprivate/OEBPS/Images/Figure-P65_22670.jpg?sign=1738953983-0xkqVQuuAY3HSahV3pLCAgqznaJSrazZ-0-1af88ec963302c556a80a65d198a6d62)
图4-6 线性回归模型的预测线
【结果分析】从图4-1中我们可以看出,黑色直线是线性回归模型在make_regression数据集中生成的预测线。接下来我们来看一下这条直线所对应的斜率和截距。
输入代码如下:
print('\n\n\n代码运行结果:') print('==========\n') #打印直线的系数和截距 print('直线的系数是:{:.2f}'.format(reg.coef_[0])) print('直线的截距是:{:.2f}'.format(reg.intercept_)) print('\n==========') print('\n\n\n')
运行代码,会得到结果如图4-7所示。
![](https://epubservercos.yuewen.com/89A715/23721687209561106/epubprivate/OEBPS/Images/Figure-P66_22808.jpg?sign=1738953983-3BygQ5sH2pdgCAcwhhvAlT79wpYJ574i-0-37c6f8c03c9e67e02075dfcc6d8b057a)
图4-7 直线的系数和截距
【结果分析】从图4-7中我们可以看到,在我们手工生成的数据集中,线性模型的方程为
y = 79.52 x +10.92
而这条直线距离50个数据点的距离之和,是最小的。这便是一般线性模型的原理。
注意 细心的读者可能注意到coef_和intercept_这两个属性非常奇怪,它们都是以下划线_结尾。这是sciki-learn的一个特点,它总是用下划线作为来自训练数据集的属性的结尾,以便将它们与由用户设置的参数区分开。