Deep Learning & Machine Learning/강좌&예제 코드

데이터셋을 일정비율로 분리하는 : train_test_split 예제

webnautes 2023. 10. 17. 22:41
반응형

train_test_split를 사용하여 데이터셋을 Train 데이터셋과 Test 데이터셋으로 분리하는 예제 코드입니다.

 

 

2023. 6. 11  최초작성



아래 링크에 있는 데이터셋의 Train 데이터셋을 분리하여 Train 데이터셋과 Test 데이터셋으로 분리했습니다.

https://www.kaggle.com/datasets/andrewmvd/animal-faces



다음 처럼 train 데이터셋과 validation 데이터셋에 cat,dog,wild 클래스별 하위 디렉토리가 위치하고 각각의 파일에 이미지 파일이 포함되어 있습니다. 



코드를 실행하면 train 디렉토리에 있는 클래스별 이미지 파일을 지정한 비율로 나누어서 dataset 하위 디렉토리에 train, test 디렉토리에 나누어서 저장합니다. 



전체 소스코드입니다.

 

import os
import shutil
from sklearn.model_selection import train_test_split


# 원본 데이터셋을 지정합니다.
src_dirs_root = 'afhq/train'

# 원본 데이터셋의 하위 디렉토리를 저장합니다.
src_dirs = []
for entry in os.listdir(src_dirs_root):
    entry_path = os.path.join(src_dirs_root, entry)
    if os.path.isdir(entry_path):
        src_dirs.append(entry_path)

print(src_dirs)
print()
# ['afhq/train/cat', 'afhq/train/dog', 'afhq/train/wild']


# 원본 데이터셋을 복사할 디렉토리로 dataset 아래에 train 디렉토리와 test 디렉토리로 지정합니다.
target_parent_dir = 'dataset/'
train_dir = os.path.join(target_parent_dir, 'train/')
test_dir = os.path.join(target_parent_dir, 'test/')

# 디렉토리가 생성되어 있지 않으면 새로 생성합니다.
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)


# 원본 데이터셋의 하위 디렉토리 별로 처리합니다.
for src_dir in src_dirs:
   
    # 클래스 이름을 추출합니다.
    class_name = os.path.basename(src_dir)
   
    # train 디렉토리와 test 디렉토리에 클래스 이름으로 하위 디렉토리를 생성합니다.
    os.makedirs(os.path.join(train_dir, class_name), exist_ok=True)
    os.makedirs(os.path.join(test_dir, class_name), exist_ok=True)

    # 원본 데이터셋의 하위 디렉토리에 있는 모든 파일의 리스트를 가져옵니다.
    files = os.listdir(src_dir)

    # 파일 리스트를 train,test로 나눕니다.
    train_files, test_files = train_test_split(files, test_size=0.2, random_state=42# 80%/20% split

    # train 디렉토리로 파일을 복사합니다.
    for file_name in train_files:
        shutil.copy(os.path.join(src_dir, file_name), os.path.join(train_dir, class_name, file_name))

    # test 디렉토리로 파일을 복사합니다.
    for file_name in test_files:
        shutil.copy(os.path.join(src_dir, file_name), os.path.join(test_dir, class_name, file_name))




반응형