Conversation
long8v
left a comment
There was a problem hiding this comment.
- bipartite matching
- model I/O
- architecture
위주로 봄
| Params: | ||
| outputs: This is a dict that contains at least these entries: | ||
| "pred_logits": Tensor of dim [batch_size, num_entities, num_entity_classes] with the entity classification logits | ||
| "pred_boxes": Tensor of dim [batch_size, num_entities, 4] with the predicted box coordinates | ||
| "sub_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the subject classification logits | ||
| "sub_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted subject box coordinates | ||
| "obj_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the object classification logits | ||
| "obj_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted object box coordinates | ||
| "rel_logits": Tensor of dim [batch_size, num_triplets, num_predicate_classes] with the predicate classification logits |
There was a problem hiding this comment.
모델의 output들은 아래와 같음.
entity들에 대한건 logits과 bbox coord
아하! triplet에 대한 건
subject에 대해 [batch_size, num_triplets, num_entity_classes]으로 나오고
[batch_size, num_triplets, 4]으로 나오고 obj에 대해서도 똑같이 나오는군.
n번째 triplet이 주어졌을 때, entity class와 bbox prediction / relation 을 하는구나.
There was a problem hiding this comment.
다시 들어가는 param을 보면
- pred_boxes : OD에서 나오는 boxes. 그래서 1번째 차원이 num_entities
- pred_logits : OD에서 나오는 cls logits
- sub_boxes : triplet decoder에서 나오는 subject에 대한 boxes
- sub_logits : ..
| targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: | ||
| "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth | ||
| objects in the target) containing the class labels | ||
| "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates | ||
| "image_id": Image index | ||
| "orig_size": Tensor of dim [2] with the height and width | ||
| "size": Tensor of dim [2] with the height and width after transformation | ||
| "rel_annotations": Tensor of dim [num_gt_triplet, 3] with the subject index/object index/predicate class |
There was a problem hiding this comment.
target은
- labels : 타겟 box class labels
- boxes : 타겟 box들의 coord
- rel_annotations : num_gt_triplets, 3(subject, object, predicate class.) (index?)
| A list of size batch_size, containing tuples of (index_i, index_j) where: | ||
| - index_i is the indices of the selected entity predictions (in order) | ||
| - index_j is the indices of the corresponding selected entity targets (in order) | ||
| A list of size batch_size, containing tuples of (index_i, index_j) where: | ||
| - index_i is the indices of the selected triplet predictions (in order) | ||
| - index_j is the indices of the corresponding selected triplet targets (in order) | ||
| Subject loss weight (Type: bool) to determine if back propagation should be conducted | ||
| Object loss weight (Type: bool) to determine if back propagation should be conducted |
There was a problem hiding this comment.
(selected entity prediction, selected entity target), (selected triplet prediction, selected triplet target)
| self.so_mask_conv = nn.Sequential(torch.nn.Upsample(size=(28, 28)), | ||
| nn.Conv2d(2, 64, kernel_size=3, stride=2, padding=3, bias=True), | ||
| nn.ReLU(inplace=True), | ||
| nn.BatchNorm2d(64), | ||
| nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | ||
| nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, bias=True), | ||
| nn.ReLU(inplace=True), | ||
| nn.BatchNorm2d(32)) | ||
| self.so_mask_fc = nn.Sequential(nn.Linear(2048, 512), | ||
| nn.ReLU(inplace=True), | ||
| nn.Linear(512, 128)) | ||
|
|
||
| # predicate classification | ||
| self.rel_class_embed = MLP(hidden_dim*2+128, hidden_dim, num_rel_classes + 1, 2) |
There was a problem hiding this comment.
relation 뽑을 때는 아까 뽑은 obj_maps 사용해서 CNN레이어 통해서 함
|
|
||
| if isinstance(samples, (list, torch.Tensor)): | ||
| samples = nested_tensor_from_tensor_list(samples) | ||
| features, pos = self.backbone(samples) |
| hs, hs_t, so_masks, _ = self.transformer(self.input_proj(src), mask, self.entity_embed.weight, | ||
| self.triplet_embed.weight, pos[-1], self.so_embed.weight) |
There was a problem hiding this comment.
transformer encoder decoder 통과
| so_masks = so_masks.detach() | ||
| so_masks = self.so_mask_conv(so_masks.view(-1, 2, src.shape[-2],src.shape[-1])).view(hs_t.shape[0], hs_t.shape[1], hs_t.shape[2],-1) | ||
| so_masks = self.so_mask_fc(so_masks) |
| outputs_class = self.entity_class_embed(hs) | ||
| outputs_coord = self.entity_bbox_embed(hs).sigmoid() | ||
|
|
||
| outputs_class_sub = self.sub_class_embed(hs_sub) | ||
| outputs_coord_sub = self.sub_bbox_embed(hs_sub).sigmoid() | ||
|
|
||
| outputs_class_obj = self.obj_class_embed(hs_obj) | ||
| outputs_coord_obj = self.obj_bbox_embed(hs_obj).sigmoid() | ||
|
|
||
| outputs_class_rel = self.rel_class_embed(torch.cat((hs_sub, hs_obj, so_masks), dim=-1)) | ||
|
|
||
| out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], | ||
| 'sub_logits': outputs_class_sub[-1], 'sub_boxes': outputs_coord_sub[-1], | ||
| 'obj_logits': outputs_class_obj[-1], 'obj_boxes': outputs_coord_obj[-1], | ||
| 'rel_logits': outputs_class_rel[-1]} | ||
| if self.aux_loss: | ||
| out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, outputs_class_sub, outputs_coord_sub, | ||
| outputs_class_obj, outputs_coord_obj, outputs_class_rel) |
| outputs = model(img) | ||
|
|
||
| # keep only predictions with 0.+ confidence | ||
| probas = outputs['rel_logits'].softmax(-1)[0, :, :-1] | ||
| probas_sub = outputs['sub_logits'].softmax(-1)[0, :, :-1] | ||
| probas_obj = outputs['obj_logits'].softmax(-1)[0, :, :-1] | ||
| keep = torch.logical_and(probas.max(-1).values > 0.3, torch.logical_and(probas_sub.max(-1).values > 0.3, | ||
| probas_obj.max(-1).values > 0.3)) | ||
|
|
||
| # convert boxes from [0; 1] to image scales | ||
| sub_bboxes_scaled = rescale_bboxes(outputs['sub_boxes'][0, keep], im.size) | ||
| obj_bboxes_scaled = rescale_bboxes(outputs['obj_boxes'][0, keep], im.size) | ||
|
|
||
| topk = 10 | ||
| keep_queries = torch.nonzero(keep, as_tuple=True)[0] | ||
| indices = torch.argsort(-probas[keep_queries].max(-1)[0] * probas_sub[keep_queries].max(-1)[0] * probas_obj[keep_queries].max(-1)[0])[:topk] | ||
| keep_queries = keep_queries[indices] |
There was a problem hiding this comment.
inference는 그냥 prob 기준치로 자르고, sort해서 뽑나보다
long8v
left a comment
There was a problem hiding this comment.
rel cost가 어떻게 정의되는지 다시 읽음
| Params: | ||
| outputs: This is a dict that contains at least these entries: | ||
| "pred_logits": Tensor of dim [batch_size, num_entities, num_entity_classes] with the entity classification logits | ||
| "pred_boxes": Tensor of dim [batch_size, num_entities, 4] with the predicted box coordinates | ||
| "sub_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the subject classification logits | ||
| "sub_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted subject box coordinates | ||
| "obj_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the object classification logits | ||
| "obj_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted object box coordinates | ||
| "rel_logits": Tensor of dim [batch_size, num_triplets, num_predicate_classes] with the predicate classification logits |
| Params: | ||
| outputs: This is a dict that contains at least these entries: | ||
| "pred_logits": Tensor of dim [batch_size, num_entities, num_entity_classes] with the entity classification logits | ||
| "pred_boxes": Tensor of dim [batch_size, num_entities, 4] with the predicted box coordinates | ||
| "sub_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the subject classification logits | ||
| "sub_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted subject box coordinates | ||
| "obj_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the object classification logits | ||
| "obj_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted object box coordinates | ||
| "rel_logits": Tensor of dim [batch_size, num_triplets, num_predicate_classes] with the predicate classification logits |
There was a problem hiding this comment.
다시 들어가는 param을 보면
- pred_boxes : OD에서 나오는 boxes. 그래서 1번째 차원이 num_entities
- pred_logits : OD에서 나오는 cls logits
- sub_boxes : triplet decoder에서 나오는 subject에 대한 boxes
- sub_logits : ..
| bs, num_queries = outputs["pred_logits"].shape[:2] | ||
| num_queries_rel = outputs["rel_logits"].shape[1] |
There was a problem hiding this comment.
num_queries는 DETR decoder에 들어가는 num_queries
num_queries_rel은 아마 num_sub=num_obj=num_pred
| # Concat the subject/object/predicate labels and subject/object boxes | ||
| sub_tgt_bbox = torch.cat([v['boxes'][v['rel_annotations'][:, 0]] for v in targets]) | ||
| sub_tgt_ids = torch.cat([v['labels'][v['rel_annotations'][:, 0]] for v in targets]) | ||
| obj_tgt_bbox = torch.cat([v['boxes'][v['rel_annotations'][:, 1]] for v in targets]) | ||
| obj_tgt_ids = torch.cat([v['labels'][v['rel_annotations'][:, 1]] for v in targets]) | ||
| rel_tgt_ids = torch.cat([v["rel_annotations"][:, 2] for v in targets]) |
There was a problem hiding this comment.
RelTR의 output은 SOP가 따로 나오기 때문에 rel_annotations에서 나온 idx로 인덱싱해서 tensor로 만들어줌
| sub_prob = outputs["sub_logits"].flatten(0, 1).sigmoid() | ||
| sub_bbox = outputs["sub_boxes"].flatten(0, 1) | ||
| obj_prob = outputs["obj_logits"].flatten(0, 1).sigmoid() | ||
| obj_bbox = outputs["obj_boxes"].flatten(0, 1) | ||
| rel_prob = outputs["rel_logits"].flatten(0, 1).sigmoid() |
There was a problem hiding this comment.
output들 0~1차원 flatten 시켜줌.
logits 기준 [batch_size, num_entities, num_entity_classes] -> [batch_size * num_entities, num_entity_classes]
| # Compute the subject matching cost based on class and box. | ||
| neg_cost_class_sub = (1 - alpha) * (sub_prob ** gamma) * (-(1 - sub_prob + 1e-8).log()) | ||
| pos_cost_class_sub = alpha * ((1 - sub_prob) ** gamma) * (-(sub_prob + 1e-8).log()) | ||
| cost_sub_class = pos_cost_class_sub[:, sub_tgt_ids] - neg_cost_class_sub[:, sub_tgt_ids] | ||
| cost_sub_bbox = torch.cdist(sub_bbox, sub_tgt_bbox, p=1) | ||
| cost_sub_giou = -generalized_box_iou(box_cxcywh_to_xyxy(sub_bbox), box_cxcywh_to_xyxy(sub_tgt_bbox)) |
There was a problem hiding this comment.
neg_cost_class_sub 뭐하는건지 모르겠넹
| # Compute the object matching cost only based on class. | ||
| neg_cost_class_rel = (1 - alpha) * (rel_prob ** gamma) * (-(1 - rel_prob + 1e-8).log()) | ||
| pos_cost_class_rel = alpha * ((1 - rel_prob) ** gamma) * (-(rel_prob + 1e-8).log()) | ||
| cost_rel_class = pos_cost_class_rel[:, rel_tgt_ids] - neg_cost_class_rel[:, rel_tgt_ids] |
There was a problem hiding this comment.
rel_cost 따로 cdist 안하고 OD class_cost하듯이 함. -> 100 x 100 임~


RelTR code reading
https://github.com/yrcong/RelTR
#40