#!/usr/bin/env Rscript
library(tidyverse)
library(brms)

args        <- commandArgs(trailingOnly = TRUE)

inputfile   <- args[[1]]
outputfile  <- args[[2]]

config      <- yaml::read_yaml("config.yml")

# read and process input data
df <- readRDS(inputfile) %>%
  mutate(
    Outcomes.DerivedCompositeGOSE = factor(
      Outcomes.DerivedCompositeGOSE,
      levels = Outcomes.DerivedCompositeGOSE %>%
        unique %>%
        as.numeric %>%
        sort %>%
        as.character,
      ordered = TRUE
    )
  ) %>%
  mutate_if(is.character, factor)

# generate random seed from first 5 digits of the sha1 hash value
seed <- digest::digest(df, algo = "sha1") %>%
  substr(1, 5) %>%
  strtoi(base = 16)
cat(sprintf("seed: %i\n\r", seed))

formula <- Outcomes.DerivedCompositeGOSE ~
  s(Outcomes.DerivedCompositeGOSEDaysPostInjury) +
  (I(Outcomes.DerivedCompositeGOSEDaysPostInjury^2) + Outcomes.DerivedCompositeGOSEDaysPostInjury + 1 | gupi)

mdl <- brms::brm(
  formula,
  data     = df,
  family   = brms::cumulative("logit", threshold = "flexible"),
  chains   = config$stan$chains,
  cores    = config$stan$chains,
  seed     = seed,
  iter     = config$stan$warmup + config$stan$iter,
  warmup   = config$stan$warmup,
  refresh  = 1,
  control  = list(
    max_treedepth = config$stan$max_treedepth,
    adapt_delta   = config$stan$adapt_delta
  ),
  save_warmup = FALSE
)

df_new_data <- expand.grid(
    gupi = df$gupi %>% unique,
    Outcomes.DerivedCompositeGOSEDaysPostInjury = config$t_out
  ) %>%
  as_tibble %>%
  arrange(gupi, Outcomes.DerivedCompositeGOSEDaysPostInjury) %>%
  left_join(
    df %>%
      select(-Outcomes.DerivedCompositeGOSE, -Outcomes.DerivedCompositeGOSEDaysPostInjury) %>%
      group_by(gupi) %>%
      summarize_all(first) %>%
      ungroup,
    by = c("gupi")
  )

df_posteriors <- predict(mdl, df_new_data, summary = TRUE) %>%
  as_tibble() %>%
  mutate(
    gupi = df_new_data$gupi,
    Outcomes.DerivedCompositeGOSEDaysPostInjury = df_new_data$Outcomes.DerivedCompositeGOSEDaysPostInjury
  ) %>% {
    # since not all levels are observed/fitted, need to extend predictions
    # by zero for those categories, dplyr black magic
    cols         <- sprintf("P(Y = %s)", 1:8)
    missing_cols <- cols[!(cols %in% names(as_tibble(.)))]
    res <- .
    for (newcol in missing_cols) {
      # add missing factor levels
      res <- mutate(res, !!newcol := 0.0)
    }
    res <- select(res, gupi, Outcomes.DerivedCompositeGOSEDaysPostInjury, everything())
    return(res)
  } %>%
  rename(t = Outcomes.DerivedCompositeGOSEDaysPostInjury) %>%
  gather(GOSE, probability, starts_with("P(Y =")) %>%
  mutate(
    GOSE = factor(GOSE, labels = 1:8, levels = sprintf("P(Y = %s)", 1:8))
  ) %>%
  arrange(gupi, t, GOSE) %>%
  mutate_if(is.factor, as.character) %>%
  mutate(
    GOSE = GOSE %>% as.integer,
    gupi = as.character(gupi)
  )

saveRDS(df_posteriors, outputfile)
