Fliw commited on
Commit
78598be
·
1 Parent(s): 341513e

chore(inference) : add gradio model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model/*.model filter=lfs diff=lfs merge=lfs -text
37
+ model/*.pth filter=lfs diff=lfs merge=lfs -text
JupyterNotebook.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os, gradio as gr, numpy as np
2
+ from torchvision import utils, transforms
3
+ from progan_modules import Generator
4
+
5
+ CHECKPOINT_DIR = "./model"
6
+ Z_DIM, CHANNEL_SIZE = 128, 128
7
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
8
+ FIXED_STEP = 6
9
+ FIXED_ALPHA = 0.0
10
+
11
+ g_running = Generator(CHANNEL_SIZE, Z_DIM, pixel_norm=False, tanh=False).to(DEVICE)
12
+ g_running.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "g.model"), map_location=DEVICE))
13
+ g_running.eval()
14
+
15
+ to_pil = transforms.ToPILImage()
16
+
17
+ @torch.inference_mode()
18
+ def sample_images(n_images: int = 50, seed: int | None = None):
19
+ if seed is not None and seed >= 0:
20
+ torch.manual_seed(seed); np.random.seed(seed)
21
+ else:
22
+ torch.seed()
23
+
24
+ z = torch.randn(n_images, Z_DIM, device=DEVICE)
25
+ imgs = g_running(z, step=FIXED_STEP, alpha=FIXED_ALPHA).cpu()
26
+
27
+ grid = utils.make_grid(imgs, nrow=10, normalize=True, value_range=(-1, 1))
28
+ return to_pil(grid)
29
+
30
+ demo = gr.Interface(
31
+ fn=sample_images,
32
+ inputs=[
33
+ gr.Slider(1, 200, value=50, step=10, label="Jumlah Gambar (kelipatan 10)"),
34
+ gr.Number(value=-1, precision=0, label="Seed (‑1 = acak)"),
35
+ ],
36
+ outputs=gr.Image(type="pil", label="Grid Hasil"),
37
+ title="Progressive Growing Generative Adversarial Network",
38
+ description="contoh implementasi PGGAN untuk dataset jerawat",
39
+ allow_flagging="never",
40
+ )
41
+
42
+ if __name__ == "__main__":
43
+ demo.queue()
44
+ demo.launch(show_api=False, share=True)
model/d.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1ff8ae7d55d9126ccf99e1177d9e63f6884ab9404dc9501ff62b5d5752628cd
3
+ size 6396418
model/d_optim.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3165c039f10cd0d541e88b3dfbf6dfe2a13e0aca6e7f771d6eaa1f610f451b04
3
+ size 12640648
model/g.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e14508dc41768a63d4a5021547215212ba708df99efe53cee3bbeefbe54e188b
3
+ size 6422598
model/g_optim.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92601d2b20eaade8f3f94aea3077019fdafecd59799095ff13ed5236a9308c54
3
+ size 12694058
progan_modules.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from math import sqrt
6
+
7
+
8
+ class EqualLR:
9
+ def __init__(self, name):
10
+ self.name = name
11
+
12
+ def compute_weight(self, module):
13
+ weight = getattr(module, self.name + '_orig')
14
+ fan_in = weight.data.size(1) * weight.data[0][0].numel()
15
+
16
+ return weight * sqrt(2 / fan_in)
17
+
18
+ @staticmethod
19
+ def apply(module, name):
20
+ fn = EqualLR(name)
21
+
22
+ weight = getattr(module, name)
23
+ del module._parameters[name]
24
+ module.register_parameter(name + '_orig', nn.Parameter(weight.data))
25
+ module.register_forward_pre_hook(fn)
26
+
27
+ return fn
28
+
29
+ def __call__(self, module, input):
30
+ weight = self.compute_weight(module)
31
+ setattr(module, self.name, weight)
32
+
33
+
34
+ def equal_lr(module, name='weight'):
35
+ EqualLR.apply(module, name)
36
+
37
+ return module
38
+
39
+
40
+ class PixelNorm(nn.Module):
41
+ def __init__(self):
42
+ super().__init__()
43
+
44
+ def forward(self, input):
45
+ return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True)
46
+ + 1e-8)
47
+
48
+
49
+ class EqualConv2d(nn.Module):
50
+ def __init__(self, *args, **kwargs):
51
+ super().__init__()
52
+
53
+ conv = nn.Conv2d(*args, **kwargs)
54
+ conv.weight.data.normal_()
55
+ conv.bias.data.zero_()
56
+ self.conv = equal_lr(conv)
57
+
58
+ def forward(self, input):
59
+ return self.conv(input)
60
+
61
+
62
+ class EqualConvTranspose2d(nn.Module):
63
+ ### additional module for OOGAN usage
64
+ def __init__(self, *args, **kwargs):
65
+ super().__init__()
66
+
67
+ conv = nn.ConvTranspose2d(*args, **kwargs)
68
+ conv.weight.data.normal_()
69
+ conv.bias.data.zero_()
70
+ self.conv = equal_lr(conv)
71
+
72
+ def forward(self, input):
73
+ return self.conv(input)
74
+
75
+ class EqualLinear(nn.Module):
76
+ def __init__(self, in_dim, out_dim):
77
+ super().__init__()
78
+
79
+ linear = nn.Linear(in_dim, out_dim)
80
+ linear.weight.data.normal_()
81
+ linear.bias.data.zero_()
82
+
83
+ self.linear = equal_lr(linear)
84
+
85
+ def forward(self, input):
86
+ return self.linear(input)
87
+
88
+
89
+ class ConvBlock(nn.Module):
90
+ def __init__(self, in_channel, out_channel, kernel_size, padding, kernel_size2=None, padding2=None, pixel_norm=True):
91
+ super().__init__()
92
+
93
+ pad1 = padding
94
+ pad2 = padding
95
+ if padding2 is not None:
96
+ pad2 = padding2
97
+
98
+ kernel1 = kernel_size
99
+ kernel2 = kernel_size
100
+ if kernel_size2 is not None:
101
+ kernel2 = kernel_size2
102
+
103
+ convs = [EqualConv2d(in_channel, out_channel, kernel1, padding=pad1)]
104
+ if pixel_norm:
105
+ convs.append(PixelNorm())
106
+ convs.append(nn.LeakyReLU(0.1))
107
+ convs.append(EqualConv2d(out_channel, out_channel, kernel2, padding=pad2))
108
+ if pixel_norm:
109
+ convs.append(PixelNorm())
110
+ convs.append(nn.LeakyReLU(0.1))
111
+
112
+ self.conv = nn.Sequential(*convs)
113
+
114
+ def forward(self, input):
115
+ out = self.conv(input)
116
+ return out
117
+
118
+
119
+ def upscale(feat):
120
+ return F.interpolate(feat, scale_factor=2, mode='bilinear', align_corners=False)
121
+
122
+ class Generator(nn.Module):
123
+ def __init__(self, input_code_dim=128, in_channel=128, pixel_norm=True, tanh=True):
124
+ super().__init__()
125
+ self.input_dim = input_code_dim
126
+ self.tanh = tanh
127
+ self.input_layer = nn.Sequential(
128
+ EqualConvTranspose2d(input_code_dim, in_channel, 4, 1, 0),
129
+ PixelNorm(),
130
+ nn.LeakyReLU(0.1))
131
+
132
+ self.progression_4 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm)
133
+ self.progression_8 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm)
134
+ self.progression_16 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm)
135
+ self.progression_32 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm)
136
+ self.progression_64 = ConvBlock(in_channel, in_channel//2, 3, 1, pixel_norm=pixel_norm)
137
+ self.progression_128 = ConvBlock(in_channel//2, in_channel//4, 3, 1, pixel_norm=pixel_norm)
138
+ self.progression_256 = ConvBlock(in_channel//4, in_channel//4, 3, 1, pixel_norm=pixel_norm)
139
+
140
+ self.to_rgb_8 = EqualConv2d(in_channel, 3, 1)
141
+ self.to_rgb_16 = EqualConv2d(in_channel, 3, 1)
142
+ self.to_rgb_32 = EqualConv2d(in_channel, 3, 1)
143
+ self.to_rgb_64 = EqualConv2d(in_channel//2, 3, 1)
144
+ self.to_rgb_128 = EqualConv2d(in_channel//4, 3, 1)
145
+ self.to_rgb_256 = EqualConv2d(in_channel//4, 3, 1)
146
+
147
+ self.max_step = 6
148
+
149
+ def progress(self, feat, module):
150
+ out = F.interpolate(feat, scale_factor=2, mode='bilinear', align_corners=False)
151
+ out = module(out)
152
+ return out
153
+
154
+ def output(self, feat1, feat2, module1, module2, alpha):
155
+ if 0 <= alpha < 1:
156
+ skip_rgb = upscale(module1(feat1))
157
+ out = (1-alpha)*skip_rgb + alpha*module2(feat2)
158
+ else:
159
+ out = module2(feat2)
160
+ if self.tanh:
161
+ return torch.tanh(out)
162
+ return out
163
+
164
+ def forward(self, input, step=0, alpha=-1):
165
+ if step > self.max_step:
166
+ step = self.max_step
167
+
168
+ out_4 = self.input_layer(input.view(-1, self.input_dim, 1, 1))
169
+ out_4 = self.progression_4(out_4)
170
+ out_8 = self.progress(out_4, self.progression_8)
171
+ if step==1:
172
+ if self.tanh:
173
+ return torch.tanh(self.to_rgb_8(out_8))
174
+ return self.to_rgb_8(out_8)
175
+
176
+ out_16 = self.progress(out_8, self.progression_16)
177
+ if step==2:
178
+ return self.output( out_8, out_16, self.to_rgb_8, self.to_rgb_16, alpha )
179
+
180
+ out_32 = self.progress(out_16, self.progression_32)
181
+ if step==3:
182
+ return self.output( out_16, out_32, self.to_rgb_16, self.to_rgb_32, alpha )
183
+
184
+ out_64 = self.progress(out_32, self.progression_64)
185
+ if step==4:
186
+ return self.output( out_32, out_64, self.to_rgb_32, self.to_rgb_64, alpha )
187
+
188
+ out_128 = self.progress(out_64, self.progression_128)
189
+ if step==5:
190
+ return self.output( out_64, out_128, self.to_rgb_64, self.to_rgb_128, alpha )
191
+
192
+ out_256 = self.progress(out_128, self.progression_256)
193
+ if step==6:
194
+ return self.output( out_128, out_256, self.to_rgb_128, self.to_rgb_256, alpha )
195
+
196
+
197
+ class Discriminator(nn.Module):
198
+ def __init__(self, feat_dim=128):
199
+ super().__init__()
200
+
201
+ self.progression = nn.ModuleList([ConvBlock(feat_dim//4, feat_dim//4, 3, 1),
202
+ ConvBlock(feat_dim//4, feat_dim//2, 3, 1),
203
+ ConvBlock(feat_dim//2, feat_dim, 3, 1),
204
+ ConvBlock(feat_dim, feat_dim, 3, 1),
205
+ ConvBlock(feat_dim, feat_dim, 3, 1),
206
+ ConvBlock(feat_dim, feat_dim, 3, 1),
207
+ ConvBlock(feat_dim+1, feat_dim, 3, 1, 4, 0)])
208
+
209
+ self.from_rgb = nn.ModuleList([EqualConv2d(3, feat_dim//4, 1),
210
+ EqualConv2d(3, feat_dim//4, 1),
211
+ EqualConv2d(3, feat_dim//2, 1),
212
+ EqualConv2d(3, feat_dim, 1),
213
+ EqualConv2d(3, feat_dim, 1),
214
+ EqualConv2d(3, feat_dim, 1),
215
+ EqualConv2d(3, feat_dim, 1)])
216
+
217
+ self.n_layer = len(self.progression)
218
+
219
+ self.linear = EqualLinear(feat_dim, 1)
220
+
221
+ def forward(self, input, step=0, alpha=-1):
222
+ for i in range(step, -1, -1):
223
+ index = self.n_layer - i - 1
224
+
225
+ if i == step:
226
+ out = self.from_rgb[index](input)
227
+
228
+ if i == 0:
229
+ out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)
230
+ mean_std = out_std.mean()
231
+ mean_std = mean_std.expand(out.size(0), 1, 4, 4)
232
+ out = torch.cat([out, mean_std], 1)
233
+
234
+ out = self.progression[index](out)
235
+
236
+ if i > 0:
237
+ # out = F.avg_pool2d(out, 2)
238
+ out = F.interpolate(out, scale_factor=0.5, mode='bilinear', align_corners=False)
239
+
240
+ if i == step and 0 <= alpha < 1:
241
+ # skip_rgb = F.avg_pool2d(input, 2)
242
+ skip_rgb = F.interpolate(input, scale_factor=0.5, mode='bilinear', align_corners=False)
243
+ skip_rgb = self.from_rgb[index + 1](skip_rgb)
244
+ out = (1 - alpha) * skip_rgb + alpha * out
245
+
246
+ out = out.squeeze(2).squeeze(2)
247
+ # print(input.size(), out.size(), step)
248
+ out = self.linear(out)
249
+
250
+ return out
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.0
2
+ torchvision==0.15.0
3
+ numpy==1.23.4
4
+ Pillow==9.2.0
5
+ tqdm==4.64.1
6
+ pydantic==2.10.6
7
+ gradio==3.32.0
8
+ gradio_client>=0.13.0
train.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import numpy as np
3
+ from PIL import Image
4
+ import argparse
5
+ import random
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import os
9
+ from torch import nn, optim
10
+ from torch.autograd import Variable, grad
11
+ from torch.utils.data import DataLoader
12
+ from torchvision import datasets, transforms, utils
13
+
14
+ from progan_modules import Generator, Discriminator
15
+
16
+
17
+ def accumulate(model1, model2, decay=0.999):
18
+ par1 = dict(model1.named_parameters())
19
+ par2 = dict(model2.named_parameters())
20
+
21
+ for k in par1.keys():
22
+ par1[k].data.mul_(decay).add_(par2[k].data, alpha=(1 - decay))
23
+
24
+
25
+ def imagefolder_loader(path):
26
+ def loader(transform):
27
+ data = datasets.ImageFolder(path, transform=transform)
28
+ data_loader = DataLoader(data, shuffle=True, batch_size=batch_size, num_workers=2)
29
+ return data_loader
30
+ return loader
31
+
32
+
33
+ def sample_data(dataloader, image_size=4):
34
+ transform = transforms.Compose([
35
+ transforms.Resize(image_size+int(image_size*0.2)+1),
36
+ transforms.RandomCrop(image_size),
37
+ transforms.RandomHorizontalFlip(),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
40
+ ])
41
+
42
+ loader = dataloader(transform)
43
+
44
+ return loader
45
+
46
+
47
+ def train(generator, discriminator, init_step, loader, total_iter=600000, start_iter=0):
48
+ step = init_step # can be 1 = 8, 2 = 16, 3 = 32, 4 = 64, 5 = 128, 6 = 128
49
+ data_loader = sample_data(loader, 4 * 2 ** step)
50
+ dataset = iter(data_loader)
51
+
52
+ #total_iter = 600000
53
+ total_iter_remain = total_iter - (total_iter//6)*(step-1)
54
+
55
+ pbar = tqdm(range(total_iter_remain))
56
+
57
+ disc_loss_val = 0
58
+ gen_loss_val = 0
59
+ grad_loss_val = 0
60
+
61
+ from datetime import datetime
62
+ import os
63
+ date_time = datetime.now()
64
+ post_fix = '%s_%s_%d_%d.txt'%(trial_name, date_time.date(), date_time.hour, date_time.minute)
65
+ log_folder = 'trial_%s_%s_%d_%d'%(trial_name, date_time.date(), date_time.hour, date_time.minute)
66
+
67
+ os.mkdir(log_folder)
68
+ os.mkdir(log_folder+'/checkpoint')
69
+ os.mkdir(log_folder+'/sample')
70
+
71
+ config_file_name = os.path.join(log_folder, 'train_config_'+post_fix)
72
+ config_file = open(config_file_name, 'w')
73
+ config_file.write(str(args))
74
+ config_file.close()
75
+
76
+ log_file_name = os.path.join(log_folder, 'train_log_'+post_fix)
77
+ log_file = open(log_file_name, 'w')
78
+ log_file.write('g,d,nll,onehot\n')
79
+ log_file.close()
80
+
81
+ from shutil import copy
82
+ copy('train.py', log_folder+'/train_%s.py'%post_fix)
83
+ copy('progan_modules.py', log_folder+'/model_%s.py'%post_fix)
84
+
85
+ alpha = 0
86
+ #one = torch.FloatTensor([1]).to(device)
87
+ one = torch.tensor(1, dtype=torch.float).to(device)
88
+ mone = one * -1
89
+ iteration = 0
90
+
91
+ for i in pbar:
92
+ discriminator.zero_grad()
93
+
94
+ alpha = min(1, (2/(total_iter//6)) * iteration)
95
+
96
+ if iteration > total_iter//6:
97
+ alpha = 0
98
+ iteration = 0
99
+ step += 1
100
+
101
+ if step > 6:
102
+ alpha = 1
103
+ step = 6
104
+ data_loader = sample_data(loader, 4 * 2 ** step)
105
+ dataset = iter(data_loader)
106
+
107
+ try:
108
+ real_image, label = next(dataset)
109
+
110
+ except (OSError, StopIteration):
111
+ dataset = iter(data_loader)
112
+ real_image, label = next(dataset)
113
+
114
+ iteration += 1
115
+
116
+ ### 1. train Discriminator
117
+ b_size = real_image.size(0)
118
+ real_image = real_image.to(device)
119
+ label = label.to(device)
120
+ real_predict = discriminator(
121
+ real_image, step=step, alpha=alpha)
122
+ real_predict = real_predict.mean() \
123
+ - 0.001 * (real_predict ** 2).mean()
124
+ real_predict.backward(mone)
125
+
126
+ # sample input data: vector for Generator
127
+ gen_z = torch.randn(b_size, input_code_size).to(device)
128
+
129
+ fake_image = generator(gen_z, step=step, alpha=alpha)
130
+ fake_predict = discriminator(
131
+ fake_image.detach(), step=step, alpha=alpha)
132
+ fake_predict = fake_predict.mean()
133
+ fake_predict.backward(one)
134
+
135
+ ### gradient penalty for D
136
+ eps = torch.rand(b_size, 1, 1, 1).to(device)
137
+ x_hat = eps * real_image.data + (1 - eps) * fake_image.detach().data
138
+ x_hat.requires_grad = True
139
+ hat_predict = discriminator(x_hat, step=step, alpha=alpha)
140
+ grad_x_hat = grad(
141
+ outputs=hat_predict.sum(), inputs=x_hat, create_graph=True)[0]
142
+ grad_penalty = ((grad_x_hat.view(grad_x_hat.size(0), -1)
143
+ .norm(2, dim=1) - 1)**2).mean()
144
+ grad_penalty = 10 * grad_penalty
145
+ grad_penalty.backward()
146
+ grad_loss_val += grad_penalty.item()
147
+ disc_loss_val += (real_predict - fake_predict).item()
148
+
149
+ d_optimizer.step()
150
+
151
+ ### 2. train Generator
152
+ if (i + 1) % n_critic == 0:
153
+ generator.zero_grad()
154
+ discriminator.zero_grad()
155
+
156
+ predict = discriminator(fake_image, step=step, alpha=alpha)
157
+
158
+ loss = -predict.mean()
159
+ gen_loss_val += loss.item()
160
+
161
+
162
+ loss.backward()
163
+ g_optimizer.step()
164
+ accumulate(g_running, generator)
165
+
166
+ if (i + 1) % 1000 == 0 or i==0:
167
+ with torch.no_grad():
168
+ images = g_running(torch.randn(5 * 10, input_code_size).to(device), step=step, alpha=alpha).data.cpu()
169
+
170
+ utils.save_image(
171
+ images,
172
+ f'{log_folder}/sample/{str(i + 1).zfill(6)}.png',
173
+ nrow=10,
174
+ normalize=True)
175
+
176
+ if (i+1) % 10000 == 0 or i==0:
177
+ try:
178
+ torch.save(g_running.state_dict(), f'{log_folder}/checkpoint/{str(i + 1).zfill(6)}_g.model')
179
+ torch.save(discriminator.state_dict(), f'{log_folder}/checkpoint/{str(i + 1).zfill(6)}_d.model')
180
+ torch.save(g_optimizer.state_dict(), os.path.join(log_folder, 'checkpoint', f'{str(i + 1).zfill(6)}_g_optim.pth'))
181
+ torch.save(d_optimizer.state_dict(), os.path.join(log_folder, 'checkpoint', f'{str(i + 1).zfill(6)}_d_optim.pth'))
182
+ except:
183
+ pass
184
+
185
+ if (i+1)%500 == 0:
186
+ state_msg = (f'{i + 1}; G: {gen_loss_val/(500//n_critic):.3f}; D: {disc_loss_val/500:.3f};'
187
+ f' Grad: {grad_loss_val/500:.3f}; Alpha: {alpha:.3f}')
188
+
189
+ log_file = open(log_file_name, 'a+')
190
+ new_line = "%.5f,%.5f\n"%(gen_loss_val/(500//n_critic), disc_loss_val/500)
191
+ log_file.write(new_line)
192
+ log_file.close()
193
+
194
+ disc_loss_val = 0
195
+ gen_loss_val = 0
196
+ grad_loss_val = 0
197
+
198
+ print(state_msg)
199
+ #pbar.set_description(state_msg)
200
+
201
+
202
+ if __name__ == '__main__':
203
+ parser = argparse.ArgumentParser(description='Progressive GAN, during training, the model will learn to generate images from a low resolution, then progressively getting high resolution ')
204
+
205
+ parser.add_argument('--start_iter', type=int, default=0, help='Iterasi awal dari training')
206
+ parser.add_argument('--checkpoint', type=str, default="/content/model/", help='Path to model checkpoint directory (default: None, train from scratch)')
207
+ parser.add_argument('--path', type=str,default="/content/merged_dataset/Acne", help='path of specified dataset, should be a folder that has one or many sub image folders inside')
208
+ parser.add_argument('--trial_name', type=str, default="test1", help='a brief description of the training trial')
209
+ parser.add_argument('--gpu_id', type=int, default=0, help='0 is the first gpu, 1 is the second gpu, etc.')
210
+ parser.add_argument('--lr', type=float, default=0.001, help='learning rate, default is 1e-3, usually dont need to change it, you can try make it bigger, such as 2e-3')
211
+ parser.add_argument('--z_dim', type=int, default=128, help='the initial latent vector\'s dimension, can be smaller such as 64, if the dataset is not diverse')
212
+ parser.add_argument('--channel', type=int, default=128, help='determines how big the model is, smaller value means faster training, but less capacity of the model')
213
+ parser.add_argument('--batch_size', type=int, default=4, help='how many images to train together at one iteration')
214
+ parser.add_argument('--n_critic', type=int, default=1, help='train Dhow many times while train G 1 time')
215
+ parser.add_argument('--init_step', type=int, default=1, help='start from what resolution, 1 means 8x8 resolution, 2 means 16x16 resolution, ..., 6 means 256x256 resolution')
216
+ parser.add_argument('--total_iter', type=int, default=300000, help='how many iterations to train in total, the value is in assumption that init step is 1')
217
+ parser.add_argument('--pixel_norm', default=False, action="store_true", help='a normalization method inside the model, you can try use it or not depends on the dataset')
218
+ parser.add_argument('--tanh', default=False, action="store_true", help='an output non-linearity on the output of Generator, you can try use it or not depends on the dataset')
219
+
220
+ args = parser.parse_args()
221
+
222
+ trial_name = args.trial_name
223
+ device = torch.device("cuda:%d"%(args.gpu_id))
224
+ input_code_size = args.z_dim
225
+ batch_size = args.batch_size
226
+ n_critic = args.n_critic
227
+
228
+ generator = Generator(in_channel=args.channel, input_code_dim=input_code_size, pixel_norm=args.pixel_norm, tanh=args.tanh).to(device)
229
+ discriminator = Discriminator(feat_dim=args.channel).to(device)
230
+ g_running = Generator(in_channel=args.channel, input_code_dim=input_code_size, pixel_norm=args.pixel_norm, tanh=args.tanh).to(device)
231
+
232
+ ## you can directly load a pretrained model here
233
+ if args.checkpoint:
234
+ generator_path = os.path.join(args.checkpoint, "g.model")
235
+ discriminator_path = os.path.join(args.checkpoint, "d.model")
236
+
237
+ if os.path.exists(generator_path) and os.path.exists(discriminator_path):
238
+ print(f"Loading checkpoints from {args.checkpoint}...")
239
+ generator.load_state_dict(torch.load(generator_path))
240
+ g_running.load_state_dict(torch.load(generator_path))
241
+ discriminator.load_state_dict(torch.load(discriminator_path))
242
+ else:
243
+ print(f"Warning: Checkpoint not found at {args.checkpoint}. Training from scratch!")
244
+ else:
245
+ print("No checkpoint provided, training from scratch.")
246
+
247
+ if args.checkpoint:
248
+ generator_path = os.path.join(args.checkpoint, "g.model")
249
+ discriminator_path = os.path.join(args.checkpoint, "d.model")
250
+ optimizer_g_path = os.path.join(args.checkpoint, "g_optim.pth")
251
+ optimizer_d_path = os.path.join(args.checkpoint, "d_optim.pth")
252
+
253
+ if os.path.exists(generator_path) and os.path.exists(discriminator_path):
254
+ print(f"Loading checkpoints from {args.checkpoint}...")
255
+ generator.load_state_dict(torch.load(generator_path))
256
+ g_running.load_state_dict(torch.load(generator_path))
257
+ discriminator.load_state_dict(torch.load(discriminator_path))
258
+ else:
259
+ print(f"Warning: Checkpoint not found at {args.checkpoint}. Training from scratch!")
260
+ else:
261
+ print("No checkpoint provided, training from scratch.")
262
+
263
+ g_running.train(False)
264
+
265
+ g_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.0, 0.99))
266
+ d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99))
267
+
268
+ optimizer_g_path = os.path.join(args.checkpoint, "g_optim.pth")
269
+ optimizer_d_path = os.path.join(args.checkpoint, "d_optim.pth")
270
+
271
+ if os.path.exists(optimizer_g_path) and os.path.exists(optimizer_d_path):
272
+ g_optimizer.load_state_dict(torch.load(optimizer_g_path))
273
+ d_optimizer.load_state_dict(torch.load(optimizer_d_path))
274
+ print("Optimizers loaded successfully!")
275
+ else:
276
+ print("Warning: Optimizer checkpoint not found. Using new optimizers!")
277
+ accumulate(g_running, generator, 0)
278
+
279
+ loader = imagefolder_loader(args.path)
280
+
281
+ train(generator, discriminator, args.init_step, loader, args.total_iter, args.start_iter)