Quentin Gallouédec commited on
Commit
b5f1eb6
1 Parent(s): c8398fa
Files changed (2) hide show
  1. requirements.txt +0 -1
  2. src/evaluation.py +1 -4
requirements.txt CHANGED
@@ -14,7 +14,6 @@ pandas==2.0.0
14
  python-dateutil==2.8.2
15
  requests==2.28.2
16
  rliable==1.0.8
17
- --extra-index-url https://download.pytorch.org/whl/cu113
18
  torch==2.2.2
19
  tqdm==4.65.0
20
 
 
14
  python-dateutil==2.8.2
15
  requests==2.28.2
16
  rliable==1.0.8
 
17
  torch==2.2.2
18
  tqdm==4.65.0
19
 
src/evaluation.py CHANGED
@@ -15,9 +15,6 @@ logger = setup_logger(__name__)
15
 
16
  API = HfApi(token=os.environ.get("TOKEN"))
17
 
18
- logger.info(f"Is CUDA available: {torch.cuda.is_available()}")
19
- logger.info(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
20
-
21
 
22
  ALL_ENV_IDS = [
23
  "AdventureNoFrameskip-v4",
@@ -353,7 +350,7 @@ def evaluate(model_id, revision):
353
  observations, _ = envs.reset()
354
  episodic_returns = []
355
  while len(episodic_returns) < 10:
356
- actions = agent(torch.tensor(observations, device="cuda")).cpu().numpy()
357
  observations, _, _, _, infos = envs.step(actions)
358
  if "final_info" in infos:
359
  for info in infos["final_info"]:
 
15
 
16
  API = HfApi(token=os.environ.get("TOKEN"))
17
 
 
 
 
18
 
19
  ALL_ENV_IDS = [
20
  "AdventureNoFrameskip-v4",
 
350
  observations, _ = envs.reset()
351
  episodic_returns = []
352
  while len(episodic_returns) < 10:
353
+ actions = agent(torch.tensor(observations)).numpy()
354
  observations, _, _, _, infos = envs.step(actions)
355
  if "final_info" in infos:
356
  for info in infos["final_info"]: