Introdução ao Stan

Pedro Tótolo

Introdução

O que é o Stan?

O Stan (mc-stan.org) é uma plataforma para modelagem e computação estatística de alto desempenho.

O Stan faz interface com as linguagens de análise de dados mais populares (Python, R, etc) e é executado em todas as principais plataformas (Linux, Windows, Mac).

Pacotes adicionais fornecem interfaces de alto nível, ferramentas para análise da posteriori e validação cruzada leave-one-out.

Stanisław Ulam, conhecido pelos métodos de Monte Carlo. (Laboratório Nacional de Los Alamos, 1945)

Como funciona?

Os usuários especificam funções de densidade na linguagem de programação probabilística do Stan e, em seguida, ajustam os modelos aos dados usando:

  • inferência estatística bayesiana completa com amostragem MCMC: NUTS-HMC
  • inferência bayesiana aproximada com métodos variacionais: Pathfinder e ADVI
  • estimativa de máxima verossimilhança penalizada com otimização: Newton e quasi-Newton — (L-)BFGS

Quando usar o Stan?

Resposta: modelos hierárquicos, amostras pequenas.

A inferência clássica geralmente depende das propriedades assintóticas do EMV ou de métodos de reamostragem para estimação de incerteza das estimativas.

Com amostras pequenas, o bootstrap também pode falhar.

Quando usar o Stan?

Pensando em modelos hierárquicos, a inferência clássica enfrenta algumas dificuldades:

  • computacionais: problemas de otimização.

  • conceituais: modelos com variáveis latentes são complicados (verossimilhança restrita? H-verossimihança? quadratura?)

  • tamanho amostral: quando o número de grupos é pequeno e desequilibrado, inferência com base na verossimilhança subestima a variância dos estimadores. 1

Stan

Usando o pacote RStan

library(rstan)

modelo = stan_model(file = 'meu_modelo.stan')
            
fit = sampling(modelo, data = meus_dados)

fit

O arquivo .stan especifica a densidade a posteriori através da linguagem de modelagem do Stan. Com a função stan_model, ele é traduzido para C++ e compilado.

A função sampling amostra da densidade a posteriori usando MCMC, de acordo com os dados passados no argumento data.

A linguagem de modelagem do Stan

O arquivo .stan é organizado em blocos:

functions {
  // define funcoes para serem usadas em outros blocos
}
data {
  // declara os dados que precisamos para calcular a verossimilhança
}
transformed data {
  // define transformações dos dados
}
parameters {
  // declara os parâmetros que vão ser estimados
}
transformed parameters {
  // define transformações dos parâmetros
}
model {
  // especificação do modelo
}
generated quantities {
  // define quantidades que queremos na saída
}

Exemplo 1: distribuição normal

\[ y_1,\dots,y_n|\mu,\sigma \overset{c.i.}\sim N(\mu,\sigma^2)\\ \mu \sim N(1,4)\\ \sigma \sim \text{Gamma}(1,2) \]

data {
  int n;
  real y[n];
}
parameters {
  real mu;
  real<lower=0> sigma;
}
model {
  for (i in 1:n) {
    y[i] ~ normal(mu,sigma);
  }
  mu ~ normal(1,2);
  sigma ~ gamma(1,2);
}

Exemplo 1: distribuição normal

# dados simulados
y = rnorm(6, -0.5, 0.6)

modelo = stan_model(file = 'exemplo1.stan');

meus_dados = list(
  n = length(y),
  y = y
)

meus_dados
$n
[1] 6

$y
[1] -0.3084231 -1.7662650 -1.4452394 -0.3545800 -1.1809908 -1.0858136
fit = sampling(modelo, data = meus_dados, refresh=0)

fit
Inference for Stan model: exemplo1.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

       mean se_mean   sd  2.5%   25%   50%   75% 97.5% n_eff Rhat
mu    -0.98    0.01 0.28 -1.52 -1.15 -0.99 -0.81 -0.41  1676    1
sigma  0.65    0.01 0.21  0.37  0.50  0.61  0.76  1.15  1392    1
lp__  -2.55    0.03 1.05 -5.29 -2.97 -2.25 -1.78 -1.48  1113    1

Samples were drawn using NUTS(diag_e) at Wed Oct 23 23:24:18 2024.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

Parênteses: como é feita essa amostragem?

O algoritmo de amostragem padrão do Stan é uma variação do Monte Carlo Hamiltoniano (HMC).

Nesse tipo de cadeia, a transição se dá por uma simulação análoga a um sistema físico: podemos imaginar a densidade a posteriori como uma superfície, e as amostras são obtidas chutando uma bola em direções aleatórias e vendo onde ela chega em um determinado tempo predefinido.

https://chi-feng.github.io/mcmc-demo/app.html

https://arxiv.org/pdf/1701.02434

Parênteses: como é feita essa amostragem?

O HMC é eficiente porque usa mais informação da densidade a posteriori para informar as transições: o gradiente.

No entanto, isso introduz uma limitação: ele não consegue amostrar parâmetros discretos.

Ainda assim, muitas vezes é possível integrar os parâmetros discretos, e o modelo resultante (marginal) permite inferência eficiente.

Vale lembrar que o método não é infalível: quando a geometria da densidade a posteriori varia radicalmente, teremos problemas.

Exemplo 2: modelo hierárquico normal

Experimentos paralelos foram aplicados em 8 escolas para avaliar o efeito de um treinamento especial. Podemos modelar os efeitos \(y_j\), supondo cada desvio padrão \(\sigma_j\) como conhecido, da seguinte forma:

\[ y_j|\theta_j \sim N(\theta_j,\sigma_j)\\ \theta_j|\mu \sim N(\mu,\tau) \\ p(\mu,\tau) \propto 1 \]

Exemplo 2: modelo hierárquico normal

data {
  int<lower=0> J;         // número de escolas
  real y[J];              // efeito estimado de treinamento
  real<lower=0> sigma[J]; // desvio padrão das estimativas 
}
parameters {
  real mu;                // efeito médio na população
  real<lower=0> tau;      // desvio padrão dos efeitos de treinamento
  vector[J] theta;          // efeito médio em cada escola
}
model {
  // verossimilhança
  target += normal_lpdf(theta | mu, tau);  
  target += normal_lpdf(y | theta, sigma);
}

modelo = stan_model(file = 'exemplo2.stan');

fit = sampling(modelo, data = schools_dat, refresh=0)

fit
Inference for Stan model: exemplo2.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

           mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
mu         7.92    0.16 5.07  -2.01   4.86   7.82  11.18  17.73  1021 1.00
tau        7.15    0.24 5.39   1.39   3.18   5.73   9.61  20.76   507 1.00
theta[1]  11.76    0.24 8.43  -1.80   6.18  10.59  16.10  32.34  1233 1.00
theta[2]   7.86    0.17 6.23  -5.17   4.06   7.87  11.88  20.46  1405 1.00
theta[3]   5.81    0.23 8.04 -12.00   1.58   6.48  10.82  20.35  1212 1.00
theta[4]   7.52    0.15 6.68  -6.37   3.68   7.44  11.67  20.70  1866 1.00
theta[5]   4.93    0.20 6.47  -8.90   0.93   5.34   9.24  16.60  1051 1.00
theta[6]   5.91    0.20 6.99  -9.07   1.83   6.30  10.36  18.54  1214 1.00
theta[7]  10.96    0.18 6.90  -1.13   6.35  10.30  15.01  26.40  1537 1.00
theta[8]   8.56    0.19 7.80  -6.76   4.02   8.09  12.95  25.35  1721 1.00
lp__     -52.90    0.30 5.10 -62.27 -56.56 -53.11 -49.12 -42.87   295 1.01

Samples were drawn using NUTS(diag_e) at Wed Oct 23 23:24:19 2024.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).
util$check_all_diagnostics(fit)
[1] "n_eff / iter looks reasonable for all parameters"
[1] "Rhat looks reasonable for all parameters"
[1] "64 of 4000 iterations ended with a divergence (1.6%)"
[1] "  Try running with larger adapt_delta to remove the divergences"
[1] "0 of 4000 iterations saturated the maximum tree depth of 10 (0%)"
[1] "E-BFMI indicated no pathological behavior"

Para \(\tau\) pequeno, transições divergem.

Se ajustarmos só 3 escolas, isso fica ainda mais acentuado.

schools_dat <- list(J = 3, 
                    y = c(28,  8, -3),
                    sigma = c(15, 10, 16))

fit = sampling(modelo, data = schools_dat, refresh=0)

fit
Inference for Stan model: exemplo2.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

           mean se_mean     sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
mu        18.30    6.05  80.78 -82.95   1.50  15.92  26.43 124.22   178 1.01
tau       85.37   22.63 494.68   3.69  10.66  23.89  53.66 416.51   478 1.01
theta[1]  23.90    0.61  13.53  -3.97  15.50  23.06  32.20  52.63   491 1.01
theta[2]  10.70    2.28  10.54 -10.11   3.24  10.89  18.85  28.21    21 1.21
theta[3]   4.45    2.42  15.17 -26.69  -6.11   5.63  17.56  27.22    39 1.14
lp__     -22.27    0.32   2.51 -28.05 -23.74 -21.92 -20.49 -17.98    62 1.09

Samples were drawn using NUTS(diag_e) at Wed Oct 23 23:24:20 2024.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

[1] "n_eff / iter looks reasonable for all parameters"
[1] "Rhat for parameter theta[2] is 1.21306766078134!"
[1] "Rhat for parameter theta[3] is 1.13668825354372!"
[1] "  Rhat above 1.1 indicates that the chains very likely have not mixed"
[1] "675 of 4000 iterations ended with a divergence (16.875%)"
[1] "  Try running with larger adapt_delta to remove the divergences"
[1] "0 of 4000 iterations saturated the maximum tree depth of 10 (0%)"
[1] "E-BFMI indicated no pathological behavior"

A solução é reparametrizar: amostramos de uma distribuição \(N(0,1)\) e transformamos em uma \(N(\mu,\sigma^2)\).

data {
  int<lower=0> J;         // número de escolas 
  real y[J];              // efeito estimado do treinamento
  real<lower=0> sigma[J]; // desvio padrão das estimativas 
}
parameters {
  real mu;                // efeito médio na população
  real<lower=0> tau;      // desvio padrão dos efeitos de treinamento
  vector[J] theta_raw;    // efeito médio em cada escola (sem escala)
}
transformed parameters {
  vector[J] theta = mu + tau * theta_raw;        // efeito médio em cada escola
}
model {
  // verossimilhança
  target += normal_lpdf(theta_raw | 0, 1);
  target += normal_lpdf(y | theta, sigma);
}

https://mc-stan.org/docs/stan-users-guide/reparameterization.html

[1] "n_eff / iter looks reasonable for all parameters"
[1] "Rhat looks reasonable for all parameters"
[1] "76 of 4000 iterations ended with a divergence (1.9%)"
[1] "  Try running with larger adapt_delta to remove the divergences"
[1] "0 of 4000 iterations saturated the maximum tree depth of 10 (0%)"
[1] "E-BFMI indicated no pathological behavior"

Aumentando o adapt_delta, como sugerido, de 0.8 para 0.99:

fit = sampling(exemplo2ncp, data = schools_dat, refresh=0,
               control = list(adapt_delta=0.99))
[1] "n_eff / iter looks reasonable for all parameters"
[1] "Rhat looks reasonable for all parameters"
[1] "2 of 4000 iterations ended with a divergence (0.05%)"
[1] "  Try running with larger adapt_delta to remove the divergences"
[1] "0 of 4000 iterations saturated the maximum tree depth of 10 (0%)"
[1] "E-BFMI indicated no pathological behavior"

Para as 8 escolas, a reparametrização já é suficiente:

schools_dat <- list(J = 8, 
                    y = c(28,  8, -3,  7, -1,  1, 18, 12),
                    sigma = c(15, 10, 16, 11,  9, 11, 10, 18))

fit = sampling(exemplo2ncp, data = schools_dat, refresh=0)
[1] "n_eff / iter looks reasonable for all parameters"
[1] "Rhat looks reasonable for all parameters"
[1] "2 of 4000 iterations ended with a divergence (0.05%)"
[1] "  Try running with larger adapt_delta to remove the divergences"
[1] "0 of 4000 iterations saturated the maximum tree depth of 10 (0%)"
[1] "E-BFMI indicated no pathological behavior"

Inference for Stan model: exemplo2_ncp.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

               mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
mu             8.02    0.14 5.51  -2.72   4.45   7.92  11.28  19.70  1543 1.00
tau            6.73    0.15 5.64   0.29   2.57   5.33   9.32  21.25  1510 1.00
theta_raw[1]   0.38    0.02 0.91  -1.44  -0.21   0.40   1.00   2.13  3421 1.00
theta_raw[2]   0.01    0.01 0.86  -1.68  -0.56   0.01   0.59   1.71  3662 1.00
theta_raw[3]  -0.20    0.02 0.93  -2.03  -0.80  -0.21   0.44   1.60  3507 1.00
theta_raw[4]  -0.04    0.02 0.92  -1.80  -0.65  -0.05   0.56   1.85  3061 1.00
theta_raw[5]  -0.36    0.02 0.90  -2.07  -0.96  -0.38   0.23   1.45  3154 1.00
theta_raw[6]  -0.22    0.02 0.90  -2.05  -0.80  -0.21   0.37   1.59  3301 1.00
theta_raw[7]   0.34    0.02 0.91  -1.43  -0.24   0.33   0.95   2.12  3384 1.00
theta_raw[8]   0.06    0.01 0.94  -1.82  -0.55   0.05   0.68   1.90  4166 1.00
theta[1]      11.60    0.17 8.60  -2.26   5.97  10.39  15.69  32.69  2500 1.00
theta[2]       7.99    0.11 6.41  -4.82   4.04   7.92  11.95  21.23  3699 1.00
theta[3]       6.16    0.14 7.74 -10.62   1.94   6.58  10.96  20.58  3013 1.00
theta[4]       7.67    0.10 6.70  -5.85   3.81   7.66  11.61  21.27  4314 1.00
theta[5]       5.07    0.11 6.46  -9.05   1.32   5.48   9.31  16.78  3764 1.00
theta[6]       6.05    0.12 6.79  -9.02   2.29   6.34  10.48  18.49  3285 1.00
theta[7]      10.68    0.13 7.01  -1.41   5.98  10.05  14.59  26.47  3111 1.00
theta[8]       8.50    0.15 8.07  -7.26   3.77   8.11  12.71  26.42  2953 1.00
lp__         -39.60    0.08 2.68 -45.74 -41.19 -39.39 -37.75 -34.99  1211 1.01

Samples were drawn using NUTS(diag_e) at Wed Oct 23 23:24:23 2024.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

https://mc-stan.org/users/documentation/case-studies/divergences_and_bias.html

Pacotes adicionais: brms

O brms (bayesian regression models with stan) oferece uma interface de alto nível para modelagem de regressão. A sintaxe é muito semelhante à do pacote lme4.

library(brms)
schools_dat$J = NULL
schools_dat$escola = 1:8

schools_dat
$y
[1] 28  8 -3  7 -1  1 18 12

$sigma
[1] 15 10 16 11  9 11 10 18

$escola
[1] 1 2 3 4 5 6 7 8
fit2 = brm(y | resp_se(sigma) ~ 1 + (1|escola), data = schools_dat, refresh=0,
            prior = c(
              set_prior("", class = "sd"),
              set_prior("", class = 'Intercept')
            )
           )
fit2
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: y | resp_se(sigma) ~ 1 + (1 | escola) 
   Data: schools_dat (Number of observations: 8) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Group-Level Effects: 
~escola (Number of levels: 8) 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     6.43      5.26     0.24    19.48 1.00     1366     1828

Population-Level Effects: 
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept     7.80      4.94    -2.22    17.46 1.00     2367     1539

Family Specific Parameters: 
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     0.00      0.00     0.00     0.00   NA       NA       NA

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Exemplo 4: regressão linear

Temperatura anual média do ar sobre o solo do planeta Terra.

data_temperature <- aida::data_WorldTemp

Exemplo 4: regressão linear

fit_temperature <- brm(formula = avg_temp ~ s(year),
  data = data_temperature, refresh=0
)

Exemplo 4: regressão linear com spline

fit_temperature <- brm(formula = avg_temp ~ s(year),
  data = data_temperature, refresh=0
)

Exemplo 5: análise de sobrevivência

Tempo de recorrência de infecção em pacientes com doença renal.

head(kidney, n = 3)
  time censored patient recur age    sex disease
1    8        0       1     1  28   male   other
2   23        0       2     1  48 female      GN
3   22        0       3     1  32   male   other
plot(survfit(Surv(time, censored) ~ 1, data = kidney), xlab='dias', main = 'Kaplan-Meier')

Exemplo 5: análise de sobrevivência

fit1 <- brm(
  formula = time | cens(censored) ~ age * sex + disease + (1 + age|patient),
  data = kidney, family = lognormal(),
  prior = c(set_prior("normal(0,5)", class = "b"),
            set_prior("cauchy(0,2)", class = "sd"),
            set_prior("lkj(2)", class = "cor")), warmup = 1000,
  iter = 2000, chains = 4, control = list(adapt_delta = 0.95),
  refresh=0)
fit1
 Family: lognormal 
  Links: mu = identity; sigma = identity 
Formula: time | cens(censored) ~ age * sex + disease + (1 + age | patient) 
   Data: kidney (Number of observations: 76) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Group-Level Effects: 
~patient (Number of levels: 38) 
                   Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)          0.41      0.29     0.02     1.08 1.00     1026     1569
sd(age)                0.01      0.01     0.00     0.02 1.00      950     1553
cor(Intercept,age)    -0.15      0.46    -0.89     0.77 1.00     2405     2478

Population-Level Effects: 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept         2.73      0.95     0.88     4.61 1.00     1817     2490
age               0.01      0.02    -0.03     0.06 1.00     1562     2149
sexfemale         2.45      1.12     0.30     4.67 1.00     1697     2180
diseaseGN        -0.41      0.51    -1.44     0.59 1.00     2597     2964
diseaseAN        -0.52      0.52    -1.50     0.51 1.00     2444     2681
diseasePKD        0.59      0.74    -0.87     2.04 1.00     2516     2860
age:sexfemale    -0.02      0.03    -0.07     0.03 1.00     1615     2138

Family Specific Parameters: 
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     1.14      0.13     0.92     1.42 1.00     2364     2461

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Pacote bayesplot

O bayesplot tem várias utilidades para diagnóstico do modelo.

Prophet

O Prophet é um pacote de previsão para séries temporais com base em um modelo aditivo em que as tendências não lineares são ajustadas com sazonalidade anual, semanal e diária, além de efeitos de feriados. Ele funciona melhor com séries temporais que têm fortes efeitos sazonais e várias temporadas de dados históricos. O Prophet é robusto em relação a dados ausentes e mudanças na tendência e, normalmente, lida bem com outliers.

O Prophet é usado em muitas aplicações no Facebook para produzir previsões confiáveis para planejamento e definição de metas. Descobrimos que ele tem um desempenho melhor do que qualquer outra abordagem na maioria dos casos. Ajustamos modelos no Stan para que você obtenha previsões em apenas alguns segundos.

Até janeiro de 2023, o pacote Python foi baixado mais de 16 milhões de vezes via PyPI e continua a receber 1 milhão de downloads por mês.

https://facebook.github.io/prophet/

Outros recursos: mais exemplo, aplicações, etc

https://mc-stan.org/users/documentation/case-studies