PyTorch Dataset과 DataLoader 활용법
PyTorch에서 데이터 처리는 중요한 요소 중 하나입니다. 특히, 대량의 데이터를 효과적으로 관리하고 모델 학습에 적절히 공급하기 위해 PyTorch에서는 Dataset
과 DataLoader
라는 두 가지 핵심 클래스를 제공합니다. 본 포스팅에서는 이 두 개념을 이해하고, 직접 활용하는 방법을 예제 코드와 함께 설명하겠습니다.
1. Dataset 클래스 개요
Dataset
클래스는 PyTorch에서 데이터를 로드하는 기본 단위입니다. PyTorch의 torch.utils.data.Dataset
클래스를 상속하여 사용자의 필요에 맞게 커스텀 데이터셋을 만들 수 있습니다.
1.1 Dataset 클래스의 주요 메서드
Dataset
을 커스텀하게 정의하기 위해서는 최소한 다음 세 가지 메서드를 구현해야 합니다.
__init__
: 데이터셋을 초기화하는 생성자입니다.__len__
: 데이터셋의 크기(길이)를 반환합니다.__getitem__
: 주어진 인덱스에 해당하는 데이터를 반환합니다.
1.2 기본적인 Dataset 구현 예제
아래는 간단한 CSV 파일에서 데이터를 로드하는 Dataset
클래스의 예제입니다.
import torch
from torch.utils.data import Dataset
import pandas as pd
class CustomDataset(Dataset):
def __init__(self, csv_file):
self.data = pd.read_csv(csv_file)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = torch.tensor(self.data.iloc[idx].values, dtype=torch.float32)
return sample
위의 CustomDataset
클래스는 CSV 파일을 로드하고, 각 행을 PyTorch Tensor로 변환하여 반환하는 기능을 합니다.
2. DataLoader 활용법
DataLoader
는 Dataset
을 감싸는 역할을 하며, 데이터를 일괄(batch) 처리하거나, 섞거나(shuffle), 다중 프로세스를 이용해 데이터를 불러오는 등의 기능을 제공합니다.
2.1 DataLoader의 주요 파라미터
torch.utils.data.DataLoader
를 사용할 때, 몇 가지 중요한 파라미터를 이해하는 것이 중요합니다.
dataset
: 사용할 데이터셋 객체 (Dataset
클래스의 인스턴스)batch_size
: 한 번에 불러올 데이터의 개수shuffle
: 데이터를 랜덤하게 섞을지 여부num_workers
: 데이터를 불러올 때 사용할 병렬 프로세스 개수
2.2 DataLoader 사용 예제
아래는 앞서 정의한 CustomDataset
을 DataLoader
에 적용하는 예제입니다.
from torch.utils.data import DataLoader
dataset = CustomDataset("data.csv")
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)
for batch in dataloader:
print(batch)
위 코드는 data.csv
파일을 데이터셋으로 변환한 후, 한 번에 4개씩 데이터를 불러와 출력합니다.
3. Dataset과 DataLoader의 실전 활용
3.1 이미지 데이터 처리 예제
이미지 데이터를 다룰 때도 Dataset
을 커스텀하여 사용할 수 있습니다. 아래는 이미지 데이터셋을 로드하는 예제입니다.
from torchvision import transforms
from PIL import Image
import os
class ImageDataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.image_filenames = os.listdir(image_dir)
self.transform = transform
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.image_filenames[idx])
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
return image
transform
을 활용하면 이미지 크기 조정, 정규화 등의 사전 처리를 자동화할 수 있습니다.
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
dataset = ImageDataset("images", transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
4. Dataset과 DataLoader의 활용 시 주의할 점
- 데이터 로딩 속도 최적화:
num_workers
를 적절히 설정하여 데이터 로딩 속도를 최적화할 수 있습니다. - 메모리 관리: 너무 큰
batch_size
를 설정하면 메모리 부족 문제가 발생할 수 있습니다. - 데이터 변형(transform): 이미지나 텍스트 데이터를 처리할 때는
torchvision.transforms
또는torchtext
등의 라이브러리를 활용하는 것이 유용합니다.
5. 결론
PyTorch의 Dataset
과 DataLoader
는 데이터 처리를 효율적으로 관리하는 핵심 도구입니다.
Dataset
은 데이터셋을 로드하고 관리하는 역할을 하며, 커스텀 클래스를 작성하여 확장할 수 있습니다.DataLoader
는 배치 단위로 데이터를 제공하고, 데이터를 무작위로 섞거나 병렬 처리를 할 수 있도록 돕습니다.
위에서 제공한 예제를 활용하여 자신만의 데이터셋을 구성하고, 효과적으로 모델 학습을 진행해 보시기 바랍니다.
'PyTorch' 카테고리의 다른 글
PyTorch 데이터 변환 및 Augmentation (0) | 2025.04.08 |
---|---|
PyTorch 이미지 데이터 및 텍스트 데이터 로딩 (0) | 2025.04.07 |
PyTorch Autograd 소개 (자동 미분) (0) | 2025.04.05 |
PyTorch Tensor와 NumPy 배열 비교 (0) | 2025.04.04 |
PyTorch Tensor 생성 및 조작 (0) | 2025.04.03 |