WGD inference using reversible-jump MCMC and gene count data

Note

Please note that all analyses below are on a small test data set and use short MCMC runs. This is of course not recommended in practice, but simply to make generating this page feasible in reasonable time.

Load Beluga and required packages:

using Beluga, CSV, Distributions, Random
Random.seed!(23031964)
Random.MersenneTwister(UInt32[0x015f709c], Random.DSFMT.DSFMT_state(Int32[-767906194, 1073601484, 1994176309, 1073658613, 111759256, 1073640127, 319734154, 1073527007, 1773971897, 1073085338  …  -1224132512, 1072863464, 385170543, 1073076045, -1232795947, 1995616467, 149752818, -2058204503, 382, 0]), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], UInt128[0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000  …  0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000], 1002, 0)

Or, if you are running a julia session with multiple processes (when you started julia with -p option, or manually added workers using addprocs, see julia docs), run:

using CSV, Distributions
@everywhere using Beluga  # if julia is running in a parallel environment

Then get some data (for instance from the example directory of the git repo)

nw = readline(joinpath(@__DIR__, "../../example/9dicots/9dicots.nw"))
df = CSV.read(joinpath(@__DIR__, "../../example/9dicots/9dicots-f01-25.csv"))
model, data = DLWGD(nw, df, 1.0, 1.2, 0.9)
(model = DLWGD{Float64,NewickTree.TreeNode{Beluga.Branch{Float64}}}(17), data = PArray{Float64}(16))

model now refers to the duplication-loss and WGD model (with no WGDs for now), data refers to the phylogenetic profile matrix. The model was initialized with all duplication and loss rates set to 1 and 1.2 respectively. You can check this easily:

getrates(model)
2×17 Array{Float64,2}:
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  …  1.0  1.0  1.0  1.0  1.0  1.0  1.0
 1.2  1.2  1.2  1.2  1.2  1.2  1.2  1.2     1.2  1.2  1.2  1.2  1.2  1.2  1.2

or to get the full parameter vector:

asvector(model)
35-element Array{Float64,1}:
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 ⋮
 1.2
 1.2
 1.2
 1.2
 1.2
 1.2
 1.2
 1.2
 0.9

Now you can easily compute log-likelihoods (and gradients thereof)

logpdf!(model, data)
-251.0358357105553

so we can do likelihood-based inference (either maximum-likelihood or Bayesian). For the kind of problems tackled here, the only viable option however is Bayesian inference.

We proceed by specifying the hierarchical prior on the gene family evolutionary process. There is no DSL available (à la Turing.jl, Mamba.jl, Soss.jl or Stan) but we use a fairly flexible prior struct. Here is an exaple for the (recommended) bivariate independent rates (IR) prior:

prior = IRRevJumpPrior(
    Ψ=[1 0. ; 0. 1],
    X₀=MvNormal([0., 0.], [1 0. ; 0. 1]),
    πK=DiscreteUniform(0,20),
    πq=Beta(1,1),
    πη=Beta(3,1),
    Tl=treelength(model))
IRRevJumpPrior{Distributions.MvNormal{Float64,PDMats.PDMat{Float64,Array{Float64,2}},Array{Float64,1}},Distributions.Beta{Float64},Distributions.Beta{Float64},Distributions.DiscreteUniform,Nothing}
  Ψ: Array{Float64}((2, 2)) [1.0 0.0; 0.0 1.0]
  X₀: Distributions.MvNormal{Float64,PDMats.PDMat{Float64,Array{Float64,2}},Array{Float64,1}}
  πη: Distributions.Beta{Float64}
  πq: Distributions.Beta{Float64}
  πK: Distributions.DiscreteUniform
  πE: Nothing nothing
  equal: Bool false
  Tl: Float64 0.8565633076700001

Ψ is the prior covariance matrix for the Inverse-Wishart distribution. X₀ is the multivariate Normal prior on the mean duplication and loss rates. πK is the prior on the number of WGDs (i.e. the model indicator). πq is the Beta prior on the retention rates (iid). πη is the hyperprior on the parameter of the geometric distribution on the number of ancestral lineages at the root. Tl is the tree length (and is used for the prior on the WGD ages).

To sample across model-space (i.e. where we infer the number and locations of WGDs), we need the reversible jump algorithm. There are several reversible-jump kernels implemented. Each reversible jump kernel introduces a WGD with probability 1/2 (forward jump) or removes a WGD with probability 1/2 (reverse jump). The simplest is the aptly named SimpleKernel, which introduces new WGDs with a random retention rate drawn from a Beta distribution.

kernel = SimpleKernel(qkernel=Beta(1,3))
SimpleKernel
  qkernel: Distributions.Beta{Float64}
  accepted: Int64 0

The differences between the kernels is the parameter updates that are applied during these jumps. Kernel 1 (SimpleKernel) only introduces/removes a WGD and does not modify the duplication and loss rates on the associated branch, kernel 2 (DropKernel) introduces/removes a WGD and concomitantly decreases/increases the duplication rate by a random amount and kernel 3 (BranchKernel) introduces/removes a WGD, decreases/increases the duplication rate and increases/decreases the loss rate (randomly). As, in contrast with ordinary Metropolis-Hastings MCMC proposals, tuning of the reversible-jump kernel is not possible automatically (or at least, I'm unsure how it should be done), it can be worthwhile to do some initial short pilot runs on a small data set to figure out a parameterization that results in a reasonably high jump probability (pjmp in the output). Usually, using the same distribution for the qkernel as the prior on q is a good idea.

We can now construct a chain object bundling everything

chain = RevJumpChain(data=data, model=model, prior=prior, kernel=kernel)
init!(chain)

and sample from it:

@time rjmcmc!(chain, 1000, show=50)
|    gen, pjmp,  k,k2,k3,    λ1,    λ2,    λ3,    μ1,    μ2,    μ3 …       logp,    logπ,     η
|     50, 0.08,  0, 0, 0, 1.913, 0.978, 0.765, 1.114, 4.488, 0.720 ⋯   -207.388, -49.787, 0.945
|    100, 0.10,  2, 0, 1, 0.865, 1.226, 0.170, 0.570, 0.959, 0.171 ⋯   -198.193, -57.517, 0.862
|    150, 0.12,  2, 0, 0, 2.201, 4.167, 0.649, 0.838, 0.749, 0.303 ⋯   -189.533, -41.007, 0.846
|    200, 0.10,  2, 0, 0, 1.371, 2.991, 1.561, 0.916, 0.712, 0.529 ⋯   -182.993, -39.238, 0.972
|    250, 0.12,  1, 0, 0, 1.076, 0.778, 1.488, 0.920, 0.822, 0.849 ⋯   -190.102, -41.240, 0.959
|    300, 0.11,  2, 0, 0, 1.375, 2.026, 1.070, 1.273, 0.630, 0.797 ⋯   -184.317, -39.010, 0.964
|    350, 0.12,  2, 0, 0, 0.825, 1.253, 0.524, 0.714, 0.902, 0.189 ⋯   -181.899, -47.138, 0.865
|    400, 0.13,  1, 0, 0, 1.158, 0.772, 1.603, 0.723, 1.048, 0.657 ⋯   -180.693, -42.574, 0.967
|    450, 0.12,  1, 0, 0, 1.500, 1.644, 0.334, 0.495, 0.283, 0.304 ⋯   -180.586, -55.389, 0.965
|    500, 0.12,  3, 0, 0, 0.806,10.844, 0.389, 0.569, 0.449, 0.209 ⋯   -175.986, -52.798, 0.956
|    550, 0.12,  2, 0, 0, 0.701,12.669, 0.168, 0.420, 0.703, 0.205 ⋯   -179.517, -68.020, 0.928
|    600, 0.12,  2, 0, 0, 0.925, 2.706, 0.668, 0.467, 0.100, 0.047 ⋯   -181.683, -67.677, 0.831
|    650, 0.12,  2, 0, 1, 1.086, 8.474, 0.981, 0.585, 0.714, 0.304 ⋯   -178.284, -59.941, 0.980
|    700, 0.13,  2, 0, 0, 1.458, 0.840, 0.580, 0.716, 0.756, 0.712 ⋯   -186.357, -36.212, 0.841
|    750, 0.13,  4, 0, 0, 1.480, 1.659, 0.875, 0.633, 0.101, 0.073 ⋯   -184.229, -47.909, 0.992
|    800, 0.13,  2, 0, 0, 1.834, 2.763, 1.373, 0.771, 0.307, 0.222 ⋯   -181.864, -51.067, 0.969
|    850, 0.13,  3, 0, 0, 1.286, 0.698, 0.959, 0.360, 0.309, 0.203 ⋯   -181.870, -48.371, 0.953
|    900, 0.13,  3, 0, 0, 1.755, 1.710, 1.236, 0.619, 0.192, 0.248 ⋯   -185.061, -41.456, 0.959
|    950, 0.13,  5, 0, 0, 1.282, 2.278, 1.050, 1.081, 0.856, 1.163 ⋯   -186.281, -27.840, 0.965
|   1000, 0.13,  4, 0, 0, 1.552, 9.759, 1.205, 0.897, 0.585, 0.629 ⋯   -182.143, -30.877, 0.898
 36.664107 seconds (58.34 M allocations: 10.376 GiB, 6.67% gc time)

This will log a part of the trace to stdout every show iterations, so that we're able to monitor a bit whether everything looks sensible. Of course in reality you would sample way longer than n=1000 iterations, but since this page has to be generated in decent time using a single CPU I'll keep it to 1000 here.

Now the computer has done Bayesian inference, and we have to do our part. We can analyze the trace (in chain.trace), write it to a file, compute statistics, etc. Here are some trace plots:

using Plots, DataFrames, LaTeXStrings
burnin=100
p1 = plot(chain.trace[burnin:end,:λ1], label=L"\lambda_1")
plot!(chain.trace[burnin:end,:μ1], label=L"\mu_1")
p2 = plot(chain.trace[burnin:end,:λ8], label=L"\lambda_8")
plot!(chain.trace[burnin:end,:μ8], label=L"\mu_8")
p3 = plot(chain.trace[!,:k], label=L"k")
p4 = plot(chain.trace[!,:η1], label=L"\eta")
plot(p1,p2,p3,p4, grid=false, layout=(2,2))

Clearly, we should sample way longer to get decent estimates for the duplication rates (λ), loss rates (μ) and number of WGDs (k). Note how η is quite well-sampled already.

We can also check the effective sample size (ESS) of the model indicator variable, for that we will use the method of Heck et al. (2019) implemented in the module DiscreteMarkovFit.jl:

using DiscreteMarkovFit

We'll discard a burnin of 100 iterations

d = ObservedBirthDeathChain(Array(chain.trace[100:end,:k]))
out = DiscreteMarkovFit.sample(d, 10000)
ESS = 51.50191747476654
  π = 
 ⋅2 => (mean = 0.276, std = 0.072, q025 = 0.147, q0975 = 0.43)
 ⋅3 => (mean = 0.299, std = 0.053, q025 = 0.198, q0975 = 0.406)
 ⋅4 => (mean = 0.238, std = 0.052, q025 = 0.145, q0975 = 0.345)
 ⋅5 => (mean = 0.146, std = 0.051, q025 = 0.066, q0975 = 0.264)
 ⋅6 => (mean = 0.023, std = 0.016, q025 = 0.005, q0975 = 0.063)
 ⋅7 => (mean = 0.018, std = 0.027, q025 = 0.001, q0975 = 0.09)

This shows the effective sample size for the number of WGDs and the associated posterior probabilities. The maximum a posteriori (MAP) number of WGDs here is three. When doing a serious analysis, one should aim for higher ESS values of course. Note that if one is interested in WGDs for a specific branch, it is also relevant to look at that variable, for instance for the poplar branch

d = ObservedBirthDeathChain(Array(chain.trace[100:end,:k10]))
out = DiscreteMarkovFit.sample(d, 10000)
ESS = 31.45826167701062
  π = 
 ⋅1 => (mean = 0.671, std = 0.102, q025 = 0.447, q0975 = 0.849)
 ⋅2 => (mean = 0.312, std = 0.097, q025 = 0.143, q0975 = 0.521)
 ⋅3 => (mean = 0.017, std = 0.012, q025 = 0.004, q0975 = 0.049)

We can also compute Bayes factors to get an idea of the number of WGDs for each branch in the species tree.

bfs = bayesfactors(chain, burnin=100);
🌲  2: (vvi,cpa,ath,ptr,mtr)

______________________________________________________________________________
🌲  3: (vvi)
[ 1 vs.  0] K = (0.02/0.98) ÷ (0.29/0.35) =    0.021 [log₁₀(K) =   -1.684]

[≥1 vs.  0] K = (0.02/0.98) ÷ (0.65/0.35) =    0.009 [log₁₀(K) =   -2.040]
______________________________________________________________________________
🌲  4: (cpa,ath,ptr,mtr)
[ 1 vs.  0] K = (0.03/0.97) ÷ (0.04/0.95) =    0.721 [log₁₀(K) =   -0.142]

[≥1 vs.  0] K = (0.03/0.97) ÷ (0.05/0.95) =    0.700 [log₁₀(K) =   -0.155]
______________________________________________________________________________
🌲  5: (cpa,ath)

______________________________________________________________________________
🌲  6: (ath)
[ 1 vs.  0] K = (0.06/0.94) ÷ (0.30/0.49) =    0.102 [log₁₀(K) =   -0.993]

[≥1 vs.  0] K = (0.06/0.94) ÷ (0.51/0.49) =    0.061 [log₁₀(K) =   -1.216]
______________________________________________________________________________
🌲  7: (cpa)
[ 1 vs.  0] K = (0.01/0.99) ÷ (0.30/0.49) =    0.013 [log₁₀(K) =   -1.895]

[≥1 vs.  0] K = (0.01/0.99) ÷ (0.51/0.49) =    0.008 [log₁₀(K) =   -2.118]
______________________________________________________________________________
🌲  8: (ptr,mtr)
[ 1 vs.  0] K = (0.23/0.77) ÷ (0.06/0.93) =    4.422 [log₁₀(K) =    0.646] *

[≥1 vs.  0] K = (0.23/0.77) ÷ (0.07/0.93) =    4.230 [log₁₀(K) =    0.626] *
______________________________________________________________________________
🌲  9: (mtr)
[ 1 vs.  0] K = (0.13/0.86) ÷ (0.29/0.38) =    0.191 [log₁₀(K) =   -0.719]
[ 2 vs.  1] K = (0.01/0.13) ÷ (0.18/0.29) =    0.181 [log₁₀(K) =   -0.743]

[≥1 vs.  0] K = (0.14/0.86) ÷ (0.62/0.38) =    0.101 [log₁₀(K) =   -0.998]
[≥2 vs.  1] K = (0.01/0.13) ÷ (0.33/0.29) =    0.101 [log₁₀(K) =   -0.994]
______________________________________________________________________________
🌲 10: (ptr)
[ 1 vs.  0] K = (0.32/0.66) ÷ (0.29/0.38) =    0.630 [log₁₀(K) =   -0.201]
[ 2 vs.  1] K = (0.02/0.32) ÷ (0.18/0.29) =    0.077 [log₁₀(K) =   -1.115]

[≥1 vs.  0] K = (0.34/0.66) ÷ (0.62/0.38) =    0.312 [log₁₀(K) =   -0.505]
[≥2 vs.  1] K = (0.02/0.32) ÷ (0.33/0.29) =    0.043 [log₁₀(K) =   -1.366]
______________________________________________________________________________
🌲 11: (ugi,sly,bvu,cqu)

______________________________________________________________________________
🌲 12: (bvu,cqu)
[ 1 vs.  0] K = (0.00/1.00) ÷ (0.30/0.49) =    0.007 [log₁₀(K) =   -2.139]

[≥1 vs.  0] K = (0.00/1.00) ÷ (0.51/0.49) =    0.004 [log₁₀(K) =   -2.362]
______________________________________________________________________________
🌲 13: (bvu)
[ 1 vs.  0] K = (0.03/0.97) ÷ (0.27/0.63) =    0.069 [log₁₀(K) =   -1.164]

[≥1 vs.  0] K = (0.03/0.97) ÷ (0.37/0.63) =    0.050 [log₁₀(K) =   -1.305]
______________________________________________________________________________
🌲 14: (cqu)
[ 1 vs.  0] K = (0.92/0.05) ÷ (0.27/0.63) =   40.872 [log₁₀(K) =    1.611] **
[ 2 vs.  1] K = (0.02/0.92) ÷ (0.08/0.27) =    0.087 [log₁₀(K) =   -1.059]

[≥1 vs.  0] K = (0.95/0.05) ÷ (0.37/0.63) =   30.294 [log₁₀(K) =    1.481] **
[≥2 vs.  1] K = (0.02/0.92) ÷ (0.10/0.27) =    0.069 [log₁₀(K) =   -1.163]
______________________________________________________________________________
🌲 15: (ugi,sly)
[ 1 vs.  0] K = (0.00/1.00) ÷ (0.22/0.73) =    0.011 [log₁₀(K) =   -1.949]

[≥1 vs.  0] K = (0.00/1.00) ÷ (0.27/0.73) =    0.009 [log₁₀(K) =   -2.040]
______________________________________________________________________________
🌲 16: (ugi)
[ 1 vs.  0] K = (0.06/0.94) ÷ (0.30/0.43) =    0.091 [log₁₀(K) =   -1.042]

[≥1 vs.  0] K = (0.06/0.94) ÷ (0.57/0.43) =    0.048 [log₁₀(K) =   -1.315]
______________________________________________________________________________
🌲 17: (sly)
[ 1 vs.  0] K = (0.29/0.68) ÷ (0.30/0.43) =    0.602 [log₁₀(K) =   -0.220]
[ 2 vs.  1] K = (0.04/0.29) ÷ (0.16/0.30) =    0.229 [log₁₀(K) =   -0.641]
[ 3 vs.  2] K = (0.00/0.04) ÷ (0.07/0.16) =    0.074 [log₁₀(K) =   -1.131]

[≥1 vs.  0] K = (0.32/0.68) ÷ (0.57/0.43) =    0.363 [log₁₀(K) =   -0.441]
[≥2 vs.  1] K = (0.04/0.29) ÷ (0.26/0.30) =    0.146 [log₁₀(K) =   -0.834]
[≥3 vs.  2] K = (0.00/0.04) ÷ (0.10/0.16) =    0.051 [log₁₀(K) =   -1.292]
______________________________________________________________________________

This suggests strong support for WGD in quinoa (cqu), for which we know the genome shows a strong signature of an ancestral WGD. Note that we already detect these WGDs using a mere 25 gene families as data!

A plot of the posterior probabilities for the number of WGDs on each branch is a nice way to summarize the rjMCMC output:

plots = [bar(g[!,:k], g[!,:p1], color=:white,
            title=join(string.(g[1,:clade]), ", "))
            for g in groupby(bfs, :branch)]
xlabel!.(plots[end-3:end], L"k")
ylabel!.(plots[1:4:end], L"P(k|X)")
plot(plots..., grid=false, legend=false,
    ylim=(0,1), xlim=(-0.5,3.5),
    yticks=[0, 0.5, 1], xticks=0:3,
    title_loc=:left, titlefont=8)

This page was generated using Literate.jl.