mischeiwiller commited on
Commit
be91971
1 Parent(s): b368542

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -1
app.py CHANGED
@@ -7,23 +7,41 @@ import torch
7
  import numpy as np
8
 
9
  def preprocess_image(img):
 
 
 
10
  # Convert numpy array to Tensor and ensure correct shape
11
  if isinstance(img, np.ndarray):
12
  img = K.image_to_tensor(img, keepdim=False).float() / 255.0
 
 
 
 
 
 
 
 
13
 
14
  # Ensure 3D tensor (C, H, W)
15
  if img.ndim == 2:
16
  img = img.unsqueeze(0)
 
 
 
 
17
 
18
  # Ensure 3 channel image
19
  if img.shape[0] == 1:
20
- img = img.repeat(3, 1, 1)
21
  elif img.shape[0] > 3:
22
  img = img[:3] # Take only the first 3 channels if more than 3
23
 
 
 
24
  # Add batch dimension
25
  img = img.unsqueeze(0)
26
 
 
27
  return img
28
 
29
  def inference(img_1, img_2):
 
7
  import numpy as np
8
 
9
  def preprocess_image(img):
10
+ print(f"Input image type: {type(img)}")
11
+ print(f"Input image shape: {img.shape if hasattr(img, 'shape') else 'No shape attribute'}")
12
+
13
  # Convert numpy array to Tensor and ensure correct shape
14
  if isinstance(img, np.ndarray):
15
  img = K.image_to_tensor(img, keepdim=False).float() / 255.0
16
+ elif isinstance(img, torch.Tensor):
17
+ img = img.float()
18
+ if img.max() > 1.0:
19
+ img = img / 255.0
20
+ else:
21
+ raise ValueError(f"Unsupported image type: {type(img)}")
22
+
23
+ print(f"After conversion to tensor - shape: {img.shape}")
24
 
25
  # Ensure 3D tensor (C, H, W)
26
  if img.ndim == 2:
27
  img = img.unsqueeze(0)
28
+ elif img.ndim == 3 and img.shape[0] not in [1, 3]:
29
+ img = img.permute(2, 0, 1)
30
+
31
+ print(f"After ensuring 3D - shape: {img.shape}")
32
 
33
  # Ensure 3 channel image
34
  if img.shape[0] == 1:
35
+ img = img.expand(3, -1, -1)
36
  elif img.shape[0] > 3:
37
  img = img[:3] # Take only the first 3 channels if more than 3
38
 
39
+ print(f"After ensuring 3 channels - shape: {img.shape}")
40
+
41
  # Add batch dimension
42
  img = img.unsqueeze(0)
43
 
44
+ print(f"Final tensor shape: {img.shape}")
45
  return img
46
 
47
  def inference(img_1, img_2):