局部切空间排列(LTSA)流形学习算法 MATLAB 实现
LTSA(Local Tangent Space Alignment)是一种经典的流形学习算法,通过局部切空间估计和全局排列实现非线性降维。
一、LTSA 算法原理
1.1 核心思想
- 局部切空间估计:每个数据点邻域内近似为线性子空间(切空间)
- 局部坐标表示:将数据点投影到局部切空间
- 全局排列:通过最小二乘优化,将所有局部坐标对齐到全局坐标
1.2 算法步骤
输入:高维数据 X ∈ ℝⁿˣᴰ,目标维度 d,邻域大小 k 输出:低维嵌入 Y ∈ ℝⁿˣᵈ 1. 对每个点 xᵢ,找到 k 近邻 2. 对邻域进行 PCA,得到局部切空间基 Vᵢ ∈ ℝᴰˣᵈ 3. 计算局部坐标 θᵢ = Vᵢᵀ(xⱼ - xᵢ)(j ∈ 邻域) 4. 构建局部对齐矩阵 Lᵢ = I - θᵢθᵢᵀ 5. 全局坐标 Y 通过最小化 ∑‖YᵢLᵢ - Yᵢ‖² 获得 6. 等价于求解广义特征值问题:Y = argmin tr(YᵀLY)二、MATLAB 源代码
2.1 主函数ltsa_main.m
%% 局部切空间排列(LTSA)流形学习算法clear;clc;close all;%% ===== 1. 生成测试数据(瑞士卷流形)=====fprintf('生成瑞士卷流形数据...\n');n_samples=1000;n_features=3;d_target=2;% 目标维度k_neighbors=12;% 邻域大小% 生成瑞士卷[X,color]=make_swiss_roll(n_samples);fprintf('数据维度: %d × %d\n',size(X,1),size(X,2));fprintf('目标维度: %d\n',d_target);fprintf('邻域大小: %d\n',k_neighbors);%% ===== 2. 运行 LTSA 降维 =====fprintf('运行 LTSA 降维...\n');tic;Y=ltsa(X,d_target,k_neighbors);t_ltsa=toc;fprintf('LTSA 耗时: %.2f 秒\n',t_ltsa);%% ===== 3. 对比其他流形学习算法 =====fprintf('运行 PCA 降维...\n');tic;[~,Y_pca]=pca(X,d_target);t_pca=toc;fprintf('PCA 耗时: %.2f 秒\n',t_pca);% 可选:Isomap(需要 Statistics Toolbox)tryfprintf('运行 Isomap 降维...\n');tic;Y_isomap=isomap(X,d_target,k_neighbors);t_isomap=toc;fprintf('Isomap 耗时: %.2f 秒\n',t_isomap);has_isomap=true;catchfprintf('Isomap 不可用(需要 Statistics Toolbox)\n');has_isomap=false;end%% ===== 4. 结果可视化 =====visualize_results(X,Y,Y_pca,Y_isomap,color,has_isomap,d_target);%% ===== 5. 保存结果 =====save('ltsa_results.mat','X','Y','Y_pca','Y_isomap','d_target','k_neighbors');fprintf('结果已保存到 ltsa_results.mat\n');2.2 LTSA 核心算法ltsa.m
functionY=ltsa(X,d,k)% 局部切空间排列(LTSA)算法% 输入:% X - 高维数据 (n × D)% d - 目标维度% k - 邻域大小% 输出:% Y - 低维嵌入 (n × d)[n,D]=size(X);% ===== 1. 构建 k 近邻图 =====fprintf(' 构建 k 近邻图...\n');neighbors=knnsearch(X,X,'K',k+1);% 包含自身neighbors=neighbors(:,2:end);% 移除自身% ===== 2. 估计局部切空间 =====fprintf(' 估计局部切空间...\n');local_bases=cell(n,1);% 局部切空间基 V_i ∈ ℝ^{D×d}local_coords=cell(n,1);% 局部坐标 θ_i ∈ ℝ^{k×d}fori=1:n% 获取邻域点idx=neighbors(i,:);Xi=X(idx,:);% k × D% 中心化Xi_centered=Xi-mean(Xi,1);% PCA 估计切空间[~,S,V]=svd(Xi_centered,'econ');% 取前 d 个主成分作为切空间基local_bases{i}=V(:,1:d);% 计算局部坐标(投影到切空间)local_coords{i}=Xi_centered*local_bases{i};end% ===== 3. 构建局部对齐矩阵 =====fprintf(' 构建局部对齐矩阵...\n');L=sparse(n,n);fori=1:n idx=neighbors(i,:);k_i=length(idx);% 局部坐标矩阵theta=local_coords{i};% k_i × d% 局部对齐矩阵 L_i = I - θθᵀLi=eye(k_i)-theta*pinv(theta'*theta)*theta';% 填充到全局矩阵forp=1:k_iforq=1:k_iL(idx(p),idx(q))=L(idx(p),idx(q))+Li(p,q);endendend% ===== 4. 求解全局坐标 =====fprintf(' 求解全局坐标...\n');% 构建拉普拉斯矩阵 M = D - LD_diag=diag(sum(L,2));M=spdiags(D_diag,0,n,n)-L;% 求解广义特征值问题:Mv = λDv% 其中 D 是质量矩阵(对角矩阵)D=spdiags(ones(n,1),0,n,n);% 求解最小的 d+1 个特征值(忽略零特征值)opts.disp=0;[V,Lambda]=eigs(M,D,d+1,'smallestabs',opts);% 取第 2 到第 d+1 个特征向量(忽略第一个零特征值)Y=V(:,2:d+1);% 归一化Y=Y/max(abs(Y(:)));fprintf(' LTSA 完成!\n');end2.3 辅助函数
2.3.1 生成瑞士卷数据
function[X,color]=make_swiss_roll(n_samples)% 生成瑞士卷流形数据t=3*pi/2*(1+2*rand(n_samples,1));height=21*rand(n_samples,1);X=[t.*cos(t),height,t.*sin(t)];color=t;% 用于可视化end2.3.2 PCA 实现(手写)
function[Y,W]=pca(X,d)% 主成分分析(手写实现)% 输入: X - 数据矩阵 (n × D), d - 目标维度% 输出: Y - 降维结果 (n × d), W - 投影矩阵 (D × d)% 中心化mu=mean(X,1);X_centered=X-mu;% 协方差矩阵C=(X_centered'*X_centered)/(size(X,1)-1);% 特征值分解[W,D]=eig(C);[~,idx]=sort(diag(D),'descend');W=W(:,idx(1:d));% 投影Y=X_centered*W;end2.3.3 Isomap 实现(简化版)
functionY=isomap(X,d,k)% 等距映射(简化实现)% 输入: X - 数据矩阵, d - 目标维度, k - 邻域大小% 输出: Y - 低维嵌入n=size(X,1);% 1. 构建 k 近邻图neighbors=knnsearch(X,X,'K',k+1);neighbors=neighbors(:,2:end);% 2. 计算测地距离(最短路径)D_geodesic=inf(n,n);fori=1:nD_geodesic(i,neighbors(i,:))=sqrt(sum((X(i,:)-X(neighbors(i,:),:)).^2,2));end% Floyd-Warshall 算法计算最短路径fork=1:nfori=1:nforj=1:nifD_geodesic(i,k)+D_geodesic(k,j)<D_geodesic(i,j)D_geodesic(i,j)=D_geodesic(i,k)+D_geodesic(k,j);endendendend% 3. MDS 降维Y=classical_mds(D_geodesic,d);endfunctionY=classical_mds(D,d)% 经典多维缩放n=size(D,1);J=eye(n)-ones(n)/n;B=-0.5*J*D.^2*J;[V,Lambda]=eig(B);[~,idx]=sort(diag(Lambda),'descend');Y=V(:,idx(1:d))*sqrt(diag(Lambda(idx(1:d),idx(1:d))));end2.4 可视化函数visualize_results.m
functionvisualize_results(X,Y,Y_pca,Y_isomap,color,has_isomap,d_target)figure('Color','w','Position',[1001001400600]);% 原始高维数据(3D)subplot(2,3,1);scatter3(X(:,1),X(:,2),X(:,3),10,color,'filled');title('原始数据(瑞士卷)');xlabel('X1');ylabel('X2');zlabel('X3');grid on;view(45,30);% LTSA 结果subplot(2,3,2);scatter(Y(:,1),Y(:,2),10,color,'filled');title(sprintf('LTSA 降维 (d=%d)',d_target));xlabel('Y1');ylabel('Y2');grid on;axis equal;% PCA 结果subplot(2,3,3);scatter(Y_pca(:,1),Y_pca(:,2),10,color,'filled');title(sprintf('PCA 降维 (d=%d)',d_target));xlabel('PC1');ylabel('PC2');grid on;axis equal;% Isomap 结果(如果存在)ifhas_isomapsubplot(2,3,4);scatter(Y_isomap(:,1),Y_isomap(:,2),10,color,'filled');title(sprintf('Isomap 降维 (d=%d)',d_target));xlabel('Y1');ylabel('Y2');grid on;axis equal;end% 局部切空间可视化subplot(2,3,5);% 随机选择几个点显示局部切空间idx_show=randperm(size(X,1),5);hold on;grid on;scatter3(X(:,1),X(:,2),X(:,3),5,'k','filled');fori=1:length(idx_show)p=X(idx_show(i),:);% 随机生成切空间方向(示意)v1=[1,0,0]*0.1;v2=[0,1,0]*0.1;quiver3(p(1),p(2),p(3),v1(1),v1(2),v1(3),'r','LineWidth',1);quiver3(p(1),p(2),p(3),v2(1),v2(2),v2(3),'b','LineWidth',1);endtitle('局部切空间示意');xlabel('X1');ylabel('X2');zlabel('X3');% 重构误差subplot(2,3,6);% 计算重构误差(示意)errors=randn(100,1)*0.01;plot(errors,'k-','LineWidth',1.5);title('重构误差(示意)');xlabel('样本索引');ylabel('重构误差');grid on;sgtitle('局部切空间排列(LTSA)流形学习结果','FontSize',14,'FontWeight','bold');end三、运行说明
3.1 直接运行
>>ltsa_main3.2 参数调优建议
| 参数 | 建议值 | 说明 |
|---|---|---|
d_target | 2~3 | 目标维度,瑞士卷通常降为 2D |
k_neighbors | 10~20 | 邻域大小,太小欠平滑,太大过平滑 |
n_samples | 500~2000 | 样本数,太少流形不完整,太多计算慢 |
3.3 预期结果
生成瑞士卷流形数据... 数据维度: 1000 × 3 目标维度: 2 邻域大小: 12 运行 LTSA 降维... 构建 k 近邻图... 估计局部切空间... 构建局部对齐矩阵... 求解全局坐标... LTSA 完成! LTSA 耗时: 0.85 秒 运行 PCA 降维... PCA 耗时: 0.02 秒参考代码 流形学习算法,局部切空间排列算法进行降维的源代码www.youwenfan.com/contentcsw/82051.html
四、算法特性分析
| 特性 | LTSA | PCA | Isomap |
|---|---|---|---|
| 线性/非线性 | 非线性 | 线性 | 非线性 |
| 局部保持 | ✓ | ✗ | ✓ |
| 全局保持 | ✓ | ✓ | ✓ |
| 计算复杂度 | O(n²d) | O(nD²) | O(n³) |
| 对噪声鲁棒性 | 中等 | 强 | 弱 |
五、工程应用建议
5.1 参数选择
% 自动选择邻域大小(基于局部曲率)k_auto=ceil(2*log(size(X,1)));fprintf('自动选择邻域大小: %d\n',k_auto);5.2 数据预处理
% 标准化数据X_normalized=(X-mean(X,1))./std(X,0,1);% 去噪(可选)X_denoised=wavelet_denoise(X_normalized);5.3 与其他算法结合
% LTSA + K-means 聚类labels=kmeans(Y,3);% LTSA + SVM 分类model=fitcsvm(Y,labels);