当前位置: 首页 > news >正文

ACT代码详解

一、用record_sim_episodes.py生成数据

import time import os import numpy as np import argparse import matplotlib.pyplot as plt import h5py from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS from ee_sim_env import make_ee_sim_env from sim_env import make_sim_env, BOX_POSE from scripted_policy import PickAndTransferPolicy, InsertionPolicy import IPython e = IPython.embed def main(args): """ Generate demonstration data in simulation. First rollout the policy (defined in ee space) in ee_sim_env. Obtain the joint trajectory. Replace the gripper joint positions with the commanded joint position. Replay this joint trajectory (as action sequence) in sim_env, and record all observations. Save this episode of data, and continue to next episode of data collection. """ task_name = args['task_name'] dataset_dir = args['dataset_dir'] num_episodes = args['num_episodes'] onscreen_render = args['onscreen_render'] inject_noise = False render_cam_name = 'angle' if not os.path.isdir(dataset_dir): os.makedirs(dataset_dir, exist_ok=True) episode_len = SIM_TASK_CONFIGS[task_name]['episode_len'] camera_names = SIM_TASK_CONFIGS[task_name]['camera_names'] if task_name == 'sim_transfer_cube_scripted': policy_cls = PickAndTransferPolicy elif task_name == 'sim_insertion_scripted': policy_cls = InsertionPolicy else: raise NotImplementedError success = [] for episode_idx in range(num_episodes): print(f'{episode_idx=}') print('Rollout out EE space scripted policy') # setup the environment env = make_ee_sim_env(task_name) ts = env.reset() episode = [ts] policy = policy_cls(inject_noise) # setup plotting if onscreen_render: ax = plt.subplot() plt_img = ax.imshow(ts.observation['images'][render_cam_name]) plt.ion() for step in range(episode_len): action = policy(ts) ts = env.step(action) episode.append(ts) if onscreen_render: plt_img.set_data(ts.observation['images'][render_cam_name]) plt.pause(0.002) plt.close() episode_return = np.sum([ts.reward for ts in episode[1:]]) episode_max_reward = np.max([ts.reward for ts in episode[1:]]) if episode_max_reward == env.task.max_reward: print(f"{episode_idx=} Successful, {episode_return=}") else: print(f"{episode_idx=} Failed") joint_traj = [ts.observation['qpos'] for ts in episode] # replace gripper pose with gripper control gripper_ctrl_traj = [ts.observation['gripper_ctrl'] for ts in episode] for joint, ctrl in zip(joint_traj, gripper_ctrl_traj): left_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0]) right_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2]) joint[6] = left_ctrl joint[6+7] = right_ctrl subtask_info = episode[0].observation['env_state'].copy() # box pose at step 0 # clear unused variables del env del episode del policy # setup the environment print('Replaying joint commands') env = make_sim_env(task_name) BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env ts = env.reset() episode_replay = [ts] # setup plotting if onscreen_render: ax = plt.subplot() plt_img = ax.imshow(ts.observation['images'][render_cam_name]) plt.ion() for t in range(len(joint_traj)): # note: this will increase episode length by 1 action = joint_traj[t] ts = env.step(action) episode_replay.append(ts) if onscreen_render: plt_img.set_data(ts.observation['images'][render_cam_name]) plt.pause(0.02) episode_return = np.sum([ts.reward for ts in episode_replay[1:]]) episode_max_reward = np.max([ts.reward for ts in episode_replay[1:]]) if episode_max_reward == env.task.max_reward: success.append(1) print(f"{episode_idx=} Successful, {episode_return=}") else: success.append(0) print(f"{episode_idx=} Failed") plt.close() """ For each timestep: observations - images - each_cam_name (480, 640, 3) 'uint8' - qpos (14,) 'float64' - qvel (14,) 'float64' action (14,) 'float64' """ data_dict = { '/observations/qpos': [], '/observations/qvel': [], '/action': [], } for cam_name in camera_names: data_dict[f'/observations/images/{cam_name}'] = [] # because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps # truncate here to be consistent joint_traj = joint_traj[:-1] episode_replay = episode_replay[:-1] # len(joint_traj) i.e. actions: max_timesteps # len(episode_replay) i.e. time steps: max_timesteps + 1 max_timesteps = len(joint_traj) while joint_traj: action = joint_traj.pop(0) ts = episode_replay.pop(0) data_dict['/observations/qpos'].append(ts.observation['qpos']) data_dict['/observations/qvel'].append(ts.observation['qvel']) data_dict['/action'].append(action) for cam_name in camera_names: data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name]) # HDF5 t0 = time.time() dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}') with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root: root.attrs['sim'] = True obs = root.create_group('observations') image = obs.create_group('images') for cam_name in camera_names: _ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8', chunks=(1, 480, 640, 3), ) # compression='gzip',compression_opts=2,) # compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False) qpos = obs.create_dataset('qpos', (max_timesteps, 14)) qvel = obs.create_dataset('qvel', (max_timesteps, 14)) action = root.create_dataset('action', (max_timesteps, 14)) for name, array in data_dict.items(): root[name][...] = array print(f'Saving: {time.time() - t0:.1f} secs\n') print(f'Saved to {dataset_dir}') print(f'Success: {np.sum(success)} / {len(success)}') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) parser.add_argument('--dataset_dir', action='store', type=str, help='dataset saving dir', required=True) parser.add_argument('--num_episodes', action='store', type=int, help='num_episodes', required=False) parser.add_argument('--onscreen_render', action='store_true') main(vars(parser.parse_args()))

1.导入与全局配置

import time import os import numpy as np import argparse import matplotlib.pyplot as plt import h5py from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS from ee_sim_env import make_ee_sim_env from sim_env import make_sim_env, BOX_POSE from scripted_policy import PickAndTransferPolicy, InsertionPolicy
  • constants:包含任务配置SIM_TASK_CONFIGS(每个任务的回合长度、相机名称列表等)和夹持器归一化函数PUPPET_GRIPPER_POSITION_NORMALIZE_FN(用于将夹持器控制信号映射到关节位置)

  • ee_sim_env:提供末端执行器空间(end-effector space)仿真环境,其动作空间是末端执行器的位姿(如 delta 位移、旋转、夹爪开合),不涉及逆运动学

  • sim_env:提供完整关节空间(joint space)仿真环境,动作直接是关节角度

  • BOX_POSE:全局变量,用于指定盒子(物体)的初始位姿,在第二阶段重放前需要与第一阶段末尾的物体状态同步

  • scripted_policy:包含两个脚本策略PickAndTransferPolicyInsertionPolicy,它们接收当前观测并输出末端执行器空间的动作

2.参数解析与主入口

parser = argparse.ArgumentParser() parser.add_argument('--task_name', required=True) parser.add_argument('--dataset_dir', required=True) parser.add_argument('--num_episodes', type=int) parser.add_argument('--onscreen_render', action='store_true') main(vars(parser.parse_args()))
  • task_name:必须指定,例如sim_transfer_cube_scriptedsim_insertion_scripted,决定了使用哪个脚本策略和任务配置

  • dataset_dir:保存HDF5文件的目录。

  • num_episodes:生成的演示回合数(若未指定则使用配置中的默认值?代码中未显式默认,实际可能由调用者提供)。

  • onscreen_render:是否在屏幕上实时渲染仿真画面(用于调试或观察)。

3.main函数详细流程

task_name = args['task_name'] dataset_dir = args['dataset_dir'] num_episodes = args['num_episodes'] onscreen_render = args['onscreen_render'] inject_noise = False #表示策略不注入噪声 render_cam_name = 'angle' #选择渲染时使用的相机视角(通常是 'angle') if not os.path.isdir(dataset_dir): os.makedirs(dataset_dir, exist_ok=True) #从 SIM_TASK_CONFIGS 读取任务特定配置:回合长度、相机名称列表 episode_len = SIM_TASK_CONFIGS[task_name]['episode_len'] camera_names = SIM_TASK_CONFIGS[task_name]['camera_names'] #根据任务名实例化对应的脚本策略类,后面会用它生成动作 if task_name == 'sim_transfer_cube_scripted': policy_cls = PickAndTransferPolicy elif task_name == 'sim_insertion_scripted': policy_cls = InsertionPolicy else: raise NotImplementedError success = [] for episode_idx in range(num_episodes): print(f'{episode_idx=}') print('Rollout out EE space scripted policy') # 创建末端执行器空间环境 env = make_ee_sim_env(task_name) ts = env.reset() #重置得到初始时间步ts episode = [ts] #用于存储每个时间步的 TimeStep 对象(包含观测、奖励、是否结束等信息) policy = policy_cls(inject_noise)#初始化策略对象(不注入噪声) # setup plotting if onscreen_render: ax = plt.subplot() plt_img = ax.imshow(ts.observation['images'][render_cam_name]) plt.ion() for step in range(episode_len): action = policy(ts)#策略根据当前时间步 ts 输出动作(末端执行器空间动作) ts = env.step(action) episode.append(ts) if onscreen_render: plt_img.set_data(ts.observation['images'][render_cam_name]) plt.pause(0.002) plt.close() #计算整个回合的总奖励和最大奖励 episode_return = np.sum([ts.reward for ts in episode[1:]]) episode_max_reward = np.max([ts.reward for ts in episode[1:]]) if episode_max_reward == env.task.max_reward: print(f"{episode_idx=} Successful, {episode_return=}") else: print(f"{episode_idx=} Failed") #从 episode 中提取所有时间步的关节位置和夹持器控制信号 joint_traj = [ts.observation['qpos'] for ts in episode] gripper_ctrl_traj = [ts.observation['gripper_ctrl'] for ts in episode] for joint, ctrl in zip(joint_traj, gripper_ctrl_traj): left_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0]) right_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2]) joint[6] = left_ctrl joint[6+7] = right_ctrl subtask_info = episode[0].observation['env_state'].copy() # box pose at step 0 #清理第一阶段资源 del env del episode del policy # setup the environment print('Replaying joint commands') env = make_sim_env(task_name) #创建关节空间环境 sim_env BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env ts = env.reset() #重置环境,得到初始时间步ts episode_replay = [ts] # setup plotting if onscreen_render: ax = plt.subplot() plt_img = ax.imshow(ts.observation['images'][render_cam_name]) plt.ion() for t in range(len(joint_traj)): # note: this will increase episode length by 1 action = joint_traj[t] ts = env.step(action) episode_replay.append(ts) if onscreen_render: plt_img.set_data(ts.observation['images'][render_cam_name]) plt.pause(0.02) episode_return = np.sum([ts.reward for ts in episode_replay[1:]]) episode_max_reward = np.max([ts.reward for ts in episode_replay[1:]]) if episode_max_reward == env.task.max_reward: success.append(1) print(f"{episode_idx=} Successful, {episode_return=}") else: success.append(0) print(f"{episode_idx=} Failed") plt.close() """ For each timestep: observations - images - each_cam_name (480, 640, 3) 'uint8' - qpos (14,) 'float64' - qvel (14,) 'float64' action (14,) 'float64' """ data_dict = { '/observations/qpos': [], '/observations/qvel': [], '/action': [], } for cam_name in camera_names: data_dict[f'/observations/images/{cam_name}'] = [] # because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps # truncate here to be consistent joint_traj = joint_traj[:-1] episode_replay = episode_replay[:-1] # len(joint_traj) i.e. actions: max_timesteps # len(episode_replay) i.e. time steps: max_timesteps + 1 max_timesteps = len(joint_traj) while joint_traj: action = joint_traj.pop(0) ts = episode_replay.pop(0) data_dict['/observations/qpos'].append(ts.observation['qpos']) data_dict['/observations/qvel'].append(ts.observation['qvel']) data_dict['/action'].append(action) for cam_name in camera_names: data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name]) # HDF5 t0 = time.time() dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}') with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root: root.attrs['sim'] = True obs = root.create_group('observations') image = obs.create_group('images') for cam_name in camera_names: _ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8', chunks=(1, 480, 640, 3), ) # compression='gzip',compression_opts=2,) # compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False) qpos = obs.create_dataset('qpos', (max_timesteps, 14)) qvel = obs.create_dataset('qvel', (max_timesteps, 14)) action = root.create_dataset('action', (max_timesteps, 14)) for name, array in data_dict.items(): root[name][...] = array print(f'Saving: {time.time() - t0:.1f} secs\n') print(f'Saved to {dataset_dir}') print(f'Success: {np.sum(success)} / {len(success)}') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) parser.add_argument('--dataset_dir', action='store', type=str, help='dataset saving dir', required=True) parser.add_argument('--num_episodes', action='store', type=int, help='num_episodes', required=False) parser.add_argument('--onscreen_render', action='store_true') main(vars(parser.parse_args()))

二、训练模型

imitate_episodes.py

import torch import numpy as np import os import pickle import argparse import matplotlib.pyplot as plt from copy import deepcopy from tqdm import tqdm from einops import rearrange from constants import DT #constants定义任务配置、物理时间步长、夹爪开度等常量 from constants import PUPPET_GRIPPER_JOINT_OPEN from utils import load_data # data functions from utils import sample_box_pose, sample_insertion_pose # robot functions from utils import compute_dict_mean, set_seed, detach_dict # helper functions from policy import ACTPolicy, CNNMLPPolicy #定义ACTPolicy和CNNMLPPolicy类 from visualize_episodes import save_videos #保存 rollout 视频 from sim_env import BOX_POSE #仿真或真实机器人环境 import IPython e = IPython.embed #根据 eval 标志决定执行训练还是评估流程,它构建全局配置 config,并在训练时加载数据、训练模型、保存最佳模型; #在评估时加载已训练模型并在环境中 rollout def main(args): set_seed(1) # command line parameters is_eval = args['eval'] #是否仅评估(不训练) ckpt_dir = args['ckpt_dir'] #模型保存/加载目录 policy_class = args['policy_class'] #策略类型,'ACT' 或 'CNNMLP' onscreen_render = args['onscreen_render'] #评估时是否实时渲染 task_name = args['task_name'] batch_size_train = args['batch_size'] batch_size_val = args['batch_size']#训练和验证的 batch size num_epochs = args['num_epochs']#训练轮数 # get task parameters is_sim = task_name[:4] == 'sim_' if is_sim: from constants import SIM_TASK_CONFIGS task_config = SIM_TASK_CONFIGS[task_name] else: from aloha_scripts.constants import TASK_CONFIGS task_config = TASK_CONFIGS[task_name] dataset_dir = task_config['dataset_dir'] num_episodes = task_config['num_episodes'] episode_len = task_config['episode_len'] camera_names = task_config['camera_names'] # fixed parameters state_dim = 14 lr_backbone = 1e-5 backbone = 'resnet18' if policy_class == 'ACT': enc_layers = 4 dec_layers = 7 nheads = 8 policy_config = {'lr': args['lr'], 'num_queries': args['chunk_size'], 'kl_weight': args['kl_weight'], 'hidden_dim': args['hidden_dim'], 'dim_feedforward': args['dim_feedforward'], 'lr_backbone': lr_backbone, 'backbone': backbone, 'enc_layers': enc_layers, 'dec_layers': dec_layers, 'nheads': nheads, 'camera_names': camera_names, } elif policy_class == 'CNNMLP': policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1, 'camera_names': camera_names,} else: raise NotImplementedError config = { 'num_epochs': num_epochs, 'ckpt_dir': ckpt_dir, 'episode_len': episode_len, 'state_dim': state_dim, 'lr': args['lr'], 'policy_class': policy_class, 'onscreen_render': onscreen_render, 'policy_config': policy_config, 'task_name': task_name, 'seed': args['seed'], 'temporal_agg': args['temporal_agg'], 'camera_names': camera_names, 'real_robot': not is_sim } if is_eval: ckpt_names = [f'policy_best.ckpt'] results = [] for ckpt_name in ckpt_names: success_rate, avg_return = eval_bc(config, ckpt_name, save_episode=True) results.append([ckpt_name, success_rate, avg_return]) for ckpt_name, success_rate, avg_return in results: print(f'{ckpt_name}: {success_rate=} {avg_return=}') print() exit() train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val) # save dataset stats if not os.path.isdir(ckpt_dir): os.makedirs(ckpt_dir) stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl') with open(stats_path, 'wb') as f: pickle.dump(stats, f) best_ckpt_info = train_bc(train_dataloader, val_dataloader, config) best_epoch, min_val_loss, best_state_dict = best_ckpt_info # save best checkpoint ckpt_path = os.path.join(ckpt_dir, f'policy_best.ckpt') torch.save(best_state_dict, ckpt_path) print(f'Best ckpt, val loss {min_val_loss:.6f} @ epoch{best_epoch}') def make_policy(policy_class, policy_config): if policy_class == 'ACT': policy = ACTPolicy(policy_config) elif policy_class == 'CNNMLP': policy = CNNMLPPolicy(policy_config) else: raise NotImplementedError return policy def make_optimizer(policy_class, policy): if policy_class == 'ACT': optimizer = policy.configure_optimizers() elif policy_class == 'CNNMLP': optimizer = policy.configure_optimizers() else: raise NotImplementedError return optimizer def get_image(ts, camera_names): curr_images = [] for cam_name in camera_names: curr_image = rearrange(ts.observation['images'][cam_name], 'h w c -> c h w') curr_images.append(curr_image) curr_image = np.stack(curr_images, axis=0) curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) return curr_image def eval_bc(config, ckpt_name, save_episode=True): set_seed(1000) ckpt_dir = config['ckpt_dir'] state_dim = config['state_dim'] real_robot = config['real_robot'] policy_class = config['policy_class'] onscreen_render = config['onscreen_render'] policy_config = config['policy_config'] camera_names = config['camera_names'] max_timesteps = config['episode_len'] task_name = config['task_name'] temporal_agg = config['temporal_agg'] onscreen_cam = 'angle' # load policy and stats ckpt_path = os.path.join(ckpt_dir, ckpt_name) policy = make_policy(policy_class, policy_config) loading_status = policy.load_state_dict(torch.load(ckpt_path)) print(loading_status) policy.cuda() policy.eval() print(f'Loaded: {ckpt_path}') stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl') with open(stats_path, 'rb') as f: stats = pickle.load(f) pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] post_process = lambda a: a * stats['action_std'] + stats['action_mean'] # load environment if real_robot: from aloha_scripts.robot_utils import move_grippers # requires aloha from aloha_scripts.real_env import make_real_env # requires aloha env = make_real_env(init_node=True) env_max_reward = 0 else: from sim_env import make_sim_env env = make_sim_env(task_name) env_max_reward = env.task.max_reward query_frequency = policy_config['num_queries'] if temporal_agg: query_frequency = 1 num_queries = policy_config['num_queries'] max_timesteps = int(max_timesteps * 1) # may increase for real-world tasks num_rollouts = 50 episode_returns = [] highest_rewards = [] for rollout_id in range(num_rollouts): rollout_id += 0 ### set task if 'sim_transfer_cube' in task_name: BOX_POSE[0] = sample_box_pose() # used in sim reset elif 'sim_insertion' in task_name: BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset ts = env.reset() ### onscreen render if onscreen_render: ax = plt.subplot() plt_img = ax.imshow(env._physics.render(height=480, width=640, camera_id=onscreen_cam)) plt.ion() ### evaluation loop if temporal_agg: all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda() qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda() image_list = [] # for visualization qpos_list = [] target_qpos_list = [] rewards = [] with torch.inference_mode(): for t in range(max_timesteps): ### update onscreen render and wait for DT if onscreen_render: image = env._physics.render(height=480, width=640, camera_id=onscreen_cam) plt_img.set_data(image) plt.pause(DT) ### process previous timestep to get qpos and image_list obs = ts.observation if 'images' in obs: image_list.append(obs['images']) else: image_list.append({'main': obs['image']}) qpos_numpy = np.array(obs['qpos']) qpos = pre_process(qpos_numpy) qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0) qpos_history[:, t] = qpos curr_image = get_image(ts, camera_names) ### query policy if config['policy_class'] == "ACT": if t % query_frequency == 0: all_actions = policy(qpos, curr_image) if temporal_agg: all_time_actions[[t], t:t+num_queries] = all_actions actions_for_curr_step = all_time_actions[:, t] actions_populated = torch.all(actions_for_curr_step != 0, axis=1) actions_for_curr_step = actions_for_curr_step[actions_populated] k = 0.01 exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) exp_weights = exp_weights / exp_weights.sum() exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) else: raw_action = all_actions[:, t % query_frequency] elif config['policy_class'] == "CNNMLP": raw_action = policy(qpos, curr_image) else: raise NotImplementedError ### post-process actions raw_action = raw_action.squeeze(0).cpu().numpy() action = post_process(raw_action) target_qpos = action ### step the environment ts = env.step(target_qpos) ### for visualization qpos_list.append(qpos_numpy) target_qpos_list.append(target_qpos) rewards.append(ts.reward) plt.close() if real_robot: move_grippers([env.puppet_bot_left, env.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5) # open pass rewards = np.array(rewards) episode_return = np.sum(rewards[rewards!=None]) episode_returns.append(episode_return) episode_highest_reward = np.max(rewards) highest_rewards.append(episode_highest_reward) print(f'Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {env_max_reward=}, Success: {episode_highest_reward==env_max_reward}') if save_episode: save_videos(image_list, DT, video_path=os.path.join(ckpt_dir, f'video{rollout_id}.mp4')) success_rate = np.mean(np.array(highest_rewards) == env_max_reward) avg_return = np.mean(episode_returns) summary_str = f'\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n' for r in range(env_max_reward+1): more_or_equal_r = (np.array(highest_rewards) >= r).sum() more_or_equal_r_rate = more_or_equal_r / num_rollouts summary_str += f'Reward >= {r}: {more_or_equal_r}/{num_rollouts} = {more_or_equal_r_rate*100}%\n' print(summary_str) # save success rate to txt result_file_name = 'result_' + ckpt_name.split('.')[0] + '.txt' with open(os.path.join(ckpt_dir, result_file_name), 'w') as f: f.write(summary_str) f.write(repr(episode_returns)) f.write('\n\n') f.write(repr(highest_rewards)) return success_rate, avg_return def forward_pass(data, policy): image_data, qpos_data, action_data, is_pad = data image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda() return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None def train_bc(train_dataloader, val_dataloader, config): num_epochs = config['num_epochs'] ckpt_dir = config['ckpt_dir'] seed = config['seed'] policy_class = config['policy_class'] policy_config = config['policy_config'] set_seed(seed) policy = make_policy(policy_class, policy_config) policy.cuda() optimizer = make_optimizer(policy_class, policy) train_history = [] validation_history = [] min_val_loss = np.inf best_ckpt_info = None for epoch in tqdm(range(num_epochs)): print(f'\nEpoch {epoch}') # validation with torch.inference_mode(): policy.eval() epoch_dicts = [] for batch_idx, data in enumerate(val_dataloader): forward_dict = forward_pass(data, policy) epoch_dicts.append(forward_dict) epoch_summary = compute_dict_mean(epoch_dicts) validation_history.append(epoch_summary) epoch_val_loss = epoch_summary['loss'] if epoch_val_loss < min_val_loss: min_val_loss = epoch_val_loss best_ckpt_info = (epoch, min_val_loss, deepcopy(policy.state_dict())) print(f'Val loss: {epoch_val_loss:.5f}') summary_string = '' for k, v in epoch_summary.items(): summary_string += f'{k}: {v.item():.3f} ' print(summary_string) # training policy.train() optimizer.zero_grad() for batch_idx, data in enumerate(train_dataloader): forward_dict = forward_pass(data, policy) # backward loss = forward_dict['loss'] loss.backward() optimizer.step() optimizer.zero_grad() train_history.append(detach_dict(forward_dict)) epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)]) epoch_train_loss = epoch_summary['loss'] print(f'Train loss: {epoch_train_loss:.5f}') summary_string = '' for k, v in epoch_summary.items(): summary_string += f'{k}: {v.item():.3f} ' print(summary_string) if epoch % 100 == 0: ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{epoch}_seed_{seed}.ckpt') torch.save(policy.state_dict(), ckpt_path) plot_history(train_history, validation_history, epoch, ckpt_dir, seed) ckpt_path = os.path.join(ckpt_dir, f'policy_last.ckpt') torch.save(policy.state_dict(), ckpt_path) best_epoch, min_val_loss, best_state_dict = best_ckpt_info ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{best_epoch}_seed_{seed}.ckpt') torch.save(best_state_dict, ckpt_path) print(f'Training finished:\nSeed {seed}, val loss {min_val_loss:.6f} at epoch {best_epoch}') # save training curves plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed) return best_ckpt_info def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed): # save training curves for key in train_history[0]: plot_path = os.path.join(ckpt_dir, f'train_val_{key}_seed_{seed}.png') plt.figure() train_values = [summary[key].item() for summary in train_history] val_values = [summary[key].item() for summary in validation_history] plt.plot(np.linspace(0, num_epochs-1, len(train_history)), train_values, label='train') plt.plot(np.linspace(0, num_epochs-1, len(validation_history)), val_values, label='validation') # plt.ylim([-0.1, 1]) plt.tight_layout() plt.legend() plt.title(key) plt.savefig(plot_path) print(f'Saved plots to {ckpt_dir}') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--eval', action='store_true') parser.add_argument('--onscreen_render', action='store_true') parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True) parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True) parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) parser.add_argument('--batch_size', action='store', type=int, help='batch_size', required=True) parser.add_argument('--seed', action='store', type=int, help='seed', required=True) parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True) parser.add_argument('--lr', action='store', type=float, help='lr', required=True) # for ACT parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False) parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False) parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False) parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False) parser.add_argument('--temporal_agg', action='store_true') main(vars(parser.parse_args()))
http://www.jsqmd.com/news/580043/

相关文章:

  • Pixel Aurora Engine基础教程:8-BIT音效视觉化——将MIDI转像素动态图初探
  • Asian Beauty Z-Image Turbo快速上手:无需复杂配置,开箱即用的东方美学图像生成工具
  • 告别PPT体验!用UE5.3为你的手游打造‘丝滑’60帧:从合批、LOD到后处理的实战调优
  • 卷积神经网络(CNN)原理可视化解释:Phi-4-mini-reasoning担任AI讲师
  • 教育技术应用:集成cv_unet_image-colorization的在线作业批改系统——美术色彩作业
  • SEO_全面介绍SEO基础知识与核心概念指南
  • Qwen3-ASR-0.6B落地解析:高校智慧教室课堂语音→知识点自动标注
  • OpenClaw多模型切换:千问3.5-9B与本地LLM混合调用方案
  • 英语表达情绪日常口语
  • SAM 3作品集:看看AI如何精准分割图片中的每一个细节物体
  • SAM 3图像视频分割入门:上传图片视频,输入英文名称一键分割
  • Python无锁并发避坑清单(23个生产事故溯源):从引用计数竞争到缓存行伪共享,一文终结“线程安全幻觉”
  • Qwen3.5-9B-AWQ-4bit开源镜像解析:AWQ量化+双卡适配+supervisor自启机制
  • MTools全功能解析:从图像工坊到开发助手,一站式工具使用详解
  • 迭代器、生成器、装饰器面试题总结
  • 2025-2026年全球空气能热水器十大品牌评测:五款口碑产品推荐评价 - 品牌推荐
  • Pixel Aurora Engine部署教程:多用户共享部署+LoRA权限分级管理方案
  • Z-Image-GGUF提示词工程:从‘樱花寺庙’到‘电影级8K杰作’的结构化编写法
  • HTML 知识点
  • NaViL-9B效果展示:低质量模糊图片中的文字识别与语义补全能力
  • 算法训练之递归(一)
  • 2025-2026年全球空气能热水器十大品牌评测:五款口碑产品推荐评价知名 - 品牌推荐
  • 避开这3个坑,你的火山引擎SFT微调效果才能翻倍
  • 终结混淆:一文分清5G的“双流”与“双通道”
  • NCM格式转换技术解析:从加密限制到音频自由的技术实现
  • LiuJuan Z-Image Generator企业实操:私有化部署规避数据外泄风险
  • 7个高效技巧:BetterJoy实现Switch手柄全场景PC适配
  • 国内顶级的SEO技术网站有哪些
  • OpenClaw性能调优:Qwen3.5-9B任务响应速度提升50%的方法
  • LeaguePrank:英雄联盟段位修改与个性化展示完全指南