Tensoflw.js - 03 - 实战训练-拟合曲线

xiaoxiao2022-06-11  32

Tensoflw.js - 03 - 实战训练-拟合曲线

参考 W3Cschool 文档:https://www.w3cschool.cn/tensorflowjs/ 本文主要翻译一些英文注释,添加通俗的注释,记录新手使用遇到的小问题,去除不必要的部分,帮助新手快速入门

上一篇介绍了 模型与内存管理 本篇是实战训练

Tensorflow.js 拟合曲线

这篇文章中,我们将使用 TensorFlow.js 来根据数据拟合曲线。即使用多项式产生数据然后再改变其中某些数据(点),然后我们会训练模型来找到用于产生这些数据的多项式的系数。简单的说,就是给一些在二维坐标中的散点图,然后我们建立一个系数未知的多项式,通过TensorFlow.js来训练模型,最终找到这些未知的系数,让这个多项式和散点图拟合。

先决条件 本教程假定您熟悉核心概念中介绍的TensorFlow.js的基本构建块:张量,变量和操作。 我们建议在完成本教程之前先完成核心概念的学习。

运行代码 这篇文章关注的是创建模型以及学习模型的系数,完整的代码在这里可以找到。为了在本地运行,如下所示:

$ git clone https://github.com/tensorflow/tfjs-examples $ cd tfjs-examples/polynomial-regression-core $ yarn $ yarn watch

即首先将核心代码下载到本地,然后进入polynomial-regression-core(即多项式回归核心)部分,最后进行yarn安装并运行。

输入数据 我们的数据集由x坐标和y坐标组成,当绘制在笛卡尔平面上时,其坐标如下所示:

该数据是由三次方程 y = ax^{3}+ bx^{2} + cx + d 生成的。

我们的任务是学习这个函数的a,b,c和d系数以最好地拟合数据。 我们来看看如何使用TensorFlow.js操作来学习这些值。

学习步骤 第1步:设置变量 首先,我们需要创建一些变量。即开始我们是不知道a、b、c、d的值的,所以先给他们一个随机数,入戏所示:

const a = tf.variable(tf.scalar(Math.random())); const b = tf.variable(tf.scalar(Math.random())); const c = tf.variable(tf.scalar(Math.random())); const d = tf.variable(tf.scalar(Math.random()));

第2步:建立模型 我们可以通过TensorFlow.js中的链式调用操作来实现这个多项式方程 y = ax3 + bx2 + cx + d,下面的代码就创建了一个 predict 函数,这个函数将x作为输入,y作为输出:

function predict(x) { // y = a * x ^ 3 + b * x ^ 2 + c * x + d return tf.tidy(() => { return a.mul(x.pow(tf.scalar(3))) // a * x^3 .add(b.mul(x.square())) // + b * x ^ 2 .add(c.mul(x)) // + c * x .add(d); // + d }); }

其中,在上一篇文章中,我们讲到tf.tify函数用来清除中间张量,其他的都很好理解。

接着,让我们把这个多项式函数的系数使用之前得到的随机数,可以看到,得到的图应该是这样:

因为开始时,我们使用的系数是随机数,所以这个函数和给定的数据匹配的非常差,而我们写的模型就是为了通过学习得到更精确的系数值。

第3步:训练模型 最后一步就是要训练这个模型使得系数和这些散点更加匹配,而为了训练模型,我们需要定义下面的三样东西:

损失函数(loss function):这个损失函数代表了给定多项式和数据的匹配程度。 损失函数值越小,那么这个多项式和数据就跟匹配。 优化器(optimizer):这个优化器实现了一个算法,它会基于损失函数的输出来修正系数值。所以优化器的目的就是尽可能的减小损失函数的值。 训练迭代器(traing loop):即它会不断地运行这个优化器来减少损失函数。 所以,上面这三样东西的 关系就非常清楚了: 训练迭代器使得优化器不断运行,使得损失函数的值不断减小,以达到多项式和数据尽可能匹配的目的。这样,最终我们就可以得到a、b、c、d较为精确的值了。

定义损失函数 这篇文章中,我们使用MSE(均方误差,mean squared error)作为我们的损失函数。MSE的计算非常简单,就是先根据给定的x得到实际的y值与预测得到的y值之差 的平方,然后在对这些差的平方求平均数即可。 于是,我们可以这样定义MSE损失函数:

function loss(predictions, labels) { // 将labels(实际的值)进行抽象 // 然后获取平均数. const meanSquareError = predictions.sub(labels).square().mean(); return meanSquareError; }
转载请注明原文地址: https://www.6miu.com/read-4931361.html

最新回复(0)