|
from torch.functional import Tensor |
|
from general_utils import log |
|
from collections import defaultdict |
|
import numpy as np |
|
|
|
import torch |
|
from torch.nn import functional as nnf |
|
|
|
|
|
class BaseMetric(object): |
|
|
|
def __init__(self, metric_names, pred_range=None, gt_index=0, pred_index=0, eval_intermediate=True, |
|
eval_validation=True): |
|
self._names = tuple(metric_names) |
|
self._eval_intermediate = eval_intermediate |
|
self._eval_validation = eval_validation |
|
|
|
self._pred_range = pred_range |
|
self._pred_index = pred_index |
|
self._gt_index = gt_index |
|
|
|
self.predictions = [] |
|
self.ground_truths = [] |
|
|
|
def eval_intermediate(self): |
|
return self._eval_intermediate |
|
|
|
def eval_validation(self): |
|
return self._eval_validation |
|
|
|
def names(self): |
|
return self._names |
|
|
|
def add(self, predictions, ground_truth): |
|
raise NotImplementedError |
|
|
|
def value(self): |
|
raise NotImplementedError |
|
|
|
def scores(self): |
|
|
|
value = self.value() |
|
if type(value) == dict: |
|
return value |
|
else: |
|
assert type(value) in {list, tuple} |
|
return list(zip(self.names(), self.value())) |
|
|
|
def _get_pred_gt(self, predictions, ground_truth): |
|
pred = predictions[self._pred_index] |
|
gt = ground_truth[self._gt_index] |
|
|
|
if self._pred_range is not None: |
|
pred = pred[:, self._pred_range[0]: self._pred_range[1]] |
|
|
|
return pred, gt |
|
|
|
|
|
class FixedIntervalMetrics(BaseMetric): |
|
|
|
def __init__(self, sigmoid=False, ignore_mask=False, resize_to=None, |
|
resize_pred=None, n_values=51, custom_threshold=None): |
|
|
|
|
|
super().__init__(('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh')) |
|
self.intersections = [] |
|
self.unions = [] |
|
|
|
self.sigmoid = sigmoid |
|
self.resize_to = resize_to |
|
self.resize_pred = resize_pred |
|
self.class_count = defaultdict(lambda: 0) |
|
self.per_class = defaultdict(lambda : [0,0]) |
|
self.ignore_mask = ignore_mask |
|
self.custom_threshold = custom_threshold |
|
|
|
self.scores_ap = [] |
|
self.scores_iou = [] |
|
self.gts, self.preds = [], [] |
|
self.classes = [] |
|
|
|
|
|
self.threshold_values = np.linspace(0, 1, n_values)[1:-1] |
|
|
|
self.metrics = dict(tp=[], fp=[], fn=[], tn=[]) |
|
|
|
def add(self, pred, gt): |
|
|
|
pred_batch = pred[0].cpu() |
|
|
|
if self.sigmoid: |
|
pred_batch = torch.sigmoid(pred_batch) |
|
|
|
gt_batch = gt[0].cpu() |
|
mask_batch = gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch)) |
|
cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch) |
|
|
|
if self.resize_to is not None: |
|
gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode='nearest') |
|
pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode='bilinear', align_corners=False) |
|
|
|
if isinstance(cls_batch, torch.Tensor): |
|
cls_batch = cls_batch.cpu().numpy().tolist() |
|
|
|
assert len(gt_batch) == len(pred_batch) == len(cls_batch), f'{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}' |
|
|
|
for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch): |
|
|
|
if self.resize_pred: |
|
predictions = nnf.interpolate(predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode='bilinear', align_corners=True) |
|
|
|
p = predictions.flatten() |
|
g = ground_truth.flatten() |
|
|
|
assert len(p) == len(g) |
|
|
|
if mask is not None: |
|
m = mask.flatten().bool() |
|
p = p[m] |
|
g = g[m] |
|
|
|
p_sorted = p.sort() |
|
p = p_sorted.values |
|
g = g[p_sorted.indices] |
|
|
|
tps, fps, fns, tns = [], [], [], [] |
|
for thresh in self.threshold_values: |
|
|
|
valid = torch.where(p > thresh)[0] |
|
if len(valid) > 0: |
|
n = int(valid[0]) |
|
else: |
|
n = len(g) |
|
|
|
fn = int(g[:n].sum()) |
|
tp = int(g[n:].sum()) |
|
fns += [fn] |
|
tns += [n - fn] |
|
tps += [tp] |
|
fps += [len(g) - n - tp] |
|
|
|
self.metrics['tp'] += [tps] |
|
self.metrics['fp'] += [fps] |
|
self.metrics['fn'] += [fns] |
|
self.metrics['tn'] += [tns] |
|
|
|
self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls] |
|
|
|
def value(self): |
|
|
|
import time |
|
t_start = time.time() |
|
|
|
if set(self.classes) == set([None]): |
|
all_classes = None |
|
log.warning('classes were not provided, cannot compute mIoU') |
|
else: |
|
all_classes = set(int(c) for c in self.classes) |
|
|
|
|
|
summed = {k: [sum([self.metrics[k][i][j] |
|
for i in range(len(self.metrics[k]))]) |
|
for j in range(len(self.threshold_values))] |
|
for k in self.metrics.keys()} |
|
|
|
if all_classes is not None: |
|
|
|
assert len(self.classes) == len(self.metrics['tp']) == len(self.metrics['fn']) |
|
|
|
metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes} |
|
for i in range(len(self.metrics['tp'])): |
|
for k in self.metrics.keys(): |
|
metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]] |
|
|
|
|
|
summed_by_cls = {k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} for k in self.metrics.keys()} |
|
|
|
|
|
|
|
|
|
assert (np.array(summed['fp']) + np.array(summed['tp']) ).sum(), 'no predictions is made' |
|
|
|
|
|
precisions = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j]) for j in range(len(self.threshold_values)) |
|
if summed['tp'][j] + summed['fp'][j] > 0] |
|
recalls = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values)) |
|
if summed['tp'][j] + summed['fp'][j] > 0] |
|
|
|
|
|
recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0])) |
|
|
|
from scipy.integrate import simps |
|
ap = simps(precisions, recalls) |
|
|
|
|
|
fgiou_scores = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))] |
|
|
|
biniou_scores = [ |
|
0.5*(summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j])) + |
|
0.5*(summed['tn'][j] / (1 + summed['tn'][j] + summed['fn'][j] + summed['fp'][j])) |
|
for j in range(len(self.threshold_values)) |
|
] |
|
|
|
index_0p5 = self.threshold_values.tolist().index(0.5) |
|
index_0p1 = self.threshold_values.tolist().index(0.1) |
|
index_0p2 = self.threshold_values.tolist().index(0.2) |
|
index_0p3 = self.threshold_values.tolist().index(0.3) |
|
|
|
if self.custom_threshold is not None: |
|
index_ct = self.threshold_values.tolist().index(self.custom_threshold) |
|
|
|
if all_classes is not None: |
|
|
|
mean_ious = [np.mean([summed_by_cls['tp'][c][j] / (1 + summed_by_cls['tp'][c][j] + summed_by_cls['fp'][c][j] + summed_by_cls['fn'][c][j]) |
|
for c in all_classes]) |
|
for j in range(len(self.threshold_values))] |
|
|
|
mean_iou_dict = { |
|
'miou_best': max(mean_ious) if all_classes is not None else None, |
|
'miou_0.5': mean_ious[index_0p5] if all_classes is not None else None, |
|
'miou_0.1': mean_ious[index_0p1] if all_classes is not None else None, |
|
'miou_0.2': mean_ious[index_0p2] if all_classes is not None else None, |
|
'miou_0.3': mean_ious[index_0p3] if all_classes is not None else None, |
|
'miou_best_t': self.threshold_values[np.argmax(mean_ious)], |
|
'mean_iou_ct': mean_ious[index_ct] if all_classes is not None and self.custom_threshold is not None else None, |
|
'mean_iou_scores': mean_ious, |
|
} |
|
|
|
print(f'metric computation on {(len(all_classes) if all_classes is not None else "no")} classes took {time.time() - t_start:.1f}s') |
|
|
|
return { |
|
'ap': ap, |
|
|
|
|
|
'fgiou_best': max(fgiou_scores), |
|
'fgiou_0.5': fgiou_scores[index_0p5], |
|
'fgiou_0.1': fgiou_scores[index_0p1], |
|
'fgiou_0.2': fgiou_scores[index_0p2], |
|
'fgiou_0.3': fgiou_scores[index_0p3], |
|
'fgiou_best_t': self.threshold_values[np.argmax(fgiou_scores)], |
|
|
|
|
|
|
|
|
|
|
|
'biniou_best': max(biniou_scores), |
|
'biniou_0.5': biniou_scores[index_0p5], |
|
'biniou_0.1': biniou_scores[index_0p1], |
|
'biniou_0.2': biniou_scores[index_0p2], |
|
'biniou_0.3': biniou_scores[index_0p3], |
|
'biniou_best_t': self.threshold_values[np.argmax(biniou_scores)], |
|
|
|
|
|
'fgiou_ct': fgiou_scores[index_ct] if self.custom_threshold is not None else None, |
|
'biniou_ct': biniou_scores[index_ct] if self.custom_threshold is not None else None, |
|
'ct': self.custom_threshold, |
|
|
|
|
|
'fgiou_scores': fgiou_scores, |
|
'biniou_scores': biniou_scores, |
|
'precision_recall_curve': sorted(list(set(zip(recalls, precisions)))), |
|
'summed_statistics': summed, |
|
'summed_by_cls_statistics': summed_by_cls, |
|
|
|
**mean_iou_dict |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|