TRSegCVUpdateFair
This commit is contained in:
parent
f50a442350
commit
82e2cd4496
2 changed files with 42 additions and 4 deletions
41
src/TR.jl
41
src/TR.jl
|
|
@ -49,8 +49,8 @@ for i = 1:n
|
|||
ys = ydata .- my;
|
||||
U, s, V = svd(Xs, full=false);
|
||||
|
||||
denom = broadcast(.+, broadcast(./, lambdasu, s'), s')'; # denom = lambda/s + lambda = (lambda + s's) / lambda
|
||||
denom2 = broadcast(.+, ones(n-1), broadcast(./, lambdasu', s.^2)) # denom2 = 1 + lambda/(s's) = (s's + lambda) / (s's)
|
||||
denom = broadcast(.+, broadcast(./, lambdasu, s'), s')';
|
||||
denom2 = broadcast(.+, ones(n-1), broadcast(./, lambdasu', s.^2))
|
||||
|
||||
# Calculating regression coefficients and residual
|
||||
bcoeffs = V * broadcast(./, (U' * ys), denom) .+ bOld .- V * broadcast(./, V' * bOld, denom2);
|
||||
|
|
@ -525,3 +525,40 @@ rmsecvman = sqrt.(1/n .* rmsecvman);
|
|||
return rmsecvman
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
K-fold CV for the Ridge regression problem, using the 'SVD-trick' for calculating the regression coefficients.
|
||||
"""
|
||||
function TRSegCVUpdateFair(X, y, lambdas, cv, bOld)
|
||||
|
||||
cv = cvfolds
|
||||
|
||||
n, p = size(X);
|
||||
rmsecvman = zeros(length(lambdasu));
|
||||
nfolds = length(unique(cv));
|
||||
|
||||
for i = 1:nfolds
|
||||
inds = (cv .== i);
|
||||
Xdata = X[vec(.!inds),:];
|
||||
ydata = y[vec(.!inds)];
|
||||
|
||||
mX = mean(Xdata, dims=1);
|
||||
my = mean(ydata);
|
||||
Xs = Xdata .- mX;
|
||||
ys = ydata .- my;
|
||||
|
||||
U, s, V = svd(Xs, full=false);
|
||||
|
||||
denom = broadcast(.+, broadcast(./, lambdasu, s'), s')';
|
||||
denom2 = broadcast(.+, ones(n-sum(inds)), broadcast(./, lambdasu', s.^2));
|
||||
|
||||
# Calculating regression coefficients
|
||||
bcoeffs = V * broadcast(./, (U' * ys), denom) .+ bOld .- V * broadcast(./, V' * bOld, denom2);
|
||||
rmsecvman += sum((y[vec(inds)] .- ((X[vec(inds),:] .- mX) * bcoeffs .+ my)).^2, dims=1)';
|
||||
|
||||
end
|
||||
|
||||
rmsecvman = sqrt.(1/n .* rmsecvman);
|
||||
|
||||
return rmsecvman
|
||||
end
|
||||
|
|
@ -26,6 +26,7 @@ export plegendre
|
|||
export TRLooCVUpdateFair
|
||||
export TRLooCVUpdateNaive
|
||||
export TRSegCVUpdateNaive
|
||||
export TRSegCVUpdateFair
|
||||
|
||||
include("convenience.jl")
|
||||
include("TR.jl")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue