Skip to content
Open
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
data
alignnet_model.pth

.venv/
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.10.14
18 changes: 13 additions & 5 deletions alignit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ModelConfig:
metadata={"help": "Path to save/load trained model"},
)
use_depth_input: bool = field(
default=True, metadata={"help": "Whether to use depth input for the model"}
default=False, metadata={"help": "Whether to use depth input for the model"}
)
depth_hidden_dim: int = field(
default=128, metadata={"help": "Output dimension of depth CNN"}
Expand Down Expand Up @@ -92,6 +92,10 @@ class RecordConfig:

dataset: DatasetConfig = field(default_factory=DatasetConfig)
trajectory: TrajectoryConfig = field(default_factory=TrajectoryConfig)
robot_type: str = field(
default="sim",
metadata={"help": "Robot type: 'sim' for simulation or 'real' for real xArm robot"},
)
episodes: int = field(default=10, metadata={"help": "Number of episodes to record"})
lin_tol_alignment: float = field(
default=0.015, metadata={"help": "Linear tolerance for alignment servo"}
Expand Down Expand Up @@ -147,20 +151,20 @@ class InferConfig:
metadata={"help": "Starting pose RPY angles"},
)
lin_tolerance: float = field(
default=5e-3, metadata={"help": "Linear tolerance for convergence (meters)"}
default=2e-3, metadata={"help": "Linear tolerance for convergence (meters)"}
)
ang_tolerance: float = field(
default=5, metadata={"help": "Angular tolerance for convergence (degrees)"}
default=4, metadata={"help": "Angular tolerance for convergence (degrees)"}
)
max_iterations: Optional[int] = field(
default=20,
default=5,
metadata={"help": "Maximum iterations before stopping (None = infinite)"},
)
debug_output: bool = field(
default=True, metadata={"help": "Print debug information during inference"}
)
debouncing_count: int = field(
default=20,
default=5,
metadata={"help": "Number of iterations within tolerance before stopping"},
)
rotation_matrix_multiplier: int = field(
Expand All @@ -169,6 +173,10 @@ class InferConfig:
"help": "Number of times to multiply the rotation matrix of relative action in order to speed up convergence"
},
)
translation_multiplier: float = field(
default=1.0,
metadata={"help": "Multiplier for the translation relative action to adjust convergence speed"}
)
manual_height: float = field(
default=0.08, metadata={"help": "Height above surface for manual movement"}
)
Expand Down
217 changes: 129 additions & 88 deletions alignit/infere.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time

import torch
import transforms3d as t3d
import numpy as np
Expand All @@ -10,9 +9,14 @@
from alignit.utils.zhou import sixd_se3
from alignit.utils.tfs import print_pose, are_tfs_close
from alignit.robots.xarmsim import XarmSim
from alignit.robots.xarm import Xarm


Xarm = None
try:
from alignit.robots.xarm import Xarm
except ImportError:
pass

@draccus.wrap()
def main(cfg: InferConfig):
"""Run inference/alignment using configuration parameters."""
Expand All @@ -33,99 +37,136 @@ def main(cfg: InferConfig):
net.to(device)
net.eval()

robot = XarmSim()

start_pose = t3d.affines.compose(
[0.23, 0, 0.25], t3d.euler.euler2mat(np.pi, 0, 0), [1, 1, 1]
)
robot.servo_to_pose(start_pose, lin_tol=1e-2)
iteration = 0
iterations_within_tolerance = 0
robot = Xarm()

num_alignments = getattr(cfg, 'num_alignments', 5)
ang_tol_rad = np.deg2rad(cfg.ang_tolerance)
try:
while True:
observation = robot.get_observation()
rgb_image = observation["rgb"].astype(np.float32) / 255.0
depth_image = observation["depth"].astype(np.float32)
print(
"Min/Max depth,mean (raw):",
observation["depth"].min(),
observation["depth"].max(),
observation["depth"].mean(),
)
print(
"Min/Max depth,mean (scaled):",
depth_image.min(),
depth_image.max(),
depth_image.mean(),
)
rgb_image_tensor = (
torch.from_numpy(np.array(rgb_image))
.permute(2, 0, 1) # (H, W, C) -> (C, H, W)
.unsqueeze(0)
.to(device)
)

depth_image_tensor = (
torch.from_numpy(np.array(depth_image))
.unsqueeze(0) # Add channel dimension: (1, H, W)
.unsqueeze(0) # Add batch dimension: (1, 1, H, W)
.to(device)
)
rgb_images_batch = rgb_image_tensor.unsqueeze(1)
depth_images_batch = depth_image_tensor.unsqueeze(1)

with torch.no_grad():
relative_action = net(rgb_images_batch, depth_images=depth_images_batch)
relative_action = relative_action.squeeze(0).cpu().numpy()
relative_action = sixd_se3(relative_action)

if cfg.debug_output:
print_pose(relative_action)

relative_action[:3, :3] = np.linalg.matrix_power(
relative_action[:3, :3], cfg.rotation_matrix_multiplier
)
if are_tfs_close(
relative_action, lin_tol=cfg.lin_tolerance, ang_tol=ang_tol_rad
):
iterations_within_tolerance += 1
else:
iterations_within_tolerance = 0

print(relative_action)
target_pose = robot.pose() @ relative_action
iteration += 1
action = {
"pose": target_pose,
"gripper.pos": 1.0,
}
robot.send_action(action)
if iterations_within_tolerance >= cfg.max_iterations:
print(f"Reached maximum iterations ({cfg.max_iterations}) - stopping.")
print("Moving robot to final pose.")
alignment_results = []

MAX_TOTAL_STEPS = 1000

print(f"\nRunning {num_alignments} alignment trials...\n")

for alignment_trial in range(num_alignments):
print(f"\n{'='*60}")
print(f"Alignment Trial {alignment_trial + 1}/{num_alignments}")
print(f"{'='*60}")

start_pose = t3d.affines.compose(
[0.225, 0.0, 0.275],
t3d.euler.euler2mat(np.pi, 0.0, 0.0),
[1, 1, 1]
)
robot.servo_to_pose(start_pose, lin_tol=1e-2, ang_tol=0.1)

iteration = 0
iterations_within_tolerance = 0
trial_data = []

try:
while True:

observation = robot.get_observation()
rgb_np = observation["rgb"].astype(np.float32) / 255.0

if rgb_np.ndim == 2:
rgb_np = np.expand_dims(rgb_np, axis=-1)
if rgb_np.shape[-1] == 1:
rgb_np = np.repeat(rgb_np, 3, axis=-1)

rgb_images_batch = (
torch.from_numpy(rgb_np)
.permute(2, 0, 1)
.unsqueeze(0).unsqueeze(0)
.to(device)
)

with torch.no_grad():
raw_model_output = net(rgb_images_batch)

relative_action_np = raw_model_output.squeeze(0).cpu().numpy()
relative_action = sixd_se3(relative_action_np)

rot_mat = np.array(relative_action[:3, :3], dtype=np.float64)
try:
euler_rad = t3d.euler.mat2euler(rot_mat)
euler_deg = np.degrees(euler_rad)
except Exception:
euler_deg = [0.0, 0.0, 0.0]

trans = relative_action[:3, 3]

print(f"\n--- Model Prediction (Relative to EE) ---")
print(f"Translation (m): X: {trans[0]:.4f}, Y: {trans[1]:.4f}, Z: {trans[2]:.4f}")
print(f"Rotation (deg): R: {euler_deg[0]:.2f}, P: {euler_deg[1]:.2f}, Y: {euler_deg[2]:.2f}")
print(f"-----------------------------------------\n")

error_magnitude = np.linalg.norm(relative_action[:3, 3])

if are_tfs_close(
relative_action, lin_tol=cfg.lin_tolerance, ang_tol=ang_tol_rad
):
iterations_within_tolerance += 1
print(f"Step {iteration}: Within Tol ({iterations_within_tolerance}/{cfg.debouncing_count}) [error: {error_magnitude:.6f}]")
else:
iterations_within_tolerance = 0
print(f"Step {iteration}: Adjusting... [error: {error_magnitude:.6f}]")

scaled_action = relative_action.copy()
translation_mult = getattr(cfg, 'translation_multiplier', 1.0)
scaled_action[:3, 3] *= translation_mult
scaled_action[:3, :3] = np.linalg.matrix_power(
scaled_action[:3, :3], int(cfg.rotation_matrix_multiplier)
)

current_pose = robot.pose()
gripper_z_offset = np.array(
[
target_pose = current_pose @ scaled_action
iteration += 1

input("Hold ENTER to move robot (release to continue)...")
robot.send_action({"pose": target_pose, "gripper.pos": 1.0})


if iterations_within_tolerance >= cfg.max_iterations:
print(f"✓ Converged after {iteration} total steps.")

gripper_z_offset = np.array([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, cfg.manual_height],
[0, 0, 0, 1],
]
)
offset_pose = current_pose @ gripper_z_offset
robot.servo_to_pose(pose=offset_pose)
robot.close_gripper()
robot.gripper_off()

break

time.sleep(10.0)
except KeyboardInterrupt:
print("\nExiting...")

])
robot.servo_to_pose(pose=robot.pose() @ gripper_z_offset)

alignment_results.append({
"trial": alignment_trial + 1,
"success": True,
"iterations": iteration,
})
break

if iteration >= MAX_TOTAL_STEPS:
print(f"✗ Failed: Timeout reached ({MAX_TOTAL_STEPS} steps).")
alignment_results.append({
"trial": alignment_trial + 1,
"success": False,
"iterations": iteration,
})
break

except KeyboardInterrupt:
print("\nTrial interrupted by user.")
break

print(f"\n{'='*60}")
print(f"INFERENCE SUMMARY")
print(f"{'='*60}")
successful = sum(1 for r in alignment_results if r["success"])
print(f"Success Rate: {successful}/{len(alignment_results)} ({successful*100//max(1, len(alignment_results))}%)")
if alignment_results:
print(f"Avg Steps to Converge: {np.mean([r['iterations'] for r in alignment_results]):.1f}")

robot.disconnect()


if __name__ == "__main__":
main()
21 changes: 21 additions & 0 deletions alignit/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
import torch.nn as nn

class InversePredictionWeightedLoss(nn.Module):
def __init__(self, epsilon: float = 1e-6):
super().__init__()
self.epsilon = epsilon

def forward(self, pred, target):
pred_pos = pred[:, :3]
pred_rot = pred[:, 3:]
target_pos = target[:, :3]
target_rot = target[:, 3:]

weights_pos = 1.0 / (torch.abs(pred_pos) + self.epsilon)
pos_loss = (weights_pos * (pred_pos - target_pos) ** 2).mean()

rot_loss = torch.mean((pred_rot - target_rot) ** 2)

loss = pos_loss + rot_loss
return loss
Loading
Loading