gbyuvd's picture
Update README.md
23ec36a verified
|
raw
history blame
11.1 kB
---
license: cc-by-nc-sa-4.0
library_name: transformers
tags:
- chemistry
- selfies
---
# chemfie-gpt-experiment-1
This model is part of my own hands-on learning and experimentation on molecule generation, to determine which type of model is best suited for SELFIES (GPT2, T5, or by way of fill-mask).
It also serves as a baseline for future ablation and customization studies in model architecture, dataset augmentation(s), and training processes.
## Model Details
- **Model Type**: GPT-2
- **Architecture**: L8, A6, H384
- **Task**: Generation of SELFIES strings
- **Language**: N/A (Chemical representation)
## Personal Intended Use
- Hands-on learning, research and experimentation in molecular generation
- Baseline for ablation studies and comparisons with more advanced models
## Usage
### Direct Use
Since this model doesn't use a proper GPT2 format tokenizer, special tokens still need to be set up manually (next experiment will use a proper one ofc):
```python
from transformers import PreTrainedTokenizerFast, AutoModelForCausalLM
import torch
tokenizer = PreTrainedTokenizerFast(
tokenizer_file="gpt2_tokenizer.json",
model_max_length=512,
unk_token="<unk>",
pad_token="<pad>",
eos_token="</s>",
bos_token="<s>",
mask_token="<mask>",
)
model = AutoModelForCausalLM.from_pretrained("gbyuvd/chemfie-gpt-experiment-1")
# Generate some sample outputs
def generate_molecules(model, tokenizer, num_samples=5, max_length=100):
model.eval()
generated = []
for _ in range(num_samples):
input_ids = torch.tensor([[tokenizer.bos_token_id]]).to(model.device)
output = model.generate(input_ids, max_length=max_length, num_return_sequences=1, do_sample=True)
generated.append(tokenizer.decode(output[0], skip_special_tokens=True))
return generated
sample_molecules = generate_molecules(model, tokenizer)
print("Sample generated molecules:")
for i, mol in enumerate(sample_molecules, 1):
print(f"{i}. {mol}")
""""
....
2. [C] [Branch1] [C] [Branch1] [C] [C] [=N] [C] [Branch1] [C] [=N] [Branch1] [C] [N] [Branch1] [C] [C]
3. [C] [Branch1] [C] [Branch1] [C] [C] [=N] [C] [Branch1] [C] [=N] [Branch1] [C] [N] [=C] [Ring1] [N]
4. [C] [Branch1] [C] [Branch1] [C] [C] [=N] [C] [Branch1] [C] [=N]
5. [C] [Branch1] [C] [Branch1] [C] [C] [=N] [C] [Branch1] [C] [=N] [Branch1] [C] [N] [Branch1] [C]
""""
```
**Tokenized SELFIES to SMILES:**
```python
import selfies as sf
test = "[C] [Branch1] [O] [=C] [C] [C] [C] [C] [C] [C] [C] [=Branch1] [=O] [O] [=C] [C] [C] [C] [Ring1]"
test = test.replace(' ', '')
print(sf.decoder(test))
""""
C(CCCCCCCCO)=CCC=C
""""
```
#### Generate with Different Temperature(s) and Visualization
```python
import torch
import selfies as sf
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt
def generate_molecules(temperature, num_molecules=2):
inputs = torch.tensor([[tokenizer.bos_token_id]])
gen = model.generate(
inputs,
do_sample=True,
max_length=256,
temperature=temperature,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
num_beams=5,
num_return_sequences=num_molecules
)
return tokenizer.batch_decode(gen, skip_special_tokens=True)
def selfies_to_smiles(selfies_str):
selfies_str = selfies_str.replace(' ', '')
try:
return sf.decoder(selfies_str)
except:
return None
def visualize_molecules(temperatures):
fig, axs = plt.subplots(len(temperatures), 2, figsize=(20, 4*len(temperatures))) # don't forget to change this args, if you want to generate more than 2 samples each
fig.suptitle("Generated Molecules at Different Temperatures", fontsize=16)
for i, temp in enumerate(temperatures):
molecules = generate_molecules(temp)
for j, mol in enumerate(molecules):
smiles = selfies_to_smiles(mol)
if smiles:
rdkit_mol = Chem.MolFromSmiles(smiles)
if rdkit_mol:
img = Draw.MolToImage(rdkit_mol)
axs[i, j].imshow(img)
axs[i, j].axis('off')
axs[i, j].set_title(f"Temp: {temp}", fontsize=10)
else:
axs[i, j].text(0.5, 0.5, "Invalid\nMolecule", ha='center', va='center')
axs[i, j].axis('off')
else:
axs[i, j].text(0.5, 0.5, "Invalid\nSELFIES", ha='center', va='center')
axs[i, j].axis('off')
plt.tight_layout()
plt.show()
# Generate and visualize molecules at different temperatures
temperatures = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5]
visualize_molecules(temperatures)
```
**Output example:**
![image/png](https://cdn-uploads.huggingface.co/production/uploads/667da868d653c0b02d6a2399/6Qxd4MgRD_isM9prx-XW3.png)
#### Generate using Starting Sequence with Different Temperature(s) and Visualization
```python
import torch
import selfies as sf
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt
def generate_molecules(seed, temperature, num_molecules=5):
# Tokenize the seed
seed_tokens = tokenizer.encode(seed, add_special_tokens=False, return_tensors="pt")
# Generate from the seed
gen = model.generate(
seed_tokens,
do_sample=True,
max_length=256,
temperature=temperature,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
num_beams=5,
num_return_sequences=num_molecules
)
# Decode the generated sequences
generated = tokenizer.batch_decode(gen, skip_special_tokens=True)
# Combine seed with generated sequences
return [seed + seq[len(seed):] for seq in generated]
def selfies_to_smiles(selfies_str):
selfies_str = selfies_str.replace(' ', '')
try:
return sf.decoder(selfies_str)
except:
return None
def visualize_molecules(seed, temperatures):
fig, axs = plt.subplots(len(temperatures), 5, figsize=(20, 4*len(temperatures)))
fig.suptitle(f"Generated Molecules at Different Temperatures\nSeed: {seed}", fontsize=16)
for i, temp in enumerate(temperatures):
molecules = generate_molecules(seed, temp)
for j, mol in enumerate(molecules):
smiles = selfies_to_smiles(mol)
if smiles:
rdkit_mol = Chem.MolFromSmiles(smiles)
if rdkit_mol:
img = Draw.MolToImage(rdkit_mol)
axs[i, j].imshow(img)
axs[i, j].axis('off')
axs[i, j].set_title(f"Temp: {temp}", fontsize=10)
else:
axs[i, j].text(0.5, 0.5, "Invalid\nMolecule", ha='center', va='center')
axs[i, j].axis('off')
else:
axs[i, j].text(0.5, 0.5, "Invalid\nSELFIES", ha='center', va='center')
axs[i, j].axis('off')
plt.tight_layout()
plt.show()
# Set the seed and temperatures
seed = "[C] [C] [=Branch1] [C] [=O] [O] [C] [C] [N+1]"
temperatures = [0.5, 1.0, 1.5, 2.0, 2.5]
# Generate and visualize molecules at different temperatures
visualize_molecules(seed, temperatures)
```
**Example output:**
![image/png](https://cdn-uploads.huggingface.co/production/uploads/667da868d653c0b02d6a2399/cHamzqHjBj4tNxDPgdZ-g.png)
## Training Data
- **Source**: Curated and merged from COCONUTDB (Sorokina et al., 2021), ChemBL34 (Zdrazil et al., 2023), and SuperNatural3 (Gallo et al. 2023) database
- **Total**: 2,933,355 samples
- **Total Train**: 2,346,680 samples
- **Validation**: 293,336 samples
- **Per chunk**: 586,670 train, 73,334 validation, 73,334 test
- **Random seed for split**: 42
## Training Procedure
- **Batch Size**: 64
- **Num Epoch for Each Chunk**: 1
- **Learning Rate**: 1.5e-5
- **Optimizer**: Ranger21 (MADGRAD-Lookahead-AdaBelief with gradient centralization, linear warm up (22%), gradient clipping, and L2 weight decay)
## Training Logs
| Chunk | Chunk's Training Loss | Chunk's Validation Loss | Status |
| :---: | :-------------------: | :---------------------: | :----: |
| I | 1.346400 | 1.065180 | Done |
| II | 1.123500 | 0.993118 | Done |
| III | 1.058300 | 0.948303 | Done |
| IV | 1.016600 | 0.921706 | Done |
## Evaluation Results
[To be filled after model evaluation]
## Limitations and Biases
- May generate unrealistic or synthetically inaccessible molecules
- Performance on complex, branched, and ringed molecules to be evaluated
## Disclaimer & Ethical Considerations
- This model is in early development stage and may not consistently generate valid outputs.
- It is intended for personal exploration, academic, and research purposes only.
- You should be aware of potential ethical concerns:
- Possible generation of harmful substances if misused
- Potential biases inherent in the training data
- The accuracy, completeness, and reliability of the model's outputs are not guaranteed.
- This model should not be used for any commercial or legal purposes.
- The information and model provided are for educational and research use only.
## Additional Information
- Part of experimental chemfie-gpt/T5 project
- Serves as a baseline for future experiments with further curated datasets, training improvements, and architectural modifications
## Citation
### BibTeX
#### COCONUTDB
```bibtex
@article{sorokina2021coconut,
title={COCONUT online: Collection of Open Natural Products database},
author={Sorokina, Maria and Merseburger, Peter and Rajan, Kohulan and Yirik, Mehmet Aziz and Steinbeck, Christoph},
journal={Journal of Cheminformatics},
volume={13},
number={1},
pages={2},
year={2021},
doi={10.1186/s13321-020-00478-9}
}
```
#### ChemBL34
```bibtex
@article{zdrazil2023chembl,
title={The ChEMBL Database in 2023: a drug discovery platform spanning multiple bioactivity data types and time periods},
author={Zdrazil, Barbara and Felix, Eloy and Hunter, Fiona and Manners, Emma J and Blackshaw, James and Corbett, Sybilla and de Veij, Marleen and Ioannidis, Harris and Lopez, David Mendez and Mosquera, Juan F and Magarinos, Maria Paula and Bosc, Nicolas and Arcila, Ricardo and Kizil{\"o}ren, Tevfik and Gaulton, Anna and Bento, A Patr{\'i}cia and Adasme, Melissa F and Monecke, Peter and Landrum, Gregory A and Leach, Andrew R},
journal={Nucleic Acids Research},
year={2023},
volume={gkad1004},
doi={10.1093/nar/gkad1004}
}
@misc{chembl34,
title={ChemBL34},
year={2023},
doi={10.6019/CHEMBL.database.34}
}
```
#### SuperNatural3
```bibtex
@article{Gallo2023,
author = {Gallo, K and Kemmler, E and Goede, A and Becker, F and Dunkel, M and Preissner, R and Banerjee, P},
title = {{SuperNatural 3.0-a database of natural products and natural product-based derivatives}},
journal = {Nucleic Acids Research},
year = {2023},
month = jan,
day = {6},
volume = {51},
number = {D1},
pages = {D654-D659},
doi = {10.1093/nar/gkac1008}
}
```