增加了PINN训练网络
This commit is contained in:
parent
a31a31c967
commit
6789a7011a
21
README.md
21
README.md
@ -31,7 +31,7 @@
|
||||
|
||||
## 📂 项目结构
|
||||
|
||||
```text
|
||||
```bash
|
||||
src/soft_arm_sim/
|
||||
├── config/ # [配置] 机器人物理参数与仿真频率配置
|
||||
│ └── config.yaml
|
||||
@ -51,3 +51,22 @@ src/soft_arm_sim/
|
||||
│ └── *.pth / *.csv # 训练好的模型权重与数据集
|
||||
└── urdf/ # [描述] 机器人的基础 TF 树描述
|
||||
```
|
||||
|
||||
## 运行方式
|
||||
|
||||
```bash
|
||||
# 编译哦
|
||||
colcon build --symlink-install
|
||||
|
||||
# 终端1(运行rviz)
|
||||
source install/setup.bash
|
||||
ros2 launch soft_arm_sim simulate.launch.py
|
||||
|
||||
# 终端2 (运行demo)
|
||||
source install/setup.bash
|
||||
ros2 run soft_arm_sim sine_controller
|
||||
```
|
||||
|
||||
## 开发日志Logs
|
||||
|
||||
[Log.2026.2.5](log/log0205.md)
|
||||
|
||||
37
log/log0205.md
Normal file
37
log/log0205.md
Normal file
@ -0,0 +1,37 @@
|
||||
# Log
|
||||
|
||||
## 2026.2.5
|
||||
|
||||
- pcc_kinematics.py 中的数学公式用 PyTorch 张量重写,确保梯度可以从 Cartesian 坐标反向传播回 Joint 角度;
|
||||
`src/soft_arm_sim/soft_arm_sim/deeplearning/differentiable_pcc.py`
|
||||
|
||||
## 解决 deeplearning 文件夹混乱的问题,按照功能模块进行拆分
|
||||
|
||||
```bash
|
||||
src/soft_arm_sim/soft_arm_sim/deeplearning/
|
||||
├── __init__.py
|
||||
├── checkpoints/ # [存储] 存放训练好的 .pth 模型权重
|
||||
│ └── .gitkeep
|
||||
├── data/ # [数据] 数据生成器 (PINN 主要是随机采样,但也可能需要验证集)
|
||||
│ ├── __init__.py
|
||||
│ └── workspace_sampler.py # 用于在机械臂工作空间内随机采样目标点
|
||||
├── layers/ # [物理层] 这里存放可微的运动学公式 (PINN的核心)
|
||||
│ ├── __init__.py
|
||||
│ └── differentiable_pcc.py # == 核心:PyTorch 版的 PCC 正运动学 ==
|
||||
├── models/ # [网络] 神经网络架构定义
|
||||
│ ├── __init__.py
|
||||
│ └── mlp_network.py # 简单的全连接网络
|
||||
├── training/ # [训练] 训练脚本
|
||||
│ ├── __init__.py
|
||||
│ └── train_pinn_basic.py # PINN 训练主脚本
|
||||
└── inference/ # [推理] ROS 接口相关
|
||||
├── __init__.py
|
||||
└── inference_node.py # 加载模型并控制机器人的 ROS 节点
|
||||
```
|
||||
|
||||
```bash
|
||||
# 运行PINN训练文件
|
||||
cd src/soft_arm_sim/soft_arm_sim/deeplearning/training
|
||||
export PYTHONPATH=$PYTHONPATH:$(pwd)/src/soft_arm_sim
|
||||
python3 -m soft_arm_sim.deeplearning.training.train_pinn_basic
|
||||
```
|
||||
@ -48,6 +48,9 @@ setup(
|
||||
|
||||
# 推理控制节点
|
||||
'pinn_controller = soft_arm_sim.deeplearning.inference_node:main',
|
||||
|
||||
# 注意路径变化:soft_arm_sim.deeplearning.inference.inference_node
|
||||
'inference_node = soft_arm_sim.deeplearning.inference.inference_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
Binary file not shown.
|
Can't render this file because it is too large.
|
|
Can't render this file because it is too large.
|
|
Can't render this file because it is too large.
|
Binary file not shown.
Binary file not shown.
@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python3
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 路径修复
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
package_root = os.path.abspath(os.path.join(current_dir, "../../../"))
|
||||
if package_root not in sys.path:
|
||||
sys.path.append(package_root)
|
||||
|
||||
from soft_arm_sim.deeplearning.models.mlp_network import SimpleIKNet
|
||||
from soft_arm_sim.deeplearning.layers.differentiable_pcc import DifferentiablePCC
|
||||
|
||||
# 引入 MarkerArray
|
||||
from visualization_msgs.msg import Marker, MarkerArray
|
||||
from geometry_msgs.msg import Point
|
||||
|
||||
class InferenceNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__('inference_node')
|
||||
self.get_logger().info("Initializing PINN Inference Node (MarkerArray Mode)...")
|
||||
|
||||
# 1. 自动寻找模型路径
|
||||
node_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.model_path = os.path.join(node_dir, "../checkpoints/pinn_basic.pth")
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
# 2. 初始化网络和FK层
|
||||
self.model = SimpleIKNet(input_dim=3, output_dim=6).to(self.device)
|
||||
self.fk_layer = DifferentiablePCC().to(self.device) # 用于计算显示的坐标
|
||||
|
||||
self.load_model()
|
||||
|
||||
# 3. 发布者
|
||||
# 改用 MarkerArray 发布整个机械臂形态
|
||||
self.viz_pub = self.create_publisher(MarkerArray, 'soft_arm_viz', 10)
|
||||
self.marker_pub = self.create_publisher(Marker, 'target_marker', 10)
|
||||
|
||||
self.timer = self.create_timer(0.02, self.control_loop)
|
||||
self.start_time = time.time()
|
||||
|
||||
def load_model(self):
|
||||
try:
|
||||
self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
self.model.eval()
|
||||
self.get_logger().info(f"Loaded model from: {self.model_path}")
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"Error loading model: {e}")
|
||||
|
||||
def control_loop(self):
|
||||
# 1. 生成画圆轨迹
|
||||
t = time.time() - self.start_time
|
||||
radius = 0.15
|
||||
center_z = 0.50
|
||||
freq = 1.0
|
||||
|
||||
target_pos = [
|
||||
radius * math.cos(freq * t),
|
||||
radius * math.sin(freq * t),
|
||||
center_z
|
||||
]
|
||||
|
||||
# 2. 推理
|
||||
input_tensor = torch.tensor([target_pos], dtype=torch.float32).to(self.device)
|
||||
with torch.no_grad():
|
||||
pred_joints = self.model(input_tensor) # 得到角度
|
||||
|
||||
# 3. 通过 FK 层计算骨架坐标 (用于可视化)
|
||||
# pred_joints: (1, 6)
|
||||
# points shape: (1, N_Points, 3)
|
||||
_, points_tensor = self.fk_layer(pred_joints)
|
||||
|
||||
# 转为 numpy 用于 ROS 消息
|
||||
# points_np shape: (N_Points, 3)
|
||||
points_np = points_tensor.cpu().numpy()[0]
|
||||
|
||||
# 4. 发布可视化
|
||||
self.publish_arm_viz(points_np)
|
||||
self.publish_target(target_pos)
|
||||
|
||||
def publish_arm_viz(self, points):
|
||||
"""
|
||||
根据骨架点生成 MarkerArray
|
||||
points: List of [x, y, z]
|
||||
"""
|
||||
array_msg = MarkerArray()
|
||||
timestamp = self.get_clock().now().to_msg()
|
||||
|
||||
# --- A. 骨架连线 (LINE_STRIP) ---
|
||||
line_marker = Marker()
|
||||
line_marker.header.frame_id = "base_link" # 确保 Rviz 里 Fixed Frame 也是这个
|
||||
line_marker.header.stamp = timestamp
|
||||
line_marker.ns = "backbone"
|
||||
line_marker.id = 0
|
||||
line_marker.type = Marker.LINE_STRIP
|
||||
line_marker.action = Marker.ADD
|
||||
line_marker.scale.x = 0.02 # 线条粗细
|
||||
line_marker.color.a = 1.0
|
||||
line_marker.color.r = 0.0
|
||||
line_marker.color.g = 0.8
|
||||
line_marker.color.b = 1.0 # 青色
|
||||
|
||||
for p in points:
|
||||
pt = Point()
|
||||
pt.x, pt.y, pt.z = float(p[0]), float(p[1]), float(p[2])
|
||||
line_marker.points.append(pt)
|
||||
|
||||
array_msg.markers.append(line_marker)
|
||||
|
||||
# --- B. 关节点/圆盘 (SPHERE) ---
|
||||
for i, p in enumerate(points):
|
||||
disk_marker = Marker()
|
||||
disk_marker.header.frame_id = "base_link"
|
||||
disk_marker.header.stamp = timestamp
|
||||
disk_marker.ns = "disks"
|
||||
disk_marker.id = i + 1 # ID 不能重复
|
||||
disk_marker.type = Marker.SPHERE
|
||||
disk_marker.action = Marker.ADD
|
||||
|
||||
disk_marker.pose.position.x = float(p[0])
|
||||
disk_marker.pose.position.y = float(p[1])
|
||||
disk_marker.pose.position.z = float(p[2])
|
||||
|
||||
# 只有中间的点代表磁盘,原点和末端可以小一点
|
||||
disk_marker.scale.x = 0.08
|
||||
disk_marker.scale.y = 0.08
|
||||
disk_marker.scale.z = 0.02 # 扁的圆盘
|
||||
|
||||
disk_marker.color.a = 1.0
|
||||
disk_marker.color.r = 1.0
|
||||
disk_marker.color.g = 0.6
|
||||
disk_marker.color.b = 0.0 # 橙色
|
||||
|
||||
array_msg.markers.append(disk_marker)
|
||||
|
||||
self.viz_pub.publish(array_msg)
|
||||
|
||||
def publish_target(self, pos):
|
||||
marker = Marker()
|
||||
marker.header.frame_id = "base_link"
|
||||
marker.header.stamp = self.get_clock().now().to_msg()
|
||||
marker.ns = "target"
|
||||
marker.id = 999
|
||||
marker.type = Marker.SPHERE
|
||||
marker.action = Marker.ADD
|
||||
marker.pose.position.x, marker.pose.position.y, marker.pose.position.z = pos
|
||||
marker.scale.x = marker.scale.y = marker.scale.z = 0.05
|
||||
marker.color.a = 1.0; marker.color.r = 1.0 # 红球
|
||||
self.marker_pub.publish(marker)
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = InferenceNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,174 +0,0 @@
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from geometry_msgs.msg import PoseStamped, Point
|
||||
from std_msgs.msg import Float64MultiArray
|
||||
from visualization_msgs.msg import Marker
|
||||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
import threading
|
||||
import os
|
||||
import sys
|
||||
|
||||
# --- 路径修复 ---
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
if current_dir not in sys.path:
|
||||
sys.path.append(current_dir)
|
||||
project_root = os.path.abspath(os.path.join(current_dir, "../../"))
|
||||
if project_root not in sys.path:
|
||||
sys.path.append(project_root)
|
||||
|
||||
from pinn_model import PINN_IK
|
||||
from soft_arm_sim.model.pcc_kinematics import SoftArmKinematics
|
||||
|
||||
class PinnController(Node):
|
||||
def __init__(self):
|
||||
super().__init__('pinn_controller')
|
||||
|
||||
# 1. 加载模型
|
||||
self.device = torch.device("cpu")
|
||||
# 注意:这里 output_dim 必须是 9 (三角编码版)
|
||||
self.model = PINN_IK(output_dim=9).to(self.device)
|
||||
self.stats = None
|
||||
|
||||
# 确保这里的文件名和你 train.py 最后保存的文件名一致
|
||||
# 上一次如果是 best_model_smooth.pth 就用这个
|
||||
model_path = os.path.join(current_dir, "best_model_smooth.pth")
|
||||
|
||||
if os.path.exists(model_path):
|
||||
# =======================================================
|
||||
# 🔧 修复点:添加 weights_only=False
|
||||
# =======================================================
|
||||
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state'])
|
||||
self.stats = checkpoint['stats']
|
||||
self.model.eval()
|
||||
self.get_logger().info(f"✅ 模型加载成功: {model_path}")
|
||||
else:
|
||||
self.get_logger().error(f"❌ 找不到模型: {model_path}")
|
||||
|
||||
# 验证用的 FK 工具
|
||||
self.fk_solver = SoftArmKinematics(3, 0.24, 3, 0.033)
|
||||
|
||||
# 2. 通信接口
|
||||
self.cmd_pub = self.create_publisher(Float64MultiArray, 'soft_arm/command', 10)
|
||||
self.traj_pub = self.create_publisher(Marker, 'planned_trajectory', 10)
|
||||
self.target_marker_pub = self.create_publisher(Marker, 'target_marker', 10)
|
||||
|
||||
self.create_subscription(PoseStamped, '/goal_pose', self.goal_callback, 10)
|
||||
|
||||
self.current_pos = np.array([0.0, 0.0, 0.72])
|
||||
self.get_logger().info("等待目标指令... (请在 Rviz 中点击)")
|
||||
|
||||
def decode_output(self, raw_output):
|
||||
"""
|
||||
解码: [t1, cos1, sin1, ...] -> [t1, phi1, ...]
|
||||
"""
|
||||
# 反归一化
|
||||
y_mean = self.stats['y_mean']
|
||||
y_std = self.stats['y_std']
|
||||
real_vals = raw_output * y_std + y_mean
|
||||
|
||||
final_angles = []
|
||||
for i in range(3):
|
||||
theta = real_vals[i*3 + 0]
|
||||
cos_v = real_vals[i*3 + 1]
|
||||
sin_v = real_vals[i*3 + 2]
|
||||
|
||||
# 三角解码 phi = atan2(sin, cos)
|
||||
phi = np.arctan2(sin_v, cos_v)
|
||||
|
||||
final_angles.append(theta)
|
||||
final_angles.append(phi)
|
||||
|
||||
return np.array(final_angles)
|
||||
|
||||
def goal_callback(self, msg):
|
||||
# 强制 Z=0.5 进行测试,防止点到地面
|
||||
target_pos = np.array([msg.pose.position.x, msg.pose.position.y, 0.5])
|
||||
|
||||
dist = np.linalg.norm(target_pos)
|
||||
if dist > 0.75:
|
||||
self.get_logger().warn(f"目标太远 ({dist:.2f}m),忽略")
|
||||
return
|
||||
|
||||
self.get_logger().info(f"目标: {target_pos},开始规划")
|
||||
self.visualize_target(target_pos)
|
||||
threading.Thread(target=self.execute_move, args=(target_pos,)).start()
|
||||
|
||||
def execute_move(self, target_pos):
|
||||
# 插值规划
|
||||
steps = 50
|
||||
trajectory = []
|
||||
for t in np.linspace(0, 1, steps):
|
||||
pt = self.current_pos + t * (target_pos - self.current_pos)
|
||||
trajectory.append(pt)
|
||||
self.visualize_trajectory(trajectory)
|
||||
|
||||
final_angles_decoded = None
|
||||
|
||||
for pt in trajectory:
|
||||
with torch.no_grad():
|
||||
# 预处理输入 (归一化)
|
||||
X_mean = self.stats['X_mean']; X_std = self.stats['X_std']
|
||||
norm_input = (pt - X_mean) / X_std
|
||||
tensor_input = torch.FloatTensor(norm_input).to(self.device)
|
||||
|
||||
# 推理
|
||||
norm_output = self.model(tensor_input).numpy()
|
||||
|
||||
# 解码输出
|
||||
real_angles = self.decode_output(norm_output)
|
||||
final_angles_decoded = real_angles
|
||||
|
||||
# 发送指令
|
||||
msg = Float64MultiArray()
|
||||
msg.data = real_angles.tolist()
|
||||
self.cmd_pub.publish(msg)
|
||||
time.sleep(0.04) # 控制速度
|
||||
|
||||
self.current_pos = target_pos
|
||||
|
||||
# 验证精度
|
||||
if final_angles_decoded is not None:
|
||||
self.check_accuracy(target_pos, final_angles_decoded)
|
||||
|
||||
def check_accuracy(self, target_pos, angles):
|
||||
q_input = []
|
||||
for i in range(3):
|
||||
q_input.append((angles[i*2], angles[i*2+1], 0.24))
|
||||
transforms, _ = self.fk_solver.forward(q_input)
|
||||
actual_pos = transforms[-1][:3, 3]
|
||||
error = np.linalg.norm(target_pos - actual_pos)
|
||||
|
||||
print(f"\n>>> 精度验证 <<<\n目标: {target_pos}\n实际: {actual_pos}\n误差: {error:.4f} m (Loss较低时此误差应很小)\n")
|
||||
|
||||
def visualize_target(self, pos):
|
||||
marker = Marker()
|
||||
marker.header.frame_id = "base_link"
|
||||
marker.type = Marker.SPHERE; marker.action = Marker.ADD
|
||||
marker.scale.x = 0.05; marker.scale.y = 0.05; marker.scale.z = 0.05
|
||||
marker.color.a = 1.0; marker.color.r = 1.0; marker.color.g = 0.0; marker.color.b = 0.0
|
||||
marker.pose.position.x = float(pos[0]); marker.pose.position.y = float(pos[1]); marker.pose.position.z = float(pos[2])
|
||||
self.target_marker_pub.publish(marker)
|
||||
|
||||
def visualize_trajectory(self, points):
|
||||
marker = Marker()
|
||||
marker.header.frame_id = "base_link"
|
||||
marker.type = Marker.SPHERE_LIST; marker.action = Marker.ADD
|
||||
marker.scale.x = 0.01; marker.scale.y = 0.01; marker.scale.z = 0.01
|
||||
marker.color.a = 1.0; marker.color.r = 1.0; marker.color.g = 1.0
|
||||
for p in points:
|
||||
pt = Point(); pt.x, pt.y, pt.z = float(p[0]), float(p[1]), float(p[2]); marker.points.append(pt)
|
||||
self.traj_pub.publish(marker)
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = PinnController()
|
||||
rclpy.spin(node)
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Binary file not shown.
Binary file not shown.
@ -0,0 +1,84 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class DifferentiablePCC(nn.Module):
|
||||
def __init__(self, segment_length=0.24, num_segments=3):
|
||||
super().__init__()
|
||||
self.register_buffer('L', torch.tensor(segment_length))
|
||||
self.n_seg = num_segments
|
||||
|
||||
def forward(self, q):
|
||||
"""
|
||||
q: (Batch, 6)
|
||||
Returns:
|
||||
tip_pos: (Batch, 3) 末端位置
|
||||
backbone_points: (Batch, N_Points, 3) 骨架上所有关键点的坐标
|
||||
"""
|
||||
batch_size = q.shape[0]
|
||||
device = q.device
|
||||
|
||||
# T 初始化为单位矩阵
|
||||
T = torch.eye(4, device=device).unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
|
||||
# === 关键修改:我们需要记录原点 (0,0,0) ===
|
||||
points = []
|
||||
# 先把基座原点放进去 (从 T 中提取位置)
|
||||
points.append(T[:, :3, 3])
|
||||
|
||||
for i in range(self.n_seg):
|
||||
theta = q[:, 2*i]
|
||||
phi = q[:, 2*i+1]
|
||||
|
||||
T_i = self.pcc_transformation(theta, phi)
|
||||
T = torch.bmm(T, T_i)
|
||||
|
||||
# 记录每一段的末端位置
|
||||
points.append(T[:, :3, 3])
|
||||
|
||||
# 堆叠成 (Batch, N+1, 3)
|
||||
backbone_points = torch.stack(points, dim=1)
|
||||
|
||||
# 返回末端 和 所有点
|
||||
return T[:, :3, 3], backbone_points
|
||||
|
||||
def pcc_transformation(self, theta, phi):
|
||||
# ... (保持之前的代码不变,防止 Singular 处理逻辑丢失) ...
|
||||
# 如果之前的代码丢了,请告诉我,我再发一遍完整的
|
||||
# 这里简略,假设你保留了之前的 pcc_transformation 实现
|
||||
mask_straight = torch.abs(theta) < 1e-6
|
||||
theta_safe = torch.where(mask_straight, torch.ones_like(theta) * 1e-6, theta)
|
||||
|
||||
c_phi = torch.cos(phi)
|
||||
s_phi = torch.sin(phi)
|
||||
c_theta = torch.cos(theta_safe)
|
||||
s_theta = torch.sin(theta_safe)
|
||||
|
||||
dx = self.L * (1 - c_theta) / theta_safe
|
||||
dz = self.L * s_theta / theta_safe
|
||||
|
||||
dx = torch.where(mask_straight, torch.zeros_like(dx), dx)
|
||||
dz = torch.where(mask_straight, self.L, dz)
|
||||
|
||||
batch_size = theta.shape[0]
|
||||
T = torch.zeros((batch_size, 4, 4), device=theta.device)
|
||||
|
||||
# Rotation
|
||||
T[:, 0, 0] = c_phi**2 * (c_theta - 1) + 1
|
||||
T[:, 0, 1] = s_phi * c_phi * (c_theta - 1)
|
||||
T[:, 0, 2] = c_phi * s_theta
|
||||
|
||||
T[:, 1, 0] = s_phi * c_phi * (c_theta - 1)
|
||||
T[:, 1, 1] = s_phi**2 * (c_theta - 1) + 1
|
||||
T[:, 1, 2] = s_phi * s_theta
|
||||
|
||||
T[:, 2, 0] = -c_phi * s_theta
|
||||
T[:, 2, 1] = -s_phi * s_theta
|
||||
T[:, 2, 2] = c_theta
|
||||
|
||||
# Translation
|
||||
T[:, 0, 3] = dx * c_phi
|
||||
T[:, 1, 3] = dx * s_phi
|
||||
T[:, 2, 3] = dz
|
||||
T[:, 3, 3] = 1.0
|
||||
|
||||
return T
|
||||
Binary file not shown.
Binary file not shown.
21
soft_arm_sim/soft_arm_sim/deeplearning/models/mlp_network.py
Normal file
21
soft_arm_sim/soft_arm_sim/deeplearning/models/mlp_network.py
Normal file
@ -0,0 +1,21 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class SimpleIKNet(nn.Module):
|
||||
def __init__(self, input_dim=3, output_dim=6):
|
||||
super().__init__()
|
||||
# Input: Target [x, y, z]
|
||||
# Output: [theta1, phi1, theta2, phi2, theta3, phi3]
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(input_dim, 128),
|
||||
nn.Tanh(), # Tanh 在物理回归问题中通常比 ReLU 更平滑
|
||||
nn.Linear(128, 256),
|
||||
nn.Tanh(),
|
||||
nn.Linear(256, 128),
|
||||
nn.Tanh(),
|
||||
nn.Linear(128, output_dim)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
Binary file not shown.
Binary file not shown.
@ -0,0 +1,84 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
from ..models.mlp_network import SimpleIKNet
|
||||
from ..layers.differentiable_pcc import DifferentiablePCC
|
||||
|
||||
def train():
|
||||
# 1. 配置
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Running on {device}")
|
||||
|
||||
# 初始化模型
|
||||
ik_model = SimpleIKNet().to(device)
|
||||
# 初始化物理层 (不需要优化参数,只用于计算)
|
||||
fk_layer = DifferentiablePCC().to(device)
|
||||
|
||||
optimizer = torch.optim.Adam(ik_model.parameters(), lr=1e-4)
|
||||
|
||||
# 2. 训练循环
|
||||
epochs = 20000
|
||||
batch_size = 64
|
||||
|
||||
# 定义关节限制 (Constraint)
|
||||
# 假设 theta 范围 [0, pi/2], phi 范围 [-pi, pi]
|
||||
theta_limit = 3.14159 / 2
|
||||
|
||||
for epoch in range(epochs):
|
||||
# --- A. 随机采样目标点 (Unsupervised) ---
|
||||
# 我们需要在机械臂可达的工作空间内采样
|
||||
# 简单起见,我们在一个大致的半球内采样,或者随机生成关节角求正解来得到目标点
|
||||
# 方法:Teacher Forcing (用随机关节角生成的真实 Pos 作为 Target)
|
||||
# 这样能保证 Target 一定是可达的,训练效率最高
|
||||
|
||||
with torch.no_grad():
|
||||
# 随机生成合法的关节角作为"答案"
|
||||
random_thetas = torch.rand(batch_size, 3).to(device) * theta_limit
|
||||
random_phis = (torch.rand(batch_size, 3).to(device) * 2 - 1) * torch.pi
|
||||
|
||||
# 拼成 (Batch, 6)
|
||||
gt_joints = torch.zeros(batch_size, 6).to(device)
|
||||
gt_joints[:, 0::2] = random_thetas
|
||||
gt_joints[:, 1::2] = random_phis
|
||||
|
||||
# 计算对应的 Target Position
|
||||
target_pos = fk_layer(gt_joints)
|
||||
|
||||
# --- B. 前向传播 ---
|
||||
# 输入:Target Position
|
||||
# 输出:Predicted Joints
|
||||
pred_joints = ik_model(target_pos)
|
||||
|
||||
# --- C. 物理一致性 Loss (PINN Loss) ---
|
||||
# 将预测的关节角代入物理层,看算出的位置在哪里
|
||||
est_pos = fk_layer(pred_joints)
|
||||
|
||||
# 1. 位置误差 (Task Loss)
|
||||
loss_pos = torch.mean((est_pos - target_pos)**2)
|
||||
|
||||
# 2. 关节限制 Loss (Physical Constraint)
|
||||
# 惩罚 theta < 0 或 theta > limit
|
||||
# 提取 theta: 0, 2, 4 列
|
||||
pred_thetas = pred_joints[:, 0::2]
|
||||
loss_limits = torch.mean(torch.relu(-pred_thetas)**2) + \
|
||||
torch.mean(torch.relu(pred_thetas - theta_limit)**2)
|
||||
|
||||
# 总 Loss
|
||||
total_loss = loss_pos + 0.1 * loss_limits
|
||||
|
||||
# --- D. 反向传播 ---
|
||||
optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if epoch % 1000 == 0:
|
||||
print(f"Epoch {epoch} | Pos Loss: {loss_pos.item():.6f} | Limit Loss: {loss_limits.item():.6f}")
|
||||
|
||||
# 3. 保存
|
||||
save_path = "../checkpoints/pinn_basic.pth"
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
torch.save(ik_model.state_dict(), save_path)
|
||||
print(f"Model saved to {save_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Loading…
Reference in New Issue
Block a user