ReML estimation of covariance components from y*y' FORMAT [C,h,Ph,F] = pr_spm_reml(YY,X,Q,N,[OPT]); YY - (m x m) sample covariance matrix Y*Y' {Y = (m x N) data matrix} X - (m x p) design matrix Q - {1 x q} covariance components N - number of samples OPT = 1 : log-normal hyper-parameterisation (with hyperpriors) C - (m x m) estimated errors = h(1)*Q{1} + h(2)*Q{2} + ... h - (q x 1) ReML hyperparameters h Ph - (q x q) conditional precision of h [or log(h), if OPT(1)] F - [-ve] free energy F = log evidence = p(Y|X,Q) = ReML objective Performs a Fisher-Scoring ascent on F to find ReML variance parameter estimates. __________________________________________________________________________ Copyright (C) 2005 Wellcome Department of Imaging Neuroscience
0001 function [C,h,Ph,F] = pr_spm_reml(YY,X,Q,N,OPT); 0002 % ReML estimation of covariance components from y*y' 0003 % FORMAT [C,h,Ph,F] = pr_spm_reml(YY,X,Q,N,[OPT]); 0004 % 0005 % YY - (m x m) sample covariance matrix Y*Y' {Y = (m x N) data matrix} 0006 % X - (m x p) design matrix 0007 % Q - {1 x q} covariance components 0008 % N - number of samples 0009 % 0010 % OPT = 1 : log-normal hyper-parameterisation (with hyperpriors) 0011 % 0012 % C - (m x m) estimated errors = h(1)*Q{1} + h(2)*Q{2} + ... 0013 % h - (q x 1) ReML hyperparameters h 0014 % Ph - (q x q) conditional precision of h [or log(h), if OPT(1)] 0015 % 0016 % F - [-ve] free energy F = log evidence = p(Y|X,Q) = ReML objective 0017 % 0018 % Performs a Fisher-Scoring ascent on F to find ReML variance parameter 0019 % estimates. 0020 %__________________________________________________________________________ 0021 % Copyright (C) 2005 Wellcome Department of Imaging Neuroscience 0022 0023 % John Ashburner & Karl Friston 0024 % $Id: spm_reml.m 456 2006-02-22 18:46:29Z karl $ 0025 0026 % assume a single sample if not specified 0027 %-------------------------------------------------------------------------- 0028 try 0029 N; 0030 catch 0031 N = 1; 0032 end 0033 0034 % assume OPT = [0 0] 0035 %-------------------------------------------------------------------------- 0036 try 0037 OPT; 0038 catch 0039 OPT = 0; 0040 end 0041 0042 % ortho-normalise X 0043 %-------------------------------------------------------------------------- 0044 if isempty(X) 0045 X = sparse(length(Q{1}),1); 0046 end 0047 X = orth(full(X)); 0048 [n p] = size(X); 0049 0050 % initialise h 0051 %-------------------------------------------------------------------------- 0052 m = length(Q); 0053 h = zeros(m,1); 0054 dh = zeros(m,1); 0055 dFdh = zeros(m,1); 0056 dFdhh = zeros(m,m); 0057 0058 0059 % initialise and specify hyperpriors 0060 %-------------------------------------------------------------------------- 0061 if OPT 0062 hP = eye(m,m)/32; 0063 hE = h - 32; 0064 for i = 1:m 0065 h(i) = -log(normest(Q{i})); 0066 end 0067 else 0068 hE = zeros(m,1); 0069 hP = zeros(m,m); 0070 for i = 1:m 0071 h(i) = any(diag(Q{i})); 0072 end 0073 end 0074 0075 0076 % ReML (EM/VB) 0077 %-------------------------------------------------------------------------- 0078 for k = 1:64 0079 0080 % compute current estimate of covariance 0081 %---------------------------------------------------------------------- 0082 C = sparse(n,n); 0083 for i = 1:m 0084 if OPT 0085 C = C + Q{i}*exp(h(i)); 0086 else 0087 C = C + Q{i}*h(i); 0088 end 0089 end 0090 iC = inv(C); 0091 0092 % E-step: conditional covariance cov(B|y) {Cq} 0093 %====================================================================== 0094 iCX = iC*X; 0095 Cq = pinv(X'*iCX); 0096 XCXiC = X*Cq*iCX'; 0097 0098 % M-step: ReML estimate of hyperparameters 0099 %====================================================================== 0100 0101 % Gradient dF/dh (first derivatives) 0102 %---------------------------------------------------------------------- 0103 P = iC - iC*XCXiC; 0104 U = speye(n) - P*YY/N; 0105 for i = 1:m 0106 0107 % dF/dh = -trace(dF/diC*iC*Q{i}*iC) 0108 %------------------------------------------------------------------ 0109 PQ{i} = P*Q{i}; 0110 if OPT 0111 PQ{i} = PQ{i}*exp(h(i)); 0112 end 0113 dFdh(i) = -trace(PQ{i}*U)*N/2; 0114 0115 end 0116 0117 % Expected curvature E{dF/dhh} (second derivatives) 0118 %---------------------------------------------------------------------- 0119 for i = 1:m 0120 for j = i:m 0121 0122 % dF/dhh = -trace{P*Q{i}*P*Q{j}} 0123 %-------------------------------------------------------------- 0124 dFdhh(i,j) = -trace(PQ{i}*PQ{j})*N/2; 0125 dFdhh(j,i) = dFdhh(i,j); 0126 0127 end 0128 end 0129 0130 % add hyperpriors 0131 %---------------------------------------------------------------------- 0132 dFdh = dFdh - hP*(h - hE); 0133 dFdhh = dFdhh - hP; 0134 0135 % Fisher scoring: update dh = -inv(ddF/dhh)*dF/dh 0136 %---------------------------------------------------------------------- 0137 Ph = -dFdhh; 0138 dh = -pinv(dFdhh)*dFdh; 0139 0140 % preclude numerical overflow 0141 %---------------------------------------------------------------------- 0142 if OPT 0143 dh = min(dh, 8); 0144 dh = max(dh,-8); 0145 end 0146 h = h + dh; 0147 0148 % Convergence (1% change in log-evidence) 0149 %====================================================================== 0150 w = dFdh'*dh; 0151 fprintf('%-30s: %i %30s%e\n',' ReML Iteration',k,'...',full(w)); 0152 if w < 1e-2, break, end 0153 0154 end 0155 0156 % log evidence = ln p(y|X,Q) = ReML objective = F = trace(R'*iC*R*YY)/2 ... 0157 %-------------------------------------------------------------------------- 0158 if nargout > 3 0159 0160 F = - trace(C*P*YY*P)/2 ... 0161 - N*n*log(2*pi)/2 ... 0162 - N*pr_spm_logdet(C)/2 ... 0163 + N*pr_spm_logdet(Cq)/2 ... 0164 - pr_spm_logdet(Ph)/2 ... 0165 + pr_spm_logdet(hP)/2; 0166 end 0167 0168 % return exp(h) if log-normal hyperpriors 0169 %-------------------------------------------------------------------------- 0170 if OPT 0171 h = exp(h); 0172 end 0173