Fitting a univariate model in Stan

The atsar package contains a bunch of bundled Stan files for fitting univariate state space models in Stan, for those who want to try those methods. There’s also a bunch of options that we didn’t talk about in class, but can explore those in more detail here.

As a first example, we’ll use the MARSS gray whale dataset.

library(ggplot2)
library(shinystan)
## Loading required package: shiny
## 
## This is shinystan version 2.6.0
library(rstan)
## Loading required package: StanHeaders
## rstan (Version 2.21.8, GitRev: 2e1f913d3ca3)
## For execution on a local, multicore CPU with excess RAM we recommend calling
## options(mc.cores = parallel::detectCores()).
## To avoid recompilation of unchanged Stan programs, we recommend calling
## rstan_options(auto_write = TRUE)
library(broom.mixed)
library(bayesplot)
## This is bayesplot version 1.10.0
## - Online documentation and vignettes at mc-stan.org/bayesplot
## - bayesplot theme set to bayesplot::theme_default()
##    * Does _not_ affect other ggplot2 plots
##    * See ?bayesplot_theme_set for details on theme setting
whale <- as.data.frame(MARSS::graywhales)

ggplot(whale, aes(Year,Count)) + 
  geom_point() + geom_smooth()
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'
## Warning: Removed 15 rows containing non-finite values (`stat_smooth()`).
## Warning: Removed 15 rows containing missing values (`geom_point()`).

An interesting question might be whether we can fit a Gompertz a population model and estimate the long term trend. To do this, we’ll use the atsar package.

remotes::install_github("atsa-es/atsar")

It might be helpful to look at the documentation to see what kind of univariate models are available,

?atsar::fit_stan

For the whale example, let’s start with the ss_rw model, a state space univariate model. We will fit the model in log-space and estimate drift.

library(atsar)

We will start with a few hundred samples (normally you want the burn in period to be 1000-2000)

set.seed(123)
fit <- fit_stan(y = log(whale$Count),
                model_name = "ss_rw",
                est_drift = TRUE,
                list(n_mcmc = 500, 
                     n_burn = 100, 
                     n_chain = 3))

Now let’s look at the R-hat values. Generally we want these to be < 1.1. Does it look like the model converges ok?

fit
## Inference for Stan model: ss_rw.
## 3 chains, each with iter=1000; warmup=500; thin=1; 
## post-warmup draws per chain=500, total post-warmup draws=1500.
## 
##                mean se_mean   sd 2.5%   25%   50%   75% 97.5% n_eff Rhat
## sigma_process  0.14    0.00 0.04 0.06  0.11  0.14  0.16  0.23   149 1.01
## pred[1]        8.02    0.01 0.13 7.76  7.93  8.01  8.09  8.31   632 1.00
## pred[2]        8.11    0.01 0.14 7.84  8.01  8.12  8.20  8.40   792 1.00
## pred[3]        8.21    0.00 0.10 8.02  8.14  8.21  8.28  8.42  1363 1.00
## pred[4]        8.31    0.00 0.13 8.05  8.23  8.31  8.39  8.56  1081 1.00
## pred[5]        8.41    0.00 0.10 8.21  8.35  8.41  8.47  8.60  1759 1.00
## pred[6]        8.53    0.00 0.14 8.24  8.45  8.53  8.62  8.82  1038 1.00
## pred[7]        8.63    0.00 0.14 8.34  8.54  8.64  8.71  8.94  1195 1.00
## pred[8]        8.73    0.00 0.11 8.52  8.66  8.73  8.80  8.95  2089 1.00
## pred[9]        8.86    0.01 0.16 8.54  8.75  8.85  8.96  9.19   969 1.00
## pred[10]       8.99    0.01 0.19 8.62  8.86  8.98  9.11  9.38   964 1.00
## pred[11]       9.11    0.01 0.21 8.74  8.98  9.11  9.24  9.52   957 1.01
## pred[12]       9.24    0.01 0.21 8.84  9.09  9.24  9.37  9.65   867 1.01
## pred[13]       9.36    0.01 0.20 8.99  9.23  9.35  9.49  9.78   585 1.01
## pred[14]       9.49    0.01 0.19 9.14  9.36  9.48  9.62  9.87   433 1.01
## pred[15]       9.62    0.01 0.16 9.27  9.52  9.64  9.73  9.87   174 1.01
## pred[16]       9.54    0.01 0.14 9.25  9.44  9.54  9.63  9.80   367 1.00
## pred[17]       9.45    0.00 0.09 9.27  9.39  9.45  9.51  9.64  1342 1.00
## pred[18]       9.43    0.00 0.09 9.26  9.37  9.44  9.49  9.60  1241 1.00
## pred[19]       9.37    0.00 0.09 9.21  9.31  9.37  9.43  9.56   728 1.00
## pred[20]       9.36    0.01 0.10 9.17  9.29  9.36  9.44  9.57   285 1.01
## pred[21]       9.58    0.01 0.10 9.37  9.52  9.58  9.65  9.77   370 1.00
## pred[22]       9.58    0.00 0.09 9.40  9.52  9.58  9.63  9.75  1203 1.00
## pred[23]       9.54    0.00 0.09 9.37  9.49  9.55  9.60  9.72  1374 1.00
## pred[24]       9.60    0.00 0.08 9.44  9.55  9.60  9.65  9.77  2034 1.00
## pred[25]       9.65    0.00 0.09 9.48  9.60  9.65  9.71  9.81  1988 1.00
## pred[26]       9.68    0.00 0.09 9.50  9.63  9.68  9.73  9.85  1232 1.00
## pred[27]       9.61    0.00 0.09 9.44  9.54  9.61  9.67  9.81   580 1.00
## pred[28]       9.70    0.00 0.09 9.51  9.64  9.70  9.76  9.88  1734 1.00
## pred[29]       9.75    0.00 0.15 9.45  9.66  9.75  9.84 10.05  1021 1.00
## pred[30]       9.79    0.00 0.17 9.45  9.68  9.79  9.90 10.14  1323 1.00
## pred[31]       9.84    0.01 0.18 9.48  9.72  9.85  9.96 10.19  1055 1.00
## pred[32]       9.89    0.00 0.15 9.62  9.79  9.89  9.99 10.21  1052 1.00
## pred[33]       9.94    0.00 0.10 9.73  9.88  9.95 10.01 10.13  1189 1.00
## pred[34]       9.93    0.00 0.09 9.75  9.87  9.93  9.98 10.10  1874 1.00
## pred[35]       9.93    0.00 0.09 9.76  9.88  9.93  9.98 10.10  1671 1.00
## pred[36]       9.89    0.00 0.09 9.71  9.83  9.89  9.95 10.09   803 1.00
## pred[37]      10.01    0.00 0.09 9.82  9.96 10.01 10.06 10.19  1523 1.00
## pred[38]      10.06    0.00 0.09 9.88 10.01 10.06 10.11 10.23  1247 1.00
## pred[39]      10.16    0.00 0.10 9.94 10.09 10.16 10.22 10.36  1455 1.00
## sigma_obs      0.13    0.00 0.04 0.06  0.10  0.12  0.15  0.22   121 1.01
## mu[1]          0.06    0.00 0.02 0.01  0.04  0.06  0.07  0.11   388 1.01
## lp__          16.07    0.86 8.24 1.07 10.19 15.85 21.54 32.93    91 1.02
## 
## Samples were drawn using NUTS(diag_e) at Tue May  2 14:43:06 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

It might be good to look at the posterior plots with ShinyStan. Specifically, let’s look at the posterior distributions and trace plots of the two variance parameters. Ask:

library(shinystan)
shinystan::launch_shinystan(fit)
## 
## Launching ShinyStan interface... for large models this  may take some time.
## 
## Listening on http://127.0.0.1:6809

Let’s try to look at the posterior of the trend – is it generally overlapping 0? Do you think the drift term should be kept in the model or not?

pars <- as.data.frame(rstan::extract(fit))

ggplot(pars, aes(mu)) + 
  geom_histogram()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

Can we estimate density dependence?

Density dependence is commonly interpreted as the autoregression term,

\[ x_t = \phi*x_{t-1}+\mu+d_t; d_t \sim N(0,\sigma) \]

We can estimate the autoregression term by changing our model to ss_ar ,

set.seed(123)
fit2 <- fit_stan(y = log(whale$Count),
                model_name = "ss_ar",
                est_drift = TRUE,
                list(n_mcmc = 500, 
                     n_burn = 100, 
                     n_chain = 3))

Questions:

  • Did this model appear to converge? Why or why not?

  • Does the phi parameter appear well estimated?

  • Extract the states, and plot them with the original data. For bonus points, include 95% Credible intervals

Changing the family

In the above model(s) and all of MARSS, we have been modeling log-transformed data. Instead, the Stan code allows us to easily change the family. Given that the original whale dataset consists of counts, let’s try to change the family

  • set.seed(123)
    fit3 <- fit_stan(y = log(whale$Count),
                    model_name = "ss_rw",
                    est_drift = TRUE,
                    list(n_mcmc = 500, 
                         n_burn = 100, 
                         n_chain = 3),
                    family="poisson")

Questions:

  • How do the estimates of the trend differ between this approach

  • Calculate the correlation between the mean state estimates in this model, versus the model we used initially (fit). The Poisson model uses a log-link – and so the predicted states are in log-space. Is this relationship expected?

Bonus

  1. Fit MARSS models to any of the above, and plot the state estimates versus the estimates from the Stan models. How similar are they?
  1. Play with broom.mixed – this is often a more efficient way to extract things like predicted values or parameter estimates,
broom.mixed::tidy(fit, pars="pred")
## # A tibble: 39 × 3
##    term     estimate std.error
##    <chr>       <dbl>     <dbl>
##  1 pred[1]      8.01    0.130 
##  2 pred[2]      8.12    0.142 
##  3 pred[3]      8.21    0.100 
##  4 pred[4]      8.31    0.127 
##  5 pred[5]      8.41    0.0952
##  6 pred[6]      8.53    0.144 
##  7 pred[7]      8.64    0.144 
##  8 pred[8]      8.73    0.107 
##  9 pred[9]      8.85    0.164 
## 10 pred[10]     8.98    0.191 
## # ℹ 29 more rows
  1. Make some alternative plots of distributions and traceplots using bayesplot. For example,
bayesplot::mcmc_hist_by_chain(fit, pars = "sigma_obs")
## Warning: The `facets` argument of `facet_grid()` is deprecated as of ggplot2 2.2.0.
## ℹ Please use the `rows` argument instead.
## ℹ The deprecated feature was likely used in the bayesplot package.
##   Please report the issue at <https://github.com/stan-dev/bayesplot/issues/>.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

  1. If you want to look at the source code (in Stan) for any of the above models, you can do something like this
cat( get_stancode(fit) ) # cat makes this readable
## data {
##   int<lower=0> N;
##   int<lower=0> n_pos;
##   vector[n_pos] y;
##   int y_int[n_pos];  
##   int pos_indx[n_pos+2];
##   int<lower=0> est_drift;
##   int<lower=0> est_nu;    
##   int family; // 1 = normal, 2 = binomial, 3 = poisson, 4 = gamma, 5 = lognormal  
## }
## parameters {
##   real x0;
##   real mu[est_drift];
##   vector[N-1] pro_dev;
##   real<lower=0> sigma_process;
##   real<lower=0> sigma_obs;
##   real<lower=2> nu[est_nu];  
## }
## transformed parameters {
##   vector[N] pred;
##   real temp;
##   pred[1] = x0;
##   temp = 0;
##   if(est_drift==1) {
##     temp = mu[1];
##   }
##   for(i in 2:N) {
##     pred[i] = pred[i-1] + temp + sigma_process*pro_dev[i-1];
##   }
## }
## model {
##   x0 ~ normal(0,10);
##   if(est_drift==1) {
##     mu ~ normal(0,2);
##   }
##   if(est_nu==1) {
##     nu ~ student_t(3,2,2);
##   }       
##   sigma_process ~ student_t(3,0,2);
##   sigma_obs ~ student_t(3,0,2);
##   if(est_nu==0) {
##     pro_dev ~ std_normal();//normal(0, sigma_process);
##   } else {
##     pro_dev ~ student_t(nu,0,1);
##   }
## 
##   if(family==1) {
##     for(i in 1:(n_pos)) {
##       //y ~ normal(pred, sigma_obs);
##       y[i] ~ normal(pred[pos_indx[i]], sigma_obs);
##     }
##   }
##   if(family==2) {
##     for(i in 1:(n_pos)) {
##       y_int[i] ~ bernoulli_logit(pred[pos_indx[i]]);
##     }
##   }
##   if(family==3) {
##     for(i in 1:(n_pos)) {
##       y_int[i] ~ poisson_log(pred[pos_indx[i]]);
##     }
##   }
##   if(family==4) {
##     for(i in 1:(n_pos)) {
##       y[i] ~ gamma(sigma_obs, sigma_obs ./ exp(pred[pos_indx[i]]));
##     }
##   }
##   if(family==5) {
##     for(i in 1:(n_pos)) {
##       y[i] ~ lognormal(pred[pos_indx[i]], sigma_obs);
##     }
##   }
## }
## generated quantities {
##   vector[n_pos] log_lik;
##   // regresssion example in loo() package
##   if(family==1) for (n in 1:n_pos) log_lik[n] = normal_lpdf(y[n] | pred[pos_indx[n]], sigma_obs);
##   if(family==2) for (n in 1:n_pos) log_lik[n] = bernoulli_lpmf(y_int[n] | inv_logit(pred[pos_indx[n]]));
##   if(family==3) for (n in 1:n_pos) log_lik[n] = poisson_lpmf(y_int[n] | exp(pred[pos_indx[n]]));
##   if(family==4) for (n in 1:n_pos) log_lik[n] = gamma_lpdf(y[n] | sigma_obs, sigma_obs ./ exp(pred[pos_indx[n]]));
##   if(family==5) for (n in 1:n_pos) log_lik[n] = lognormal_lpdf(y[n] | pred[pos_indx[n]], sigma_obs);
## }