반응형

multiclass classification에서 클래스별로 ROC curve를 그리는 예제 코드입니다.



2022. 3. 14  최초작성



# https://stackoverflow.com/questions/45332410/roc-for-multiclass-classification
# https://moons08.github.io/datascience/classification_score_roc_auc/

from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score, auc, roc_auc_score, roc_curve
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
import pandas as pd


iris = load_iris()

iris_data = iris.data
iris_label = iris.target

# print(iris_data.shape) 
# print(iris_label.shape)
# (150, 4)
# (150,)


iris_df = pd.DataFrame(data = iris_data, columns = iris.feature_names)

# print(iris_df.head())
#    sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)
# 0                5.1               3.5                1.4               0.2
# 1                4.9               3.0                1.4               0.2
# 2                4.7               3.2                1.3               0.2
# 3                4.6               3.1                1.5               0.2
# 4                5.0               3.6                1.4               0.2

# label 컬럼 추가
iris_df['label'] = iris.target

# print(iris_df.head())
#    sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  label
# 0                5.1               3.5                1.4               0.2      0
# 1                4.9               3.0                1.4               0.2      0
# 2                4.7               3.2                1.3               0.2      0
# 3                4.6               3.1                1.5               0.2      0
# 4                5.0               3.6                1.4               0.2      0


# train 데이터 세트와 test 데이터 세트로 분리. 비율은 8 : 2
X_train, X_test, y_train, y_test = train_test_split(iris_data,
                                                    iris_label,
                                                    test_size=0.2,
                                                    random_state=7)

# print(X_train.shape, y_train.shape)
# print(X_test.shape, y_test.shape)
# (120, 4) (120,)
# (30, 4) (30,)


# 결정 트리 모델 분류 학습
decision_tree = DecisionTreeClassifier(random_state=32)
decision_tree.fit(X_train, y_train)


# 추론
y_pred = decision_tree.predict(X_test)

# 정확도
accuracy = accuracy_score(y_test, y_pred)
print('정확도 {:.4f}'.format(accuracy))
# 정확도 0.9000

# print(y_test.shape, y_pred.shape)
# (30,) (30,)

# print('y_test[:10]', y_test[:10])
# y_test[:10] [2 1 0 1 2 0 1 1 0 1]

# label_binarize를 사용하여 클래스별로 이진화를 합니다.
# 클래스별로 배열을 따로 만들어서 해당 클래스의 값이면 1, 아니면 0으로 표시합니다.
# label_binarize 전후의 배열값 변화를 확인해보세요
labels = [0, 1, 2]
y_test = label_binarize(y_test, classes=labels)
y_pred = label_binarize(y_pred, classes=labels)

# print(y_test.shape, y_pred.shape)
# (30, 3) (30, 3)

# print('y_test[:10, 0]', y_test[:10, 0])
# print('y_test[:10, 1]', y_test[:10, 1])
# print('y_test[:10, 2]', y_test[:10, 2])
# y_test[:10, 0] [0 0 1 0 0 1 0 0 1 0]
# y_test[:10, 1] [0 1 0 1 0 0 1 1 0 1]
# y_test[:10, 2] [1 0 0 0 1 0 0 0 0 0]


# 클래스별로 ROC curve를 그립니다.
n_classes = 3
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Plot of a ROC curve for a specific class
plt.figure(figsize=(15, 5))
for idx, i in enumerate(range(n_classes)):
    plt.subplot(131+idx)
    plt.plot(fpr[i], tpr[i], label='ROC curve (area = %0.2f)' % roc_auc[i])
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Class %0.0f' % idx)
    plt.legend(loc="lower right")
plt.show()

print("roc_auc_score: ", roc_auc_score(y_test, y_pred, multi_class='raise'))
# roc_auc_score:  0.9302675881623251



 

반응형

문제 발생시 지나치지 마시고 댓글 남겨주시면 가능한 빨리 답장드립니다.

도움이 되셨다면 토스아이디로 후원해주세요.
https://toss.me/momo2024


제가 쓴 책도 한번 검토해보세요 ^^

+ Recent posts