저수준 API로 작성된 MNIST 코드에서 사용하는 tf.argmax 함수에 대해 살펴봅니다. 


2018. 8.29 최초작성 

2020. 8. 1 내용 확인 및 Tensorflow 2.x에 맞게 수정



tf.argmax 두번째 인자값의 범위는 [-rank(input), rank(input))로 한정되어 있습니다. 텐서플로우에서 rank는 텐서의 원소 하나에 접근하기 위해 필요한 인덱스의 개수입니다. 


1차원 배열의 경우  최대 인덱스 개수는 1(=rank가 1 ) 이기 때문에 두번째 인자로 0만 사용할 수 있습니다.

(음의 범위는 논외로 합니다. ) 

한 방향으로 (1차원의 경우 열,행 구분이 없습니다.) 최대값을 찾아 인덱스 값을 찾을 수 있습니다. 


import tensorflow as tf

a = tf.constant([3, 10, 1])

print('a:\n', a.numpy())
print('인덱스의 개수 = ', tf.rank(a).numpy() )
print('tf.argmax(a, 0): 인덱스 ', tf.argmax(a, 0).numpy(), '이 가장 큽니다.')



a:

 [ 3 10  1]

인덱스의 개수 =  1

tf.argmax(a, 0): 인덱스  1 이 가장 큽니다.




tf.argmax(a, 0)는 1차원 배열에서 가장 큰 값을 찾아 인덱스를 리턴합니다.

10이 가장 크기 때문에 결과는 1이 됩니다. 


index            0   1  2   

      
                [ 3 10  1]




2차원 배열의 경우  최대 인덱스 개수는 2(=rank가 2 ) 이기 때문에 두번째 인자로 0과 1을 사용할 수 있습니다.

(음의 범위는 논외로 합니다. ) 

행과 열 방향으로 각각 최대값을 찾아서 인덱스 값을 찾을 수 있습니다. 


import tensorflow as tf

a = tf.constant([[5, 10, 17],[4, 50, 6]])

print('a:\n', a.numpy())
print('인덱스의 개수 = ', tf.rank(a).numpy())
print('tf.argmax(a, 0): 인덱스 ', tf.argmax(a, 0).numpy(), '가 가장 큽니다.')
print('tf.argmax(a, 1): 인덱스 ', tf.argmax(a, 1).numpy(), '가 가장 큽니다.')


a:

 [[ 5 10 17]

 [ 4 50  6]]

인덱스의 개수 =  2

tf.argmax(a, 0): 인덱스  [0 1 0] 가 가장 큽니다.

tf.argmax(a, 1): 인덱스  [2 1] 가 가장 큽니다.



tf.argmax(a, 0)는 2차원 배열의 각 열에서 가장 큰 값을 찾아 인덱스를 반환합니다.

첫번째 열에서 5, 두번째 열에서 50, 세번째 열에서 17을 찾았기 때문에 실행결과가 [0 1 0]입니다. 


index           

   0           [[ 5  10  17]
  1            [ 4  50   6]]



tf.argmax(a, 1)는 2차원 배열의 각 행에서 가장 큰 값을 찾아  인덱스를 반환합니다.. 

첫번째 행에서 17, 두번째 행에서 50을 찾았기 때문에 실행결과가 [2 1]입니다. 


index             0   1   2   

      
              [[ 5  10  17]
                [ 4  50   6]]




MNIST 코드에서는 one hot 벡터로 표현한 라벨이 의미하는 숫자를 찾기 위해 tf.argmax 함수를 사용됩니다. 


예를 들어 다음처럼 10개의 숫자중 세번째 인덱스의 값이 1이라면 이 라벨은 숫자 2를 의미합니다. 

이때 세번째 인덱스가 가장 큰 값임을 빨리 찾기위해 tf.argmax 함수를 사용합니다.

 

[ 0 0 1 0 0 0 0 0 0 0]



라벨이 1차원 벡터인데 실제 코드에서 보면 tf.argmax 함수의 두번째 인자로 1을 사용하고 있습니다. 

이것은 pred와 y의 shape를 출력해보면 알 수 있습니다. (?, 10) 로 출력됩니다. 첫번째 차원은 라벨의 갯수를 표현하기 위해 사용되므로 크기가 정해져 있지 않으며 두번째 차원은 0~9까지 10개의 숫자를 위한 라벨로 사용하기 때문에 10입니다.


두번째 차원을 라벨로 사용하기 때문에 0이 아닌 1을 사용하게 됩니다. 즉 각 행에서 최대값을 찾습니다. 


correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))



포스트 작성시에는 문제 없었지만 이후 문제가 생길 수 있습니다.
댓글로 알려주시면 빠른 시일내에 답변을 드리겠습니다.

여러분의 응원으로 좋은 컨텐츠가 만들어집니다.
지금 본 내용이 도움이 되었다면 유튜브 구독 부탁드립니다. 감사합니다 : )

유튜브 구독하기


제가 쓴 책도 한번 검토해보세요.

  1. 서호아빠 2020.08.01 13:36

    tf.argmax의 두번째 parm은 행렬의 axis를 이야기 합니다. 0 이면 행방향(높이)에서 가장 큰 값의 인덱스를 리턴하고, 1 이면 열방향(폭)에서 가장 큰 값의 인덱스를 리턴하는 것입니다.

    • Favicon of https://webnautes.tistory.com BlogIcon webnautes 2020.08.01 15:50 신고

      행과 열을 바꿔썼나보군요. 확인해보겠습니다.

    • Favicon of https://webnautes.tistory.com BlogIcon webnautes 2020.08.01 17:56 신고

      tf.argmax의 두번째 parm은 행렬의 axis를 이야기 합니다. 0 이면 행방향(높이)에서 가장 큰 값의 인덱스를 리턴하고,
      1 이면 열방향(폭)에서 가장 큰 값의 인덱스를 리턴하는 것입니다.

      올리신 댓글을 아래처럼 다시 확인해보았는데
      argmax의 두번째 인자가 0일때 해당 열에서 높은 값의 인덱스를 반환하는게 맞는 듯합니다.

      행방향(높이)를 세로 한줄의 의미로 사용하신건가요?

      a:
      [[ 3 10 1]
      [ 4 5 6]
      [ 0 8 7]]
      인덱스의 개수 = 2
      tf.argmax(a, 0): 인덱스 [1 0 2] 가 가장 큽니다.

      a:
      [[ 3 10 15]
      [ 4 5 6]
      [ 0 8 7]]
      인덱스의 개수 = 2
      tf.argmax(a, 0): 인덱스 [1 0 0] 가 가장 큽니다.

      행렬 [0,3]의 1을 15로 수정하면
      tf.argmax(a, 0) 실행결과에서 3번째의 값이 0으로 바뀝니다.
      세번째 열에서 최대 값이 7에서 15로 변경된 것입니다.

+ Recent posts