REML estimation of covariance components from Cov{y} FORMAT [Ce,h,W,u] = spm_reml(Cy,X,Q,TOL); Cy - (m x m) data covariance matrix y*y' {y = (m x n) data matrix} X - (m x p) design matrix Q - {1 x q} covariance components TOL - Tolerance {default = 1e-6} Ce - (m x m) estimated errors = h(1)*Q{1} + h(2)*Q{2} + ... h - (q x 1) hyperparameters W - (q x q) W*n = precision of hyperparameter estimates u - {1 x p} estimable components C{i} = u(1,i)*Q{1} + u(2,i)*Q{2} +... ___________________________________________________________________________ @(#)spm_reml.m 2.22 John Ashburner, Karl Friston 03/03/26
0001 function [Ce,h,W,u] = spm_reml(Cy,X,Q,TOL); 0002 % REML estimation of covariance components from Cov{y} 0003 % FORMAT [Ce,h,W,u] = spm_reml(Cy,X,Q,TOL); 0004 % 0005 % Cy - (m x m) data covariance matrix y*y' {y = (m x n) data matrix} 0006 % X - (m x p) design matrix 0007 % Q - {1 x q} covariance components 0008 % TOL - Tolerance {default = 1e-6} 0009 % 0010 % Ce - (m x m) estimated errors = h(1)*Q{1} + h(2)*Q{2} + ... 0011 % h - (q x 1) hyperparameters 0012 % W - (q x q) W*n = precision of hyperparameter estimates 0013 % u - {1 x p} estimable components C{i} = u(1,i)*Q{1} + u(2,i)*Q{2} +... 0014 %___________________________________________________________________________ 0015 % @(#)spm_reml.m 2.22 John Ashburner, Karl Friston 03/03/26 0016 0017 % set tolerance if not specified 0018 %--------------------------------------------------------------------------- 0019 if nargin < 4, TOL = 1e-6; end 0020 0021 % ensure X is not rank deficient 0022 %--------------------------------------------------------------------------- 0023 X = full(X); 0024 X = orth(X); 0025 X = sparse(X); 0026 0027 % find estimable components (encoded in the precision matrix W) 0028 %--------------------------------------------------------------------------- 0029 m = length(Q); 0030 n = length(Cy); 0031 W = zeros(m,m); 0032 for i = 1:m 0033 RQ{i} = Q{i} - X*(X'*Q{i}); 0034 end 0035 for i = 1:m 0036 for j = i:m 0037 dFdhh = sum(sum(RQ{i}.*RQ{j}')); 0038 W(i,j) = dFdhh; 0039 W(j,i) = dFdhh; 0040 end 0041 end 0042 0043 % eliminate inestimable components 0044 % NB: The threshold for normalized eigenvalues is 1e-6 in spm_svd 0045 %--------------------------------------------------------------------------- 0046 u = pr_spm_svd(W); 0047 for i = 1:size(u,2) 0048 C{i} = sparse(n,n); 0049 for j = 1:m 0050 C{i} = C{i} + Q{j}*u(j,i); 0051 end 0052 end 0053 Q = C; 0054 0055 % initialize hyperparameters (assuming Cov{e} = 1} 0056 %--------------------------------------------------------------------------- 0057 m = length(Q); 0058 dFdh = zeros(m,1); 0059 W = zeros(m,m); 0060 C = []; 0061 for i = 1:m 0062 C = [C Q{i}(:)]; 0063 end 0064 I = speye(n,n); 0065 h = inv(C'*C)*(C'*I(:)); 0066 0067 % Iterative EM 0068 %--------------------------------------------------------------------------- 0069 for k = 1:32 0070 0071 % Q are variance components 0072 %------------------------------------------------------------------ 0073 Ce = sparse(n,n); 0074 for i = 1:m 0075 Ce = Ce + h(i)*Q{i}; 0076 end 0077 iCe = inv(Ce); 0078 0079 % E-step: conditional covariance cov(B|y) {Cby} 0080 %=================================================================== 0081 iCeX = iCe*X; 0082 Cby = inv(X'*iCeX); 0083 0084 % M-step: ReML estimate of hyperparameters 0085 %=================================================================== 0086 0087 % Gradient dFd/h (first derivatives) 0088 %------------------------------------------------------------------- 0089 P = iCe - iCeX*Cby*iCeX'; 0090 PCy = Cy*P'- speye(n,n); 0091 for i = 1:m 0092 0093 % dF/dh = -trace(dF/diCe*iCe*Q{i}*iCe) = 0094 %--------------------------------------------------- 0095 PQ{i} = P*Q{i}; 0096 dFdh(i) = sum(sum(PCy.*PQ{i}))/2; 0097 0098 end 0099 0100 % Expected curvature E{ddF/dhh} (second derivatives) 0101 %------------------------------------------------------------------- 0102 for i = 1:m 0103 for j = i:m 0104 0105 % ddF/dhh = -trace{P*Q{i}*P*Q{j}} 0106 %--------------------------------------------------- 0107 dFdhh = sum(sum(PQ{i}.*PQ{j}))/2; 0108 W(i,j) = dFdhh; 0109 W(j,i) = dFdhh; 0110 end 0111 end 0112 0113 % Fisher scoring: update dh = -inv(ddF/dhh)*dF/dh 0114 %------------------------------------------------------------------- 0115 dh = inv(W)*dFdh; 0116 h = h + dh; 0117 0118 % Convergence (or break if there is only one hyperparameter) 0119 %=================================================================== 0120 w = dFdh'*dFdh; 0121 if w < TOL | m == 1, break, end 0122 fprintf('%-30s: %i %30s%e\n',' ReML Iteration',k,'...',full(w)); 0123 end 0124 0125 % estimate of cov{e} 0126 %--------------------------------------------------------------------------- 0127 Ce = sparse(n,n); 0128 for i = 1:m 0129 Ce = Ce + h(i)*Q{i}; 0130 end 0131 0132 % rotate hyperparameter esimates and precision back 0133 %--------------------------------------------------------------------------- 0134 h = u*h; 0135 W = u*W*u';