반응형


저수준 API로 작성된 MNIST 코드에서 사용하는 tf.argmax 함수에 대해 살펴봅니다. 


2018. 8.29 최초작성 

2020. 8. 1 내용 확인 및 Tensorflow 2.x에 맞게 수정



tf.argmax 두번째 인자값의 범위는 [-rank(input), rank(input))로 한정되어 있습니다. 텐서플로우에서 rank는 텐서의 원소 하나에 접근하기 위해 필요한 인덱스의 개수입니다. 


1차원 배열의 경우  최대 인덱스 개수는 1(=rank가 1 ) 이기 때문에 두번째 인자로 0만 사용할 수 있습니다.

(음의 범위는 논외로 합니다. ) 

한 방향으로 (1차원의 경우 열,행 구분이 없습니다.) 최대값을 찾아 인덱스 값을 찾을 수 있습니다. 


import tensorflow as tf

a = tf.constant([3, 10, 1])

print('a:\n', a.numpy())
print('인덱스의 개수 = ', tf.rank(a).numpy() )
print('tf.argmax(a, 0): 인덱스 ', tf.argmax(a, 0).numpy(), '이 가장 큽니다.')



a:

 [ 3 10  1]

인덱스의 개수 =  1

tf.argmax(a, 0): 인덱스  1 이 가장 큽니다.




tf.argmax(a, 0)는 1차원 배열에서 가장 큰 값을 찾아 인덱스를 리턴합니다.

10이 가장 크기 때문에 결과는 1이 됩니다. 


index            0   1  2   

      
                [ 3 10  1]




2차원 배열의 경우  최대 인덱스 개수는 2(=rank가 2 ) 이기 때문에 두번째 인자로 0과 1을 사용할 수 있습니다.

(음의 범위는 논외로 합니다. ) 

행과 열 방향으로 각각 최대값을 찾아서 인덱스 값을 찾을 수 있습니다. 


import tensorflow as tf

a = tf.constant([[5, 10, 17],[4, 50, 6]])

print('a:\n', a.numpy())
print('인덱스의 개수 = ', tf.rank(a).numpy())
print('tf.argmax(a, 0): 인덱스 ', tf.argmax(a, 0).numpy(), '가 가장 큽니다.')
print('tf.argmax(a, 1): 인덱스 ', tf.argmax(a, 1).numpy(), '가 가장 큽니다.')


a:

 [[ 5 10 17]

 [ 4 50  6]]

인덱스의 개수 =  2

tf.argmax(a, 0): 인덱스  [0 1 0] 가 가장 큽니다.

tf.argmax(a, 1): 인덱스  [2 1] 가 가장 큽니다.



tf.argmax(a, 0)는 2차원 배열의 각 열에서 가장 큰 값을 찾아 인덱스를 반환합니다.

첫번째 열에서 5, 두번째 열에서 50, 세번째 열에서 17을 찾았기 때문에 실행결과가 [0 1 0]입니다. 


index           

   0           [[ 5  10  17]
  1            [ 4  50   6]]



tf.argmax(a, 1)는 2차원 배열의 각 행에서 가장 큰 값을 찾아  인덱스를 반환합니다.. 

첫번째 행에서 17, 두번째 행에서 50을 찾았기 때문에 실행결과가 [2 1]입니다. 


index             0   1   2   

      
              [[ 5  10  17]
                [ 4  50   6]]




MNIST 코드에서는 one hot 벡터로 표현한 라벨이 의미하는 숫자를 찾기 위해 tf.argmax 함수를 사용됩니다. 


예를 들어 다음처럼 10개의 숫자중 세번째 인덱스의 값이 1이라면 이 라벨은 숫자 2를 의미합니다. 

이때 세번째 인덱스가 가장 큰 값임을 빨리 찾기위해 tf.argmax 함수를 사용합니다.

 

[ 0 0 1 0 0 0 0 0 0 0]



라벨이 1차원 벡터인데 실제 코드에서 보면 tf.argmax 함수의 두번째 인자로 1을 사용하고 있습니다. 

이것은 pred와 y의 shape를 출력해보면 알 수 있습니다. (?, 10) 로 출력됩니다. 첫번째 차원은 라벨의 갯수를 표현하기 위해 사용되므로 크기가 정해져 있지 않으며 두번째 차원은 0~9까지 10개의 숫자를 위한 라벨로 사용하기 때문에 10입니다.


두번째 차원을 라벨로 사용하기 때문에 0이 아닌 1을 사용하게 됩니다. 즉 각 행에서 최대값을 찾습니다. 


correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))



반응형

문제 발생시 지나치지 마시고 댓글 남겨주시면 가능한 빨리 답장드립니다.

도움이 되셨다면 토스아이디로 후원해주세요.
https://toss.me/momo2024


제가 쓴 책도 한번 검토해보세요 ^^

+ Recent posts