CNN代码理解.doc
《CNN代码理解.doc》由会员分享,可在线阅读,更多相关《CNN代码理解.doc(26页珍藏版)》请在咨信网上搜索。
1、CNN代码理解 下面是自己对代码的注释:cnnexamples.mclear all; close all; clc;addpath(./data);addpath(./util);load mnist_uint8; train_x = double(reshape(train_x,28,28,60000)/255;test_x = double(reshape(test_x,28,28,10000)/255;train_y = double(train_y);test_y = double(test_y); % ex1 %will run 1 epoch in about 200 secon
2、d and get around 11% error. %With 100 epochs youll get around 1.2% error cnn.layers = struct(type, i) %input layer struct(type, c, outputmaps, 6, kernelsize, 5) %convolution layer struct(type, s, scale, 2) %subsampling layer struct(type, c, outputmaps, 12,kernelsize, 5) %convolution layer struct(typ
3、e, s, scale, 2)%subsampling layer; % 这里把cnn的设置给cnnsetup,它会据此构建一个完整的CNN网络,并返回cnn = cnnsetup(cnn, train_x, train_y); % 学习率opts.alpha = 1;% 每次挑出一个batchsize的batch来训练,也就是每用batchsize个样本就调整一次权值,而不是% 把所有样本都输入了,计算所有样本的误差了才调整一次权值opts.batchsize = 50; % 训练次数,用同样的样本集。我训练的时候:% 1的时候 11.41% error% 5的时候 4.2% error%
4、10的时候 2.73% erroropts.numepochs = 10; % 然后开始把训练样本给它,开始训练这个CNN网络cnn = cnntrain(cnn, train_x, train_y, opts); % 然后就用测试样本来测试er, bad = cnntest(cnn, test_x, test_y); %plot mean squared errorplot(cnn.rL);%show test errordisp(num2str(er*100) % error);cnnsetup.mfunction net = cnnsetup(net, x, y) inputmaps =
5、 1; % B=squeeze(A) 返回和矩阵A相同元素但所有单一维都移除的矩阵B,单一维是满足size(A,dim)=1的维。 % train_x中图像的存放方式是三维的reshape(train_x,28,28,60000),前面两维表示图像的行与列, % 第三维就表示有多少个图像。这样squeeze(x(:, :, 1)就相当于取第一个图像样本后,再把第三维 % 移除,就变成了28x28的矩阵,也就是得到一幅图像,再size一下就得到了训练样本图像的行数与列数了 mapsize =size(squeeze(x(:, :, 1); % 下面通过传入net这个结构体来逐层构建CNN网络 %
6、 n = numel(A)返回数组A中元素个数 % net.layers中有五个struct类型的元素,实际上就表示CNN共有五层,这里范围的是5 for l = 1 :numel(net.layers) % layer ifstrcmp(net.layersl.type, s) % 如果这层是 子采样层 %subsampling层的mapsize,最开始mapsize是每张图的大小28*28 % 这里除以scale=2,就是pooling之后图的大小,pooling域之间没有重叠,所以pooling后的图像为14*14 % 注意这里的右边的mapsize保存的都是上一层每张特征map的大小,
7、它会随着循环进行不断更新 mapsize =floor(mapsize / net.layersl.scale); for j =1 : inputmaps % inputmap就是上一层有多少张特征图 net.layersl.bj = 0; % 将偏置初始化为0 end end ifstrcmp(net.layersl.type, c) % 如果这层是 卷积层 % 旧的mapsize保存的是上一层的特征map的大小,那么如果卷积核的移动步长是1,那用 %kernelsize*kernelsize大小的卷积核卷积上一层的特征map后,得到的新的map的大小就是下面这样 mapsize =map
8、size - net.layersl.kernelsize + 1; % 该层需要学习的参数个数。每张特征map是一个(后层特征图数量)*(用来卷积的patch图的大小) % 因为是通过用一个核窗口在上一个特征map层中移动(核窗口每次移动1个像素),遍历上一个特征map % 层的每个神经元。核窗口由kernelsize*kernelsize个元素组成,每个元素是一个独立的权值,所以 % 就有kernelsize*kernelsize个需要学习的权值,再加一个偏置值。另外,由于是权值共享,也就是 % 说同一个特征map层是用同一个具有相同权值元素的kernelsize*kernelsize的核
9、窗口去感受输入上一 % 个特征map层的每个神经元得到的,所以同一个特征map,它的权值是一样的,共享的,权值只取决于 % 核窗口。然后,不同的特征map提取输入上一个特征map层不同的特征,所以采用的核窗口不一样,也 % 就是权值不一样,所以outputmaps个特征map就有(kernelsize*kernelsize+1)* outputmaps那么多的权值了 % 但这里fan_out只保存卷积核的权值W,偏置b在下面独立保存 fan_out= net.layersl.outputmaps * net.layersl.kernelsize 2; for j =1 : net.layers
10、l.outputmaps % output map %fan_out保存的是对于上一层的一张特征map,我在这一层需要对这一张特征map提取outputmaps种特征, % 提取每种特征用到的卷积核不同,所以fan_out保存的是这一层输出新的特征需要学习的参数个数 % 而,fan_in保存的是,我在这一层,要连接到上一层中所有的特征map,然后用fan_out保存的提取特征 % 的权值来提取他们的特征。也即是对于每一个当前层特征图,有多少个参数链到前层 fan_in =inputmaps * net.layersl.kernelsize 2; fori = 1 : inputmaps % i
11、nput map % 随机初始化权值,也就是共有outputmaps个卷积核,对上层的每个特征map,都需要用这么多个卷积核 % 去卷积提取特征。 % rand(n)是产生nn的 0-1之间均匀取值的数值的矩阵,再减去0.5就相当于产生-0.5到0.5之间的随机数 % 再 *2 就放大到 -1, 1。然后再乘以后面那一数,why? % 反正就是将卷积核每个元素初始化为-sqrt(6 / (fan_in + fan_out), sqrt(6 / (fan_in +fan_out) % 之间的随机数。因为这里是权值共享的,也就是对于一张特征map,所有感受野位置的卷积核都是一样的 % 所以只需要保
12、存的是 inputmaps * outputmaps 个卷积核。 net.layersl.kij = (rand(net.layersl.kernelsize) - 0.5) * 2 *sqrt(6 / (fan_in + fan_out); end net.layersl.bj = 0; % 将偏置初始化为0 end % 只有在卷积层的时候才会改变特征map的个数,pooling的时候不会改变个数。这层输出的特征map个数就是 % 输入到下一层的特征map个数 inputmaps = net.layersl.outputmaps; end end % fvnum 是输出层的前面一层的神经元个
13、数。 % 这一层的上一层是经过pooling后的层,包含有inputmaps个特征map。每个特征map的大小是mapsize。 % 所以,该层的神经元个数是 inputmaps * (每个特征map的大小) % prod: Productof elements. % For vectors,prod(X) is the product of the elements of X % 在这里 mapsize = 特征map的行数 特征map的列数,所以prod后就是 特征map的行*列 fvnum =prod(mapsize) * inputmaps; % onum 是标签的个数,也就是输出层神
14、经元的个数。你要分多少个类,自然就有多少个输出神经元 onum = size(y,1); % 这里是最后一层神经网络的设定 % ffb 是输出层每个神经元对应的基biases net.ffb =zeros(onum, 1); % ffW 输出层前一层 与 输出层 连接的权值,这两层之间是全连接的 net.ffW =(rand(onum, fvnum) - 0.5) * 2 * sqrt(6 / (onum + fvnum);endcnntrain.mfunction net = cnntrain(net, x, y, opts) m = size(x, 3); % m 保存的是 训练样本个数
15、numbatches = m/ opts.batchsize; % rem: Remainder after division. rem(x,y) is x -n.*y 相当于求余 % rem(numbatches,1) 就相当于取其小数部分,如果为0,就是整数 ifrem(numbatches, 1) = 0 error(numbatches not integer); end net.rL = ; for i = 1 : opts.numepochs % disp(X) 打印数组元素。如果X是个字符串,那就打印这个字符串 disp(epoch num2str(i) / num2str(op
16、ts.numepochs); % tic 和 toc 是用来计时的,计算这两条语句之间所耗的时间 tic; % P =randperm(N) 返回1, N之间所有整数的一个随机的序列,例如 % randperm(6) 可能会返回 2 4 5 6 1 3 % 这样就相当于把原来的样本排列打乱,再挑出一些样本来训练 kk =randperm(m); forl = 1 : numbatches % 取出打乱顺序后的batchsize个样本和对应的标签 batch_x= x(:, :, kk(l - 1) *opts.batchsize + 1 : l * opts.batchsize); batch
17、_y= y(:, kk(l - 1) * opts.batchsize + 1 : l * opts.batchsize); % 在当前的网络权值和网络输入下计算网络的输出 net =cnnff(net, batch_x); % Feedforward % 得到上面的网络输出后,通过对应的样本标签用bp算法来得到误差对网络权值 %(也就是那些卷积核的元素)的导数 net =cnnbp(net, batch_y); % Backpropagation % 得到误差对权值的导数后,就通过权值更新方法去更新权值 net =cnnapplygrads(net, opts); ifisempty(net
- 配套讲稿:
如PPT文件的首页显示word图标,表示该PPT已包含配套word讲稿。双击word图标可打开word文档。
- 特殊限制:
部分文档作品中含有的国旗、国徽等图片,仅作为作品整体效果示例展示,禁止商用。设计者仅对作品中独创性部分享有著作权。
- 关 键 词:
- 完整 word CNN 代码 理解
1、咨信平台为文档C2C交易模式,即用户上传的文档直接被用户下载,收益归上传人(含作者)所有;本站仅是提供信息存储空间和展示预览,仅对用户上传内容的表现方式做保护处理,对上载内容不做任何修改或编辑。所展示的作品文档包括内容和图片全部来源于网络用户和作者上传投稿,我们不确定上传用户享有完全著作权,根据《信息网络传播权保护条例》,如果侵犯了您的版权、权益或隐私,请联系我们,核实后会尽快下架及时删除,并可随时和客服了解处理情况,尊重保护知识产权我们共同努力。
2、文档的总页数、文档格式和文档大小以系统显示为准(内容中显示的页数不一定正确),网站客服只以系统显示的页数、文件格式、文档大小作为仲裁依据,个别因单元格分列造成显示页码不一将协商解决,平台无法对文档的真实性、完整性、权威性、准确性、专业性及其观点立场做任何保证或承诺,下载前须认真查看,确认无误后再购买,务必慎重购买;若有违法违纪将进行移交司法处理,若涉侵权平台将进行基本处罚并下架。
3、本站所有内容均由用户上传,付费前请自行鉴别,如您付费,意味着您已接受本站规则且自行承担风险,本站不进行额外附加服务,虚拟产品一经售出概不退款(未进行购买下载可退充值款),文档一经付费(服务费)、不意味着购买了该文档的版权,仅供个人/单位学习、研究之用,不得用于商业用途,未经授权,严禁复制、发行、汇编、翻译或者网络传播等,侵权必究。
4、如你看到网页展示的文档有www.zixin.com.cn水印,是因预览和防盗链等技术需要对页面进行转换压缩成图而已,我们并不对上传的文档进行任何编辑或修改,文档下载后都不会有水印标识(原文档上传前个别存留的除外),下载后原文更清晰;试题试卷类文档,如果标题没有明确说明有答案则都视为没有答案,请知晓;PPT和DOC文档可被视为“模板”,允许上传人保留章节、目录结构的情况下删减部份的内容;PDF文档不管是原文档转换或图片扫描而得,本站不作要求视为允许,下载前自行私信或留言给上传者【精****】。
5、本文档所展示的图片、画像、字体、音乐的版权可能需版权方额外授权,请谨慎使用;网站提供的党政主题相关内容(国旗、国徽、党徽--等)目的在于配合国家政策宣传,仅限个人学习分享使用,禁止用于任何广告和商用目的。
6、文档遇到问题,请及时私信或留言给本站上传会员【精****】,需本站解决可联系【 微信客服】、【 QQ客服】,若有其他问题请点击或扫码反馈【 服务填表】;文档侵犯商业秘密、侵犯著作权、侵犯人身权等,请点击“【 版权申诉】”(推荐),意见反馈和侵权处理邮箱:1219186828@qq.com;也可以拔打客服电话:4008-655-100;投诉/维权电话:4009-655-100。