Snakemake template for building reusable and scalable machine learning pipelines with mikropml
Help improve this workflow!
This workflow has been published but could be further improved with some additional meta data:- Keyword(s) in categories input, output, operation, topic
You can help improve this workflow by suggesting the addition or removal of keywords, suggest changes and report issues, or request to become a maintainer of the Workflow .
Snakemake is a workflow manager that enables massively parallel and reproducible analyses. Snakemake is a suitable tool to use when you can break a workflow down into discrete steps, with each step having input and output files.
mikropml is an R package for supervised machine learning pipelines. We provide this example workflow as a template to get started running mikropml with snakemake. We hope you then customize the code to meet the needs of your particular ML task.
For more details on these tools, see the Snakemake tutorial and read the mikropml docs .
The Workflow
The
Snakefile
contains rules which define the output files we want and how to make them.
Snakemake automatically builds a directed acyclic graph (DAG) of jobs to figure
out the dependencies of each of the rules and what order to run them in.
This workflow preprocesses the example dataset, calls
mikropml::run_ml()
for each seed and ML method set in the config file,
combines the results files, plots performance results
(cross-validation and test AUROCs, hyperparameter AUROCs from cross-validation, and benchmark performance),
and renders a simple
R Markdown report
as a GitHub-flavored markdown file (
see example here
).
The DAG shows how calls to
run_ml
can run in parallel if
snakemake is allowed to run more than one job at a time.
If we use 100 seeds and 4 ML methods, snakemake would call
run_ml
400 times.
Here's a small example DAG if we were to use only 2 seeds and 1 ML method:
Usage
Full usage instructions recommended by snakemake are available in the
snakemake workflow catalog
.
Snakemake recommends using
snakedeploy
to use this workflow as a module in
your own project.
Alternatively, you can download this repo and modify the code directly to suit your needs. See instructions here .
Help & Contributing
If you come across a bug, open an issue and include a minimal reproducible example.
If you have questions, create a new post in Discussions .
If you’d like to contribute, see our guidelines here .
Code of Conduct
Please note that the mikropml-snakemake-workflow is released with a Contributor Code of Conduct . By contributing to this project, you agree to abide by its terms.
More resources
Code Snippets
17 18 | script: "../scripts/combine_results.R" |
34 35 | script: "../scripts/combine_hp_perf.R" |
47 48 | script: "../scripts/mutate_benchmark.R" |
28 29 30 31 32 33 | shell: """ for f in {input.figs}; do cp $f {params.outdir} done """ |
58 59 | script: "../scripts/report.Rmd" |
17 18 | script: "../scripts/preproc.R" |
43 44 | script: "../scripts/train_ml.R" |
64 65 | script: "../scripts/find_feature_importance.R" |
80 81 | script: "../scripts/calc_model_sensspec.R" |
14 15 | script: "../scripts/plot_performance.R" |
34 35 | script: "../scripts/plot_feature_importance.R" |
46 47 | script: "../scripts/make_blank_plot.R" |
63 64 | script: "../scripts/plot_hp_perf.R" |
81 82 | script: "../scripts/plot_benchmarks.R" |
94 95 | script: "../scripts/plot_roc_curves.R" |
107 108 109 110 | shell: """ snakemake --{wildcards.cmd} --configfile {params.config_path} 2> {log} > {output.dot} """ |
122 123 124 125 | shell: """ cat {input.dot} | dot -T png 2> {log} > {output.png} """ |
1 2 3 4 5 6 7 8 9 10 11 12 13 | schtools::log_snakemake() library(tidyverse) model <- read_rds(snakemake@input[["model"]]) test_dat <- read_csv(snakemake@input[["test"]]) outcome_colname <- snakemake@params[["outcome_colname"]] mikropml::calc_model_sensspec( model, test_dat, outcome_colname ) %>% bind_cols(schtools::get_wildcards_tbl()) %>% write_csv(snakemake@output[["csv"]]) |
1 2 3 4 5 6 | schtools::log_snakemake() models <- lapply(snakemake@input[["rds"]], function(x) readRDS(x)) hp_perf <- mikropml::combine_hp_performance(models) hp_perf$method <- snakemake@wildcards[["method"]] saveRDS(hp_perf, file = snakemake@output[["rds"]]) |
1 2 3 4 5 6 | schtools::log_snakemake() library(dplyr) snakemake@input[["csv"]] %>% purrr::map_dfr(readr::read_csv) %>% readr::write_csv(snakemake@output[["csv"]]) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | schtools::log_snakemake() library(mikropml) library(dplyr) library(readr) doFuture::registerDoFuture() future::plan(future::multicore, workers = snakemake@threads) message(paste("# workers: ", foreach::getDoParWorkers())) model <- readRDS(snakemake@input[["model"]]) outcome_colname <- snakemake@params[["outcome_colname"]] train_dat <- model$trainingData names(train_dat)[names(train_dat) == ".outcome"] <- outcome_colname test_dat <- read_csv(snakemake@input[["test"]]) method <- snakemake@params[["method"]] seed <- as.numeric(snakemake@params[["seed"]]) outcome_type <- get_outcome_type(c( train_dat %>% pull(outcome_colname), test_dat %>% pull(outcome_colname) )) class_probs <- outcome_type != "continuous" perf_metric_function <- get_perf_metric_fn(outcome_type) perf_metric_name <- get_perf_metric_name(outcome_type) if (!is.na(seed)) { set.seed(seed) } feat_imp <- mikropml::get_feature_importance( trained_model = model, test_data = test_dat, outcome_colname = outcome_colname, perf_metric_function = perf_metric_function, perf_metric_name = perf_metric_name, class_probs = class_probs, method = method, seed = seed, ) wildcards <- schtools::get_wildcards_tbl() readr::write_csv( feat_imp %>% inner_join(wildcards, by = c("method", "seed")), snakemake@output[["feat"]] ) |
1 2 3 4 5 6 7 8 9 10 | schtools::log_snakemake() library(ggplot2) message("making a blank plot") ggsave( filename = snakemake@output[["plot"]], plot = ggplot() + theme_void(), height = 0.1, width = 0.1, device = "png" ) |
1 2 3 4 5 6 7 8 | schtools::log_snakemake() library(tidyverse) wildcards <- schtools::get_wildcards_tbl() read_tsv(snakemake@input[["tsv"]]) %>% bind_cols(wildcards) %>% write_csv(snakemake@output[["csv"]]) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | schtools::log_snakemake() library(tidyverse) dat <- read_csv(snakemake@input[["csv"]], col_types = cols( s = col_double(), `h:m:s` = col_time(format = "%H:%M:%S"), max_rss = col_double(), max_vms = col_double(), max_uss = col_double(), max_pss = col_double(), io_in = col_double(), io_out = col_double(), mean_load = col_double(), cpu_time = col_double(), method = col_character(), seed = col_double() ) ) %>% mutate( runtime_mins = s / 60, memory_gb = max_rss / 1024 ) %>% select(method, runtime_mins, memory_gb) %>% pivot_longer(-method, names_to = "metric") %>% mutate(value = round(value, 2)) %>% group_by(method) bench_plot <- dat %>% ggplot(aes(method, value)) + geom_boxplot() + facet_wrap(metric ~ ., scales = "free") + theme_classic() + labs(y = "", x = "") + coord_flip() ggsave(snakemake@output[["plot"]], plot = bench_plot) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | schtools::log_snakemake() library(dplyr) library(ggplot2) library(schtools) feat_df <- readr::read_csv(snakemake@input[["csv"]]) top_n <- as.numeric(snakemake@params[["top_n"]]) top_feats <- feat_df %>% group_by(method, names) %>% summarize(median_diff = median(perf_metric_diff)) %>% slice_max(order_by = median_diff, n = top_n) feat_plot <- feat_df %>% right_join(top_feats, by = c("method", "names")) %>% mutate(features = factor(names, levels = unique(top_feats$names))) %>% ggplot(aes(x = perf_metric_diff, y = features, color = method)) + geom_boxplot() + facet_wrap(~method) + theme_sovacool() ggsave( filename = snakemake@output[["plot"]], plot = feat_plot, device = "png" ) |
1 2 3 4 5 6 7 8 | schtools::log_snakemake() hp_perf <- readRDS(snakemake@input[["rds"]]) hp_plot_list <- lapply(hp_perf$params, function(param) { mikropml::plot_hp_performance(hp_perf$dat, !!rlang::sym(param), !!rlang::sym(hp_perf$metric)) + ggplot2::theme_classic() + ggplot2::scale_color_brewer(palette = "Dark2") + ggplot2::labs(title = unique(hp_perf$method)) }) hp_plot <- cowplot::plot_grid(plotlist = hp_plot_list) ggplot2::ggsave(snakemake@output[["plot"]]) |
1 2 3 4 5 6 7 8 9 10 | schtools::log_snakemake() library(tidyverse) perf_plot <- snakemake@input[["csv"]] %>% read_csv() %>% mikropml::plot_model_performance() + theme_classic() + scale_color_brewer(palette = "Dark2") + coord_flip() ggsave(snakemake@output[["plot"]], plot = perf_plot) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | schtools::log_snakemake() library(patchwork) library(tidyverse) dat <- read_csv(snakemake@input[["csv"]]) calc_mean_perf <- function(sensspec_dat, group_var = specificity, sum_var = sensitivity, custom_group_vars = NULL) { specificity <- sensitivity <- sd <- NULL dat_round <- sensspec_dat %>% dplyr::mutate({{ group_var }} := round({{ group_var }}, 2)) if (!is.null(custom_group_vars)) { dat_grouped <- dat_round %>% dplyr::group_by({{ group_var }}, !!rlang::sym(custom_group_vars)) } else { dat_grouped <- dat_round %>% dplyr::group_by({{ group_var }}) } return( dat_grouped %>% dplyr::summarise( mean = mean({{ sum_var }}), sd = stats::sd({{ sum_var }}) ) %>% dplyr::mutate( upper = mean + sd, lower = mean - sd, upper = dplyr::case_when( upper > 1 ~ 1, TRUE ~ upper ), lower = dplyr::case_when( lower < 0 ~ 0, TRUE ~ lower ) ) %>% dplyr::rename( "mean_{{ sum_var }}" := mean, "sd_{{ sum_var }}" := sd ) ) } calc_mean_roc <- function(sensspec_dat, custom_group_vars = NULL) { specificity <- sensitivity <- NULL return(calc_mean_perf(sensspec_dat, group_var = specificity, sum_var = sensitivity, custom_group_vars = custom_group_vars )) } calc_mean_prc <- function(sensspec_dat, custom_group_vars = NULL) { sensitivity <- recall <- precision <- NULL return(calc_mean_perf( sensspec_dat %>% dplyr::rename(recall = sensitivity), group_var = recall, sum_var = precision, custom_group_vars = custom_group_vars )) } shared_ggprotos <- function(colorvar) { return(list( ggplot2::geom_ribbon(aes(fill = {{ colorvar }}), alpha = 0.5), ggplot2::geom_line(aes(color = {{ colorvar }})), ggplot2::coord_equal(), ggplot2::scale_y_continuous(expand = c(0, 0), limits = c(-0.01, 1.01)), ggplot2::theme_bw(), ggplot2::theme(legend.title = ggplot2::element_blank()) )) } plot_mean_roc <- function(dat) { specificity <- mean_sensitivity <- lower <- upper <- NULL dat %>% ggplot2::ggplot(ggplot2::aes( x = specificity, y = mean_sensitivity, ymin = lower, ymax = upper )) + shared_ggprotos(colorvar = method) + ggplot2::geom_abline( intercept = 1, slope = 1, linetype = "dashed", color = "grey50" ) + ggplot2::scale_x_reverse(expand = c(0, 0), limits = c(1.01, -0.01)) + ggplot2::labs(x = "Specificity", y = "Mean Sensitivity") } plot_mean_prc <- function(dat, baseline_precision = NULL) { recall <- mean_precision <- lower <- upper <- NULL prc_plot <- dat %>% ggplot2::ggplot(ggplot2::aes( x = recall, y = mean_precision, ymin = lower, ymax = upper )) + shared_ggprotos(colorvar = method) + ggplot2::scale_x_continuous(expand = c(0, 0), limits = c(-0.01, 1.01)) + ggplot2::labs(x = "Recall", y = "Mean Precision") if (!is.null(baseline_precision)) { prc_plot <- prc_plot + ggplot2::geom_hline( yintercept = baseline_precision, linetype = "dashed", color = "grey50" ) } return(prc_plot) } p <- (dat %>% calc_mean_roc(custom_group_vars = "method") %>% plot_mean_roc()) + (dat %>% calc_mean_prc(custom_group_vars = "method") %>% plot_mean_prc() + theme(legend.position = "none")) ggsave( filename = snakemake@output[["plot"]], plot = p, device = "png", height = 4, width = 6 ) |
1 2 3 4 5 6 7 8 9 10 | schtools::log_snakemake() library(mikropml) doFuture::registerDoFuture() future::plan(future::multicore, workers = snakemake@threads) data_raw <- readr::read_csv(snakemake@input[["csv"]]) data_processed <- preprocess_data(data_raw, outcome_colname = snakemake@params[["outcome_colname"]]) saveRDS(data_processed, file = snakemake@output[["rds"]]) |
12 | schtools::set_knitr_opts() |
16 | library(knitr) |
29 | include_graphics(snakemake@input[['rulegraph']]) |
35 | include_graphics(snakemake@input[['perf_plot']]) |
39 | include_graphics(snakemake@input[['roc_plot']]) |
45 | include_graphics(snakemake@input[['hp_plot']]) |
49 50 51 | if (isTRUE(snakemake@params[['find_feature_importance']])) { cat("## Feature Importance") } |
55 | include_graphics(snakemake@input[['feat_plot']]) |
64 | include_graphics(snakemake@input[['bench_plot']]) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | schtools::log_snakemake() library(dplyr) doFuture::registerDoFuture() future::plan(future::multicore, workers = snakemake@threads) method <- snakemake@params[["method"]] seed <- as.numeric(snakemake@params[["seed"]]) hyperparams <- snakemake@params[["hyperparams"]][[method]] data_processed <- readRDS(snakemake@input[["rds"]])$dat_transformed ml_results <- mikropml::run_ml( dataset = data_processed, method = method, outcome_colname = snakemake@params[["outcome_colname"]], find_feature_importance = FALSE, kfold = as.numeric(snakemake@params[["kfold"]]), seed = seed, hyperparameters = hyperparams ) wildcards <- schtools::get_wildcards_tbl() readr::write_csv( ml_results$performance %>% inner_join(wildcards, by = c("method", "seed")), snakemake@output[["perf"]] ) readr::write_csv(ml_results$test_data, snakemake@output[["test"]]) saveRDS(ml_results$trained_model, file = snakemake@output[["model"]]) |
84 85 | script: "scripts/report.Rmd" |
104 105 106 107 | shell: """ zip -r {output} {input} 2> {log} """ |
Support
- Future updates
Related Workflows





