function ModelParams = samplemodelparametersfromprior( PriorParams, ModelSpec, DataSpec, nSamples )
    % This function is used to sample parameter values, for initialisation
    % of the SMC algorithm (in which case nSamples indicates the number of
    % particles) or to simulate a system and, subsequently, data (in which
    % case nSamples should be 1).
    %
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % nSamples - number of particles for SMC, parameters stored in cell
    % array. If not supplied, we return only one sample, in matrices, not
    % cell arrays.
    %
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % CHANGELOG
    % 01.10.2014 - created.
    % 09.10.2014 - added data checks. Got working for spike train models
    % only. Corrected initial state distribution for fixed state space.
    % 15.10.2014 - changed ModelParams to be a struct array, indexed by
    % block, each element (block) having fields associated with particular
    % parameters (generatorMat, logSpikeRate etc). Also removed 'aug'
    % parameter name prefixes so that now the names are the same whether
    % using the augmented state model or not. Added number of states
    % (augmented if augmented model used) to each struct.
    % 16.10.2014 - changed dimensions of DataParams to (nStates, 1), i.e.
    % a column. Added nTrueStatesThisBlock to output.
    % 24.10.2014 - implemented position model: continuously observed,
    % independent jitter.
    % 26.10.2014 - added total covariance parameter, posCov, for position
    % model.
    % 05.11.2014 - changed the jitter prior to include a gamma distribution
    % (mixed with uniform) - the prior parameters were available already
    % but dormant. This prior makes sampling from the posterior easier.
    %
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    if nargin < 3,
        error('Insufficient arguments!');
    end
    if nargin < 4,
        oneSample = true;
        nSamples = 1;
    else
        oneSample = false;
    end
    if ~isfield(ModelSpec, 'modelIndicator'),
        error('modelIndicator must be a field of ModelSpec!')
    end
    if ~isfield(PriorParams, 'stateDimensionDist'),
        error('stateDimensionDist must be a field of PriorParams!')
    end
    if ~isfield(PriorParams, 'transitionRatesDirichletHyperparams'),
        error('transitionRatesDirichletHyperparams must be a field of PriorParams!')
    end
    if ModelSpec.modelIndicator == 2,
        if ~isfield(PriorParams, 'jumpRateGammaHyperparams'),
            error('jumpRateGammaHyperparams must be a field of PriorParams!')
        end
        if ~isfield(ModelSpec, 'dominatingRateScalingFactor'),
            error('dominatingRateScalingFactor must be a field of ModelSpec!')
        end
    end
    if ~isfield(ModelSpec, 'maxNStates'),
        error('maxNStates must be a field of ModelSpec!')
    end
    if ~isfield(ModelSpec, 'estimateNStates'),
        error('estimateNStates must be a field of ModelSpec!')
    end
    if ~isfield(ModelSpec, 'useAugStateModel'),
        error('useAugStateModel must be a field of ModelSpec!')
    end
    if ~isfield(ModelSpec, 'spikeTrainModelIndicator'),
        error('spikeTrainModelIndicator must be a field of ModelSpec!')
    end
    if ~isfield(ModelSpec, 'positionModelIndicator'),
        error('positionModelIndicator must be a field of ModelSpec!')
    end
    if ModelSpec.useAugStateModel,
        if ~isfield(ModelSpec, 'maxNAugStates'),
            error('maxNAugStates must be a field of ModelSpec!')
        end
        if ~isfield(ModelSpec, 'stateSpaceTransformation'),
            error('stateSpaceTransformation must be a field of ModelSpec!')
        end
        if ~isfield(ModelSpec, 'augStateComponents'),
            error('augStateComponents must be a field of ModelSpec!')
        end
    else
        if ~isfield(ModelSpec, 'initialStateModelIndicator'),
            error('initialStateModelIndicator must be a field of ModelSpec!')
        end
        if ModelSpec.initialStateModelIndicator == 3,
            if ~isfield(PriorParams, 'initialStateDistDirichletHyperparams'),
                error('initialStateDistDirichletHyperparams must be a field of PriorParams!')
            end
        end
    end
    if ModelSpec.positionModelIndicator > 0,
        if ~isfield(ModelSpec, 'positionJitterModelIndicator'),
            error('positionJitterModelIndicator must be a field of ModelSpec!')
        end
        if ~isfield(ModelSpec, 'positionObsErrorPriorIndicator'),
            error('positionObsErrorPriorIndicator must be a field of ModelSpec!')
        end
        if ~isfield(ModelSpec, 'positionObsErrorTolerance'),
            error('positionObsErrorTolerance must be a field of ModelSpec!')
        end
        if ModelSpec.modelIndicator == 2,
            if ~isfield(PriorParams, 'driftRateUniformHyperparams'),
                error('driftRateUniformHyperparams must be a field of PriorParams!')
            end
        end
        if ~isfield(PriorParams, 'asympMeanPriorDist'),
            error('asympMeanPriorDist must be a field of PriorParams!')
        end
%         if ModelSpec.positionModelIndicator == 1,
%             if ~isfield(PriorParams, 'obsErrorPrecisionGammaHyperparams'),
%                 error('obsErrorPrecisionGammaHyperparams must be a field of PriorParams!')
%             end
%         end
        if (ModelSpec.positionJitterModelIndicator == 1) || (ModelSpec.positionJitterModelIndicator == 2),
            if ~isfield(PriorParams, 'jitterPrecisionGammaHyperparams'),
                error('jitterPrecisionGammaHyperparams must be a field of PriorParams!')
            end
        elseif (ModelSpec.positionJitterModelIndicator == 3) && (ModelSpec.positionModelIndicator == 2)
            if ~isfield(PriorParams, 'jitterCovPriorDof'),
                error('jitterCovPriorDof must be a field of PriorParams!')
            end
            if ~isfield(PriorParams, 'jitterCovPriorScaleMat'),
                error('jitterCovPriorScaleMat must be a field of PriorParams!')
            end
        end
        if ~isfield(DataSpec, 'posObsInterval'),
            error('posObsInterval must be a field of DataSpec!')
        end
        if ~isfield(DataSpec, 'spatialBinWidth'),
            error('spatialBinWidth must be a field of DataSpec!')
        end
    end
    if ModelSpec.spikeTrainModelIndicator > 0,
        if ~isfield(PriorParams, 'spikeRateGammaHyperparams'),
            error('spikeRateGammaHyperparams must be a field of PriorParams!')
        end
        if ~isfield(DataSpec, 'nSpikeTrains'),
            error('nSpikeTrains must be a field of DataSpec!')
        end
    end

    % Initialise parameter structures.
    ModelParams = struct([]);

    % Take the total number of particles and distribute them among the
    % different maximum state dimensions according to the prior over true
    % state dimension.
    if ModelSpec.estimateNStates && ~oneSample,
        % Distribute particles according to the prior on state dimension.
        nParticlesEachSubpop = mnrnd(nSamples, PriorParams.stateDimensionDist);
    else
        % Distribute particles into last block only (we are fixing the
        % state dimension).
        nParticlesEachSubpop = zeros([1, ModelSpec.maxNStates]);
        nParticlesEachSubpop(end) = nSamples;
    end

    % Loop over subpopulations of particles.
    for iSubpop = 1:ModelSpec.maxNStates,
        % Special field recording number of particles in each block.
        ModelParams(iSubpop, 1).nParticlesThisBlock = nParticlesEachSubpop(iSubpop);
        ModelParams(iSubpop, 1).nTrueStatesThisBlock = iSubpop;
        if nParticlesEachSubpop(iSubpop) == 0,
            continue;
        end
        
        % State process parameters.
        if ModelSpec.useAugStateModel,
            % Number of states for this block.
            ModelParams(iSubpop, 1).nStatesThisBlock = ModelSpec.stateSpaceTransformation(iSubpop);
                    
            % Generator and jump rate.
            if ModelSpec.modelIndicator == 2,
                [ModelParams(iSubpop, 1).generatorMat, ModelParams(iSubpop, 1).jumpRate] = samplegeneratormatrix( ModelSpec, PriorParams.transitionRatesDirichletHyperparams(1:(iSubpop - 1)), PriorParams.jumpRateGammaHyperparams, 0, iSubpop, nParticlesEachSubpop(iSubpop) );
            end

            % Initial state.
            ModelParams(iSubpop, 1).logInitialStateDist = -inf([ModelSpec.stateSpaceTransformation(iSubpop), nParticlesEachSubpop(iSubpop)]);
            ModelParams(iSubpop, 1).logInitialStateDist(1, :) = 0;
% ModelParams(iSubpop, 1).logInitialStateDist = -log(repmat(ModelParams(iSubpop, 1).nStatesThisBlock, [ModelParams(iSubpop, 1).nStatesThisBlock, nParticlesEachSubpop(iSubpop)]));
        else
            % Number of states for this block.
            ModelParams(iSubpop, 1).nStatesThisBlock = iSubpop;

            % Generator and jump rate.
            if ModelSpec.modelIndicator == 2,
                [ModelParams(iSubpop, 1).generatorMat, ModelParams(iSubpop, 1).jumpRate] = samplegeneratormatrix( ModelSpec, PriorParams.transitionRatesDirichletHyperparams(1:(iSubpop - 1)), PriorParams.jumpRateGammaHyperparams, 0, iSubpop, nParticlesEachSubpop(iSubpop) );
            end

            % Initial state.
            if ModelSpec.initialStateModelIndicator == 1,
                ModelParams(iSubpop, 1).logInitialStateDist = -log(repmat(iSubpop, [iSubpop, nParticlesEachSubpop(iSubpop)]));
            elseif ModelSpec.initialStateModelIndicator == 2,
                ModelParams(iSubpop, 1).logInitialStateDist = computemjpstationarydistribution( ModelParams(iSubpop, 1).generatorMat, iSubpop, nParticlesEachSubpop(iSubpop) );
            elseif ModelSpec.initialStateModelIndicator == 3,
                ModelParams(iSubpop, 1).logInitialStateDist = sampleinitialstatedistribution( PriorParams.initialStateDistDirichletHyperparams(1:iSubpop), 0, iSubpop, nParticlesEachSubpop(iSubpop) );
            else
                error('Invalid initialStateModelIndicator!')
            end
        end
            
        if ModelSpec.modelIndicator == 2,
            % Dominating rate.
            ModelParams(iSubpop, 1).dominatingRate = max(abs(ModelParams(iSubpop, 1).jumpRate), [], 1)' * ModelSpec.dominatingRateScalingFactor;

            % Dominating process transition matrix.
            ModelParams(iSubpop, 1).logTransitionMat = formdominatingprocesstransitionmat( ModelParams(iSubpop, 1).generatorMat, ModelParams(iSubpop, 1).dominatingRate );
        elseif ModelSpec.modelIndicator == 1,
            ModelParams(iSubpop, 1).logTransitionMat = samplehmmtransitionmatrix( ModelSpec, zeros([iSubpop, iSubpop, nParticlesEachSubpop(iSubpop)]), false([iSubpop, iSubpop, nParticlesEachSubpop(iSubpop)]), PriorParams.transitionRatesDirichletHyperparams(1:iSubpop), iSubpop, nParticlesEachSubpop(iSubpop) );
        end

        % Spike train model parameters.
        if ModelSpec.spikeTrainModelIndicator > 0,
            if ModelSpec.modelIndicator == 1,
                ModelParams(iSubpop, 1).logSpikeRate = samplediscretetimespikerate( ModelSpec, PriorParams.spikeRateGammaHyperparams, 0, DataSpec.nSpikeTrains, iSubpop, nParticlesEachSubpop(iSubpop) );
            elseif ModelSpec.modelIndicator == 2,
                ModelParams(iSubpop, 1).logSpikeRate = samplespikerate( PriorParams.spikeRateGammaHyperparams, 0, DataSpec.nSpikeTrains, iSubpop, nParticlesEachSubpop(iSubpop) );
            end
        end

        % Position model parameters.
        if ModelSpec.positionModelIndicator > 0,
            
            if ModelSpec.modelIndicator == 1,
                ModelParams(iSubpop, 1).driftRate = zeros([iSubpop, nParticlesEachSubpop(iSubpop)]);
            else
                % Drift rates in the AR(1) parameterisation. Uniformly
                % distributed in prior.
                % One state is a random walk: drift rate of 1.
                if ModelSpec.randomWalkState > 0,
%                     ModelParams(iSubpop, 1).driftRate = ones([iSubpop, nParticlesEachSubpop(iSubpop)]);
%                     ModelParams(iSubpop, 1).driftRate([1:(ModelSpec.randomWalkState - 1), (ModelSpec.randomWalkState + 1):end], :) = rand([iSubpop - 1, nParticlesEachSubpop(iSubpop)]) .* (PriorParams.driftRateUniformHyperparams(2) - PriorParams.driftRateUniformHyperparams(1)) + PriorParams.driftRateUniformHyperparams(1);
                    % Value for Jasper.
%                     ModelParams(iSubpop, 1).driftRate([1:(ModelSpec.randomWalkState - 1), (ModelSpec.randomWalkState + 1):end], :) = 8.6619e-01 * ones([iSubpop - 1, nParticlesEachSubpop(iSubpop)]);
                    % Value for Ibsen.
                    ModelParams(iSubpop, 1).driftRate([1:(ModelSpec.randomWalkState - 1), (ModelSpec.randomWalkState + 1):end], :) = 8.5966e-01 * ones([iSubpop - 1, nParticlesEachSubpop(iSubpop)]);
                else
%                     ModelParams(iSubpop, 1).driftRate = rand([iSubpop, nParticlesEachSubpop(iSubpop)]) .* (PriorParams.driftRateUniformHyperparams(2) - PriorParams.driftRateUniformHyperparams(1)) + PriorParams.driftRateUniformHyperparams(1);
                    % Value for Jasper.
%                     ModelParams(iSubpop, 1).driftRate = 8.6619e-01 * ones([iSubpop, nParticlesEachSubpop(iSubpop)]);
                    % Value for Ibsen.
                    ModelParams(iSubpop, 1).driftRate = 8.5966e-01 * ones([iSubpop, nParticlesEachSubpop(iSubpop)]);
                end
            end
%             % Instantaneous drift rate.
%             ModelParams(iSubpop, 1).instDriftRate = -log(ModelParams(iSubpop, 1).driftRate) ./ DataSpec.posObsInterval;
            
            % Asymptotic mean positions: linear coordinate of discrete
            % position.
            ModelParams(iSubpop, 1).asympMeanPosInd = samplefromcategoricaldistributions(repmat(permute(PriorParams.asympMeanPriorDist, [3, 2, 1]), [iSubpop, nParticlesEachSubpop(iSubpop), 1]), 3);
            % One state is a random walk state. We set mean position to 1
            % arbitrarily.
            if ModelSpec.randomWalkState > 0,
                ModelParams(iSubpop, 1).asympMeanPosInd(ModelSpec.randomWalkState, :) = 1;
            end
                        
            % 'Jitter' model.
            if (ModelSpec.positionJitterModelIndicator == 0) || (ModelSpec.positionJitterModelIndicator == 1) || (ModelSpec.positionJitterModelIndicator == 2),
                % Independent spatial dimensions. Covariances are stored as
                % matrices; covariance parameters in columns with 2 rows (Y
                % dim, X dim) rather than a 2x2 matrix (off-diagonals not
                % needed).

                priorHyperparams = repmat(permute(PriorParams.jitterPrecisionGammaHyperparams, [1, 4, 3, 2]), [2, iSubpop, nParticlesEachSubpop(iSubpop), 1]);
                
                % This is the jitter covariance of the AR(1) model with
                % fixed time interval. It has prior dependence on drift
                % rate, and is conditionally uniform.
                if ModelSpec.positionJitterModelIndicator == 0,
                    % Fixed.
                    minDimSize = min([DataSpec.discreteYDim, DataSpec.discreteXDim]);
                    jitterVar = ModelSpec.stateFieldSizeAsProportionOfShortestMazeDim * minDimSize;
                    ModelParams(iSubpop, 1).jitterCov = repmat(jitterVar, [2, iSubpop, nParticlesEachSubpop(iSubpop)]);
                
                elseif ModelSpec.positionJitterModelIndicator == 2,
                    % Anisotropic model.
                    ModelParams(iSubpop, 1).jitterCov = gamrnd(priorHyperparams(:, :, :, 1), 1 ./ priorHyperparams(:, :, :, 2));

                    % Convert to variance.
                    ModelParams(iSubpop, 1).jitterCov = 1 ./ ModelParams(iSubpop, 1).jitterCov;

                    % Add minimum value.
                    ModelParams(iSubpop, 1).jitterCov = bsxfun(@plus, ModelParams(iSubpop, 1).jitterCov, permute(1 + ModelParams(iSubpop, 1).driftRate .^ 2, [3, 1, 2]));

                elseif ModelSpec.positionJitterModelIndicator == 1,
                    % Isotropic model.
                    ModelParams(iSubpop, 1).jitterCov = repmat(gamrnd(priorHyperparams(1, :, :, 1), 1 ./ priorHyperparams(1, :, :, 2)), [2, 1, 1]);

                    % Convert to variance.
                    ModelParams(iSubpop, 1).jitterCov = 1 ./ ModelParams(iSubpop, 1).jitterCov;

                    % Add minimum value.
                    if ModelSpec.positionModelIndicator == 1,
                        ModelParams(iSubpop, 1).jitterCov = bsxfun(@plus, ModelParams(iSubpop, 1).jitterCov, permute(1 + ModelParams(iSubpop, 1).driftRate .^ 2, [3, 1, 2]));
                    end

                end
          
                % Total covariance.
                ModelParams(iSubpop, 1).posCov = ModelParams(iSubpop, 1).jitterCov;
            elseif (ModelSpec.modelIndicator == 1) && (ModelSpec.positionJitterModelIndicator == 3),
                ModelParams(iSubpop, 1).jitterCov = samplecovariancematrix( PriorParams, iSubpop, nParticlesEachSubpop(iSubpop), [], [], [] );

                % Total covariance.
                ModelParams(iSubpop, 1).posCov = ModelParams(iSubpop, 1).jitterCov;
            else
                error('Invalid jitter model!')
            end

            % Parameters specific to continuous observations model.
            if ModelSpec.modelIndicator == 2,
                % Total covariance - incorporate observations variance.
%                 ModelParams(iSubpop, 1).posCov = bsxfun(@times, ModelParams(iSubpop, 1).jitterCov, permute(ModelParams(iSubpop, 1).posObsVar, [3, 2, 1]));
                ModelParams(iSubpop, 1).posCov = ModelParams(iSubpop, 1).jitterCov;
            end

        end
    end

    % Remove empty blocks.
    emptyBlockBinInds = [ModelParams.nParticlesThisBlock] > 0;
    ModelParams = ModelParams(emptyBlockBinInds);
end