% Use one of the following lines to prepare the data.
% This one for dataset TM1 (Jasper), HPC only.
[SpikeTrainDataRest1, PositionDataRest1, SpikeTrainDataRun1, PositionDataRun1, SpikeTrainDataRest2, PositionDataRest2] = setupdata( hpc_PN, positiondataclean, 'setup.m' );
% This one for dataset TM1 (Jasper), HPC, PC and PFC together.
[SpikeTrainDataRest1, PositionDataRest1, SpikeTrainDataRun1, PositionDataRun1, SpikeTrainDataRest2, PositionDataRest2] = setupdata( {hpc_PN{:}, pc_PN{:}, pfc_PN{:}}, positiondataclean, 'setup.m' );
% This for LT1 (Kenan) HPC data.
[SpikeTrainDataRest1, PositionDataRest1, SpikeTrainDataRun1, PositionDataRun1, SpikeTrainDataRest2, PositionDataRest2, SpikeTrainDataRun2, PositionDataRun2, SpikeTrainDataRest3, PositionDataRest3] = setupdata( {CA1{:}, CA3{:}}, [positiondataT1clean; positiondataT2clean], 'setup.m' );

% Run this; if the shape of the maze doesn't look right (are there holes?)
% run the above again with a different smoothingWidth parameter.
image(PositionDataRun1.validDiscretePositionsBinMat)


% OP model fitting.
% Configuration parameters.
[ ModelSpec, DataSpec, PriorParams ] = setupmodel( 'setup.m', SpikeTrainDataRun1, PositionDataRun1 );
ModelParams = samplemodelparametersfromprior( PriorParams, ModelSpec, DataSpec, ModelSpec.nParticles );
[AlgorithmParams, AlgorithmSpec] = setupalgorithm( ModelParams, ModelSpec, DataSpec, 'setup.m' );
% SMC algorithm.
[ AlgorithmParams, ModelParams, Diagnostics ] = evolveparticles( AlgorithmParams, ModelParams, PriorParams, ModelSpec, DataSpec, AlgorithmSpec, SpikeTrainDataRun1, PositionDataRun1 );


% Estimate number of states.
[~, estNStates] = max(Diagnostics.estLogStateDimensionPosterior(end, :), [], 2);
% Parameter estimates.
[ EstModelParams ] = estimateparametersfromsubpop( ModelParams, AlgorithmParams, ModelSpec, estNStates, DataSpec );



% State posteriors.
[mapStateTrajectory, logSmoothedPost] = decodestate( EstModelParams, ModelSpec, AlgorithmSpec, SpikeTrainDataRun1, PositionDataRun1, DataSpec, 1, DataSpec.startTime, DataSpec.endTime, [], DataSpec.nUpdateSteps, EstModelParams.logInitialStateDist, EstModelParams.logTransitionMat );
% Viterbi path.
viterbiStatePath = decodestate( EstModelParams, ModelSpec, AlgorithmSpec, SpikeTrainDataRun1, PositionDataRun1, DataSpec, 3, DataSpec.startTime, DataSpec.endTime, mapStateTrajectory(:, 1), DataSpec.nUpdateSteps, EstModelParams.logInitialStateDist, EstModelParams.logTransitionMat );



% Position decoding.
% Setup.
TestPositionData = PositionDataRun1;
[ TestModelSpec, TestDataSpec, ~ ] = setupmodel( 'setup.m', SpikeTrainDataRun1, TestPositionData );
[~, AlgorithmSpec] = setupalgorithm( ModelParams, TestModelSpec, TestDataSpec, 'setup.m' );
TestPositionData.havePositionData = false;
TestModelSpec.positionModelIndicator = 0;

% Decode.
[logPosPost, logYPosPost, logXPosPost, mapPosTrajectory] = computepositionposterior( EstModelParams, TestModelSpec, AlgorithmSpec, SpikeTrainDataRun1, TestPositionData, TestDataSpec, AlgorithmSpec.updateStepTimes(1:(end - 1)), SpikeTrainDataRun1.nUpdateSteps, EstModelParams.logInitialStateDist, EstModelParams.logTransitionMat, estNStates, TestPositionData.startTime, TestPositionData.endTime, logSmoothedPost );
% Coordinates of MAP trajectory.
[~, mapYPosTrajectory] = max(logYPosPost, [], 2);
[~, mapXPosTrajectory] = max(logXPosPost, [], 2);


%NW: below 'timeVec' is not defined


% Comparisons:
% Decoding using BD.
spikeRateMap = zeros([PositionDataRun1.nValidDiscretePositions, SpikeTrainDataRun1.nSpikeTrains]);
positionCounts = histc(PositionDataRun1.linearisedPositionTrajectory, 1:PositionDataRun1.nValidDiscretePositions);
meanFiringRates = sum(SpikeTrainDataRun1.discreteSpikeTrainsMat, 1) / SpikeTrainDataRun1.nUpdateSteps;
logPositionProbs = log(positionCounts) - log(length(PositionDataRun1.posTimeVec));
for i = 1:SpikeTrainDataRun1.nSpikeTrains,
    spikeRateMap(:, i) = accumarray(PositionDataRun1.linearisedPositionTrajectory, SpikeTrainDataRun1.discreteSpikeTrainsMat(:, i), [PositionDataRun1.nValidDiscretePositions, 1]);
end
spikeRateMap = bsxfun(@rdivide, spikeRateMap, positionCounts);
spikeRateMap(positionCounts == 0, :) = repmat(meanFiringRates, [sum(positionCounts == 0), 1]);

intermediary = bsxfun(@times, SpikeTrainDataRun1.discreteSpikeTrainsMat, permute(log(TestPositionData.posObsInterval .* spikeRateMap), [3, 2, 1]));
intermediary(isnan(intermediary)) = 0;
logLikelihood = permute(sum(bsxfun(@plus, -TestPositionData.posObsInterval .* spikeRateMap(TestPositionData.linearisedPositionTrajectory, :) - logfactorial(SpikeTrainDataRun1.discreteSpikeTrainsMat), intermediary), 2), [1, 3, 2]);
logPosPost = normaliselogdistributionsmatrix(bsxfun(@plus, logLikelihood, logPositionProbs'), 2);
[~, mapPosTrajectory] = max(logPosPost, [], 2);
mapPosTrajectory = [timeVec, mapPosTrajectory];
[ logYPosPost, logXPosPost ] = marginalisepositionlogposteriors( logPosPost, TestPositionData );
[~, mapYPosTrajectory] = max(logYPosPost, [], 2);
[~, mapXPosTrajectory] = max(logXPosPost, [], 2);

% Decoding using LP.
transitionIndexingMat = [PositionDataRun1.linearisedPositionTrajectory(1:(end - 1)), PositionDataRun1.linearisedPositionTrajectory(2:end)];
transitionCountsMat = accumarray(transitionIndexingMat, 1, [PositionDataRun1.nValidDiscretePositions, PositionDataRun1.nValidDiscretePositions]);
distanceMat = PositionDataRun1.constrainedDistanceMat ./ PositionDataRun1.spaceTransformationFactor;
sumSquaredDistancesVec = sum(distanceMat .^ 2 .* transitionCountsMat, 2);
sigmaHats = sqrt(sumSquaredDistancesVec ./ sum(transitionCountsMat, 2));
logTransitionMat = repmat(-log(sigmaHats) + 1 / 2 * log(2) - 1 / 2 * log(pi), [1, PositionDataRun1.nValidDiscretePositions]);
logTransitionMat = logTransitionMat -1 / 2 * (bsxfun(@rdivide, distanceMat, sigmaHats)) .^ 2;
logTransitionMat(isnan(sigmaHats), :) = log(1 / PositionDataRun1.nValidDiscretePositions);
logTransitionMat = normaliselogdistributionsmatrix( logTransitionMat, 2 );

intermediary = bsxfun(@times, SpikeTrainDataRun1.discreteSpikeTrainsMat, permute(log(TestPositionData.posObsInterval .* spikeRateMap), [3, 2, 1]));
intermediary(isnan(intermediary)) = 0;
logLikelihood = permute(sum(bsxfun(@plus, -TestPositionData.posObsInterval .* spikeRateMap(TestPositionData.linearisedPositionTrajectory, :) - logfactorial(SpikeTrainDataRun1.discreteSpikeTrainsMat), intermediary), 2), [1, 3, 2]);
ModelSpec.useAugStateModel = false;
logForwardProbMat = forwardfiltering( ModelSpec, [], logLikelihood, logTransitionMat, logPositionProbs );
logBackwardProbMat = backwardsmoothing( ModelSpec, [], logLikelihood, logTransitionMat );
logPosPost = logForwardProbMat + logBackwardProbMat;
logPosPost = normaliselogdistributionsmatrix( logPosPost, 2 );
[~, mapPosTrajectory] = max(logPosPost, [], 2);
mapPosTrajectory = [timeVec, mapPosTrajectory];
[ logYPosPost, logXPosPost ] = marginalisepositionlogposteriors( logPosPost, TestPositionData );
[~, mapYPosTrajectory] = max(logYPosPost, [], 2);
[~, mapXPosTrajectory] = max(logXPosPost, [], 2);