Added fast k-fold cv
This commit is contained in:
parent
41fd0b24e5
commit
7325edec01
2 changed files with 91 additions and 2 deletions
91
src/TR.jl
91
src/TR.jl
|
|
@ -64,7 +64,7 @@ end
|
||||||
"""
|
"""
|
||||||
function TRLooCV
|
function TRLooCV
|
||||||
|
|
||||||
bpress, bgcv, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV = TRLooCV(X, y, lambdas, regType="L2", regParam1=1, regParam2=1)
|
bpress, bgcv, rmsecv, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV = TRLooCV(X, y, lambdas, regType="L2", regParam1=1, regParam2=1)
|
||||||
|
|
||||||
regType: 'bc', 'legendre', 'L2', 'std', 'GL'
|
regType: 'bc', 'legendre', 'L2', 'std', 'GL'
|
||||||
"""
|
"""
|
||||||
|
|
@ -109,6 +109,7 @@ H = broadcast(.+, U.^2 * broadcast(./, s, denom), 1/n);
|
||||||
resid = broadcast(.-, y, U * broadcast(./, s .* (U'*y), denom));
|
resid = broadcast(.-, y, U * broadcast(./, s .* (U'*y), denom));
|
||||||
rescv = broadcast(./, resid, broadcast(.-, 1, H));
|
rescv = broadcast(./, resid, broadcast(.-, 1, H));
|
||||||
press = vec(sum(rescv.^2, dims=1));
|
press = vec(sum(rescv.^2, dims=1));
|
||||||
|
rmsecv = sqrt.(1/n .* press);
|
||||||
GCV = vec(broadcast(./, sum(resid.^2, dims=1), mean(broadcast(.-, 1, H), dims=1).^2));
|
GCV = vec(broadcast(./, sum(resid.^2, dims=1), mean(broadcast(.-, 1, H), dims=1).^2));
|
||||||
|
|
||||||
idminPRESS = findmin(press)[2][1]; # First index selects coordinates, second selects '1st coordinate'
|
idminPRESS = findmin(press)[2][1]; # First index selects coordinates, second selects '1st coordinate'
|
||||||
|
|
@ -126,5 +127,91 @@ end
|
||||||
bpress = bcoeffs[:, idminPRESS]
|
bpress = bcoeffs[:, idminPRESS]
|
||||||
bgcv = bcoeffs[:, idminGCV]
|
bgcv = bcoeffs[:, idminGCV]
|
||||||
|
|
||||||
return bpress, bgcv, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV
|
return bpress, bgcv, rmsecv, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
function TRSegCV(X, y, lambdas, cv, regType="L2", regParam1=0, regParam2=1e-14)
|
||||||
|
|
||||||
|
Segmented cross-validation based on the Sherman-Morrison-Woodbury updating formula.
|
||||||
|
Inputs:
|
||||||
|
- X : Data matrix
|
||||||
|
- y : Response vector
|
||||||
|
- lambdas : Vector of regularization parameter values
|
||||||
|
- cv : Vector of length n indicating segment membership for each sample
|
||||||
|
- regType, regParam1, regParam2 : Inputs to regularizationMatrix function
|
||||||
|
|
||||||
|
Outputs: b_lambda_min, rmsecv, lambda_min, lambda_min_ind
|
||||||
|
"""
|
||||||
|
function TRSegCV(X, y, lambdas, cv, regType="L2", regParam1=0, regParam2=1e-14)
|
||||||
|
|
||||||
|
n, p = size(X);
|
||||||
|
mX = mean(X, dims=1);
|
||||||
|
X = X .- mX;
|
||||||
|
my = mean(y);
|
||||||
|
y = y .- my;
|
||||||
|
|
||||||
|
if regType == "bc"
|
||||||
|
regMat = [I(p); zeros(regParam1,p)]; for i = 1:regParam1 regMat = diff(regMat, dims = 1); end
|
||||||
|
elseif regType == "legendre"
|
||||||
|
regMat = [I(p); zeros(regParam1,p)]; for i = 1:regParam1 regMat = diff(regMat, dims = 1); end
|
||||||
|
P, _ = plegendre(regParam-1, p);
|
||||||
|
regMat[end-regParam1+1:end,:] = sqrt(regParam2) * P;
|
||||||
|
elseif regType == "L2"
|
||||||
|
regMat = I(p);
|
||||||
|
elseif regType == "std"
|
||||||
|
regMat = Diagonal(vec(std(X, dims=1)));
|
||||||
|
elseif regType == "GL" # Grünwald-Letnikov fractional derivative regulariztion
|
||||||
|
# regParam1 is alpha (order of fractional derivative)
|
||||||
|
C = ones(p)*1.0;
|
||||||
|
for k in 2:p
|
||||||
|
C[k] = (1-(regParam1+1)/(k-1)) * C[k-1];
|
||||||
|
end
|
||||||
|
|
||||||
|
regMat = zeros(p,p);
|
||||||
|
|
||||||
|
for i in 1:p
|
||||||
|
regMat[i:end, i] = C[1:end-i+1];
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
X = X / regMat;
|
||||||
|
U, s, V = svd(X, full=false);
|
||||||
|
|
||||||
|
n_seg = maximum(cv);
|
||||||
|
n_lambdas = length(lambdas);
|
||||||
|
my = mean(y);
|
||||||
|
y = y .- my;
|
||||||
|
|
||||||
|
denom = broadcast(.+, broadcast(./, lambdas, s'), s')';
|
||||||
|
resid = broadcast(.-, y, U * broadcast(./, s .* (U'*y), denom));
|
||||||
|
rescv = zeros(n, n_lambdas);
|
||||||
|
sdenom = sqrt.(broadcast(./, s, denom))';
|
||||||
|
|
||||||
|
for seg in 1:n_seg
|
||||||
|
|
||||||
|
Useg = U[vec(cv .== seg),:];
|
||||||
|
Id = 1.0 * I(size(Useg,1)) .- 1/n;
|
||||||
|
|
||||||
|
for k in 1:n_lambdas
|
||||||
|
Uk = Useg .* sdenom[k,:]';
|
||||||
|
rescv[vec(cv .== seg), k] = (Id - Uk * Uk') \ resid[vec(cv .== seg), k];
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
press = sum(rescv.^2, dims=1)';
|
||||||
|
rmsecv = sqrt.(1/n .* press);
|
||||||
|
bcoeffs = V * broadcast(./, (U' * y), denom);
|
||||||
|
bcoeffs = regMat \ bcoeffs;
|
||||||
|
|
||||||
|
if my != 0
|
||||||
|
bcoeffs = [my .- mX*bcoeffs; bcoeffs];
|
||||||
|
end
|
||||||
|
|
||||||
|
lambda_min, lambda_min_ind = findmin(rmsecv)
|
||||||
|
lambda_min_ind = lambda_min_ind[1]
|
||||||
|
b_lambda_min = bcoeffs[:,lambda_min_ind]
|
||||||
|
|
||||||
|
return b_lambda_min, rmsecv, lambda_min, lambda_min_ind
|
||||||
end
|
end
|
||||||
|
|
@ -18,6 +18,8 @@ export importData
|
||||||
|
|
||||||
# From "TR.jl"
|
# From "TR.jl"
|
||||||
export TRLooCV
|
export TRLooCV
|
||||||
|
export TRSegCV
|
||||||
|
export regularizationMatrix
|
||||||
|
|
||||||
include("convenience.jl")
|
include("convenience.jl")
|
||||||
include("TR.jl")
|
include("TR.jl")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue