PyTorch

PyTorch Dataset과 DataLoader 활용법

PyExplorer 2025. 4. 6. 09:56
728x90

PyTorch Dataset과 DataLoader 활용법

PyTorch에서 데이터 처리는 중요한 요소 중 하나입니다. 특히, 대량의 데이터를 효과적으로 관리하고 모델 학습에 적절히 공급하기 위해 PyTorch에서는 DatasetDataLoader라는 두 가지 핵심 클래스를 제공합니다. 본 포스팅에서는 이 두 개념을 이해하고, 직접 활용하는 방법을 예제 코드와 함께 설명하겠습니다.

1. Dataset 클래스 개요

Dataset 클래스는 PyTorch에서 데이터를 로드하는 기본 단위입니다. PyTorch의 torch.utils.data.Dataset 클래스를 상속하여 사용자의 필요에 맞게 커스텀 데이터셋을 만들 수 있습니다.

1.1 Dataset 클래스의 주요 메서드

Dataset을 커스텀하게 정의하기 위해서는 최소한 다음 세 가지 메서드를 구현해야 합니다.

  • __init__: 데이터셋을 초기화하는 생성자입니다.
  • __len__: 데이터셋의 크기(길이)를 반환합니다.
  • __getitem__: 주어진 인덱스에 해당하는 데이터를 반환합니다.

1.2 기본적인 Dataset 구현 예제

아래는 간단한 CSV 파일에서 데이터를 로드하는 Dataset 클래스의 예제입니다.

import torch
from torch.utils.data import Dataset
import pandas as pd

class CustomDataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = torch.tensor(self.data.iloc[idx].values, dtype=torch.float32)
        return sample

위의 CustomDataset 클래스는 CSV 파일을 로드하고, 각 행을 PyTorch Tensor로 변환하여 반환하는 기능을 합니다.

2. DataLoader 활용법

DataLoaderDataset을 감싸는 역할을 하며, 데이터를 일괄(batch) 처리하거나, 섞거나(shuffle), 다중 프로세스를 이용해 데이터를 불러오는 등의 기능을 제공합니다.

2.1 DataLoader의 주요 파라미터

torch.utils.data.DataLoader를 사용할 때, 몇 가지 중요한 파라미터를 이해하는 것이 중요합니다.

  • dataset: 사용할 데이터셋 객체 (Dataset 클래스의 인스턴스)
  • batch_size: 한 번에 불러올 데이터의 개수
  • shuffle: 데이터를 랜덤하게 섞을지 여부
  • num_workers: 데이터를 불러올 때 사용할 병렬 프로세스 개수

2.2 DataLoader 사용 예제

아래는 앞서 정의한 CustomDatasetDataLoader에 적용하는 예제입니다.

from torch.utils.data import DataLoader

dataset = CustomDataset("data.csv")
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)

for batch in dataloader:
    print(batch)

위 코드는 data.csv 파일을 데이터셋으로 변환한 후, 한 번에 4개씩 데이터를 불러와 출력합니다.

3. Dataset과 DataLoader의 실전 활용

3.1 이미지 데이터 처리 예제

이미지 데이터를 다룰 때도 Dataset을 커스텀하여 사용할 수 있습니다. 아래는 이미지 데이터셋을 로드하는 예제입니다.

from torchvision import transforms
from PIL import Image
import os

class ImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_filenames = os.listdir(image_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_filenames[idx])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

transform을 활용하면 이미지 크기 조정, 정규화 등의 사전 처리를 자동화할 수 있습니다.

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

dataset = ImageDataset("images", transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

4. Dataset과 DataLoader의 활용 시 주의할 점

  • 데이터 로딩 속도 최적화: num_workers를 적절히 설정하여 데이터 로딩 속도를 최적화할 수 있습니다.
  • 메모리 관리: 너무 큰 batch_size를 설정하면 메모리 부족 문제가 발생할 수 있습니다.
  • 데이터 변형(transform): 이미지나 텍스트 데이터를 처리할 때는 torchvision.transforms 또는 torchtext 등의 라이브러리를 활용하는 것이 유용합니다.

5. 결론

PyTorch의 DatasetDataLoader는 데이터 처리를 효율적으로 관리하는 핵심 도구입니다.

  • Dataset은 데이터셋을 로드하고 관리하는 역할을 하며, 커스텀 클래스를 작성하여 확장할 수 있습니다.
  • DataLoader는 배치 단위로 데이터를 제공하고, 데이터를 무작위로 섞거나 병렬 처리를 할 수 있도록 돕습니다.

위에서 제공한 예제를 활용하여 자신만의 데이터셋을 구성하고, 효과적으로 모델 학습을 진행해 보시기 바랍니다.

728x90