| | """Text-conditional U-Net for diffusion.""" |
| | import torch |
| | import torch.nn as nn |
| | import math |
| | import config |
| |
|
| |
|
| | class TimeEmbedding(nn.Module): |
| | """Sinusoidal time embedding.""" |
| |
|
| | def __init__(self, dim): |
| | super().__init__() |
| | self.dim = dim |
| |
|
| | def forward(self, t): |
| | half_dim = self.dim // 2 |
| | emb = math.log(10000) / (half_dim - 1) |
| | emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) |
| | emb = t[:, None] * emb[None, :] |
| | return torch.cat([emb.sin(), emb.cos()], dim=1) |
| |
|
| |
|
| | class ResBlock(nn.Module): |
| | """Residual block with time and text conditioning.""" |
| |
|
| | def __init__(self, in_ch, out_ch, time_dim, text_dim=None): |
| | super().__init__() |
| | self.time_mlp = nn.Linear(time_dim, out_ch) |
| | self.text_mlp = nn.Linear(text_dim, out_ch) if text_dim else None |
| |
|
| | self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) |
| | self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) |
| | self.norm1 = nn.GroupNorm(min(8, in_ch), in_ch) |
| | self.norm2 = nn.GroupNorm(min(8, out_ch), out_ch) |
| | self.act = nn.SiLU() |
| |
|
| | self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() |
| |
|
| | def forward(self, x, t_emb, text_emb=None): |
| | h = self.act(self.norm1(x)) |
| | h = self.conv1(h) |
| |
|
| | |
| | h = h + self.time_mlp(t_emb)[:, :, None, None] |
| |
|
| | |
| | if self.text_mlp is not None and text_emb is not None: |
| | h = h + self.text_mlp(text_emb)[:, :, None, None] |
| |
|
| | h = self.act(self.norm2(h)) |
| | h = self.conv2(h) |
| |
|
| | return h + self.skip(x) |
| |
|
| |
|
| | class TextConditionedUNet(nn.Module): |
| | """U-Net with CLIP text conditioning.""" |
| |
|
| | def __init__(self, text_dim=512): |
| | super().__init__() |
| | self.text_dim = text_dim |
| |
|
| | self.time_emb = TimeEmbedding(config.TIME_DIM) |
| | self.time_mlp = nn.Sequential( |
| | nn.Linear(config.TIME_DIM, config.TIME_DIM), |
| | nn.SiLU(), |
| | nn.Linear(config.TIME_DIM, config.TIME_DIM) |
| | ) |
| |
|
| | self.text_proj = nn.Sequential( |
| | nn.Linear(text_dim, text_dim), |
| | nn.SiLU(), |
| | nn.Linear(text_dim, text_dim) |
| | ) |
| |
|
| | |
| | self.down1 = ResBlock(1, config.CHANNELS, config.TIME_DIM, text_dim) |
| | self.down2 = ResBlock(config.CHANNELS, config.CHANNELS * 2, config.TIME_DIM, text_dim) |
| | self.down3 = ResBlock(config.CHANNELS * 2, config.CHANNELS * 4, config.TIME_DIM, text_dim) |
| |
|
| | |
| | self.mid = ResBlock(config.CHANNELS * 4, config.CHANNELS * 4, config.TIME_DIM, text_dim) |
| |
|
| | |
| | self.up3 = ResBlock(config.CHANNELS * 8, config.CHANNELS * 2, config.TIME_DIM, text_dim) |
| | self.up2 = ResBlock(config.CHANNELS * 4, config.CHANNELS, config.TIME_DIM, text_dim) |
| | self.up1 = ResBlock(config.CHANNELS * 2, config.CHANNELS, config.TIME_DIM, text_dim) |
| |
|
| | |
| | self.out = nn.Conv2d(config.CHANNELS, 1, 1) |
| |
|
| | |
| | self.pool = nn.MaxPool2d(2) |
| | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') |
| |
|
| | def forward(self, x, t, text_emb): |
| | """ |
| | Args: |
| | x: [B, 1, H, W] noisy images |
| | t: [B] timesteps |
| | text_emb: [B, text_dim] CLIP text embeddings |
| | """ |
| | |
| | t_emb = self.time_mlp(self.time_emb(t)) |
| | text_emb = self.text_proj(text_emb) |
| |
|
| | |
| | h1 = self.down1(x, t_emb, text_emb) |
| | h2 = self.down2(self.pool(h1), t_emb, text_emb) |
| | h3 = self.down3(self.pool(h2), t_emb, text_emb) |
| |
|
| | |
| | h = self.mid(self.pool(h3), t_emb, text_emb) |
| |
|
| | |
| | h = self.up3(torch.cat([self.upsample(h), h3], dim=1), t_emb, text_emb) |
| | h = self.up2(torch.cat([self.upsample(h), h2], dim=1), t_emb, text_emb) |
| | h = self.up1(torch.cat([self.upsample(h), h1], dim=1), t_emb, text_emb) |
| |
|
| | return self.out(h) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | print("Testing Text-Conditioned U-Net...") |
| | model = TextConditionedUNet(text_dim=512) |
| |
|
| | |
| | batch_size = 2 |
| | x = torch.randn(batch_size, 1, 64, 64) |
| | t = torch.randint(0, 1000, (batch_size,)) |
| | text_emb = torch.randn(batch_size, 512) |
| |
|
| | out = model(x, t, text_emb) |
| | print(f"Input shape: {x.shape}") |
| | print(f"Output shape: {out.shape}") |
| | print(f"✅ Model test passed!") |