explain the tf.scan() in detail
Official API for explaining the tf.scan().
definition
tf.scan(
fn,
elems,
initializer=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None
)
examples
import tensorflow as tf
import numpy as np
elems = np.array([1, 2, 3, 4, 5, 6])
sum = scan(lambda a, x: a + x, elems)
# sum == [1, 3, 6, 10, 15, 21]
iter step:
index | a | x | sum |
---|---|---|---|
1 | 1 | # | 1 |
2 | 1 | 2 | 3 |
3 | 3 | 3 | 6 |
4 | 6 | 4 | 10 |
5 | 10 | 5 | 15 |
6 | 15 | 6 | 21 |
output: sum = [1, 3, 6, 10, 15, 21]
import tensorflow as tf
import numpy as np
elems = np.array([1, 2, 3, 4, 5, 6])
initializer = np.array(0)
sum_one = scan(lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
# sum_one == [1, 2, 3, 4, 5, 6]
iter step:
index | a | x | sum |
---|---|---|---|
1 | 0 | [2, 1] | 2-1+0=1 |
2 | 1 | [3, 2] | 3-2+1=2 |
3 | 2 | [4, 3] | 3 |
4 | 3 | [5, 4] | 4 |
5 | 4 | [6, 5] | 5 |
6 | 5 | [7, 6] | 6 |
output: sum_one = [1, 2, 3, 4, 5, 6]
import tensorflow as tf
import numpy as np
elems = np.array([1, 0, 0, 0, 0, 0])
initializer = (np.array(0), np.array(1))
fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
# fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
iter step:
index | a | x | sum |
---|---|---|---|
1 | [0, 1] | 1 | [1], [1] |
2 | [1, 1] | 0 | [1, 1], [1, 2] |
3 | [1, 2] | 0 | [1, 1, 2], [1, 2, 3] |
4 | [2, 3] | 0 | [1, 1, 2, 3], [1, 2, 3, 5] |
5 | [3, 5] | 0 | [1, 1, 2, 3, 5], [1, 2, 3, 5, 8] |
6 | [5, 8] | 0 | [1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13] |
output: fibonaccis = ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])