jinwei12 commited on
Commit
2a2fe0b
1 Parent(s): b2347e6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import plotly.graph_objects as go
3
+ import torch
4
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
5
+ import requests
6
+
7
+ def search_geonames(location):
8
+ api_endpoint = "http://api.geonames.org/searchJSON"
9
+ username = "zekun"
10
+
11
+ params = {
12
+ 'q': location,
13
+ 'username': username,
14
+ 'maxRows': 5
15
+ }
16
+
17
+ response = requests.get(api_endpoint, params=params)
18
+ data = response.json()
19
+
20
+ if 'geonames' in data:
21
+ fig = go.Figure()
22
+ for place_info in data['geonames']:
23
+ latitude = float(place_info.get('lat', 0.0))
24
+ longitude = float(place_info.get('lng', 0.0))
25
+
26
+ fig.add_trace(go.Scattermapbox(
27
+ lat=[latitude],
28
+ lon=[longitude],
29
+ mode='markers',
30
+ marker=go.scattermapbox.Marker(
31
+ size=10,
32
+ color='orange',
33
+ ),
34
+ text=[f'Location: {location}'],
35
+ hoverinfo="text",
36
+ hovertemplate='<b>Location</b>: %{text}',
37
+ ))
38
+
39
+ fig.update_layout(
40
+ mapbox_style="open-street-map",
41
+ hovermode='closest',
42
+ mapbox=dict(
43
+ bearing=0,
44
+ center=go.layout.mapbox.Center(
45
+ lat=latitude,
46
+ lon=longitude
47
+ ),
48
+ pitch=0,
49
+ zoom=2
50
+ ))
51
+
52
+ st.plotly_chart(fig)
53
+
54
+ # Return an empty figure
55
+ return go.Figure()
56
+
57
+
58
+ def mapping(location):
59
+ st.write(f"Mapping location: {location}")
60
+
61
+ search_geonames(location)
62
+
63
+
64
+
65
+ def generate_human_readable(tokens,labels):
66
+ ret = []
67
+ for t,lab in zip(tokens,labels):
68
+ if t == '[SEP]':
69
+ continue
70
+
71
+ if t.startswith("##") :
72
+ assert len(ret) > 0
73
+ ret[-1] = ret[-1] + t.strip('##')
74
+
75
+ elif lab==2:
76
+ assert len(ret) > 0
77
+ ret[-1] = ret[-1] + " "+ t.strip('##')
78
+ else:
79
+ ret.append(t)
80
+
81
+ return ret
82
+
83
+
84
+
85
+ def showOnMap(input_sentence):
86
+ # get the location names:
87
+
88
+ model_name = "zekun-li/geolm-base-toponym-recognition"
89
+
90
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
91
+ model = AutoModelForTokenClassification.from_pretrained(model_name)
92
+
93
+ tokens = tokenizer.encode(input_sentence, return_tensors="pt")
94
+
95
+ outputs = model(tokens)
96
+
97
+ predicted_labels = torch.argmax(outputs.logits, dim=2)
98
+
99
+ predicted_labels = predicted_labels.detach().cpu().numpy()
100
+
101
+ # "id2label": { "0": "O", "1": "B-Topo", "2": "I-Topo" }
102
+
103
+ predicted_labels = [model.config.id2label[label] for label in predicted_labels[0]]
104
+
105
+ predicted_labels = torch.argmax(outputs.logits, dim=2)
106
+
107
+ query_tokens = tokens[0][torch.where(predicted_labels[0] != 0)[0]]
108
+
109
+ query_labels = predicted_labels[0][torch.where(predicted_labels[0] != 0)[0]]
110
+
111
+ human_readable = generate_human_readable(tokenizer.convert_ids_to_tokens(query_tokens), query_labels)
112
+ #['Los Angeles', 'L . A .', 'California', 'U . S .', 'Southern California', 'Los Angeles', 'United States', 'New York City']
113
+
114
+ return human_readable
115
+
116
+
117
+
118
+
119
+
120
+ def show_on_map():
121
+
122
+ input = st.text_area("Enter a sentence:", height=200)
123
+
124
+ st.button("Submit")
125
+
126
+ places = showOnMap(input)
127
+
128
+ selected_place = st.selectbox("Select a location:", places)
129
+ mapping(selected_place)
130
+
131
+
132
+
133
+ if __name__ == "__main__":
134
+ show_on_map()