function [ varargout ] = samplestatetrajectory( ModelParams, ModelSpec, SpikeTrainData, PositionData, stateTrajectory, jumpTimes, trajectoryMatBinInds, updateStepTime, lastUpdateStepTime, maxNTransitions, nParticles )
    % This function performs one Gibbs iteration for sampling a Markov jump
    % process. I.e., conditioned on a sample path and parameters, sample a
    % new path.
    %
    % We return some or all of the following: new sample paths (for
    % multiple particles), the marginal likelihood up to the current SMC
    % step, the marginal likelihood up to a previous time.
    %
    % We assume that all particles have the same number of states (state
    % space dimension). Thus ModelParams is a scalar struct, not a struct
    % array.
    %
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % This function contains calls to all other functions required in steps
    % of sampling a new MJP sample path conditioned on an old one via Rao &
    % Teh 2013's algorithm. This could be used for computing marginal
    % likelihoods and thence updating particle weights in the SMC
    % algorithm, and/or for sampling new MJP paths as part of the 'move'
    % kernel after resampling.
    %
    % There are 4 steps to sampling a new sample path: extend previous
    % process, sample dominating process arrival times, sample Markov
    % chain, discard self-transitions.
    %
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % CHANGELOG
    % 17.10.2014 - created. Brought code together from evolveparticles.m
    % and samplestatetrajectorygivenjumptimes.m (now redundant).
    % 28.01.2015 - refactored computelikelihood to computelikelihood.m.
    %
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

    if nargin < 9,
        error('Insufficient arguments!')
    end
    if nargin < 10,
        maxNTransitions = size(stateTrajectory, 1);
    else
        if size(stateTrajectory, 1) ~= maxNTransitions,
            error('Size of stateTrajectory does not match maxNTransitions!')
        end
    end
    if nargin < 11,
        nParticles = size(stateTrajectory, 2);
    else
        if size(stateTrajectory, 2) ~= nParticles,
            error('Size of stateTrajectory does not match nParticles!')
        end
    end
    if ~isfield(ModelParams, 'nTrueStatesThisBlock'),
        error('nTrueStatesThisBlock must be a field of ModelParams!')
    end
    if ~isfield(ModelParams, 'nStatesThisBlock'),
        error('nStatesThisBlock must be a field of ModelParams!')
    end
    if ~isfield(ModelParams, 'generatorMat'),
        error('generatorMat must be a field of ModelParams!')
    end
    if ~isfield(ModelParams, 'jumpRate'),
        error('jumpRate must be a field of ModelParams!')
    end
    if ~isfield(ModelParams, 'dominatingRate'),
        error('dominatingRate must be a field of ModelParams!')
    end
    if ~isfield(ModelParams, 'logTransitionMat'),
        error('logTransitionMat must be a field of ModelParams!')
    end
    if ~isfield(ModelParams, 'logInitialStateDist'),
        error('logInitialStateDist must be a field of ModelParams!')
    end
    if any(any(~islogical(trajectoryMatBinInds))),
        error('trajectoryMatBinInds must be a logical matrix!')
    end
    if any(any(trajectoryMatBinInds(2:end, :) > trajectoryMatBinInds(1:(end - 1), :))),
        error('Any particles excluded from a time point (row of trajectoryMatBinInds) must not be included in later time points!')
    end
    if ~all(trajectoryMatBinInds(1, :)),
        error('No particles must be excluded from EVERY time point!')
    end
    if ~isequal(size(trajectoryMatBinInds), size(stateTrajectory)),
        error('Size of trajectoryMatBinInds does nto match that of stateTrajectory!')
    end
    if size(jumpTimes) ~= [nParticles, 1],
        error('jumpTimes should be a column cell array with the same number of cells as trajectoryMatBinInds has columns, in a column!')
    end
    if ~isfield(ModelSpec, 'useAugStateModel'),
        error('useAugStateModel must be a field of ModelSpec!')
    end
    if ~isfield(ModelSpec, 'spikeTrainModelIndicator'),
        error('spikeTrainModelIndicator must be a field of ModelSpec!')
    end
    if ~isfield(ModelSpec, 'positionModelIndicator'),
        error('positionModelIndicator must be a field of ModelSpec!')
    end
    if ModelSpec.useAugStateModel,       
        if ~isfield(ModelSpec, 'augStateComponents'),
            error('augStateComponents must be a field of ModelSpec!')
        end
        if ~isfield(ModelSpec, 'stateSpaceTransformation'),
            error('stateSpaceTransformation must be a field of ModelSpec!')
        end 
        if ~isfield(ModelSpec, 'generatorMatNonzeroRowInds'),
            error('generatorMatNonzeroRowInds must be a field of ModelSpec!')
        end
        if ~isfield(ModelSpec, 'generatorMatNonzeroColInds'),
            error('generatorMatNonzeroColInds must be a field of ModelSpec!')
        end
        if ~isfield(ModelSpec, 'nNonzerosAugGeneratorMat'),
            error('nNonzerosAugGeneratorMat must be a field of ModelSpec!')
        end
    end
    if size(ModelParams.logTransitionMat, 3) ~= nParticles,
        error('n pages of logTransitionMat must match that of other data structures!')
    end
    if size(ModelParams.logInitialStateDist, 2) ~= nParticles,
        error('n columns of logInitialStateDist must match the number of pages (particles) of other data structures!')
    end
    if ModelSpec.spikeTrainModelIndicator > 0,
        if ~isfield(ModelParams, 'logSpikeRate'),
            error('logSpikeRate must be a field of ModelParams!')
        end
        if size(ModelParams.logSpikeRate, 3) ~= nParticles,
            error('n pages of logSpikeRate must match that of other data structures!')
        end
        if size(ModelParams.logSpikeRate, 1) ~= ModelParams.nTrueStatesThisBlock,
            error('n rows of logSpikeRate must be the number of true states!')
        end
        if ~isfield(SpikeTrainData, 'spikeTimesArray'),
            error('spikeTimesArray must be a field of SpikeTrainData!')
        end
        if ~isfield(SpikeTrainData, 'nSpikeTrains'),
            error('nSpikeTrains must be a field of SpikeTrainData!')
        end
    end
    if ModelSpec.positionModelIndicator > 0,
        if ~isfield(ModelParams, 'driftRate'),
            error('driftRate must be a field of ModelParams!')
        end
        if size(ModelParams.driftRate, 2) ~= nParticles,
            error('n columns of driftRate must match that of other data structures!')
        end
        if size(ModelParams.driftRate, 1) ~= ModelParams.nTrueStatesThisBlock,
            error('n rows of driftRate must be the number of true states!')
        end
        if ~isfield(ModelParams, 'asympMeanPosInd'),
            error('asympMeanPosInd must be a field of ModelParams!')
        end
        if size(ModelParams.asympMeanPosInd, 2) ~= nParticles,
            error('n columns of asympMeanPosInd must match that of other data structures!')
        end
        if size(ModelParams.asympMeanPosInd, 1) ~= ModelParams.nTrueStatesThisBlock,
            error('n rows of asympMeanPosInd must be the number of true states!')
        end
        if ~isfield(ModelParams, 'posCov'),
            error('posCov must be a field of ModelParams!')
        end
        if size(ModelParams.posCov, 3) ~= nParticles,
            error('n pages of posCov must match that of other data structures!')
        end
        if size(ModelParams.posCov, 2) ~= ModelParams.nTrueStatesThisBlock,
            error('n columns of posCov must be the number of true states!')
        end
        if ~isfield(PositionData, 'coordsTrajectory'),
            error('coordsTrajectory must be a field of PositionData!')
        end
        if ~isfield(PositionData, 'linearisedPositionTrajectory'),
            error('linearisedPositionTrajectory must be a field of PositionData!')
        end
        if ~isfield(PositionData, 'gridCentreCoordsTrajectory'),
            error('gridCentreCoordsTrajectory must be a field of PositionData!')
        end
        if ~isfield(PositionData, 'spaceTransformationFactor'),
            error('spaceTransformationFactor must be a field of PositionData!')
        end
        if ~isfield(PositionData, 'positionIndToGridCentreCoordsMap'),
            error('positionIndToGridCentreCoordsMap must be a field of PositionData!')
        end
        if ~isfield(PositionData, 'nValidDiscretePositions'),
            error('nValidDiscretePositions must be a field of PositionData!')
        end
        if size(PositionData.coordsTrajectory, 2) ~= 3,
            error('PositionData.coordsTrajectory must have three columns (time, yCoord, xCoord)!')
        end
    end
    
    if nargout == 1,
        doReturnSamplePaths = false;
        doComputeMarginalLikelihood = true;
        doComputePartialMarginalLikelihood = false;
    elseif nargout == 2,
        doReturnSamplePaths = false;
        doComputeMarginalLikelihood = true;
        doComputePartialMarginalLikelihood = true;
    elseif nargout == 5,
        doReturnSamplePaths = true;
        doComputeMarginalLikelihood = false;
        doComputePartialMarginalLikelihood = false;
    elseif nargout == 6,
        doReturnSamplePaths = true;
        doComputeMarginalLikelihood = true;
        doComputePartialMarginalLikelihood = false;
    elseif nargout == 7,
        doReturnSamplePaths = true;
        doComputeMarginalLikelihood = true;
        doComputePartialMarginalLikelihood = true;
    end

    % First we need to simulate the Markov jump process from the last
    % update time to the current update time and combine with the
    % previously sampled path.
    if updateStepTime > lastUpdateStepTime,
        [ stateTrajectory, jumpTimes, trajectoryMatBinInds, interjumpIntervals, maxNTransitions ] = extendmarkovjumpprocess( ModelParams, stateTrajectory, jumpTimes, trajectoryMatBinInds, lastUpdateStepTime, updateStepTime, maxNTransitions, ModelParams.nStatesThisBlock, nParticles );
    else
        interjumpIntervals = cellfun(@(jumpTimesVec) [jumpTimesVec(2:end) - jumpTimesVec(1:(end - 1)); updateStepTime - jumpTimesVec(end)], jumpTimes, 'UniformOutput', false);
    end

%     if any(cellfun(@(timesVec) ~isequal(timesVec, sort(timesVec)), jumpTimes)),
%         guilty = cellfun(@(timesVec) ~isequal(timesVec, sort(timesVec)), jumpTimes);
%         this = jumpTimes(guilty);
%         this{:}
%         lastUpdateStepTime
%         updateStepTime
%         any(any(ModelParams.jumpRate >= 0))
%         any(ModelParams.dominatingRate <= 0)
%         error('2nd time.')
%     end
    
    % Sample virtual jump times and add them to the jump times we already
    % have - this forms the dominating process jump times.
    [ dominatingProcessJumpTimes, dominatingProcessBinInds, maxNDominatingTransitions ] = sampledominatingprocessjumptimes( updateStepTime, ModelParams.jumpRate, ModelParams.dominatingRate, stateTrajectory, jumpTimes, trajectoryMatBinInds, interjumpIntervals, maxNTransitions, ModelParams.nStatesThisBlock, nParticles );

    % MAYBE IMPLEMENT THIS LATER.
%     % Check number of dominating process arrivals.
%     nDominatingArrivals = cellfun(@(timesVec) length(timesVec), dominatingProcessJumpTimes);
%     largeDominatingProcessBinInds = nDominatingArrivals > AlgorithmSpec.maxNJumpsLimit;
%     if any(largeDominatingProcessBinInds),
%         % Must thin these.
%         nLargeProcesses
%         % Ensure we don't remove the first jump (start time).
%         arrivalsToKeepInds = cellfun(@(nArrivals) [1, randperm(nArrivals - 1, AlgorithmSpec.maxNJumpsLimit - 1) + 1]', num2cell(nDominatingArrivals(largeDominatingProcessBinInds)), 'UniformOutput', true);
% %         arrivalsToKeepBinInds = repmat((1:maxNDominatingTransitions)', [1, nLargeProcesses]);
%         arrivalsToKeepBinInds = false([maxNDominatingTransitions, nLargeProcesses]);
%         arrivalsToKeepBinInds()
%         
%         % Thin.
%         dominatingProcessJumpTimes = cellfun(@(timesVec) , dominatingProcessJumpTimes, mat2cell(arrivalsToKeepInds, AlgorithmSpec.maxNJumpsLimit, ones([1, nLargeProcesses])), 'UniformOutput', false);
%     end

    dominatingProcessInterjumpIntervalsArray = cellfun(@(jumpTimesVec) [jumpTimesVec(2:end) - jumpTimesVec(1:(end - 1)); updateStepTime - jumpTimesVec(end)], dominatingProcessJumpTimes, 'UniformOutput', false);

    % Now we compute forward probabilities on the dominating process times.
%     % Preallocating logLikelihood.
%     logLikelihood = -inf([maxNDominatingTransitions, ModelParams.nTrueStatesThisBlock, nParticles]);
    % Preallocating for logForwardProbMat and stateTrajectory.
    logForwardProbMat = -inf([maxNDominatingTransitions, ModelParams.nStatesThisBlock, nParticles]);
    stateTrajectory = nan([maxNDominatingTransitions, nParticles]);

    % Compute likelihood.
    logLikelihood = computelikelihood(ModelSpec, SpikeTrainData, PositionData, ModelParams, false, dominatingProcessJumpTimes, dominatingProcessBinInds, jumpTimes{1}(1), updateStepTime, ModelParams.nTrueStatesThisBlock, maxNDominatingTransitions, nParticles, dominatingProcessInterjumpIntervalsArray);

    % Compute forward probabilities.
    logForwardProbMat = forwardfiltering( ModelSpec, logForwardProbMat, logLikelihood, ModelParams.logTransitionMat, ModelParams.logInitialStateDist, dominatingProcessBinInds, maxNDominatingTransitions, ModelParams.nStatesThisBlock, nParticles );
    clear logLikelihood;

    % We can also return the marginal likelihoods up to the current SMC
    % step time, for updating particle weights.
    if doComputeMarginalLikelihood,
        % Need the forward probabilities for the last forward algorithm
        % step for each particle.
        logEndForwardProbMat = reshape(logForwardProbMat(repmat(permute([dominatingProcessBinInds(1:(end - 1), :) & ~dominatingProcessBinInds(2:end, :); dominatingProcessBinInds(end, :)], [1, 3, 2]), [1, ModelParams.nStatesThisBlock, 1])), [ModelParams.nStatesThisBlock, nParticles]);
        logMarginalLikelihood = sumlogprobmat(logEndForwardProbMat, 1);
    end

    % Similarly, the marginal likelihoods up to the last SMC step time, for
    % updating particle weights.
    if doComputePartialMarginalLikelihood,
        logPartialMarginalLikelihood = computepartialmarginallikelihood( ModelSpec, SpikeTrainData, PositionData, ModelParams, logForwardProbMat, dominatingProcessJumpTimes, lastUpdateStepTime, maxNDominatingTransitions, ModelParams.nStatesThisBlock, nParticles );
    end

    % Sample states, backwards.
    if doReturnSamplePaths,
        stateTrajectory = backwardsampling( ModelSpec, logForwardProbMat, stateTrajectory, ModelParams.logTransitionMat, dominatingProcessBinInds, maxNDominatingTransitions, ModelParams.nStatesThisBlock, nParticles );

        % Remove self-transitions (virtual jumps of newly sampled process).
        [ stateTrajectory, jumpTimes, interjumpIntervals, trajectoryMatBinInds, maxNTransitions ] = removeselftransitions( stateTrajectory, dominatingProcessJumpTimes, dominatingProcessInterjumpIntervalsArray, dominatingProcessBinInds, updateStepTime, nParticles );
    end

    if nargout == 1,
        varargout = {logMarginalLikelihood};
    elseif nargout == 2,
        varargout = {logMarginalLikelihood, logPartialMarginalLikelihood};
    elseif nargout == 5,
        varargout = {stateTrajectory, jumpTimes, interjumpIntervals, trajectoryMatBinInds, maxNTransitions};
    elseif nargout == 6,
        varargout = {stateTrajectory, jumpTimes, interjumpIntervals, trajectoryMatBinInds, maxNTransitions, logMarginalLikelihood};
    elseif nargout == 7,
        varargout = {stateTrajectory, jumpTimes, interjumpIntervals, trajectoryMatBinInds, maxNTransitions, logMarginalLikelihood, logPartialMarginalLikelihood};
    end

    
    
    

%     function computelikelihood(ModelSpec, SpikeTrainData, PositionData, ModelParams, jumpTimes, trajectoryMatBinInds, endTime, nStates, interjumpIntervalsArray, maxNTransitions, nParticles)
%         % Inline function for repeated calls to likelihood functions.
%         if ModelSpec.spikeTrainModelIndicator > 0,
%             % Compute spike trains likelihood.
%             if nargin == 11,
%                 logLikelihood = computespiketrainslikelihood( SpikeTrainData.spikeTimesArray, ModelParams.logSpikeRate, jumpTimes, trajectoryMatBinInds, endTime, interjumpIntervalsArray, maxNTransitions, nParticles, SpikeTrainData.nSpikeTrains );
%             else
%                 logLikelihood = computespiketrainslikelihood( SpikeTrainData.spikeTimesArray, ModelParams.logSpikeRate, jumpTimes, trajectoryMatBinInds, endTime );
%             end
%             % Multiply by position likelihood.
%             if ModelSpec.positionModelIndicator > 0,
%                 if nargin == 11,
%                     logLikelihood = logLikelihood + computepositionlikelihood( PositionData, ModelParams.asympMeanPosInd, ModelParams.driftRate, ModelParams.posCov, jumpTimes, trajectoryMatBinInds, endTime, maxNTransitions, nStates, nParticles );
% %                     posLogLikelihood = computepositionlikelihood( PositionData, ModelParams.asympMeanPosInd, ModelParams.driftRate, ModelParams.posCov, jumpTimes, trajectoryMatBinInds, endTime, maxNTransitions, nStates, nParticles );
% %                     exp([logLikelihood, posLogLikelihood])
% %                     logLikelihood = logLikelihood + posLogLikelihood;
%                 else
%                     logLikelihood = logLikelihood + computepositionlikelihood( PositionData, ModelParams.asympMeanPosInd, ModelParams.driftRate, ModelParams.posCov, jumpTimes, trajectoryMatBinInds, endTime );
%                 end
%             end
% 
%         elseif ModelSpec.positionModelIndicator > 0,
%             % Compute position likelihood.
%             if nargin == 11,
%                 logLikelihood = computepositionlikelihood( PositionData, ModelParams.asympMeanPosInd, ModelParams.driftRate, ModelParams.posCov, jumpTimes, trajectoryMatBinInds, endTime, maxNTransitions, nStates, nParticles );
%             else
%                 logLikelihood = computepositionlikelihood( PositionData, ModelParams.asympMeanPosInd, ModelParams.driftRate, ModelParams.posCov, jumpTimes, trajectoryMatBinInds, endTime );
%             end
%         else
%             error('No observations models specified!')
%         end
%     end

end