- Thu 26 April 2018
- python
- mani3
- #python tensorflow
Tensorflow で行列操作とか、計算する関数とかよくわからなくなるときがあるのでメモ
tf.gather
row_indices = [1]
row = tf.gather(tf.constant([[1, 2],[3, 4]]), row_indices, axis=1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(tf.Session().run(row))
# [[2]
# [4]]
tf.where
a = np.array([[1, 2],[3, 4]])
x = tf.placeholder(tf.int32, shape=(2, 2), name='x')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
b = sess.run(tf.where(x > 2, tf.subtract(x, 1), tf.fill([2, 2], 0)), feed_dict={x: a})
print(b)
# [[0 0]
# [2 3]]
tf.argmax
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
a = tf.constant([[0.5, 0.5],[0.1, 0.9]])
b = tf.constant([[0, 1],[0, 1]])
result = sess.run([tf.argmax(a, 1), tf.argmax(b, 1)])
print(result)
# [array([0, 1]), array([1, 1])]
tf.reshape
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
a = tf.constant([0.5, 0.5, 0.1, 0.9])
result = sess.run(tf.reshape(a, [-1, 1]))
print(result)
# [[0.5]
# [0.5]
# [0.1]
# [0.9]]
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
a = tf.constant([0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3])
result = sess.run(tf.reshape(a, [-1, 3, 9]))
print(result)
# [[[0.5 0.5 0.5 0.4 0.4 0.4 0.3 0.3 0.3]
# [0.5 0.5 0.5 0.4 0.4 0.4 0.3 0.3 0.3]
# [0.5 0.5 0.5 0.4 0.4 0.4 0.3 0.3 0.3]]]
IOU ぽいの
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
a = tf.constant([[0.5, 0.5],[0.1, 0.9]])
b = tf.constant([[0, 1],[0, 1]])
y1 = tf.cast(tf.greater(tf.gather(a, indices=[1], axis=1), 0.5), tf.bool)
y2 = tf.cast(tf.gather(b, indices=[1], axis=1), tf.bool)
intersection = tf.reduce_sum(tf.cast(tf.logical_and(y1, y2), tf.int32))
union = tf.reduce_sum(tf.cast(tf.logical_or(y1, y2), tf.int32))
result = sess.run([y1, y2, intersection, union, tf.divide(intersection, union)])
print(result)
# [array([[False],
# [ True]]), array([[ True],
# [ True]]), 1, 2, 0.5]