custom_pt dict error only visible in sentence transformers and not in normal transformers why?

#51
by sleeping4cat - opened

I tried loading JINA v3 model on my 3090 node. While I used transformers library to load the model and generate embeddings it was fine despite the fact that I have a 3.8.10 Python version. But, when I try loading the model using sentence transformers library, it causes an error. Solution to the error is trivial. (I have to modify a line on custom_st.py file to get it compatiable with my Python version)

But, I am puzzled why this error happens in sentence transformers only?

Hi @sleeping4cat , it's because custom_st.py is only used by sentence-transformers and not by transformers. Feel free to open a PR changing that one line in custom_st.py, would be appreciated!

@jupyterjazz thanks! I don't think it will be a good idea to change the code since hardly anyone uses Python3.8 anymore. I installed Python3.9 and I received a pytorch error. not sure how can I fix it cuz it feels like I need to load the model using pytorch and then fix it. I didn't investigate too much. But, I am posting the traceback below:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[3], line 9
      5 if __name__ == '__main__':
      6     
      7     # model = SentenceTransformer('all-MiniLM-L6-v2')
      8     model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)
----> 9     pool = model.start_multi_process_pool()
     10     embed = model.encode_multi_process(sentences, pool=pool, batch_size=32, show_progress_bar=True)
     11     print('Embeddings computed. Shape:', embed.shape)

File /mnt/raid/gpu/lib/python3.9/site-packages/sentence_transformers/SentenceTransformer.py:851, in SentenceTransformer.start_multi_process_pool(self, target_devices)
    845 for device_id in target_devices:
    846     p = ctx.Process(
    847         target=SentenceTransformer._encode_multi_process_worker,
    848         args=(device_id, self, input_queue, output_queue),
    849         daemon=True,
    850     )
--> 851     p.start()
    852     processes.append(p)
    854 return {"input": input_queue, "output": output_queue, "processes": processes}

File /usr/lib/python3.9/multiprocessing/process.py:121, in BaseProcess.start(self)
    118 assert not _current_process._config.get('daemon'), \
    119        'daemonic processes are not allowed to have children'
    120 _cleanup()
--> 121 self._popen = self._Popen(self)
    122 self._sentinel = self._popen.sentinel
    123 # Avoid a refcycle if the target function holds an indirect
    124 # reference to the process object (see bpo-30775)

File /usr/lib/python3.9/multiprocessing/context.py:284, in SpawnProcess._Popen(process_obj)
    281 @staticmethod
    282 def _Popen(process_obj):
    283     from .popen_spawn_posix import Popen
--> 284     return Popen(process_obj)

File /usr/lib/python3.9/multiprocessing/popen_spawn_posix.py:32, in Popen.__init__(self, process_obj)
     30 def __init__(self, process_obj):
     31     self._fds = []
---> 32     super().__init__(process_obj)

File /usr/lib/python3.9/multiprocessing/popen_fork.py:19, in Popen.__init__(self, process_obj)
     17 self.returncode = None
     18 self.finalizer = None
---> 19 self._launch(process_obj)

File /usr/lib/python3.9/multiprocessing/popen_spawn_posix.py:47, in Popen._launch(self, process_obj)
     45 try:
     46     reduction.dump(prep_data, fp)
---> 47     reduction.dump(process_obj, fp)
     48 finally:
     49     set_spawning_popen(None)

File /usr/lib/python3.9/multiprocessing/reduction.py:60, in dump(obj, file, protocol)
     58 def dump(obj, file, protocol=None):
     59     '''Replacement for pickle.dump() using ForkingPickler.'''
---> 60     ForkingPickler(file, protocol).dump(obj)

File /mnt/raid/gpu/lib/python3.9/site-packages/torch/nn/utils/parametrize.py:340, in _inject_new_class.<locals>.getstate(self)
    339 def getstate(self):
--> 340     raise RuntimeError(
    341         "Serialization of parametrized modules is only "
    342         "supported through state_dict(). See:\n"
    343         "https://pytorch.org/tutorials/beginner/saving_loading_models.html"
    344         "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
    345     )

RuntimeError: Serialization of parametrized modules is only supported through state_dict(). See:
https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training

I received this error when I loaded the model using SentenceTransformers

@jupyterjazz I have identified the issue. SentenceTransformers is loading the model while using custom_st.py file. Since this is not included in the main file/model, in multi-processing it is causing an serialization error. I tried loading the model without multi_process function and it worked fine. But, I want the multi_process function to work and able to load the model because I am going to run a big process on a supercluster where I'll use (4x A100s) x 16. And I can use all the performance I can get. I tried colbertv2 and it works with multi-process on GPUs.

If you guys can help resolve this for Jina-v3 embedding model, it'll mean a lot.

Sign up or log in to comment