-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdemo.py
More file actions
153 lines (118 loc) · 5.56 KB
/
demo.py
File metadata and controls
153 lines (118 loc) · 5.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import sys
import numpy as np
import argparse
import torch
import open3d as o3d
# Add sggnet to path
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, ROOT_DIR)
from sggnet.models.hggqnet import HGGQNet
from sggnet.skeleton.keypoint_generation import KeypointGeneration
from sggnet.utils.collision_detector import ModelFreeCollisionDetector
from sggnet.graspnetAPI.grasp import GraspGroup
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint_path', required=True, help='Model checkpoint path')
parser.add_argument('--point_cloud', required=True, help='Path to point cloud file (.ply, .pcd, or .npy)')
parser.add_argument('--topk', type=int, default=50, help='Number of top grasps to visualize [default: 50]')
parser.add_argument('--collision_thresh', type=float, default=0.01, help='Collision Threshold [default: 0.01]')
parser.add_argument('--voxel_size', type=float, default=0.01, help='Voxel Size [default: 0.01]')
cfgs = parser.parse_args()
def load_point_cloud(file_path):
"""Load point cloud from file."""
if file_path.endswith('.ply') or file_path.endswith('.pcd'):
pcd = o3d.io.read_point_cloud(file_path)
points = np.asarray(pcd.points)
elif file_path.endswith('.npy'):
points = np.load(file_path)
else:
raise ValueError(f"Unsupported file format: {file_path}")
return points
def generate_grasp_candidates(point_cloud):
"""Generate grasp candidates using skeleton-based method."""
keypoint_generator = KeypointGeneration(debugMode=False, considerGripperValidity=True)
# Split point cloud by height
selected_points, cross_heights = keypoint_generator.split_point_cloud_by_height(point_cloud)
grasp_points = []
for height_index in [len(selected_points) // 2, 1, -2]:
if height_index >= len(selected_points):
continue
final_grasp_poses = keypoint_generator.generate_grasp_poses(height_index, point_cloud)
if final_grasp_poses:
height = cross_heights[height_index]
grasp_points += [((x1, y1, height), (x2, y2, height)) for ((x1, y1), (x2, y2)) in final_grasp_poses]
return grasp_points
def convert_to_grasp_configs(grasp_points):
"""Convert grasp point pairs to grasp configurations."""
grasp_configs = []
for p1, p2 in grasp_points:
center = (np.array(p1) + np.array(p2)) / 2
direction = np.array(p2) - np.array(p1)
direction = direction / np.linalg.norm(direction)
# Create rotation matrix (simplified)
# In practice, you'd compute proper rotation from normals
rotation_matrix = np.eye(3)
grasp_config = np.concatenate([center, rotation_matrix.flatten(), [0.05], [np.linalg.norm(np.array(p2) - np.array(p1))]])
grasp_configs.append(grasp_config)
return np.array(grasp_configs)
def demo():
"""Run demo inference."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load point cloud
print(f"Loading point cloud from {cfgs.point_cloud}...")
point_cloud = load_point_cloud(cfgs.point_cloud)
print(f"Loaded {len(point_cloud)} points")
# Generate grasp candidates using skeleton method
print("Generating grasp candidates using skeleton method...")
grasp_points = generate_grasp_candidates(point_cloud)
print(f"Generated {len(grasp_points)} grasp candidates")
if len(grasp_points) == 0:
print("No grasp candidates found!")
return
# Convert to grasp configurations
grasp_configs = convert_to_grasp_configs(grasp_points)
# Load model
print(f"Loading model from {cfgs.checkpoint_path}...")
model = HGGQNet()
model.to(device)
model.eval()
checkpoint = torch.load(cfgs.checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Model loaded (epoch: {checkpoint.get('epoch', 'unknown')})")
# Create graph from point cloud
import torch_geometric
pc_tensor = torch.tensor(point_cloud[:, :3], dtype=torch.float32).to(device)
rgb_tensor = torch.tensor(np.empty((len(point_cloud), 3)), dtype=torch.float32).to(device)
graph = torch_geometric.data.batch.Batch.from_data_list([
torch_geometric.data.Data(xyz=pc_tensor, rgb=rgb_tensor)
], follow_batch=['xyz'])
# Score grasps
print("Scoring grasps...")
grasp_configs_tensor = torch.tensor(grasp_configs, dtype=torch.float32).to(device)
graph_indices = torch.zeros(len(grasp_configs), dtype=torch.int64).to(device)
with torch.no_grad():
scores, _, _ = model(grasp_configs_tensor, graph, graph_indices)
scores_np = scores.cpu().numpy()
# Create GraspGroup
gg = GraspGroup(grasp_configs)
gg.scores = scores_np
# Collision detection
if cfgs.collision_thresh > 0:
print("Performing collision detection...")
mfcdetector = ModelFreeCollisionDetector(point_cloud, voxel_size=cfgs.voxel_size)
collision_mask = mfcdetector.detect(gg, approach_dist=0.05, collision_thresh=cfgs.collision_thresh)
gg = gg[~collision_mask]
print(f"After collision detection: {len(gg)} grasps remaining")
# Sort and select top-k
gg.nms()
gg.sort_by_score()
gg = gg[:cfgs.topk]
print(f"Visualizing top {len(gg)} grasps...")
# Visualize
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(point_cloud[:, :3])
grippers = gg.to_open3d_geometry_list()
o3d.visualization.draw_geometries([pcd, *grippers])
if __name__ == '__main__':
demo()