TIL | GAN - NLP Text Generation Model

[ NLP ] GAN : Generative Adversarial Nets

๋ณธ ํฌ์ŠคํŠธ์—์„œ๋Š” ๋ฌธ์žฅ ์ƒ์„ฑ ๋ชจ๋ธ์— ๊ด€ํ•œ ํ”„๋กœ์ ํŠธ๋ฅผ ์ง„ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ๊ณต๋ถ€ํ•œ ๋‚ด์šฉ์„ ์ •๋ฆฌํ•˜์˜€๋‹ค. GAN์— ๋Œ€ํ•ด ์•Œ์•„๋ณด์ž.

Generative Model ์˜ ๋ชฉํ‘œ

์‹œ๊ฐ„์ด ์ง€๋‚˜๋ฉด์„œ ์ƒ์„ฑ ๋ชจ๋ธ์ด ์›๋ณธ ๋ฐ์ดํ„ฐ์˜ ๋ถ„ํฌ๋ฅผ ํ•™์Šตํ•œ๋‹ค.

ํ•™์Šต์ด ์ž˜ ๋˜์—ˆ๋‹ค๋ฉด ํ†ต๊ณ„์ ์œผ๋กœ ํ‰๊ท ์ ์ธ ํŠน์ง•์„ ๊ฐ€์ง€๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ์‰ฝ๊ฒŒ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค. ์ด๋•Œ ํŒ๋ณ„ ๋ชจ๋ธ์€ ๋” ์ด์ƒ ์ง„์งœ ์ด๋ฏธ์ง€์™€ ๊ฐ€์งœ ์ด๋ฏธ์ง€๋ฅผ ๊ตฌ๋ถ„ํ•  ์ˆ˜ ์—†๊ธฐ ๋•Œ๋ฌธ์— ๋ถ„ํฌ๊ฐ€ 1/2 ๋กœ ์ˆ˜๋ ดํ•œ๋‹ค.

GAN์ด๋ž€

Generative Adversarial Networks์˜ ์•ฝ์ž๋กœ, ์‹ค์ œ๋กœ ์กด์žฌํ•˜์ง€ ์•Š์ง€๋งŒ ์žˆ์„๋ฒ•ํ•œ ์ด๋ฏธ์ง€๋‚˜ ํ…์ŠคํŠธ ๋“ฑ ์—ฌ๋Ÿฌ ๋ฐ์ดํ„ฐ๋ฅผ ์ƒ์„ฑํ•ด ๋‚ด๋Š” ๋ชจ๋ธ์ด๋‹ค.

์ด๋ฆ„์—์„œ ์•Œ ์ˆ˜ ์žˆ๋“ฏ์ด GAN์€ ์„œ๋กœ ๋‹ค๋ฅธ ๋‘ ๊ฐœ์˜ ๋„คํŠธ์›Œํฌ๋ฅผ ์ ๋Œ€์ ์œผ๋กœ(adversarial) ํ•™์Šต์‹œํ‚ค๋ฉฐ ์‹ค์ œ ๋ฐ์ดํ„ฐ์™€ ๋น„์Šทํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์ƒ์„ฑ(generative)ํ•ด๋‚ด๋Š” ๋ชจ๋ธ์ด๋ฉฐ ์ด๋ ‡๊ฒŒ ์ƒ์„ฑ๋œ ๋ฐ์ดํ„ฐ์— ์ •ํ•ด์ง„ label๊ฐ’์ด ์—†๊ธฐ ๋•Œ๋ฌธ์— ๋น„์ง€๋„ ํ•™์Šต ๊ธฐ๋ฐ˜ ์ƒ์„ฑ๋ชจ๋ธ๋กœ ๋ถ„๋ฅ˜๋œ๋‹ค.

์ด๋•Œ ๋‘ ๊ฐœ์˜ ๋„คํŠธ์›Œํฌ๋Š” ์ƒ์„ฑ์ž generator ์™€ ํŒ๋ณ„์ž discriminator ๋ฅผ ๋งํ•œ๋‹ค. ํ•™์Šต์ด ๋‹ค ๋œ ์ดํ›„์—, ๋ชจ๋ธ์€ ์ƒ์„ฑ์ž๋ผ๊ณ  ํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ ํŒ๋ณ„์ž๋Š” ์ด ์ƒ์„ฑ์ž๊ฐ€ ์ž˜ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋„๋ก ๋„์™€์ฃผ๊ธฐ ์œ„ํ•œ ๋ชฉ์ ์œผ๋กœ ๋งŽ์ด ์‚ฌ์šฉํ•œ๋‹ค. ์ด ๋‘ ๊ฐœ์˜ ๋„คํŠธ์›Œํฌ๋ฅผ ํ•จ๊ป˜ ํ•™์Šต์‹œํ‚ค๋ฉด์„œ ๊ฒฐ๊ณผ์ ์œผ๋กœ ์ƒ์„ฑ ๋ชจ๋ธ์„ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค !


[ ๋ชฉ์  ํ•จ์ˆ˜ ]

์ƒ์„ฑ์ž G ๋Š” ํ•จ์ˆ˜ V ์˜ ๊ฐ’์„ ๋‚ฎ์ถ”๊ณ ์ž ๋…ธ๋ ฅํ•˜๊ณ  ํŒ๋ณ„์ž D ๋Š” ๋†’์ด๋ ค๊ณ  ๋…ธ๋ ฅํ•œ๋‹ค. ์ด ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ์ƒ์„ฑ์ž๋Š” ์ด๋ฏธ์ง€ ๋ถ„ํฌ๋ฅผ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋‹ค.

1๏ธโƒฃ ์ด๋ฏธ์ง€
๋ฐ์ดํ„ฐ์—์„œ ์—ฌ๋Ÿฌ ๊ฐœ์˜ ๋ฐ์ดํ„ฐ (์ด๋ฏธ์ง€) ๋ฅผ ๋ฝ‘์•„์™€์„œ log ๋ณ€ํ™˜ํ•ด์ค€ ํ›„์˜ ๊ธฐ๋Œ“๊ฐ’, ์ฆ‰ ํ‰๊ท ๊ฐ’

2๏ธโƒฃ ๋…ธ์ด์ฆˆ
๋…ธ์ด์ฆˆ๋ฅผ ์ƒ˜ํ”Œ๋งํ•ด์™€์„œ ์ƒ์„ฑ์ž G์— ๋„ฃ์–ด์„œ ๊ฐ€์งœ ์ด๋ฏธ์ง€๋ฅผ ๋งŒ๋“  ํ›„ ํŒ๋ณ„์ž D์— ๋„ฃ์€ ๊ฐ’์„ 1์—์„œ ๋นผ์„œ Log๋ฅผ ์ทจํ•œ ๊ฐ’์˜ ํ‰๊ท ๊ฐ’
์ƒ์„ฑ์ž์— ๋Œ€ํ•œ ๊ฐœ๋…์ด ํฌํ•จ. ๊ธฐ๋ณธ์ ์œผ๋กœ ์ƒ์„ฑ์ž๋Š” ๋…ธ์ด์ฆˆ ๋ฒกํ„ฐ๋กœ๋ถ€ํ„ฐ ์ด๋ฏธ์ง€๋ฅผ ๋ฐ›์•„์™€์„œ ์ƒˆ๋กœ์šด ์ด๋ฏธ์ง€๋ฅผ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค. ์ง„์งœ ์ด๋ฏธ์ง€๋Š” 1, ๊ฐ€์งœ ์ด๋ฏธ์ง€๋Š” 0.

๋ชจ๋ธ ํ•™์Šต

๋งค epoch๋‹น Descriminator๋ฅผ ๋จผ์ € ํ•™์Šตํ•˜๊ณ , Generator์˜ ํ•™์Šต์ด ์ด๋ฃจ์–ด์ง„๋‹ค.
Descriminator๋Š” ๊ธฐ์šธ๊ธฐ (Stochastic gradient) ๊ฐ€ ์ฆ๊ฐ€ํ•˜๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ํ•™์Šต๋˜๊ณ , generator๋Š” ๊ธฐ์šธ๊ธฐ๊ฐ€ ๊ฐ์†Œํ•˜๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ํ•™์Šต๋œ๋‹ค.

GAN ํŠน์ง•

  • Not cherry-picked
    • ์ด๋ฏธ์ง€๋ฅผ ์„ ๋ณ„ํ•ด์„œ ๋„ฃ์€ ๊ฒŒ ์•„๋‹ˆ๋ผ ๋žœ๋คํ•˜๊ฒŒ ๋„ฃ์–ด์คŒ
  • Not memorized the training set
    • ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ๋‹จ์ˆœํžˆ ์•”๊ธฐํ•œ ๊ฒƒ์ด ์•„๋‹˜
  • Competitive with the better generative models
    • ์ด์ „์˜ ๋‹ค๋ฅธ ์ƒ์„ฑ ๋ชจ๋ธ๊ณผ ๋น„๊ตํ–ˆ์„ ๋•Œ ์ถฉ๋ถ„ํžˆ ์ข‹์€ ์„ฑ๋Šฅ์ด ๋‚˜์˜ด
  • Images represent sharp
    • Blurryํ•˜๊ฒŒ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜์ง€ ์•Š๊ณ  ๊ฝค ์„ ๋ช…ํ•˜๊ฒŒ ์ƒ์„ฑ
  • ๋ฌธ์žฅ์ด ๊ธธ์–ด์งˆ์ˆ˜๋ก ์ƒ์„ฑ๋œ ๋ฌธ์žฅ์˜ ํ’ˆ์งˆ์ด ์•ˆ ์ข‹์•„์ง

๐Ÿ’ป ์ฝ”๋“œ ์‹ค์Šต - GAN

์ถœ์ฒ˜ : https://github.com/ndb796/Deep-Learning-Paper-Review-and-Practice/blob/master/code_practices/GAN_for_MNIST_Tutorial.ipynb

ํ•ด๋‹น ์ฝ”๋“œ๋Š” MNIST ๋ฐ์ดํ„ฐ์…‹์„ ๊ฐ€์ง€๊ณ  ๊ฐ€์žฅ ๊ธฐ๋ณธ์ ์ธ GAN ๋ชจ๋ธ์„ ํ•™์Šตํ•œ ์ฝ”๋“œ์ด๋‹ค.

์ƒ์„ฑ์ž Generator ๋ฐ ํŒ๋ณ„์ž Discriminator ๋ชจ๋ธ ์ •์˜

latent_dim = 100

# ์ƒ์„ฑ์ž(Generator) ํด๋ž˜์Šค ์ •์˜
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # ํ•˜๋‚˜์˜ ๋ธ”๋ก(block) ์ •์˜
        def block(input_dim, output_dim, normalize=True):
            layers = [nn.Linear(input_dim, output_dim)]
            if normalize:
                # ๋ฐฐ์น˜ ์ •๊ทœํ™”(batch normalization) ์ˆ˜ํ–‰(์ฐจ์› ๋™์ผ)
                layers.append(nn.BatchNorm1d(output_dim, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # ์ƒ์„ฑ์ž ๋ชจ๋ธ์€ ์—ฐ์†์ ์ธ ์—ฌ๋Ÿฌ ๊ฐœ์˜ ๋ธ”๋ก์„ ๊ฐ€์ง
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, 1 * 28 * 28),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img
# ํŒ๋ณ„์ž(Discriminator) ํด๋ž˜์Šค ์ •์˜
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(1 * 28 * 28, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    # ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ํŒ๋ณ„ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ˜ํ™˜
    def forward(self, img):
        flattened = img.view(img.size(0), -1)
        output = self.model(flattened)

        return output

ํ•™์Šต ๋ฐ์ดํ„ฐ์…‹ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

transforms_train = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = datasets.MNIST(root="./dataset", train=True, download=True, transform=transforms_train)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

๋ชจ๋ธ ํ•™์Šต ๋ฐ ์ƒ˜ํ”Œ๋ง

# ์ƒ์„ฑ์ž(generator)์™€ ํŒ๋ณ„์ž(discriminator) ์ดˆ๊ธฐํ™”
generator = Generator()
discriminator = Discriminator()

generator.cuda()
discriminator.cuda()

# ์†์‹ค ํ•จ์ˆ˜(loss function)
adversarial_loss = nn.BCELoss()
adversarial_loss.cuda()

# ํ•™์Šต๋ฅ (learning rate) ์„ค์ •
lr = 0.0002

# ์ƒ์„ฑ์ž์™€ ํŒ๋ณ„์ž๋ฅผ ์œ„ํ•œ ์ตœ์ ํ™” ํ•จ์ˆ˜
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
import time

n_epochs = 200 # ํ•™์Šต์˜ ํšŸ์ˆ˜(epoch) ์„ค์ •
sample_interval = 2000 # ๋ช‡ ๋ฒˆ์˜ ๋ฐฐ์น˜(batch)๋งˆ๋‹ค ๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•  ๊ฒƒ์ธ์ง€ ์„ค์ •
start_time = time.time()

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # ์ง„์งœ(real) ์ด๋ฏธ์ง€์™€ ๊ฐ€์งœ(fake) ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์ •๋‹ต ๋ ˆ์ด๋ธ” ์ƒ์„ฑ
        real = torch.cuda.FloatTensor(imgs.size(0), 1).fill_(1.0) # ์ง„์งœ(real): 1
        fake = torch.cuda.FloatTensor(imgs.size(0), 1).fill_(0.0) # ๊ฐ€์งœ(fake): 0

        real_imgs = imgs.cuda()

        """ ์ƒ์„ฑ์ž(generator)๋ฅผ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค. """
        optimizer_G.zero_grad()

        # ๋žœ๋ค ๋…ธ์ด์ฆˆ(noise) ์ƒ˜ํ”Œ๋ง
        z = torch.normal(mean=0, std=1, size=(imgs.shape[0], latent_dim)).cuda()

        # ์ด๋ฏธ์ง€ ์ƒ์„ฑ
        generated_imgs = generator(z)

        # ์ƒ์„ฑ์ž(generator)์˜ ์†์‹ค(loss) ๊ฐ’ ๊ณ„์‚ฐ
        g_loss = adversarial_loss(discriminator(generated_imgs), real)

        # ์ƒ์„ฑ์ž(generator) ์—…๋ฐ์ดํŠธ
        g_loss.backward()
        optimizer_G.step()

        """ ํŒ๋ณ„์ž(discriminator)๋ฅผ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค. """
        optimizer_D.zero_grad()

        # ํŒ๋ณ„์ž(discriminator)์˜ ์†์‹ค(loss) ๊ฐ’ ๊ณ„์‚ฐ
        real_loss = adversarial_loss(discriminator(real_imgs), real)
        fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        # ํŒ๋ณ„์ž(discriminator) ์—…๋ฐ์ดํŠธ
        d_loss.backward()
        optimizer_D.step()

        done = epoch * len(dataloader) + i
        if done % sample_interval == 0:
            # ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€ ์ค‘์—์„œ 25๊ฐœ๋งŒ ์„ ํƒํ•˜์—ฌ 5 X 5 ๊ฒฉ์ž ์ด๋ฏธ์ง€์— ์ถœ๋ ฅ
            save_image(generated_imgs.data[:25], f"{done}.png", nrow=5, normalize=True)

    # epoch 10๋งˆ๋‹ค ๋กœ๊ทธ(log) ์ถœ๋ ฅ
    if epoch % 10 == 0:
        print(f"[Epoch {epoch}/{n_epochs}] [D loss: {d_loss.item():.6f}] [G loss: {g_loss.item():.6f}] [Elapsed time: {time.time() - start_time:.2f}s]")

์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€ ์ถœ๋ ฅ

from IPython.display import Image

Image('92000.png')

์ฐธ๊ณ 

[YouTube] GAN:Generative Adversarial Networks

Make GANs Training Easier for Everyone : Generate Images Following a Sketch

๋…ผ๋ฌธ Generative Adversarial Networks (NIPS 2014)

 

Related Posts



๐Ÿ’™ You need to log in to GitHub to write comments. ๐Ÿ’™
If you can't see comments, please refresh page(F5).