Hopefully finished with general functions
This commit is contained in:
parent
3a4293b572
commit
1c0c6edf0f
2 changed files with 43 additions and 6 deletions
43
src/TR.jl
43
src/TR.jl
|
|
@ -470,7 +470,7 @@ The LS problem is solved explicitly and no shortcuts are used.
|
||||||
function TRSegCVNaive(X, y, lambdas, cvfolds)
|
function TRSegCVNaive(X, y, lambdas, cvfolds)
|
||||||
|
|
||||||
n, p = size(X);
|
n, p = size(X);
|
||||||
rmsecvman = zeros(length(lambdas));
|
rmsecv = zeros(length(lambdas));
|
||||||
nfolds = length(unique(cvfolds));
|
nfolds = length(unique(cvfolds));
|
||||||
|
|
||||||
for j = 1:length(lambdas)
|
for j = 1:length(lambdas)
|
||||||
|
|
@ -489,9 +489,9 @@ for j = 1:length(lambdas)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
rmsecvman = sqrt.(1/n .* rmsecvman);
|
rmsecv = sqrt.(1/n .* rmsecv);
|
||||||
|
|
||||||
return rmsecvman
|
return rmsecv
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
@ -527,7 +527,7 @@ end
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
K-fold CV for the Ridge regression problem, using the 'SVD-trick' for calculating the regression coefficients.
|
K-fold CV for the Ridge regression update problem, using the 'SVD-trick' for calculating the regression coefficients.
|
||||||
"""
|
"""
|
||||||
function TRSegCVUpdateFair(X, y, lambdas, cv, bOld)
|
function TRSegCVUpdateFair(X, y, lambdas, cv, bOld)
|
||||||
|
|
||||||
|
|
@ -560,3 +560,38 @@ rmsecvman = sqrt.(1/n .* rmsecvman);
|
||||||
|
|
||||||
return rmsecvman
|
return rmsecvman
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
K-fold CV for the Ridge regression problem, using the 'SVD-trick' for calculating the regression coefficients.
|
||||||
|
"""
|
||||||
|
function TRSegCVFair(X, y, lambdas, cv)
|
||||||
|
|
||||||
|
n, p = size(X);
|
||||||
|
rmsecv = zeros(length(lambdas));
|
||||||
|
nfolds = length(unique(cv));
|
||||||
|
|
||||||
|
for i = 1:nfolds
|
||||||
|
inds = (cv .== i);
|
||||||
|
Xdata = X[vec(.!inds),:];
|
||||||
|
ydata = y[vec(.!inds)];
|
||||||
|
|
||||||
|
mX = mean(Xdata, dims=1);
|
||||||
|
my = mean(ydata);
|
||||||
|
Xs = Xdata .- mX;
|
||||||
|
ys = ydata .- my;
|
||||||
|
|
||||||
|
U, s, V = svd(Xs, full=false);
|
||||||
|
|
||||||
|
denom = broadcast(.+, broadcast(./, lambdas, s'), s')';
|
||||||
|
denom2 = broadcast(.+, ones(n-sum(inds)), broadcast(./, lambdas', s.^2));
|
||||||
|
|
||||||
|
# Calculating regression coefficients
|
||||||
|
bcoeffs = V * broadcast(./, (U' * ys), denom);
|
||||||
|
rmsecv += sum((y[vec(inds)] .- ((X[vec(inds),:] .- mX) * bcoeffs .+ my)).^2, dims=1)';
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
rmsecv = sqrt.(1/n .* rmsecv);
|
||||||
|
|
||||||
|
return rmsecv
|
||||||
|
end
|
||||||
|
|
@ -27,6 +27,8 @@ export TRLooCVUpdateFair
|
||||||
export TRLooCVUpdateNaive
|
export TRLooCVUpdateNaive
|
||||||
export TRSegCVUpdateNaive
|
export TRSegCVUpdateNaive
|
||||||
export TRSegCVUpdateFair
|
export TRSegCVUpdateFair
|
||||||
|
export TRSegCVNaive
|
||||||
|
export TRSegCVFair
|
||||||
|
|
||||||
include("convenience.jl")
|
include("convenience.jl")
|
||||||
include("TR.jl")
|
include("TR.jl")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue