Fitting a univariate model in Stan

The atsar package contains a bunch of bundled Stan files for fitting univariate state space models in Stan, for those who want to try those methods. There’s also a bunch of options that we didn’t talk about in class, but can explore those in more detail here.

As a first example, we’ll use the MARSS gray whale dataset.

whale <-

ggplot(whale, aes(Year,Count)) + 
  geom_point() + geom_smooth()
An interesting question might be whether we can fit a Gompertz a population model and estimate the long term trend. To do this, we’ll use the atsar package.


It might be helpful to look at the documentation to see what kind of univariate models are available,


For the whale example, let’s start with the ss_rw model, a state space univariate model. We will fit the model in log-space and estimate drift.


We will start with a few hundred samples (normally you want the burn in period to be 1000-2000)

fit <- fit_stan(y = log(whale$Count),
                model_name = "ss_rw",
                est_drift = TRUE,
                list(n_mcmc = 500, 
                     n_burn = 100, 
                     n_chain = 3))

Now let’s look at the R-hat values. Generally we want these to be < 1.1. Does it look like the model converges ok?

## Inference for Stan model: ss_rw.
## 3 chains, each with iter=1000; warmup=500; thin=1; 
## post-warmup draws per chain=500, total post-warmup draws=1500.
##                mean se_mean   sd 2.5%   25%   50%   75% 97.5% n_eff Rhat
## sigma_process  0.14    0.00 0.04 0.06  0.11  0.14  0.16  0.23   149 1.01
## pred[1]        8.02    0.01 0.13 7.76  7.93  8.01  8.09  8.31   632 1.00
## pred[2]        8.11    0.01 0.14 7.84  8.01  8.12  8.20  8.40   792 1.00
## pred[3]        8.21    0.00 0.10 8.02  8.14  8.21  8.28  8.42  1363 1.00
## pred[4]        8.31    0.00 0.13 8.05  8.23  8.31  8.39  8.56  1081 1.00
## pred[5]        8.41    0.00 0.10 8.21  8.35  8.41  8.47  8.60  1759 1.00
## pred[6]        8.53    0.00 0.14 8.24  8.45  8.53  8.62  8.82  1038 1.00
## pred[7]        8.63    0.00 0.14 8.34  8.54  8.64  8.71  8.94  1195 1.00
## pred[8]        8.73    0.00 0.11 8.52  8.66  8.73  8.80  8.95  2089 1.00
## pred[9]        8.86    0.01 0.16 8.54  8.75  8.85  8.96  9.19   969 1.00
## pred[10]       8.99    0.01 0.19 8.62  8.86  8.98  9.11  9.38   964 1.00
## pred[11]       9.11    0.01 0.21 8.74  8.98  9.11  9.24  9.52   957 1.01
## pred[12]       9.24    0.01 0.21 8.84  9.09  9.24  9.37  9.65   867 1.01
## pred[13]       9.36    0.01 0.20 8.99  9.23  9.35  9.49  9.78   585 1.01
## pred[14]       9.49    0.01 0.19 9.14  9.36  9.48  9.62  9.87   433 1.01
## pred[15]       9.62    0.01 0.16 9.27  9.52  9.64  9.73  9.87   174 1.01
## pred[16]       9.54    0.01 0.14 9.25  9.44  9.54  9.63  9.80   367 1.00
## pred[17]       9.45    0.00 0.09 9.27  9.39  9.45  9.51  9.64  1342 1.00
## pred[18]       9.43    0.00 0.09 9.26  9.37  9.44  9.49  9.60  1241 1.00
## pred[19]       9.37    0.00 0.09 9.21  9.31  9.37  9.43  9.56   728 1.00
## pred[20]       9.36    0.01 0.10 9.17  9.29  9.36  9.44  9.57   285 1.01
## pred[21]       9.58    0.01 0.10 9.37  9.52  9.58  9.65  9.77   370 1.00
## pred[22]       9.58    0.00 0.09 9.40  9.52  9.58  9.63  9.75  1203 1.00
## pred[23]       9.54    0.00 0.09 9.37  9.49  9.55  9.60  9.72  1374 1.00
## pred[24]       9.60    0.00 0.08 9.44  9.55  9.60  9.65  9.77  2034 1.00
## pred[25]       9.65    0.00 0.09 9.48  9.60  9.65  9.71  9.81  1988 1.00
## pred[26]       9.68    0.00 0.09 9.50  9.63  9.68  9.73  9.85  1232 1.00
## pred[27]       9.61    0.00 0.09 9.44  9.54  9.61  9.67  9.81   580 1.00
## pred[28]       9.70    0.00 0.09 9.51  9.64  9.70  9.76  9.88  1734 1.00
## pred[29]       9.75    0.00 0.15 9.45  9.66  9.75  9.84 10.05  1021 1.00
## pred[30]       9.79    0.00 0.17 9.45  9.68  9.79  9.90 10.14  1323 1.00
## pred[31]       9.84    0.01 0.18 9.48  9.72  9.85  9.96 10.19  1055 1.00
## pred[32]       9.89    0.00 0.15 9.62  9.79  9.89  9.99 10.21  1052 1.00
## pred[33]       9.94    0.00 0.10 9.73  9.88  9.95 10.01 10.13  1189 1.00
## pred[34]       9.93    0.00 0.09 9.75  9.87  9.93  9.98 10.10  1874 1.00
## pred[35]       9.93    0.00 0.09 9.76  9.88  9.93  9.98 10.10  1671 1.00
## pred[36]       9.89    0.00 0.09 9.71  9.83  9.89  9.95 10.09   803 1.00
## pred[37]      10.01    0.00 0.09 9.82  9.96 10.01 10.06 10.19  1523 1.00
## pred[38]      10.06    0.00 0.09 9.88 10.01 10.06 10.11 10.23  1247 1.00
## pred[39]      10.16    0.00 0.10 9.94 10.09 10.16 10.22 10.36  1455 1.00
## sigma_obs      0.13    0.00 0.04 0.06  0.10  0.12  0.15  0.22   121 1.01
## mu[1]          0.06    0.00 0.02 0.01  0.04  0.06  0.07  0.11   388 1.01
## lp__          16.07    0.86 8.24 1.07 10.19 15.85 21.54 32.93    91 1.02
It might be good to look at the posterior plots with ShinyStan. Specifically, let’s look at the posterior distributions and trace plots of the two variance parameters. Ask:

Let’s try to look at the posterior of the trend – is it generally overlapping 0? Do you think the drift term should be kept in the model or not?

pars <-

ggplot(pars, aes(mu)) + 
Can we estimate density dependence?

Density dependence is commonly interpreted as the autoregression term,

\[ x_t = \phi*x_{t-1}+\mu+d_t; d_t \sim N(0,\sigma) \]

We can estimate the autoregression term by changing our model to ss_ar ,

fit2 <- fit_stan(y = log(whale$Count),
                model_name = "ss_ar",
                est_drift = TRUE,
                list(n_mcmc = 500, 
                     n_burn = 100, 
                     n_chain = 3))


  • Did this model appear to converge? Why or why not?

  • Does the phi parameter appear well estimated?

  • Extract the states, and plot them with the original data. For bonus points, include 95% Credible intervals

Changing the family

In the above model(s) and all of MARSS, we have been modeling log-transformed data. Instead, the Stan code allows us to easily change the family. Given that the original whale dataset consists of counts, let’s try to change the family

  • set.seed(123)
    fit3 <- fit_stan(y = log(whale$Count),
                    model_name = "ss_rw",
                    est_drift = TRUE,
                    list(n_mcmc = 500, 
                         n_burn = 100, 
                         n_chain = 3),


  • How do the estimates of the trend differ between this approach

  • Calculate the correlation between the mean state estimates in this model, versus the model we used initially (fit). The Poisson model uses a log-link – and so the predicted states are in log-space. Is this relationship expected?


  1. Fit MARSS models to any of the above, and plot the state estimates versus the estimates from the Stan models. How similar are they?
  1. Play with broom.mixed – this is often a more efficient way to extract things like predicted values or parameter estimates,
broom.mixed::tidy(fit, pars="pred")
## # A tibble: 39 × 3
##    term     estimate std.error
##    <chr>       <dbl>     <dbl>
##  1 pred[1]      8.01    0.130 
##  2 pred[2]      8.12    0.142 
##  3 pred[3]      8.21    0.100 
##  4 pred[4]      8.31    0.127 
##  5 pred[5]      8.41    0.0952
##  6 pred[6]      8.53    0.144 
##  7 pred[7]      8.64    0.144 
##  8 pred[8]      8.73    0.107 
##  9 pred[9]      8.85    0.164 
## 10 pred[10]     8.98    0.191 
## # ℹ 29 more rows
  1. Make some alternative plots of distributions and traceplots using bayesplot. For example,
bayesplot::mcmc_hist_by_chain(fit, pars = "sigma_obs")
  1. If you want to look at the source code (in Stan) for any of the above models, you can do something like this
cat( get_stancode(fit) ) # cat makes this readable
## data {
##   int<lower=0> N;
##   int<lower=0> n_pos;
##   vector[n_pos] y;
##   int y_int[n_pos];  
##   int pos_indx[n_pos+2];
##   int<lower=0> est_drift;
##   int<lower=0> est_nu;    
##   int family; // 1 = normal, 2 = binomial, 3 = poisson, 4 = gamma, 5 = lognormal  
## }
## parameters {
##   real x0;
##   real mu[est_drift];
##   vector[N-1] pro_dev;
##   real<lower=0> sigma_process;
##   real<lower=0> sigma_obs;
##   real<lower=2> nu[est_nu];  
## }
## transformed parameters {
##   vector[N] pred;
##   real temp;
##   pred[1] = x0;
##   temp = 0;
##   if(est_drift==1) {
##     temp = mu[1];
##   }
##   for(i in 2:N) {
##     pred[i] = pred[i-1] + temp + sigma_process*pro_dev[i-1];
##   }
## }
## model {
##   x0 ~ normal(0,10);
##   if(est_drift==1) {
##     mu ~ normal(0,2);
##   }
##   if(est_nu==1) {
##     nu ~ student_t(3,2,2);
##   }       
##   sigma_process ~ student_t(3,0,2);
##   sigma_obs ~ student_t(3,0,2);
##   if(est_nu==0) {
##     pro_dev ~ std_normal();//normal(0, sigma_process);
##   } else {
##     pro_dev ~ student_t(nu,0,1);
##   }
##   if(family==1) {
##     for(i in 1:(n_pos)) {
##       //y ~ normal(pred, sigma_obs);
##       y[i] ~ normal(pred[pos_indx[i]], sigma_obs);
##     }
##   }
##   if(family==2) {
##     for(i in 1:(n_pos)) {
##       y_int[i] ~ bernoulli_logit(pred[pos_indx[i]]);
##     }
##   }
##   if(family==3) {
##     for(i in 1:(n_pos)) {
##       y_int[i] ~ poisson_log(pred[pos_indx[i]]);
##     }
##   }
##   if(family==4) {
##     for(i in 1:(n_pos)) {
##       y[i] ~ gamma(sigma_obs, sigma_obs ./ exp(pred[pos_indx[i]]));
##     }
##   }
##   if(family==5) {
##     for(i in 1:(n_pos)) {
##       y[i] ~ lognormal(pred[pos_indx[i]], sigma_obs);
##     }
##   }
## }
## generated quantities {
##   vector[n_pos] log_lik;
##   // regresssion example in loo() package
##   if(family==1) for (n in 1:n_pos) log_lik[n] = normal_lpdf(y[n] | pred[pos_indx[n]], sigma_obs);
##   if(family==2) for (n in 1:n_pos) log_lik[n] = bernoulli_lpmf(y_int[n] | inv_logit(pred[pos_indx[n]]));
##   if(family==3) for (n in 1:n_pos) log_lik[n] = poisson_lpmf(y_int[n] | exp(pred[pos_indx[n]]));
##   if(family==4) for (n in 1:n_pos) log_lik[n] = gamma_lpdf(y[n] | sigma_obs, sigma_obs ./ exp(pred[pos_indx[n]]));
##   if(family==5) for (n in 1:n_pos) log_lik[n] = lognormal_lpdf(y[n] | pred[pos_indx[n]], sigma_obs);
## }