function [ varargout ] = samplehmmstatetrajectory( ModelParams, ModelSpec, SpikeTrainData, PositionData, logEndForwardProbMat, jumpTimes, trajectoryMatBinInds, updateStepTime, previousUpdateStepTime, maxNTransitions, nParticles )
    % This function performs one Gibbs iteration for sampling a Markov jump
    % process. I.e., conditioned on a sample path and parameters, sample a
    % new path.
    %
    % We return some or all of the following: new sample paths (for
    % multiple particles), the marginal likelihood up to the current SMC
    % step, the marginal likelihood up to a previous time.
    %
    % We assume that all particles have the same number of states (state
    % space dimension). Thus ModelParams is a scalar struct, not a struct
    % array.
    %
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % This function contains calls to all other functions required in steps
    % of sampling a new MJP sample path conditioned on an old one via Rao &
    % Teh 2013's algorithm. This could be used for computing marginal
    % likelihoods and thence updating particle weights in the SMC
    % algorithm, and/or for sampling new MJP paths as part of the 'move'
    % kernel after resampling.
    %
    % There are 4 steps to sampling a new sample path: extend previous
    % process, sample dominating process arrival times, sample Markov
    % chain, discard self-transitions.
    %
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % CHANGELOG
    % 17.10.2014 - created. Brought code together from evolveparticles.m
    % and samplestatetrajectorygivenjumptimes.m (now redundant).
    % 28.01.2015 - refactored computelikelihood to computelikelihood.m.
    %
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

    if nargin < 10,
        error('Insufficient arguments!')
    end
    if nargin < 11,
        maxNTransitions = size(trajectoryMatBinInds, 1);
    else
        if size(trajectoryMatBinInds, 1) ~= maxNTransitions,
            error('Size of trajectoryMatBinInds does not match maxNTransitions!')
        end
    end
    if nargin < 12,
        nParticles = size(trajectoryMatBinInds, 2);
    else
        if size(trajectoryMatBinInds, 2) ~= nParticles,
            error('Size of trajectoryMatBinInds does not match nParticles!')
        end
    end
    if ~isfield(ModelParams, 'nTrueStatesThisBlock'),
        error('nTrueStatesThisBlock must be a field of ModelParams!')
    end
    if ~isfield(ModelParams, 'nStatesThisBlock'),
        error('nStatesThisBlock must be a field of ModelParams!')
    end
    if ~isfield(ModelParams, 'logTransitionMat'),
        error('logTransitionMat must be a field of ModelParams!')
    end
    if ~isfield(ModelParams, 'logInitialStateDist'),
        error('logInitialStateDist must be a field of ModelParams!')
    end
    if any(any(~islogical(trajectoryMatBinInds))),
        error('trajectoryMatBinInds must be a logical matrix!')
    end
    if any(any(trajectoryMatBinInds(2:end, :) > trajectoryMatBinInds(1:(end - 1), :))),
        error('Any particles excluded from a time point (row of trajectoryMatBinInds) must not be included in later time points!')
    end
    if ~all(trajectoryMatBinInds(1, :)),
        error('No particles must be excluded from EVERY time point!')
    end
    if size(jumpTimes) ~= [nParticles, 1],
        error('jumpTimes should be a column cell array with the same number of cells as trajectoryMatBinInds has columns, in a column!')
    end
    if ~isfield(ModelSpec, 'useAugStateModel'),
        error('useAugStateModel must be a field of ModelSpec!')
    end
    if ~isfield(ModelSpec, 'spikeTrainModelIndicator'),
        error('spikeTrainModelIndicator must be a field of ModelSpec!')
    end
    if ~isfield(ModelSpec, 'positionModelIndicator'),
        error('positionModelIndicator must be a field of ModelSpec!')
    end
    if ModelSpec.useAugStateModel,       
        if ~isfield(ModelSpec, 'augStateComponents'),
            error('augStateComponents must be a field of ModelSpec!')
        end
        if ~isfield(ModelSpec, 'stateSpaceTransformation'),
            error('stateSpaceTransformation must be a field of ModelSpec!')
        end 
        if ~isfield(ModelSpec, 'generatorMatNonzeroRowInds'),
            error('generatorMatNonzeroRowInds must be a field of ModelSpec!')
        end
        if ~isfield(ModelSpec, 'generatorMatNonzeroColInds'),
            error('generatorMatNonzeroColInds must be a field of ModelSpec!')
        end
        if ~isfield(ModelSpec, 'nNonzerosAugGeneratorMat'),
            error('nNonzerosAugGeneratorMat must be a field of ModelSpec!')
        end
    end
    if size(ModelParams.logTransitionMat, 3) ~= nParticles,
        error('n pages of logTransitionMat must match that of other data structures!')
    end
    if size(ModelParams.logInitialStateDist, 2) ~= nParticles,
        error('n columns of logInitialStateDist must match the number of pages (particles) of other data structures!')
    end
    if ModelSpec.spikeTrainModelIndicator > 0,
        if ~isfield(ModelParams, 'logSpikeRate'),
            error('logSpikeRate must be a field of ModelParams!')
        end
        if size(ModelParams.logSpikeRate, 3) ~= nParticles,
            error('n pages of logSpikeRate must match that of other data structures!')
        end
        if size(ModelParams.logSpikeRate, 1) ~= ModelParams.nTrueStatesThisBlock,
            error('n rows of logSpikeRate must be the number of true states!')
        end
        if ~isfield(SpikeTrainData, 'spikeTimesArray'),
            error('spikeTimesArray must be a field of SpikeTrainData!')
        end
        if ~isfield(SpikeTrainData, 'nSpikeTrains'),
            error('nSpikeTrains must be a field of SpikeTrainData!')
        end
    end
    if ModelSpec.positionModelIndicator > 0,
        if ~isfield(ModelParams, 'driftRate'),
            error('driftRate must be a field of ModelParams!')
        end
        if size(ModelParams.driftRate, 2) ~= nParticles,
            error('n columns of driftRate must match that of other data structures!')
        end
        if size(ModelParams.driftRate, 1) ~= ModelParams.nTrueStatesThisBlock,
            error('n rows of driftRate must be the number of true states!')
        end
        if ~isfield(ModelParams, 'asympMeanPosInd'),
            error('asympMeanPosInd must be a field of ModelParams!')
        end
        if size(ModelParams.asympMeanPosInd, 2) ~= nParticles,
            error('n columns of asympMeanPosInd must match that of other data structures!')
        end
        if size(ModelParams.asympMeanPosInd, 1) ~= ModelParams.nTrueStatesThisBlock,
            error('n rows of asympMeanPosInd must be the number of true states!')
        end
        if ~isfield(ModelParams, 'posCov'),
            error('posCov must be a field of ModelParams!')
        end
        if (ModelSpec.positionJitterModelIndicator == 1) || (ModelSpec.positionJitterModelIndicator == 2),
            if size(ModelParams.posCov, 3) ~= nParticles,
                error('n pages of posCov must match that of other data structures!')
            end
            if size(ModelParams.posCov, 2) ~= ModelParams.nTrueStatesThisBlock,
                error('n columns of posCov must be the number of true states!')
            end
        elseif ModelSpec.positionJitterModelIndicator == 3,
            if size(ModelParams.posCov, 4) ~= nParticles,
                error('n pages of posCov must match that of other data structures!')
            end
            if size(ModelParams.posCov, 3) ~= ModelParams.nTrueStatesThisBlock,
                error('n columns of posCov must be the number of true states!')
            end
        end
        if ~isfield(PositionData, 'coordsTrajectory'),
            error('coordsTrajectory must be a field of PositionData!')
        end
        if ~isfield(PositionData, 'linearisedPositionTrajectory'),
            error('linearisedPositionTrajectory must be a field of PositionData!')
        end
        if ~isfield(PositionData, 'gridCentreCoordsTrajectory'),
            error('gridCentreCoordsTrajectory must be a field of PositionData!')
        end
        if ~isfield(PositionData, 'spaceTransformationFactor'),
            error('spaceTransformationFactor must be a field of PositionData!')
        end
        if ~isfield(PositionData, 'positionIndToGridCentreCoordsMap'),
            error('positionIndToGridCentreCoordsMap must be a field of PositionData!')
        end
        if ~isfield(PositionData, 'nValidDiscretePositions'),
            error('nValidDiscretePositions must be a field of PositionData!')
        end
        if size(PositionData.coordsTrajectory, 2) ~= 3,
            error('PositionData.coordsTrajectory must have three columns (time, yCoord, xCoord)!')
        end
    end
    if updateStepTime > previousUpdateStepTime,
        doExtendTrajectories = true;
    else
        doExtendTrajectories = false;
    end
    
    if nargout == 5,
        doReturnSamplePaths = true;
        doComputeMarginalLikelihood = false;
        doComputePartialMarginalLikelihood = false;
    elseif nargout == 6,
        doReturnSamplePaths = false;
        doComputeMarginalLikelihood = true;
        doComputePartialMarginalLikelihood = false;
    elseif nargout == 7,
        doReturnSamplePaths = false;
        doComputeMarginalLikelihood = true;
        doComputePartialMarginalLikelihood = true;
    end
    
    if doComputePartialMarginalLikelihood && ~isempty(logEndForwardProbMat),
        error('Trying to compute partial ML and perform single step of forward algorithm!')
    end

    % First extend jump times etc.
%     % The ith jump time is the END time for the (i-1)th interval. The first
%     % jump time is the start time, in which the process spends zero time.
    % Correction: the ith jump time is the START of the ith interval -
    % including the first jump time.
    if doExtendTrajectories,
%         jumpTimes = cellfun(@(jumpTimesVec) [jumpTimesVec; updateStepTime], jumpTimes, 'UniformOutput', false);
        jumpTimes = cellfun(@(jumpTimesVec) [jumpTimesVec; previousUpdateStepTime], jumpTimes, 'UniformOutput', false);
        trajectoryMatBinInds = [trajectoryMatBinInds; true([1, nParticles])];
        maxNTransitions = maxNTransitions + 1;
    end

    % Zero time spent at initialisation: this really means the first jump
    % time / state in the state process does not correspond to any
    % interval.
    interjumpIntervals = cellfun(@(jumpTimesVec) [jumpTimesVec(2:end) - jumpTimesVec(1:(end - 1)); updateStepTime - jumpTimesVec(end)], jumpTimes, 'UniformOutput', false);

    % Compute likelihood and forward probabilities.
    if ~isempty(logEndForwardProbMat),
        logLikelihood = computelikelihood(ModelSpec, SpikeTrainData, PositionData, ModelParams, false, cellfun(@(jumpTimesVec) jumpTimesVec(end), jumpTimes, 'UniformOutput', false), trajectoryMatBinInds(end, :), jumpTimes{1}(end), updateStepTime, ModelParams.nTrueStatesThisBlock, 1, nParticles, interjumpIntervals);
%         logForwardProbMat = -inf([1, ModelParams.nStatesThisBlock, nParticles]);
%         logForwardProbMat = forwardfiltering( ModelSpec, logForwardProbMat, logLikelihood, ModelParams.logTransitionMat, logEndForwardProbMat, trajectoryMatBinInds(end, :), 1, ModelParams.nStatesThisBlock, nParticles );
        if ModelSpec.useAugStateModel,
            logEndForwardProbMat = forwardfilteringsinglestepaugmented( ModelSpec, logEndForwardProbMat, logLikelihood, ModelParams.logTransitionMat, ModelParams.nStatesThisBlock, nParticles );
        else
            logEndForwardProbMat = forwardfilteringsinglestepregular( logEndForwardProbMat, logLikelihood, ModelParams.logTransitionMat );
        end
%         logEndForwardProbMat = reshape(logEndForwardProbMat, [ModelParams.nStatesThisBlock, nParticles]);
    else
        logForwardProbMat = -inf([maxNTransitions, ModelParams.nStatesThisBlock, nParticles]);
        stateTrajectory = nan([maxNTransitions, nParticles]);
        logLikelihood = computelikelihood(ModelSpec, SpikeTrainData, PositionData, ModelParams, false, jumpTimes, trajectoryMatBinInds, jumpTimes{1}(1), updateStepTime, ModelParams.nTrueStatesThisBlock, maxNTransitions, nParticles, interjumpIntervals);
        logForwardProbMat = forwardfiltering( ModelSpec, logForwardProbMat, logLikelihood, ModelParams.logTransitionMat, ModelParams.logInitialStateDist, trajectoryMatBinInds, maxNTransitions, ModelParams.nStatesThisBlock, nParticles );
%         logEndForwardProbMat = reshape(logForwardProbMat(end, :, :), [ModelParams.nStatesThisBlock, nParticles]);
        logEndForwardProbMat = logForwardProbMat(end, :, :);
    end
    clear logLikelihood;

    % We can also return the marginal likelihoods up to the current SMC
    % step time, for updating particle weights.
    if doComputeMarginalLikelihood,
        % Need the forward probabilities for the last forward algorithm
        % step for each particle.
        logMarginalLikelihood = permute(sumlogprobmat(logEndForwardProbMat, 2), [1, 3, 2]);
    end

    % Similarly, the marginal likelihoods up to the last SMC step time, for
    % updating particle weights.
    if doComputePartialMarginalLikelihood,
%         logPartialMarginalLikelihood = computepartialmarginallikelihood( ModelSpec, SpikeTrainData, PositionData, ModelParams, logForwardProbMat, jumpTimes, lastUpdateStepTime, maxNTransitions, ModelParams.nStatesThisBlock, nParticles );
        % Need the forward probabilities for the second last forward
        % algorithm step for each particle.
%         logEndForwardProbMat = reshape(logForwardProbMat(end - 1, :, :), [ModelParams.nStatesThisBlock, nParticles]);
        logPartialMarginalLikelihood = permute(sumlogprobmat(logForwardProbMat(end - 1, :, :), 2), [1, 3, 2]);
    end

    % Sample states, backwards.
    if doReturnSamplePaths,
        stateTrajectory = backwardsampling( ModelSpec, logForwardProbMat, stateTrajectory, ModelParams.logTransitionMat, trajectoryMatBinInds, maxNTransitions, ModelParams.nStatesThisBlock, nParticles );
    end

    if nargout == 5,
        varargout = {stateTrajectory, jumpTimes, interjumpIntervals, trajectoryMatBinInds, maxNTransitions};
    elseif nargout == 6,
        varargout = {logEndForwardProbMat, jumpTimes, interjumpIntervals, trajectoryMatBinInds, maxNTransitions, logMarginalLikelihood};
    elseif nargout == 7,
        varargout = {logEndForwardProbMat, jumpTimes, interjumpIntervals, trajectoryMatBinInds, maxNTransitions, logMarginalLikelihood, logPartialMarginalLikelihood};
    end

end