Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions demo/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@


class VisualizationDemo(object):
def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False, confidence_threshold=0.5):
"""
Args:
cfg (CfgNode):
instance_mode (ColorMode):
parallel (bool): whether to run the model in different processes from visualization.
Useful since the visualization logic can be slow.
confidence_threshold (float): minimum score for instance predictions to be shown
"""
self.metadata = MetadataCatalog.get(
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
)
self.cpu_device = torch.device("cpu")
self.instance_mode = instance_mode
self.confidence_threshold = confidence_threshold

self.parallel = parallel
if parallel:
Expand Down Expand Up @@ -60,9 +62,16 @@ def run_on_image(self, image):
vis_output = visualizer.draw_sem_seg(
predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
)
if "instances" in predictions:
instances = predictions["instances"].to(self.cpu_device)
vis_output = visualizer.draw_instance_predictions(predictions=instances)
if "instances" in predictions:
instances = predictions["instances"].to(self.cpu_device)

# Filter instances by confidence threshold
if self.confidence_threshold > 0:
scores = instances.scores
keep = scores > self.confidence_threshold
instances = instances[keep]

vis_output = visualizer.draw_instance_predictions(predictions=instances)

return predictions, vis_output

Expand Down Expand Up @@ -94,6 +103,13 @@ def process_predictions(frame, predictions):
)
elif "instances" in predictions:
predictions = predictions["instances"].to(self.cpu_device)

# Filter instances by confidence threshold
if self.confidence_threshold > 0:
scores = predictions.scores
keep = scores > self.confidence_threshold
predictions = predictions[keep]

vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
elif "sem_seg" in predictions:
vis_frame = video_visualizer.draw_sem_seg(
Expand Down