Spaces:

aksell commited on
Commit
1b13d7f
1 Parent(s): b3a37f0

Make model selector more compact

Browse files
Files changed (1) hide show
  1. hexviz/app.py +13 -12
hexviz/app.py CHANGED
@@ -18,10 +18,6 @@ models = [
18
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
19
  ]
20
 
21
- selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
22
- selected_model = next((model for model in models if model.name.value == selected_model_name), None)
23
-
24
-
25
  st.sidebar.markdown(
26
  """
27
  Select Protein
@@ -58,17 +54,15 @@ n_pairs = st.sidebar.number_input("Num attention pairs labeled", value=2, min_va
58
  label_highest = st.sidebar.checkbox("Label highest attention pairs", value=True)
59
  # TODO add avg or max attention as params
60
 
61
- if selected_model.name == ModelType.ZymCTRL:
62
- try:
63
- ec_class = structure.header["compound"]["1"]["ec"]
64
- except KeyError:
65
- ec_class = None
66
- if ec_class and selected_model.name == ModelType.ZymCTRL:
67
- ec_class = st.sidebar.text_input("Enzyme classification number fetched from PDB", ec_class)
68
 
69
 
70
- left, right = st.columns(2)
 
 
71
  with left:
 
 
 
72
  layer_one = st.number_input("Layer", value=10, min_value=1, max_value=selected_model.layers)
73
  layer = layer_one - 1
74
  with right:
@@ -76,6 +70,13 @@ with right:
76
  head = head_one - 1
77
 
78
 
 
 
 
 
 
 
 
79
 
80
  attention_pairs = get_attention_pairs(pdb_id, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name)
81
 
 
18
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
19
  ]
20
 
 
 
 
 
21
  st.sidebar.markdown(
22
  """
23
  Select Protein
 
54
  label_highest = st.sidebar.checkbox("Label highest attention pairs", value=True)
55
  # TODO add avg or max attention as params
56
 
 
 
 
 
 
 
 
57
 
58
 
59
+
60
+
61
+ left, mid, right = st.columns(3)
62
  with left:
63
+ selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
64
+ selected_model = next((model for model in models if model.name.value == selected_model_name), None)
65
+ with mid:
66
  layer_one = st.number_input("Layer", value=10, min_value=1, max_value=selected_model.layers)
67
  layer = layer_one - 1
68
  with right:
 
70
  head = head_one - 1
71
 
72
 
73
+ if selected_model.name == ModelType.ZymCTRL:
74
+ try:
75
+ ec_class = structure.header["compound"]["1"]["ec"]
76
+ except KeyError:
77
+ ec_class = None
78
+ if ec_class and selected_model.name == ModelType.ZymCTRL:
79
+ ec_class = st.sidebar.text_input("Enzyme classification number fetched from PDB", ec_class)
80
 
81
  attention_pairs = get_attention_pairs(pdb_id, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name)
82