@ -1,3 +1,4 @@
from typing import Dict , Callable
import torch
import torch
from torchvision import datasets , transforms
from torchvision import datasets , transforms
from torch . utils . data import DataLoader
from torch . utils . data import DataLoader
@ -20,7 +21,6 @@ def count_parameters(model):
return sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad )
return sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad )
# Data loading and preprocessing
transform = transforms . Compose (
transform = transforms . Compose (
[
[
transforms . Resize ( IMAGE_SIZE ) ,
transforms . Resize ( IMAGE_SIZE ) ,
@ -30,7 +30,6 @@ transform = transforms.Compose(
]
]
)
)
# Download and load CIFAR-10
train_dataset = datasets . CIFAR10 (
train_dataset = datasets . CIFAR10 (
root = " ./data " , train = True , download = True , transform = transform
root = " ./data " , train = True , download = True , transform = transform
)
)
@ -40,13 +39,14 @@ train_loader = DataLoader(
)
)
def linear_beta_schedule ( timesteps ) :
def get_diffusion_params (
beta_start = 0.0001
timesteps : int , device : torch . device
beta_end = 0.02
) - > Dict [ str , torch . Tensor ] :
return torch . linspace ( beta_start , beta_end , timesteps )
def linear_beta_schedule ( timesteps ) :
beta_start = 0.0001
beta_end = 0.02
return torch . linspace ( beta_start , beta_end , timesteps )
def get_diffusion_params ( timesteps , device ) :
betas = linear_beta_schedule ( timesteps )
betas = linear_beta_schedule ( timesteps )
alphas = 1.0 - betas
alphas = 1.0 - betas
alphas_cumprod = torch . cumprod ( alphas , dim = 0 )
alphas_cumprod = torch . cumprod ( alphas , dim = 0 )
@ -68,14 +68,14 @@ def get_diffusion_params(timesteps, device):
}
}
def extract ( a , t , x_shape ) :
def extract ( a : torch . Tensor , t : torch . Tensor , x_shape : torch . Tensor . shape ) :
""" Extract coefficients at specified timesteps t """
""" Extract coefficients at specified timesteps t """
batch_size = t . shape [ 0 ]
batch_size = t . shape [ 0 ]
out = a . gather ( - 1 , t )
out = a . gather ( - 1 , t )
return out . reshape ( batch_size , * ( ( 1 , ) * ( len ( x_shape ) - 1 ) ) ) . to ( t . device )
return out . reshape ( batch_size , * ( ( 1 , ) * ( len ( x_shape ) - 1 ) ) ) . to ( t . device )
def get_loss_fn ( model , params ) :
def get_loss_fn ( model : torch . nn . Module , params : Dict [ str , torch . Tensor ] ) - > Callable :
def loss_fn ( x_0 ) :
def loss_fn ( x_0 ) :
batch_size = x_0 . shape [ 0 ]
batch_size = x_0 . shape [ 0 ]
t = torch . randint ( 0 , TIMESTEPS , ( batch_size , ) , device = device )
t = torch . randint ( 0 , TIMESTEPS , ( batch_size , ) , device = device )
@ -95,7 +95,9 @@ def get_loss_fn(model, params):
# Training loop template
# Training loop template
def train_epoch ( model , optimizer , train_loader , loss_fn ) :
def train_epoch (
model : torch . nn . Module , optimize , train_loader : DataLoader , loss_fn : Callable
) - > float :
model . train ( )
model . train ( )
total_loss = 0
total_loss = 0
@ -118,9 +120,14 @@ if __name__ == "__main__":
model = UNet ( 32 , TIMESTEPS ) . to ( device )
model = UNet ( 32 , TIMESTEPS ) . to ( device )
nb_params = count_parameters ( model )
nb_params = count_parameters ( model )
print ( f " Total number of parameters: { nb_params } " )
print ( f " Total number of parameters: { nb_params } " )
optimizer = torch . optim . AdamW ( model . parameters ( ) , lr = 1e-4 , betas = ( 0.9 , 0.95 ) )
optimizer = torch . optim . AdamW ( model . parameters ( ) , lr = 1e-4 , betas = ( 0.9 , 0.95 ) )
params = get_diffusion_params ( TIMESTEPS , device )
params = get_diffusion_params ( TIMESTEPS , device )
loss_fn = get_loss_fn ( model , params )
loss_fn = get_loss_fn ( model , params )
# Main training loop
for e in tqdm ( range ( NUM_EPOCHS ) ) :
for e in tqdm ( range ( NUM_EPOCHS ) ) :
train_epoch ( model , optimizer , train_loader , loss_fn )
train_epoch ( model , optimizer , train_loader , loss_fn )
# Save model after training
torch . save ( model . state_dict ( ) , " model.pkl " )
torch . save ( model . state_dict ( ) , " model.pkl " )