最简单的三层神经网络Matlab实现
人工神经网络(Artificial Neural Network, ANN)[1]具有很广的应用。理论上来说,采用多层的神经网络能够逼近任何的连续函数,而不管该函数是否平滑。自从SVM出现后,大家都快忘了还有ANN这种东东。但是近几年随着deep learning技术的兴起,大牛们又重新开始关注神经网络了。在这里,演示一种个人认为最简单的三层神经网络。它包含一个输入层(即数据本身),一个中间层(又称隐含层,hidden layer), 和一个输出层。其中,输入层的大小即为数据的维度,中间层大小可调,输出层只包括一个神经元。当该神经元输出接近1时,判断为第一类,输出为0时判断为第二类。所有激活函数均采用sigmoid函数:神经网络的训练最常用的方法是误差反传播(back-propagation, BP),本质上其实是梯度下降算法。其推导过程需要用到高数中的链式法则(chain rule),在此不作推导,只在代码中给出结果。有不明白的小盆友可以发email给我。训练部分的代码如下:
%function output = ANN_Training(X, t, H, max_iter)
%Train ANN with one hidden layer and one output unit for classification
%input :
%X: attributes. Every column represents a sample.
%y: target. should be 0 or 1. length(y) == size(X,2) is assumed.
%H: size of hidden layer.
%max_iter: maximum iterates
%tol: convergence tolerate
%output:
% output: a structure containing all network parameters.
%Created by Ranch Y.Q. Lai on Mar 22, 2011
%ranchlai@163.com
function output = ANN_Training(X, t, H, max_iter,tol)
[n,N] = size(X);
if N~= length(t)
error('inconsistent sample size');
end
W = randn(n,H); % weight for hidden layer, W(:,i) is the weight vector for unit i
b = randn(H,1); % bias for hidden layer
wo = randn(H,1); %weight for output layer
bo = randn(1,1); %bias, output
y = zeros(H,1); %output of hidden layer
iter = 0;
cost_v = [inf];
fprintf('###################################\n');
fprintf('ANN TRAINING STARTED\n');
fprintf('###################################\n');
fprintf('Iterate \t training error\n');
while iter < max_iter
delta_wo = zeros(H,1);
delta_bo = 0;
delta_W = zeros(n,H);
delta_b = zeros(H,1);
for i=1:N
for j=1:H
y(j) = s(X(:,i),W(:,j),b(j));
end
delta_wo = delta_wo + y*(s(y,wo,bo)-t(i))*s(y,wo,bo)*(1-s(y,wo,bo));
delta_bo = delta_bo + (s(y,wo,bo)-t(i))*s(y,wo,bo)*(1-s(y,wo,bo));
for j=1:H
delta_W(:,j) = delta_W(:,j) + X(:,i)*(s(y,wo,bo)-t(i))*s(y,wo,bo)*(1-s(y,wo,bo))*wo(j)*s(X(:,i),W(:,j),b(j))*(1-s(X(:,i),W(:,j),b(j)));
delta_b(j) = delta_b(j) + (s(y,wo,bo)-t(i))*s(y,wo,bo)*(1-s(y,wo,bo))*wo(j)*s(X(:,i),W(:,j),b(j))*(1-s(X(:,i),W(:,j),b(j)));
end
end
step = 1;
cost = training_error(N,H,X,t,W,b,wo,bo);
while step > 1e-10
wo1 = wo - step*delta_wo;
bo1 = bo - step*delta_bo;
W1 = W - step.*delta_W;
b1 = b - step.*delta_b;
cost1 = training_error(N,H,X,t,W1,b1,wo1,bo1);
if cost1 < cost
break;
end
step = step * 0.1;
end
if step <=1e-10
disp('cannot descend anymore');
break;
end
% update wo,bo,W,b
wo = wo1;
bo = bo1;
W = W1;
b = b1;
cost = cost1;
if iter>0
fprintf('\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b');
end
fprintf('%4d\t%1.10f\n',iter,cost);
cost_v = [cost_v;cost];
if abs(cost - cost_v(end-1))< tol
break;
end
iter = iter + 1;
end
output.W = W;
output.b= b;
output.wo = wo;
output.bo =bo;
output.cost = cost_v;
function cost = training_error(N,H,X,t,W,b,wo,bo)
%total cost
y = zeros(H,1);
cost = 0;
for i=1:N
for j=1:H
y(j) = s(X(:,i),W(:,j),b(j));
end
cost = cost + ( s(y,wo,bo) - t(i)) ^2;
end
cost = cost / N;
其中S函数的定义如下:
%%sigmoid function
function val = s(x,w,b)
val = 1/(1+exp(-w'*x-b));
给定样本集,统计网络的误差
% % function num_correct = test_correct_classification(training_output,X,y)
% Given a test set and a trained ANN, compute how many samples are
% correctly classifed.
%input :
%training_output: the result of calling ANN_Training
%X: attributes. Every column represents a sample.
%y: target. should be 0 or 1. length(y) == size(X,2) is assumed.
%output:
% num_correct: number of samples that are correctly classifed.
function num_correct = test_correct_classification(training_output,X,y)
num_correct = 0;
H = length(training_output.wo);
N = length(y);
yh = zeros(H,1);
for i=1:N
for j=1:H
yh(j) = s(X(:,i),training_output.W(:,j),training_output.b(j));
end
num_correct = num_correct + (s(yh,training_output.wo,training_output.bo)>0.5 == y(i));
end
end
下面根据以上代码对样本集进行训练和测试。这里用到了著名
补充:综合编程 , 其他综合 ,
- 更多Matlab疑问解答:
- matlab中信噪比与误码率曲线
- matlab的marker(比如"o"、"+")怎样才能变粗一点呢?
- 用matlab编出大直径墩柱上不规则的波浪力程序
- matlab语音识别系统设计中如何读取wav文件
- matlab编程求助: R(1,1)=y*(p^2+x^2) 、R(1,2)=Y*(P*X)解方程组 结果得到的是含有R的表达式 而不是值
- matlab 用后diff 数据长度变小 l
- 怎么在这段MATLAB程序导入的心电信号上添加一段50HZ的噪声?
- matlab gui 清除图形
- matlab在命令行中运行很好,存成m文件就不行,怎么办?
- 谁能帮我看下 这段程序到底哪出错了 关于Matlab牛顿环仿真的程序 高手救命啊
- matlab 2011a 下好了 怎么安装啊 没有setup.exe
- 帮我看一下下面的matlab程序,运行时提示有错误,谢谢
- matlab代码纠错
- matlab中如何将若干个大小相等的二位数组存入三维矩阵中(不用嵌套循环)?
- a(2:3,2:3)matlab元宝数组 这个是什么意思