|
|
|
@ -123,7 +123,6 @@ def ddim_sample_images( |
|
|
|
t = torch.full((batch_size,), timesteps[i], device=device, dtype=torch.long) |
|
|
|
x_before = x.clone() |
|
|
|
x = ddim_sample(model, x, t, params) |
|
|
|
print(f"Step {i}, max diff: {(x - x_before).abs().max().item()}") |
|
|
|
|
|
|
|
if x.isnan().any(): |
|
|
|
raise ValueError(f"NaN detected at timestep {timesteps[i]}") |
|
|
|
|