0001 function net = rbfnet(varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 global in_data out_data;
0018
0019
0020 m = varargin{1};
0021 k = varargin{2};
0022
0023
0024
0025 if (prod(size(k)) == 1),
0026 k = k*ones(1,m);
0027 end;
0028
0029
0030 if (prod(size(k)) ~= m),
0031 error('rbfnet: no. of bases per dimension not specified');
0032 end;
0033 k = k(:)';
0034
0035
0036 net.nin = m;
0037 net.hiddist = k;
0038 net.nhid = prod(k);
0039 net.nout = varargin{3};
0040 lrate = varargin{4};
0041 irange = varargin{5};
0042 vers = varargin{6};
0043
0044
0045 net.alpha = lrate(1);
0046 if length(lrate)>1,
0047 net.beta = lrate(2);
0048 else,
0049 net.beta = 0;
0050 end;
0051
0052
0053 net.c = zeros(net.nin,net.nhid);
0054 net.h = zeros(net.nhid,1);
0055
0056 net.v = 0.001.*rand(net.nout,net.nhid);
0057 net.w = zeros(net.nin,net.nhid,net.nout);
0058
0059
0060 net.min = min(irange);
0061 net.max = max(irange);
0062 net.range = net.max-net.min;
0063
0064
0065 net.out = zeros(net.nout,1);
0066 net.in = zeros(net.nin,1);
0067
0068
0069 net.mu = zeros(net.nin,net.nhid);
0070 net.sg = zeros(net.nin,net.nhid);
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082
0083
0084
0085
0086
0087
0088 a = cell(net.nin,1);
0089 for i=1:net.nin,
0090
0091
0092
0093
0094 switch vers,
0095 case 1,
0096 [a{i} sg{i}]= getcenter1(net.min,net.max,k(i),i);
0097 case 2,
0098 [a{i} sg{i}]= getcenter2(net.min,net.max,k(i),i);
0099 case 3,
0100 [a{i} sg{i}]= getcenter3(net.min,net.max,k(i),i);
0101 case 4,
0102 disp('Using k-means to determine center');
0103 otherwise,
0104 error('Incorrect version');
0105 end;
0106 end;
0107
0108
0109
0110
0111 if vers<4,
0112 b = cell(net.nin,1);
0113 c = cell(net.nin,1);
0114 if net.nin == 1,
0115 b{1} = a{1};
0116 c{1} = sg{1};
0117 else,
0118 [b{:}] = ndgrid(a{:});
0119 [c{:}] = ndgrid(sg{:});
0120 end;
0121
0122 for i=1:net.nin,
0123 net.mu(i,:) = [b{i}(:)]';
0124 net.sg(i,:) = [c{i}(:)]';
0125 end;
0126 else,
0127 [idx,C] = kmeans(in_data',net.nhid);
0128 net.mu = C';
0129 for i=1:net.nhid,
0130
0131 net.sg(:,i) = var(in_data(:,idx==i)');
0132 end;
0133 end
0134
0135
0136
0137
0138 function [pos, sg] = getcenter1(netmin,netmax,num_hid,dim_num)
0139
0140
0141
0142
0143
0144 if (num_hid == 1),
0145 pos = (netmin+netmax)/2;
0146 sg = (netmax-netmin)/2;
0147 else
0148
0149
0150
0151
0152 pos = linspace(netmin,netmax,2*num_hid+1);
0153 pos = pos(2:2:end);
0154 sg = ((netmax-netmin)/(num_hid-1))/2*ones(num_hid,1);
0155 end;
0156
0157
0158
0159 function [pos, sg]= getcenter2(netmin,netmax,num_hid,dim_num)
0160
0161 global in_data out_data
0162
0163 if (num_hid == 1),
0164 pos = mean(in_data(dim_num,:));
0165 sg=std(in_data(dim_num,:));
0166 else
0167
0168 pos_i = linspace(netmin,netmax,2*num_hid+1);
0169 pos_i = pos_i(2:2:end);
0170
0171 for ii=1:num_hid
0172 kk=find(in_data(dim_num,:)>=netmin+(ii-1)*(netmax-netmin)/num_hid & in_data(dim_num,:)<=netmin+(ii)*(netmax-netmin)/num_hid);
0173
0174
0175 pos(ii)=mean(in_data(dim_num,kk));
0176
0177
0178
0179 sg(ii)=2*std(in_data(dim_num,kk));
0180 kk=[];
0181 end
0182 end
0183
0184
0185 idx =find(isnan(sg));
0186 if ~isempty(idx),
0187 sg(idx) = 10;
0188 end;
0189
0190
0191 idx =find(isnan(pos));
0192 if ~isempty(idx),
0193 pos(idx) = pos_i(idx);
0194 end;
0195
0196
0197
0198
0199
0200
0201 function [pos, sg]= getcenter3(netmin,netmax,num_hid,dim_num)
0202
0203 global in_data out_data
0204
0205 Nb = num_hid;
0206 [h,x] = hist(in_data(dim_num,:),512);
0207 N = length(h);
0208 f = convn(h,hamming(14),'same');
0209 beta0 = [ones(Nb,1)/Nb,linspace(netmin,netmax,Nb)',4*ones(Nb,1)];
0210 [beta] = gmem(x,f,beta0);
0211 pos = beta(:,2);
0212
0213 if (num_hid == 1),
0214 sg = (netmax-netmin)/2;
0215 else
0216
0217
0218
0219
0220 sg = ((netmax-netmin)/(num_hid-1))/2*ones(num_hid,1);
0221 end;