TensorFlow中的constant,Variable,placeholder

概况

介绍TensorFlow中常量(constant)、变量(Variable)和占位符(placeholder)的用法。

tf.constant

创建TensorFlow常量。

定义

1
2
3
4
5
6
7
tf.constant(
value,
dtype=None,
shape=None,
name='Const',
verify_shape=False
)

其中value时TensorFlow常量存储的值,dtype是value元素的类型,shape为该常量的形状,name是该常量的别名,verify_shape为布尔型表示是否验证形状。

例子

例子1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import tensorflow as tf
a = tf.constant(2.0)
with tf.Session() as sess:
print(sess.run(a))
#output:
#2.0
```

例子2
```python
import tensorflow as tf
a = tf.constant(2.0)
b = tf.constant(5.0)
c = tf.add(a, b)
with tf.Session() as sess:
print(sess.run(c))
#output:
#7.0

tf.Variable

创建TensorFlow变量,可以用变量来存储和更新训练模型时的参数。变量包含张量(Tensor)存放于内存的缓冲区。建立时需要对它们进行明确的初始化,模型训练完后它们也必须被存储到磁盘中。

定义

1
2
3
4
tf.Variable(
initial_value,
name='Variable'
)

详细说明参见官网API文档

例子

例子1

1
2
3
4
5
6
7
8
import tensorflow as tf
a = tf.Variable(3.0)
init = tf.variables_initializer([a])
with tf.Session() as sess:
sess.run(init)
print(sess.run(a))
#output:
#3.0

变量定义后需要初始化才能使用。

例子2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import tensorflow as tf
a = tf.Variable([1,2,3,4,5,6])
init = tf.variables_initializer([a])
with tf.Session() as sess:
sess.run(init)
print(sess.run(a))
#output:
#[1 2 3 4 5 6]
update = tf.assign(a,[2,3,4,5,6,7])
with tf.Session() as sess:
sess.run(update)
print(sess.run(a))
#output:
#[2 3 4 5 6 7]

这里演示的是如何更新TensorFlow变量的值,需要使用tf.assign函数。

tf.placeholder

创建TensorFlow占位符。

定义

1
2
3
4
5
tf.placeholder(
dtype,
shape=None,
name=None
)

其中dtype是占位符值元素的类型,shape为占位符形状,name为别名。

例子

1
2
3
4
5
6
7
8
9
10
import tensorflow as tf
import numpy as np
a = tf.placeholder(tf.float32, shape=(2,3))
b = tf.placeholder(tf.float32, shape=(3,2))
y = tf.matmul(a, b)
with tf.Session() as sess:
print(sess.run(y, feed_dict={a:[[1,2,3],[3,4,5]], b:[[1,2],[2,3],[3,4]]}))
#output:
#[[14. 20.]
# [26. 38.]]

存在占位符的操作在sess.run的时候需要“喂数据”才能正确的完成操作运算。

总结

constant、Variable和placeholder是TensorFlow中的数据类型,它们都属于张量(Tensor)数据类型。

参考

0%