|
using System; |
|
using TorchSharp; |
|
|
|
public class DDIMSampler |
|
{ |
|
private readonly DDPM _model; |
|
private const int TIME_STEPS = 1000; |
|
private readonly torch.Device _device; |
|
|
|
public DDIMSampler(DDPM model, float scale = 9.0f) |
|
{ |
|
_model = model; |
|
_device = model.Device; |
|
} |
|
|
|
public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f) |
|
{ |
|
var gap = DDIMSampler.TIME_STEPS / steps; |
|
var batch = img.shape[0]; |
|
|
|
using(var context = torch.enable_grad(false)) |
|
{ |
|
for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap) |
|
{ |
|
var t_cur = torch.full(batch, i, dtype: torch.ScalarType.Int64, device: _device); |
|
var t_prev = torch.full(batch, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: _device); |
|
(var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur); |
|
var model_output = e_t_uncond + scale * (e_t - e_t_uncond); |
|
e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output); |
|
var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output); |
|
img = _model.QSample(pred_x0, t_prev, e_t); |
|
Console.WriteLine(img); |
|
} |
|
|
|
return img; |
|
} |
|
} |
|
} |