Tensorflow Dataset의 window 메소드 사용법
Tensorflow Dataset의 window 메소드 사용법을 파악해보려고 테스트해 본 내용입니다.
2021. 9. 14 - 최초작성
range 메소드를 사용하여 0 ~ 9 까지 값을 갖는 Dataset을 생성합니다.
as_numpy_iterator 메소드는 Dataset의 모든 요소를 numpy로 변환하는 iterator를 리턴하는데 이것을 리스트에 담아서 출력할 수 있습니다.
import tensorflow as tf
ds = tf.data.Dataset.range(10)
print(list(ds.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] |
첫번째 window 예제
window 메소드는 원본 Dataset의 요소를 3개씩 묶어서 서브 Dataset을 생성합니다.
window 메소드의 리턴값이 저장된 ds를 for문에 사용하면 서브 Dataset 하나씩 접근할 수 있습니다.
서브 Dataset d를 출력하기 위해 앞에서 한 것처럼 as_numpy_iterator 메소드를 사용합니다.
import tensorflow as tf
ds = tf.data.Dataset.range(10)
ds = ds.window(3)
for d in ds:
print(list(d.as_numpy_iterator()))
[0, 1, 2] [3, 4, 5] [6, 7, 8] [9] |
window 메소드에 3으로 지정할시 기본값은 다음과 같습니다.
각각의 의미는 계속 읽어보면 알 수 있습니다. 그러면 지금 예제의 결과 의미를 알 수 있습니다.
shift=3
stride=1
drop_remainder=Flase
drop_remainder
windows 메소드를 사용하여 크기 3인 서브 Dataset을 생성하도록 하고 drop_remainder=True로 지정하면 크기 3 미만으로 남은 요소를 버립니다.
그래서 출력결과에 9가 없습니다.
drop_remainder를 지정하지 않은 경우 디폴트값은 False입니다.
그래서 첫번째 window 메소드 예제에서 크기 3미만인 9가 출력되었습니다.
import tensorflow as tf
ds = tf.data.Dataset.range(10)
ds = ds.window(3, drop_remainder=True)
for d in ds:
print(list(d.as_numpy_iterator()))
[0, 1, 2] [3, 4, 5] [6, 7, 8] |
shift
window 메소드에서 크기 3으로 서브 Dataset을 생성하도록 하고 shift=1을 지정하면 서브 Dataset의 첫번째 요소의 값이 1씩 증가합니다. 그래서 실행결과 인접한 서브 Dataset의 첫번째 요소 간에 차이가 1입니다.
첫번째 window 예제에서는 shift를 지정하지 않았기 때문에 윈도우 크기인 3과 동일한 값을 디폴트로 사용합니다. 그래서 실행결과 인접한 서브 Dataset의 첫번째 요소 간에 차이가 3입니다.
import tensorflow as tf
ds = tf.data.Dataset.range(10)
ds = ds.window(3, shift=1, drop_remainder=True)
for d in ds:
print(list(d.as_numpy_iterator()))
[0, 1, 2] [1, 2, 3] [2, 3, 4] [3, 4, 5] [4, 5, 6] [5, 6, 7] [6, 7, 8] [7, 8, 9] |
stride
window 메소드에서 크기 3으로 서브 Dataset을 생성하도록 하고 stride=2로 지정하면 서브 Dataset의 요소 간격이 2가 됩니다. 그래서 첫번째 서브 Dataset의 요소는 2간격으로 0, 2, 4가 됩니다.
첫번째 window 예제에서는 stride를 지정하지 않았기 때문에 디폴트값으로 stride=1이 됩니다. 그래서 실행결과 인접한 서브 Dataset의 첫번째 요소 간에 차이가 1입니다.
import tensorflow as tf
ds = tf.data.Dataset.range(10)
ds = ds.window(3, stride=2, drop_remainder=True)
for d in ds:
print(list(d.as_numpy_iterator()))
[0, 2, 4] [3, 5, 7] |
함께 사용
window 메소드에서 크기를 3으로 지정했기 때문에 서브 Dataset의 크기는 3입니다.
shift=2로 지정했기 때문에 서브 Dataset의 첫번째 요소가 2씩 증가합니다.
stride=1로 지정했기 때문에 서브 Dataset의 요소들이 1씩 증가합니다.
import tensorflow as tf
ds= tf.data.Dataset.range(10)
ds = ds.window(3, shift=2, stride=1, drop_remainder=True)
for window in dataset:
print(list(window.as_numpy_iterator()))
[0, 1, 2] [2, 3, 4] [4, 5, 6] [6, 7, 8] |
NumPy 배열로 바꾸기
뒤에서 사용한 flat_map 메소드를 사용하지 않고 as_numpy_iterator 메소드를 사용할 수 있지만 Dataset 크기가 클 경우 상대적으로 처리속도가 느립니다.
import tensorflow as tf
import numpy as np
ds = tf.data.Dataset.range(100)
window_size = 10
ds = ds.window(window_size, shift=5, stride=3, drop_remainder=True)
list_d = []
for d in ds:
l = list(d.as_numpy_iterator())
list_d.append(l)
array_d = np.array(list_d)
print(array_d.shape)
print(array_d)
(15, 10) [[ 0 3 6 9 12 15 18 21 24 27] [ 5 8 11 14 17 20 23 26 29 32] [10 13 16 19 22 25 28 31 34 37] [15 18 21 24 27 30 33 36 39 42] [20 23 26 29 32 35 38 41 44 47] [25 28 31 34 37 40 43 46 49 52] [30 33 36 39 42 45 48 51 54 57] [35 38 41 44 47 50 53 56 59 62] [40 43 46 49 52 55 58 61 64 67] [45 48 51 54 57 60 63 66 69 72] [50 53 56 59 62 65 68 71 74 77] [55 58 61 64 67 70 73 76 79 82] [60 63 66 69 72 75 78 81 84 87] [65 68 71 74 77 80 83 86 89 92] [70 73 76 79 82 85 88 91 94 97]] |
flat_map 메소드를 사용한 경우입니다.
import tensorflow as tf
import numpy as np
ds = tf.data.Dataset.range(100)
window_size = 10
ds = ds.window(window_size, shift=5, stride=3, drop_remainder=True)
ds = ds.flat_map(lambda w: w.batch(window_size))
list_d = []
for d in ds:
list_d.append(d)
array_d = np.array(list_d)
print(array_d.shape)
print(array_d)
(15, 10) [[ 0 3 6 9 12 15 18 21 24 27] [ 5 8 11 14 17 20 23 26 29 32] [10 13 16 19 22 25 28 31 34 37] [15 18 21 24 27 30 33 36 39 42] [20 23 26 29 32 35 38 41 44 47] [25 28 31 34 37 40 43 46 49 52] [30 33 36 39 42 45 48 51 54 57] [35 38 41 44 47 50 53 56 59 62] [40 43 46 49 52 55 58 61 64 67] [45 48 51 54 57 60 63 66 69 72] [50 53 56 59 62 65 68 71 74 77] [55 58 61 64 67 70 73 76 79 82] [60 63 66 69 72 75 78 81 84 87] [65 68 71 74 77 80 83 86 89 92] [70 73 76 79 82 85 88 91 94 97]] |
참고
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#window