Fliw
commited on
Commit
·
78598be
1
Parent(s):
341513e
chore(inference) : add gradio model
Browse files- .gitattributes +2 -0
- JupyterNotebook.ipynb +0 -0
- app.py +44 -0
- model/d.model +3 -0
- model/d_optim.pth +3 -0
- model/g.model +3 -0
- model/g_optim.pth +3 -0
- progan_modules.py +250 -0
- requirements.txt +8 -0
- train.py +281 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
model/*.model filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
model/*.pth filter=lfs diff=lfs merge=lfs -text
|
JupyterNotebook.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, os, gradio as gr, numpy as np
|
| 2 |
+
from torchvision import utils, transforms
|
| 3 |
+
from progan_modules import Generator
|
| 4 |
+
|
| 5 |
+
CHECKPOINT_DIR = "./model"
|
| 6 |
+
Z_DIM, CHANNEL_SIZE = 128, 128
|
| 7 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 8 |
+
FIXED_STEP = 6
|
| 9 |
+
FIXED_ALPHA = 0.0
|
| 10 |
+
|
| 11 |
+
g_running = Generator(CHANNEL_SIZE, Z_DIM, pixel_norm=False, tanh=False).to(DEVICE)
|
| 12 |
+
g_running.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "g.model"), map_location=DEVICE))
|
| 13 |
+
g_running.eval()
|
| 14 |
+
|
| 15 |
+
to_pil = transforms.ToPILImage()
|
| 16 |
+
|
| 17 |
+
@torch.inference_mode()
|
| 18 |
+
def sample_images(n_images: int = 50, seed: int | None = None):
|
| 19 |
+
if seed is not None and seed >= 0:
|
| 20 |
+
torch.manual_seed(seed); np.random.seed(seed)
|
| 21 |
+
else:
|
| 22 |
+
torch.seed()
|
| 23 |
+
|
| 24 |
+
z = torch.randn(n_images, Z_DIM, device=DEVICE)
|
| 25 |
+
imgs = g_running(z, step=FIXED_STEP, alpha=FIXED_ALPHA).cpu()
|
| 26 |
+
|
| 27 |
+
grid = utils.make_grid(imgs, nrow=10, normalize=True, value_range=(-1, 1))
|
| 28 |
+
return to_pil(grid)
|
| 29 |
+
|
| 30 |
+
demo = gr.Interface(
|
| 31 |
+
fn=sample_images,
|
| 32 |
+
inputs=[
|
| 33 |
+
gr.Slider(1, 200, value=50, step=10, label="Jumlah Gambar (kelipatan 10)"),
|
| 34 |
+
gr.Number(value=-1, precision=0, label="Seed (‑1 = acak)"),
|
| 35 |
+
],
|
| 36 |
+
outputs=gr.Image(type="pil", label="Grid Hasil"),
|
| 37 |
+
title="Progressive Growing Generative Adversarial Network",
|
| 38 |
+
description="contoh implementasi PGGAN untuk dataset jerawat",
|
| 39 |
+
allow_flagging="never",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
demo.queue()
|
| 44 |
+
demo.launch(show_api=False, share=True)
|
model/d.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1ff8ae7d55d9126ccf99e1177d9e63f6884ab9404dc9501ff62b5d5752628cd
|
| 3 |
+
size 6396418
|
model/d_optim.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3165c039f10cd0d541e88b3dfbf6dfe2a13e0aca6e7f771d6eaa1f610f451b04
|
| 3 |
+
size 12640648
|
model/g.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e14508dc41768a63d4a5021547215212ba708df99efe53cee3bbeefbe54e188b
|
| 3 |
+
size 6422598
|
model/g_optim.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92601d2b20eaade8f3f94aea3077019fdafecd59799095ff13ed5236a9308c54
|
| 3 |
+
size 12694058
|
progan_modules.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
from math import sqrt
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class EqualLR:
|
| 9 |
+
def __init__(self, name):
|
| 10 |
+
self.name = name
|
| 11 |
+
|
| 12 |
+
def compute_weight(self, module):
|
| 13 |
+
weight = getattr(module, self.name + '_orig')
|
| 14 |
+
fan_in = weight.data.size(1) * weight.data[0][0].numel()
|
| 15 |
+
|
| 16 |
+
return weight * sqrt(2 / fan_in)
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def apply(module, name):
|
| 20 |
+
fn = EqualLR(name)
|
| 21 |
+
|
| 22 |
+
weight = getattr(module, name)
|
| 23 |
+
del module._parameters[name]
|
| 24 |
+
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
|
| 25 |
+
module.register_forward_pre_hook(fn)
|
| 26 |
+
|
| 27 |
+
return fn
|
| 28 |
+
|
| 29 |
+
def __call__(self, module, input):
|
| 30 |
+
weight = self.compute_weight(module)
|
| 31 |
+
setattr(module, self.name, weight)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def equal_lr(module, name='weight'):
|
| 35 |
+
EqualLR.apply(module, name)
|
| 36 |
+
|
| 37 |
+
return module
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class PixelNorm(nn.Module):
|
| 41 |
+
def __init__(self):
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
def forward(self, input):
|
| 45 |
+
return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True)
|
| 46 |
+
+ 1e-8)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class EqualConv2d(nn.Module):
|
| 50 |
+
def __init__(self, *args, **kwargs):
|
| 51 |
+
super().__init__()
|
| 52 |
+
|
| 53 |
+
conv = nn.Conv2d(*args, **kwargs)
|
| 54 |
+
conv.weight.data.normal_()
|
| 55 |
+
conv.bias.data.zero_()
|
| 56 |
+
self.conv = equal_lr(conv)
|
| 57 |
+
|
| 58 |
+
def forward(self, input):
|
| 59 |
+
return self.conv(input)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class EqualConvTranspose2d(nn.Module):
|
| 63 |
+
### additional module for OOGAN usage
|
| 64 |
+
def __init__(self, *args, **kwargs):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
conv = nn.ConvTranspose2d(*args, **kwargs)
|
| 68 |
+
conv.weight.data.normal_()
|
| 69 |
+
conv.bias.data.zero_()
|
| 70 |
+
self.conv = equal_lr(conv)
|
| 71 |
+
|
| 72 |
+
def forward(self, input):
|
| 73 |
+
return self.conv(input)
|
| 74 |
+
|
| 75 |
+
class EqualLinear(nn.Module):
|
| 76 |
+
def __init__(self, in_dim, out_dim):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
linear = nn.Linear(in_dim, out_dim)
|
| 80 |
+
linear.weight.data.normal_()
|
| 81 |
+
linear.bias.data.zero_()
|
| 82 |
+
|
| 83 |
+
self.linear = equal_lr(linear)
|
| 84 |
+
|
| 85 |
+
def forward(self, input):
|
| 86 |
+
return self.linear(input)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class ConvBlock(nn.Module):
|
| 90 |
+
def __init__(self, in_channel, out_channel, kernel_size, padding, kernel_size2=None, padding2=None, pixel_norm=True):
|
| 91 |
+
super().__init__()
|
| 92 |
+
|
| 93 |
+
pad1 = padding
|
| 94 |
+
pad2 = padding
|
| 95 |
+
if padding2 is not None:
|
| 96 |
+
pad2 = padding2
|
| 97 |
+
|
| 98 |
+
kernel1 = kernel_size
|
| 99 |
+
kernel2 = kernel_size
|
| 100 |
+
if kernel_size2 is not None:
|
| 101 |
+
kernel2 = kernel_size2
|
| 102 |
+
|
| 103 |
+
convs = [EqualConv2d(in_channel, out_channel, kernel1, padding=pad1)]
|
| 104 |
+
if pixel_norm:
|
| 105 |
+
convs.append(PixelNorm())
|
| 106 |
+
convs.append(nn.LeakyReLU(0.1))
|
| 107 |
+
convs.append(EqualConv2d(out_channel, out_channel, kernel2, padding=pad2))
|
| 108 |
+
if pixel_norm:
|
| 109 |
+
convs.append(PixelNorm())
|
| 110 |
+
convs.append(nn.LeakyReLU(0.1))
|
| 111 |
+
|
| 112 |
+
self.conv = nn.Sequential(*convs)
|
| 113 |
+
|
| 114 |
+
def forward(self, input):
|
| 115 |
+
out = self.conv(input)
|
| 116 |
+
return out
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def upscale(feat):
|
| 120 |
+
return F.interpolate(feat, scale_factor=2, mode='bilinear', align_corners=False)
|
| 121 |
+
|
| 122 |
+
class Generator(nn.Module):
|
| 123 |
+
def __init__(self, input_code_dim=128, in_channel=128, pixel_norm=True, tanh=True):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.input_dim = input_code_dim
|
| 126 |
+
self.tanh = tanh
|
| 127 |
+
self.input_layer = nn.Sequential(
|
| 128 |
+
EqualConvTranspose2d(input_code_dim, in_channel, 4, 1, 0),
|
| 129 |
+
PixelNorm(),
|
| 130 |
+
nn.LeakyReLU(0.1))
|
| 131 |
+
|
| 132 |
+
self.progression_4 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm)
|
| 133 |
+
self.progression_8 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm)
|
| 134 |
+
self.progression_16 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm)
|
| 135 |
+
self.progression_32 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm)
|
| 136 |
+
self.progression_64 = ConvBlock(in_channel, in_channel//2, 3, 1, pixel_norm=pixel_norm)
|
| 137 |
+
self.progression_128 = ConvBlock(in_channel//2, in_channel//4, 3, 1, pixel_norm=pixel_norm)
|
| 138 |
+
self.progression_256 = ConvBlock(in_channel//4, in_channel//4, 3, 1, pixel_norm=pixel_norm)
|
| 139 |
+
|
| 140 |
+
self.to_rgb_8 = EqualConv2d(in_channel, 3, 1)
|
| 141 |
+
self.to_rgb_16 = EqualConv2d(in_channel, 3, 1)
|
| 142 |
+
self.to_rgb_32 = EqualConv2d(in_channel, 3, 1)
|
| 143 |
+
self.to_rgb_64 = EqualConv2d(in_channel//2, 3, 1)
|
| 144 |
+
self.to_rgb_128 = EqualConv2d(in_channel//4, 3, 1)
|
| 145 |
+
self.to_rgb_256 = EqualConv2d(in_channel//4, 3, 1)
|
| 146 |
+
|
| 147 |
+
self.max_step = 6
|
| 148 |
+
|
| 149 |
+
def progress(self, feat, module):
|
| 150 |
+
out = F.interpolate(feat, scale_factor=2, mode='bilinear', align_corners=False)
|
| 151 |
+
out = module(out)
|
| 152 |
+
return out
|
| 153 |
+
|
| 154 |
+
def output(self, feat1, feat2, module1, module2, alpha):
|
| 155 |
+
if 0 <= alpha < 1:
|
| 156 |
+
skip_rgb = upscale(module1(feat1))
|
| 157 |
+
out = (1-alpha)*skip_rgb + alpha*module2(feat2)
|
| 158 |
+
else:
|
| 159 |
+
out = module2(feat2)
|
| 160 |
+
if self.tanh:
|
| 161 |
+
return torch.tanh(out)
|
| 162 |
+
return out
|
| 163 |
+
|
| 164 |
+
def forward(self, input, step=0, alpha=-1):
|
| 165 |
+
if step > self.max_step:
|
| 166 |
+
step = self.max_step
|
| 167 |
+
|
| 168 |
+
out_4 = self.input_layer(input.view(-1, self.input_dim, 1, 1))
|
| 169 |
+
out_4 = self.progression_4(out_4)
|
| 170 |
+
out_8 = self.progress(out_4, self.progression_8)
|
| 171 |
+
if step==1:
|
| 172 |
+
if self.tanh:
|
| 173 |
+
return torch.tanh(self.to_rgb_8(out_8))
|
| 174 |
+
return self.to_rgb_8(out_8)
|
| 175 |
+
|
| 176 |
+
out_16 = self.progress(out_8, self.progression_16)
|
| 177 |
+
if step==2:
|
| 178 |
+
return self.output( out_8, out_16, self.to_rgb_8, self.to_rgb_16, alpha )
|
| 179 |
+
|
| 180 |
+
out_32 = self.progress(out_16, self.progression_32)
|
| 181 |
+
if step==3:
|
| 182 |
+
return self.output( out_16, out_32, self.to_rgb_16, self.to_rgb_32, alpha )
|
| 183 |
+
|
| 184 |
+
out_64 = self.progress(out_32, self.progression_64)
|
| 185 |
+
if step==4:
|
| 186 |
+
return self.output( out_32, out_64, self.to_rgb_32, self.to_rgb_64, alpha )
|
| 187 |
+
|
| 188 |
+
out_128 = self.progress(out_64, self.progression_128)
|
| 189 |
+
if step==5:
|
| 190 |
+
return self.output( out_64, out_128, self.to_rgb_64, self.to_rgb_128, alpha )
|
| 191 |
+
|
| 192 |
+
out_256 = self.progress(out_128, self.progression_256)
|
| 193 |
+
if step==6:
|
| 194 |
+
return self.output( out_128, out_256, self.to_rgb_128, self.to_rgb_256, alpha )
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class Discriminator(nn.Module):
|
| 198 |
+
def __init__(self, feat_dim=128):
|
| 199 |
+
super().__init__()
|
| 200 |
+
|
| 201 |
+
self.progression = nn.ModuleList([ConvBlock(feat_dim//4, feat_dim//4, 3, 1),
|
| 202 |
+
ConvBlock(feat_dim//4, feat_dim//2, 3, 1),
|
| 203 |
+
ConvBlock(feat_dim//2, feat_dim, 3, 1),
|
| 204 |
+
ConvBlock(feat_dim, feat_dim, 3, 1),
|
| 205 |
+
ConvBlock(feat_dim, feat_dim, 3, 1),
|
| 206 |
+
ConvBlock(feat_dim, feat_dim, 3, 1),
|
| 207 |
+
ConvBlock(feat_dim+1, feat_dim, 3, 1, 4, 0)])
|
| 208 |
+
|
| 209 |
+
self.from_rgb = nn.ModuleList([EqualConv2d(3, feat_dim//4, 1),
|
| 210 |
+
EqualConv2d(3, feat_dim//4, 1),
|
| 211 |
+
EqualConv2d(3, feat_dim//2, 1),
|
| 212 |
+
EqualConv2d(3, feat_dim, 1),
|
| 213 |
+
EqualConv2d(3, feat_dim, 1),
|
| 214 |
+
EqualConv2d(3, feat_dim, 1),
|
| 215 |
+
EqualConv2d(3, feat_dim, 1)])
|
| 216 |
+
|
| 217 |
+
self.n_layer = len(self.progression)
|
| 218 |
+
|
| 219 |
+
self.linear = EqualLinear(feat_dim, 1)
|
| 220 |
+
|
| 221 |
+
def forward(self, input, step=0, alpha=-1):
|
| 222 |
+
for i in range(step, -1, -1):
|
| 223 |
+
index = self.n_layer - i - 1
|
| 224 |
+
|
| 225 |
+
if i == step:
|
| 226 |
+
out = self.from_rgb[index](input)
|
| 227 |
+
|
| 228 |
+
if i == 0:
|
| 229 |
+
out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)
|
| 230 |
+
mean_std = out_std.mean()
|
| 231 |
+
mean_std = mean_std.expand(out.size(0), 1, 4, 4)
|
| 232 |
+
out = torch.cat([out, mean_std], 1)
|
| 233 |
+
|
| 234 |
+
out = self.progression[index](out)
|
| 235 |
+
|
| 236 |
+
if i > 0:
|
| 237 |
+
# out = F.avg_pool2d(out, 2)
|
| 238 |
+
out = F.interpolate(out, scale_factor=0.5, mode='bilinear', align_corners=False)
|
| 239 |
+
|
| 240 |
+
if i == step and 0 <= alpha < 1:
|
| 241 |
+
# skip_rgb = F.avg_pool2d(input, 2)
|
| 242 |
+
skip_rgb = F.interpolate(input, scale_factor=0.5, mode='bilinear', align_corners=False)
|
| 243 |
+
skip_rgb = self.from_rgb[index + 1](skip_rgb)
|
| 244 |
+
out = (1 - alpha) * skip_rgb + alpha * out
|
| 245 |
+
|
| 246 |
+
out = out.squeeze(2).squeeze(2)
|
| 247 |
+
# print(input.size(), out.size(), step)
|
| 248 |
+
out = self.linear(out)
|
| 249 |
+
|
| 250 |
+
return out
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.0.0
|
| 2 |
+
torchvision==0.15.0
|
| 3 |
+
numpy==1.23.4
|
| 4 |
+
Pillow==9.2.0
|
| 5 |
+
tqdm==4.64.1
|
| 6 |
+
pydantic==2.10.6
|
| 7 |
+
gradio==3.32.0
|
| 8 |
+
gradio_client>=0.13.0
|
train.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import tqdm
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import argparse
|
| 5 |
+
import random
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import os
|
| 9 |
+
from torch import nn, optim
|
| 10 |
+
from torch.autograd import Variable, grad
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from torchvision import datasets, transforms, utils
|
| 13 |
+
|
| 14 |
+
from progan_modules import Generator, Discriminator
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def accumulate(model1, model2, decay=0.999):
|
| 18 |
+
par1 = dict(model1.named_parameters())
|
| 19 |
+
par2 = dict(model2.named_parameters())
|
| 20 |
+
|
| 21 |
+
for k in par1.keys():
|
| 22 |
+
par1[k].data.mul_(decay).add_(par2[k].data, alpha=(1 - decay))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def imagefolder_loader(path):
|
| 26 |
+
def loader(transform):
|
| 27 |
+
data = datasets.ImageFolder(path, transform=transform)
|
| 28 |
+
data_loader = DataLoader(data, shuffle=True, batch_size=batch_size, num_workers=2)
|
| 29 |
+
return data_loader
|
| 30 |
+
return loader
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def sample_data(dataloader, image_size=4):
|
| 34 |
+
transform = transforms.Compose([
|
| 35 |
+
transforms.Resize(image_size+int(image_size*0.2)+1),
|
| 36 |
+
transforms.RandomCrop(image_size),
|
| 37 |
+
transforms.RandomHorizontalFlip(),
|
| 38 |
+
transforms.ToTensor(),
|
| 39 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
| 40 |
+
])
|
| 41 |
+
|
| 42 |
+
loader = dataloader(transform)
|
| 43 |
+
|
| 44 |
+
return loader
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def train(generator, discriminator, init_step, loader, total_iter=600000, start_iter=0):
|
| 48 |
+
step = init_step # can be 1 = 8, 2 = 16, 3 = 32, 4 = 64, 5 = 128, 6 = 128
|
| 49 |
+
data_loader = sample_data(loader, 4 * 2 ** step)
|
| 50 |
+
dataset = iter(data_loader)
|
| 51 |
+
|
| 52 |
+
#total_iter = 600000
|
| 53 |
+
total_iter_remain = total_iter - (total_iter//6)*(step-1)
|
| 54 |
+
|
| 55 |
+
pbar = tqdm(range(total_iter_remain))
|
| 56 |
+
|
| 57 |
+
disc_loss_val = 0
|
| 58 |
+
gen_loss_val = 0
|
| 59 |
+
grad_loss_val = 0
|
| 60 |
+
|
| 61 |
+
from datetime import datetime
|
| 62 |
+
import os
|
| 63 |
+
date_time = datetime.now()
|
| 64 |
+
post_fix = '%s_%s_%d_%d.txt'%(trial_name, date_time.date(), date_time.hour, date_time.minute)
|
| 65 |
+
log_folder = 'trial_%s_%s_%d_%d'%(trial_name, date_time.date(), date_time.hour, date_time.minute)
|
| 66 |
+
|
| 67 |
+
os.mkdir(log_folder)
|
| 68 |
+
os.mkdir(log_folder+'/checkpoint')
|
| 69 |
+
os.mkdir(log_folder+'/sample')
|
| 70 |
+
|
| 71 |
+
config_file_name = os.path.join(log_folder, 'train_config_'+post_fix)
|
| 72 |
+
config_file = open(config_file_name, 'w')
|
| 73 |
+
config_file.write(str(args))
|
| 74 |
+
config_file.close()
|
| 75 |
+
|
| 76 |
+
log_file_name = os.path.join(log_folder, 'train_log_'+post_fix)
|
| 77 |
+
log_file = open(log_file_name, 'w')
|
| 78 |
+
log_file.write('g,d,nll,onehot\n')
|
| 79 |
+
log_file.close()
|
| 80 |
+
|
| 81 |
+
from shutil import copy
|
| 82 |
+
copy('train.py', log_folder+'/train_%s.py'%post_fix)
|
| 83 |
+
copy('progan_modules.py', log_folder+'/model_%s.py'%post_fix)
|
| 84 |
+
|
| 85 |
+
alpha = 0
|
| 86 |
+
#one = torch.FloatTensor([1]).to(device)
|
| 87 |
+
one = torch.tensor(1, dtype=torch.float).to(device)
|
| 88 |
+
mone = one * -1
|
| 89 |
+
iteration = 0
|
| 90 |
+
|
| 91 |
+
for i in pbar:
|
| 92 |
+
discriminator.zero_grad()
|
| 93 |
+
|
| 94 |
+
alpha = min(1, (2/(total_iter//6)) * iteration)
|
| 95 |
+
|
| 96 |
+
if iteration > total_iter//6:
|
| 97 |
+
alpha = 0
|
| 98 |
+
iteration = 0
|
| 99 |
+
step += 1
|
| 100 |
+
|
| 101 |
+
if step > 6:
|
| 102 |
+
alpha = 1
|
| 103 |
+
step = 6
|
| 104 |
+
data_loader = sample_data(loader, 4 * 2 ** step)
|
| 105 |
+
dataset = iter(data_loader)
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
real_image, label = next(dataset)
|
| 109 |
+
|
| 110 |
+
except (OSError, StopIteration):
|
| 111 |
+
dataset = iter(data_loader)
|
| 112 |
+
real_image, label = next(dataset)
|
| 113 |
+
|
| 114 |
+
iteration += 1
|
| 115 |
+
|
| 116 |
+
### 1. train Discriminator
|
| 117 |
+
b_size = real_image.size(0)
|
| 118 |
+
real_image = real_image.to(device)
|
| 119 |
+
label = label.to(device)
|
| 120 |
+
real_predict = discriminator(
|
| 121 |
+
real_image, step=step, alpha=alpha)
|
| 122 |
+
real_predict = real_predict.mean() \
|
| 123 |
+
- 0.001 * (real_predict ** 2).mean()
|
| 124 |
+
real_predict.backward(mone)
|
| 125 |
+
|
| 126 |
+
# sample input data: vector for Generator
|
| 127 |
+
gen_z = torch.randn(b_size, input_code_size).to(device)
|
| 128 |
+
|
| 129 |
+
fake_image = generator(gen_z, step=step, alpha=alpha)
|
| 130 |
+
fake_predict = discriminator(
|
| 131 |
+
fake_image.detach(), step=step, alpha=alpha)
|
| 132 |
+
fake_predict = fake_predict.mean()
|
| 133 |
+
fake_predict.backward(one)
|
| 134 |
+
|
| 135 |
+
### gradient penalty for D
|
| 136 |
+
eps = torch.rand(b_size, 1, 1, 1).to(device)
|
| 137 |
+
x_hat = eps * real_image.data + (1 - eps) * fake_image.detach().data
|
| 138 |
+
x_hat.requires_grad = True
|
| 139 |
+
hat_predict = discriminator(x_hat, step=step, alpha=alpha)
|
| 140 |
+
grad_x_hat = grad(
|
| 141 |
+
outputs=hat_predict.sum(), inputs=x_hat, create_graph=True)[0]
|
| 142 |
+
grad_penalty = ((grad_x_hat.view(grad_x_hat.size(0), -1)
|
| 143 |
+
.norm(2, dim=1) - 1)**2).mean()
|
| 144 |
+
grad_penalty = 10 * grad_penalty
|
| 145 |
+
grad_penalty.backward()
|
| 146 |
+
grad_loss_val += grad_penalty.item()
|
| 147 |
+
disc_loss_val += (real_predict - fake_predict).item()
|
| 148 |
+
|
| 149 |
+
d_optimizer.step()
|
| 150 |
+
|
| 151 |
+
### 2. train Generator
|
| 152 |
+
if (i + 1) % n_critic == 0:
|
| 153 |
+
generator.zero_grad()
|
| 154 |
+
discriminator.zero_grad()
|
| 155 |
+
|
| 156 |
+
predict = discriminator(fake_image, step=step, alpha=alpha)
|
| 157 |
+
|
| 158 |
+
loss = -predict.mean()
|
| 159 |
+
gen_loss_val += loss.item()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
loss.backward()
|
| 163 |
+
g_optimizer.step()
|
| 164 |
+
accumulate(g_running, generator)
|
| 165 |
+
|
| 166 |
+
if (i + 1) % 1000 == 0 or i==0:
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
images = g_running(torch.randn(5 * 10, input_code_size).to(device), step=step, alpha=alpha).data.cpu()
|
| 169 |
+
|
| 170 |
+
utils.save_image(
|
| 171 |
+
images,
|
| 172 |
+
f'{log_folder}/sample/{str(i + 1).zfill(6)}.png',
|
| 173 |
+
nrow=10,
|
| 174 |
+
normalize=True)
|
| 175 |
+
|
| 176 |
+
if (i+1) % 10000 == 0 or i==0:
|
| 177 |
+
try:
|
| 178 |
+
torch.save(g_running.state_dict(), f'{log_folder}/checkpoint/{str(i + 1).zfill(6)}_g.model')
|
| 179 |
+
torch.save(discriminator.state_dict(), f'{log_folder}/checkpoint/{str(i + 1).zfill(6)}_d.model')
|
| 180 |
+
torch.save(g_optimizer.state_dict(), os.path.join(log_folder, 'checkpoint', f'{str(i + 1).zfill(6)}_g_optim.pth'))
|
| 181 |
+
torch.save(d_optimizer.state_dict(), os.path.join(log_folder, 'checkpoint', f'{str(i + 1).zfill(6)}_d_optim.pth'))
|
| 182 |
+
except:
|
| 183 |
+
pass
|
| 184 |
+
|
| 185 |
+
if (i+1)%500 == 0:
|
| 186 |
+
state_msg = (f'{i + 1}; G: {gen_loss_val/(500//n_critic):.3f}; D: {disc_loss_val/500:.3f};'
|
| 187 |
+
f' Grad: {grad_loss_val/500:.3f}; Alpha: {alpha:.3f}')
|
| 188 |
+
|
| 189 |
+
log_file = open(log_file_name, 'a+')
|
| 190 |
+
new_line = "%.5f,%.5f\n"%(gen_loss_val/(500//n_critic), disc_loss_val/500)
|
| 191 |
+
log_file.write(new_line)
|
| 192 |
+
log_file.close()
|
| 193 |
+
|
| 194 |
+
disc_loss_val = 0
|
| 195 |
+
gen_loss_val = 0
|
| 196 |
+
grad_loss_val = 0
|
| 197 |
+
|
| 198 |
+
print(state_msg)
|
| 199 |
+
#pbar.set_description(state_msg)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == '__main__':
|
| 203 |
+
parser = argparse.ArgumentParser(description='Progressive GAN, during training, the model will learn to generate images from a low resolution, then progressively getting high resolution ')
|
| 204 |
+
|
| 205 |
+
parser.add_argument('--start_iter', type=int, default=0, help='Iterasi awal dari training')
|
| 206 |
+
parser.add_argument('--checkpoint', type=str, default="/content/model/", help='Path to model checkpoint directory (default: None, train from scratch)')
|
| 207 |
+
parser.add_argument('--path', type=str,default="/content/merged_dataset/Acne", help='path of specified dataset, should be a folder that has one or many sub image folders inside')
|
| 208 |
+
parser.add_argument('--trial_name', type=str, default="test1", help='a brief description of the training trial')
|
| 209 |
+
parser.add_argument('--gpu_id', type=int, default=0, help='0 is the first gpu, 1 is the second gpu, etc.')
|
| 210 |
+
parser.add_argument('--lr', type=float, default=0.001, help='learning rate, default is 1e-3, usually dont need to change it, you can try make it bigger, such as 2e-3')
|
| 211 |
+
parser.add_argument('--z_dim', type=int, default=128, help='the initial latent vector\'s dimension, can be smaller such as 64, if the dataset is not diverse')
|
| 212 |
+
parser.add_argument('--channel', type=int, default=128, help='determines how big the model is, smaller value means faster training, but less capacity of the model')
|
| 213 |
+
parser.add_argument('--batch_size', type=int, default=4, help='how many images to train together at one iteration')
|
| 214 |
+
parser.add_argument('--n_critic', type=int, default=1, help='train Dhow many times while train G 1 time')
|
| 215 |
+
parser.add_argument('--init_step', type=int, default=1, help='start from what resolution, 1 means 8x8 resolution, 2 means 16x16 resolution, ..., 6 means 256x256 resolution')
|
| 216 |
+
parser.add_argument('--total_iter', type=int, default=300000, help='how many iterations to train in total, the value is in assumption that init step is 1')
|
| 217 |
+
parser.add_argument('--pixel_norm', default=False, action="store_true", help='a normalization method inside the model, you can try use it or not depends on the dataset')
|
| 218 |
+
parser.add_argument('--tanh', default=False, action="store_true", help='an output non-linearity on the output of Generator, you can try use it or not depends on the dataset')
|
| 219 |
+
|
| 220 |
+
args = parser.parse_args()
|
| 221 |
+
|
| 222 |
+
trial_name = args.trial_name
|
| 223 |
+
device = torch.device("cuda:%d"%(args.gpu_id))
|
| 224 |
+
input_code_size = args.z_dim
|
| 225 |
+
batch_size = args.batch_size
|
| 226 |
+
n_critic = args.n_critic
|
| 227 |
+
|
| 228 |
+
generator = Generator(in_channel=args.channel, input_code_dim=input_code_size, pixel_norm=args.pixel_norm, tanh=args.tanh).to(device)
|
| 229 |
+
discriminator = Discriminator(feat_dim=args.channel).to(device)
|
| 230 |
+
g_running = Generator(in_channel=args.channel, input_code_dim=input_code_size, pixel_norm=args.pixel_norm, tanh=args.tanh).to(device)
|
| 231 |
+
|
| 232 |
+
## you can directly load a pretrained model here
|
| 233 |
+
if args.checkpoint:
|
| 234 |
+
generator_path = os.path.join(args.checkpoint, "g.model")
|
| 235 |
+
discriminator_path = os.path.join(args.checkpoint, "d.model")
|
| 236 |
+
|
| 237 |
+
if os.path.exists(generator_path) and os.path.exists(discriminator_path):
|
| 238 |
+
print(f"Loading checkpoints from {args.checkpoint}...")
|
| 239 |
+
generator.load_state_dict(torch.load(generator_path))
|
| 240 |
+
g_running.load_state_dict(torch.load(generator_path))
|
| 241 |
+
discriminator.load_state_dict(torch.load(discriminator_path))
|
| 242 |
+
else:
|
| 243 |
+
print(f"Warning: Checkpoint not found at {args.checkpoint}. Training from scratch!")
|
| 244 |
+
else:
|
| 245 |
+
print("No checkpoint provided, training from scratch.")
|
| 246 |
+
|
| 247 |
+
if args.checkpoint:
|
| 248 |
+
generator_path = os.path.join(args.checkpoint, "g.model")
|
| 249 |
+
discriminator_path = os.path.join(args.checkpoint, "d.model")
|
| 250 |
+
optimizer_g_path = os.path.join(args.checkpoint, "g_optim.pth")
|
| 251 |
+
optimizer_d_path = os.path.join(args.checkpoint, "d_optim.pth")
|
| 252 |
+
|
| 253 |
+
if os.path.exists(generator_path) and os.path.exists(discriminator_path):
|
| 254 |
+
print(f"Loading checkpoints from {args.checkpoint}...")
|
| 255 |
+
generator.load_state_dict(torch.load(generator_path))
|
| 256 |
+
g_running.load_state_dict(torch.load(generator_path))
|
| 257 |
+
discriminator.load_state_dict(torch.load(discriminator_path))
|
| 258 |
+
else:
|
| 259 |
+
print(f"Warning: Checkpoint not found at {args.checkpoint}. Training from scratch!")
|
| 260 |
+
else:
|
| 261 |
+
print("No checkpoint provided, training from scratch.")
|
| 262 |
+
|
| 263 |
+
g_running.train(False)
|
| 264 |
+
|
| 265 |
+
g_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.0, 0.99))
|
| 266 |
+
d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99))
|
| 267 |
+
|
| 268 |
+
optimizer_g_path = os.path.join(args.checkpoint, "g_optim.pth")
|
| 269 |
+
optimizer_d_path = os.path.join(args.checkpoint, "d_optim.pth")
|
| 270 |
+
|
| 271 |
+
if os.path.exists(optimizer_g_path) and os.path.exists(optimizer_d_path):
|
| 272 |
+
g_optimizer.load_state_dict(torch.load(optimizer_g_path))
|
| 273 |
+
d_optimizer.load_state_dict(torch.load(optimizer_d_path))
|
| 274 |
+
print("Optimizers loaded successfully!")
|
| 275 |
+
else:
|
| 276 |
+
print("Warning: Optimizer checkpoint not found. Using new optimizers!")
|
| 277 |
+
accumulate(g_running, generator, 0)
|
| 278 |
+
|
| 279 |
+
loader = imagefolder_loader(args.path)
|
| 280 |
+
|
| 281 |
+
train(generator, discriminator, args.init_step, loader, args.total_iter, args.start_iter)
|