ECG-Kit 1.0

File: <base>/common/prtools_addins/CG_CLASSIFY.m (3,267 bytes)
% Version 1.000
%
% Code provided by Ruslan Salakhutdinov and Geoff Hinton
%
% Permission is granted for anyone to copy, use, modify, or distribute this
% program and accompanying programs and documents for any purpose, provided
% this copyright notice is retained and prominently displayed, along with
% a note saying that the original programs are available from our
% web page.
% The programs and documents are distributed without any warranty, express or
% implied.  As the programs were written for research purposes only, they have
% not been tested to the degree that would be advisable in any important
% application.  All use of these programs is entirely at the user's own risk.


function [f, df] = CG_CLASSIFY(VV,Dim,XX,target)

% l1 = Dim(1);
% l2 = Dim(2);
% l3= Dim(3);
% l4= Dim(4);
% l5= Dim(5);

N = size(XX,1);
hid_num = length(Dim)-2;

% Do decomversion.

xxx = 0;
for ii=1:hid_num
    model{ii}.W = reshape(VV(xxx+1:xxx+(Dim(ii)+1)*Dim(ii+1)),Dim(ii)+1,Dim(ii+1));
    xxx = xxx+(Dim(ii)+1)*Dim(ii+1);
end
w_class = reshape(VV(xxx+1:xxx+(Dim(ii+1)+1)*Dim(ii+2)),Dim(ii+1)+1,Dim(ii+2));

%  w1 = reshape(VV(1:(l1+1)*l2),l1+1,l2);
%  xxx = (l1+1)*l2;
%  w2 = reshape(VV(xxx+1:xxx+(l2+1)*l3),l2+1,l3);
%  xxx = xxx+(l2+1)*l3;
%  w3 = reshape(VV(xxx+1:xxx+(l3+1)*l4),l3+1,l4);
%  xxx = xxx+(l3+1)*l4;
%  w_class = reshape(VV(xxx+1:xxx+(l4+1)*l5),l4+1,l5);


%   XX = [XX ones(N,1)];
%   w1probs = 1./(1 + exp(-XX*w1)); 
%   w1probs = [w1probs  ones(N,1)];
%   w2probs = 1./(1 + exp(-w1probs*w2)); 
%   w2probs = [w2probs ones(N,1)];
%   w3probs = 1./(1 + exp(-w2probs*w3)); 
%   w3probs = [w3probs  ones(N,1)];

XX = [XX ones(N,1)];
wxprobs = cell(hid_num,1);
wxprobs{1} = 1./(1 + exp(-XX * model{1}.W)); 
wxprobs{1} = [wxprobs{1} ones(N,1)];
for ii=2:hid_num
    wxprobs{ii} = 1./(1 + exp(-wxprobs{ii-1} * model{ii}.W)); 
    wxprobs{ii} = [wxprobs{ii}  ones(N,1)];
end

% w1probs = wxprobs{1};
% w2probs = wxprobs{2};
% w3probs = wxprobs{3};
% w2 = model{2}.W;
% w3 = model{3}.W;

% targetout = exp(w3probs*w_class);
targetout = exp(wxprobs{hid_num}*w_class);
targetout = targetout./repmat(sum(targetout,2),1,size(target,2));
f = -sum(sum( target .* log(targetout))) ;

%     [~, J]=max(targetout,[],2);
%     [~, J1]=max(target,[],2);
%     errors = sum(J~=J1);
%     fprintf(1, 'Error rate %3.2f\n', errors/N);
  
% IO = (targetout-target);
% Ix_class=IO; 
% dw_class =  w3probs'*Ix_class; 
% 
% Ix3 = (Ix_class*w_class').*w3probs.*(1-w3probs);
% Ix3 = Ix3(:,1:end-1);
% dw3 =  w2probs'*Ix3;
% 
% Ix2 = (Ix3*w3').*w2probs.*(1-w2probs); 
% Ix2 = Ix2(:,1:end-1);
% dw2 =  w1probs'*Ix2;
% 
% Ix1 = (Ix2*w2').*w1probs.*(1-w1probs); 
% Ix1 = Ix1(:,1:end-1);
% dw1 =  XX'*Ix1;
% 
% df = [dw1(:)' dw2(:)' dw3(:)' dw_class(:)']'; 

Ix = (targetout-target);
df = colvec(wxprobs{hid_num}' * Ix);

Ix = (Ix*w_class') .* wxprobs{hid_num} .* (1-wxprobs{hid_num});    
Ix = Ix(:,1:end-1);
df = [ colvec( wxprobs{hid_num-1}'*Ix ) ; df];

if( hid_num > 2 )
    for ii=hid_num-1:-1:2

        Ix = (Ix * model{ii+1}.W') .* wxprobs{ii} .* (1-wxprobs{ii});    
        Ix = Ix(:,1:end-1);
        df = [ colvec( wxprobs{ii-1}'*Ix ) ; df];

    end
end

Ix = (Ix * model{2}.W') .* wxprobs{1} .* (1-wxprobs{1});    
Ix = Ix(:,1:end-1);
df = [ colvec( XX'*Ix ) ; df];