@LucasFromGoogle recommended that i posted about the issue on the forum:
Currently, AI studio has issues with formatting code and it doesn’t seem to be happening when using 3.1 Pro in the Gemini app, only in AI studio.
The highlighted part of my code got omitted when asking 3.1 Pro for some fixes. it then kept omitting it multiple times. I’m not sure whether this is an AI studio issue or 3.1 pro issue as i haven’t seen it happen in the Gemini app yet:
Here is the issue happening multiple times in a conversation:
right:
if len(loss_history) > 20:
smoothed = [sum(loss_history)/20 for i in range(20, len(loss_history))]
ax.plot(step_history, smoothed, color=“red”, linewidth=2, label=“Trend (Moving Avg)”)
wrong:
if len(loss_history) > 20:
smoothed =)/20 for i in range(20, len(loss_history))]
ax.plot(step_history, smoothed, color=“red”, linewidth=2, label=“Trend (Moving Avg)”)
Just for fun here’s some code that was previously error-free and working 100% that i have given to 3.1 Pro and pasted back into this place, now with 29 problems (20 errors and 9 warnings):
import json
import math
import os
import time
from pathlib import Path
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from models import AutoencoderKL, VAEConfig
CONFIG: Dict = {
"image_dir": r"C:\Users\jplw0\Desktop\creative\programming projects\image generation\diffusion model 01\dataset\cats", # Testing by overfitting, there's just 1 image here
"caption_dir": r"C:\Users\jplw0\Desktop\creative\programming projects\image generation\diffusion model 01\dataset\cats_captions",
"image_size": 128,
"batch_size": 4, # Increased for stability
"num_workers": 0,
"vae_path": "vae_results/vae_model.pt",
"output_dir": "results_latent",
# Model Params
"latent_dim": 8,
"patch_size": 2,
"dim": 768,
"depth": 12,
"heads": 12,
"mlp_ratio": 4.0,
# Optimization
"lr": 1e-4, # FIXED: 4e-3 was exploding gradients
"weight_decay": 0.01,
"epochs": 2000,
"grad_clip": 1.0,
"grad_accum_steps": 8, # Accumulate to simulate larger batch size
"mixed_precision": "bf16",
# Flow Matching / CFG
"cfg_drop_prob": 0.1, # Drops caption 10% of the time for CFG
"cfg_scale": 4.0, # Guidance scale during inference
"shift_scale": 1.0, # Timestep shifting (Flux/SD3 style)
"use_ema": True,
"ema_decay": 0.999,
"sample_every_epochs": 50,
"sample_steps": 30, # 30 steps is enough for Euler with Flow Matching
"n_sample_images": 4,
"compile": False,
"seed": 1234,
}
def get_2d_sincos_pos_embed(embed_dim, grid_size):
"""Generates 2D Sine-Cosine positional embeddings."""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape()
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid)
return np.concatenate(, axis=1)
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim=-1) * self.scale * self.g
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.freq_dim = frequency_embedding_size
def forward(self, t):
half = self.freq_dim // 2
freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half)
args = t.float() * freqs
embedding = torch.cat(, dim=-1)
return self.mlp(embedding)
class SwiGLU(nn.Module):
def __init__(self, in_features, hidden_features):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features * 2)
self.fc2 = nn.Linear(hidden_features, in_features)
def forward(self, x):
x, gate = self.fc1(x).chunk(2, dim=-1)
return self.fc2(x * F.silu(gate))
class ModernDiTBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
self.q_norm = RMSNorm(hidden_size)
self.k_norm = RMSNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.cross_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
self.q_norm_cross = RMSNorm(hidden_size)
self.k_norm_cross = RMSNorm(hidden_size)
self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = SwiGLU(hidden_size, int(hidden_size * mlp_ratio))
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
# AdaLN-Zero Initialization (Crucial for preventing muddy generations)
nn.init.constant_(self.adaLN_modulation.weight, 0)
nn.init.constant_(self.adaLN_modulation.bias, 0)
def forward(self, x, c, text_ctx):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
# Self Attention
x_norm = modulate(self.norm1(x), shift_msa, scale_msa)
q = self.q_norm(x_norm)
k = self.k_norm(x_norm)
attn_out, _ = self.attn(q, k, x_norm, need_weights=False)
x = x + gate_msa.unsqueeze(1) * attn_out
# Cross Attention (Text)
x_norm_cross = self.norm2(x)
q_cross = self.q_norm_cross(x_norm_cross)
k_cross = self.k_norm_cross(text_ctx)
cross_out, _ = self.cross_attn(q_cross, k_cross, text_ctx, need_weights=False)
x = x + cross_out
# SwiGLU MLP
x_norm_mlp = modulate(self.norm3(x), shift_mlp, scale_mlp)
x = x + gate_mlp.unsqueeze(1) * self.mlp(x_norm_mlp)
return x
class StateOfTheArtDiT(nn.Module):
def __init__(self, latent_size=16, latent_channels=8, patch_size=2, dim=768, depth=12, heads=12):
super().__init__()
self.patch_size = patch_size
self.latent_channels = latent_channels
seq_len = (latent_size // patch_size) ** 2
self.x_embedder = nn.Linear(latent_channels * patch_size**2, dim)
self.t_embedder = TimestepEmbedder(dim)
self.text_embedder = nn.Linear(768, dim)
self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, dim))
self.blocks = nn.ModuleList()
self.final_layer = nn.Sequential(
nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6),
nn.Linear(dim, latent_channels * patch_size**2, bias=True)
)
# Zero-out final layer to ensure identity function at start
nn.init.constant_(self.final_layer.weight, 0)
nn.init.constant_(self.final_layer.bias, 0)
def unpatchify(self, x, latent_size):
c = self.latent_channels
p = self.patch_size
h = w = latent_size // p
x = x.reshape(shape=(x.shape, h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape, c, h * p, w * p))
return imgs
def forward(self, x, t, context):
b, c, h, w = x.shape
p = self.patch_size
# Patchify
x = x.reshape(b, c, h//p, p, w//p, p)
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(b, (h//p) * (w//p), c * p**2)
x = self.x_embedder(x) + self.pos_embed
t_emb = self.t_embedder(t)
ctx_emb = self.text_embedder(context)
for block in self.blocks:
x = block(x, t_emb, ctx_emb)
# Final projection and unpatchify
shift, scale = self.t_embedder.mlp.bias.chunk(2, dim=-1) # Simplified AdaLN for final
x = self.final_layer(x)
x = x * (1 + scale) + shift
x = self.final_layer(x)
return self.unpatchify(x, h)
class EMA:
def __init__(self, model: nn.Module, decay: float):
self.decay = decay
self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
@torch.no_grad()
def update(self, model: nn.Module) -> None:
for k, v in model.state_dict().items():
if k in self.shadow:
self.shadow.mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)
def copy_to(self, model: nn.Module) -> None:
model.load_state_dict(self.shadow, strict=True)
class TextImageDataset(Dataset):
def __init__(self, img_dir, cap_dir, transform):
self.img_dir = Path(img_dir)
self.cap_dir = Path(cap_dir)
self.transform = transform
self.img_files =
def __len__(self):
return len(self.img_files)
def __getitem__(self, idx):
img_name = self.img_files
img_path = self.img_dir / img_name
base_name = os.path.splitext(img_name)
cap_path = self.cap_dir / f"{base_name}.txt"
try:
image = self.transform(Image.open(img_path).convert('RGB'))
except Exception:
image = torch.zeros((3, int(CONFIG), int(CONFIG)))
caption = ""
if cap_path.exists():
with open(cap_path, 'r', encoding='utf-8') as f:
caption = f.read().strip()
return image, caption
@torch.no_grad()
def sample_cfg_euler(model, z, steps, context_cond, context_uncond, cfg_scale, device):
t_steps = torch.linspace(0.0, 1.0, steps + 1, device=device)
for i in range(steps):
t_cur = t_steps.expand(z.shape)
dt = (t_steps - t_steps).item()
# Batch conditional and unconditional passes together
z_combined = torch.cat(, dim=0)
t_combined = torch.cat(, dim=0)
ctx_combined = torch.cat(, dim=0)
v_pred = model(z_combined, t_combined, ctx_combined)
v_cond, v_uncond = v_pred.chunk(2, dim=0)
# CFG Formula
v_final = v_uncond + cfg_scale * (v_cond - v_uncond)
z = z + v_final * dt
return z
def train_diffusion() -> None:
torch.manual_seed(CONFIG)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
out_dir = Path(CONFIG)
out_dir.mkdir(parents=True, exist_ok=True)
transform = transforms.Compose(), int(CONFIG))),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(,),
])
dataset = TextImageDataset(CONFIG, CONFIG, transform)
dataloader = DataLoader(dataset, batch_size=CONFIG, shuffle=True, drop_last=True)
print("Loading VAE & Text Encoder...")
payload = torch.load(CONFIG, map_location=device)
vae = AutoencoderKL(VAEConfig(**payload)).to(device)
vae.load_state_dict(payload, strict=True)
vae.eval().requires_grad_(False)
vae_scale = float(payload.get("latent_scale", 1.0))
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
text_encoder.requires_grad_(False)
# Calculate latent resolution
latent_size = CONFIG // 8
model = StateOfTheArtDiT(
latent_size=latent_size,
latent_channels=CONFIG,
patch_size=CONFIG,
dim=CONFIG,
depth=CONFIG,
heads=CONFIG
).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=CONFIG, weight_decay=CONFIG)
scaler = torch.cuda.amp.GradScaler(enabled=(CONFIG == "fp16"))
ema = EMA(model, decay=CONFIG)
null_inputs = tokenizer(, padding="max_length", max_length=77, truncation=True, return_tensors="pt").to(device)
null_embeds = text_encoder(null_inputs.input_ids)
for epoch in range(CONFIG):
model.train()
pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
for i, (x, captions) in enumerate(pbar):
x = x.to(device)
# 1. Extract Latents & Text Embeds
with torch.no_grad():
latents = vae.encode(x, sample=False) * vae_scale
# Apply 10% CFG Dropout
do_drop = torch.rand(1).item() < CONFIG
if do_drop:
text_embeds = null_embeds.expand(x.size(0), -1, -1)
else:
text_inputs = tokenizer(captions, padding="max_length", max_length=77, truncation=True, return_tensors="pt").to(device)
text_embeds = text_encoder(text_inputs.input_ids)
# 2. Flow Matching target & integration
noise = torch.randn_like(latents)
t = torch.rand(x.size(0), device=device)
# Shifted Timesteps (Pushes focus towards early stages to build structure)
s = CONFIG
t = (t * s) / (1 + (s - 1) * t)
t_exp = t.view(-1, 1, 1, 1)
x_t = (1.0 - t_exp) * noise + t_exp * latents
v_target = latents - noise # The Rectified Flow vector field
# 3. Model Forward
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
v_pred = model(x_t, t, text_embeds)
loss = F.mse_loss(v_pred, v_target) / CONFIG
scaler.scale(loss).backward()
# 4. Step
if (i + 1) % CONFIG == 0:
scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG)
scaler.step(opt)
scaler.update()
opt.zero_grad(set_to_none=True)
ema.update(model)
pbar.set_postfix({"loss": loss.item() * CONFIG})
if (epoch % CONFIG == 0) or (epoch == CONFIG - 1):
eval_model = StateOfTheArtDiT(latent_size, CONFIG, CONFIG, CONFIG, CONFIG, CONFIG).to(device)
ema.copy_to(eval_model)
eval_model.eval()
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
n = CONFIG
z = torch.randn(n, CONFIG, latent_size, latent_size, device=device)
test_prompts = * n
cond_in = tokenizer(test_prompts, padding="max_length", max_length=77, truncation=True, return_tensors="pt").to(device)
cond_embed = text_encoder(cond_in.input_ids)
uncond_embed = null_embeds.expand(n, -1, -1)
z_out = sample_cfg_euler(eval_model, z, CONFIG, cond_embed, uncond_embed, CONFIG, device)
decoded = vae.decode(z_out / vae_scale)
decoded = ((decoded + 1.0) * 0.5).clamp(0, 1)
save_image(decoded, out_dir / f"epoch_{epoch:04d}.png", nrow=int(math.sqrt(n)))
torch.save({"model": eval_model.state_dict(), "ema": ema.shadow}, out_dir / "sota_dit.pt")
print("Saved Checkpoint & Image.")
if __name__ == "__main__":
train_diffusion()
Unfortunately due to the restrictions to new users on here i’m not able to add multiple embeds or links, this is what i can do for now. Thank you





