custom_pt dict error only visible in sentence transformers and not in normal transformers why?
I tried loading JINA v3 model on my 3090 node. While I used transformers library to load the model and generate embeddings it was fine despite the fact that I have a 3.8.10 Python version. But, when I try loading the model using sentence transformers library, it causes an error. Solution to the error is trivial. (I have to modify a line on custom_st.py file to get it compatiable with my Python version)
But, I am puzzled why this error happens in sentence transformers only?
Hi
@sleeping4cat
, it's because custom_st.py
is only used by sentence-transformers
and not by transformers
. Feel free to open a PR changing that one line in custom_st.py
, would be appreciated!
@jupyterjazz thanks! I don't think it will be a good idea to change the code since hardly anyone uses Python3.8 anymore. I installed Python3.9 and I received a pytorch error. not sure how can I fix it cuz it feels like I need to load the model using pytorch and then fix it. I didn't investigate too much. But, I am posting the traceback below:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[3], line 9
5 if __name__ == '__main__':
6
7 # model = SentenceTransformer('all-MiniLM-L6-v2')
8 model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)
----> 9 pool = model.start_multi_process_pool()
10 embed = model.encode_multi_process(sentences, pool=pool, batch_size=32, show_progress_bar=True)
11 print('Embeddings computed. Shape:', embed.shape)
File /mnt/raid/gpu/lib/python3.9/site-packages/sentence_transformers/SentenceTransformer.py:851, in SentenceTransformer.start_multi_process_pool(self, target_devices)
845 for device_id in target_devices:
846 p = ctx.Process(
847 target=SentenceTransformer._encode_multi_process_worker,
848 args=(device_id, self, input_queue, output_queue),
849 daemon=True,
850 )
--> 851 p.start()
852 processes.append(p)
854 return {"input": input_queue, "output": output_queue, "processes": processes}
File /usr/lib/python3.9/multiprocessing/process.py:121, in BaseProcess.start(self)
118 assert not _current_process._config.get('daemon'), \
119 'daemonic processes are not allowed to have children'
120 _cleanup()
--> 121 self._popen = self._Popen(self)
122 self._sentinel = self._popen.sentinel
123 # Avoid a refcycle if the target function holds an indirect
124 # reference to the process object (see bpo-30775)
File /usr/lib/python3.9/multiprocessing/context.py:284, in SpawnProcess._Popen(process_obj)
281 @staticmethod
282 def _Popen(process_obj):
283 from .popen_spawn_posix import Popen
--> 284 return Popen(process_obj)
File /usr/lib/python3.9/multiprocessing/popen_spawn_posix.py:32, in Popen.__init__(self, process_obj)
30 def __init__(self, process_obj):
31 self._fds = []
---> 32 super().__init__(process_obj)
File /usr/lib/python3.9/multiprocessing/popen_fork.py:19, in Popen.__init__(self, process_obj)
17 self.returncode = None
18 self.finalizer = None
---> 19 self._launch(process_obj)
File /usr/lib/python3.9/multiprocessing/popen_spawn_posix.py:47, in Popen._launch(self, process_obj)
45 try:
46 reduction.dump(prep_data, fp)
---> 47 reduction.dump(process_obj, fp)
48 finally:
49 set_spawning_popen(None)
File /usr/lib/python3.9/multiprocessing/reduction.py:60, in dump(obj, file, protocol)
58 def dump(obj, file, protocol=None):
59 '''Replacement for pickle.dump() using ForkingPickler.'''
---> 60 ForkingPickler(file, protocol).dump(obj)
File /mnt/raid/gpu/lib/python3.9/site-packages/torch/nn/utils/parametrize.py:340, in _inject_new_class.<locals>.getstate(self)
339 def getstate(self):
--> 340 raise RuntimeError(
341 "Serialization of parametrized modules is only "
342 "supported through state_dict(). See:\n"
343 "https://pytorch.org/tutorials/beginner/saving_loading_models.html"
344 "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
345 )
RuntimeError: Serialization of parametrized modules is only supported through state_dict(). See:
https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training
I received this error when I loaded the model using SentenceTransformers
@jupyterjazz I have identified the issue. SentenceTransformers is loading the model while using custom_st.py file. Since this is not included in the main file/model, in multi-processing it is causing an serialization error. I tried loading the model without multi_process function and it worked fine. But, I want the multi_process function to work and able to load the model because I am going to run a big process on a supercluster where I'll use (4x A100s) x 16. And I can use all the performance I can get. I tried colbertv2 and it works with multi-process on GPUs.
If you guys can help resolve this for Jina-v3 embedding model, it'll mean a lot.