[ NLP ] GAN : Generative Adversarial Nets
๋ณธ ํฌ์คํธ์์๋ ๋ฌธ์ฅ ์์ฑ ๋ชจ๋ธ์ ๊ดํ ํ๋ก์ ํธ๋ฅผ ์งํํ๊ธฐ ์ํด ๊ณต๋ถํ ๋ด์ฉ์ ์ ๋ฆฌํ์๋ค. GAN์ ๋ํด ์์๋ณด์.
Generative Model ์ ๋ชฉํ
์๊ฐ์ด ์ง๋๋ฉด์ ์์ฑ ๋ชจ๋ธ์ด ์๋ณธ ๋ฐ์ดํฐ์ ๋ถํฌ๋ฅผ ํ์ตํ๋ค.
ํ์ต์ด ์ ๋์๋ค๋ฉด ํต๊ณ์ ์ผ๋ก ํ๊ท ์ ์ธ ํน์ง์ ๊ฐ์ง๋ ๋ฐ์ดํฐ๋ฅผ ์ฝ๊ฒ ์์ฑํ ์ ์๋ค. ์ด๋ ํ๋ณ ๋ชจ๋ธ์ ๋ ์ด์ ์ง์ง ์ด๋ฏธ์ง์ ๊ฐ์ง ์ด๋ฏธ์ง๋ฅผ ๊ตฌ๋ถํ ์ ์๊ธฐ ๋๋ฌธ์ ๋ถํฌ๊ฐ 1/2 ๋ก ์๋ ดํ๋ค.
GAN์ด๋
Generative Adversarial Networks์ ์ฝ์๋ก, ์ค์ ๋ก ์กด์ฌํ์ง ์์ง๋ง ์์๋ฒํ ์ด๋ฏธ์ง๋ ํ ์คํธ ๋ฑ ์ฌ๋ฌ ๋ฐ์ดํฐ๋ฅผ ์์ฑํด ๋ด๋ ๋ชจ๋ธ์ด๋ค.
์ด๋ฆ์์ ์ ์ ์๋ฏ์ด GAN์ ์๋ก ๋ค๋ฅธ ๋ ๊ฐ์ ๋คํธ์ํฌ๋ฅผ ์ ๋์ ์ผ๋ก(adversarial) ํ์ต์ํค๋ฉฐ ์ค์ ๋ฐ์ดํฐ์ ๋น์ทํ ๋ฐ์ดํฐ๋ฅผ ์์ฑ(generative)ํด๋ด๋ ๋ชจ๋ธ์ด๋ฉฐ ์ด๋ ๊ฒ ์์ฑ๋ ๋ฐ์ดํฐ์ ์ ํด์ง label๊ฐ์ด ์๊ธฐ ๋๋ฌธ์ ๋น์ง๋ ํ์ต ๊ธฐ๋ฐ ์์ฑ๋ชจ๋ธ๋ก ๋ถ๋ฅ๋๋ค.
์ด๋ ๋ ๊ฐ์ ๋คํธ์ํฌ๋ ์์ฑ์ generator ์ ํ๋ณ์ discriminator ๋ฅผ ๋งํ๋ค. ํ์ต์ด ๋ค ๋ ์ดํ์, ๋ชจ๋ธ์ ์์ฑ์๋ผ๊ณ ํ ์ ์์ผ๋ฉฐ ํ๋ณ์๋ ์ด ์์ฑ์๊ฐ ์ ํ์ตํ ์ ์๋๋ก ๋์์ฃผ๊ธฐ ์ํ ๋ชฉ์ ์ผ๋ก ๋ง์ด ์ฌ์ฉํ๋ค. ์ด ๋ ๊ฐ์ ๋คํธ์ํฌ๋ฅผ ํจ๊ป ํ์ต์ํค๋ฉด์ ๊ฒฐ๊ณผ์ ์ผ๋ก ์์ฑ ๋ชจ๋ธ์ ํ์ตํ ์ ์๊ฒ ๋๋ค !
[ ๋ชฉ์ ํจ์ ]
์์ฑ์ G ๋ ํจ์ V ์ ๊ฐ์ ๋ฎ์ถ๊ณ ์ ๋ ธ๋ ฅํ๊ณ ํ๋ณ์ D ๋ ๋์ด๋ ค๊ณ ๋ ธ๋ ฅํ๋ค. ์ด ํจ์๋ฅผ ํตํด ์์ฑ์๋ ์ด๋ฏธ์ง ๋ถํฌ๋ฅผ ํ์ตํ ์ ์๋ค.
1๏ธโฃ ์ด๋ฏธ์ง
๋ฐ์ดํฐ์์ ์ฌ๋ฌ ๊ฐ์ ๋ฐ์ดํฐ (์ด๋ฏธ์ง) ๋ฅผ ๋ฝ์์์ log ๋ณํํด์ค ํ์ ๊ธฐ๋๊ฐ, ์ฆ ํ๊ท ๊ฐ
2๏ธโฃ ๋
ธ์ด์ฆ
๋
ธ์ด์ฆ๋ฅผ ์ํ๋งํด์์ ์์ฑ์ G์ ๋ฃ์ด์ ๊ฐ์ง ์ด๋ฏธ์ง๋ฅผ ๋ง๋ ํ ํ๋ณ์ D์ ๋ฃ์ ๊ฐ์ 1์์ ๋นผ์ Log๋ฅผ ์ทจํ ๊ฐ์ ํ๊ท ๊ฐ
์์ฑ์์ ๋ํ ๊ฐ๋
์ด ํฌํจ. ๊ธฐ๋ณธ์ ์ผ๋ก ์์ฑ์๋ ๋
ธ์ด์ฆ ๋ฒกํฐ๋ก๋ถํฐ ์ด๋ฏธ์ง๋ฅผ ๋ฐ์์์ ์๋ก์ด ์ด๋ฏธ์ง๋ฅผ ๋ง๋ค ์ ์๋ค. ์ง์ง ์ด๋ฏธ์ง๋ 1, ๊ฐ์ง ์ด๋ฏธ์ง๋ 0.
๋ชจ๋ธ ํ์ต
๋งค epoch๋น Descriminator๋ฅผ ๋จผ์ ํ์ตํ๊ณ , Generator์ ํ์ต์ด ์ด๋ฃจ์ด์ง๋ค.
Descriminator๋ ๊ธฐ์ธ๊ธฐ (Stochastic gradient) ๊ฐ ์ฆ๊ฐํ๋ ๋ฐฉํฅ์ผ๋ก ํ์ต๋๊ณ , generator๋ ๊ธฐ์ธ๊ธฐ๊ฐ ๊ฐ์ํ๋ ๋ฐฉํฅ์ผ๋ก ํ์ต๋๋ค.
GAN ํน์ง
- Not cherry-picked
- ์ด๋ฏธ์ง๋ฅผ ์ ๋ณํด์ ๋ฃ์ ๊ฒ ์๋๋ผ ๋๋คํ๊ฒ ๋ฃ์ด์ค
- Not memorized the training set
- ํ์ต ๋ฐ์ดํฐ๋ฅผ ๋จ์ํ ์๊ธฐํ ๊ฒ์ด ์๋
- Competitive with the better generative models
- ์ด์ ์ ๋ค๋ฅธ ์์ฑ ๋ชจ๋ธ๊ณผ ๋น๊ตํ์ ๋ ์ถฉ๋ถํ ์ข์ ์ฑ๋ฅ์ด ๋์ด
- Images represent sharp
- Blurryํ๊ฒ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ์ง ์๊ณ ๊ฝค ์ ๋ช ํ๊ฒ ์์ฑ
- ๋ฌธ์ฅ์ด ๊ธธ์ด์ง์๋ก ์์ฑ๋ ๋ฌธ์ฅ์ ํ์ง์ด ์ ์ข์์ง
๐ป ์ฝ๋ ์ค์ต - GAN
ํด๋น ์ฝ๋๋ MNIST ๋ฐ์ดํฐ์ ์ ๊ฐ์ง๊ณ ๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ GAN ๋ชจ๋ธ์ ํ์ตํ ์ฝ๋์ด๋ค.
์์ฑ์ Generator ๋ฐ ํ๋ณ์ Discriminator ๋ชจ๋ธ ์ ์
latent_dim = 100
# ์์ฑ์(Generator) ํด๋์ค ์ ์
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# ํ๋์ ๋ธ๋ก(block) ์ ์
def block(input_dim, output_dim, normalize=True):
layers = [nn.Linear(input_dim, output_dim)]
if normalize:
# ๋ฐฐ์น ์ ๊ทํ(batch normalization) ์ํ(์ฐจ์ ๋์ผ)
layers.append(nn.BatchNorm1d(output_dim, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
# ์์ฑ์ ๋ชจ๋ธ์ ์ฐ์์ ์ธ ์ฌ๋ฌ ๊ฐ์ ๋ธ๋ก์ ๊ฐ์ง
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, 1 * 28 * 28),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 1, 28, 28)
return img
# ํ๋ณ์(Discriminator) ํด๋์ค ์ ์
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(1 * 28 * 28, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
# ์ด๋ฏธ์ง์ ๋ํ ํ๋ณ ๊ฒฐ๊ณผ๋ฅผ ๋ฐํ
def forward(self, img):
flattened = img.view(img.size(0), -1)
output = self.model(flattened)
return output
ํ์ต ๋ฐ์ดํฐ์ ๋ถ๋ฌ์ค๊ธฐ
transforms_train = transforms.Compose([
transforms.Resize(28),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root="./dataset", train=True, download=True, transform=transforms_train)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
๋ชจ๋ธ ํ์ต ๋ฐ ์ํ๋ง
# ์์ฑ์(generator)์ ํ๋ณ์(discriminator) ์ด๊ธฐํ
generator = Generator()
discriminator = Discriminator()
generator.cuda()
discriminator.cuda()
# ์์ค ํจ์(loss function)
adversarial_loss = nn.BCELoss()
adversarial_loss.cuda()
# ํ์ต๋ฅ (learning rate) ์ค์
lr = 0.0002
# ์์ฑ์์ ํ๋ณ์๋ฅผ ์ํ ์ต์ ํ ํจ์
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
import time
n_epochs = 200 # ํ์ต์ ํ์(epoch) ์ค์
sample_interval = 2000 # ๋ช ๋ฒ์ ๋ฐฐ์น(batch)๋ง๋ค ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ ๊ฒ์ธ์ง ์ค์
start_time = time.time()
for epoch in range(n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# ์ง์ง(real) ์ด๋ฏธ์ง์ ๊ฐ์ง(fake) ์ด๋ฏธ์ง์ ๋ํ ์ ๋ต ๋ ์ด๋ธ ์์ฑ
real = torch.cuda.FloatTensor(imgs.size(0), 1).fill_(1.0) # ์ง์ง(real): 1
fake = torch.cuda.FloatTensor(imgs.size(0), 1).fill_(0.0) # ๊ฐ์ง(fake): 0
real_imgs = imgs.cuda()
""" ์์ฑ์(generator)๋ฅผ ํ์ตํฉ๋๋ค. """
optimizer_G.zero_grad()
# ๋๋ค ๋
ธ์ด์ฆ(noise) ์ํ๋ง
z = torch.normal(mean=0, std=1, size=(imgs.shape[0], latent_dim)).cuda()
# ์ด๋ฏธ์ง ์์ฑ
generated_imgs = generator(z)
# ์์ฑ์(generator)์ ์์ค(loss) ๊ฐ ๊ณ์ฐ
g_loss = adversarial_loss(discriminator(generated_imgs), real)
# ์์ฑ์(generator) ์
๋ฐ์ดํธ
g_loss.backward()
optimizer_G.step()
""" ํ๋ณ์(discriminator)๋ฅผ ํ์ตํฉ๋๋ค. """
optimizer_D.zero_grad()
# ํ๋ณ์(discriminator)์ ์์ค(loss) ๊ฐ ๊ณ์ฐ
real_loss = adversarial_loss(discriminator(real_imgs), real)
fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
# ํ๋ณ์(discriminator) ์
๋ฐ์ดํธ
d_loss.backward()
optimizer_D.step()
done = epoch * len(dataloader) + i
if done % sample_interval == 0:
# ์์ฑ๋ ์ด๋ฏธ์ง ์ค์์ 25๊ฐ๋ง ์ ํํ์ฌ 5 X 5 ๊ฒฉ์ ์ด๋ฏธ์ง์ ์ถ๋ ฅ
save_image(generated_imgs.data[:25], f"{done}.png", nrow=5, normalize=True)
# epoch 10๋ง๋ค ๋ก๊ทธ(log) ์ถ๋ ฅ
if epoch % 10 == 0:
print(f"[Epoch {epoch}/{n_epochs}] [D loss: {d_loss.item():.6f}] [G loss: {g_loss.item():.6f}] [Elapsed time: {time.time() - start_time:.2f}s]")
์์ฑ๋ ์ด๋ฏธ์ง ์ถ๋ ฅ
from IPython.display import Image
Image('92000.png')
์ฐธ๊ณ
[YouTube] GAN:Generative Adversarial Networks
Make GANs Training Easier for Everyone : Generate Images Following a Sketch
๋ ผ๋ฌธ Generative Adversarial Networks (NIPS 2014)
๐ You need to log in to GitHub to write comments. ๐
If you can't see comments, please refresh page(F5).