lengyue233 commited on
Commit
1caffd8
1 Parent(s): a4dfb48

Better init event waiting

Browse files
Files changed (2) hide show
  1. app.py +0 -3
  2. tools/llama/generate.py +4 -4
app.py CHANGED
@@ -306,7 +306,6 @@ if __name__ == "__main__":
306
  args.vqgan_config_name = "vqgan_pretrain"
307
 
308
  logger.info("Loading Llama model...")
309
- init_event = threading.Event()
310
  llama_queue = launch_thread_safe_queue(
311
  config_name=args.llama_config_name,
312
  checkpoint_path=args.llama_checkpoint_path,
@@ -314,10 +313,8 @@ if __name__ == "__main__":
314
  precision=args.precision,
315
  max_length=args.max_length,
316
  compile=args.compile,
317
- init_event=init_event,
318
  )
319
  llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
320
- init_event.wait()
321
  logger.info("Llama model loaded, loading VQ-GAN model...")
322
 
323
  vqgan_model = load_vqgan_model(
 
306
  args.vqgan_config_name = "vqgan_pretrain"
307
 
308
  logger.info("Loading Llama model...")
 
309
  llama_queue = launch_thread_safe_queue(
310
  config_name=args.llama_config_name,
311
  checkpoint_path=args.llama_checkpoint_path,
 
313
  precision=args.precision,
314
  max_length=args.max_length,
315
  compile=args.compile,
 
316
  )
317
  llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
 
318
  logger.info("Llama model loaded, loading VQ-GAN model...")
319
 
320
  vqgan_model = load_vqgan_model(
tools/llama/generate.py CHANGED
@@ -600,6 +600,7 @@ def generate_long(
600
  yield all_codes
601
 
602
 
 
603
  def launch_thread_safe_queue(
604
  config_name,
605
  checkpoint_path,
@@ -607,17 +608,15 @@ def launch_thread_safe_queue(
607
  precision,
608
  max_length,
609
  compile=False,
610
- init_event=None,
611
  ):
612
  input_queue = queue.Queue()
 
613
 
614
  def worker():
615
  model, decode_one_token = load_model(
616
  config_name, checkpoint_path, device, precision, max_length, compile=compile
617
  )
618
-
619
- if init_event is not None:
620
- init_event.set()
621
 
622
  while True:
623
  item = input_queue.get()
@@ -641,6 +640,7 @@ def launch_thread_safe_queue(
641
  event.set()
642
 
643
  threading.Thread(target=worker, daemon=True).start()
 
644
 
645
  return input_queue
646
 
 
600
  yield all_codes
601
 
602
 
603
+
604
  def launch_thread_safe_queue(
605
  config_name,
606
  checkpoint_path,
 
608
  precision,
609
  max_length,
610
  compile=False,
 
611
  ):
612
  input_queue = queue.Queue()
613
+ init_event = threading.Event()
614
 
615
  def worker():
616
  model, decode_one_token = load_model(
617
  config_name, checkpoint_path, device, precision, max_length, compile=compile
618
  )
619
+ init_event.set()
 
 
620
 
621
  while True:
622
  item = input_queue.get()
 
640
  event.set()
641
 
642
  threading.Thread(target=worker, daemon=True).start()
643
+ init_event.wait()
644
 
645
  return input_queue
646