File size: 1,423 Bytes
7779efa e3192e0 7779efa e3192e0 7779efa e3192e0 7779efa e3192e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
using System;
using TorchSharp;
public class DDIMSampler
{
private readonly DDPM _model;
private const int TIME_STEPS = 1000;
private readonly torch.Device _device;
public DDIMSampler(DDPM model, float scale = 9.0f)
{
_model = model;
_device = model.Device;
}
public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f)
{
var gap = DDIMSampler.TIME_STEPS / steps;
var batch = img.shape[0];
using(var context = torch.enable_grad(false))
{
for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap)
{
var t_cur = torch.full(batch, i, dtype: torch.ScalarType.Int64, device: _device);
var t_prev = torch.full(batch, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: _device);
(var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur);
var model_output = e_t_uncond + scale * (e_t - e_t_uncond);
e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output);
var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output);
img = _model.QSample(pred_x0, t_prev, e_t);
Console.WriteLine(img);
}
return img;
}
}
} |