data {

  // total number of observations
  int<lower=1> N;
  int<lower=1> n_pat;
  // start / stop index per patient
  int<lower=1, upper=N> pat_start_ind[n_pat];
  int<lower=1, upper=N> pat_end_ind[n_pat];
  // actual data:
  real t[N]; // sampling times
  real y[N]; // outcomes
  int id[N]; // id as integers
  int<lower=1> n_knots; // number of knots to model population mean
  real<lower=0.0> knots[n_knots]; // knot positions
  int n_t_out; // number of output points
  real t_out_[n_t_out, 1]; // output times
  int p_cov; // number covariates (model matrix)
  matrix[n_pat, p_cov] modelmatrix;

}


transformed data {

  real t_out[n_t_out] = t_out_[1:n_t_out, 1];

}


parameters {

  // standard deviation of the population mean process
  real<lower=0> s_pop_mean;
  // lengthscale of the population mean process
  // real<lower=0> l_pop_mean;
  real<lower=.5, upper=8.5> knots_y[n_knots];

  // standard deviation of the population mean process
  real<lower=0> s;
  // lengthscale of the population mean process
  real<lower=30, upper=120> l;
  // nugget term standard deviation
  real<lower=.01> s_nugg;

  vector[p_cov] beta;

}

transformed parameters {

  matrix[n_knots, n_knots] K_knots_inv = inverse(cov_exp_quad(knots, s_pop_mean, l));

  vector[n_pat] lin_pred = modelmatrix * beta;

}

model {

  // priors
  s          ~ normal(4, 3);
  s_pop_mean ~ normal(1, 1);
  l          ~ normal(60, 14);
  s_nugg     ~ normal(0, .1);
  beta       ~ normal(0, 1);

  for (i in 1:n_pat) {

    int n_sub = pat_end_ind[i] - pat_start_ind[i] + 1;

    // latent mean process
    real t_sub[n_sub] = t[pat_start_ind[i]:pat_end_ind[i]];
    matrix[n_sub, n_knots] K_latent_12 = cov_exp_quad(t_sub, knots, s_pop_mean, l);
    vector[n_sub] mu_sub = K_latent_12 * K_knots_inv * to_vector(knots_y);

    vector[n_sub] gose = to_vector(y[pat_start_ind[i]:pat_end_ind[i]]);

    matrix[n_sub, n_sub] Sigma = cov_exp_quad(t_sub, s, l) + diag_matrix(rep_vector(pow(s_nugg, 2.0), n_sub));

    gose ~ multi_normal(
      mu_sub + rep_vector(lin_pred[i], n_sub),
      Sigma
    );

  }

}

generated quantities {

  real gose_out[n_pat, n_t_out];
  row_vector[n_t_out] tmp;
  matrix[n_t_out, n_t_out] Sigma;

  for (i in 1:n_pat) {

    int n_sub = pat_end_ind[i] - pat_start_ind[i] + 1;

    // latent mean process
    real t_sub[n_sub] = t[pat_start_ind[i]:pat_end_ind[i]];
    matrix[n_sub, n_sub] K_oo_inv = inverse(cov_exp_quad(t_sub, t_sub, s, l) + diag_matrix(rep_vector(pow(s_nugg, 2.0), n_sub)));
    matrix[n_t_out, n_sub] K_no    = cov_exp_quad(t_out, t_sub, s, l);

    matrix[n_t_out, n_knots] K_latent_12 = cov_exp_quad(t_out, knots, s_pop_mean, l);
    vector[n_t_out] mu_out = rep_vector(lin_pred[i], n_t_out) + K_latent_12 * K_knots_inv * to_vector(knots_y);

    matrix[n_sub, n_knots] K_latent_12_2 = cov_exp_quad(t_sub, knots, s_pop_mean, l);
    vector[n_sub] mu_sub = rep_vector(lin_pred[i], n_sub) + K_latent_12_2 * K_knots_inv * to_vector(knots_y);

    // observed values
    vector[n_sub] gose = to_vector(y[pat_start_ind[i]:pat_end_ind[i]]);

    // posterior covariance matrix
    Sigma = cov_exp_quad(t_out, t_out, s, l) - K_no * K_oo_inv * K_no' + diag_matrix(rep_vector(.001, n_t_out));

    tmp = multi_normal_rng(
      mu_out + K_no * K_oo_inv * (gose - mu_sub),
      Sigma
    )';

    // discretize
    for (j in 1:n_t_out) {
      gose_out[i, j] = round(tmp[j]);
      if (gose_out[i, j] > 8) {
        gose_out[i, j] = 8;
      }
      if (gose_out[i, j] < 1) {
        gose_out[i, j] = 1;
      }
    }

  }

}
