TRAINRBF Trains an rbfnet given training samples Input: net: the net object X: the inputs [input_dimension x number of samples] Y: the targets [output_dimension x number of samples] options: Currently allowed options - minerr, maxiter, displayerr, initializev Output: Trained net Misc: Currently this class only modifies weight matrices The centers and widths are not modified. This will be added after bias has been taken care of.
0001 function net = trainrbf(net,X,Y,options) 0002 % TRAINRBF Trains an rbfnet given training samples 0003 % Input: 0004 % net: the net object 0005 % X: the inputs [input_dimension x number of samples] 0006 % Y: the targets [output_dimension x number of samples] 0007 % options: Currently allowed options - minerr, maxiter, displayerr, initializev 0008 % 0009 % Output: 0010 % Trained net 0011 % 0012 % Misc: 0013 % Currently this class only modifies weight matrices 0014 % The centers and widths are not modified. This will be added 0015 % after bias has been taken care of. 0016 0017 0018 % Satrajit Ghosh, SpeechLab, Boston University. (c)2001 0019 % $Header: /mnt/localhd/cvsdir/MODELLING/NEWDIVA/@ahrbf/trainrbf.m,v 1.1.1.1 2006/10/06 18:20:23 brumberg Exp $ 0020 0021 % $NoKeywords: $ 0022 0023 % Determine what options have been set 0024 % Currently allowed options 0025 % minerr, maxiter, displayerr, initializev 0026 if ~isfield(options,'minerr'), 0027 minerr = 0; 0028 else, 0029 minerr = options.minerr; 0030 end; 0031 if ~isfield(options,'maxiter'), 0032 maxiter = inf; 0033 else, 0034 maxiter = options.maxiter; 0035 end; 0036 if ~isfield(options,'displayerr'), 0037 bDisplayError = 0; % By default no error displayed 0038 else, 0039 bDisplayError = options.displayerr; 0040 end; 0041 if ~isfield(options,'initializev'), 0042 bInitializeV = 0; % By default no error displayed 0043 else, 0044 bInitializeV = options.initializev; 0045 end; 0046 0047 % Determine the weight net.v using a pseudo-inverse calculation 0048 if bInitializeV, 0049 numiter = 50; 0050 hproc = waitbar(0,'Initializing weights'); 0051 net.v = zeros(size(net.v)); 0052 for i=1:numiter, 0053 idx = ceil(size(X,2)*rand(1,min([size(X,2) 500]))); 0054 [Yhat,H] = simrbf(net,X(:,idx)); 0055 net.v = net.v+[Y(:,idx)*pinv(H)]; 0056 waitbar(i/numiter,hproc); 0057 end; 0058 close(hproc); 0059 net.v = net.v/numiter; 0060 end; 0061 0062 if bDisplayError, 0063 figure(10);hold on; 0064 end; 0065 0066 % Train to get net.w 0067 E = inf; 0068 count = 0; 0069 while(mean(max(E,0))>0) & count<maxiter, 0070 count = count+1; 0071 % Select a random ordering of the training set 0072 idx = randperm(size(X,2)); 0073 Err = 0; 0074 % do an online training 0075 for i=1:length(idx), 0076 %simulate 0077 net = netout(net,X(:,idx(i))); 0078 %calculate error 0079 err = Y(:,idx(i))-net.out; 0080 % update weights 0081 net = updatevw(net,err); 0082 % cerr = mean(abs(err./(1+Y(:,idx(i))))); 0083 Err = (Err*(i-1)+err.^2)/i; 0084 end 0085 % Error displayed in command window should go to zero if minerr 0086 % is reached 0087 E = sqrt(Err)-minerr 0088 0089 if bDisplayError, 0090 plot(count,mean(E),'o'); 0091 drawnow; 0092 end; 0093 end;