sess.run()详解
TensorFlow与我们正常的编程思维略有不同:
创建session对象
tf.compat.v1.Session(
target=’’, graph=None, config=None
)
会话可能拥有资源,当不再需要这些资源时,释放这些资源很重要。为此,要么调用tf.Session.close会话上的方法,要么将会话用作上下文管理器。下面两个例子是等价的:
# Using the `close()` method.
sess = tf.compat.v1.Session()
sess.run(...)
sess.close() #关闭此会话
# Using the context manager.
with tf.compat.v1.Session() as sess:
sess.run(...)
Session对象的方法run()
run(fetches, feed_dict=None, options=None, run_metadata=None)
参数:
返回值: sess.run()可以将tensor格式转成numpy格式,在python语言中,返回的tensor是numpy ndarray对象。
feed_dict 作用
feed_dict只在调用它的方法内有效,方法结束,feed_dict就会消失。
替换graph中的某个tensor
feed_dict使用一个值临时替换一个 op 的输出结果
a = tf.add(2, 5)
b = tf.multiply(a, 3)
with tf.Session() as sess:
sess.run(b, feed_dict = {a:15}) # 重新给a赋值为15 运行结果:45
sess.run(b) #feed_dict只在调用它的方法内有效,方法结束,feed_dict就会消失。 所以运行结果是:21
设置graph的输入值
feed_dict可以给使用placeholder创建出来的tensor赋值
# 使用tf.placeholder()创建占位符 ,在session.run()过程中再投递数据
x = tf.placeholder(tf.float32, shape=(1, 2)) #设置矩阵x,用占位符表示
w1 = tf.Variable(tf.random_normal([2, 3],stddev=1,seed=1))
w2 = tf.Variable(tf.random_normal([3, 1],stddev=1,seed=1))
# 创建op, 设置矩阵相乘
a = tf.matmul(x,w1)
y = tf.matmul(a,w2)
# 启动graph,运行op
with tf.Session() as sess:
# 对变量进行初始化,变量运行前必须做初始化操作
sess.run(tf.global_variables_initializer())
# 使用feed_dict将数据投入到y中
print(sess.run(y, feed_dict={x:[[0.7, 0.5]]})) #运行结果: [[3.0904665]]
注意:此时的x是一个占位符(placeholder)。我们定义了它的type和shape,但是并没有具体的值。在后面定义graph的代码中,placeholder看上去和普通的tensor对象一样。在运行程序的时候我们用feed_dict的方式把具体的值提供给placeholder,达到了给graph提供input的目的。
placeholder有点像在定义函数的时候用到的参数。我们在写函数内部代码的时候,虽然用到了参数,但并不知道参数所代表的值。只有在调用函数的时候,我们才把具体的值传递给参数。
来源:Jqlender