pbotsaris commited on
Commit
ceb1cc0
1 Parent(s): fb92df5

added params parser to handler

Browse files
Files changed (1) hide show
  1. handler.py +37 -7
handler.py CHANGED
@@ -2,6 +2,32 @@ from typing import Dict, List, Any
2
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
  import torch
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  class EndpointHandler:
6
  def __init__(self, path="pbotsaris/musicgen-small"):
7
  # load model and processor
@@ -25,17 +51,21 @@ class EndpointHandler:
25
  return_tensors="pt"
26
  ).to('cuda')
27
 
28
- if params is not None:
29
- with torch.cuda.amp.autocast():
30
- outputs = self.model.generate(**inputs, **params)
31
- else:
32
- with torch.cuda.amp.autocast():
33
- outputs = self.model.generate(**inputs)
34
 
35
  pred = outputs[0].cpu().numpy().tolist()
 
 
 
 
36
 
37
- return [{"audio": pred, "sr": self.model.config.sampling_rate}]
 
38
 
 
39
 
40
 
41
  if __name__ == "__main__":
 
2
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
  import torch
4
 
5
+ def create_params(params, fr):
6
+ # default
7
+ out = { "do_sample": True,
8
+ "guidance_scale": 3,
9
+ "max_new_tokens": 256
10
+ }
11
+
12
+ has_tokens = False
13
+
14
+ if params is None:
15
+ return out
16
+
17
+ if 'duration' in params:
18
+ out['max_new_tokens'] = params['duration'] * fr
19
+ has_tokens = True
20
+
21
+ for k, p in params.items():
22
+ if k in out:
23
+ if has_tokens and k == 'max_new_tokens':
24
+ continue
25
+
26
+ out[k] = p
27
+
28
+ return out
29
+
30
+
31
  class EndpointHandler:
32
  def __init__(self, path="pbotsaris/musicgen-small"):
33
  # load model and processor
 
51
  return_tensors="pt"
52
  ).to('cuda')
53
 
54
+ params = create_params(params, self.model.config.audio_encoder.frame_rate)
55
+
56
+ with torch.cuda.amp.autocast():
57
+ outputs = self.model.generate(**inputs, **params)
 
 
58
 
59
  pred = outputs[0].cpu().numpy().tolist()
60
+ sr = 32000
61
+
62
+ try:
63
+ sr = self.model.config.audio_encoder.sampling_rate
64
 
65
+ except:
66
+ sr = 32000
67
 
68
+ return [{"audio": pred, "sr":sr}]
69
 
70
 
71
  if __name__ == "__main__":