问题 如何使用累积渐变更新模型参数?


我正在使用TensorFlow构建一个深度学习模型。 TensorFlow的新手。

由于某种原因,我的模型具有有限的批量大小,然后这种有限的批量大小将使模型具有高的方差。

所以,我想用一些技巧来增加批量。我的想法是存储每个小批量的梯度,例如64个小批量,然后将梯度相加,使用这64个小批量训练数据的平均梯度来更新模型的参数。

这意味着对于前63个小批量,不更新参数,并且在64迷你批次之后,仅更新模型的参数一次。

但是由于TensorFlow是基于图形的,有谁知道如何实现这个想要的功能?

非常感谢。


5509
2018-02-10 10:23


起源

是个 同步副本优化器 你在找什么? - Allen Lavoie
似乎我可以存储所有中间渐变,然后计算渐变的平均值,然后更新模型参数。 - weixsong
同步副本优化器似乎适用于多GPU并行训练。我会调查它,看看我是否可以利用它。 - weixsong


答案:


我在这里找到了解决方案: https://github.com/tensorflow/tensorflow/issues/3994#event-766328647

opt = tf.train.AdamOptimizer()
tvs = tf.trainable_variables()
accum_vars = [tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) for tv in tvs]                                        
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_vars]
gvs = opt.compute_gradients(rmse, tvs)
accum_ops = [accum_vars[i].assign_add(gv[0]) for i, gv in enumerate(gvs)]
train_step = opt.apply_gradients([(accum_vars[i], gv[1]) for i, gv in enumerate(gvs)])

在训练循环中:

while True:
    sess.run(zero_ops)
    for i in xrange(n_minibatches):
        sess.run(accum_ops, feed_dict=dict(X: Xs[i], y: ys[i]))
    sess.run(train_step)

但是这段代码看起来不是很干净漂亮,有谁知道如何优化这些代码?


7
2018-02-14 09:58



在keras这可能吗? - verystrongjoe


我有同样的问题,只是弄清楚了。

首先得到符号渐变,然后将累积的渐变定义为tf.Variables。 (看起来 tf.global_variables_initializer() 必须在定义之前运行 grads_accum。否则我会收到错误,不知道为什么。)

tvars = tf.trainable_variables()
optimizer = tf.train.GradientDescentOptimizer(lr)
grads = tf.gradients(cost, tvars)

# initialize
tf.local_variables_initializer().run()
tf.global_variables_initializer().run()

grads_accum = [tf.Variable(tf.zeros_like(v)) for v in grads] 
update_op = optimizer.apply_gradients(zip(grads_accum, tvars)) 

在训练中,你可以累积渐变(保存在 gradients_accum)在每个批次中,并在运行第64批后更新模型:

feed_dict = dict()
for i, _grads in enumerate(gradients_accum):
    feed_dict[grads_accum[i]] = _grads
sess.run(fetches=[update_op], feed_dict=feed_dict) 

你可以参考 tensorflow / tensorflow /蟒蛇/培训/ optimizer_test.py 例如用法,特别是这个功能: testGradientsAsVariables()

希望能帮助到你。


2
2018-06-17 00:23



我不认为这段代码与这个问题有关。客户在什么时候总结,即积累?此外,在您所指的示例中,渐变不会累积;它们是由w.r.t计算的。独立地输入两个输入。 - Giorgos Sfikas