Fast implementation of kdpp
Details about kdpp can be found in (http://enginius.tistory.com/494)
Implementation of kdpp using matrix inversion lemma to reduce the computational complexity.
The speedup increases as the number of points increases (>100)
Starting kdpp where k is 50.
Starting kdpp_fast where k is 50.
k: 50 / toc_kdpp: 4.547e-02 sec / toc_kdpp_fast: 9.981e-03 sec (4.6 times faster)
One important rule of thumb is to lower the length parameter! So that the kernel function is peaky! kdpp_fast utilizes matrix inversion lemma and computing the inverse of a kernel matrix might be numerically unstable if the kernel function is smooth.
Result
code: main
%%
clc; clear all; close all;
%% Y를 만들자.
N = 1E4;
Y = zeros(N, 2);
xmax = 10;
ymax = 10;
for i = 1:N
Y(i, :) = [xmax*rand ymax*rand];
end
%% DPP와 unif로 뽑아보자.
k = 1000;
kfun = @(X1, X2)(kernel_se(X1, X2));
fprintf('Starting kdpp where k is %d. \n', k);
tic;
kdpp_struct = kdpp(Y, kfun, k);
toc1 = toc;
fprintf('Starting kdpp_fast where k is %d. \n', k);
tic;
kdpp_fast_struct = kdpp_fast(Y, kfun, k);
toc2 = toc;
fprintf('k: %d / toc_kdpp: %.3e sec / toc_kdpp_fast: %.3e sec (%.1f times faster) \n', k, toc1, toc2, toc1/toc2);
%% Uniform하게 뽑는다.
randIndices = randperm(N);
unif_set = Y(randIndices(1:k), :);
%% 그림을 그려서 확인한다.
fig = figure(1); set(fig, 'Position', [400 500 1500 400]);
subplot(1,3,1);
hold on;
plot(Y(:, 1), Y(:, 2), 'o', 'Color', 0.5*ones(1, 3));
plot(unif_set(:, 1), unif_set(:, 2), 'bo', 'LineWidth', 3, 'MarkerSize', 15);
hold off; grid on; axis([0 10 0 10]); axis equal;title(sprintf('[Uniform] N: %d k: %d', N, k), 'FontSize', 20);
subplot(1,3,2);
hold on;
plot(Y(:, 1), Y(:, 2), 'o', 'Color', 0.5*ones(1, 3));
plot(kdpp_struct.kdata(:, 1), kdpp_struct.kdata(:, 2), 'ro', 'LineWidth', 3, 'MarkerSize', 15);
for i = 1:kdpp_struct.k
text(kdpp_struct.kdata(i, 1)+0.3, kdpp_struct.kdata(i, 2), sprintf('%d', i), 'FontSize', 15 ...
, 'BackgroundColor',[.7 .9 .7]);
end
hold off; grid on; axis([0 10 0 10]); axis equal; title(sprintf('[DPP] N: %d k: %d', N, k), 'FontSize', 20);
subplot(1,3,3);
hold on;
plot(Y(:, 1), Y(:, 2), 'o', 'Color', 0.5*ones(1, 3));
plot(kdpp_fast_struct.kdata(:, 1), kdpp_fast_struct.kdata(:, 2), 'ro', 'LineWidth', 3, 'MarkerSize', 15);
for i = 1:kdpp_fast_struct.k
text(kdpp_fast_struct.kdata(i, 1)+0.3, kdpp_fast_struct.kdata(i, 2), sprintf('%d', i), 'FontSize', 15 ...
, 'BackgroundColor',[.7 .9 .7]);
end
hold off; grid on; axis([0 10 0 10]); axis equal; title(sprintf('[DPP] N: %d k: %d', N, k), 'FontSize', 20);
code: kernel
function K = kernel_se(X1, X2, hyp)
%
%
%
if nargin == 2
hyp.sig2w = 1E-2;
hyp.sig2f = 1;
hyp.sig2x = 1;
end
nr_X1 = size(X1, 1);
nr_X2 = size(X2, 1);
K = hyp.sig2f*exp(-sqdist(X1', X2', 1/hyp.sig2x));
if nr_X1 == nr_X2
K = K + hyp.sig2w*eye(nr_X1, nr_X1);
end
function m = sqdist(p, q, A)
% SQDIST Squared Euclidean or Mahalanobis distance.
% SQDIST(p,q) returns m(i,j) = (p(:,i) - q(:,j))'*(p(:,i) - q(:,j)).
% SQDIST(p,q,A) returns m(i,j) = (p(:,i) - q(:,j))'*A*(p(:,i) - q(:,j)).
% Written by Tom Minka
[d, pn] = size(p);
[d, qn] = size(q);
if pn == 0 || qn == 0
m = zeros(pn,qn);
return
end
if nargin == 2
pmag = col_sum(p .* p);
qmag = col_sum(q .* q);
m = repmat(qmag, pn, 1) + repmat(pmag', 1, qn) - 2*p'*q;
%m = ones(pn,1)*qmag + pmag'*ones(1,qn) - 2*p'*q;
else
Ap = A*p;
Aq = A*q;
pmag = col_sum(p .* Ap);
qmag = col_sum(q .* Aq);
m = repmat(qmag, pn, 1) + repmat(pmag', 1, qn) - 2*p'*Aq;
end
function Y = col_sum(X)
Y = sum(X, 1);
code: kdpp
function kdppStruct = kdpp(data, kernelFunction, k)
% *********************************************************************************
% data가 주어지고, kernelFunction으로 corr이 정의될 때 k개의 데이터를 뽑는 함수이다.
% data: 데이터는 column 방향으로 들어있다.
% kernelFunction: corr 정의
% k: 뽑을 데이터의 수
% *********************************************************************************
kdppStruct.k = k;
kdppStruct.totalData = data;
kdppStruct.kernelFunction = kernelFunction;
kdppStruct.kdata = zeros(k, size(data, 2));
kdppStruct.kIdx = zeros(k, 1);
tic;
% 이걸 저장하면 용량이 너무 커진다.
% kdppStruct.Kmtx = kdppStruct.kernelFunction(kdppStruct.totalData, kdppStruct.totalData);
Kmtx = kdppStruct.kernelFunction(kdppStruct.totalData, kdppStruct.totalData);
tocSec = toc;
% fprintf('Kernel Matrix computed (%2.2f sec) \n', tocSec);
nrData = size(data, 1); % 데이터의 수
indices = 1:nrData;
% DPP에서 처음은 랜덤하게 뽑는다.
randIdx = randi([1 nrData]);
kdppStruct.kdata(1, :) = kdppStruct.totalData(indices(randIdx), :);
kdppStruct.kIdx(1) = indices(randIdx);
indices(randIdx) = []; % 뽑은 데이터에 해당하는 idx는 없에버린다.
% 두 번째부터는 det를 구해서 뽑는다.
for i = 2:k
% 남아있는 length(indices)만큼의 데이터를 하나씩 추가해보면서 p를 구해본다.
tic;
p = zeros(length(indices), 1);
for j = 1:length(indices)
currIdx = indices(j);
tempIdx = [kdppStruct.kIdx(1:i-1, :); currIdx];
K = Kmtx(tempIdx, tempIdx);
p(j) = det(K);
end
[~, maxIdx] = max(p);
% maxIdx번째 데이터를 추가한다.
kdppStruct.kdata(i, :) = kdppStruct.totalData(indices(maxIdx), :);
kdppStruct.kIdx(i) = indices(maxIdx);
% 추가한 데이터를 indices에선 뺀다.
indices(maxIdx) = [];
currToc = toc;
% fprintf('[%d/%d] %.4f sec \n', i, k, currToc);
end
% 예외 처리
if k == 0
kdppStruct.kIdx = [];
end
code: kdpp_fast
function kdppStruct = kdpp_fast(data, kernelFunction, k)
% *********************************************************************************
% data가 주어지고, kernelFunction으로 corr이 정의될 때 k개의 데이터를 뽑는 함수이다.
% data: 데이터는 column 방향으로 들어있다.
% kernelFunction: corr 정의
% k: 뽑을 데이터의 수
% *********************************************************************************
kdppStruct.k = k;
kdppStruct.totalData = data;
kdppStruct.kernelFunction = kernelFunction;
kdppStruct.kdata = zeros(k, size(data, 2));
kdppStruct.kIdx = zeros(k, 1);
tic;
Kmtx = kdppStruct.kernelFunction(kdppStruct.totalData, kdppStruct.totalData);
tocSec = toc;
% fprintf('Kernel Matrix computed (%2.2f sec) \n', tocSec);
nrData = size(data, 1); % 데이터의 수
indices = 1:nrData;
% DPP에서 처음은 랜덤하게 뽑는다.
randIdx = randi([1 nrData]);
kdppStruct.kdata(1, :) = kdppStruct.totalData(indices(randIdx), :);
kdppStruct.kIdx(1) = indices(randIdx);
Kinv = 1/Kmtx(kdppStruct.kIdx(1), kdppStruct.kIdx(1));
detK = det(Kmtx(kdppStruct.kIdx(1), kdppStruct.kIdx(1)));
indices(randIdx) = []; % 뽑은 데이터에 해당하는 idx는 없에버린다.
% 두 번째부터는 det를 구해서 뽑는다.
for i = 2:k
tic;
% 남아있는 length(indices)만큼의 데이터를 하나씩 추가해보면서 p를 구해본다.
p = zeros(length(indices), 1);
cumIdx = kdppStruct.kIdx(1:i-1, :);
for j = 1:length(indices)
currIdx = indices(j);
detKnew = detK*(Kmtx(currIdx, currIdx) ...
- Kmtx(currIdx, cumIdx)*Kinv*Kmtx(cumIdx, currIdx));
p(j) = detKnew;
end
[~, maxIdx] = max(p);
% 빠른 연산을 위해 Kinv와 detK를 저장한다.
S = Kmtx(maxIdx, maxIdx) ...
- Kmtx(maxIdx, cumIdx)*Kinv*Kmtx(cumIdx, currIdx);
P = Kinv*Kmtx(cumIdx, currIdx);
Kinv = [Kinv+P*P'/S -P/S ; P'/S 1/S];
detK = p(maxIdx);
% maxIdx번째 데이터를 추가한다.
kdppStruct.kdata(i, :) = kdppStruct.totalData(indices(maxIdx), :);
kdppStruct.kIdx(i) = indices(maxIdx);
% 추가한 데이터를 indices에선 뺀다.
indices(maxIdx) = [];
currToc = toc;
% fprintf('[%d/%d] %.4f sec \n', i, k, currToc);
end