Commit 9016d184 authored by Kevin Kunzmann's avatar Kevin Kunzmann

added MI model

parent ddcf7606
......@@ -14,3 +14,5 @@ output*
*.png
*.sif
*.DS_Store
*.eps
*.html
library(tidyverse)
library(mice)
library(miceadds)
args <- commandArgs(trailingOnly = TRUE)
inputfile <- args[[1]] # "../output/v1.1/data/validation/df_train_mi_1_fold_1.rds"#
outputfile <- args[[2]]
config <- yaml::read_yaml("config.yml")
tbl_train <- readRDS(inputfile) %>%
as_tibble()
tbl_blank <- tbl_train %>%
select(
-Outcomes.DerivedCompositeGOSE,
-Outcomes.DerivedCompositeGOSEDaysPostInjury
) %>%
distinct() %>%
mutate(
Outcomes.DerivedCompositeGOSE = NA,
Outcomes.DerivedCompositeGOSEDaysPostInjury = 180
)
tbl_gose <- tbl_train %>%
bind_rows(tbl_blank) %>%
select(
gupi,
Outcomes.DerivedCompositeGOSEDaysPostInjury,
Outcomes.DerivedCompositeGOSE
)
tbl_baseline <- tbl_train %>%
bind_rows(tbl_blank) %>%
select(
-Outcomes.DerivedCompositeGOSEDaysPostInjury,
-Outcomes.DerivedCompositeGOSE
) %>%
group_by(gupi) %>%
summarize(
across(everything(), first),
.groups = "drop"
)
map_to_timepoint <- function(x,
timepoints = c(14, 90, 180, 365)
) {
xx <- matrix(
rep(x, length(timepoints)),
ncol = length(timepoints)
)
as.character(timepoints)[
abs(xx - matrix(
rep(timepoints, length(x)),
ncol = length(timepoints),
byrow = TRUE
)) %>%
apply(1, which.min)
]
}
tbl_combined <- expand_grid(
gupi = tbl_gose$gupi %>% unique,
timepoint = c("14", "90", "180", "365")
) %>%
left_join(
tbl_gose %>%
mutate(
gose = as.integer(as.character(Outcomes.DerivedCompositeGOSE)),
timepoint = map_to_timepoint(Outcomes.DerivedCompositeGOSEDaysPostInjury)
) %>%
group_by(gupi, timepoint) %>%
arrange(
gupi,
as.integer(timepoint),
abs(Outcomes.DerivedCompositeGOSEDaysPostInjury - as.integer(timepoint))
) %>%
select(
-Outcomes.DerivedCompositeGOSEDaysPostInjury,
-Outcomes.DerivedCompositeGOSE
) %>%
summarize(
gose = first(gose[!is.na(gose)]),
.groups = "drop"
),
by = c("gupi", "timepoint")
) %>%
mutate(
gose = factor(gose, levels = c(1, 3:8), ordered = TRUE)
) %>%
mutate(
timepoint = sprintf("gose_%s", timepoint)
) %>%
pivot_wider(
names_from = timepoint,
values_from = gose
) %>%
left_join(
tbl_baseline,
by = "gupi"
)
predmat <- mice::make.predictorMatrix(data = tbl_combined)
predmat[ , "gupi"] <- 0
method <- mice::make.method(data = tbl_combined)
tbl_imputed <- tbl_combined %>%
as.data.frame() %>%
mice::mice(
method = method,
predictorMatrix = predmat,
m = 100,
maxiter = 25,
seed = 42
)
tbl_imputed <- bind_rows(
complete(tbl_imputed, "long") %>%
as_tibble() ,
complete(tbl_imputed, 0) %>%
as_tibble %>%
mutate(.imp = 0, .id = row_number())
) %>%
arrange(.imp, .id) %>%
pivot_longer(
starts_with("gose_"),
names_to = "tmp",
values_to = "GOSE"
) %>%
separate(tmp, c("tmp", "timepoint")) %>%
select(.imp, gupi, timepoint, GOSE)
tbl_posteriors <- expand_grid(
gupi = unique(tbl_combined$gupi),
GOSE = c(1:8)
) %>%
left_join(
tbl_imputed %>%
filter(timepoint == 180, .imp > 0) %>%
group_by(gupi, timepoint, GOSE) %>%
dplyr::summarize(
t = 180,
n = dplyr::n(),
.groups = "drop"
) %>%
group_by(gupi) %>%
mutate(
probability = n / sum(n),
GOSE = as.integer(as.character(GOSE))
) %>%
ungroup() %>%
arrange(gupi, GOSE) %>%
select(gupi, GOSE, probability),
by = c("gupi", "GOSE")
) %>%
mutate(
t = 180L,
probability = if_else(is.na(probability), 0, probability)
) %>%
select(gupi, GOSE, probability, t)
saveRDS(tbl_posteriors, outputfile)
# tmp2 %>%
# filter(
# gupi %in% tbl_test$gupi
# ) %>%
# group_by(gupi) %>%
# filter(cur_group_id() <= 36) %>%
# ungroup() %>%
# {ggplot(filter(., .imp == 0, !is.na(gose))) +
# aes(days, gose) +
# geom_point(data = filter(., .imp > 0, timepoint == "180"), color = "red", alpha = 1/25) +
# geom_point() +
# facet_wrap(~gupi)
# }
library(tidyverse)
library(mice)
library(miceadds)
args <- commandArgs(trailingOnly = TRUE)
inputfile <- args[[1]] # "../output/v1.1/data/validation/df_train_mi_1_fold_1.rds"#
outputfile <- args[[2]]
config <- yaml::read_yaml("config.yml")
tbl_train <- readRDS(inputfile) %>%
as_tibble()
tbl_blank <- tbl_train %>%
select(
-Outcomes.DerivedCompositeGOSE,
-Outcomes.DerivedCompositeGOSEDaysPostInjury
) %>%
distinct() %>%
mutate(
Outcomes.DerivedCompositeGOSE = NA,
Outcomes.DerivedCompositeGOSEDaysPostInjury = 180
)
tbl_gose <- tbl_train %>%
bind_rows(tbl_blank) %>%
select(
gupi,
Outcomes.DerivedCompositeGOSEDaysPostInjury,
Outcomes.DerivedCompositeGOSE
)
map_to_timepoint <- function(x,
timepoints = c(14, 90, 180, 365)
) {
xx <- matrix(
rep(x, length(timepoints)),
ncol = length(timepoints)
)
as.character(timepoints)[
abs(xx - matrix(
rep(timepoints, length(x)),
ncol = length(timepoints),
byrow = TRUE
)) %>%
apply(1, which.min)
]
}
tbl_combined <- expand_grid(
gupi = tbl_gose$gupi %>% unique,
timepoint = c("14", "90", "180", "365")
) %>%
left_join(
tbl_gose %>%
mutate(
gose = as.integer(as.character(Outcomes.DerivedCompositeGOSE)),
timepoint = map_to_timepoint(Outcomes.DerivedCompositeGOSEDaysPostInjury)
) %>%
group_by(gupi, timepoint) %>%
arrange(
gupi,
as.integer(timepoint),
abs(Outcomes.DerivedCompositeGOSEDaysPostInjury - as.integer(timepoint))
) %>%
select(
-Outcomes.DerivedCompositeGOSEDaysPostInjury,
-Outcomes.DerivedCompositeGOSE
) %>%
summarize(
gose = first(gose[!is.na(gose)]),
.groups = "drop"
),
by = c("gupi", "timepoint")
) %>%
mutate(
gose = factor(gose, levels = c(1, 3:8), ordered = TRUE)
) %>%
mutate(
timepoint = sprintf("gose_%s", timepoint)
) %>%
pivot_wider(
names_from = timepoint,
values_from = gose
)
predmat <- mice::make.predictorMatrix(data = tbl_combined)
predmat[ , "gupi"] <- 0
method <- mice::make.method(data = tbl_combined)
tbl_imputed <- tbl_combined %>%
as.data.frame() %>%
mice::mice(
method = method,
predictorMatrix = predmat,
m = 100,
maxiter = 25,
seed = 42
)
tbl_imputed <- bind_rows(
complete(tbl_imputed, "long") %>%
as_tibble() ,
complete(tbl_imputed, 0) %>%
as_tibble %>%
mutate(.imp = 0, .id = row_number())
) %>%
arrange(.imp, .id) %>%
pivot_longer(
starts_with("gose_"),
names_to = "tmp",
values_to = "GOSE"
) %>%
separate(tmp, c("tmp", "timepoint")) %>%
select(.imp, gupi, timepoint, GOSE)
tbl_posteriors <- expand_grid(
gupi = unique(tbl_combined$gupi),
GOSE = c(1:8)
) %>%
left_join(
tbl_imputed %>%
filter(timepoint == 180, .imp > 0) %>%
group_by(gupi, timepoint, GOSE) %>%
dplyr::summarize(
t = 180,
n = dplyr::n(),
.groups = "drop"
) %>%
group_by(gupi) %>%
mutate(
probability = n / sum(n),
GOSE = as.integer(as.character(GOSE))
) %>%
ungroup() %>%
arrange(gupi, GOSE) %>%
select(gupi, GOSE, probability),
by = c("gupi", "GOSE")
) %>%
mutate(
t = 180L,
probability = if_else(is.na(probability), 0, probability)
) %>%
select(gupi, GOSE, probability, t)
saveRDS(tbl_posteriors, outputfile)
# tmp2 %>%
# filter(
# gupi %in% tbl_test$gupi
# ) %>%
# group_by(gupi) %>%
# filter(cur_group_id() <= 36) %>%
# ungroup() %>%
# {ggplot(filter(., .imp == 0, !is.na(gose))) +
# aes(days, gose) +
# geom_point(data = filter(., .imp > 0, timepoint == "180"), color = "red", alpha = 1/25) +
# geom_point() +
# facet_wrap(~gupi)
# }
......@@ -233,7 +233,7 @@ for (modelname in modelnames) {
```{r compute-map-predictions}
df_predictions <- df_model_posteriors %>%
group_by(fold, model, gupi, t, GOSE) %>%
summarise(probability = mean(probability)) %>%
summarise(probability = mean(probability), .groups = "drop") %>%
arrange(model, fold, gupi, t, GOSE) %>%
ungroup() %>%
filter(t == 180) %>%
......@@ -246,7 +246,7 @@ df_predictions <- df_model_posteriors %>%
ifelse(length(.) > 1, round(mean(.)), .), 8),
probability = .$probability
) %>%
unnest %>%
unnest(c(GOSE, prediction, probability)) %>%
spread(GOSE, probability) %>%
right_join(
df_ground_truth %>%
......@@ -303,9 +303,9 @@ df_average_confusion_matrices <- df_predictions %>%
mutate(`Predicted GOSE` = row_number() %>% as.character) %>%
gather(`True GOSE`, n, 1:8)
) %>%
unnest %>%
unnest(confusion_matrix) %>%
group_by(model, `Predicted GOSE`, `True GOSE`) %>%
summarize(n = mean(n)) %>%
summarize(n = mean(n), .groups = "drop") %>%
ungroup
```
......@@ -317,20 +317,20 @@ All values are averaged accross folds.
```{r confusion-matrix-locf, warning=FALSE, message=FALSE, fig.height=4.5, out.width=".9\\textwidth", fig.align='center'}
df_average_confusion_matrices %>%
group_by(model, `True GOSE`) %>%
ggplot(aes(`True GOSE`, `Predicted GOSE`, fill = n)) +
geom_raster() +
geom_hline(yintercept = c(2, 4, 6) + .5, color = "black") +
geom_vline(xintercept = c(2, 4, 6) + .5, color = "black") +
scale_fill_gradient(low = "white", high = "black") +
coord_fixed(expand = FALSE) +
labs(x = "true GOSE", y = "imputed GOSE", fill = "") +
theme_bw() +
theme(
panel.grid = element_blank()
) +
facet_wrap(~model) +
ggtitle("Average confusion matrix accross folds (absolute counts)")
group_by(model, `True GOSE`) %>%
ggplot(aes(`True GOSE`, `Predicted GOSE`, fill = n)) +
geom_raster() +
geom_hline(yintercept = c(2, 4, 6) + .5, color = "black") +
geom_vline(xintercept = c(2, 4, 6) + .5, color = "black") +
scale_fill_gradient(low = "white", high = "black") +
coord_fixed(expand = FALSE) +
labs(x = "true GOSE", y = "imputed GOSE", fill = "") +
theme_bw() +
theme(
panel.grid = element_blank()
) +
facet_wrap(~model) +
ggtitle("Average confusion matrix accross folds (absolute counts)")
ggsave(filename = "confusion_matrices_locf.pdf", width = 7, height = 7)
ggsave(filename = "confusion_matrices_locf.png", width = 7, height = 7)
......
......@@ -25,7 +25,7 @@ rule fit_validation_model:
rule fit_models_validation_v1_1:
input:
["output/v1.1/data/validation/posteriors/%s/df_posterior_mi_%i_fold_%i.rds" % (m, i, j)
for m in ("locf", "msm", "msm_age", "gp", "gp_nb", "mm", "mm_nb")
for m in ("locf", "msm", "msm_age", "gp", "gp_nb", "mm", "mm_nb", "mice", "mice_nb")
for i in range(1, config["mi_m"] + 1)
for j in range(1, config["folds"] + 1)
]
......@@ -199,8 +199,8 @@ for (modelname in modelnames) {
df_model_posteriors <- df_model_posteriors %>%
mutate(
model = factor(model,
levels = c("locf", "mm_nb", "mm", "gp_nb", "gp", "msm", "msm_age"),
labels = c("LOCF", "MM", "MM + cov", "GP", "GP + cov", "MSM", "MSM + age"),
levels = c("locf", "mice_nb", "mice", "mm_nb", "mm", "gp_nb", "gp", "msm", "msm_age"),
labels = c("LOCF", "MI", "MI + cov", "MM", "MM + cov", "GP", "GP + cov", "MSM", "MSM + age"),
) %>% as.character
)
......@@ -346,13 +346,13 @@ plot_confusion_matrices <- function(df_predictions, models, nrow = 2, legendpos,
plot_confusion_matrices(
df_predictions %>%
filter(!(gupi %in% idx)),
c("MSM", "GP + cov", "MM", "LOCF"),
c("MSM", "MI", "GP + cov", "MM", "LOCF"),
nrow = 1,
legendpos = "none",
scriptsize = 2.5
)
ggsave(filename = "confusion_matrices_locf.eps", width = 6, height = 3, colormodel = "cmyk")
ggsave(filename = "confusion_matrices_locf.eps", width = 6, height = 2.5, colormodel = "cmyk")
......@@ -416,7 +416,7 @@ plot_summary_measures_cond <- function(df_predictions, models, label) {
plot_summary_measures_cond(
df_predictions %>% filter(!(gupi %in% idx)),
c("MSM", "GP + cov", "MM", "LOCF"),
c("MSM", "MI", "GP + cov", "MM", "LOCF"),
"Summary measures by observed GOSe, LOCF subset"
)
......@@ -427,20 +427,20 @@ ggsave(filename = "errors_stratified_locf.eps", width = 6, height = 3.5, colormo
# figure 5: confusion all ==================================================
plot_confusion_matrices(
df_predictions,
c("MSM", "GP + cov", "MM"),
c("MSM", "MI", "GP + cov", "MM"),
nrow = 1,
legendpos = "none",
scriptsize = 3
)
ggsave(filename = "confusion_matrices_all.eps", width = 6, height = 3, colormodel = "cmyk")
ggsave(filename = "confusion_matrices_all.eps", width = 6, height = 2.5, colormodel = "cmyk")
# figure 6: errors overall ==================================================
plot_summary_measures_cond(
df_predictions %>% filter(!(gupi %in% idx)),
c("MSM", "GP + cov", "MM"),
c("MSM", "MI", "GP + cov", "MM"),
"Summary measures by observed GOSe, full test set"
)
......@@ -448,7 +448,7 @@ ggsave(filename = "imputation_error.eps", width = 6, height = 3.5, colormodel =
# figure 7:marginal gose per fold ==================================================
# figure 7: marginal gose per fold ==================================================
df_ground_truth %>%
ggplot(aes(Outcomes.DerivedCompositeGOSE)) +
geom_bar() +
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment