ECG-Kit 1.0
(6,924 bytes)
function [model, top] = rbmBB(X, numhid, varargin)
%Learn RBM with Bernoulli hidden and visible units
%This is not meant to be applied to image data
%code by Andrej Karpathy
%based on implementation of Kevin Swersky and Ruslan Salakhutdinov
%INPUTS:
%X ... data. should be binary, or in [0,1] to be interpreted
% ... as probabilities
%numhid ... number of hidden layers
%additional inputs (specified as name value pairs or in struct)
%method ... CD or SML
%eta ... learning rate
%momentum ... momentum for smoothness amd to prevent overfitting
% ... NOTE: momentum is not recommended with SML
%maxepoch ... # of epochs: each is a full pass through train data
%avglast ... how many epochs before maxepoch to start averaging
% ... before. Procedure suggested for faster convergence by
% ... Kevin Swersky in his MSc thesis
%penalty ... weight decay factor
%batchsize ... The number of training instances per batch
%verbose ... For printing progress
%anneal ... Flag. If set true, the penalty is annealed linearly
% ... through epochs to 10% of its original value
%OUTPUTS:
%model.type ... Type of RBM (i.e. type of its visible and hidden units)
%model.W ... The weights of the connections
%model.b ... The biases of the hidden layer
%model.c ... The biases of the visible layer
%model.top ... The activity of the top layer, to be used when training
% ... DBN's
%errors ... The errors in reconstruction at every epoch
%Process options
%if args are just passed through in calls they become cells
if (isstruct(varargin))
args= prepareArgs(varargin{1});
else
args= prepareArgs(varargin);
end
[ method ...
eta ...
momentum ...
maxepoch ...
avglast ...
penalty ...
batchsize ...
verbose ...
anneal ...
] = process_options(args , ...
'method' , 'CD' , ...
'eta' , 0.1 , ...
'momentum' , 0.5 , ...
'maxepoch' , 50 , ...
'avglast' , 5 , ...
'penalty' , 2e-4 , ...
'batchsize' , 100 , ...
'verbose' , false , ...
'anneal' , false);
avgstart = maxepoch - avglast;
oldpenalty= penalty;
[N,d]=size(X);
if (verbose)
fprintf('Preprocessing data...\n');
end
%Create batches
numcases=N;
numdims=d;
numbatches= ceil(N/batchsize);
groups= repmat(1:numbatches, 1, batchsize);
groups= groups(1:N);
perm=randperm(N);
groups = groups(perm);
for i=1:numbatches
batchdata{i}= X(groups==i,:);
end
%train RBM
W = 0.1*randn(numdims,numhid);
c = zeros(1,numdims);
b = zeros(1,numhid);
ph = zeros(numcases,numhid);
nh = zeros(numcases,numhid);
phstates = zeros(numcases,numhid);
nhstates = zeros(numcases,numhid);
negdata = zeros(numcases,numdims);
negdatastates = zeros(numcases,numdims);
Winc = zeros(numdims,numhid);
binc = zeros(1,numhid);
cinc = zeros(1,numdims);
Wavg = W;
bavg = b;
cavg = c;
t = 1;
errors=zeros(1,maxepoch);
for epoch = 1:maxepoch
errsum=0;
if (anneal)
%apply linear weight penalty decay
penalty= oldpenalty - 0.9*epoch/maxepoch*oldpenalty;
end
for batch = 1:numbatches
[numcases numdims]=size(batchdata{batch});
data = batchdata{batch};
% %go up
% ph = logistic(data*W + repmat(b,numcases,1));
% phstates = ph > rand(numcases,numhid);
% if (isequal(method,'SML'))
% if (epoch == 1 && batch == 1)
% nhstates = phstates;
% end
% elseif (isequal(method,'CD'))
% nhstates = phstates;
% end
%%%%%%%%% START POSITIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
poshidprobs = 1./(1 + exp(-data*W - repmat(b,numcases,1)));
posprods = data' * poshidprobs;
poshidact = sum(poshidprobs);
posvisact = sum(data);
%%%%%%%%% END OF POSITIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
poshidstates = poshidprobs > rand(numcases,numhid);
% %go down
% negdata = logistic(nhstates*W' + repmat(c,numcases,1));
% negdatastates = negdata > rand(numcases,numdims);
%
% %go up one more time
% nh = logistic(negdatastates*W + repmat(b,numcases,1));
% nhstates = nh > rand(numcases,numhid);
%%%%%%%%% START NEGATIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
negdata = 1./(1 + exp(-poshidstates*W' - repmat(c,numcases,1)));
neghidprobs = 1./(1 + exp(-negdata*W - repmat(b,numcases,1)));
negprods = negdata'*neghidprobs;
neghidact = sum(neghidprobs);
negvisact = sum(negdata);
%%%%%%%%% END OF NEGATIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
err= sum(sum( (data-negdata).^2 ));
errsum = err + errsum;
if epoch>5,
this_momentum=1.8*momentum;
else
this_momentum=momentum;
end;
% %update weights and biases
% dW = (data'*ph - negdatastates'*nh);
% dc = sum(data) - sum(negdatastates);
% db = sum(ph) - sum(nh);
% Winc = momentum*Winc + eta*(dW/numcases - penalty*W);
% binc = momentum*binc + eta*(db/numcases);
% cinc = momentum*cinc + eta*(dc/numcases);
% W = W + Winc;
% b = b + binc;
% c = c + cinc;
%%%%%%%%% UPDATE WEIGHTS AND BIASES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Winc = this_momentum*Winc + eta*( (posprods-negprods)/numcases - penalty*W);
binc = this_momentum*binc + (eta/numcases)*(poshidact-neghidact);
cinc = this_momentum*cinc + (eta/numcases)*(posvisact-negvisact);
W = W + Winc;
c = c + cinc;
b = b + binc;
%%%%%%%%%%%%%%%% END OF UPDATES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% if (epoch > avgstart)
% %apply averaging
% Wavg = Wavg - (1/t)*(Wavg - W);
% cavg = cavg - (1/t)*(cavg - c);
% bavg = bavg - (1/t)*(bavg - b);
% t = t+1;
% else
Wavg = W;
bavg = b;
cavg = c;
% end
%accumulate reconstruction error
err= sum(sum( (data-negdata).^2 ));
errsum = err + errsum;
end
errors(epoch)=errsum;
if (verbose)
fprintf('Ended epoch %i/%i. Reconstruction error is %f\n', ...
epoch, maxepoch, errsum);
end
end
model.type= 'BB';
top = logistic(X*Wavg + repmat(bavg,N,1));
model.W= Wavg;
model.b= bavg;
model.c= cavg;