Sunday 28 October 2018

mnistdeepauto analysis 8

%CG_MNIST.m
% Version 1.000
%
% Code provided by Ruslan Salakhutdinov and Geoff Hinton
%
% Permission is granted for anyone to copy, use, modify, or distribute this
% program and accompanying programs and documents for any purpose, provided
% this copyright notice is retained and prominently displayed, along with
% a note saying that the original programs are available from our
% web page.
% The programs and documents are distributed without any warranty, express or
% implied.  As the programs were written for research purposes only, they have
% not been tested to the degree that would be advisable in any important
% application.  All use of these programs is entirely at the user's own risk.

function [f, df] = CG_MNIST(VV,Dim,XX);

l1 = Dim(1); % 784
l2 = Dim(2); % 1000
l3 = Dim(3); % 500
l4= Dim(4); % 250
l5= Dim(5); % 30
l6= Dim(6); % 250
l7= Dim(7); % 500
l8= Dim(8); % 1000
l9= Dim(9); % 784
N = size(XX,1); % 1000

% Do decomversion.
 w1 = reshape(VV(1:(l1+1)*l2),l1+1,l2); % 785x1000
 xxx = (l1+1)*l2; % 785000
 w2 = reshape(VV(xxx+1:xxx+(l2+1)*l3),l2+1,l3); % 1001x500
 xxx = xxx+(l2+1)*l3; % 1285500
 w3 = reshape(VV(xxx+1:xxx+(l3+1)*l4),l3+1,l4); % 501x250
 xxx = xxx+(l3+1)*l4;
 w4 = reshape(VV(xxx+1:xxx+(l4+1)*l5),l4+1,l5); % 251x30
 xxx = xxx+(l4+1)*l5;
 w5 = reshape(VV(xxx+1:xxx+(l5+1)*l6),l5+1,l6); % 31x250
 xxx = xxx+(l5+1)*l6;
 w6 = reshape(VV(xxx+1:xxx+(l6+1)*l7),l6+1,l7); % 251x500
 xxx = xxx+(l6+1)*l7;
 w7 = reshape(VV(xxx+1:xxx+(l7+1)*l8),l7+1,l8); % 501x1000
 xxx = xxx+(l7+1)*l8;
 w8 = reshape(VV(xxx+1:xxx+(l8+1)*l9),l8+1,l9); % 1001x784


  XX = [XX ones(N,1)]; % 1000x785
  w1probs = 1./(1 + exp(-XX*w1)); w1probs = [w1probs  ones(N,1)]; % 1000x785 * 785x1000 = 1000x100->1000x1001
  w2probs = 1./(1 + exp(-w1probs*w2)); w2probs = [w2probs ones(N,1)]; % 1000x1001 * 1001x500 = 1000x500->1000x501
  w3probs = 1./(1 + exp(-w2probs*w3)); w3probs = [w3probs  ones(N,1)]; % 1000x501 * 501x250 => 1000x251
  w4probs = w3probs*w4; w4probs = [w4probs  ones(N,1)]; % 1000x251 * 251x30 = 1000x31
  w5probs = 1./(1 + exp(-w4probs*w5)); w5probs = [w5probs  ones(N,1)]; % 1000x31 * 31x250 => 1000x251
  w6probs = 1./(1 + exp(-w5probs*w6)); w6probs = [w6probs  ones(N,1)]; % 1000x251 * 251x500 => 1000x501
  w7probs = 1./(1 + exp(-w6probs*w7)); w7probs = [w7probs  ones(N,1)]; % 1000x501 * 501x1000 => 1000x1001
  XXout = 1./(1 + exp(-w7probs*w8)); % 1000x1001 * 1001x784 = 1000x784

f = -1/N*sum(sum( XX(:,1:end-1).*log(XXout) + (1-XX(:,1:end-1)).*log(1-XXout))); % 1000x785-1 .* 1000x784 --> 1x1
IO = 1/N*(XXout-XX(:,1:end-1)); % 1000x784
Ix8=IO; % 1000x784
dw8 =  w7probs'*Ix8; % 1001x1000 * 1000x784 = 1001x784

Ix7 = (Ix8*w8').*w7probs.*(1-w7probs); % 1000x784 * 784x1001 .* 1000x1001 .* 1-1000x1001 = 1000x1001
Ix7 = Ix7(:,1:end-1); % 1000x1000 
dw7 =  w6probs'*Ix7; % 501x1000 * 1000x1000 = 501x1000

Ix6 = (Ix7*w7').*w6probs.*(1-w6probs); % 1000x1000 * 1000x501 = 1000x501
Ix6 = Ix6(:,1:end-1); % 1000x500
dw6 =  w5probs'*Ix6; % 251x500

Ix5 = (Ix6*w6').*w5probs.*(1-w5probs); % 1000x500 * 500x251 =1000x251
Ix5 = Ix5(:,1:end-1); % 1000x250
dw5 =  w4probs'*Ix5; % 31x250

Ix4 = (Ix5*w5'); % 1000x250 * 250x31 = 1000x31
Ix4 = Ix4(:,1:end-1); % 1000x30
dw4 =  w3probs'*Ix4; % 251x100 * 1000x30 = 251x30

Ix3 = (Ix4*w4').*w3probs.*(1-w3probs); % 1000x30 * 30x251 .* 1000x251 = 1000x251
Ix3 = Ix3(:,1:end-1); % 1000x250
dw3 =  w2probs'*Ix3; % 501x250

Ix2 = (Ix3*w3').*w2probs.*(1-w2probs); % 1000x250 * 250x501 .*1000x501 = 1000x501
Ix2 = Ix2(:,1:end-1); % 1000x500
dw2 =  w1probs'*Ix2; % 1001x500

Ix1 = (Ix2*w2').*w1probs.*(1-w1probs); %1000x500 * 500x1001 = 1000x1001
Ix1 = Ix1(:,1:end-1); % 1000x1000
dw1 =  XX'*Ix1; % 785x1000

df = [dw1(:)' dw2(:)' dw3(:)' dw4(:)' dw5(:)' dw6(:)'  dw7(:)'  dw8(:)'  ]'; % 2837314x1