ECG-Kit 1.0

File: <base>/common/prtools_addins/rbmBB.m (6,924 bytes)
function [model, top] = rbmBB(X, numhid, varargin)
%Learn RBM with Bernoulli hidden and visible units
%This is not meant to be applied to image data
%code by Andrej Karpathy
%based on implementation of Kevin Swersky and Ruslan Salakhutdinov

%INPUTS: 
%X              ... data. should be binary, or in [0,1] to be interpreted 
%               ... as probabilities
%numhid         ... number of hidden layers

%additional inputs (specified as name value pairs or in struct)
%method         ... CD or SML 
%eta            ... learning rate
%momentum       ... momentum for smoothness amd to prevent overfitting
%               ... NOTE: momentum is not recommended with SML
%maxepoch       ... # of epochs: each is a full pass through train data
%avglast        ... how many epochs before maxepoch to start averaging
%               ... before. Procedure suggested for faster convergence by
%               ... Kevin Swersky in his MSc thesis
%penalty        ... weight decay factor
%batchsize      ... The number of training instances per batch
%verbose        ... For printing progress
%anneal         ... Flag. If set true, the penalty is annealed linearly
%               ... through epochs to 10% of its original value

%OUTPUTS:
%model.type     ... Type of RBM (i.e. type of its visible and hidden units)
%model.W        ... The weights of the connections
%model.b        ... The biases of the hidden layer
%model.c        ... The biases of the visible layer
%model.top      ... The activity of the top layer, to be used when training
%               ... DBN's
%errors         ... The errors in reconstruction at every epoch

%Process options
%if args are just passed through in calls they become cells
if (isstruct(varargin)) 
    args= prepareArgs(varargin{1});
else
    args= prepareArgs(varargin);
end
[   method        ...
    eta           ...
    momentum      ...
    maxepoch      ...
    avglast       ...
    penalty       ...
    batchsize     ...
    verbose       ...
    anneal        ...
    ] = process_options(args    , ...
    'method'        ,  'CD'     , ...
    'eta'           ,  0.1      , ...
    'momentum'      ,  0.5      , ...
    'maxepoch'      ,  50       , ...
    'avglast'       ,  5        , ...
    'penalty'       , 2e-4      , ...
    'batchsize'     , 100       , ...
    'verbose'       , false     , ...
    'anneal'        , false);
avgstart = maxepoch - avglast;
oldpenalty= penalty;
[N,d]=size(X);

if (verbose) 
    fprintf('Preprocessing data...\n');
end

%Create batches
numcases=N;
numdims=d;
numbatches= ceil(N/batchsize);
groups= repmat(1:numbatches, 1, batchsize);
groups= groups(1:N);
perm=randperm(N);
groups = groups(perm);
for i=1:numbatches
    batchdata{i}= X(groups==i,:);
end

%train RBM
W = 0.1*randn(numdims,numhid);
c = zeros(1,numdims);
b = zeros(1,numhid);
ph = zeros(numcases,numhid);
nh = zeros(numcases,numhid);
phstates = zeros(numcases,numhid);
nhstates = zeros(numcases,numhid);
negdata = zeros(numcases,numdims);
negdatastates = zeros(numcases,numdims);
Winc  = zeros(numdims,numhid);
binc = zeros(1,numhid);
cinc = zeros(1,numdims);
Wavg = W;
bavg = b;
cavg = c;
t = 1;
errors=zeros(1,maxepoch);

for epoch = 1:maxepoch
    
	errsum=0;
    if (anneal)
        %apply linear weight penalty decay
        penalty= oldpenalty - 0.9*epoch/maxepoch*oldpenalty;
    end
    
    for batch = 1:numbatches
		[numcases numdims]=size(batchdata{batch});
		data = batchdata{batch};

        
%         %go up
% 		ph = logistic(data*W + repmat(b,numcases,1));
% 		phstates = ph > rand(numcases,numhid);
%         if (isequal(method,'SML'))
%             if (epoch == 1 && batch == 1)
%                 nhstates = phstates;
%             end
%         elseif (isequal(method,'CD'))
%             nhstates = phstates;
%         end

        %%%%%%%%% START POSITIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        poshidprobs = 1./(1 + exp(-data*W - repmat(b,numcases,1)));    
        posprods    = data' * poshidprobs;
        poshidact   = sum(poshidprobs);
        posvisact = sum(data);

        %%%%%%%%% END OF POSITIVE PHASE  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        poshidstates = poshidprobs > rand(numcases,numhid);

%         %go down
% 		negdata = logistic(nhstates*W' + repmat(c,numcases,1));
% 		negdatastates = negdata > rand(numcases,numdims);
%         
%         %go up one more time
% 		nh = logistic(negdatastates*W + repmat(b,numcases,1));
% 		nhstates = nh > rand(numcases,numhid);
        
        %%%%%%%%% START NEGATIVE PHASE  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        negdata = 1./(1 + exp(-poshidstates*W' - repmat(c,numcases,1)));
        neghidprobs = 1./(1 + exp(-negdata*W - repmat(b,numcases,1)));    
        negprods  = negdata'*neghidprobs;
        neghidact = sum(neghidprobs);
        negvisact = sum(negdata); 

        %%%%%%%%% END OF NEGATIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        err= sum(sum( (data-negdata).^2 ));
        errsum = err + errsum;

        if epoch>5,
            this_momentum=1.8*momentum;
        else
            this_momentum=momentum;
        end;

%         %update weights and biases
%         dW = (data'*ph - negdatastates'*nh);
%         dc = sum(data) - sum(negdatastates);
%         db = sum(ph) - sum(nh);
% 		Winc = momentum*Winc + eta*(dW/numcases - penalty*W);
% 		binc = momentum*binc + eta*(db/numcases);
% 		cinc = momentum*cinc + eta*(dc/numcases);
% 		W = W + Winc;
% 		b = b + binc;
% 		c = c + cinc;
        
        %%%%%%%%% UPDATE WEIGHTS AND BIASES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 
        Winc = this_momentum*Winc + eta*( (posprods-negprods)/numcases - penalty*W);
        binc = this_momentum*binc + (eta/numcases)*(poshidact-neghidact);
        cinc = this_momentum*cinc + (eta/numcases)*(posvisact-negvisact);

        W = W + Winc;
        c = c + cinc;
        b = b + binc;

        %%%%%%%%%%%%%%%% END OF UPDATES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 
        
        
        
%         if (epoch > avgstart)
%             %apply averaging
% 			Wavg = Wavg - (1/t)*(Wavg - W);
% 			cavg = cavg - (1/t)*(cavg - c);
% 			bavg = bavg - (1/t)*(bavg - b);
% 			t = t+1;
% 		else
			Wavg = W;
			bavg = b;
			cavg = c;
%         end
        
        %accumulate reconstruction error
        err= sum(sum( (data-negdata).^2 ));
		errsum = err + errsum;
        
    end
    
    errors(epoch)=errsum;
    if (verbose) 
        fprintf('Ended epoch %i/%i. Reconstruction error is %f\n', ...
            epoch, maxepoch, errsum);
    end
end

model.type= 'BB';
top = logistic(X*Wavg + repmat(bavg,N,1));
model.W= Wavg;
model.b= bavg;
model.c= cavg;