Sequential模型

xiaoxiao2021-02-28  38

Keras的核心数据结构是model,一种组织网络层的方式,最简单的数据模型是Sequential模型,它是由多个网络层线性堆叠的栈,对于更复杂的结构,你应该使用Keras函数式,它允许构建任意的神经网络图。

sequential模型具体步骤如下所示:

from keras.models importSequential model = Sequential()

可以简单地使用 .add() 来堆叠模型:

from keras.layers import Dense model.add(Dense(units= 64, activation= 'relu', input_dim= 100)) model.add(Dense(units= 10, activation= 'softmax'))

也可以是一个完整的形式:

from keras.models import Sequential from keras.layers import Dense, Activation model = Sequential([ Dense(32, input_dim=784), Activation('relu'), Dense(10), Activation('softmax'),]) 指定输入数据的 shape

模型需要知道输入数据的shape,因此, Sequential 的第一层需要接受一个关于输入数据shape的参数,后面的各个层则可以自动的推导出中间数据的 shape ,因此不需要为每个层都指定这个参数。有几种方法来为第一层指定输入数据的 shape:

1 传递一个input_shape的关键字参数给第一层, input_shape是一个tuple类型的数据,其中也可以填入None,则表示此位置可能是任何正整数。数据的batch大小不应包含在其中。2 传递一个 batch_input_shape 的关键字参数给第一层,该参数包含数据的batch大小。该参数在指定固定大小 batch 时比较有用,例如在 stateful RNNs 中。事实上, Keras 在内部会通过添加一个 None 将 input_shape 转化为 batch_input_shape。

3 有些2D层,如Dense,支持通过指定其输入维度input_dim来隐含的指定输入数据shape。一些3D的时域层支持通过参数input_dim和input_dim input_length来指定输入shape。

下面的三个指定输入数据 shape 的方法是严格等价的:

[html]  view plain  copy model = Sequential()  model.add(Dense(32, <span style="background-color:rgb(255,153,255);">input_shape</span>=(784,)))    model = Sequential()  model.add(Dense(32, <span style="background-color:rgb(255,153,255);">batch_input_shape</span>=(None, 784)))  # note that batch dimension is "None" here,  # so the model will be able to process batches of any size.</pre>     model = Sequential()  model.add(Dense(32,<span style="background-color:rgb(255,153,255);"> input_dim</span>=784))  

下面三种方法也是严格等价的:

[html]  view plain  copy    [html]  view plain  copy model = Sequential()  model.add(LSTM(32, <span style="background-color:rgb(255,153,255);"><span style="color:#330033;">input_shape=(10, 64)</span></span>))     model = Sequential()  model.add(LSTM(32, <span style="background-color:rgb(255,153,255);">batch_input_shape=</span>(None, 10, 64)))     model = Sequential()  model.add(LSTM(32,<span style="background-color:rgb(255,153,255);"> input_length=10input_dim=64))</span>  
转载请注明原文地址: https://www.6miu.com/read-1644746.html

最新回复(0)