littlelittlecloud commited on
Commit
e3192e0
1 Parent(s): db2ecf2

refactor and add DDPM and DDIMSampler

Browse files
Files changed (6) hide show
  1. .gitattributes +2 -0
  2. DDIMSampler.cs +34 -0
  3. DDPM.cs +48 -0
  4. Program.cs +11 -29
  5. cat.png +2 -2
  6. ddim_v_sampler.ckpt +2 -2
.gitattributes CHANGED
@@ -22,6 +22,7 @@
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
@@ -32,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.png filter=lfs diff=lfs merge=lfs -text
26
  *.safetensors filter=lfs diff=lfs merge=lfs -text
27
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ cat.png filter=lfs diff=lfs merge=lfs -text
DDIMSampler.cs ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ using TorchSharp;
2
+
3
+ public class DDIMSampler
4
+ {
5
+ private readonly DDPM _model;
6
+ private const int TIME_STEPS = 1000;
7
+ private readonly torch.Device _device;
8
+
9
+ public DDIMSampler(DDPM model, float scale = 9.0f)
10
+ {
11
+ _model = model;
12
+ _device = model.Device;
13
+ }
14
+
15
+ public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f)
16
+ {
17
+ var gap = DDIMSampler.TIME_STEPS / steps;
18
+ using(var context = torch.enable_grad(false))
19
+ {
20
+ for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap)
21
+ {
22
+ var t_cur = torch.full(1, i, dtype: torch.ScalarType.Int64, device: _device);
23
+ var t_prev = torch.full(1, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: _device);
24
+ (var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur);
25
+ var model_output = e_t_uncond + scale * (e_t - e_t_uncond);
26
+ e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output);
27
+ var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output);
28
+ img = _model.QSample(pred_x0, t_prev, e_t);
29
+ }
30
+
31
+ return img;
32
+ }
33
+ }
34
+ }
DDPM.cs ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ using TorchSharp;
2
+
3
+ public class DDPM
4
+ {
5
+ private readonly torch.jit.ScriptModule _model;
6
+ public torch.Device Device {get;}
7
+ public DDPM(string modelPath, torch.Device device)
8
+ {
9
+ _model = TorchSharp.torch.jit.load(modelPath);
10
+ Device = device;
11
+ _model.to(Device);
12
+ _model.eval();
13
+ }
14
+
15
+ public (torch.Tensor e_T_Uncondition, torch.Tensor e_T) DiffusionModel(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, torch.Tensor t)
16
+ {
17
+ var x_in = torch.cat(new[] { img, img });
18
+ var condition_in = torch.cat(new[] { unconditional_condition, condition });
19
+ var t_in = torch.cat(new[] { t, t });
20
+ var res = _model.invoke<torch.Tensor>("diffusion_model", x_in, t_in, condition_in).chunk(2);
21
+ return (res[0], res[1]);
22
+ }
23
+
24
+ public torch.Tensor DecodeImage(torch.Tensor img)
25
+ {
26
+ return _model.invoke<torch.Tensor>("decode_image", img);
27
+ }
28
+
29
+ public torch.Tensor ClipEncoder(torch.Tensor tokenTensor)
30
+ {
31
+ return _model.invoke<torch.Tensor>("clip_encoder", tokenTensor);
32
+ }
33
+
34
+ public torch.Tensor QSample(torch.Tensor z, torch.Tensor t, torch.Tensor v)
35
+ {
36
+ return _model.invoke<torch.Tensor>("q_sample",z, t, v);
37
+ }
38
+
39
+ public torch.Tensor PredictEPSFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v)
40
+ {
41
+ return _model.invoke<torch.Tensor>("predict_eps_from_z_and_v", z, t, v);
42
+ }
43
+
44
+ public torch.Tensor PredictStartFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v)
45
+ {
46
+ return _model.invoke<torch.Tensor>("predict_start_from_z_and_v", z, t, v);
47
+ }
48
+ }
Program.cs CHANGED
@@ -3,11 +3,11 @@ using System.Collections.Generic;
3
  using System.IO;
4
  using System.Linq;
5
  using TorchSharp;
 
6
  torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
7
  var device = TorchSharp.torch.device("cuda:0");
8
- var ddpm_v_sampler = TorchSharp.torch.jit.load("ddim_v_sampler.ckpt");
9
- ddpm_v_sampler.to(device);
10
- ddpm_v_sampler.eval();
11
 
12
  var start_token = 49406;
13
  var end_token = 49407;
@@ -34,30 +34,12 @@ var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype
34
  unconditional_tokenTensor = unconditional_tokenTensor.reshape((long)batch, -1);
35
  var img = torch.randn(batch, 4, 96, 96, dtype: torch.ScalarType.Float32, device: device);
36
  var t = torch.ones(batch, dtype: torch.ScalarType.Int32, device: device);
37
- var condition = ddpm_v_sampler.invoke("clip_encoder", tokenTensor);
38
- var unconditional_condition = ddpm_v_sampler.invoke("clip_encoder", unconditional_tokenTensor);
39
- Console.WriteLine(condition);
40
- var timesteps = 1000;
41
  var ddim_steps = 50;
42
- int gap = timesteps / ddim_steps;
43
- using(var context = torch.enable_grad(false))
44
- {
45
- for(var i = timesteps-1; i >=0; i -= gap)
46
- {
47
- var t_cur = torch.full(batch, i, dtype: torch.ScalarType.Int64, device: device);
48
- var t_prev = torch.full(batch, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: device);
49
- img = (torch.Tensor)ddpm_v_sampler.invoke("ddim_sampler", img, condition, unconditional_condition, t_cur, t_prev);
50
- Console.WriteLine($"step {i}");
51
- }
52
-
53
- var decoded_images = (torch.Tensor)ddpm_v_sampler.invoke("decode_image", img);
54
- decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
55
-
56
- for(int i = 0; i!= batch; ++i)
57
- {
58
- // c * h * w
59
- var image = decoded_images[i];
60
- image = (image * 255.0).to(torch.ScalarType.Byte).cpu();
61
- torchvision.io.write_image(image, $"{i}.png", torchvision.ImageFormat.Png);
62
- }
63
- }
 
3
  using System.IO;
4
  using System.Linq;
5
  using TorchSharp;
6
+
7
  torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
8
  var device = TorchSharp.torch.device("cuda:0");
9
+ var ddpm = new DDPM("ddim_v_sampler.ckpt", device);
10
+ var ddimSampler = new DDIMSampler(ddpm);
 
11
 
12
  var start_token = 49406;
13
  var end_token = 49407;
 
34
  unconditional_tokenTensor = unconditional_tokenTensor.reshape((long)batch, -1);
35
  var img = torch.randn(batch, 4, 96, 96, dtype: torch.ScalarType.Float32, device: device);
36
  var t = torch.ones(batch, dtype: torch.ScalarType.Int32, device: device);
37
+ var condition = ddpm.ClipEncoder(tokenTensor);
38
+ var unconditional_condition = ddpm.ClipEncoder(unconditional_tokenTensor);
 
 
39
  var ddim_steps = 50;
40
+ img = ddimSampler.Sample(img, condition, unconditional_condition, ddim_steps);
41
+ var decoded_images = (torch.Tensor)ddpm.DecodeImage(img);
42
+ decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
43
+ var image = decoded_images[0];
44
+ image = (image * 255.0).to(torch.ScalarType.Byte).cpu();
45
+ torchvision.io.write_image(image, $"0.png", torchvision.ImageFormat.Png);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cat.png CHANGED

Git LFS Details

  • SHA256: 8e7a7b2e0ca4d8ea9a8cc490fc944e2ab628d261513d1115357976eda76f489c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.76 MB

Git LFS Details

  • SHA256: db0eec932ec5f5e907f3a6addfb10b85e45ce94a4af72faddb01abf860144cac
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
ddim_v_sampler.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d32dd77b16c5ab4b584f6827037039c6d484caa6d11e541b7f6c8b93bb30c8cc
3
- size 5216916628
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22b16b2fc18c3b20c0eb74ed49a8f1834388fbfd84a49110340943f22fd30fa1
3
+ size 5216915007