Quentin GallouΓ©dec commited on
Commit
0ef2585
β€’
1 Parent(s): 8e630b3

better hanfdle refresh

Browse files
Files changed (1) hide show
  1. app.py +120 -100
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import json
2
  import os
3
- import re
4
 
5
  import gradio as gr
6
  import numpy as np
@@ -9,7 +8,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
9
  from huggingface_hub import HfApi
10
 
11
  from src.backend import backend_routine
12
-
13
  from src.logging import configure_root_logger, setup_logger
14
 
15
 
@@ -17,71 +16,72 @@ configure_root_logger()
17
  logger = setup_logger(__name__)
18
 
19
  API = HfApi(token=os.environ.get("TOKEN"))
20
- RESULTS_REPO = f"open-rl-leaderboard/results"
 
21
  ALL_ENV_IDS = {
22
  "Atari": [
23
- "Adventure",
24
- "AirRaid",
25
- "Alien",
26
- "Amidar",
27
- "Assault",
28
- "Asterix",
29
- "Asteroids",
30
- "Atlantis",
31
- "BankHeist",
32
- "BattleZone",
33
- "BeamRider",
34
- "Berzerk",
35
- "Bowling",
36
- "Boxing",
37
- "Breakout",
38
- "Carnival",
39
- "Centipede",
40
- "ChopperCommand",
41
- "CrazyClimber",
42
- "Defender",
43
- "DemonAttack",
44
- "DoubleDunk",
45
- "ElevatorAction",
46
- "Enduro",
47
- "FishingDerby",
48
- "Freeway",
49
- "Frostbite",
50
- "Gopher",
51
- "Gravitar",
52
- "Hero",
53
- "IceHockey",
54
- "Jamesbond",
55
- "JourneyEscape",
56
- "Kangaroo",
57
- "Krull",
58
- "KungFuMaster",
59
- "MontezumaRevenge",
60
- "MsPacman",
61
- "NameThisGame",
62
- "Phoenix",
63
- "Pitfall",
64
- "Pong",
65
- "Pooyan",
66
- "PrivateEye",
67
- "Qbert",
68
- "Riverraid",
69
- "RoadRunner",
70
- "Robotank",
71
- "Seaquest",
72
- "Skiing",
73
- "Solaris",
74
- "SpaceInvaders",
75
- "StarGunner",
76
- "Tennis",
77
- "TimePilot",
78
- "Tutankham",
79
- "UpNDown",
80
- "Venture",
81
- "VideoPinball",
82
- "WizardOfWor",
83
- "YarsRevenge",
84
- "Zaxxon",
85
  ],
86
  "Box2D": [
87
  "BipedalWalker-v3",
@@ -120,18 +120,16 @@ ALL_ENV_IDS = {
120
 
121
 
122
  def get_leaderboard_df():
123
- # List all results files in results repo
124
- pattern = re.compile(r"^[^/]*/[^/]*/[^/]*results_[a-f0-9]+\.json$")
125
- filenames = API.list_repo_files(RESULTS_REPO, repo_type="dataset")
126
- filenames = [filename for filename in filenames if pattern.match(filename)]
127
 
128
  data = []
129
  for filename in filenames:
130
- path = API.hf_hub_download(repo_id=RESULTS_REPO, filename=filename, repo_type="dataset")
131
- with open(path) as fp:
132
  report = json.load(fp)
133
  user_id, model_id = report["config"]["model_id"].split("/")
134
- row = {"user_id": user_id, "model_id": model_id}
135
  if report["status"] == "DONE" and len(report["results"]) > 0:
136
  env_ids = list(report["results"].keys())
137
  assert len(env_ids) == 1, "Only one environment supported for the moment"
@@ -165,6 +163,29 @@ def format_df(df: pd.DataFrame):
165
  return df.values.tolist()
166
 
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  HEADING = """
169
  # πŸ₯‡ Open RL Leaderboard πŸ₯‡
170
 
@@ -243,54 +264,53 @@ If you encounter any issue, please [open an issue](https://huggingface.co/spaces
243
  ```
244
  """
245
 
246
-
247
  with gr.Blocks() as demo:
248
  gr.Markdown(HEADING)
249
  with gr.Tabs(elem_classes="tab-buttons") as tabs:
250
  with gr.TabItem("πŸ… Leaderboard"):
251
  df = get_leaderboard_df()
 
 
 
252
  for env_domain, env_ids in ALL_ENV_IDS.items():
253
  with gr.TabItem(env_domain):
254
  for env_id in env_ids:
255
- with gr.TabItem(env_id):
 
 
 
256
  with gr.Row(equal_height=False):
257
- if env_domain == "Atari":
258
- env_id = f"{env_id}NoFrameskip-v4"
259
- env_df = select_env(df, env_id)
260
- gr.components.Dataframe(
261
- value=format_df(env_df),
262
  headers=["πŸ†", "πŸ§‘ User", "πŸ€– Model id", "πŸ“Š Mean episodic return"],
263
  datatype=["number", "markdown", "markdown", "number"],
264
  row_count=(10, "fixed"),
265
  scale=3,
266
  )
267
- # Get the best model and
268
- if not env_df.empty:
269
- user_id = env_df.iloc[0]["user_id"]
270
- model_id = env_df.iloc[0]["model_id"]
271
- video_path = API.hf_hub_download(
272
- repo_id=f"{user_id}/{model_id}",
273
- filename="replay.mp4",
274
- revision="main",
275
- repo_type="model",
276
- )
277
- video = gr.PlayableVideo(
278
- video_path,
279
- label=model_id,
280
- scale=1,
281
- min_width=50,
282
- autoplay=True,
283
- show_download_button=False,
284
- show_share_button=False,
285
- )
286
- # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689
287
 
288
  with gr.TabItem("πŸ“ About"):
289
  gr.Markdown(ABOUT_TEXT)
290
 
 
 
 
291
 
292
  scheduler = BackgroundScheduler()
293
- scheduler.add_job(func=backend_routine, trigger="interval", seconds=10 * 60, max_instances=1)
294
  scheduler.start()
295
 
296
 
 
1
  import json
2
  import os
 
3
 
4
  import gradio as gr
5
  import numpy as np
 
8
  from huggingface_hub import HfApi
9
 
10
  from src.backend import backend_routine
11
+ import glob
12
  from src.logging import configure_root_logger, setup_logger
13
 
14
 
 
16
  logger = setup_logger(__name__)
17
 
18
  API = HfApi(token=os.environ.get("TOKEN"))
19
+ RESULTS_REPO = "open-rl-leaderboard/results"
20
+ REFRESH_RATE = 5 * 60 # 5 minutes
21
  ALL_ENV_IDS = {
22
  "Atari": [
23
+ "AdventureNoFrameskip-v4",
24
+ "AirRaidNoFrameskip-v4",
25
+ "AlienNoFrameskip-v4",
26
+ "AmidarNoFrameskip-v4",
27
+ "AssaultNoFrameskip-v4",
28
+ "AsterixNoFrameskip-v4",
29
+ "AsteroidsNoFrameskip-v4",
30
+ "AtlantisNoFrameskip-v4",
31
+ "BankHeistNoFrameskip-v4",
32
+ "BattleZoneNoFrameskip-v4",
33
+ "BeamRiderNoFrameskip-v4",
34
+ "BerzerkNoFrameskip-v4",
35
+ "BowlingNoFrameskip-v4",
36
+ "BoxingNoFrameskip-v4",
37
+ "BreakoutNoFrameskip-v4",
38
+ "CarnivalNoFrameskip-v4",
39
+ "CentipedeNoFrameskip-v4",
40
+ "ChopperCommandNoFrameskip-v4",
41
+ "CrazyClimberNoFrameskip-v4",
42
+ "DefenderNoFrameskip-v4",
43
+ "DemonAttackNoFrameskip-v4",
44
+ "DoubleDunkNoFrameskip-v4",
45
+ "ElevatorActionNoFrameskip-v4",
46
+ "EnduroNoFrameskip-v4",
47
+ "FishingDerbyNoFrameskip-v4",
48
+ "FreewayNoFrameskip-v4",
49
+ "FrostbiteNoFrameskip-v4",
50
+ "GopherNoFrameskip-v4",
51
+ "GravitarNoFrameskip-v4",
52
+ "HeroNoFrameskip-v4",
53
+ "IceHockeyNoFrameskip-v4",
54
+ "JamesbondNoFrameskip-v4",
55
+ "JourneyEscapeNoFrameskip-v4",
56
+ "KangarooNoFrameskip-v4",
57
+ "KrullNoFrameskip-v4",
58
+ "KungFuMasterNoFrameskip-v4",
59
+ "MontezumaRevengeNoFrameskip-v4",
60
+ "MsPacmanNoFrameskip-v4",
61
+ "NameThisGameNoFrameskip-v4",
62
+ "PhoenixNoFrameskip-v4",
63
+ "PitfallNoFrameskip-v4",
64
+ "PongNoFrameskip-v4",
65
+ "PooyanNoFrameskip-v4",
66
+ "PrivateEyeNoFrameskip-v4",
67
+ "QbertNoFrameskip-v4",
68
+ "RiverraidNoFrameskip-v4",
69
+ "RoadRunnerNoFrameskip-v4",
70
+ "RobotankNoFrameskip-v4",
71
+ "SeaquestNoFrameskip-v4",
72
+ "SkiingNoFrameskip-v4",
73
+ "SolarisNoFrameskip-v4",
74
+ "SpaceInvadersNoFrameskip-v4",
75
+ "StarGunnerNoFrameskip-v4",
76
+ "TennisNoFrameskip-v4",
77
+ "TimePilotNoFrameskip-v4",
78
+ "TutankhamNoFrameskip-v4",
79
+ "UpNDownNoFrameskip-v4",
80
+ "VentureNoFrameskip-v4",
81
+ "VideoPinballNoFrameskip-v4",
82
+ "WizardOfWorNoFrameskip-v4",
83
+ "YarsRevengeNoFrameskip-v4",
84
+ "ZaxxonNoFrameskip-v4",
85
  ],
86
  "Box2D": [
87
  "BipedalWalker-v3",
 
120
 
121
 
122
  def get_leaderboard_df():
123
+ dir_path = API.snapshot_download(repo_id=RESULTS_REPO, repo_type="dataset")
124
+ pattern = os.path.join(dir_path, "**", "results_*.json")
125
+ filenames = glob.glob(pattern, recursive=True)
 
126
 
127
  data = []
128
  for filename in filenames:
129
+ with open(filename) as fp:
 
130
  report = json.load(fp)
131
  user_id, model_id = report["config"]["model_id"].split("/")
132
+ row = {"user_id": user_id, "model_id": model_id, "model_sha": report["config"]["model_sha"]}
133
  if report["status"] == "DONE" and len(report["results"]) > 0:
134
  env_ids = list(report["results"].keys())
135
  assert len(env_ids) == 1, "Only one environment supported for the moment"
 
163
  return df.values.tolist()
164
 
165
 
166
+ def refresh_dataframes():
167
+ df = get_leaderboard_df()
168
+ all_dfs = [format_df(select_env(df, env_id)) for env_id in all_env_ids]
169
+ return all_dfs
170
+
171
+
172
+ def refresh_videos():
173
+ df = get_leaderboard_df()
174
+ outputs = []
175
+ for env_id in all_env_ids:
176
+ env_df = select_env(df, env_id)
177
+ if not env_df.empty:
178
+ user_id = env_df.iloc[0]["user_id"]
179
+ model_id = env_df.iloc[0]["model_id"]
180
+ model_sha = env_df.iloc[0]["model_sha"]
181
+ repo_id = f"{user_id}/{model_id}"
182
+ video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=model_sha, repo_type="model")
183
+ outputs.append(video_path)
184
+ else:
185
+ outputs.append(None)
186
+ return outputs
187
+
188
+
189
  HEADING = """
190
  # πŸ₯‡ Open RL Leaderboard πŸ₯‡
191
 
 
264
  ```
265
  """
266
 
 
267
  with gr.Blocks() as demo:
268
  gr.Markdown(HEADING)
269
  with gr.Tabs(elem_classes="tab-buttons") as tabs:
270
  with gr.TabItem("πŸ… Leaderboard"):
271
  df = get_leaderboard_df()
272
+ all_env_ids = []
273
+ all_gr_dfs = []
274
+ all_gr_videos = []
275
  for env_domain, env_ids in ALL_ENV_IDS.items():
276
  with gr.TabItem(env_domain):
277
  for env_id in env_ids:
278
+ # If the env_id envs with "NoFrameskip-v4", we remove it
279
+ tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id
280
+ with gr.TabItem(tab_env_id):
281
+ logger.info(f"Creating tab for {env_id}")
282
  with gr.Row(equal_height=False):
283
+ # Display the leaderboard
284
+ gr_df = gr.components.Dataframe(
 
 
 
285
  headers=["πŸ†", "πŸ§‘ User", "πŸ€– Model id", "πŸ“Š Mean episodic return"],
286
  datatype=["number", "markdown", "markdown", "number"],
287
  row_count=(10, "fixed"),
288
  scale=3,
289
  )
290
+
291
+ # Play the video of the best model
292
+ gr_video = gr.PlayableVideo( # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689
293
+ scale=1,
294
+ min_width=50,
295
+ autoplay=True,
296
+ show_download_button=False,
297
+ show_share_button=False,
298
+ show_label=False,
299
+ )
300
+
301
+ all_env_ids.append(env_id)
302
+ all_gr_dfs.append(gr_df)
303
+ all_gr_videos.append(gr_video)
 
 
 
 
 
 
304
 
305
  with gr.TabItem("πŸ“ About"):
306
  gr.Markdown(ABOUT_TEXT)
307
 
308
+ demo.load(refresh_dataframes, outputs=all_gr_dfs, every=REFRESH_RATE)
309
+ demo.load(refresh_videos, outputs=all_gr_videos, every=REFRESH_RATE)
310
+
311
 
312
  scheduler = BackgroundScheduler()
313
+ scheduler.add_job(func=backend_routine, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
314
  scheduler.start()
315
 
316