学习 handson_ml 时,学到 LinearRegression,记录一下。其中用到了python的格式化输出,对于小数的格式化 formatspec,详细介绍了各种简写对应的格式。
# 绘制散点图
country_stats.plot(kind='scatter', x='GDP per capita', y='Life satisfaction')
# 设置坐标轴范围,先x后y
plt.axis([0, 60000, 0, 10])
# 准备数据
country_stats = prepare_country_stats(oecd_bli, gdp_per_captia)
# 按列堆叠1D数组,使其成为2D
X = np.c_[country_stats['GDP per capita']]
y = np.c_[country_stats['Life satisfaction']]
# 实例化线性模型,训练
model = LinearRegression()
model.fit(X, y)
# 预测新数据
X_new = [[22587]]
print(model.predict(X_new))
# 绘制线性模型
theta_0, theta_1= model.intercept_[0], model.coef_[0][0]
theta_0, theta_1
X = np.linspace(0, 60000, 1000)
plt.plot(X, theta_1 * X + theta_0, 'r-')
# g -> 通用格式,四舍五入,保留p个有效数字,会自动换科学计数法
plt.text(10000, 3.8, r'$y = \theta_0 + \theta_1 x$', fontsize=16, color='r')
plt.text(10000, 3, r'$\theta_0 = {:.4g}$'.format(theta_0), fontsize=14, color='r')
plt.text(10000, 2.3, r'$\theta_1 = {:.4g}$'.format(theta_1), fontsize=14, color='r')
plt.savefig('code_1-1.png')
plt.show()
output:
[[6.28653635]]