%%%%%%%%%%% Classification baseline with KNN and DTW %%%%%%%%%%%%%%%%%%%%%%
clear;clc;
% (Re)load the training corpus under cosideration:
corpus = load("../processedCorpusAllEx.mat");
processedCorpus = corpus.processedCorpus;
seqsForClassification = 6;
distanceMetrics = {'absolute','euclidean','symmkl'};
numFolds = 5;
kMax = 8;
outputFileContent = struct(); % save all file io variables in here
outputFileContent.kMax = kMax;
outputFileContent.numFolds = numFolds;
rng(1337)
%% %%% Start the training & evaluation %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
for(testSubjectNumber = 1:9)
    testSubject = "speaker" + num2str(testSubjectNumber);
    
    %% Allocate Training and test sets:
    % The test set is a full speaker:
    testSetIdices = ([processedCorpus{:,5}] == testSubjectNumber);
    YTest = categorical([processedCorpus{testSetIdices,3}])';
    XTest = processedCorpus(testSetIdices,seqsForClassification);
    % The training set are all remaining speakers:
    YTrain = categorical([processedCorpus{~testSetIdices,3}])';
    XTrain = processedCorpus(~testSetIdices,seqsForClassification);
    % Create the cv-splits:
    cvPartitions = cvpartition(YTrain,'k',numFolds);
    % Check the splits:
    foldIndices = zeros(1,numel(YTrain));
    for(foldIndex = 1:numFolds)
       foldIndices = foldIndices + reshape(cvPartitions.test(foldIndex), ...
                     size(foldIndices)); 
    end
    if(length(unique(find(foldIndices))) ~= numel(YTrain))
       fprintf("Error: not all sequences are part of a validation test set!\n"); 
       return
    end

    %% "Train" the Knn Classifier by optimizing k and the DTW metric %%%%%%
    % Create the results file:
    resultsFileName = sprintf("Results\\%s_optResults_%dFoldCv.txt", ...
                              testSubject, numFolds);
    %%% Iterate over the distance metric %%%
    for metricIndex = 1:numel(distanceMetrics)
        metric = distanceMetrics{metricIndex};
        outputFileContent.metric = metric;
        %%% Iterate over the folds: %%%
        kOptOverFolds = zeros(1,kMax);
        validationErrorsOverFolds = zeros(kMax,numFolds); 
        for(foldIndex = 1:numFolds)
            valSeqIndices_logic = cvPartitions.test(foldIndex);
            % assign validation and training sets for the knn "training"
            XValidation = XTrain(cvPartitions.test(foldIndex));
            YValidation = YTrain(cvPartitions.test(foldIndex));
            XTrainCv = XTrain(cvPartitions.training(foldIndex));
            YTrainCv = YTrain(cvPartitions.training(foldIndex));
            %%% Iterate over k: %%%
            % allocate validation error vector:
            validationErrors = ones(1,kMax);
            % calculate validation error for each k:
            parfor k = 1:kMax
                predictedClasses = classifyKnn(XTrainCv, YTrainCv, XValidation, ...
                                               k, metric);
                validationErrors(k) = 1 - mean(predictedClasses == ...
                                   reshape(YValidation, size(predictedClasses)));
                fprintf("Validation error k = %d: %1.4f\n", k, validationErrors(k));
            end
            validationErrorsOverFolds(:,foldIndex) = validationErrors;
            % find the optimal k:
            kOpts = find(validationErrors == min(validationErrors));
            for(k = 1:length(kOpts))
                kOptOverFolds(1, kOpts(k)) = kOptOverFolds(1, kOpts(k)) + 1;
            end
        end
        outputFileContent.validationErrorsOverFolds = validationErrorsOverFolds;
        outputFileContent.kOptOverFolds = kOptOverFolds;
        % Calculate the test error:
        % Note: the test error is calculated for an optimized k and
        % every distance metric for evaluation purposes. Reported are
        % still the test results for {kopt, metric} with the highest CV
        % accuracy, not with the highest overall test accuracy.
        [~, kOpt] = max(kOptOverFolds);
        outputFileContent.kOpt = kOpt;
        for k = kOpt
            numTestSeqs = numel(XTest);
            predictedClasses = classifyKnn(XTrain, YTrain, XTest, k, metric);
            testError = 1 - mean(predictedClasses == reshape(YTest, size(predictedClasses)));
            % save the label|prediction pairs:
            testAndPredictions = [YTest, predictedClasses];
            save(testSubject + "_labelsAndPredictions_" + ...
                "_k=" + num2str(k) + "_" + metric + ...
                ".mat", "testAndPredictions");
        end
        outputFileContent.testError = testError;
        % Write results to file:
        writeToDtwKnnResultsFile(resultsFileName, outputFileContent);
    end
end
