728x90
PyTorch로 GAN 구현하기
1. 서론
이번 포스팅에서는 PyTorch를 활용하여 GAN(Generative Adversarial Network)을 구현하는 방법을 설명하겠습니다. GAN은 생성자(Generator)와 판별자(Discriminator)가 서로 경쟁하며 학습하는 모델로, 주어진 데이터 분포를 학습하여 새로운 데이터를 생성할 수 있습니다.
이 글에서는 간단한 GAN을 구현하여 MNIST 데이터를 생성하는 과정을 살펴보겠습니다.
2. GAN의 기본 개념
GAN은 다음과 같은 두 개의 신경망으로 구성됩니다.
- 생성자(Generator): 랜덤 노이즈를 입력받아 실제 데이터와 유사한 가짜 데이터를 생성합니다.
- 판별자(Discriminator): 입력받은 데이터가 실제 데이터인지, 생성자가 만든 가짜 데이터인지 판별합니다.
두 네트워크는 서로 경쟁하면서 점점 더 정교한 가짜 데이터를 생성하도록 학습됩니다.
GAN의 손실 함수는 다음과 같습니다.
$$
\min_G \max_D V(D, G) = \mathbb{E}{x \sim P{data}}[\log D(x)] + \mathbb{E}_{z \sim P_z}[\log(1 - D(G(z)))]
$$
즉, 판별자는 실제 데이터에 대해 1을 출력하고, 생성자가 만든 가짜 데이터에 대해 0을 출력하도록 학습되며, 생성자는 판별자를 속여 1을 출력하도록 학습됩니다.
3. PyTorch를 활용한 GAN 구현
3.1 라이브러리 임포트
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
3.2 데이터셋 준비
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
batch_size = 64
dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
3.3 생성자(Generator) 정의
class Generator(nn.Module):
def __init__(self, noise_dim=100):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(noise_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 28*28),
nn.Tanh()
)
def forward(self, z):
return self.model(z).view(-1, 1, 28, 28)
3.4 판별자(Discriminator) 정의
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(28*28, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x.view(-1, 28*28))
3.5 모델 학습
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 하이퍼파라미터 설정
noise_dim = 100
lr = 0.0002
epochs = 50
# 모델 초기화
generator = Generator(noise_dim).to(device)
discriminator = Discriminator().to(device)
# 손실 함수 및 최적화 기법
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
for epoch in range(epochs):
for real_imgs, _ in dataloader:
real_imgs = real_imgs.to(device)
batch_size = real_imgs.size(0)
# 판별자 학습
z = torch.randn(batch_size, noise_dim).to(device)
fake_imgs = generator(z)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
real_loss = criterion(discriminator(real_imgs), real_labels)
fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
d_loss = real_loss + fake_loss
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# 생성자 학습
g_loss = criterion(discriminator(fake_imgs), real_labels)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
print(f"Epoch [{epoch+1}/{epochs}] - D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
3.6 생성된 이미지 시각화
def generate_and_show_images(generator, noise_dim, num_images=16):
generator.eval()
z = torch.randn(num_images, noise_dim).to(device)
fake_images = generator(z).cpu().detach().numpy()
fig, axes = plt.subplots(4, 4, figsize=(6, 6))
for i, ax in enumerate(axes.flat):
ax.imshow(fake_images[i].squeeze(), cmap='gray')
ax.axis('off')
plt.show()
generate_and_show_images(generator, noise_dim)
4. 결론
이번 포스팅에서는 PyTorch를 활용하여 간단한 GAN 모델을 구현하는 과정을 살펴보았습니다. GAN을 활용하면 이미지 생성뿐만 아니라 다양한 응용이 가능합니다. 앞으로 DCGAN, WGAN 등의 변형 모델을 실습하면서 더 발전된 생성 모델을 학습할 수 있도록 하겠습니다.
728x90
'PyTorch' 카테고리의 다른 글
Pretrained Model 소개 (ResNet, VGG) (0) | 2025.05.01 |
---|---|
PyTorch 데이터 생성 실험 - GAN을 활용한 이미지 생성 (0) | 2025.04.30 |
GAN(Generative Adversarial Network) 개념 및 구조 (0) | 2025.04.28 |
간단한 BERT 모델 Fine-Tuning (0) | 2025.04.27 |
Hugging Face와 PyTorch 활용 (0) | 2025.04.26 |