Sampling from multimodal distributions
Let say you have a multimodal distribution, for example a mixture of two Gaussians. DifferentialEvolutionMetropolis implements differential evolution MCMC samplers (including deMC-zs and DREAMz) that are designed to sample from such distributions efficiently. Roughly these samplers work by generating new proposals based on many separate chains (or a history of sampled chains). In theory this allows the sampler to easily jump between modes of the distribution.
Multimodal Distributions
First we need to implement a multimodal problem. We will assume that our data is generated by a mixture of two 2D-Gaussians with means 5 and -5 and a 65% bias towards the positive value. Then the distributions of the means will be multimodal and highly correlated with the probability. We can easily implement this using the Distributions package, which underpins most of the functionality in DifferentialEvolutionMetropolis.
using Distributions, Random
mean = 5.0
std = 2.0
p = 0.65
dist = MixtureModel([
Normal(mean, std),
Normal(-mean, std)
], [p, 1 - p]);
#generate our data
Random.seed!(1234)
data = rand(dist, 1000)
# setup our log density
struct MixtureNormal
#data
data::Vector{Float64}
#prior parameters
prior_μ_σ::Float64
prior_σ_σ::Float64
end
#initialize the model
model = MixtureNormal(data, 10.0, 10.0)
# implement the log density function
function (problem::MixtureNormal)(θ)
(; μ, p, σ) = θ
ld =
#priors
logpdf(Normal(0, problem.prior_σ_σ), σ) + #half-normal because σ > 0
logpdf(Uniform(0, 1), p) #prior on mixture weight
μ_prior = Normal(0, problem.prior_μ_σ)
for μ_ in μ
ld += logpdf(μ_prior, μ_)
end
#likelihood
ll_model = MixtureModel(
[
Normal(μ[1], σ),
Normal(μ[2], σ)
], [p, 1 - p]
)
for i in axes(problem.data, 1)
ld += logpdf(ll_model, data[i])
end
return ld
end
# to visualize our data
using Plots
plot(histogram(data, bins = 30, xlabel = "Value", ylabel = "Frequency", title = "Histogram of Generated Data"))We can also transform our log density function, so we can provide real-valued inputs. This is much easier to work with.
using TransformedLogDensities, TransformVariables
transformation = as((μ = as(Array, 2), p = as𝕀, σ = asℝ₊))
transformed_ld = TransformedLogDensity(transformation, model)TransformedLogDensity of dimension 4Sampling with DifferentialEvolutionMetropolis
Now let's use DifferentialEvolutionMetropolis to sample from this multimodal distribution. Here we use the DREAMz sampler, which is well-suited for exploring complex, multimodal spaces. We increase the number of chains to allow the sampler to explore the distribution more effectively.
using DifferentialEvolutionMetropolis, AbstractMCMC
model = AbstractMCMC.LogDensityModel(transformed_ld)
# Sample using DREAMz with adaptive stopping based on convergence
dreamz = DREAMz(model, 10000; n_chains = 6, progress = true);DifferentialEvolutionOutput{Float64}([5.21998534854409 -4.894976983254473 … -5.087742914695565 -4.83442570392559; 5.21998534854409 -4.894976983254473 … -5.087742914695565 -4.83442570392559; … ; 5.049000586059232 5.004631403176455 … 5.104991390939873 -4.862454865810638; 5.049000586059232 5.004631403176455 … 5.104991390939873 -4.862454865810638;;; -4.8469965554551 5.027996226638525 … 5.0465482537637545 5.150332027081797; -4.8469965554551 5.027996226638525 … 5.0465482537637545 5.150332027081797; … ; -4.858237045439179 -4.909018318833721 … -5.2001033881941785 5.0627852246742995; -4.858237045439179 -4.909018318833721 … -5.2001033881941785 5.0627852246742995;;; 0.5792210658034845 -0.4847191413479646 … -0.46518688539725794 -0.6294469099417428; 0.5792210658034845 -0.4847191413479646 … -0.46518688539725794 -0.6294469099417428; … ; 0.5741268883486262 0.5878262939407015 … 0.641089187077828 -0.6207524068346807; 0.5741268883486262 0.5878262939407015 … 0.641089187077828 -0.6207524068346807;;; 0.6787448634143769 0.6394147922239041 … 0.7004646662467714 0.6818497234288529; 0.6787448634143769 0.6394147922239041 … 0.7004646662467714 0.6818497234288529; … ; 0.7091680027328165 0.6779301330123079 … 0.6756224919974223 0.6423091716582922; 0.7091680027328165 0.6779301330123079 … 0.6756224919974223 0.6423091716582922], [-2740.2668735143643 -2740.8953492429764 … -2742.8792654864524 -2739.478441961077; -2740.2668735143643 -2740.8953492429764 … -2742.8792654864524 -2739.478441961077; … ; -2740.036251028598 -2739.080494208638 … -2743.3906347740226 -2739.6092420635537; -2740.036251028598 -2739.080494208638 … -2743.3906347740226 -2739.6092420635537])Other implementations of the differential evolution MCMC algorithm are available in DifferentialEvolutionMetropolis.jl, such as deMC and deMCzs, which can be used similarly.
Custom Scheme
DREAMz can be further customized. For example, we could include snooker updates alongside the DREAMz-like subspace sampling.
You can also modify aspects of the implemented sampling, for example tell DREAMz to use non-memory-based sampling with DREAMz(...; memory = false), or you can define your own sampler scheme for more control over the sampling process.
This time we will use MCMCChains.jl to the handle the output.
using MCMCChains
# Create a custom sampler scheme combining different update types
custom_sampler = setup_sampler_scheme(
setup_subspace_sampling(), # a DREAM-like sampler that uses subspace sampling
setup_snooker_update(deterministic_γ = false), # a snooker update for better exploration
setup_de_update(); # standard DE update
w = [0.6, 0.2, 0.2] # weights for each update type
);
# Sample using AbstractMCMC.sample with custom stopping criteria
custom_results = sample(
model,
custom_sampler,
r̂_stopping_criteria;
check_every = 10000,
maximum_R̂ = 1.05,
n_chains = 4,
N₀ = 100, # want good exploration of the space
memory = true,
parallel = true,
annealing = true,
num_warmup = 10000,
memory_size = 5000,
memory_refill = true,
chain_type = MCMCChains.Chains
);Chains MCMC chain (9999×5×4 Array{Float64, 3}):
Iterations = 1:1:9999
Number of chains = 4
Samples per chain = 9999
parameters = param_1, param_2, param_3, param_4, ld
Use `describe(chains)` for summary statistics and quantiles.
You can also define your own samplers for more specialized use cases by extending the abstract types.
Interpreting Results
After running the sampler, you will have a collection of samples from the target distribution. These samples can be used to estimate summary statistics, credible intervals, and to assess the quality of your sampling.
Assessing Sampler Performance: ESS and R-hat
To evaluate how well your sampler is performing, you can compute the effective sample size (ESS) and the R-hat diagnostic. These metrics help you determine if your chains have mixed well and if your estimates are reliable.
- Effective Sample Size (ESS): This measures the number of independent samples your chains are equivalent to. Higher ESS values indicate more reliable estimates.
- R-hat Diagnostic: Also known as the Gelman-Rubin statistic, R-hat compares the variance within each chain to the variance between chains. Values close to 1 suggest good mixing and convergence; values much greater than 1 indicate potential problems.
Below is an example of how to compute these diagnostics for the DifferentialEvolutionMetropolis samplers using MCMCDiagnosticTools:
using Statistics, MCMCDiagnosticTools
# Compute diagnostics for DREAMz results
ess_val = ess(dreamz.samples) ./ size(dreamz.samples, 1)
rhat_val = maximum(rhat(dreamz.samples))
println("DREAMz diagnostics:")
println(" ESS per iteration: $ess_val")
println(" R-hat: $rhat_val")
# example trace plot
plot(dreamz.samples[:, :, 1], xlabel = "Iteration", ylabel = "Value", title = "Trace Plot for Mean 1 (DREAMz)")Summarizing Posterior Samples
Once you have confirmed good mixing and convergence, you can summarize your posterior samples. For each parameter, you may want to compute the median and a credible interval (such as the 90% interval):
#need to transform the samples to original space, there is currently no nice way to do this
custom_results = cat([dreamz.samples[:, c, :] for c in axes(dreamz.samples, 2)]...; dims = 1)
for it in axes(custom_results, 1)
transformed_values = transform(transformation, Array(custom_results[it, :]))
custom_results[it, :] .= [transformed_values.μ..., transformed_values.p, transformed_values.σ]
end
#median for variance
med = median(custom_results[:, 4])
q05, q95 = quantile(custom_results[:, 4], [0.05, 0.95])
println("Standard deviation: posterior median = $med, 90% CI = ($q05, $q95)")
println("True standard deviation: $std")
#not so sensible for p or means
med = median(custom_results[:, 1])
q05, q95 = quantile(custom_results[:, 1], [0.05, 0.95])
println("Mean 1: posterior median = $med, 90% CI = ($q05, $q95)")
println("True mean: $(mean)")
#visualize the posteriors, expect multimodal distributions
plot(histogram(custom_results[:, 1], bins = 30, xlabel = "Mean 1", ylabel = "Frequency"), histogram(custom_results[:, 2], bins = 30, xlabel = "Mean 2", ylabel = "Frequency"), histogram(custom_results[:, 3], bins = 30, xlabel = "Probability", ylabel = "Frequency"), layout = (3, 1))
#covariance
plot(scatter(custom_results[:, 1], custom_results[:, 3], xlabel = "Mean 1", ylabel = "Probability of Normal 1"), scatter(custom_results[:, 1], custom_results[:, 2], xlabel = "Mean 1", ylabel = "Mean 2"), layout = (2, 1))For more details, see the DifferentialEvolutionMetropolis Documentation and Customizing your sampler.