이번 포스팅은 분류와 회귀, classification과 regression중에서 regression 에 representer theorem 을 적용하는 것에 대한 것이다. 여기서 한가지 집고 넘어가고 싶은 것은 representer theorem 자체는 위의 두 문제 모두 적용될 수 있는 알고리즘이라는 것이다.
1. Problem Setting
먼저 우리가 갖고 있는 데이터가 다음과 같다고 하자.
$$ (x_1, y_1), \ (x_2, y_2), \ \cdots \ , \ (x_m, y_m) \in X \times R $$
그리고 '$X$'의 두 원소의 product space에서 정의된 kernel이 있다고 하자.
$$ k: X \times X \to R \ \ \ \ (x, x') \mapsto k(x, x') $$
이 kernel 함수를 통해서 우리는 Gram matrix 혹은 kernel matrix를 만들 수 있다.
$$ K := (k(x_i, x_j))_{ij} \in R^{m \times m} $$
Associated Feature Space
'$C^X$'를 '$X \to C$'인 함수들이 살고 있는 공간이라고 하자. 그리고 '$X \to C^X$'의 mapping을 '$\phi$'라 하자. 헛갈리지말자. '$\phi$'가 살고 있는 공간이 '$C^X$'가 아니다. '$X$'를 '$X$'에서 정의된 함수로 연결해주는 함수이다.
$$ \phi: X \to C^X, x \mapsto k(\cdot, x) $$
이 함수들이 살고 있는 공간을 일종의 linear space라 한다면, 우리는 이 공간에서 살고 있는 임의의 함수 '$f$'와 '$g$'를 다음과 같이 표현할 수 있다.
$$ f(\cdot) = \sum_{i=1}^m \alpha_i k(\cdot, x) , \\ g(\cdot) = \sum_{j=1}^{m'} \beta_j k(\cdot, x'_j)$$
그리고 이 두 함수, 혹은 원소의 내적 (dot product)를 다음과 같이 정의한다.
$$ <f, g> := \sum_{i=1}^m \sum_{j=1}^{m'} \alpha_i \beta_j k(x_i, x'_j). $$
이렇게 정의할 경우 몇 가지 재밌는 성질들이 유도된다.
1. '$ <k(\cdot, x), f> = f(x) $'
여기서 '$k(\cdot, x)$'를 representer of evaluation 이라 한다.
2. '$ <k(\cdot, x), k(\cdot, x')> = k(x, x') $'
3. '$ k(x, x') = <\phi(x), \phi(x')> $'
위의 성질 때문에, 임의의 kernel로 형성되는 함수들의 공간을 reproducing kernel Hilbert space (RKHS) 라고 한다. Hilbert space는 원소들과 이 원소들의 내적이 정의된 공간을 의미한다.
2. The Representer Theorem
우리에게 임의의 nonempty set '$X$'가 있고, positive definite한 kernel '$k(\cdot)$'이 있다고 하자. 그리고 학습 데이터가 다음과 같이 주어졌다고 하자.
$$ (x_1, y_1), \ (x_2, y_2), \ \cdots \ , \ (x_m, y_m) \in X \times R $$
또한 strictly increasing function '$g(\cdot)$'과 cost function '$C: (X \times R^2)^m \to R \cup \{\infty\}$'이 있다고 하자. 그리고 우리가 찾고자 하는 함수가 다음의 꼴로 표현된다고 하자.
$$ F = \{ f \in R^X | f(\cdot) = \sum_{i=1}^\infty \beta_i k(\cdot, z_i) \} \\ \| \sum_{i=1}^{\infty} \beta_i k(\cdot, z_i) \|^2_2 = \sum_{i=1, j=1}^{\infty} \beta_i \beta_j k(z_i, z_j) $$
모든 가능한 조합이라고 생각하면 된다. 혹은 vector space를 생각해보면 선형 조합에 닫혀있으므로 모든 원소라고 생각해도 된다. 그 뒤의 식인 이 공간에서 정의된 내적의 정의이다.
자 여기서부터가 중요하다.
임의의 '$f \in F$'에 대해서 아래 최적화 문제를 푼다고 생각해보자.
$$ c((x_1, y_1, f(x_1), \ \cdots \, \ (x_m, y_m, f(x_m)) ) + g(\| f \|_2^2) $$
그러면 이 문제의 solution은 다음의 꼴로 표현될 수 있다.
$$ f(\cdot) = \sum_{i=1}^m \alpha_i k(\cdot, x_i) $$
>> The signifi cance of the theorem is that it shows that a whole range of learn ing algorithms have optimal solutions that can be expressed as expansions in terms of the training examples .
3. Applications of Representer Theorem
구슬이 서말이래도 꿰어야 보배 라고, 이제 실제로 써먹어보자. 두 가지 문제를 풀어볼 것이다.
첫 번째는 일반적인 regression 문제이다. 즉 입력과 출력에 해당하는 학습 데이터를 주고, 이 데이터들을 잘 표현하는 함수를 찾는 것이다. 두 번째는 문제는 조금 복잡한 문제로, 이 함수가 특정 점들과는 가깝고, 다른 점들과는 멀어지게 해보자. 결론부터 써보면 아래와 같다.
Representer theorem이 멋진 이유는 비선형 함수를 찾는 문제를 Convex quadratic programming으로 바꾼다는 것이다! 그럼 기존의 convex solver를 쓸 수 있다. 또한 negative data를 고려하는 문제가 indefinite QP이므로 이는 sequential QP와 Newton KKT로 풀수 있다. (모두 아래 매트랩 파일에 있다.)
동영상은 다음과 같다.
VIDEO
Matlab code
1. main.m
더보기 접기
%%
%
% Learning from negative data
% Here, we used warped error instead of squared loss for negative data
%
ccc
%% LfND From mouse clicks
ccc
global click_pos flag key_pressed
key_pressed = '';
flag = 'plot'; ppos = zeros(1E3, 2); nrppos = 0; npos = zeros(1E3, 2); nrnpos = 0;
fig = figure(1);
set(fig, 'WindowButtonDownFcn', @buttonDownCallback,'KeyPressFcn', @keyDownListener, 'Position', [200 200 1300 700] ...
, 'NumberTitle', 'off', 'Name', 'Regression with both positive and negative data');
nrtest = 100;
xtest = linspace(0, 10, nrtest)';
ytest = zeros(nrtest, 1);
curv = 0; totalmsec = 0;
while ~isequal(flag, 'terminate')
switch flag % mouse handler
case 'normal' % left click
nrppos = nrppos + 1; ppos(nrppos, :) = click_pos;
flag = 'update';
case 'alt' % right click
nrnpos = nrnpos + 1; npos(nrnpos, :) = click_pos;
flag = 'update';
end
% Do regression with both positive and negative data
if isequal(flag, 'update') || isequal(key_pressed, 'space') % update
xp = ppos(1:nrppos, 1); yp = ppos(1:nrppos, 2); % Positive data
xn = npos(1:nrnpos, 1); yn = npos(1:nrnpos, 2); % Negative data
x = [xp ; xn]; y = [yp ; yn];
hyp = struct('g2', 1, 'A', 1/(2)^2, 'w2', 1E-8, 'mzflag', 0);
if nrnpos == 0 % <= With no negative data, analytic solution exists
tic;
K = kernel_se(x, x, hyp);
alphabar = [K(1:nrppos, 1:nrppos) \ (y(1:nrppos)) ; zeros(nrnpos, 1)];
totalmsec = toc*1000;
else
opt = struct('maxiter', 50, 'ploteach', 0, 'xrange', [0 ; 10] ...
, 'alpharange', [-100 100], 'yrange', [-5 5] ...
, 'eta', 1, 'lambda', 0.5, 'gamma', 0E+0);
[alphabar, totalmsec, alphainit, K] = lfnd(xp, yp, xn, yn, hyp, opt);
end
curv = alphabar'*K*alphabar;
% Check result
Ktest = kernel_se(xtest, [xp; xn], hyp);
ytest = Ktest*alphabar;
flag = 'plot';
end
if isequal(flag, 'plot') % plot
flag = 'waiting';
clf; hold on;
hp = plot(ppos(1:nrppos, 1), ppos(1:nrppos, 2), 'bo', 'MarkerSize', 15, 'LineWidth', 2);
hn = plot(npos(1:nrnpos, 1), npos(1:nrnpos, 2), 'rx', 'MarkerSize', 15, 'LineWidth', 2);
plot(xtest, ytest, 'k-', 'LineWidth', 3);
title(sprintf('curvature: %.1e / msec: %.1ems', curv, totalmsec), 'FontSize', 13);
axis equal; axis([0 10 -3 3]);
grid on;
if ~isempty(hp) && ~isempty(hn), hleg = legend([hp, hn], 'Positive data', 'Negative data');
elseif ~isempty(hp) && isempty(hn), hleg = legend([hp], 'Positive data');
elseif isempty(hp) && ~isempty(hn), hleg = legend([hn], 'Negative data');
else hleg = legend('');
end
set(hleg, 'FontSize', 13);
xlabel('x (input)', 'FontSize', 13); ylabel('f(x) (output)', 'FontSize', 13);
drawnow;
end
if isequal(key_pressed, 'q')
flag = 'terminate';
else
key_pressed = '';
end
if ~ishandle(fig)
fprintf('Figure closed. \n');
flag = 'terminate';
end
pause(1E-10);
end
fprintf(2, 'Simulation terminated. \n');
if ishandle(fig)
title('Terminated', 'Color', 'r', 'FontSize', 13);
end
접기
2. buttonDownCallback.m
더보기 접기
function buttonDownCallback(~, ~)
global click_pos flag ;
p = get(gca,'CurrentPoint');
p = p(1,1:2);
click_pos = p;
flag = get(gcbf, 'SelectionType');
% disp(flag);
end
접기
3.keyDownListener.m
더보기 접기
function keyDownListener(~, event)
global key_pressed;
key_pressed = event.Key;
disp(key_pressed);
접기
4. lfnd.m
더보기 접기
function [alphabar, totalmsec, alphainit, K] = lfnd(xp, yp, xn, yn, hyp, opt)
%
% Learning from negative data
%
% sungjoon.choi@cpslab.snu.ac.kr
%
if nargin == 5
opt.maxiter = 1E3;
opt.sig2n = 1E-2;
opt.ploteach = 0;
opt.xrange = [0 ; 10];
opt.eta = 1;
opt.lambda = 0.5;
opt.gamma = 0;
opt.alpharange = [-1000 1000];
opt.yrange = [-100 100];
end
% Parser
if ~isfield(opt, 'maxiter')
opt.maxiter = 1E3;
end
if ~isfield(opt, 'sig2n')
opt.sig2n = 1E-2;
end
if ~isfield(opt, 'ploteach')
opt.ploteach = 0;
end
if ~isfield(opt, 'xrange')
opt.xrange = [0 ; 10];
end
if ~isfield(opt, 'eta')
opt.eta = 5;
end
if ~isfield(opt, 'lambda')
opt.lambda = 1;
end
if ~isfield(opt, 'gamma')
opt.gamma = 0;
end
if ~isfield(opt, 'alpharange')
opt.alpharange = [-1000 1000];
end
if ~isfield(opt, 'yrange')
opt.yrange = [-100 100];
end
maxiter = opt.maxiter; % Maximum iteration
ploteach = opt.ploteach; % Debugging flag
xrange = opt.xrange; % This is also used for debugging
eta = opt.eta; % Negative fitting
lambda = opt.lambda; % Hilber norm regularizer
gamma = opt.gamma; % Additional regularizer
alpharange = opt.alpharange;
yrange = opt.yrange;
alphamin = alpharange(1);
alpahmax = alpharange(2);
ymin = yrange(1);
ymax = yrange(2);
% Train
nrp = size(xp, 1);
nrn = size(xn, 1);
x = [xp ; xn]; y = [yp ; yn];
nr = nrp + nrn;
% Normalize outputscale
maxy = max(abs(y));
yp = yp / maxy;
yn = yn / maxy;
y = y / maxy;
% Precompute kernel matrices
K = kernel_se(x, x, hyp);
Kdp = kernel_se(x, xp, hyp); Kpd = Kdp';
Kdn = kernel_se(x, xn, hyp); Knd = Kdn';
Kpp = kernel_se(xp, xp, hyp); Knn = kernel_se(xn, xn, hyp);
% Initial alphabar
alphainit = [(Kpp+1E-4*eye(nrp)) \ yp ; zeros(nrn, 1)] + 0E-1*randn(nr, 1);
% alphainit = [(Kpp+1E-6*eye(nrp)) \ yp ; (Knn+1E-2*eye(nrn)) \ yn] + 1E-1*randn(nr, 1);
alphabar = alphainit;
totalmsec = 0;
for iter = 1:maxiter
iclk = clock;
ep = Kpd*alphabar - yp;
en = Knd*alphabar - yn;
Q = Kdp*Kpd - eta*Kdn*diag(eta*g_warp(en, 2))*Knd + lambda*(K')*K + gamma*eye(nr);
c = Kdp*ep - eta*Kdn*g_warp(en, 1) + lambda*K'*K*alphabar + gamma*alphabar;
Aprime = zeros(2*nrn+2*nr, nr);
bprime = zeros(2*nrn+2*nr, 1);
for i = 1:nrn
ktemp = kernel_se(x, xn(i, :), hyp);
Aprime(2*i-1, :) = ktemp';
Aprime(2*i, :) = -ktemp';
bprime(2*i-1) = ymax;
bprime(2*i) = -ymin;
end
for i = 2*nrn+1:2*nrn+nr
vtemp = zeros(1, nr);
vtemp(i-2*nrn) = 1;
Aprime(i, :) = vtemp;
bprime(i) = alpahmax;
end
for i = 2*nrn+nr+1:2*nrn+2*nr
vtemp = zeros(1, nr);
vtemp(i-2*nrn-nr) = -1;
Aprime(i, :) = vtemp;
bprime(i) = -alphamin;
end
A = Aprime;
b = bprime - A*alphabar;
dalpha0 = 0E-1*randn(nr, 1);
opts.step1 = 0;
opts.maxit = 1E3;
dalpha = NewtonKKTqp(Q, c, A, b, dalpha0, opts);
maxdalpha = max(abs(dalpha));
msec = etime(clock, iclk)*1000;
totalmsec = totalmsec + msec;
% Plot
if ploteach
xt = linspace(xrange(1), xrange(2), 1E2)';
yt = kernel_se(xt, x, hyp)*alphabar*maxy;
figure(99); clf; hold on;
ms = 'MarkerSize'; lw = 'LineWidth';
hp = plot(xp, yp*maxy, 'bo', ms, 15, lw, 2); hn = plot(xn, yn*maxy, 'rx', ms, 15, lw, 2);
plot(xt, yt, 'k--', lw, 3);
hleg = legend([hp hn], 'Positive data', 'Negative data'); set(hleg, 'FontSize', 13);
grid on; axis([0 10 -5 5]);
title(sprintf('iter: %d / maxdalpha: %.2e', iter, maxdalpha));
drawnow;
pause();
end
% Update
maxth = 100;
if maxdalpha > maxth, dalpha = maxth*dalpha/maxdalpha; end;
stepsize = 1E-1;
alphabar = alphabar + stepsize*dalpha;
if maxdalpha < 0.1
break;
end
end
% Check contraint
ineqvec = A*alphabar - b; % This should be less than zero
maxineqvec = max(ineqvec); % This also should be less than zeros
fprintf('maxineqvec: %.2e \n', maxineqvec);
% Rescale
alphainit = alphainit * maxy;
alphabar = alphabar * maxy;
접기
5. kernel_se.m
더보기 접기
function K = kernel_se(X1, X2, hyp)
%
% Squared exponential kernel function
%
% sungjoon.choi@cpslab.snu.ac.kr
%
nr_X1 = size(X1, 1);
nr_X2 = size(X2, 1);
K = hyp.g2*exp(-sqdist(X1', X2', hyp.A));
if nr_X1 == nr_X2 && nr_X1 > 1
K = K + hyp.w2*eye(nr_X1, nr_X1);
end
function Y = col_sum(X)
Y = sum(X, 1);
function m = sqdist(p, q, A)
% SQDIST Squared Euclidean or Mahalanobis distance.
% SQDIST(p,q) returns m(i,j) = (p(:,i) - q(:,j))'*(p(:,i) - q(:,j)).
% SQDIST(p,q,A) returns m(i,j) = (p(:,i) - q(:,j))'*A*(p(:,i) - q(:,j)).
% Written by Tom Minka
[~, pn] = size(p);
[~, qn] = size(q);
if pn == 0 || qn == 0
m = zeros(pn,qn);
return
end
if nargin == 2
pmag = col_sum(p .* p);
qmag = col_sum(q .* q);
m = repmat(qmag, pn, 1) + repmat(pmag', 1, qn) - 2*p'*q;
%m = ones(pn,1)*qmag + pmag'*ones(1,qn) - 2*p'*q;
else
Ap = A*p;
Aq = A*q;
pmag = col_sum(p .* Ap);
qmag = col_sum(q .* Aq);
m = repmat(qmag, pn, 1) + repmat(pmag', 1, qn) - 2*p'*Aq;
end
접기