#!/usr/bin/env Rscript
library(tidyverse)
library(rstan)

args        <- commandArgs(trailingOnly = TRUE)

inputfile   <- args[[1]]
outputfile  <- args[[2]]
modelfile   <- "models/gp_nb/gp_nb.stan"

config      <- yaml::read_yaml("config.yml")

df          <- readRDS(inputfile)

model_matrix_baseline <- df %>%
  arrange(gupi, Outcomes.DerivedCompositeGOSEDaysPostInjury) %>%
  group_by(gupi) %>%
  select(
    -Outcomes.DerivedCompositeGOSEDaysPostInjury, -Outcomes.DerivedCompositeGOSE
  ) %>%
  summarize_all(first) %>%
  model.matrix(~ . - gupi - 1, data = .)

df_start_end_index <- df %>%
  arrange(gupi, Outcomes.DerivedCompositeGOSEDaysPostInjury) %>%
  mutate(
    tmp = row_number()
  ) %>%
  group_by(gupi) %>%
  summarize(
    start_row = first(tmp),
    end_row   = last(tmp)
  )

df_gose <- df %>%
  arrange(gupi, Outcomes.DerivedCompositeGOSEDaysPostInjury) %>%
  transmute(
    gupi = gupi %>% factor %>% as.numeric,
    Outcomes.DerivedCompositeGOSE = as.numeric(Outcomes.DerivedCompositeGOSE),
    Outcomes.DerivedCompositeGOSEDaysPostInjury = Outcomes.DerivedCompositeGOSEDaysPostInjury
  )

stan_data <- list(
  N             = nrow(df),
  n_pat         = df_gose$gupi %>% unique %>% length,
  pat_start_ind = df_start_end_index$start_row,
  pat_end_ind   = df_start_end_index$end_row,
  t             = df_gose$Outcomes.DerivedCompositeGOSEDaysPostInjury,
  y             = df_gose$Outcomes.DerivedCompositeGOSE,
  id            = df_gose$gupi,
  modelmatrix   = model_matrix_baseline,
  p_cov         = ncol(model_matrix_baseline),
  t_out_        = config$t_out %>% as.matrix,
  n_t_out       = length(config$t_out),
  knots         = config$gp$t_knots,
  n_knots       = length(config$gp$t_knots)
)

seed <- digest::digest(stan_data, algo = "sha1") %>%
  substr(1, 5) %>%
  strtoi(base = 16)
cat(sprintf("seed: %i\n\r", seed))

rstan::rstan_options(auto_write = TRUE)

mdl <- rstan::stan(
    file     = modelfile,
    data     = stan_data,
    chains   = config$stan$chains,
    cores    = config$stan$chains,
    seed     = seed,
    iter     = config$stan$warmup + config$stan$iter,
    warmup   = config$stan$warmup,
    refresh  = 1,
    control  = list(
      max_treedepth = config$stan$max_treedepth,
      adapt_delta   = config$stan$adapt_delta
    )
  )

df_posteriors <- rstan::extract(mdl)$gose_out %>% # (n_sample, n_pat, n_t)
  # relative frequencies
  apply(
    MARGIN = c(2, 3),
    function(x) table(factor(x, levels = 1:8)) / length(x) %>%  as.numeric()
  ) %>% # (8, n_pat, n_t) -> permute
  aperm(c(2, 1, 3)) %>%
  as.data.frame %>%
  as_tibble %>%
  mutate(
    gupi = row_number()
  ) %>%
  gather(GOSE.t, probability, -gupi) %>%
  separate(col = GOSE.t, c("GOSE", "t"), sep = "\\.", convert = TRUE) %>%
  mutate(
    t    = stan_data$t_out_[t, 1],
    gupi = (df %>%
      arrange(gupi, Outcomes.DerivedCompositeGOSEDaysPostInjury) %>%
      .[["gupi"]] %>%
      unique)[gupi]
  ) %>%
  select(gupi, t, GOSE, probability) %>%
  arrange(gupi, t)

saveRDS(df_posteriors, outputfile)
