Tensorflow Dataset의 window 메소드 사용법을 파악해보려고 테스트해 본 내용입니다.
2021. 9. 14 - 최초작성
range 메소드를 사용하여 0 ~ 9 까지 값을 갖는 Dataset을 생성합니다.
as_numpy_iterator 메소드는 Dataset의 모든 요소를 numpy로 변환하는 iterator를 리턴하는데 이것을 리스트에 담아서 출력할 수 있습니다.
import tensorflow as tf
ds = tf.data.Dataset.range(10)
print(list(ds.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] |
첫번째 window 예제
window 메소드는 원본 Dataset의 요소를 3개씩 묶어서 서브 Dataset을 생성합니다.
window 메소드의 리턴값이 저장된 ds를 for문에 사용하면 서브 Dataset 하나씩 접근할 수 있습니다.
서브 Dataset d를 출력하기 위해 앞에서 한 것처럼 as_numpy_iterator 메소드를 사용합니다.
import tensorflow as tf
ds = tf.data.Dataset.range(10)
ds = ds.window(3)
for d in ds:
print(list(d.as_numpy_iterator()))
[0, 1, 2] [3, 4, 5] [6, 7, 8] [9] |
window 메소드에 3으로 지정할시 기본값은 다음과 같습니다.
각각의 의미는 계속 읽어보면 알 수 있습니다. 그러면 지금 예제의 결과 의미를 알 수 있습니다.
shift=3
stride=1
drop_remainder=Flase
drop_remainder
windows 메소드를 사용하여 크기 3인 서브 Dataset을 생성하도록 하고 drop_remainder=True로 지정하면 크기 3 미만으로 남은 요소를 버립니다.
그래서 출력결과에 9가 없습니다.
drop_remainder를 지정하지 않은 경우 디폴트값은 False입니다.
그래서 첫번째 window 메소드 예제에서 크기 3미만인 9가 출력되었습니다.
import tensorflow as tf
ds = tf.data.Dataset.range(10)
ds = ds.window(3, drop_remainder=True)
for d in ds:
print(list(d.as_numpy_iterator()))
[0, 1, 2] [3, 4, 5] [6, 7, 8] |
shift
window 메소드에서 크기 3으로 서브 Dataset을 생성하도록 하고 shift=1을 지정하면 서브 Dataset의 첫번째 요소의 값이 1씩 증가합니다. 그래서 실행결과 인접한 서브 Dataset의 첫번째 요소 간에 차이가 1입니다.
첫번째 window 예제에서는 shift를 지정하지 않았기 때문에 윈도우 크기인 3과 동일한 값을 디폴트로 사용합니다. 그래서 실행결과 인접한 서브 Dataset의 첫번째 요소 간에 차이가 3입니다.
import tensorflow as tf
ds = tf.data.Dataset.range(10)
ds = ds.window(3, shift=1, drop_remainder=True)
for d in ds:
print(list(d.as_numpy_iterator()))
[0, 1, 2] [1, 2, 3] [2, 3, 4] [3, 4, 5] [4, 5, 6] [5, 6, 7] [6, 7, 8] [7, 8, 9] |
stride
window 메소드에서 크기 3으로 서브 Dataset을 생성하도록 하고 stride=2로 지정하면 서브 Dataset의 요소 간격이 2가 됩니다. 그래서 첫번째 서브 Dataset의 요소는 2간격으로 0, 2, 4가 됩니다.
첫번째 window 예제에서는 stride를 지정하지 않았기 때문에 디폴트값으로 stride=1이 됩니다. 그래서 실행결과 인접한 서브 Dataset의 첫번째 요소 간에 차이가 1입니다.
import tensorflow as tf
ds = tf.data.Dataset.range(10)
ds = ds.window(3, stride=2, drop_remainder=True)
for d in ds:
print(list(d.as_numpy_iterator()))
[0, 2, 4] [3, 5, 7] |
함께 사용
window 메소드에서 크기를 3으로 지정했기 때문에 서브 Dataset의 크기는 3입니다.
shift=2로 지정했기 때문에 서브 Dataset의 첫번째 요소가 2씩 증가합니다.
stride=1로 지정했기 때문에 서브 Dataset의 요소들이 1씩 증가합니다.
import tensorflow as tf
ds= tf.data.Dataset.range(10)
ds = ds.window(3, shift=2, stride=1, drop_remainder=True)
for window in dataset:
print(list(window.as_numpy_iterator()))
[0, 1, 2] [2, 3, 4] [4, 5, 6] [6, 7, 8] |
NumPy 배열로 바꾸기
뒤에서 사용한 flat_map 메소드를 사용하지 않고 as_numpy_iterator 메소드를 사용할 수 있지만 Dataset 크기가 클 경우 상대적으로 처리속도가 느립니다.
import tensorflow as tf
import numpy as np
ds = tf.data.Dataset.range(100)
window_size = 10
ds = ds.window(window_size, shift=5, stride=3, drop_remainder=True)
list_d = []
for d in ds:
l = list(d.as_numpy_iterator())
list_d.append(l)
array_d = np.array(list_d)
print(array_d.shape)
print(array_d)
(15, 10) [[ 0 3 6 9 12 15 18 21 24 27] [ 5 8 11 14 17 20 23 26 29 32] [10 13 16 19 22 25 28 31 34 37] [15 18 21 24 27 30 33 36 39 42] [20 23 26 29 32 35 38 41 44 47] [25 28 31 34 37 40 43 46 49 52] [30 33 36 39 42 45 48 51 54 57] [35 38 41 44 47 50 53 56 59 62] [40 43 46 49 52 55 58 61 64 67] [45 48 51 54 57 60 63 66 69 72] [50 53 56 59 62 65 68 71 74 77] [55 58 61 64 67 70 73 76 79 82] [60 63 66 69 72 75 78 81 84 87] [65 68 71 74 77 80 83 86 89 92] [70 73 76 79 82 85 88 91 94 97]] |
flat_map 메소드를 사용한 경우입니다.
import tensorflow as tf
import numpy as np
ds = tf.data.Dataset.range(100)
window_size = 10
ds = ds.window(window_size, shift=5, stride=3, drop_remainder=True)
ds = ds.flat_map(lambda w: w.batch(window_size))
list_d = []
for d in ds:
list_d.append(d)
array_d = np.array(list_d)
print(array_d.shape)
print(array_d)
(15, 10) [[ 0 3 6 9 12 15 18 21 24 27] [ 5 8 11 14 17 20 23 26 29 32] [10 13 16 19 22 25 28 31 34 37] [15 18 21 24 27 30 33 36 39 42] [20 23 26 29 32 35 38 41 44 47] [25 28 31 34 37 40 43 46 49 52] [30 33 36 39 42 45 48 51 54 57] [35 38 41 44 47 50 53 56 59 62] [40 43 46 49 52 55 58 61 64 67] [45 48 51 54 57 60 63 66 69 72] [50 53 56 59 62 65 68 71 74 77] [55 58 61 64 67 70 73 76 79 82] [60 63 66 69 72 75 78 81 84 87] [65 68 71 74 77 80 83 86 89 92] [70 73 76 79 82 85 88 91 94 97]] |
참고
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#window
'Deep Learning & Machine Learning > 강좌&예제 코드' 카테고리의 다른 글
tfds build 시 에러 AssertionError: Unrecognized instruction format ( 또는 Unrecognized split format ) (0) | 2021.11.12 |
---|---|
Tensorflow 디버깅 정보 메시지 안보이게 하기 (0) | 2021.09.14 |
LSTM 레이어 사용시 cuDNN 관련 에러 나는 경우 해결방법 (0) | 2021.09.09 |
Unknown: OSError: cannot identify image file (0) | 2021.08.21 |
손글씨 숫자 인식하여 세븐 세그먼트에 출력하기 (0) | 2021.06.16 |
시간날때마다 틈틈이 이것저것 해보며 블로그에 글을 남깁니다.
블로그의 문서는 종종 최신 버전으로 업데이트됩니다.
여유 시간이 날때 진행하는 거라 언제 진행될지는 알 수 없습니다.
영화,책, 생각등을 올리는 블로그도 운영하고 있습니다.
https://freewriting2024.tistory.com
제가 쓴 책도 한번 검토해보세요 ^^
그렇게 천천히 걸으면서도 그렇게 빨리 앞으로 나갈 수 있다는 건.
포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!