Files changed (2) hide show
  1. README.md +13 -13
  2. app.py +164 -164
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: SigLIP Tagger
3
- emoji: 🧷
4
- colorFrom: green
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.16.0
8
- app_file: app.py
9
- pinned: true
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: SigLIP Tagger
3
+ emoji: 🧷
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.43.0
8
+ app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,164 +1,164 @@
1
- import os
2
- from PIL import Image
3
-
4
- import numpy as np
5
- import torch
6
-
7
- from transformers import (
8
- AutoImageProcessor,
9
- )
10
-
11
- import gradio as gr
12
-
13
- from modeling_siglip import SiglipForImageClassification
14
-
15
-
16
- HF_TOKEN = os.environ["HF_READ_TOKEN"]
17
-
18
- EXAMPLES = [["./images/sample.jpg"], ["./images/sample2.webp"]]
19
-
20
- model_maps: dict[str, dict] = {
21
- "test2": {
22
- "repo": "p1atdev/siglip-tagger-test-2",
23
- },
24
- "test3": {
25
- "repo": "p1atdev/siglip-tagger-test-3",
26
- },
27
- # "test4": {
28
- # "repo": "p1atdev/siglip-tagger-test-4",
29
- # },
30
- }
31
-
32
- for key in model_maps.keys():
33
- model_maps[key]["model"] = SiglipForImageClassification.from_pretrained(
34
- model_maps[key]["repo"], torch_dtype=torch.bfloat16, token=HF_TOKEN
35
- )
36
- model_maps[key]["processor"] = AutoImageProcessor.from_pretrained(
37
- model_maps[key]["repo"], token=HF_TOKEN
38
- )
39
-
40
- README_MD = (
41
- f"""\
42
- ## SigLIP Tagger Test 3
43
- An experimental model for tagging danbooru tags of images using SigLIP.
44
-
45
- Model(s):
46
- """
47
- + "\n".join(
48
- f"- [{value['repo']}](https://huggingface.co/{value['repo']})"
49
- for value in model_maps.values()
50
- )
51
- + "\n"
52
- + """
53
- Example images by NovelAI and niji・journey.
54
- """
55
- )
56
-
57
-
58
- def compose_text(results: dict[str, float], threshold: float = 0.3):
59
- return ", ".join(
60
- [
61
- key
62
- for key, value in sorted(results.items(), key=lambda x: x[1], reverse=True)
63
- if value > threshold
64
- ]
65
- )
66
-
67
-
68
- @torch.no_grad()
69
- def predict_tags(image: Image.Image, model_name: str, threshold: float):
70
- if image is None:
71
- return None, None
72
-
73
- inputs = model_maps[model_name]["processor"](image, return_tensors="pt")
74
-
75
- logits = (
76
- model_maps[model_name]["model"](
77
- **inputs.to(
78
- model_maps[model_name]["model"].device,
79
- model_maps[model_name]["model"].dtype,
80
- )
81
- )
82
- .logits.detach()
83
- .cpu()
84
- .float()
85
- )
86
-
87
- logits = np.clip(logits, 0.0, 1.0)
88
-
89
- results = {}
90
-
91
- for prediction in logits:
92
- for i, prob in enumerate(prediction):
93
- if prob.item() > 0:
94
- results[model_maps[model_name]["model"].config.id2label[i]] = (
95
- prob.item()
96
- )
97
-
98
- return compose_text(results, threshold), results
99
-
100
-
101
- css = """\
102
- .sticky {
103
- position: sticky;
104
- top: 16px;
105
- }
106
-
107
- .gradio-container {
108
- overflow: clip;
109
- }
110
- """
111
-
112
-
113
- def demo():
114
- with gr.Blocks(css=css) as ui:
115
- gr.Markdown(README_MD)
116
-
117
- with gr.Row():
118
- with gr.Column():
119
- with gr.Row(elem_classes="sticky"):
120
- with gr.Column():
121
- input_img = gr.Image(
122
- label="Input image", type="pil", height=480
123
- )
124
-
125
- with gr.Group():
126
- model_name_radio = gr.Radio(
127
- label="Model",
128
- choices=list(model_maps.keys()),
129
- value="test3",
130
- )
131
- tag_threshold_slider = gr.Slider(
132
- label="Tags threshold",
133
- minimum=0.0,
134
- maximum=1.0,
135
- value=0.3,
136
- step=0.01,
137
- )
138
-
139
- start_btn = gr.Button(value="Start", variant="primary")
140
-
141
- gr.Examples(
142
- examples=EXAMPLES,
143
- inputs=[input_img],
144
- cache_examples=False,
145
- )
146
-
147
- with gr.Column():
148
- output_tags = gr.Text(label="Output text", interactive=False)
149
- output_label = gr.Label(label="Output tags")
150
-
151
- start_btn.click(
152
- fn=predict_tags,
153
- inputs=[input_img, model_name_radio, tag_threshold_slider],
154
- outputs=[output_tags, output_label],
155
- )
156
-
157
- ui.launch(
158
- debug=True,
159
- # share=True
160
- )
161
-
162
-
163
- if __name__ == "__main__":
164
- demo()
 
1
+ import os
2
+ from PIL import Image
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from transformers import (
8
+ AutoImageProcessor,
9
+ )
10
+
11
+ import gradio as gr
12
+
13
+ from modeling_siglip import SiglipForImageClassification
14
+
15
+
16
+ HF_TOKEN = os.environ.get("HF_READ_TOKEN")
17
+
18
+ EXAMPLES = [["./images/sample.jpg"], ["./images/sample2.webp"]]
19
+
20
+ model_maps: dict[str, dict] = {
21
+ "test2": {
22
+ "repo": "p1atdev/siglip-tagger-test-2",
23
+ },
24
+ "test3": {
25
+ "repo": "p1atdev/siglip-tagger-test-3",
26
+ },
27
+ # "test4": {
28
+ # "repo": "p1atdev/siglip-tagger-test-4",
29
+ # },
30
+ }
31
+
32
+ for key in model_maps.keys():
33
+ model_maps[key]["model"] = SiglipForImageClassification.from_pretrained(
34
+ model_maps[key]["repo"], torch_dtype=torch.bfloat16, token=HF_TOKEN
35
+ )
36
+ model_maps[key]["processor"] = AutoImageProcessor.from_pretrained(
37
+ model_maps[key]["repo"], token=HF_TOKEN
38
+ )
39
+
40
+ README_MD = (
41
+ f"""\
42
+ ## SigLIP Tagger Test 3
43
+ An experimental model for tagging danbooru tags of images using SigLIP.
44
+
45
+ Model(s):
46
+ """
47
+ + "\n".join(
48
+ f"- [{value['repo']}](https://huggingface.co/{value['repo']})"
49
+ for value in model_maps.values()
50
+ )
51
+ + "\n"
52
+ + """
53
+ Example images by NovelAI and niji・journey.
54
+ """
55
+ )
56
+
57
+
58
+ def compose_text(results: dict[str, float], threshold: float = 0.3):
59
+ return ", ".join(
60
+ [
61
+ key
62
+ for key, value in sorted(results.items(), key=lambda x: x[1], reverse=True)
63
+ if value > threshold
64
+ ]
65
+ )
66
+
67
+
68
+ @torch.no_grad()
69
+ def predict_tags(image: Image.Image, model_name: str, threshold: float):
70
+ if image is None:
71
+ return None, None
72
+
73
+ inputs = model_maps[model_name]["processor"](image, return_tensors="pt")
74
+
75
+ logits = (
76
+ model_maps[model_name]["model"](
77
+ **inputs.to(
78
+ model_maps[model_name]["model"].device,
79
+ model_maps[model_name]["model"].dtype,
80
+ )
81
+ )
82
+ .logits.detach()
83
+ .cpu()
84
+ .float()
85
+ )
86
+
87
+ logits = np.clip(logits, 0.0, 1.0)
88
+
89
+ results = {}
90
+
91
+ for prediction in logits:
92
+ for i, prob in enumerate(prediction):
93
+ if prob.item() > 0:
94
+ results[model_maps[model_name]["model"].config.id2label[i]] = (
95
+ prob.item()
96
+ )
97
+
98
+ return compose_text(results, threshold), results
99
+
100
+
101
+ css = """\
102
+ .sticky {
103
+ position: sticky;
104
+ top: 16px;
105
+ }
106
+
107
+ .gradio-container {
108
+ overflow: clip;
109
+ }
110
+ """
111
+
112
+
113
+ def demo():
114
+ with gr.Blocks(css=css) as ui:
115
+ gr.Markdown(README_MD)
116
+
117
+ with gr.Row():
118
+ with gr.Column():
119
+ with gr.Row(elem_classes="sticky"):
120
+ with gr.Column():
121
+ input_img = gr.Image(
122
+ label="Input image", type="pil", height=480
123
+ )
124
+
125
+ with gr.Group():
126
+ model_name_radio = gr.Radio(
127
+ label="Model",
128
+ choices=list(model_maps.keys()),
129
+ value="test3",
130
+ )
131
+ tag_threshold_slider = gr.Slider(
132
+ label="Tags threshold",
133
+ minimum=0.0,
134
+ maximum=1.0,
135
+ value=0.3,
136
+ step=0.01,
137
+ )
138
+
139
+ start_btn = gr.Button(value="Start", variant="primary")
140
+
141
+ gr.Examples(
142
+ examples=EXAMPLES,
143
+ inputs=[input_img],
144
+ cache_examples=False,
145
+ )
146
+
147
+ with gr.Column():
148
+ output_tags = gr.Text(label="Output text", interactive=False)
149
+ output_label = gr.Label(label="Output tags")
150
+
151
+ start_btn.click(
152
+ fn=predict_tags,
153
+ inputs=[input_img, model_name_radio, tag_threshold_slider],
154
+ outputs=[output_tags, output_label],
155
+ )
156
+
157
+ ui.launch(
158
+ debug=True,
159
+ # share=True
160
+ )
161
+
162
+
163
+ if __name__ == "__main__":
164
+ demo()