StratoMod: Quantifying the Difficulty of Variant Calling in Genomic Context
Help improve this workflow!
This workflow has been published but could be further improved with some additional meta data:- Keyword(s) in categories input, output
You can help improve this workflow by suggesting the addition or removal of keywords, suggest changes and report issues, or request to become a maintainer of the Workflow .
A model-based tool to quantify the difficulty of calling a variant given genomic context.
Background
Intuitively we understand that accurately calling variants in a genome can be more or less difficult depending on the context of that variant. For example, many sequencing technologies have higher error rates in homopolymers, and this error rate generally increases as homopolymers get longer. However, precisely quantifying the relationship between these errors, the length of the homopolymer, and the impact on the resulting variant call remain challenging. Analogous arguments can be drawn for other "repetitive" regions in the genome, such as tandem repeats, segmental duplications, transposable elements, and difficult-to-map regions.
The solution we present here is to use an interpretable modeling framework called explainable boosting machines to predict variant calling errors as a function of genomic features (eg, whether or not the variant in a tandem repeat, homopolymer, etc). The interpretability of the model is important for allowing end users to understand the relationship each feature has to the prediction, which facilitates understanding (for example) at what lengths of homopolymers the likelihood of incorrectly calling a variant drastically increases. This precision is an improvement over existing methods we have developed for stratifying the genome by difficulty into discrete bins. Furthermore, this modeling framework allows understanding of interactions between different genomic contexts, which is important as many repetitive characteristics do not exist in isolation.
We anticipate
StratoMod
would be useful for both method developers and clinicians who wish to better understand variant calling error modalities. In the case of method development,
StratoMod
can be used to accurately compare the error modalities of different sequencing technologies. For clinicians, this can be used for determining in which regions/genes (which may be clinically interesting for a given study) variant errors are likely to occur, which may in turn inform which technologies should be employed and/or other mitigation strategies should be used.
Further information can be found in our preprint .
User Guide
Pipeline steps
-
Compare user-supplied query vcf with GIAB benchmark vcf to produce labels (true positive, false positive, false negative). The labels comprise the dependent variable used in model training downstream.
-
Intersect comparison output labels with genomic features to produce the features (independent variables) used in model training.
-
Train the EBM model with random holdout for testing
-
If desired, test the model on other query vcfs (which may or may not also be labeled with a benchmark comparison).
-
Inspect the output features (plots showing the profile of each feature and its effect on the label).
NOTE: currently only two labels can be compared at once given that we used a binary classifier. This means either one of the three labels must be omitted or two need to be combined into one label.
Data Inputs
The only mandatory user-supplied data required to run is a query vcf. Optionally one can supply other vcfs for testing the model.
Unless one is using esoteric references or benchmarks, the pipeline is preconfigured to retrieve commonly-used data defined by flags in the configuration. This includes:
-
a GIAB benchmark, including the vcf, bed, and reference fasta
-
reference-specific bed files which will provide "contextual" features for each variant call, including:
-
difficult-to-map regions (GIAB stratification bed file)
-
segmental duplications (UCSC superdups database)
-
tandem repeats (UCSC simple repeats database)
-
transposable elements (UCSC Repeat Masker)
-
Installation
This assumes the user has a working
conda
or
mamba
installation.
Run the following to set up the runtime environment.
mamba env create -f env.yml
Configuration
A sample configuration file may be found in
config/dynamic-testing.yml
which may be copied as a starting point and modified to one's liking. This file is heavily annotated to explain all the options/flags and their purpose.
For a list of features which may be used, see
FEATURES.md
.
Running
Execute the pipeline using snakemake:
snakemake --use-conda -c <num_cores> --rerun-incomplete --configfile=config/<confname.yml> all
Output
Report
Each model has a report at
results/model/<model_key>-<filter_key>-<run_key>/summary.html
which contains model performance curves and feature plots (the latter which allows model interpretation).
Here
<model_key>
is the key under the
models
section in the config,
<filter_key
is either
SNV
or
INDEL
depending on what was requested, and
<run_key>
is the key under the
models -> <model_key> -> runs
section in the config.
Train/test data
All raw data for the models will be saved alongside the model report (see above). This includes the input tsv of data used to train the EBM, a config yml file with all settings used to train the EBM for reference, and python pickles for the X/Y train/test datasets as well as a pickle for the final model itself.
Within the run directory will also be a
test
directory which will contain all test runs (eg the results of the model test and the input data used for the test).
Raw input data
In addition to the model data itself, the raw input data (that is the master dataframe with all features for each query vcf prior to filtering/transformation) can be found in
results/annotated/{unlabeled,labeled}/<query_key>
where
query_key
is the key under either
labeled_queries
or
unlabeled_queries
in the config.
Each of these directories contains the raw dataframe itself (both both SNVs and INDELs) as well as an HTML report summarizing the dataframe (statistics for each feature, distributions, correlations, etc)
Developer Guide
Environments
By convention, the conda environment specified by
env.yml
only has runtime dependencies for the pipeline itself.
To install development environments, run the following:
./setup_dev.sh
In addition to creating new environments, this script will update existing ones if they are changed during development.
Note that scripts in the pipeline are segregated by environment in order to prevent dependency hell while maintaining reproducible builds. When editing, one will need to switch between environments in the IDE in order to benefit from the features they provide. Further details on which environments correspond to which files can be found in
workflow/scripts
.
Note that this will only install environments necessary for running scripts (eg rules with a
script
directive).
Linting
All python code should be error free when finalizing any new features. Linting will be performed automatically as part of the CI/CD pipeline, but to run it manually, invoke the following:
./lint.sh
This assumes all development environments are installed (see above).
New Feature Workflow
There are two main development branches:
master
and
develop
.
Make a new branch off of develop for the new feature, then merge into develop when done (note
--no-ff
).
git checkout develop
git branch -n <new_feature>
git checkout <new_feature>
# do a bunch of stuff...
git checkout develop
git merge --no-ff <new_feature>
After feature(s) have been added and all tests have succeeded, update changelog, add tag, and merge into master. Use semantic versioning for tags.
# update changelog
vim CHANGELOG.md
git commit
git tag vX.Y.Z
git checkout master
git merge --no-ff vX.Y.Z
NOTE: do not add an experiment-specific configuration to
master
or
develop
. The yml files in
config
for these branches are used for testing. See below for how to add an experiment.
Code Snippets
15 16 | shell: "curl -sS -L -o {output} {params.url}" |
24 25 26 27 28 | shell: """ mkdir {output} && \ tar xzf {input} --directory {output} --strip-components=1 """ |
40 41 | shell: "make -C {input} > {log} && mv {input}/repseq {output}" |
55 56 57 58 59 60 61 | shell: """ gunzip -c {input.ref} | \ {input.bin} 1 4 - 2> {log} | \ sed '/^#/d' | \ gzip -c > {output} """ |
80 81 | script: "../../scripts/python/bio/get_homopoly_features.py" |
18 19 | script: "../../scripts/python/bio/download_bed_or_vcf.py" |
51 52 | script: "../../scripts/python/bio/get_mappability_features.py" |
39 40 | script: "../../scripts/python/bio/get_repeat_masker_features.py" |
29 30 | script: "../../scripts/python/bio/get_segdup_features.py" |
32 33 | script: "../../scripts/python/bio/get_tandem_repeat_features.py" |
37 38 | script: "../scripts/python/bio/download_ref.py" |
69 70 | script: "../scripts/python/bio/filter_sort_ref.py" |
83 84 85 86 87 88 | shell: """ samtools faidx {input} -o - 2> {log} | \ cut -f1,2 > \ {output} """ |
102 103 | shell: "rtg format -o {output} {input} 2>&1 > {log}" |
123 124 | script: "../scripts/python/bio/write_bed.py" |
141 142 | script: "../scripts/python/bio/download_bed_or_vcf.py" |
164 165 | script: "../scripts/python/bio/standardize_vcf.py" |
188 189 | shell: "tabix -p vcf {input}" |
247 248 | script: "../scripts/python/bio/standardize_bed.py" |
260 261 262 263 264 | shell: """ bedtools subtract -a {input.bed} -b {input.mhc} | \ bgzip > {output} """ |
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 | shell: """ rm -rf {params.tmp_dir} && \ rtg RTG_MEM=$(({resources.mem_mb}*80/100))M \ vcfeval {params.extra} \ --threads={threads} \ -b {input.bench_vcf} \ -e {input.bench_bed} \ -c {input.query_vcf} \ -o {params.tmp_dir} \ -t {input.sdf} > {log} 2>&1 && \ mv {params.tmp_dir}/* {params.output_dir} && \ rm -r {params.tmp_dir} """ |
353 354 | script: "../scripts/python/bio/vcf_to_bed.py" |
376 377 | script: "../scripts/python/bio/concat_tsv.py" |
75 76 | script: "../scripts/python/bio/annotate_variants.py" |
115 116 | script: "../scripts/rmarkdown/summary/input_summary.Rmd" |
161 162 | script: "../scripts/python/bio/prepare_train.py" |
187 188 | script: "../scripts/python/ebm/train_ebm.py" |
203 204 | script: "../scripts/python/ebm/decompose_model.py" |
227 228 | script: "../scripts/rmarkdown/summary/train_summary.Rmd" |
274 275 | script: "../scripts/python/bio/prepare_test.py" |
325 326 | script: "../scripts/python/ebm/test_ebm.py" |
360 361 | script: "../scripts/rmarkdown/summary/test_summary.Rmd" |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | from functools import reduce import pandas as pd from typing import Any, cast import numpy as np from common.tsv import write_tsv from pybedtools import BedTool as bt # type: ignore from common.io import setup_logging import common.config as cfg logger = setup_logging(snakemake.log[0]) # type: ignore def left_outer_intersect(left: pd.DataFrame, path: str) -> pd.DataFrame: logger.info("Adding annotations from %s", path) # Use bedtools to perform left-outer join of two bed/tsv files. Since # bedtools will join all columns from the two input files, keep track of the # width of the left input file so that the first three columns of the right # input (chr, chrStart, chrEnd, which are redundant) can be dropped. left_cols = left.columns.tolist() left_width = len(left_cols) right = pd.read_table(path) # ASSUME the first three columns are the bed index columns right_cols = ["_" + c if i < 3 else c for i, c in enumerate(right.columns.tolist())] right_bed = bt.from_dataframe(right) # prevent weird type errors when converted back to dataframe from bed dtypes = {right_cols[0]: str} # convert "." to NaN since "." is a string/object which will make pandas run # slower than an actual panda na_vals = {c: "." for c in left_cols + right_cols[3:]} new_df = cast( pd.DataFrame, bt.from_dataframe(left) .intersect(right_bed, loj=True) .to_dataframe(names=left_cols + right_cols, na_values=na_vals, dtype=dtypes), ) # Bedtools intersect will use -1 for NULL in the case of numeric columns. I # suppose this makes sense since any "real" bed columns (according to the # "spec") will always be positive integers or strings. Since -1 might be a # real value and not a missing one in my case, use the chr field to figure # out if a row is "missing" and fill NaNs accordingly new_cols = new_df.columns[left_width:] new_pky = new_cols[:3] new_chr = new_pky[0] new_data_cols = new_cols[3:] new_df.loc[:, new_data_cols] = new_df[new_data_cols].where( new_df[new_chr] != ".", np.nan ) logger.info("Annotations added: %s\n", ", ".join(new_data_cols.tolist())) return new_df.drop(columns=new_pky) def intersect_tsvs( config: cfg.StratoMod, ifile: str, ofile: str, tsv_paths: list[str], ) -> None: target_df = pd.read_table(ifile) new_df = reduce(left_outer_intersect, tsv_paths, target_df) new_df.insert(loc=0, column=cfg.VAR_IDX, value=new_df.index) write_tsv(ofile, new_df) def main(smk: Any, config: cfg.StratoMod) -> None: fs = smk.input.features vcf = smk.input.variants[0] logger.info("Adding annotations to %s\n", vcf) intersect_tsvs(config, vcf, smk.output[0], fs) main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 | import pandas as pd from common.tsv import write_tsv from common.bed import sort_bed_numerically # use pandas here since it will more reliably account for headers df = pd.concat([pd.read_table(i, header=0) for i in snakemake.input]) # type: ignore write_tsv(snakemake.output[0], sort_bed_numerically(df, 3)) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | from pathlib import Path import subprocess as sp from typing import Callable from typing_extensions import assert_never from tempfile import NamedTemporaryFile as Tmp import common.config as cfg from common.io import get_md5, is_gzip, setup_logging # hacky curl/gzip wrapper; this exists because I got tired of writing # specialized rules to convert gzip/nozip files to bgzip and back :/ # Solution: force bgzip for references and gzip for bed log = setup_logging(snakemake.log[0]) # type: ignore GZIP = ["gzip", "-c"] CURL = ["curl", "-Ss", "-L", "-q"] def main(opath: Path, src: cfg.FileSrc | None) -> None: if isinstance(src, cfg.LocalSrc): # ASSUME these are already tested via the pydantic class for the # proper file format Path(opath).symlink_to(Path(src.filepath).resolve()) elif isinstance(src, cfg.HTTPSrc): curlcmd = [*CURL, src.url] # to test the format of downloaded files, sample the first 65000 bytes # (which should be enough to get one block of a bgzip file, which will # allow us to test for it) curltestcmd = [*CURL, "-r", "0-65000", src.url] with open(opath, "wb") as f, Tmp() as tf: def curl() -> None: sp.Popen(curlcmd, stdout=f).wait() def curl_test(testfun: Callable[[Path], bool]) -> bool: sp.Popen(curltestcmd, stdout=tf).wait() return testfun(Path(tf.name)) def curl_gzip(cmd: list[str]) -> None: p1 = sp.Popen(curlcmd, stdout=sp.PIPE) p2 = sp.Popen(cmd, stdin=p1.stdout, stdout=f) p2.wait() if curl_test(is_gzip): curl() else: curl_gzip(GZIP) elif src is None: assert False, "file src is null; this should not happen" else: assert_never(src) if src.md5 is not None and src.md5 != (actual := get_md5(opath)): log.error("md5s don't match; wanted %s, actual %s", src.md5, actual) exit(1) main(Path(snakemake.output[0]), snakemake.params.src) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | from pathlib import Path import subprocess as sp from typing import Callable, Any, cast from typing_extensions import assert_never from tempfile import NamedTemporaryFile as Tmp from common.io import is_gzip, setup_logging, get_md5, get_md5_dir from common.bed import is_bgzip import common.config as cfg GUNZIP = ["gunzip", "-c"] BGZIP = ["bgzip", "-c"] CURL = ["curl", "-Ss", "-L", "-q"] log = setup_logging(snakemake.log[0]) # type: ignore def main(smk: Any, params: Any) -> None: src = cast(cfg.FileSrc, params.src) opath = Path(smk.output[0]) is_fasta = smk.params.is_fasta if isinstance(src, cfg.LocalSrc): # ASSUME this is in the format we indicate (TODO be more paranoid) opath.symlink_to(Path(src.filepath).resolve()) elif isinstance(src, cfg.HTTPSrc): curlcmd = [*CURL, src.url] if is_fasta: # to test the format of downloaded files, sample the first 65000 bytes # (which should be enough to get one block of a bgzip file, which will # allow us to test for it) curltestcmd = [*CURL, "-r", "0-65000", src.url] with open(opath, "wb") as f, Tmp() as tf: def curl() -> None: sp.Popen(curlcmd, stdout=f).wait() def curl_test(testfun: Callable[[Path], bool]) -> bool: sp.Popen(curltestcmd, stdout=tf).wait() return testfun(Path(tf.name)) def curl_gzip(cmd: list[str]) -> None: p1 = sp.Popen(curlcmd, stdout=sp.PIPE) p2 = sp.Popen(cmd, stdin=p1.stdout, stdout=f) p2.wait() if curl_test(is_bgzip): curl() elif curl_test(is_gzip): p1 = sp.Popen(curlcmd, stdout=sp.PIPE) p2 = sp.Popen(GUNZIP, stdin=p1.stdout, stdout=sp.PIPE) p3 = sp.Popen(BGZIP, stdin=p2.stdout, stdout=f) p3.wait() else: curl_gzip(BGZIP) else: tarcmd = [ *["bsdtar", "-xf", "-"], *["--directory", str(opath)], "--strip-component=1", ] opath.mkdir(parents=True) p1 = sp.Popen(curlcmd, stdout=sp.PIPE) p2 = sp.Popen(tarcmd, stdin=p1.stdout) p2.wait() else: assert_never(src) if src.md5 is not None: actual = get_md5(opath) if is_fasta else get_md5_dir(opath) if actual != src.md5: log.error("md5s don't match; wanted %s, actual %s", src.md5, actual) exit(1) main(snakemake, snakemake.params) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | import re from typing import Any import subprocess as sp import common.config as cfg from Bio import bgzf # type: ignore from common.io import setup_logging logger = setup_logging(snakemake.log[0]) # type: ignore def stream_fasta(ipath: str, chr_names: list[str]) -> sp.Popen[bytes]: return sp.Popen( ["samtools", "faidx", ipath, *chr_names], stdout=sp.PIPE, stderr=sp.PIPE, ) def stream_sdf(ipath: str, chr_names: list[str]) -> sp.Popen[bytes]: return sp.Popen( [ *["rtg", "sdf2fasta", "--no-gzip", "--line-length=70"], *["--input", ipath], *["--output", "-"], *["--names", *chr_names], ], stdout=sp.PIPE, stderr=sp.PIPE, ) def main(smk: Any, sconf: cfg.StratoMod) -> None: rsk = cfg.RefsetKey(smk.wildcards["refset_key"]) cs = sconf.refsetkey_to_chr_indices(rsk) prefix = sconf.refsetkey_to_ref(rsk).sdf.chr_prefix chr_mapper = {c.chr_name_full(prefix): c.value for c in cs} chr_names = [*chr_mapper] # Read from a fasta or sdf depending on what we were given; in either # case, read only the chromosomes we want in sorted order and return a # fasta text stream def choose_input(i: Any) -> sp.Popen[bytes]: try: return stream_fasta(i.fasta[0], chr_names) except AttributeError: try: return stream_sdf(i.sdf[0], chr_names) except AttributeError: assert False, "unknown input key, this should not happen" p = choose_input(smk.input) if p.stdout is not None: # Stream the fasta and replace the chromosome names in the header with # its integer index with bgzf.open(smk.output[0], "w") as f: for i in p.stdout: if i.startswith(b">"): m = re.match(">([^ \n]+)", i.decode()) if m is None: logger.error("could get chrom name from FASTA header") exit(1) try: f.write(f">{chr_mapper[m[1]]}\n") except KeyError: assert False, ( "could not convert '%s' to index, this should not happen" % m[1] ) else: f.write(i) else: assert False, "stdout not a pipe, this should not happen" p.wait() if p.returncode != 0: logger.error(p.stderr) exit(1) main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | from pathlib import Path import pandas as pd import common.config as cfg from typing import Any, cast from pybedtools import BedTool as bt # type: ignore from pybedtools import cleanup from common.tsv import write_tsv from common.io import setup_logging from common.bed import read_bed logger = setup_logging(snakemake.log[0]) # type: ignore # temporary columns used for dataframe processing BASE_COL = "_base" PFCT_LEN_COL = "_perfect_length" SLOP = 1 def read_input(path: Path) -> pd.DataFrame: logger.info("Reading dataframe from %s", path) return read_bed(path, more={3: BASE_COL}) def merge_base( config: cfg.StratoMod, df: pd.DataFrame, base: cfg.Base, genome: str, ) -> pd.DataFrame: logger.info("Filtering bed file for %ss", base) _df = df[df[BASE_COL] == f"unit={base.value}"].drop(columns=[BASE_COL]) logger.info("Merging %s rows for %ss", len(_df), base) # Calculate the length of each "pure" homopolymer (eg just "AAAAAAAA"). # Note that this is summed in the merge below, and the final length based # on start/end won't necessarily be this sum because of the -d 1 parameter _df[PFCT_LEN_COL] = _df[cfg.BED_END] - _df[cfg.BED_START] merged = cast( pd.DataFrame, bt.from_dataframe(_df) .merge(d=1, c=[4], o=["sum"]) .slop(b=SLOP, g=genome) .to_dataframe(names=[*cfg.BED_COLS, PFCT_LEN_COL]), ) # these files are huge; now that we have a dataframe, remove all the bed # files from tmpfs to prevent a run on downloadmoreram.com cleanup() hgroup = config.feature_definitions.homopolymers length_col = hgroup.fmt_name(base, lambda x: x.len) frac_col = hgroup.fmt_name(base, lambda x: x.imp_frac) merged[length_col] = merged[cfg.BED_END] - merged[cfg.BED_START] - SLOP * 2 merged[frac_col] = 1 - (merged[PFCT_LEN_COL] / merged[length_col]) return merged.drop(columns=[PFCT_LEN_COL]) def main(smk: Any, config: cfg.StratoMod) -> None: # ASSUME this file is already sorted simreps = read_input(smk.input["bed"][0]) merged = merge_base( config, simreps, cfg.Base(smk.wildcards["base"]), smk.input["genome"][0], ) write_tsv(smk.output[0], merged, header=True) main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | from pathlib import Path import pandas as pd import common.config as cfg from typing import Any from pybedtools import BedTool as bt # type: ignore from common.tsv import write_tsv from common.bed import read_bed from common.io import setup_logging logger = setup_logging(snakemake.log[0]) # type: ignore def main(smk: Any, config: cfg.StratoMod) -> None: rsk = cfg.RefsetKey(smk.wildcards["refset_key"]) cs = config.refsetkey_to_chr_indices(rsk) mapconf = config.refsetkey_to_ref(rsk).feature_data.mappability mapmeta = config.feature_definitions.mappability def read_map_bed(p: Path, ps: cfg.BedFileParams, col: str) -> pd.DataFrame: logger.info("Reading mappability feature: %s", col) df = read_bed(p, ps, {}, cs) df[col] = 1 return df high = read_map_bed(smk.input["high"][0], mapconf.high.params, mapmeta.high) low = read_map_bed(smk.input["low"][0], mapconf.low.params, mapmeta.low) # subtract high from low (since the former is a subset of the latter) new_low = ( bt.from_dataframe(low) .subtract(bt.from_dataframe(high)) .to_dataframe(names=low.columns.tolist()) ) write_tsv(smk.output["high"], high) write_tsv(smk.output["low"], new_low) main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | import pandas as pd from pathlib import Path from typing import Optional, Any import common.config as cfg from pybedtools import BedTool as bt # type: ignore from common.tsv import write_tsv from common.io import setup_logging from common.bed import read_bed # The repeat masker database is documented here: # https://genome.ucsc.edu/cgi-bin/hgTables?db=hg38&hgta_group=rep&hgta_track=rmsk&hgta_table=rmsk&hgta_doSchema=describe+table+schema logger = setup_logging(snakemake.log[0]) # type: ignore # both of these columns are temporary and used to make processing easier CLASSCOL = "_repClass" FAMCOL = "_repFamily" def main(smk: Any, config: cfg.StratoMod) -> None: rsk = cfg.RefsetKey(smk.wildcards["refset_key"]) rk = config.refsetkey_to_refkey(rsk) src = config.references[rk].feature_data.repeat_masker cs = config.refsetkey_to_chr_indices(rsk) def read_rmsk_df(path: Path) -> pd.DataFrame: cols = {11: CLASSCOL, 12: FAMCOL} return read_bed(path, src.params, cols, cs) def merge_and_write_group( df: pd.DataFrame, groupcol: str, clsname: str, famname: Optional[str] = None, ) -> None: groupname = clsname if famname is None else famname dropped = df[df[groupcol] == groupname].drop(columns=[groupcol]) merged = bt.from_dataframe(dropped).merge().to_dataframe(names=cfg.BED_COLS) col = config.feature_definitions.repeat_masker.fmt_name(src, clsname, famname) merged[col] = merged[cfg.BED_END] - merged[cfg.BED_START] write_tsv(smk.output[0], merged, header=True) cls = smk.wildcards.rmsk_class df = read_rmsk_df(smk.input[0]) try: fam = smk.wildcards.rmsk_family logger.info("Filtering/merging rmsk family %s/class %s", fam, cls) merge_and_write_group(df, FAMCOL, cls, fam) except AttributeError: logger.info("Filtering/merging rmsk class %s", cls) merge_and_write_group(df, CLASSCOL, cls) main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | import pandas as pd from pathlib import Path from typing import Any, cast import common.config as cfg from common.tsv import write_tsv from common.bed import read_bed, merge_and_apply_stats from common.io import setup_logging # This database is documented here: # http://genome.ucsc.edu/cgi-bin/hgTables?hgta_doSchemaDb=hg38&hgta_doSchemaTable=genomicSuperDups # ASSUME segdups dataframe is fed into this script with the chromosome column # standardized. The column numbers below are dictionary values, and the # corresponding feature names are the dictionary keys. Note that many feature # names don't match the original column names in the database. logger = setup_logging(snakemake.log[0]) # type: ignore def read_segdups( smk: Any, config: cfg.StratoMod, path: Path, fconf: cfg.SegDupsGroup, ) -> pd.DataFrame: rsk = cfg.RefsetKey(smk.wildcards["refset_key"]) rk = config.refsetkey_to_refkey(rsk) s = config.references[rk].feature_data.segdups ocs = s.other_cols feature_cols = { ocs.align_L: str(fconf.fmt_col(lambda x: x.alignL)[0]), ocs.frac_match_indel: str(fconf.fmt_col(lambda x: x.fracMatchIndel)[0]), } cs = config.refsetkey_to_chr_indices(rsk) return read_bed(path, s.params, feature_cols, cs) def merge_segdups( df: pd.DataFrame, fconf: cfg.SegDupsGroup, ) -> pd.DataFrame: bed, names = merge_and_apply_stats(fconf, df) return cast(pd.DataFrame, bed.to_dataframe(names=names)) def main(smk: Any, config: cfg.StratoMod) -> None: fconf = config.feature_definitions.segdups repeat_df = read_segdups(smk, config, smk.input[0], fconf) merged_df = merge_segdups(repeat_df, fconf) write_tsv(smk.output[0], merged_df, header=True) # TODO make a stub so I don't need to keep repeating this main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | from pathlib import Path import pandas as pd from typing import Any, cast import common.config as cfg from common.tsv import write_tsv from common.bed import read_bed, merge_and_apply_stats from common.io import setup_logging # Input dataframe documented here: # https://genome.ucsc.edu/cgi-bin/hgTables?db=hg38&hgta_group=rep&hgta_track=simpleRepeat&hgta_table=simpleRepeat&hgta_doSchema=describe+table+schema # # ASSUME this dataframe is fed into this script as-is. The column numbers below # are dictionary values, and the corresponding feature names are the dictionary # keys. Note that many feature names don't match the original column names in # the database. logger = setup_logging(snakemake.log[0]) # type: ignore SLOP = 5 def read_tandem_repeats( smk: Any, path: Path, fconf: cfg.TandemRepeatGroup, sconf: cfg.StratoMod, ) -> pd.DataFrame: rsk = cfg.RefsetKey(smk.wildcards["refset_key"]) rk = sconf.refsetkey_to_refkey(rsk) ss = sconf.references[rk].feature_data.tandem_repeats ocs = ss.other_cols fmt_col = fconf.fmt_col perc_a_col = str(fconf.A[0]) perc_t_col = str(fconf.T[0]) perc_c_col = str(fconf.C[0]) perc_g_col = str(fconf.G[0]) unit_size_col = fmt_col(lambda x: x.period)[0] feature_cols = { ocs.period: unit_size_col, ocs.copy_num: fmt_col(lambda x: x.copyNum)[0], ocs.per_match: fmt_col(lambda x: x.perMatch)[0], ocs.per_indel: fmt_col(lambda x: x.perIndel)[0], ocs.score: fmt_col(lambda x: x.score)[0], ocs.per_A: perc_a_col, ocs.per_C: perc_c_col, ocs.per_G: perc_g_col, ocs.per_T: perc_t_col, } cs = sconf.refsetkey_to_chr_indices(rsk) df = read_bed(path, ss.params, feature_cols, cs) base_groups = [ (fconf.AT[0], perc_a_col, perc_t_col), (fconf.AG[0], perc_a_col, perc_g_col), (fconf.CT[0], perc_c_col, perc_t_col), (fconf.GC[0], perc_c_col, perc_g_col), ] for double, single1, single2 in base_groups: df[double] = df[single1] + df[single2] # Filter out all TRs that have period == 1, since those by definition are # homopolymers. NOTE, there is a difference between period and consensusSize # in this database; however, it turns out that at least for GRCh38 that the # sets of TRs where either == 1 are identical, so just use period here # since I can easily refer to it. logger.info("Removing TRs with unitsize == 1") return df[df[unit_size_col] > 1] def merge_tandem_repeats( gfile: str, df: pd.DataFrame, fconf: cfg.TandemRepeatGroup, ) -> pd.DataFrame: bed, names = merge_and_apply_stats(fconf, df) merged_df = cast(pd.DataFrame, bed.slop(b=SLOP, g=gfile).to_dataframe(names=names)) len_col = fconf.length[0] merged_df[len_col] = merged_df[cfg.BED_END] - merged_df[cfg.BED_START] - SLOP * 2 return merged_df def main(smk: Any, sconf: cfg.StratoMod) -> None: i = smk.input fconf = sconf.feature_definitions.tandem_repeats repeat_df = read_tandem_repeats(smk, Path(i.src[0]), fconf, sconf) merged_df = merge_tandem_repeats(i.genome[0], repeat_df, fconf) write_tsv(smk.output[0], merged_df, header=True) main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | import pandas as pd from typing import Any from common.tsv import write_tsv from common.io import setup_logging import common.config as cfg from common.prepare import process_labeled_data, process_unlabeled_data logger = setup_logging(snakemake.log[0]) # type: ignore def write_labeled( xpath: str, ypath: str, sconf: cfg.StratoMod, rconf: cfg.Model, df: pd.DataFrame, ) -> None: filter_col = sconf.feature_definitions.vcf.filter label_col = sconf.feature_definitions.label_name processed = process_labeled_data( rconf.features, rconf.error_labels, rconf.filtered_are_candidates, [cfg.FeatureKey(c) for c in cfg.IDX_COLS], filter_col, cfg.FeatureKey(label_col), df, ) write_tsv(xpath, processed.drop([label_col], axis=1)) write_tsv(ypath, processed[label_col].to_frame()) def write_unlabeled( xpath: str, sconf: cfg.StratoMod, rconf: cfg.Model, df: pd.DataFrame, ) -> None: processed = process_unlabeled_data( rconf.features, [cfg.FeatureKey(c) for c in cfg.IDX_COLS], df, ) write_tsv(xpath, processed) def main(smk: Any, sconf: cfg.StratoMod) -> None: sin = smk.input sout = smk.output wcs = smk.wildcards variables = sconf.testkey_to_variables( cfg.ModelKey(wcs["model_key"]), cfg.TestKey(wcs["test_key"]), ) df = pd.read_table(sin["annotated"][0]).assign( **{str(k): v for k, v in variables.items()} ) rconf = sconf.models[cfg.ModelKey(wcs.model_key)] if "test_y" in dict(sout): write_labeled( sout["test_x"], sout["test_y"], sconf, rconf, df, ) else: write_unlabeled(sout["test_x"], sconf, rconf, df) main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import pandas as pd import common.config as cfg from typing import Any from common.tsv import write_tsv from common.io import setup_logging from common.prepare import process_labeled_data logger = setup_logging(snakemake.log[0]) # type: ignore def read_query( config: cfg.StratoMod, path: str, key: cfg.LabeledQueryKey ) -> pd.DataFrame: variables = config.querykey_to_variables(key) return pd.read_table(path).assign(**{str(k): v for k, v in variables.items()}) def read_queries( config: cfg.StratoMod, paths: dict[cfg.LabeledQueryKey, str], ) -> pd.DataFrame: # TODO this is weird, why do I need the [0] here? return pd.concat([read_query(config, path[0], key) for key, path in paths.items()]) def main(smk: Any, sconf: cfg.StratoMod) -> None: rconf = sconf.models[cfg.ModelKey(cfg.ModelKey(smk.wildcards.model_key))] fconf = sconf.feature_definitions raw_df = read_queries(sconf, smk.input) processed = process_labeled_data( rconf.features, rconf.error_labels, rconf.filtered_are_candidates, [cfg.FeatureKey(c) for c in cfg.IDX_COLS], cfg.FeatureKey(fconf.vcf.filter), cfg.FeatureKey(fconf.label_name), raw_df, ) write_tsv(smk.output["df"], processed) main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | from typing import Any, TextIO from common.config import StratoMod, RefsetKey from common.io import with_gzip_maybe def filter_file(smk: Any, config: StratoMod, fi: TextIO, fo: TextIO) -> None: rsk = RefsetKey(smk.wildcards["refset_key"]) chr_prefix = smk.params.chr_prefix cs = config.refsetkey_to_chr_indices(rsk) chr_mapper = {c.chr_name_full(chr_prefix): c.value for c in cs} for ln in fi: if ln.startswith("#"): fo.write(ln) else: ls = ln.rstrip().split("\t") try: ls[0] = str(chr_mapper[ls[0]]) fo.write("\t".join(ls) + "\n") except KeyError: pass def main(smk: Any, config: StratoMod) -> None: with_gzip_maybe( lambda i, o: filter_file(smk, config, i, o), str(smk.input[0]), str(smk.output[0]), ) main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | from typing import Any, cast, TextIO from common.config import StratoMod, RefsetKey, VCFFile from common.bed import with_bgzip_maybe def fix_DV_refcall(filter_col: str, sample_col: str) -> str: return ( sample_col.replace("./.", "0/1").replace("0/0", "0/1") if filter_col == "RefCall" else sample_col ) def strip_format_fields( fields: set[str], format_col: str, sample_col: str, ) -> tuple[str, str]: f, s = zip( *[ (f, s) for f, s in zip(format_col.split(":"), sample_col.split(":")) if f not in fields ] ) return (":".join(f), ":".join(s)) def filter_file(smk: Any, config: StratoMod, fi: TextIO, fo: TextIO) -> None: rsk = RefsetKey(smk.wildcards["refset_key"]) vcf = cast(VCFFile, smk.params.vcf) chr_prefix = vcf.chr_prefix cs = config.refsetkey_to_chr_indices(rsk) chr_mapper = {c.chr_name_full(chr_prefix): c.value for c in cs} for ln in fi: if ln.startswith("#"): fo.write(ln) else: ls = ln.rstrip().split("\t")[:10] # CHROM = 0 # POS = 1 # ID = 2 # REF = 3 # ALT = 4 # QUAL = 5 # FILTER = 6 # INFO = 7 # FORMAT = 8 # SAMPLE = 9 try: ls[0] = str(chr_mapper[ls[0]]) if vcf.corrections.fix_refcall_gt: ls[9] = fix_DV_refcall(ls[6], ls[9]) if len(vcf.corrections.strip_format_fields) > 0: ls[8], ls[9] = strip_format_fields( vcf.corrections.strip_format_fields, ls[8], ls[9], ) fo.write("\t".join(ls) + "\n") except KeyError: pass def main(smk: Any, config: StratoMod) -> None: with_bgzip_maybe( lambda i, o: filter_file(smk, config, i, o), str(smk.input[0]), str(smk.output[0]), ) main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | from typing import Any, TextIO import common.config as cfg from common.io import with_gzip_maybe, setup_logging logger = setup_logging(snakemake.log[0]) # type: ignore def is_real(s: str) -> bool: return s.removeprefix("-").replace(".", "", 1).isdigit() def dot_to_blank(s: str) -> str: return "" if s == "." else s def none_to_blank(s: str | None) -> str: return "" if s is None else s def write_row( fo: TextIO, chrom: str, start: str, end: str, qual: str, filt: str, info: str, indel_length: str, parse_fields: list[str], const_fields: list[str], label: str | None, ) -> None: const_cols = [chrom, start, end, qual, info, filt, indel_length] label_col = [] if label is None else [label] cols = [*const_cols, *parse_fields, *const_fields, *label_col] fo.write("\t".join(cols) + "\n") def lookup_field(f: cfg.FormatField, d: dict[str, str]) -> str: try: v = d[f.name] if len(f.mapper) == 0: return v if is_real(v) else "" try: return str(f.mapper[v]) except KeyError: return "" except KeyError: return none_to_blank(f.missing) def line_to_bed_row( fo: TextIO, ls: list[str], vcf: cfg.UnlabeledVCFQuery, vtk: cfg.VartypeKey, parse_fields: list[cfg.FormatField], const_field_values: list[str], label: str | None, ) -> bool: # CHROM = 0 # POS = 1 # ID = 2 # REF = 3 # ALT = 4 # QUAL = 5 # FILTER = 6 # INFO = 7 # FORMAT = 8 # SAMPLE = 9 chrom = int(ls[0]) start = int(ls[1]) - 1 # bed's are 0-indexed and vcf's are 1-indexed # remove cases where ref and alt are equal (which is what "." means) if ls[4] == "." or ls[3] == ls[4]: logger.info("Skipping equal variant at %s, %s", chrom, start) return False # remove multiallelics if "," in ls[4]: logger.info("Skipping multiallelic variant at %s, %s", chrom, start) return False # remove anything that doesn't pass out length filters ref_len = len(ls[3]) alt_len = len(ls[4]) if len(ls[3]) > vcf.max_ref or len(ls[4]) > vcf.max_alt: logger.info("Skipping oversized variant at %s, %s", chrom, start) return False # keep only the variant type we care about is_snv = ref_len == alt_len == 1 if is_snv and vtk is cfg.VartypeKey.SNV: indel_length = 0 elif not is_snv and ref_len != alt_len and vtk is cfg.VartypeKey.INDEL: indel_length = alt_len - ref_len else: return False # parse the format/sample columns if desired if len(parse_fields) > 0: fmt_col = ls[8].split(":") smpl_col = ls[9].split(":") # ASSUME any FORMAT/SAMPLE columns with different lengths are screwed # up in some way if len(fmt_col) != len(smpl_col): logger.error("FORMAT/SAMPLE have different lengths at %s, %s", chrom, start) return True d = dict(zip(fmt_col, smpl_col)) parsed_field_values = [lookup_field(f, d) for f in parse_fields] else: parsed_field_values = [] write_row( fo, str(chrom), str(start), str(start + ref_len), dot_to_blank(ls[5]), dot_to_blank(ls[6]), dot_to_blank(ls[7]), str(indel_length), parsed_field_values, list(const_field_values), label, ) return False def parse(smk: Any, sconf: cfg.StratoMod, fi: TextIO, fo: TextIO) -> None: defs = sconf.feature_definitions vcf = sconf.querykey_to_vcf(cfg.LabeledQueryKey(smk.params.query_key)) vtk = cfg.VartypeKey(smk.wildcards.vartype_key) found_error = False try: label = str(smk.wildcards.label) except AttributeError: label = None fields = [(str(defs.vcf.fmt_feature(k)), v) for k, v in vcf.format_fields.items()] parse_fields = [(k, v) for k, v in fields if isinstance(v, cfg.FormatField)] const_fields = [ (k, none_to_blank(v)) for k, v in fields if not isinstance(v, cfg.FormatField) ] # write header write_row( fo, cfg.BED_CHROM, cfg.BED_START, cfg.BED_END, defs.vcf.qual[0], defs.vcf.filter, defs.vcf.info, defs.vcf.indel_length[0], [f[0] for f in parse_fields], [f[0] for f in const_fields], None if label is None else defs.label_name, ) for ln in fi: if ln.startswith("#"): continue err = line_to_bed_row( fo, ln.rstrip().split("\t"), vcf, vtk, [f[1] for f in parse_fields], [f[1] for f in const_fields], label, ) found_error = err or found_error if found_error is True: exit(1) def main(smk: Any, config: cfg.StratoMod) -> None: with_gzip_maybe( lambda i, o: parse(smk, config, i, o), str(smk.input[0]), str(smk.output[0]), ) main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 | import common.config as cfg from Bio import bgzf # type: ignore def main(opath: str, regions: list[cfg.BedRegion]) -> None: with bgzf.open(opath, "w") as f: for r in sorted(regions): f.write(r.fmt() + "\n") main(snakemake.output[0], snakemake.params.regions) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 | import json import pandas as pd import numpy as np from numpy.typing import NDArray from typing import Any, Hashable, cast, TypedDict from common.io import setup_logging from common.ebm import read_model from common.tsv import write_tsv import common.config as cfg from interpret.glassbox import ExplainableBoostingClassifier # type: ignore from enum import Enum setup_logging(snakemake.log[0]) # type: ignore IndexedVectors = dict[int, NDArray[np.float64]] NamedVectors = dict[str, NDArray[np.float64]] EBMUniData = TypedDict( "EBMUniData", {"type": str, "names": list[float | str], "scores": NDArray[np.float64]}, ) class VarType(Enum): INT = "interaction" CNT = "continuous" CAT = "categorical" AllFeatures = dict[str, tuple[VarType, int]] Variable = TypedDict("Variable", {"name": str, "type": str}) BivariateData = TypedDict( "BivariateData", { "left": Variable, "right": Variable, "df": dict[Hashable, float], }, ) GlobalScoreData = TypedDict( "GlobalScoreData", { "variable": list[str], "score": list[float], }, ) UnivariateDF = TypedDict( "UnivariateDF", { "value": list[str | float], "score": list[float], "stdev": list[float], }, ) UnivariateData = TypedDict( "UnivariateData", { "name": str, "vartype": str, "df": UnivariateDF, }, ) ModelData = TypedDict( "ModelData", { "global_scores": GlobalScoreData, "intercept": float, "univariate": list[UnivariateData], "bivariate": list[BivariateData], }, ) # TODO there is no reason this can't be done immediately after training # just to avoid the pickle thing def array_to_list(arr: NDArray[np.float64], repeat_last: bool) -> list[float]: # cast needed since this can return a nested list depending on number of dims al = cast(list[float], arr.tolist()) return al + [al[-1]] if repeat_last else al def get_univariate_df( continuous: bool, feature_data: EBMUniData, stdev: NDArray[np.float64], ) -> UnivariateDF: def proc_scores(scores: NDArray[np.float64]) -> list[float]: return array_to_list(scores, continuous) return UnivariateDF( value=feature_data["names"], score=proc_scores(feature_data["scores"]), # For some reason, the standard deviations array has an extra 0 in # in the front and thus is one longer than the scores array. stdev=proc_scores(stdev[1:]), ) def build_scores_array( arr: NDArray[np.float64], left_type: VarType, right_type: VarType, ) -> NDArray[np.float64]: # any continuous dimension is going to be one less than the names length, # so copy the last row/column to the end in these cases if left_type == VarType.CNT: arr = np.vstack((arr, arr[-1, :])) if right_type == VarType.CNT: arr = np.column_stack((arr, arr[:, -1])) return arr def get_bivariate_df( all_features: AllFeatures, ebm_global: ExplainableBoostingClassifier, name: str, data_index: int, stdevs: IndexedVectors, ) -> BivariateData: def lookup_feature_type(name: str) -> VarType: return all_features[name][0] feature_data = ebm_global.data(data_index) # left is first dimension, right is second left_name, right_name = tuple(name.split(" x ")) left_type = lookup_feature_type(left_name) right_type = lookup_feature_type(right_name) left_index = pd.Index(feature_data["left_names"], name="left_value") right_index = pd.Index(feature_data["right_names"], name="right_value") def stack_array(arr: NDArray[np.float64], name: str) -> "pd.Series[float]": return cast( "pd.Series[float]", pd.DataFrame( build_scores_array(arr, left_type, right_type), index=left_index, columns=right_index, ).stack(), ).rename(name) # the standard deviations are in an array that has 1 larger shape than the # scores array in both directions where the first row/column is all zeros. # Not sure why it is all zeros, but in order to make it line up with the # scores array we need to shave off the first row/column. return BivariateData( left=Variable(name=left_name, type=left_type.value), right=Variable(name=right_name, type=right_type.value), df=pd.concat( [ stack_array(feature_data["scores"], "score"), stack_array(stdevs[data_index][1:, 1:], "stdev"), ], axis=1, ) .reset_index() .to_dict(orient="list"), ) def get_global_scores(ebm_global: ExplainableBoostingClassifier) -> GlobalScoreData: glob = ebm_global.data() return GlobalScoreData(variable=glob["names"], score=glob["scores"]) def get_univariate_list( ebm_global: ExplainableBoostingClassifier, all_features: AllFeatures, stdevs: IndexedVectors, ) -> list[UnivariateData]: return [ UnivariateData( name=name, vartype=vartype.value, df=get_univariate_df( vartype == VarType.CNT, ebm_global.data(i), stdevs[i], ), ) for name, (vartype, i) in all_features.items() if vartype in [VarType.CNT, VarType.CAT] ] def get_bivariate_list( ebm_global: ExplainableBoostingClassifier, all_features: AllFeatures, stdevs: IndexedVectors, ) -> list[BivariateData]: return [ get_bivariate_df(all_features, ebm_global, name, i, stdevs) for name, (vartype, i) in all_features.items() if vartype == VarType.INT ] def get_model(ebm: ExplainableBoostingClassifier) -> ModelData: ebm_global = ebm.explain_global() stdevs = cast(IndexedVectors, ebm.term_standard_deviations_) all_features = { cast(str, n): (VarType(t), i) for i, (n, t) in enumerate( map(tuple, ebm_global.selector[["Name", "Type"]].to_numpy()) ) } return ModelData( global_scores=get_global_scores(ebm_global), intercept=ebm.intercept_[0], univariate=get_univariate_list(ebm_global, all_features, stdevs), bivariate=get_bivariate_list(ebm_global, all_features, stdevs), ) def write_model_json(path: str, ebm: ExplainableBoostingClassifier) -> None: with open(path, "w") as f: json.dump(get_model(ebm), f) def main(smk: Any, sconf: cfg.StratoMod) -> None: sin = smk.input sout = smk.output ebm = read_model(sin["model"]) write_model_json(sout["model"], ebm) label = sconf.feature_definitions.label def write_predictions(xpath: str, ypath: str, out_path: str) -> None: X = pd.read_table(xpath).drop(columns=cfg.IDX_COLS) y = pd.read_table(ypath) y_pred = pd.DataFrame( { "prob": ebm.predict_proba(X)[::, 1], "label": y[label], } ) write_tsv(out_path, y_pred) write_predictions(sin["train_x"], sin["train_y"], sout["train_predictions"]) write_predictions(sin["test_x"], sin["test_y"], sout["predictions"]) main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | import pandas as pd from typing import Any from common.io import setup_logging from common.tsv import write_tsv from common.ebm import read_model import common.config as cfg from interpret.glassbox import ExplainableBoostingClassifier # type: ignore setup_logging(snakemake.log[0]) # type: ignore def _write_tsv(path: str, df: pd.DataFrame) -> None: write_tsv(path, df, header=True) def predict_from_x( ebm: ExplainableBoostingClassifier, df: pd.DataFrame, ) -> tuple[pd.DataFrame, pd.DataFrame]: probs, explanations = ebm.predict_and_contrib(df) return pd.DataFrame(probs), pd.DataFrame(explanations, columns=ebm.feature_names) def main(smk: Any, sconf: cfg.StratoMod) -> None: sin = smk.input sout = smk.output ebm = read_model(sin["model"]) predict_x = pd.read_table(sin["test_x"]).drop(columns=cfg.IDX_COLS) ps, xs = predict_from_x(ebm, predict_x) _write_tsv(sout["predictions"], ps) _write_tsv(sout["explanations"], xs) main(snakemake, snakemake.config) # type: ignore |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | import pandas as pd import yaml from typing import Any from more_itertools import flatten from sklearn.model_selection import train_test_split # type: ignore from interpret.glassbox import ExplainableBoostingClassifier # type: ignore from common.tsv import write_tsv from common.io import setup_logging from common.ebm import write_model import common.config as cfg logger = setup_logging(snakemake.log[0]) # type: ignore def _write_tsv(smk: Any, key: str, df: pd.DataFrame) -> None: write_tsv(smk.output[key], df, header=True) def dump_config(smk: Any, config: cfg.Model) -> None: with open(smk.output["config"], "w") as f: yaml.dump(config, f) def get_interactions( df_columns: list[cfg.FeatureKey], iconfig: int | cfg.InteractionSpec, ) -> int | list[list[int]]: def expand_interactions(i: cfg.InteractionSpec_) -> list[list[int]]: if isinstance(i, str): return [ [df_columns.index(i), c] for c, f in enumerate(df_columns) if f != i ] else: return [[df_columns.index(i.f1), df_columns.index(i.f2)]] if isinstance(iconfig, int): return iconfig else: return [*flatten(expand_interactions(i) for i in iconfig)] def train_ebm( smk: Any, sconf: cfg.StratoMod, rconf: cfg.Model, df: pd.DataFrame, ) -> None: label = sconf.feature_definitions.label def strip_coords(df: pd.DataFrame) -> pd.DataFrame: return df.drop(columns=cfg.IDX_COLS) features = rconf.features feature_names = [ k if v.alt_name is None else v.alt_name for k, v in features.items() ] misc_params = rconf.ebm_settings.misc_parameters if misc_params.downsample is not None: df = df.sample(frac=misc_params.downsample) train_cols = [c for c in df.columns if c != label] X = df[train_cols] y = df[label] X_train, X_test, y_train, y_test = train_test_split( X, y, **rconf.ebm_settings.split_parameters.dict(), ) cores = smk.threads logger.info( "Training EBM with %d features and %d cores", len(features), cores, ) ebm = ExplainableBoostingClassifier( # NOTE the EBM docs show them explicitly adding interactions here like # 'F1 x F2' but it appears to work when I specify them separately via # the 'interactions' parameter feature_names=feature_names, feature_types=[f.feature_type.value for f in features.values()], interactions=get_interactions(feature_names, rconf.interactions), n_jobs=cores, **rconf.ebm_settings.classifier_parameters.mapping, ) ebm.fit(strip_coords(X_train), y_train) write_model(smk.output["model"], ebm) _write_tsv(smk, "train_x", X_train) _write_tsv(smk, "train_y", y_train) _write_tsv(smk, "test_x", X_test) _write_tsv(smk, "test_y", y_test) def main(smk: Any, sconf: cfg.StratoMod) -> None: rconf = sconf.models[smk.wildcards.model_key] df = pd.read_table(smk.input[0]) train_ebm(smk, sconf, rconf, df) dump_config(smk, rconf) main(snakemake, snakemake.config) # type: ignore |
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | library(tidyverse) library(infotheo) # from blablabla import nukes `:=` = rlang::`:=` `!!` = rlang::`!!` root = snakemake@params$lib_path source(file.path(root, "colocalization.r")) source(file.path(root, "plots.r")) format_perc <- function(x) { sprintf("%.4f", x) } format_exp <- function(x) { sprintf("%.1e", x) } make_stats_table <- function(df) { N <- nrow(df) gather(df, factor_key = TRUE) %>% group_by(key) %>% summarize(n_present = sum(!is.na(value)), perc_present = 100 * n_present / N, # prevent NULL error for zero length vectors in min/max min = ifelse(n_present == 0, NA, min(value, na.rm = TRUE)), max = ifelse(n_present == 0, NA, max(value, na.rm = TRUE)), med = median(value, na.rm = TRUE), mean = mean(value, na.rm = TRUE), stdev = sd(value, na.rm = TRUE), range = max - min) %>% rename(feature = key) %>% mutate(perc_present = format_perc(perc_present)) %>% mutate(across(c(min, max, med, mean, stdev, range), format_exp)) %>% arrange(desc(as.numeric(perc_present))) %>% knitr::kable(align = "r") } ## TODO wet..... make_feature_distribution <- function(x, labels) { infer_transform(x) %>% mutate(label = labels) %>% gather(-label, key = key, value = value) %>% ggplot() + aes(value, color = label) + geom_density() + xlab(NULL) + ylab(NULL) + facet_wrap(~key, scales = "free") } make_unlabeled_feature_distribution <- function(x) { infer_transform(x) %>% gather(key = key, value = value) %>% ggplot() + aes(value) + geom_density() + xlab(NULL) + ylab(NULL) + facet_wrap(~key, scales = "free") } label_summary_table <- function(y) { tibble(label = y) %>% group_by(label) %>% summarize(n = n(), proportion = format_perc(n / N)) %>% knitr::kable() } columns <- snakemake@params[["columns"]] query_key <- snakemake@params[["query_key"]] label_col <- snakemake@params[["label_col"]] path <- snakemake@input[[1]] has_label <- !is.null(label_col) all_columns <- if (is.null(label_col)) { columns } else { c(columns, label_col) } x_col_types <- rep("-", length(all_columns)) %>% as.list() %>% setNames(all_columns) %>% c(list(".default"="d")) %>% do.call(cols, .) df_x <- readr::read_tsv(path, col_types = x_col_types) features <- names(df_x) N <- nrow(df_x) |
99 100 101 102 103 104 | df_y <- readr::read_tsv(path, col_types = cols( !!label_col := "c", .default = "-")) %>% pull(!!label_col) y_labels <- unique(df_y) |
117 | label_summary_table(df_y) |
127 128 | df_x %>% make_stats_table() |
134 135 136 137 138 139 140 141 142 143 144 145 | print_label_tbl <- function(label) { cat(sprintf("## %s Label\n\n", label)) df_x %>% filter(df_y == label) %>% make_stats_table() %>% print() cat("\n\n") } walk(as.list(y_labels), print_label_tbl) |
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | information <- function(df, var1, var2) { .df <- df %>% select(all_of(c(var1, var2))) %>% drop_na() n <- nrow(.df) if (n > 0) { nbreaks <- sqrt(n) %>% ifelse(. > 2, ., 2) mi <- .df %>% ## the discretize function doesn't seem to work the way I want, so ## just use 'cut' since I know what that does ## discretize() %>% mutate(across(everything(), ~ cut(.x, breaks = nbreaks, labels = FALSE) %>% as.vector())) %>% mutinformation() H1 <- mi[1, 1] H2 <- mi[2, 2] I <- mi[1, 2] } else { H1 <- NA H2 <- NA I <- NA } list(H1 = H1, H2 = H2, ## mutual information I = I, ## mutual information normalized to joint entropy Inorm = I / (H1 + H2 - I), ## mutual information normalized to the first feature I_H1 = I / H1, ## variation of information (if a metric is needed) VI = H1 + H2 - 2 * I) } info_df <- function(features, df_info) { features %>% as.list() %>% map(~ information(df_info, "label", .x)) %>% tibble(i = ., param = features) %>% unnest_wider(i) %>% drop_na() } info_plot <- function(df) { ggplot(df, aes(reorder(param, desc(I_H1)), I_H1)) + geom_col() + xlab(NULL) + ylab("Mutual Inf. (Normalized to Label)") + theme(axis.text.x = element_text(hjust = 1, vjust = 0.5, angle = 90)) } print_info_plot <- function(df_info, features) { info_df(features, df_info) %>% info_plot() %>% print() } df_info_na <- df_x %>% mutate(label = as.integer(df_y == "tp")) df_info_filled <- df_info_na %>% mutate(across(everything(), ~ if_else(is.na(.x), 0.0, as.double(.x)))) |
232 | print_info_plot(df_info_filled, features) |
240 | print_info_plot(df_info_na, features) |
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 | print_coloc <- function(df_bool, comb_df) { mutate(comb_df, asymm_jaccard = ajaccard(df_bool, var.x, var.y)) %>% make_xy_tile_plot("var.x", "var.y", "asymm_jaccard", "starting set", "overlapping set") %>% print() cat("\n\n") } combinations <- df_x %>% names() %>% cross_tibble() df_x_bool <- to_binary(df_x) |
281 282 283 | cat("## All labels\n\n") print_coloc(df_x_bool, combinations) |
289 290 291 292 293 294 295 296 297 298 299 | print_label_coloc <- function(label) { cat(sprintf("## %s only\n\n", label)) df_x_bool <- filter(df_x, df_y == label) %>% to_binary() print_coloc(df_x_bool, combinations) cat("\n\n") } walk(as.list(y_labels), print_label_coloc) |
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 | ## ASSUME these will be the same for TP/FP/both perfect_overlaps <- df_x_bool %>% perfect_overlapping_sets(combinations, "var.x", "var.y") print_subset_cor_plot <- function(df, subset) { df %>% select(all_of(subset)) %>% drop_na() %>% make_cor_plot() %>% print() } print_cor_plots <- function(df, subsets) { cat(sprintf("number of rows: %s\n\n", nrow(df))) walk(as.list(subsets), ~ print_subset_cor_plot(df, .x)) cat("\n\n") } |
336 337 338 | cat("## All labels\n\n") print_cor_plots(df_x, perfect_overlaps) |
344 345 346 347 348 349 350 | print_label_cor <- function(label) { cat(sprintf("## %s only\n\n", label)) filter(df_x, df_y == label) %>% print_cor_plots(perfect_overlaps) cat("\n\n") } walk(as.list(y_labels), print_label_cor) |
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 | print_labeled_plot <- function(x, name) { cat(sprintf("## %s\n\n", name)) .df <- tibble(x = x, y = df_y) %>% filter(!is.na(x)) label_summary_table(.df$y) %>% print() cat("\n\n") .x <- .df$x .y <- .df$y if (length(.x) == 0) { cat("Feature has no values") } else if (max(.x) - min(.x) == 0) { cat(sprintf("Feature has one value: %.1f", max(.x))) } else { print(make_feature_distribution(.x, .y)) } cat("\n\n") } print_unlabeled_plot <- function(x, name) { cat(sprintf("## %s\n\n", name)) .df <- tibble(x = x) %>% filter(!is.na(x)) .x <- .df$x if (length(.x) == 0) { cat("Feature has no values") } else if (max(.x) - min(.x) == 0) { cat(sprintf("Feature has one value: %.1f", max(.x))) } else { print(make_unlabeled_feature_distribution(.x)) } cat("\n\n") } if (has_label) { iwalk(df_x, ~ print_labeled_plot(.x, .y)) } else { iwalk(df_x, ~ print_unlabeled_plot(.x, .y)) } |
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | library(tidyverse) library(ggpubr) root = snakemake@params$lib_path source(file.path(root, "plots.r")) read_df <- function(path) { readr::read_tsv(path, col_types = cols(.default = "d")) } has_label <- "truth_y" %in% names(snakemake@input) pred_y <- read_df(snakemake@input[["predictions"]]) explain_x <- read_df(snakemake@input[["explanations"]]) query_key <- snakemake@params[["query_key"]] |
27 28 29 30 31 32 33 | truth_y <- readr::read_tsv(snakemake@input[["truth_y"]], col_types = cols(chrom = "-", chromStart = "-", chromEnd = "-", variant_index = "-", .default = "d")) y <- tibble(label = truth_y$label, prob = pred_y$`1`) |
48 49 | cat(sprintf("* N: %i\n", nrow(pred_y))) cat(sprintf("* Perc. Pos: %f\n\n", sum(y$label)/nrow(y))) |
57 58 59 60 61 62 63 64 65 66 | if (has_label) { ggplot(y, aes(prob, color = factor(label))) + geom_density() + xlab("probability") + scale_color_discrete(name = "label") } else { ggplot(pred_y, aes(x = `1`)) + geom_density() + xlab("probability") } |
70 71 72 73 74 75 76 77 78 79 80 81 82 | AllP <- sum(y$label) AllN <- nrow(y) - AllP roc <- y %>% arrange(prob) %>% mutate(thresh_FN = cumsum(label), thresh_TN = row_number() - thresh_FN, thresh_TP = AllP - thresh_FN, thresh_FP = AllN - thresh_TN, TPR = thresh_TP / AllP, TNR = thresh_TN / AllN, FPR = 1 - TNR, precision = thresh_TP / (thresh_TP + thresh_FP)) |
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | cat("## Calibration\n\n") nbins <- 10 y %>% mutate(bin = cut(prob, nbins, labels = FALSE) / nbins) %>% group_by(bin) %>% summarize(mean_pred = mean(prob), frac_pos = mean(label)) %>% ggplot(aes(mean_pred, frac_pos)) + geom_point() + geom_line() + geom_abline(linetype = "dotted", color = "red") + xlim(0, 1) + ylim(0, 1) cat("## ROC Curves\n\n") roc %>% arrange(FPR, TPR) %>% ggplot(aes(FPR, TPR)) + geom_line() roc %>% filter(!is.na(precision)) %>% arrange(TPR, desc(precision)) %>% ggplot(aes(TPR, precision)) + geom_line() |
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | library(tidyverse) library(ggpubr) root = snakemake@params$lib_path source(file.path(root, "plots.r")) to_tibble <- function(lst) { do.call(tibble, lst) } read_model <- function(path) { jsonlite::read_json(path, simplifyVector = TRUE, simplifyDataFrame = FALSE) } read_predictions <- function(path) { readr::read_tsv(path, col_types = cols(.default = "d")) } lookup_input_path <- function(mapping, k) { pluck(mapping, as.character(as.integer(k))) } to_univariate <- function(model) { model$univariate %>% map(~ list(meta = .x[c("name", "vartype")], df = to_tibble(.x[["df"]]))) } to_bivariate <- function(model) { model$bivariate %>% map(~ list(left = .x$left, right = .x$right, df = to_tibble(.x$df))) } run_features <- snakemake@params[["features"]] error_labels <- snakemake@params[["error_labels"]] mod <- read_model(snakemake@input[["model"]]) test_pred <- read_predictions(snakemake@input[["predictions"]]) train_pred <- read_predictions(snakemake@input[["train_predictions"]]) train_x <- readr::read_tsv(snakemake@input[["train_x"]], col_types = cols(chrom = "-", chromStart = "-", chromEnd = "-", variant_index = "-", .default = "d")) train_y <- readr::read_tsv(snakemake@input[["train_y"]], col_types = cols(.default = "d")) threshold <- train_pred %>% pull(label) %>% mean() alltrain <- bind_cols(train_x, train_pred) %>% mutate(pred = prob > threshold) VCF_input_name <- "VCF_input" global_df <- to_tibble(mod$global_scores) univariate <- to_univariate(mod) bivariate <- to_bivariate(mod) |
80 | cat("\n") |
93 94 95 96 | ggplot(test_pred, aes(prob, color = factor(label))) + geom_density() + xlab("probability") + scale_color_discrete(name = "label") |
102 103 104 105 106 107 108 109 110 111 112 113 | nbins <- 10 test_pred %>% mutate(bin = cut(prob, nbins, labels = FALSE) / nbins) %>% group_by(bin) %>% summarize(mean_pred = mean(prob), frac_pos = mean(label)) %>% ggplot(aes(mean_pred, frac_pos)) + geom_point() + geom_line() + geom_abline(linetype = "dotted", color = "red") + xlim(0, 1) + ylim(0, 1) |
121 122 123 124 125 126 127 128 129 130 131 132 133 | AllP <- sum(test_pred$label) AllN <- nrow(test_pred) - AllP roc <- test_pred %>% arrange(prob) %>% mutate(thresh_FN = cumsum(label), thresh_TN = row_number() - thresh_FN, thresh_TP = AllP - thresh_FN, thresh_FP = AllN - thresh_TN, TPR = thresh_TP / AllP, TNR = thresh_TN / AllN, FPR = 1 - TNR, precision = thresh_TP / (thresh_TP + thresh_FP)) |
139 140 141 142 143 144 145 146 147 148 | roc %>% arrange(FPR, TPR) %>% ggplot(aes(FPR, TPR)) + geom_line() roc %>% filter(!is.na(precision)) %>% arrange(TPR, desc(precision)) %>% ggplot(aes(TPR, precision)) + geom_line() |
159 160 161 162 | ggplot(global_df, aes(score, reorder(variable, score))) + geom_col() + xlab("Score") + ylab(NULL) |
173 174 175 176 177 | tibble(x = "intercept", y = mod$intercept) %>% ggplot(aes(x, y)) + geom_col() + xlab(NULL) + ylab("score") |
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 | get_truncation <- function(s) { run_features[[s]][["visualization"]][["truncate"]] } get_split_missing <- function(s) { run_features[[s]][["visualization"]][["split_missing"]] } get_fill_na <- function(s) { run_features[[s]][["fill_na"]] } get_plot_type <- function(s) { run_features[[s]][["visualization"]][["plot_type"]] } truncate_maybe <- function(df, name) { t <- get_truncation(name) lower <- t[["lower"]] upper <- t[["upper"]] caption <- if (!is.null(lower) && !is.null(upper)) { sprintf("Truncated from %d to %d", lower, upper) } else if (!is.null(lower)) { sprintf("Truncated from %d to -Inf", lower) } else if (!is.null(upper)) { sprintf("Truncated from -Inf to %d", upper) } .df <- if (is.null(lower) && is.null(upper)) { df } else { .lower <- if (is.null(lower)) min(df$value) else lower .upper <- if (is.null(upper)) max(df$value) else upper filter(df, .lower <= value & value <= .upper) } list(df = .df, lower = lower, upper = upper, caption = caption) } null2alt <- function(default, x) { if (is.null(x)) default else x } null2na <- function(x) { null2alt(NA, x) } make_integer_plot <- function(df, name, lower = NULL, upper = NULL, ylab = "Score") { fill_cols <- c("score", "stdev") .lower <- null2alt(min(df$value), lower) .upper <- null2alt(max(df$value), upper) .join <- tibble(value = .lower:.upper) mutate(df, value = ceiling(value)) %>% right_join(.join, by = "value") %>% arrange(value) %>% fill(all_of(fill_cols), .direction = "downup") %>% ggplot(aes(value, score)) + geom_col() + xlab(name) + ylab(ylab) + geom_errorbar(aes(ymin = score - stdev, ymax = score + stdev)) } ## TODO use inverse logit here? make_fraction_plot <- function(df) { df %>% group_by(value) %>% summarize(frac = mean(label), stderr = sqrt(frac * (1 - frac) / n())) %>% ggplot(aes(value, frac)) + geom_point() + geom_errorbar(aes(ymin = frac - stderr, ymax = frac + stderr), width = 0.1) } make_integer_fraction_plot <- function(df, name, lower = NULL, upper = NULL) { .name <- sym(name) df %>% mutate(value = ceiling({{ .name }})) %>% make_fraction_plot() + xlab(name) + ylab("Frac(TP)") + coord_trans(xlim = c(null2na(lower), null2na(upper))) } make_continuous_plot <- function(df, name, ylab = "Score") { ggplot(df, aes(value, score)) + geom_step(aes(y = score + stdev), color = "red") + geom_step(aes(y = score - stdev), color = "red") + geom_step() + xlab(name) + ylab(ylab) } make_continuous_fraction_plot <- function(df, name) { .name <- sym(name) df %>% ## TODO average the bin ends to make the axis cleaner mutate(value = cut({{ .name }}, 20)) %>% make_fraction_plot() + xlab(name) + ylab("Frac(TP)") + theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) } make_categorical_plot <- function(df, name, ylab = "Score") { df %>% ggplot(aes(factor(value), score)) + geom_col() + geom_errorbar(aes(ymin = score - stdev, ymax = score + stdev), width = 0.1) + xlab(name) + ylab(ylab) } make_categorical_fraction_plot <- function(df, name) { .name <- sym(name) df %>% mutate(value = factor({{ .name }})) %>% make_fraction_plot() + xlab(name) + ylab("Frac(TP)") } standardize_y_axes <- function(ps) { lims <- map(ps, ~ layer_scales(.x)[["y"]]$get_limits()) %>% do.call(cbind, .) new <- c(min(lims[1, ]), max(lims[2, ])) map(ps, ~ .x + ylim(new)) } make_split_plot <- function(df, name, bound, fun) { missing_val <- get_fill_na(name) missing <- filter(df, value == missing_val) %>% mutate(value = "Missing") nonmissing <- filter(df, value != missing_val) %>% mutate(value = if_else(value < bound, bound, value)) bar <- ggplot(missing, aes(factor(value), score)) + geom_col() + geom_errorbar(aes(ymax = score + stdev, ymin = score - stdev), width = 0.2) + xlab(NULL) step <- fun(nonmissing, NULL) + ylab(NULL) + theme(axis.text.y = element_blank(), axis.ticks.y = element_blank()) list(bar, step) %>% standardize_y_axes() %>% ggarrange(plotlist = ., ncol = 2, widths = c(1, 5)) %>% annotate_figure(bottom = text_grob(name)) } make_split_fraction_plot <- function(df, name, bound, fun) { .name <- sym(name) missing_val <- get_fill_na(name) missing <- filter(df, {{ .name }} == missing_val) %>% mutate({{ .name }} := "Missing") nonmissing <- filter(df, {{ .name }} != missing_val) bar <- make_categorical_fraction_plot(missing, name) + xlab(NULL) step <- fun(nonmissing, name) + xlab(NULL) + ylab(NULL) + theme(axis.text.y = element_blank(), axis.ticks.y = element_blank()) list(bar, step) %>% standardize_y_axes() %>% ggarrange(plotlist = ., ncol = 2, widths = c(1, 5), align = "h") %>% annotate_figure(bottom = text_grob(name)) } wrap_split_maybe <- function(name, split_f, f) { s <- get_split_missing(name) if (is.null(s)) f else partial(split_f, fun = f, bound = s) } print_uv_plot <- function(vartype, df, name) { r <- if (vartype == "continuous") { tr <- truncate_maybe(df, name) t <- get_plot_type(name) fs <- if (t == "step") { list( make_continuous_plot, make_continuous_fraction_plot ) } else if (t == "bar") { list( partial( make_integer_plot, lower = tr[["lower"]], upper = tr[["upper"]] ), partial( make_integer_fraction_plot, lower = tr[["lower"]], upper = tr[["upper"]] ) ) } else { stop(sprintf("wrong type, dummy; got %s", t)) } ## TODO only continuous plots can be split (for now) list( feat_f = wrap_split_maybe(name, make_split_plot, fs[[1]]), frac_f = wrap_split_maybe(name, make_split_fraction_plot, fs[[2]]), df = tr[["df"]], caption = tr[["caption"]] ) } else if (vartype == "categorical") { list( feat_f = make_categorical_plot, frac_f = make_categorical_fraction_plot, df = df, caption = NULL ) } else { stop(sprintf("wrong plot type, dummy; got %s", vartype)) } p0 <- r$feat_f(r$df, name) p1 <- r$frac_f(alltrain, name) cat(sprintf("## %s\n", name)) print(p0) cat("\n\n") print(p1) cat("\n\n") if (!is.null(r$caption)) { cat(sprintf("%s\n\n", r$caption)) } } walk(univariate, ~ print_uv_plot(.x$meta$vartype, .x$df, .x$meta$name)) |
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 | cont_cont_plot <- function(df, yvar, left_name, right_name) { # poor-mans 2d step heatmap plot thing .yvar <- sym(yvar) df %>% group_by(right_value) %>% mutate(left_upper = lead(left_value)) %>% ungroup() %>% group_by(left_value) %>% mutate(right_upper = lead(right_value)) %>% ungroup() %>% filter(!is.na(left_upper)) %>% filter(!is.na(right_upper)) %>% ggplot() + geom_rect(aes(xmin = left_value, xmax = left_upper, ymin = right_value, ymax = right_upper, fill = {{ .yvar }})) + xlab(left_name) + ylab(right_name) } print_cont_cont_plot <- function(df, left_name, right_name) { x_tr <- df %>% rename(value = left_value) %>% truncate_maybe(left_name) y_tr <- df %>% rename(value = right_value) %>% truncate_maybe(right_name) lims <- coord_trans(xlim = c(null2na(x_tr$lower), null2na(x_tr$upper)), ylim = c(null2na(y_tr$lower), null2na(y_tr$upper))) if (!is.null(x_tr$caption)) { cat(sprintf("%s: %s\n\n", left_name, x_tr$caption)) } if (!is.null(y_tr$caption)) { cat(sprintf("%s: %s\n\n", right_name, y_tr$caption)) } cat("### Scores\n\n") p0 <- cont_cont_plot(df, "score", left_name, right_name) + scale_fill_gradient2(midpoint = 0) + lims print(p0) cat("\n\n") cat("### Stdevs\n\n") p1 <- cont_cont_plot(df, "stdev", left_name, right_name) + scale_fill_gradient() + lims print(p1) } print_cont_cat_plot_inner <- function(df, cat_name, cont_name) { cat_val <- as.character(df$c[[1]]) tr <- truncate_maybe(df, cont_name) t <- get_plot_type(cont_name) f <- if (t == "step") { partial(make_continuous_plot) } else if (t == "bar") { partial(make_integer_plot, lower = tr[["lower"]], upper = tr[["upper"]] ) } else { stop(sprintf("wrong plot type: got %s", t)) } p <- wrap_split_maybe(cont_name, make_split_plot, f)(tr[["df"]], cont_name) cat(sprintf("### %s = %s\n\n", cat_name, cat_val)) cat(sprintf("%s\n\n", tr[["caption"]])) print(p) cat("\n\n") } print_cont_cat_plot <- function(df, cat, cont, cat_name, cont_name) { ## this 'all_of' thing is needed to silence a weird warning about ## using vectors to select things (I disagree with it, but whatever) .df <- rename(df, value = all_of(cont), c = all_of(cat)) %>% mutate(lower = score - stdev, upper = score + stdev) %>% group_split(c) %>% walk(~ print_cont_cat_plot_inner(.x, cat_name, cont_name)) } print_cat_cat_plot <- function(df, left_name, right_name) { p <- ggplot(df, aes(factor(left_value), score, fill = factor(right_value) )) + geom_col(position = "dodge") + geom_errorbar(aes(ymin = score - stdev, ymax = score + stdev), width = 0.1, position = position_dodge(0.9)) + xlab(left_name) + scale_fill_discrete(name = right_name) print(p) } print_bv_plot_inner <- function(L, R, df) { if (L$type == "continuous" && R$type == "continuous") { print_cont_cont_plot(df, L$name, R$name) } else if (L$type == "categorical" && R$type == "continuous") { print_cont_cat_plot(df, "left_value", "right_value", L$name, R$name) } else if (L$type == "continuous" && R$type == "categorical") { print_cont_cat_plot(df, "right_value", "left_value", R$name, L$name) } else if (L$type == "categorical" && R$type == "categorical") { print_cat_cat_plot(df, L$name, R$name) } else { sprintf("Types are wrong, dummy: %s and/or %s", L$type, R$type) } } print_bv_plot <- function(L, R, df) { cat(sprintf("## %s x %s\n\n", L$name, R$name)) print_bv_plot_inner(L, R, df) cat("\n\n") } if (length(bivariate) == 0) { cat("None\n\n") } else { walk(bivariate, ~ print_bv_plot(.x$left, .x$right, .x$df)) } |
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 | logErr <- function(p) { -log10(1 - p) } make_perf_df <- function(df) { df %>% group_by(value) %>% summarize(precision = sum(label & pred) / sum(pred), recall = sum(label & pred) / sum(label), f1 = 2 * (precision * recall) / (precision + recall)) %>% pivot_longer(cols = c(precision, recall, f1), names_to = "metric", values_to = "mvalue") %>% mutate(mvalue = logErr(mvalue)) ## ggplot(aes(value, mvalue, color = metric)) + ## labs(x = "Feature Value", ## y = "-log10(metric)") } make_perf_plot <- function(df) { df %>% group_by(value) %>% summarize(precision = sum(label & pred) / sum(pred), recall = sum(label & pred) / sum(label), f1 = 2 * (precision * recall) / (precision + recall)) %>% pivot_longer(cols = c(precision, recall, f1), names_to = "metric", values_to = "mvalue") %>% mutate(mvalue = logErr(mvalue)) %>% ggplot(aes(value, mvalue, color = metric)) + xlab(NULL) + ylab("-log10(metric)") } make_integer_perf_plot <- function(df, name, lower = NULL, upper = NULL) { .n <- sym(name) df %>% mutate(value = ceiling({{ .n }})) %>% make_perf_plot() + geom_point() + geom_line() + coord_trans(xlim = c(null2na(lower), null2na(upper))) } make_continuous_perf_plot <- function(df, name) { .n <- sym(name) df %>% mutate(value = cut({{ .n }}, 20)) %>% make_perf_plot() + geom_point() + theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) } make_categorical_perf_plot <- function(df, name) { .n <- sym(name) df %>% mutate(value = factor({{ .n }})) %>% make_perf_plot() + geom_col(position = "dodge", aes(fill = metric)) } make_perf_split_plot <- function(df, name, bound, fun) { .n <- sym(name) missing_val <- get_fill_na(name) missing <- filter(df, {{ .n }} == missing_val) %>% mutate({{ .n }} := "Missing") nonmissing <- filter(df, {{ .n }} != missing_val) %>% mutate({{ .n }} := if_else({{ .n }} < bound, bound, {{ .n }})) bar <- make_categorical_perf_plot(missing, name) + xlab(NULL) step <- fun(nonmissing, name) + ylab(NULL) + xlab(NULL) + theme(axis.text.y = element_blank(), axis.ticks.y = element_blank()) list(bar, step) %>% standardize_y_axes() %>% ggarrange(plotlist = ., ncol = 2, widths = c(1, 5), common.legend = TRUE, legend = "right", align = "h") %>% annotate_figure(bottom = text_grob(name)) } print_perf_profile_plot <- function(vartype, name) { r <- if (vartype == "continuous") { tr <- get_truncation(name) t <- get_plot_type(name) f <- if (t == "step") { make_continuous_perf_plot } else if (t == "bar") { partial( make_integer_perf_plot, lower = tr[["lower"]], upper = tr[["upper"]] ) } else { stop(sprintf("wrong type, dummy; got %s", t)) } list( feat_f = wrap_split_maybe(name, make_perf_split_plot, f), caption = tr[["caption"]] ) } else if (vartype == "categorical") { list( feat_f = make_categorical_perf_plot, caption = NULL ) } else { stop(sprintf("wrong plot type, dummy; got %s", vartype)) } cat(sprintf("## %s\n", name)) print(r$feat_f(alltrain, name)) cat("\n\n") if (!is.null(r$caption)) { cat(sprintf("%s\n\n", r$caption)) } } walk(univariate, ~ print_perf_profile_plot(.x$meta$vartype, .x$meta$name)) |
Support
- Future updates
Related Workflows





