The atsar
package contains a bunch of bundled Stan files
for fitting univariate state space models in Stan, for those who want to
try those methods. There’s also a bunch of options that we didn’t talk
about in class, but can explore those in more detail here.
As a first example, we’ll use the MARSS gray whale dataset.
library(ggplot2)
library(shinystan)
## Loading required package: shiny
##
## This is shinystan version 2.6.0
library(rstan)
## Loading required package: StanHeaders
## rstan (Version 2.21.8, GitRev: 2e1f913d3ca3)
## For execution on a local, multicore CPU with excess RAM we recommend calling
## options(mc.cores = parallel::detectCores()).
## To avoid recompilation of unchanged Stan programs, we recommend calling
## rstan_options(auto_write = TRUE)
library(broom.mixed)
library(bayesplot)
## This is bayesplot version 1.10.0
## - Online documentation and vignettes at mc-stan.org/bayesplot
## - bayesplot theme set to bayesplot::theme_default()
## * Does _not_ affect other ggplot2 plots
## * See ?bayesplot_theme_set for details on theme setting
whale <- as.data.frame(MARSS::graywhales)
ggplot(whale, aes(Year,Count)) +
geom_point() + geom_smooth()
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'
## Warning: Removed 15 rows containing non-finite values (`stat_smooth()`).
## Warning: Removed 15 rows containing missing values (`geom_point()`).
An interesting question might be whether we can fit a Gompertz a
population model and estimate the long term trend. To do this, we’ll use
the atsar
package.
remotes::install_github("atsa-es/atsar")
It might be helpful to look at the documentation to see what kind of univariate models are available,
?atsar::fit_stan
For the whale example, let’s start with the ss_rw
model,
a state space univariate model. We will fit the model in log-space and
estimate drift.
library(atsar)
We will start with a few hundred samples (normally you want the burn in period to be 1000-2000)
set.seed(123)
fit <- fit_stan(y = log(whale$Count),
model_name = "ss_rw",
est_drift = TRUE,
list(n_mcmc = 500,
n_burn = 100,
n_chain = 3))
Now let’s look at the R-hat values. Generally we want these to be < 1.1. Does it look like the model converges ok?
fit
## Inference for Stan model: ss_rw.
## 3 chains, each with iter=1000; warmup=500; thin=1;
## post-warmup draws per chain=500, total post-warmup draws=1500.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## sigma_process 0.14 0.00 0.04 0.06 0.11 0.14 0.16 0.23 149 1.01
## pred[1] 8.02 0.01 0.13 7.76 7.93 8.01 8.09 8.31 632 1.00
## pred[2] 8.11 0.01 0.14 7.84 8.01 8.12 8.20 8.40 792 1.00
## pred[3] 8.21 0.00 0.10 8.02 8.14 8.21 8.28 8.42 1363 1.00
## pred[4] 8.31 0.00 0.13 8.05 8.23 8.31 8.39 8.56 1081 1.00
## pred[5] 8.41 0.00 0.10 8.21 8.35 8.41 8.47 8.60 1759 1.00
## pred[6] 8.53 0.00 0.14 8.24 8.45 8.53 8.62 8.82 1038 1.00
## pred[7] 8.63 0.00 0.14 8.34 8.54 8.64 8.71 8.94 1195 1.00
## pred[8] 8.73 0.00 0.11 8.52 8.66 8.73 8.80 8.95 2089 1.00
## pred[9] 8.86 0.01 0.16 8.54 8.75 8.85 8.96 9.19 969 1.00
## pred[10] 8.99 0.01 0.19 8.62 8.86 8.98 9.11 9.38 964 1.00
## pred[11] 9.11 0.01 0.21 8.74 8.98 9.11 9.24 9.52 957 1.01
## pred[12] 9.24 0.01 0.21 8.84 9.09 9.24 9.37 9.65 867 1.01
## pred[13] 9.36 0.01 0.20 8.99 9.23 9.35 9.49 9.78 585 1.01
## pred[14] 9.49 0.01 0.19 9.14 9.36 9.48 9.62 9.87 433 1.01
## pred[15] 9.62 0.01 0.16 9.27 9.52 9.64 9.73 9.87 174 1.01
## pred[16] 9.54 0.01 0.14 9.25 9.44 9.54 9.63 9.80 367 1.00
## pred[17] 9.45 0.00 0.09 9.27 9.39 9.45 9.51 9.64 1342 1.00
## pred[18] 9.43 0.00 0.09 9.26 9.37 9.44 9.49 9.60 1241 1.00
## pred[19] 9.37 0.00 0.09 9.21 9.31 9.37 9.43 9.56 728 1.00
## pred[20] 9.36 0.01 0.10 9.17 9.29 9.36 9.44 9.57 285 1.01
## pred[21] 9.58 0.01 0.10 9.37 9.52 9.58 9.65 9.77 370 1.00
## pred[22] 9.58 0.00 0.09 9.40 9.52 9.58 9.63 9.75 1203 1.00
## pred[23] 9.54 0.00 0.09 9.37 9.49 9.55 9.60 9.72 1374 1.00
## pred[24] 9.60 0.00 0.08 9.44 9.55 9.60 9.65 9.77 2034 1.00
## pred[25] 9.65 0.00 0.09 9.48 9.60 9.65 9.71 9.81 1988 1.00
## pred[26] 9.68 0.00 0.09 9.50 9.63 9.68 9.73 9.85 1232 1.00
## pred[27] 9.61 0.00 0.09 9.44 9.54 9.61 9.67 9.81 580 1.00
## pred[28] 9.70 0.00 0.09 9.51 9.64 9.70 9.76 9.88 1734 1.00
## pred[29] 9.75 0.00 0.15 9.45 9.66 9.75 9.84 10.05 1021 1.00
## pred[30] 9.79 0.00 0.17 9.45 9.68 9.79 9.90 10.14 1323 1.00
## pred[31] 9.84 0.01 0.18 9.48 9.72 9.85 9.96 10.19 1055 1.00
## pred[32] 9.89 0.00 0.15 9.62 9.79 9.89 9.99 10.21 1052 1.00
## pred[33] 9.94 0.00 0.10 9.73 9.88 9.95 10.01 10.13 1189 1.00
## pred[34] 9.93 0.00 0.09 9.75 9.87 9.93 9.98 10.10 1874 1.00
## pred[35] 9.93 0.00 0.09 9.76 9.88 9.93 9.98 10.10 1671 1.00
## pred[36] 9.89 0.00 0.09 9.71 9.83 9.89 9.95 10.09 803 1.00
## pred[37] 10.01 0.00 0.09 9.82 9.96 10.01 10.06 10.19 1523 1.00
## pred[38] 10.06 0.00 0.09 9.88 10.01 10.06 10.11 10.23 1247 1.00
## pred[39] 10.16 0.00 0.10 9.94 10.09 10.16 10.22 10.36 1455 1.00
## sigma_obs 0.13 0.00 0.04 0.06 0.10 0.12 0.15 0.22 121 1.01
## mu[1] 0.06 0.00 0.02 0.01 0.04 0.06 0.07 0.11 388 1.01
## lp__ 16.07 0.86 8.24 1.07 10.19 15.85 21.54 32.93 91 1.02
##
## Samples were drawn using NUTS(diag_e) at Tue May 2 14:43:06 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
It might be good to look at the posterior plots with ShinyStan. Specifically, let’s look at the posterior distributions and trace plots of the two variance parameters. Ask:
are posterior distributions multi-modal?
do MCMC trace plots look stationary?
library(shinystan)
shinystan::launch_shinystan(fit)
##
## Launching ShinyStan interface... for large models this may take some time.
##
## Listening on http://127.0.0.1:6809
Let’s try to look at the posterior of the trend – is it generally overlapping 0? Do you think the drift term should be kept in the model or not?
pars <- as.data.frame(rstan::extract(fit))
ggplot(pars, aes(mu)) +
geom_histogram()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
Density dependence is commonly interpreted as the autoregression term,
\[ x_t = \phi*x_{t-1}+\mu+d_t; d_t \sim N(0,\sigma) \]
We can estimate the autoregression term by changing our model to
ss_ar
,
set.seed(123)
fit2 <- fit_stan(y = log(whale$Count),
model_name = "ss_ar",
est_drift = TRUE,
list(n_mcmc = 500,
n_burn = 100,
n_chain = 3))
Questions:
Did this model appear to converge? Why or why not?
Does the phi
parameter appear well
estimated?
Extract the states, and plot them with the original data. For bonus points, include 95% Credible intervals
In the above model(s) and all of MARSS, we have been modeling log-transformed data. Instead, the Stan code allows us to easily change the family. Given that the original whale dataset consists of counts, let’s try to change the family
set.seed(123)
fit3 <- fit_stan(y = log(whale$Count),
model_name = "ss_rw",
est_drift = TRUE,
list(n_mcmc = 500,
n_burn = 100,
n_chain = 3),
family="poisson")
Questions:
How do the estimates of the trend differ between this approach
Calculate the correlation between the mean state estimates in
this model, versus the model we used initially (fit
). The
Poisson model uses a log-link – and so the predicted states are in
log-space. Is this relationship expected?
broom.mixed::tidy(fit, pars="pred")
## # A tibble: 39 × 3
## term estimate std.error
## <chr> <dbl> <dbl>
## 1 pred[1] 8.01 0.130
## 2 pred[2] 8.12 0.142
## 3 pred[3] 8.21 0.100
## 4 pred[4] 8.31 0.127
## 5 pred[5] 8.41 0.0952
## 6 pred[6] 8.53 0.144
## 7 pred[7] 8.64 0.144
## 8 pred[8] 8.73 0.107
## 9 pred[9] 8.85 0.164
## 10 pred[10] 8.98 0.191
## # ℹ 29 more rows
bayesplot
. For example,bayesplot::mcmc_hist_by_chain(fit, pars = "sigma_obs")
## Warning: The `facets` argument of `facet_grid()` is deprecated as of ggplot2 2.2.0.
## ℹ Please use the `rows` argument instead.
## ℹ The deprecated feature was likely used in the bayesplot package.
## Please report the issue at <https://github.com/stan-dev/bayesplot/issues/>.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
cat( get_stancode(fit) ) # cat makes this readable
## data {
## int<lower=0> N;
## int<lower=0> n_pos;
## vector[n_pos] y;
## int y_int[n_pos];
## int pos_indx[n_pos+2];
## int<lower=0> est_drift;
## int<lower=0> est_nu;
## int family; // 1 = normal, 2 = binomial, 3 = poisson, 4 = gamma, 5 = lognormal
## }
## parameters {
## real x0;
## real mu[est_drift];
## vector[N-1] pro_dev;
## real<lower=0> sigma_process;
## real<lower=0> sigma_obs;
## real<lower=2> nu[est_nu];
## }
## transformed parameters {
## vector[N] pred;
## real temp;
## pred[1] = x0;
## temp = 0;
## if(est_drift==1) {
## temp = mu[1];
## }
## for(i in 2:N) {
## pred[i] = pred[i-1] + temp + sigma_process*pro_dev[i-1];
## }
## }
## model {
## x0 ~ normal(0,10);
## if(est_drift==1) {
## mu ~ normal(0,2);
## }
## if(est_nu==1) {
## nu ~ student_t(3,2,2);
## }
## sigma_process ~ student_t(3,0,2);
## sigma_obs ~ student_t(3,0,2);
## if(est_nu==0) {
## pro_dev ~ std_normal();//normal(0, sigma_process);
## } else {
## pro_dev ~ student_t(nu,0,1);
## }
##
## if(family==1) {
## for(i in 1:(n_pos)) {
## //y ~ normal(pred, sigma_obs);
## y[i] ~ normal(pred[pos_indx[i]], sigma_obs);
## }
## }
## if(family==2) {
## for(i in 1:(n_pos)) {
## y_int[i] ~ bernoulli_logit(pred[pos_indx[i]]);
## }
## }
## if(family==3) {
## for(i in 1:(n_pos)) {
## y_int[i] ~ poisson_log(pred[pos_indx[i]]);
## }
## }
## if(family==4) {
## for(i in 1:(n_pos)) {
## y[i] ~ gamma(sigma_obs, sigma_obs ./ exp(pred[pos_indx[i]]));
## }
## }
## if(family==5) {
## for(i in 1:(n_pos)) {
## y[i] ~ lognormal(pred[pos_indx[i]], sigma_obs);
## }
## }
## }
## generated quantities {
## vector[n_pos] log_lik;
## // regresssion example in loo() package
## if(family==1) for (n in 1:n_pos) log_lik[n] = normal_lpdf(y[n] | pred[pos_indx[n]], sigma_obs);
## if(family==2) for (n in 1:n_pos) log_lik[n] = bernoulli_lpmf(y_int[n] | inv_logit(pred[pos_indx[n]]));
## if(family==3) for (n in 1:n_pos) log_lik[n] = poisson_lpmf(y_int[n] | exp(pred[pos_indx[n]]));
## if(family==4) for (n in 1:n_pos) log_lik[n] = gamma_lpdf(y[n] | sigma_obs, sigma_obs ./ exp(pred[pos_indx[n]]));
## if(family==5) for (n in 1:n_pos) log_lik[n] = lognormal_lpdf(y[n] | pred[pos_indx[n]], sigma_obs);
## }