Tensorflow:tf.assign()函数的使用方法及易错点

引言:
当大家在使用tf.assign()这个函数时,如果不是很了解这个函数的用法,很容易出错,而且似乎对应不同的tf版本其操作结果也会有细微的差别,本文是基于1.9.0版本的tf进行描述的,对于更新的版本而言应该结论是一样的,但对于比较旧的版本,可能就会有细微差别。


首先我们看一下源码中的返回值说明:

update = tf.assign(ref, new_value)    # 平时的使用写法
--------------------------------------------------------------------
Returns:A `Tensor` that will hold the new value of 'ref' afterthe assignment has completed.

也就是说,只有当这个赋值被完成时,该旧值ref才会被修改成new_value。不过这样描述还是太抽象了,那到底什么叫赋值被完成呢?下面我给大家放两个简单的例子,帮助大家理解

import tensorflow as tf ref_a = tf.Variable(tf.constant(1))
ref_b = tf.Variable(tf.constant(2))
update = tf.assign(ref_a, 10)
ref_sum = tf.add(ref_a, ref_b)with tf.Session() as sess:sess.run(tf.global_variables_initializer())print(sess.run(ref_sum))
------------------------------------------------------
输出结果:3

然后你就会感到奇怪,这里与往常的直觉不一样,理论上ref_a应该已经被修改为10了?带着疑问,我们看第二个例子

import tensorflow as tf ref_a = tf.Variable(tf.constant(1))
ref_b = tf.Variable(tf.constant(2))
update = tf.assign(ref_a, 10)
ref_sum = tf.add(ref_a, ref_b)with tf.Session() as sess:sess.run(tf.global_variables_initializer())sess.run(update)  # 唯一修改的地方print(sess.run(ref_sum))
------------------------------------------------------
输出结果:12

看到这里,是不是大家就明白了。所谓的赋值被完成其实指得是需要对tf.assign()函数的返回值执行一下sess.run()操作后,才能保证正常更新。

在明白了这个易错的地方后,我再介绍两种方法,来达到同样的目的。


方法一:采用ref_a = tf.assign(ref_a, 10)操作,我们看一下代码和运行结果

import tensorflow as tf ref_a = tf.Variable(tf.constant(1))
ref_b = tf.Variable(tf.constant(2))
ref_a = tf.assign(ref_a, 10)
ref_sum = tf.add(ref_a, ref_b)with tf.Session() as sess:sess.run(tf.global_variables_initializer())print(sess.run(ref_sum))
------------------------------------------------------
输出结果:12

事实上,tf.assign(ref, new_value)函数返回的结果就是参数中的new_value,因此我们只需要用ref来接收返回值也可以达到直接更新的效果

方法二:使用tf.control_dependencies()函数,我们也同样来看一下代码和结果

import tensorflow as tf ref_a = tf.Variable(tf.constant(1))
ref_b = tf.Variable(tf.constant(2))
update = tf.assign(ref_a, 10)with tf.control_dependencies([update]):ref_sum = tf.add(ref_a, ref_b)with tf.Session() as sess:sess.run(tf.global_variables_initializer())print(sess.run(ref_sum))
------------------------------------------------------
输出结果:12

可以发现,结果也为我们预期想要达到的效果,该函数保证其辖域中的操作必须要在该函数所传递的参数中的操作完成后再进行。简单地说,就是实际在运行时,会先执行该函数传递的参数update,再执行其辖域中的操作ref_sum = tf.add(ref_a, ref_b)


如果觉得我有地方讲的不好的或者有错误的欢迎给我留言,谢谢大家阅读(点个赞我可是会很开心的哦)~


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部