当前位置:编程学习 > Matlab >>

最简单的三层神经网络Matlab实现

人工神经网络(Artificial Neural Network, ANN)[1]具有很广的应用。理论上来说,采用多层的神经网络能够逼近任何的连续函数,而不管该函数是否平滑。自从SVM出现后,大家都快忘了还有ANN这种东东。但是近几年随着deep learning技术的兴起,大牛们又重新开始关注神经网络了。在这里,演示一种个人认为最简单的三层神经网络。它包含一个输入层(即数据本身),一个中间层(又称隐含层,hidden layer), 和一个输出层。其中,输入层的大小即为数据的维度,中间层大小可调,输出层只包括一个神经元。当该神经元输出接近1时,判断为第一类,输出为0时判断为第二类。所有激活函数均采用sigmoid函数:
 
 
 
神经网络的训练最常用的方法是误差反传播(back-propagation, BP),本质上其实是梯度下降算法。其推导过程需要用到高数中的链式法则(chain rule),在此不作推导,只在代码中给出结果。有不明白的小盆友可以发email给我。训练部分的代码如下:
%function output = ANN_Training(X, t, H, max_iter)
%Train ANN with one hidden layer and one output unit for classification
 
 
%input :
%X:  attributes. Every column represents a sample.
%y:  target. should be 0 or 1.  length(y) == size(X,2) is assumed.
%H: size of hidden layer.
%max_iter: maximum iterates
%tol: convergence tolerate
 
 
%output:
% output: a structure containing all network parameters.
 
 
%Created by Ranch Y.Q. Lai on Mar 22, 2011
%ranchlai@163.com
 
 
function output = ANN_Training(X, t, H, max_iter,tol)
[n,N] = size(X);
if N~= length(t)
    error('inconsistent sample size');
end
 
W = randn(n,H); % weight for hidden layer, W(:,i) is the weight vector for unit i
b = randn(H,1); % bias for hidden layer
wo = randn(H,1); %weight for output layer
bo = randn(1,1); %bias, output
y = zeros(H,1); %output of hidden layer
iter = 0;
cost_v = [inf];
fprintf('###################################\n');
fprintf('ANN TRAINING STARTED\n');
fprintf('###################################\n');
fprintf('Iterate \t training error\n');
 
 
while iter < max_iter
    
    
    delta_wo = zeros(H,1);
    delta_bo = 0;
    delta_W = zeros(n,H);
    delta_b = zeros(H,1);
    for i=1:N
        for j=1:H
            y(j) = s(X(:,i),W(:,j),b(j));
        end
        delta_wo = delta_wo + y*(s(y,wo,bo)-t(i))*s(y,wo,bo)*(1-s(y,wo,bo));
        delta_bo = delta_bo + (s(y,wo,bo)-t(i))*s(y,wo,bo)*(1-s(y,wo,bo));
        
        for j=1:H
            delta_W(:,j) = delta_W(:,j) + X(:,i)*(s(y,wo,bo)-t(i))*s(y,wo,bo)*(1-s(y,wo,bo))*wo(j)*s(X(:,i),W(:,j),b(j))*(1-s(X(:,i),W(:,j),b(j)));
            delta_b(j) = delta_b(j) + (s(y,wo,bo)-t(i))*s(y,wo,bo)*(1-s(y,wo,bo))*wo(j)*s(X(:,i),W(:,j),b(j))*(1-s(X(:,i),W(:,j),b(j)));            
        end
    end
    
    step = 1;
    cost =  training_error(N,H,X,t,W,b,wo,bo);
    while step > 1e-10
        
     
        wo1 = wo - step*delta_wo;
        bo1 = bo - step*delta_bo;
        W1 = W - step.*delta_W;
        b1 = b - step.*delta_b;
        cost1 =  training_error(N,H,X,t,W1,b1,wo1,bo1);
        if cost1 < cost
            break;
        end
        step = step * 0.1;
        
    end
    if step <=1e-10
        disp('cannot descend anymore');
        break;
    end
    % update wo,bo,W,b
    wo = wo1;
    bo = bo1;
    W = W1;
    b = b1;
    cost  = cost1;
    
    if iter>0
        fprintf('\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b');
    end
    fprintf('%4d\t%1.10f\n',iter,cost);
    cost_v = [cost_v;cost];
    
    
    if abs(cost - cost_v(end-1))< tol
        break;
    end
    
    iter = iter  + 1;
end
 
 
 
 
output.W = W;
output.b= b;
output.wo = wo;
output.bo =bo;
output.cost = cost_v;
 
 
function cost = training_error(N,H,X,t,W,b,wo,bo)
%total cost
y = zeros(H,1);
cost = 0;
for i=1:N
    for j=1:H
        y(j) = s(X(:,i),W(:,j),b(j));
    end
    
    cost = cost + ( s(y,wo,bo) - t(i)) ^2;
end
cost = cost / N;
 
其中S函数的定义如下:
 
 
%%sigmoid function
function val = s(x,w,b)
val = 1/(1+exp(-w'*x-b));
 
 
 
给定样本集,统计网络的误差
% % function num_correct = test_correct_classification(training_output,X,y)
% Given a test set and a trained ANN, compute how many samples are
% correctly classifed. 
 
%input : 
%training_output:  the result of calling ANN_Training
%X:  attributes. Every column represents a sample. 
%y:  target. should be 0 or 1.  length(y) == size(X,2) is assumed.
 
 
%output:
% num_correct:  number of samples that are correctly classifed. 
function num_correct = test_correct_classification(training_output,X,y)
 
 
num_correct = 0;
 
 
H = length(training_output.wo);
N = length(y);
 yh = zeros(H,1);
for i=1:N
    for j=1:H
        yh(j) = s(X(:,i),training_output.W(:,j),training_output.b(j));
    end
    num_correct = num_correct + (s(yh,training_output.wo,training_output.bo)>0.5 == y(i));
end
end
 
 
下面根据以上代码对样本集进行训练和测试。这里用到了著名
补充:综合编程 , 其他综合 ,
CopyRight © 2012 站长网 编程知识问答 www.zzzyk.com All Rights Reserved
部份技术文章来自网络,