TensorFlow实现线性回归

xiaoxiao2021-02-28  70

注:本文来自《面向机器智能的TensorFlow实践》一书,请去购买正版书:https://item.jd.com/12176592.html

在有监督学习问题中,线性回归是一种最简单的建模手段。给定一个数据点集合作为训练集,线性回归的目标是找到一个与这些数据最为吻合的线性函数。对于2D数据,这样的函数对应一条直线。

上图展示了一个2D情形下的线性回归模型。图中的点代表训练数据,而直线代表模型的推断结果。

下面运用少量数学公式解释线性回归模型的基本原理。线性函数的一般表达式为:

其矩阵(或张量)形式为:

·Y为待预测的值。

·x1,x2,…,xk是一组独立的预测变量;在使用模型对新样本进行预测时,需要提供这些值。若采用矩阵形式,可一次性提供多个样本,其中每行对应一个样本。

·w1,w2,…,wk为模型从训练数据中学习到的参数,或赋予每个变量的“权值”。

·b也是一个学习到的参数,这个线性函数中的常量也称为模型的偏置(bias)。

下面用代码来表示这种模型。这里没有使用权值的转置,而是将它们定义为单个列向量:

# 初始化变量或模型参数 W = tf.Variable(tf.zeros([2, 1]), name="weights") b = tf.Variable(0., name="bias") def inference(X): return tf.matmul(X, W) + b

接下来需要定义如何计算损失。对于这种简单的模型,将采用总平方误差,即模型对每个训练样本的预测值与期望输出之差的平方的总和。从代数角度看,这个损失函数实际上是预测的输出向量与期望向量之间欧氏距离的平方。对于2D数据集,总平方误差对应于每个数据点在垂直方向上到所预测的回归直线的距离的平方总和。这种损失函数也称为L2范数或L2损失函数。这里之所以采用平方,是为了避免计算平方根,因为对于最小化损失这个目标,有无平方并无本质区别,但有平方可以节省一定的计算量。

我们需要遍历i来求和,其中i为数据样本的索引。该函数的实现如下:

def loss(X, Y): Y_predicted = inference(X) return tf.reduce_sum(tf.squared_difference(Y, Y_predicted))

接下来便可用数据实际训练模型。例如,将准备使用一个将年龄、体重(单位:千克)与血液脂肪含量关联的数据集(http://people.sc.fsu.edu/~jburkardt/datasets/regression/x09.txt)。 由于这个数据集规模很小,下面直接将其嵌入在代码中。下一节将演示如何像实际应用场景中那样从文件中读取训练数据。

def inputs(): # Data from http://people.sc.fsu.edu/~jburkardt/datasets/regression/x09.txt weight_age = [[84, 46], [73, 20], [65, 52], [70, 30], [76, 57], [69, 25], [63, 28], [72, 36], [79, 57], [75, 44], [27, 24], [89, 31], [65, 52], [57, 23], [59, 60], [69, 48], [60, 34], [79, 51], [75, 50], [82, 34], [59, 46], [67, 23], [85, 37], [55, 40], [63, 30]] blood_fat_content = [354, 190, 405, 263, 451, 302, 288, 385, 402, 365, 209, 290, 346, 254, 395, 434, 220, 374, 308, 220, 311, 181, 274, 303, 244] return tf.to_float(weight_age), tf.to_float(blood_fat_content)

下面定义模型的训练运算。我们将采用梯度下降算法对模型参数进行优化

def train(total_loss): learning_rate = 0.0000001 return tf.train.GradientDescentOptimizer(learning_rate).minimize(total_loss)

运行上述代码时,将看到损失函数的值随训练步数的增加呈现逐渐减小的趋势。

模型训练完毕后,便需要对其进行评估。下面计算一个年龄25岁、体重80千克的人的血液脂肪含量,这个数据并未在训练集中出现过,但可将预测结果与同年龄的、体重65千克的人进行比较:

def evaluate(sess, X, Y): print (sess.run(inference([[80., 25.]]))) # ~ 303 print (sess.run(inference([[65., 25.]]))) # ~ 256

作为一种快速评估方法,可验证该模型学习到了血液脂肪含量随体重下降的衰减情况,且输出值介于原始数据训练值的边界之间。

运行结果:

loss: [7608772.5] loss: [5352849.5] loss: [5350043.5] loss: [5347918.5] loss: [5346300.5] loss: [5345062.0] loss: [5344106.0] loss: [5343361.0] loss: [5342774.5] loss: [5342306.0] loss: [5341925.5] loss: [5341611.5] loss: [5341345.0] loss: [5341115.5] loss: [5340913.0] loss: [5340733.0] loss: [5340566.5] loss: [5340413.5] loss: [5340268.0] loss: [5340128.0] loss: [5339993.0] loss: [5339860.5] loss: [5339733.5] loss: [5339606.0] loss: [5339481.5] loss: [5339358.0] loss: [5339234.5] loss: [5339112.0] loss: [5338989.5] loss: [5338869.0] loss: [5338747.0] loss: [5338624.5] loss: [5338504.5] loss: [5338384.0] loss: [5338262.0] loss: [5338141.0] loss: [5338022.5] loss: [5337900.5] loss: [5337780.0] loss: [5337661.0] loss: [5337538.5] loss: [5337418.5] loss: [5337297.5] loss: [5337177.5] loss: [5337056.5] loss: [5336936.5] loss: [5336815.0] loss: [5336695.5] loss: [5336575.0] loss: [5336455.0] loss: [5336334.0] loss: [5336213.0] loss: [5336092.5] loss: [5335973.0] loss: [5335852.5] loss: [5335732.5] loss: [5335611.5] loss: [5335491.5] loss: [5335370.5] loss: [5335250.0] loss: [5335129.5] loss: [5335009.0] loss: [5334889.0] loss: [5334768.5] loss: [5334647.0] loss: [5334526.5] loss: [5334409.0] loss: [5334287.0] loss: [5334166.5] loss: [5334047.5] loss: [5333926.0] loss: [5333806.0] loss: [5333685.5] loss: [5333565.0] loss: [5333446.0] loss: [5333325.0] loss: [5333204.5] loss: [5333085.0] loss: [5332965.0] loss: [5332844.5] loss: [5332724.0] loss: [5332603.5] loss: [5332485.0] loss: [5332363.0] loss: [5332243.5] loss: [5332123.5] loss: [5332002.5] loss: [5331883.5] loss: [5331763.0] loss: [5331642.0] loss: [5331524.0] loss: [5331403.5] loss: [5331282.5] loss: [5331162.0] loss: [5331042.0] loss: [5330922.5] loss: [5330802.5] loss: [5330681.5] loss: [5330562.5] loss: [5330442.5] [[ 320.64968872]] [[ 267.78182983]]
转载请注明原文地址: https://www.6miu.com/read-27635.html

最新回复(0)