Jax model weights

#16
by bmazoure - opened

When running this chunk of code:

from transformers import AutoTokenizer, FlaxGemmaModel

model = FlaxGemmaModel.from_pretrained("google/gemma-2b")

I get the error:

Support for sharded checkpoints using safetensors is coming soon!

which I assume means that the currently provided checkpoints do not work for Jax models?

switch it to this:
model = FlaxGemmaModel.from_pretrained("google/gemma-2b", revision="flax")
the JAX weights are on the 'flax' branch

Google org

@bmazoure Did this end up working?

Yes, this worked, thanks!

bmazoure changed discussion status to closed

Sign up or log in to comment