Wednesday 7 November 2018

mnistclassify analysis 4

CG_CLASSIFY.m
% Version 1.000
%
% Code provided by Ruslan Salakhutdinov and Geoff Hinton
%
% Permission is granted for anyone to copy, use, modify, or distribute this
% program and accompanying programs and documents for any purpose, provided
% this copyright notice is retained and prominently displayed, along with
% a note saying that the original programs are available from our
% web page.
% The programs and documents are distributed without any warranty, express or
% implied.  As the programs were written for research purposes only, they have
% not been tested to the degree that would be advisable in any important
% application.  All use of these programs is entirely at the user's own risk.


function [f, df] = CG_CLASSIFY(VV,Dim,XX,target);

l1 = Dim(1); % 784
l2 = Dim(2); % 250
l3= Dim(3); % 250
l4= Dim(4); 5 50
l5= Dim(5); % 10
N = size(XX,1); % 1000 785

% Do decomversion.
 w1 = reshape(VV(1:(l1+1)*l2),l1+1,l2); % 785 250 = 272060(1:785*250, 785,250)
 xxx = (l1+1)*l2; % 785 * 250
 w2 = reshape(VV(xxx+1:xxx+(l2+1)*l3),l2+1,l3); % 251 250
 xxx = xxx+(l2+1)*l3;
 w3 = reshape(VV(xxx+1:xxx+(l3+1)*l4),l3+1,l4); % 251 50
 xxx = xxx+(l3+1)*l4;
 w_class = reshape(VV(xxx+1:xxx+(l4+1)*l5),l4+1,l5); % 51 10


  XX = [XX ones(N,1)]; %1000 785  data+1col
  w1probs = 1./(1 + exp(-XX*w1)); w1probs = [w1probs  ones(N,1)]; % 1000 785 * 785 250 -> 1000 251
  w2probs = 1./(1 + exp(-w1probs*w2)); w2probs = [w2probs ones(N,1)]; % 1000 251 * 251 250 -> 1000 251
  w3probs = 1./(1 + exp(-w2probs*w3)); w3probs = [w3probs  ones(N,1)]; % 1000 251 * 251 50 -> 1000 51

  targetout = exp(w3probs*w_class); % 1000 51 * 51 10 = 1000 10
  targetout = targetout./repmat(sum(targetout,2),1,10); % repmat((1000 1),1,10) normalize
  f = -sum(sum( target(:,1:end).*log(targetout))) ;  % cross enthropy

IO = (targetout-target(:,1:end));  % 1000 10
Ix_class=IO;
dw_class =  w3probs'*Ix_class; % 1000 51' * 1000 10 = 51 10

Ix3 = (Ix_class*w_class').*w3probs.*(1-w3probs); % 1000 10 * 51 10' = 1000 51 .* 1000 51 = 1000 51
Ix3 = Ix3(:,1:end-1); % 1000 50
dw3 =  w2probs'*Ix3; %1000 251' * 1000 50 = 251 50

Ix2 = (Ix3*w3').*w2probs.*(1-w2probs); % 1000 50 * 251 50' .* 1000 251 = 1000 251
Ix2 = Ix2(:,1:end-1); % 1000 250
dw2 =  w1probs'*Ix2; % 1000 251' * 1000 250 = 251 250

Ix1 = (Ix2*w2').*w1probs.*(1-w1probs); % 1000 250 * 251 250' .* 1000 251 = 1000 251
Ix1 = Ix1(:,1:end-1); % 1000 250
dw1 =  XX'*Ix1; % 1000 785' * 1000 250 = 785 250

df = [dw1(:)' dw2(:)' dw3(:)' dw_class(:)']'; % 272060 1