File size: 1,423 Bytes
7779efa
e3192e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7779efa
 
e3192e0
 
 
 
7779efa
 
e3192e0
 
 
 
 
7779efa
e3192e0
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
using System;
using TorchSharp;

public class DDIMSampler
{
    private readonly DDPM _model;
    private const int TIME_STEPS = 1000;
    private readonly torch.Device _device;

    public DDIMSampler(DDPM model, float scale = 9.0f)
    {
        _model = model;
        _device = model.Device;
    }

    public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f)
    {
        var gap = DDIMSampler.TIME_STEPS / steps;
        var batch = img.shape[0];
        
        using(var context = torch.enable_grad(false))
        {
            for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap)
            {
                var t_cur = torch.full(batch, i, dtype: torch.ScalarType.Int64, device: _device);
                var t_prev = torch.full(batch, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: _device);
                (var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur);
                var model_output = e_t_uncond + scale * (e_t - e_t_uncond);
                e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output);
                var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output);
                img = _model.QSample(pred_x0, t_prev, e_t);
                Console.WriteLine(img);
            }

            return img;
        }
    }
}