Integrating genotypes and phenotypes improves long-term forecasts of seasonal influenza A/H3N2 evolution
John Huddleston1,2, John R. Barnes3, Thomas Rowe3, Xiyan Xu3, Rebecca Kondor3, David E. Wentworth3, Lynne Whittaker4, Burcu Ermetal4, Rodney S. Daniels4, John W. McCauley4, Seiichiro Fujisaki5, Kazuya Nakamura5, Noriko Kishida5, Shinji Watanabe5, Hideki Hasegawa5, Ian Barr6, Kanta Subbarao6, Pierre Barrat-Charlaix7,8, Richard A. Neher7,8 & Trevor Bedford1
1Vaccine and Infectious Disease Division, Fred Hutchinson Cancer Research Center, Seattle, WA, USA, 2Molecular and Cell Biology, University of Washington, Seattle, WA, USA, 3Virology Surveillance and Diagnosis Branch, Influenza Division, National Center for Immunization and Respiratory Diseases (NCIRD), Centers for Disease Control and Prevention (CDC), 1600 Clifton Road, Atlanta, GA 30333, USA, 4WHO Collaborating Centre for Reference and Research on Influenza, Crick Worldwide Influenza Centre, The Francis Crick Institute, London, UK., 5Influenza Virus Research Center, National Institute of Infectious Diseases, Tokyo, Japan, 6The WHO Collaborating Centre for Reference and Research on Influenza, The Peter Doherty Institute for Infection and Immunity, Melbourne, VIC, Australia; Department of Microbiology and Immunology, The University of Melbourne, The Peter Doherty Institute for Infection and Immunity, Melbourne, VIC, Australia., 7Biozentrum, University of Basel, Basel, Switzerland, 8Swiss Institute of Bioinformatics, Basel, Switzerland
DOI: https://doi.org/10.7554/eLife.60067
Contents
Abstract
Seasonal influenza virus A/H3N2 is a major cause of death globally. Vaccination remains the most effective preventative. Rapid mutation of hemagglutinin allows viruses to escape adaptive immunity. This antigenic drift necessitates regular vaccine updates. Effective vaccine strains need to represent H3N2 populations circulating one year after strain selection. Experts select strains based on experimental measurements of antigenic drift and predictions made by models from hemagglutinin sequences. We developed a novel influenza forecasting framework that integrates phenotypic measures of antigenic drift and functional constraint with previously published sequence-only fitness estimates. Forecasts informed by phenotypic measures of antigenic drift consistently outperformed previous sequence-only estimates, while sequence-only estimates of functional constraint surpassed more comprehensive experimentally-informed estimates. Importantly, the best models integrated estimates of both functional constraint and either antigenic drift phenotypes or recent population growth.
Installation
Install miniconda . Clone the forecasting repository.
git clone https://github.com/blab/flu-forecasting.git
cd flu-forecasting
Create and activate a conda environment for the pipeline.
conda env create -f envs/anaconda.python3.yaml
conda activate flu_forecasting
Quickstart
Run the pipeline for sparse simulated data. This will first simulate influenza-like populations and then fit models to those populations. Inspect all steps to be executed by the pipeline with a dryrun.
snakemake --dryrun --use-conda --config active_builds='simulated_sample_1'
Run the pipeline locally with four jobs (or cores) at once.
snakemake --use-conda --config active_builds='simulated_sample_1' -j 4
Always specify a value for
-j
, to limit the number of cores available to the simulator.
If no limit is provided, the Java-based simulator will attempt to use all available cores and may cause headaches for you or your cluster's system administrator.
Configuration
Analyses are parameterized by the contents of
config/config.json
.
Models are fit to annotated data frames created for one or more "builds" from one or more "datasets".
Datasets and builds are decoupled to allow multiple builds from a single dataset.
Builds are split into "simulated" and "natural" such that each entry in one of these categories is a dictionary of build settings indexed by a build name.
The list of active builds is determined by the space-delimited values in the
active_builds
top-level key of the configuration.
Workflow structure
Workflow
The analyses for this paper were produced using a workflow written with Snakemake . The complete graph of the workflow is available as a PDF . This PDF was created with the following Snakemake command.
snakemake --forceall --dag manuscript/flu_forecasting.pdf | dot -Tpdf > full_dag.pdf
Below is a subset of the complete workflow showing how tip attributes are created for a single timepoint (2015-10-01) from the natural populations analysis. This image was created with the following Snakemake command.
snakemake --forceall --dag \
results/builds/natural/natural_sample_1_with_90_vpm_sliding/timepoints/2015-10-01/tip_attributes.tsv | \
dot -Tpng > example_dag.png
Inputs
Both simulated and natural population builds depend on
the configuration file
,
config/config.json
, described above.
Simulated populations are generated by SANTA-SIM as part of the workflow.
SANTA-SIM XML configuration files determine the parameters of the simulations and can be found in the corresponding data directory for a given simulated sample.
For example, the densely sampled simulated populations configuration file is
data/simulated/simulated_sample_3/influenza_h3n2_ha.xml
.
Natural populations are represented by FASTA sequences that are freely available through GISAID.
See
instructions on how to download these sequences below
.
The full analysis for this paper also depends on raw hemagglutination inhibition (HI) and focus-reduction assay (FRA) titer measurements.
Although these measurements are not publicly available, due to existing data sharing agreements, we provide imputed log2 titer values produced by Neher al. 2016's phylogenetic model for each strain.
These values are available in the results files named
tip_attributes_with_weighted_distances.tsv
.
For example, the complete set of tip attributes including imputed titer drops for the validation period of natural populations are available in
results/builds/natural/natural_sample_1_with_90_vpm_sliding/tip_attributes_with_weighted_distances.tsv
.
Outputs
The primary outputs of this workflow are tables of tip attributes per populations that are used to fit models (
tip_attributes_with_weighted_distances.tsv
) and the tables of resulting model coefficients (
distance_model_coefficients.tsv
) and distances to the future (
distance_model_errors.tsv
).
Data for validation figures (e.g., Figures 4 and 7) can be found in
validation_figure_clades.tsv
and
validation_figure_ranks.tsv
.
Additional outputs include the mapping of individual strains to clades (
tips_to_clades.tsv
) for the creation of model validation figures (e.g., comparison of estimated and observed clade frequency fold changes and absolute forecasting errors).
The following outputs are included in this repository and are also created by running the full analysis pipeline.
-
results/
-
distance_model_errors.tsv
-
distance_model_coefficients.tsv
-
validation_figure_clades.tsv
-
validation_figure_ranks.tsv
-
builds/
-
natural/
-
natural_sample_1_with_90_vpm_sliding/
-
tip_attributes_with_weighted_distances.tsv
-
-
natural_sample_1_with_90_vpm_sliding_test_tree/
-
tip_attributes_with_weighted_distances.tsv
-
-
-
simulated/
-
simulated_sample_3/
-
tip_attributes_with_weighted_distances.tsv
-
-
simulated_sample_3_test_tree/
-
tip_attributes_with_weighted_distances.tsv
-
-
-
-
The manuscript and most figures and tables within are also automatically generated by the full analysis workflow. These files can be found in the following paths.
-
manuscript/
-
flu_forecasting.pdf
-
figures/
-
tables/
-
Full analysis
Inspect sequences for simulated populations
Each SANTA-SIM run and subsequent subsampling of the resulting sequences will produce a different random collection of sequences for the workflow. To ensure reproducibility of results, we have included the specific simulated sequences used for analyses in the manuscript. These sequences and their corresponding metadata are available at the following paths:
-
data/simulated/simulated_sample_3/filtered_sequences.fasta
-
data/simulated/simulated_sample_3/filtered_metadata.tsv
Download sequences for natural populations
All hemagglutinin sequences for natural populations are available through the GISAID database . To get access to the database, register for a free GISAID account. After logging into GISAID, select the "EpiFlu" tab from the navigation bar.
Downloading sequences from GISAID requires manually searching for specific accessions (i.e., sequence identifiers) and downloading the corresponding sequences. The maximum length of the GISAID search field is 1,000 characters, so you cannot search for all 20,000+ sequences at once. To facilitate the download process, we have created batches of accessions no longer than 1,000 characters in the file data/gisaid_batches.csv . Each of the 216 batches has its own id and expected number of sequences, to help you track your progress. Copy and paste the list of accessions from each batch into the "Search patterns" field of the GISAID search and select the "Search" button. An example search is shown below.
From the search results display, select the checkbox in the top-left of the search display (above the checkbox for the first row of results). This will select all matching sequences to be downloaded. Click the "Download" button. An example of these search results is shown below.
When the download dialog appears, select the "Sequences (DNA) as FASTA" radio button. Click the checkbox near "HA" to only download hemagglutinin sequences. Delete the contents of the "FASTA Header" text field and paste in the following line instead:
Isolate name | Isolate ID | Collection date | Passage details/history | Submitting lab
Leave all other fields at their default values. The download interface should look like the following screenshot.
Click the "Download" button and name the resulting FASTA file with the same id as your current batch (e.g.,
gisaid_downloads/gisaid_epiflu_sequence_001.fasta
).
This file naming convention will make tracking your progress easier.
After the download completes, click the "Go back" button on the download dialog and then again from the search results display.
Copy and paste the next batch of ids into the search field and repeat these steps until you have downloaded all batches.
When you have downloaded sequences for all batches, concatenate them together into a single file.
cat gisaid_epiflu_sequences/gisaid_epiflu_sequence_*.fasta > gisaid_downloads.fasta
Some strain names contain characters that IQ-TREE does not allow and which it will convert to underscores in its output trees. For example, the apostrophe in the name "Cote d'Ivoire" will be replaced with an underscore. To avoid mismatches between strain names caused by this IQ-TREE replacement, we replace those characters in the initial FASTA file at the beginning of the analysis using seqkit's replace command .
# Install seqkit. Optionally, use "mamba install" instead of "conda install".
conda install -c conda-forge -c bioconda seqkit
# Replace apostrophes with underscores in the FASTA record names.
seqkit replace -p "(')" -r "_" gisaid_downloads.fasta > gisaid_downloads.renamed.fasta
Use augur to parse out the metadata and sequences into separate files. Store these files in a directory with the same name as the natural samples in this analysis.
# Write out sequences and metadata for the validation sample.
mkdir -p data/natural/natural_sample_1_with_90_vpm
augur parse \
--sequences gisaid_downloads.renamed.fasta \
--output-sequences data/natural/natural_sample_1_with_90_vpm/filtered_sequences.fasta \
--output-metadata data/natural/natural_sample_1_with_90_vpm/strains_metadata.tsv \
--fields strain accession collection_date passage_category submitting_lab
# Copy the resulting sequences and metadata into the test sample directory.
mkdir -p data/natural/natural_sample_1_with_90_vpm_test_tree
cp data/natural/natural_sample_1_with_90_vpm/*.{fasta,tsv} data/natural/natural_sample_1_with_90_vpm_test_tree/
Now, you should be able to run the pipeline from start to finish. Confirm this is true by running snakemake in dry run mode.
snakemake --dryrun
Inspect derived titer data for natural populations
Due to existing data sharing agreements, we cannot publicly distribute raw titer measurements for hemagglutination inhibition (HI) assays and focus reduction assays (FRAs).
As an alternative, we provide the derived titer models produced by the
augur titers
command using the algorithms described in
Neher et al. 2016
.
For HI assays, we provide two different model files per analysis timepoint for the titer "tree model" and "substitution model".
These files are named
titers-tree-model.json
and
titers-sub-model.json
, respectively.
For FRAs, we provide model files for the tree model with names like
fra-titers-tree-model.json
.
Example paths for each of these files are listed below for a single timepoint in the analysis of most recent A/H3N2 sequences.
-
results/builds/natural/natural_sample_20191001/timepoints/2015-10-01/titers-tree-model.json
-
results/builds/natural/natural_sample_20191001/timepoints/2015-10-01/titers-sub-model.json
-
results/builds/natural/natural_sample_20191001/timepoints/2015-10-01/fra-titers-tree-model.json
These model files contain all information required to fit the HI- and FRA-based forecasting models described in the manuscript.
Run the full analysis
Run the entire pipeline locally with four simultaneous jobs.
snakemake --use-conda -j 4
You can also run just one of the natural builds as follows, to confirm your environment is configured properly.
snakemake --use-conda --config active_builds='natural_sample_1_with_90_vpm_sliding' -j 4
Alternately, follow Snakemake documentation to distribute the entire pipeline to your cloud or cluster accounts . The following is an example of how to distribute the pipeline on a SLURM-based cluster using a Snakemake profile.
snakemake --profile profiles/slurm-drmaa
Code Snippets
18 19 20 21 22 23 24 25 26 | shell: """ python3 scripts/partition_strains_by_timepoint.py \ {input.metadata} \ {wildcards.timepoint} \ {output} \ --years-back {params.years_back} \ {params.reference_strains} """ |
36 37 38 39 40 41 42 | shell: """ python3 scripts/extract_sequences.py \ --sequences {input.sequences} \ --samples {input.strains} \ --output {output} """ |
54 55 56 57 58 59 60 61 62 63 | shell: """ augur align \ --sequences {input.sequences} \ --reference-sequence {input.reference} \ --output {output.alignment} \ --remove-reference \ --fill-gaps \ --nthreads {threads} """ |
77 78 79 80 81 82 83 84 | shell: """ augur tree \ --alignment {input.alignment} \ --output {output.tree} \ --method iqtree \ --nthreads {threads} &> {log} """ |
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | shell: """ augur refine \ --tree {input.tree} \ --alignment {input.alignment} \ --metadata {input.metadata} \ --output-tree {output.tree} \ --output-node-data {output.node_data} \ --timetree \ --no-covariance \ {params.clock_rate} \ {params.clock_std_dev} \ --coalescent {params.coalescent} \ --date-confidence \ --date-inference {params.date_inference} &> {log} """ |
144 145 146 147 148 149 150 151 152 153 154 155 156 157 | shell: """ augur frequencies \ --method kde \ --tree {input.tree} \ --metadata {input.metadata} \ --narrow-bandwidth {params.narrow_bandwidth} \ --wide-bandwidth {params.wide_bandwidth} \ --proportion-wide {params.proportion_wide} \ --min-date {params.min_date} \ --max-date {params.max_date} \ --pivot-interval {params.pivot_frequency} \ --output {output} """ |
186 187 188 189 190 191 192 193 | shell: """python3 scripts/frequencies.py {input.tree} {input.metadata} {output} \ --narrow-bandwidth {params.narrow_bandwidth} \ --wide-bandwidth {params.wide_bandwidth} \ --proportion-wide {params.proportion_wide} \ --pivot-frequency {params.pivot_frequency} \ --start-date {params.start_date} \ --end-date {wildcards.timepoint} \ --include-internal-nodes &> {log}""" |
215 216 217 218 219 220 221 222 223 224 225 | shell: """augur frequencies \ --method diffusion \ --tree {input.tree} \ --metadata {input.metadata} \ --output {output} \ --include-internal-nodes \ --stiffness {params.stiffness} \ --inertia {params.inertia} \ --pivot-interval {params.pivot_frequency} \ --min-date {params.min_date} \ --max-date {params.max_date} &> {log}""" |
237 238 239 240 241 242 243 244 245 | shell: """ python3 scripts/frequencies_to_table.py \ --tree {input.tree} \ --frequencies {input.frequencies} \ --method {params.method} \ --output {output} \ --annotations timepoint={wildcards.timepoint} """ |
257 258 259 260 261 262 263 264 265 | shell: """ python3 scripts/frequencies_to_table.py \ --tree {input.tree} \ --frequencies {input.frequencies} \ --method {params.method} \ --output {output} \ --annotations timepoint={wildcards.timepoint} """ |
280 281 282 283 284 285 286 287 | shell: """ augur ancestral \ --tree {input.tree} \ --alignment {input.alignment} \ --output {output.node_data} \ --inference {params.inference} &> {log} """ |
301 302 303 304 305 306 307 308 | shell: """ augur translate \ --tree {input.tree} \ --ancestral-sequences {input.node_data} \ --reference-sequence {input.reference} \ --output {output.node_data} &> {log} """ |
321 322 323 324 325 326 327 328 329 | shell: """ augur reconstruct-sequences \ --tree {input.tree} \ --mutations {input.node_data} \ --gene {wildcards.gene} \ --output {output.aa_alignment} \ --internal-nodes &> {log} """ |
340 341 342 343 344 345 346 347 | shell: """ python3 scripts/convert_translations_to_json.py \ --tree {input.tree} \ --alignment {input.translations} \ --gene-names {params.gene_names} \ --output {output.translations} """ |
361 362 363 364 365 366 367 368 369 370 | shell: """ python3 scripts/nonoverlapping_clades.py \ --tree {input.tree} \ --translations {input.translations} \ --gene-names {params.gene_names} \ --annotations timepoint={wildcards.timepoint} \ --output {output.clades} \ --output-tip-clade-table {output.tip_clade_table} &> {log} """ |
384 385 386 387 388 389 390 391 392 | shell: """ python3 scripts/calculate_delta_frequency.py \ --tree {input.tree} \ --frequencies {input.frequencies} \ --frequency-method {params.method} \ --delta-pivots {params.delta_pivots} \ --output {output.delta_frequency} &> {log} """ |
406 407 408 409 410 411 412 413 414 | shell: """ augur traits \ --tree {input.tree} \ --metadata {input.metadata} \ --output {output.node_data} \ --columns {params.columns} \ --confidence """ |
433 434 435 436 437 438 439 440 441 442 443 444 445 446 | shell: """ augur distance \ --tree {input.tree} \ --alignment {input.alignments} \ --gene-names {params.genes} \ --compare-to {params.comparisons} \ --attribute-name {params.attribute_names} \ --map {input.distance_maps} \ --date-annotations {input.date_annotations} \ --earliest-date {params.earliest_date} \ --latest-date {params.latest_date} \ --output {output} """ |
465 466 467 468 469 470 471 472 473 474 475 476 477 | shell: """ python3 scripts/pairwise_distances.py \ --tree {input.tree} \ --frequencies {input.frequencies} \ --alignment {input.alignments} \ --gene-names {params.genes} \ --attribute-name {params.attribute_names} \ --map {input.distance_maps} \ --date-annotations {input.date_annotations} \ --years-back-to-compare {params.years_back_to_compare} \ --output {output} &> {log} """ |
505 506 507 508 509 510 511 512 513 514 515 | shell: """ python3 src/cross_immunity.py \ --frequencies {input.frequencies} \ --distances {input.distances} \ --date-annotations {input.date_annotations} \ --distance-attributes {params.distance_attributes} \ --immunity-attributes {params.immunity_attributes} \ --decay-factors {params.decay_factors} \ --output {output} """ |
530 531 532 533 534 535 536 537 538 539 | shell: """ augur lbi \ --tree {input.tree} \ --branch-lengths {input.branch_lengths} \ --output {output} \ --attribute-names {params.names} \ --tau {params.tau} \ --window {params.window} """ |
553 554 555 556 557 558 559 560 561 562 563 | shell: """ augur lbi \ --tree {input.tree} \ --branch-lengths {input.branch_lengths} \ --output {output} \ --attribute-names {params.names} \ --tau {params.tau} \ --window {params.window} \ --no-normalization """ |
581 582 583 584 585 586 587 588 589 590 | shell: """ augur titers sub \ --titers {input.titers} \ --alignment {input.alignments} \ --tree {input.tree} \ --gene-names {params.genes} \ --allow-empty-model \ --output {output.titers_model} &> {log} """ |
602 603 604 605 606 607 608 609 | shell: """ augur titers tree \ --titers {input.titers} \ --tree {input.tree} \ --allow-empty-model \ --output {output.titers_model} &> {log} """ |
621 622 623 624 625 626 627 628 | shell: """ augur titers tree \ --titers {input.titers} \ --tree {input.tree} \ --allow-empty-model \ --output {output.titers_model} &> {log} """ |
640 641 642 643 644 645 | shell: """ python3 scripts/rename_fields_in_fra_titer_models.py \ --titers-model {input.titers_model} \ --output {output.titers_model} """ |
654 655 656 657 658 659 | shell: """ python3 scripts/titer_model_to_distance_map.py \ --model {input.model} \ --output {output} """ |
678 679 680 681 682 683 684 685 686 687 688 689 690 | shell: """ python3 scripts/pairwise_distances.py \ --tree {input.tree} \ --frequencies {input.frequencies} \ --alignment {input.alignments} \ --gene-names {params.genes} \ --attribute-name {params.attribute_names} \ --map {input.distance_maps} \ --date-annotations {input.date_annotations} \ --years-back-to-compare {params.years_back_to_compare} \ --output {output} &> {log} """ |
708 709 710 711 712 713 714 715 716 717 718 719 | shell: """ python3 scripts/pairwise_titer_tree_distances.py \ --tree {input.tree} \ --frequencies {input.frequencies} \ --model {input.model} \ --attribute-name {params.attribute_names} \ --date-annotations {input.date_annotations} \ --months-back-for-current-samples {params.months_back_for_current_samples} \ --years-back-to-compare {params.years_back_to_compare} \ --output {output} &> {log} """ |
738 739 740 741 742 743 744 745 746 747 748 749 750 | shell: """ python3 scripts/pairwise_titer_tree_distances.py \ --tree {input.tree} \ --frequencies {input.frequencies} \ --model {input.model} \ --model-attribute-name {params.model_attribute_name} \ --attribute-name {params.attribute_names} \ --date-annotations {input.date_annotations} \ --months-back-for-current-samples {params.months_back_for_current_samples} \ --years-back-to-compare {params.years_back_to_compare} \ --output {output} &> {log} """ |
766 767 768 769 770 771 772 773 774 775 776 | shell: """ python3 src/cross_immunity.py \ --frequencies {input.frequencies} \ --distances {input.distances} \ --date-annotations {input.date_annotations} \ --distance-attributes {params.distance_attributes} \ --immunity-attributes {params.immunity_attributes} \ --decay-factors {params.decay_factors} \ --output {output} """ |
792 793 794 795 796 797 798 799 800 801 802 | shell: """ python3 src/cross_immunity.py \ --frequencies {input.frequencies} \ --distances {input.distances} \ --date-annotations {input.date_annotations} \ --distance-attributes {params.distance_attributes} \ --immunity-attributes {params.immunity_attributes} \ --decay-factors {params.decay_factors} \ --output {output} """ |
818 819 820 821 822 823 824 825 826 827 828 | shell: """ python3 src/cross_immunity.py \ --frequencies {input.frequencies} \ --distances {input.distances} \ --date-annotations {input.date_annotations} \ --distance-attributes {params.distance_attributes} \ --immunity-attributes {params.immunity_attributes} \ --decay-factors {params.decay_factors} \ --output {output} """ |
839 840 841 842 843 844 845 846 | shell: """ python3 scripts/normalize_fitness.py \ --metadata {input.metadata} \ --frequencies-table {input.frequencies} \ --frequency-method {params.preferred_frequency_method} \ --output {output.fitness} """ |
856 857 858 859 860 861 862 | shell: """ python3 scripts/distance_from_consensus.py \ --sequences {input.sequences} \ --frequencies {input.frequencies} \ --output {output.distances} """ |
920 921 922 923 924 925 926 927 928 929 930 931 | shell: """ python3 scripts/node_data_to_table.py \ --tree {input.tree} \ --metadata {input.metadata} \ --jsons {input.node_data} \ --output {output} \ {params.excluded_fields_arg} \ --annotations timepoint={wildcards.timepoint} \ lineage={params.lineage} \ segment={params.segment} """ |
944 945 946 947 948 949 950 951 952 | shell: """ python3 scripts/merge_node_data_and_frequencies.py \ --node-data {input.node_data} \ --kde-frequencies {input.kde_frequencies} \ --diffusion-frequencies {input.diffusion_frequencies} \ --preferred-frequency-method {params.preferred_frequency_method} \ --output {output.table} """ |
961 962 963 964 965 966 | shell: """ python3 scripts/collect_tables.py \ --tables {input} \ --output {output.attributes} """ |
975 976 977 978 979 980 | shell: """ python3 scripts/annotate_naive_tip_attribute.py \ --tip-attributes {input.attributes} \ --output {output.attributes} """ |
993 994 995 996 997 998 999 1000 | shell: """ python3 scripts/calculate_target_distances.py \ --tip-attributes {input.attributes} \ --delta-months {params.delta_months} \ --sequence-attribute-name {params.sequence_attribute_name} \ --output {output} """ |
1011 1012 1013 1014 1015 1016 1017 1018 | shell: """ python3 src/weighted_distances.py \ --tip-attributes {input.attributes} \ --distances {input.distances} \ --delta-months {params.delta_months} \ --output {output} """ |
1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 | shell: """ python3 src/fit_model.py \ --tip-attributes {input.attributes} \ --training-window {params.training_window} \ --delta-months {params.delta_months} \ --predictors {params.predictors} \ --cost-function {params.cost_function} \ --l1-lambda {params.l1_lambda} \ --target distances \ --distances {input.distances} \ --errors-by-timepoint {output.errors} \ --coefficients-by-timepoint {output.coefficients} \ --include-scores \ --output {output.model} &> {log} """ |
1062 1063 1064 1065 1066 1067 | shell: """ python3 scripts/extract_minimal_models_by_distances.py \ --model {input.model} \ --output {output.model} """ |
1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 | shell: """ python3 scripts/annotate_model_tables.py \ --tip-attributes {input.attributes} \ --model {input.model} \ --errors-by-timepoint {input.errors} \ --coefficients-by-timepoint {input.coefficients} \ --annotated-errors-by-timepoint {output.errors} \ --annotated-coefficients-by-timepoint {output.coefficients} \ --delta-months {params.delta_months} \ --annotations type="{wildcards.type}" sample="{wildcards.sample}" error_type="{params.error_type}" """ |
1103 1104 1105 1106 1107 1108 | shell: """ python3 scripts/collect_tables.py \ --tables {input} \ --output {output.tip_clade_table} """ |
1121 1122 1123 1124 1125 1126 1127 1128 | shell: """ python3 scripts/select_clades.py \ --tip-attributes {input.attributes} \ --tips-to-clades {input.tips_to_clades} \ --delta-months {params.delta_months} \ --output {output} &> {log} """ |
1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 | shell: """ python3 src/fit_model.py \ --tip-attributes {input.attributes} \ --final-clade-frequencies {input.final_clade_frequencies} \ --training-window {params.training_window} \ --delta-months {params.delta_months} \ --predictors {params.predictors} \ --cost-function {params.cost_function} \ --l1-lambda {params.l1_lambda} \ --pseudocount {params.pseudocount} \ --target clades \ --output {output} &> {log} """ |
1200 1201 1202 1203 | shell: """ python3 scripts/plot_tree.py {input} {output} &> {log} """ |
1210 | shell: "gs -dBATCH -dNOPAUSE -q -sDEVICE=pdfwrite -sOutputFile={output} {input}" |
1221 1222 1223 1224 1225 1226 1227 | shell: """ python3 scripts/calculate_target_distances.py \ --tip-attributes {input.attributes} \ --delta-months {params.delta_months} \ --output {output} """ |
1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 | shell: """ python3 src/forecast_model.py \ --tip-attributes {input.attributes} \ --distances {input.distances} \ --frequencies {input.frequencies} \ --model {input.model} \ --delta-months {params.delta_months} \ --output-node-data {output.node_data} \ --output-frequencies {output.frequencies} """ |
1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 | shell: """ augur export \ --tree {input.tree} \ --metadata {input.metadata} \ --node-data {input.node_data} {input.forecasts} \ --colors {input.colors} \ --auspice-config {input.auspice_config} \ --output-tree {output.auspice_tree} \ --output-meta {output.auspice_metadata} \ --panels {params.panels} \ --minify-json """ |
1294 1295 1296 1297 1298 1299 1300 1301 1302 | shell: """ python3 src/forecast_model.py \ --tip-attributes {input.attributes} \ --distances {input.distances} \ --model {input.model} \ --delta-months {params.delta_months} \ --output-table {output.table} """ |
1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 | shell: """ python3 src/fit_model.py \ --tip-attributes {input.attributes} \ --target distances \ --distances {input.distances} \ --fixed-model {input.model} \ --errors-by-timepoint {output.errors} \ --coefficients-by-timepoint {output.coefficients} \ --include-scores \ --output {output.model} &> {log} """ |
1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 | shell: """ python3 scripts/annotate_model_tables.py \ --tip-attributes {input.attributes} \ --model {input.model} \ --errors-by-timepoint {input.errors} \ --coefficients-by-timepoint {input.coefficients} \ --annotated-errors-by-timepoint {output.errors} \ --annotated-coefficients-by-timepoint {output.coefficients} \ --delta-months {params.delta_months} \ --annotations type="{wildcards.type}" sample="{params.sample}" error_type="test" """ |
1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 | shell: """ python3 scripts/plot_validation_figure_by_population.py \ --tip-attributes {input.attributes} \ --tips-to-clades {input.tips_to_clades} \ --forecasts {input.forecasts} \ --model-errors {input.model_errors} \ --population {wildcards.type} \ --sample {wildcards.sample} \ --predictors {wildcards.predictors} \ --output {output.figure} \ --output-clades-table {output.clades} \ --output-ranks-table {output.ranks} """ |
22 23 24 25 | shell: """ cd data/simulated/{wildcards.sample} && java -jar {SNAKEMAKE_DIR}/dist/santa.jar -seed={params.seed} {SNAKEMAKE_DIR}/{input.simulation_config} """ |
37 38 39 40 41 42 43 44 | shell: """ augur parse \ --sequences {input.sequences} \ --output-sequences {output.sequences} \ --output-metadata {output.metadata} \ --fields {params.fasta_fields} """ |
56 57 58 59 60 61 62 63 | shell: """ python3 scripts/standardize_simulated_sequence_dates.py \ --metadata {input.metadata} \ --start-year {params.start_year} \ --generations-per-year {params.generations_per_year} \ --output {output.metadata} """ |
79 80 81 82 83 84 85 86 87 88 | shell: """ augur filter \ --sequences {input.sequences} \ --metadata {input.metadata} \ --min-date {params.min_date} \ --group-by {params.group_by} \ --sequences-per-group {params.viruses_per_month} \ --output {output} """ |
98 99 100 101 102 103 104 | shell: """ python3 scripts/filter_simulated_metadata.py \ --sequences {input.sequences} \ --metadata {input.metadata} \ --output {output.metadata} """ |
127 128 129 130 131 132 133 134 135 136 137 | shell: """ python3 {path_to_fauna}/vdb/download.py \ --database vdb \ --virus flu \ --fasta_fields {params.fasta_fields} \ --resolve_method split_passage \ --select locus:{params.segment} lineage:seasonal_{params.lineage} \ --path data/natural/{wildcards.sample} \ --fstem original_sequences """ |
150 151 152 153 154 155 156 157 158 159 160 | shell: """ python3 {path_to_fauna}/tdb/download.py \ --database {params.databases} \ --virus flu \ --subtype {params.lineage} \ --select assay_type:{params.assay} \ --path data/natural/{wildcards.sample} \ --fstem complete \ --ftype json """ |
173 174 175 176 177 178 179 180 181 182 183 | shell: """ python3 {path_to_fauna}/tdb/download.py \ --database {params.databases} \ --virus flu \ --subtype {params.lineage} \ --select assay_type:{params.assay} \ --path data/natural/{wildcards.sample} \ --fstem complete_fra \ --ftype json """ |
196 197 198 199 200 201 202 | shell: """ python3 scripts/get_titers_by_passage.py \ --titers {input.titers} \ --passage-type {params.passage} \ --output {output.titers} """ |
215 216 217 218 219 220 221 | shell: """ python3 scripts/get_titers_by_passage.py \ --titers {input.titers} \ --passage-type {params.passage} \ --output {output.titers} """ |
233 234 235 236 237 238 239 240 | shell: """ augur parse \ --sequences {input.sequences} \ --output-sequences {output.sequences} \ --output-metadata {output.metadata} \ --fields {params.fasta_fields} """ |
254 255 256 257 258 259 260 261 262 263 | shell: """ augur filter \ --sequences {input.sequences} \ --metadata {input.metadata} \ --min-length {params.min_length} \ --exclude {params.exclude} \ --exclude-where country=? region=? passage=egg \ --output {output} """ |
272 273 274 275 276 277 | shell: """ python3 scripts/filter_strains_with_ambiguous_dates.py \ --metadata {input.metadata} \ --output {output.metadata} """ |
304 305 306 307 308 309 310 311 312 313 314 315 316 | shell: """ python3 scripts/select_strains.py \ --sequences {input.sequences} \ --metadata {input.metadata} \ --segments {params.segment} \ --include {input.include} \ --lineage {params.lineage} \ --time-interval {params.start_date} {params.end_date} \ --viruses_per_month {params.viruses_per_month} \ --titers {input.titers} \ --output {output.strains} """ |
326 327 328 329 330 331 332 | shell: """ python3 scripts/filter_metadata_by_strains.py \ --metadata {input.metadata} \ --strains {input.strains} \ --output {output.metadata} """ |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | import argparse import json import pandas as pd if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tip-attributes", required=True, help="tab-delimited file describing tip attributes at all timepoints with standardized predictors and weighted distances to the future") parser.add_argument("--model", required=True, help="JSON representing the model fit with training and cross-validation results, beta coefficients for predictors, and summary statistics") parser.add_argument("--errors-by-timepoint", help="data frame of cross-validation errors by validation timepoint") parser.add_argument("--coefficients-by-timepoint", help="data frame of coefficients by validation timepoint") parser.add_argument("--annotated-errors-by-timepoint", help="annotated model errors by timepoint") parser.add_argument("--annotated-coefficients-by-timepoint", help="annotated model coefficients by timepoint") parser.add_argument("--delta-months", type=int, help="number of months to project clade frequencies into the future") parser.add_argument("--annotations", nargs="+", help="additional annotations to add to the output table in the format of 'key=value' pairs") args = parser.parse_args() # Load tip attributes to calculate within-timepoint diversity. tips = pd.read_csv(args.tip_attributes, sep="\t", parse_dates=["timepoint"]) # Calculate weighted distance of each timepoint to itself. tips["average_distance_to_present"] = tips["weighted_distance_to_present"] * tips["frequency"] distances_to_present_per_timepoint = tips.groupby("timepoint")["average_distance_to_present"].sum().reset_index() print(distances_to_present_per_timepoint.head()) # Load the model JSON to get access to projected frequencies for tips. with open(args.model, "r") as fh: model = json.load(fh) # Collect all projected frequencies and weighted distances, to enable # calculation of weighted average distances within and between seasons. df = pd.concat([ pd.DataFrame(scores["validation_data"]["y_hat"]) for scores in model["scores"] ]) # Prepare to calculate weighted distance of each timepoint's projected # future to its observed future timepoint. df["average_distance_to_future"] = df["weighted_distance_to_future"] * df["projected_frequency"] # Sum the scaled weighted distances to get average distances per timepoint. distances_per_timepoint = df.groupby("timepoint").aggregate({ "average_distance_to_future": "sum" }).reset_index() # Prepare timepoint for joins with model errors. distances_per_timepoint["timepoint"] = pd.to_datetime(distances_per_timepoint["timepoint"]) # Load the original model table output for validation/test errors. errors = pd.read_csv(args.errors_by_timepoint, sep="\t", parse_dates=["validation_timepoint"]) errors["future_timepoint"] = errors["validation_timepoint"] + pd.DateOffset(months=args.delta_months) # Annotate information about the present's estimate of the future. print("Errors: %s" % str(errors.shape)) errors = errors.merge( distances_per_timepoint.loc[:, ["timepoint", "average_distance_to_future"]].copy(), left_on=["validation_timepoint"], right_on=["timepoint"] ).drop(columns=["timepoint"]) print("Errors with distance to future: %s" % str(errors.shape)) # Annotate information about the future's distance to itself for the # present timepoints. errors = errors.merge( distances_to_present_per_timepoint, left_on=["future_timepoint"], right_on=["timepoint"] ).rename(columns={ "average_distance_to_present": "average_diversity_in_future" }) #drop(columns=["timepoint", "future_timepoint"]). print("Errors with diversity in future: %s" % str(errors.shape)) # Load coefficients to which annotations will be added. coefficients = pd.read_csv(args.coefficients_by_timepoint, sep="\t") # Add any additional annotations requested by the user in the format of # "key=value" pairs where each key becomes a new column with the given # value. if args.annotations: for annotation in args.annotations: key, value = annotation.split("=") errors[key] = value coefficients[key] = value # Save annotated tables. errors.to_csv(args.annotated_errors_by_timepoint, sep="\t", header=True, index=False) coefficients.to_csv(args.annotated_coefficients_by_timepoint, sep="\t", header=True, index=False) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | import argparse import pandas as pd if __name__ == "__main__": parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--tip-attributes", required=True, help="table of tip attributes from one or more timepoints") parser.add_argument("--output", required=True, help="table of tip attributes annotated with a 'naive' predictor") args = parser.parse_args() # Annotate a predictor for a naive model with no growth. df = pd.read_csv(args.tip_attributes, sep="\t") df["naive"] = 0.0 df.to_csv(args.output, sep="\t", index=False) |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | import argparse from augur.frequency_estimators import TreeKdeFrequencies from augur.utils import write_json import Bio.Phylo from collections import defaultdict import json import numpy as np if __name__ == "__main__": parser = argparse.ArgumentParser( description="Calculate the change in frequency for clades over time", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--tree", required=True, help="Newick tree") parser.add_argument("--frequencies", required=True, help="frequencies JSON") parser.add_argument("--frequency-method", required=True, choices=["kde", "diffusion"], help="method used to estimate frequencies") parser.add_argument("--clades", help="JSON of clade annotations for nodes in the given tree") parser.add_argument("--delta-pivots", type=int, default=1, help="number of frequency pivots to look back in time for change in frequency calculation") parser.add_argument("--output", required=True, help="JSON of delta frequency annotations for nodes in the given tree") args = parser.parse_args() # Load the tree. tree = Bio.Phylo.read(args.tree, "newick") # Load frequencies. with open(args.frequencies, "r") as fh: frequencies_json = json.load(fh) if args.frequency_method == "kde": kde_frequencies = TreeKdeFrequencies.from_json(frequencies_json) frequencies = kde_frequencies.frequencies # Load clades. with open(args.clades, "r") as fh: clades_json = json.load(fh) clades_by_node = { key: value["clade_membership"] for key, value in clades_json["nodes"].items() } # Calculate the total frequency per clade at the most recent timepoint and # requested timepoint in the past using non-zero tip frequencies. current_clade_frequencies = defaultdict(float) previous_clade_frequencies = defaultdict(float) for tip in tree.find_clades(terminal=True): # Add tip to current clade frequencies. current_clade_frequencies[clades_by_node[tip.name]] += frequencies[tip.name][-1] # Add tip to previous clade frequencies. previous_clade_frequencies[clades_by_node[tip.name]] += frequencies[tip.name][-(args.delta_pivots + 1)] # Determine the total time that elapsed between the current and past timepoint. delta_time = kde_frequencies.pivots[-1] - kde_frequencies.pivots[-(args.delta_pivots + 1)] # Calculate the change in frequency over time elapsed for each clade. delta_frequency_by_clade = {} for clade, current_frequency in current_clade_frequencies.items(): # If the current clade was not observed in the previous timepoint, it # will have a zero frequency. delta_frequency_by_clade[clade] = (current_frequency - previous_clade_frequencies.get(clade, 0.0)) / delta_time # Assign clade delta frequencies to all corresponding tips and internal nodes. delta_frequency = {} for node in tree.find_clades(terminal=True): delta_frequency[node.name] = { "delta_frequency": delta_frequency_by_clade.get(clades_by_node[node.name], 0.0) } else: frequencies = frequencies_json # Determine the total time that elapsed between the current and past timepoint. delta_time = frequencies["pivots"][-1] - frequencies["pivots"][-(args.delta_pivots + 1)] delta_frequency = {} for node in tree.find_clades(terminal=True): delta_frequency[node.name] = { "delta_frequency": (frequencies[node.name]["global"][-1] - frequencies[node.name]["global"][-(args.delta_pivots + 1)]) / delta_time } # Write out the node annotations. write_json({"nodes": delta_frequency}, args.output) |
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import argparse import numpy as np import pandas as pd if __name__ == '__main__': parser = argparse.ArgumentParser( description="Calculate pairwise distances between samples at adjacent timepoints (t and t - delta months)", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--tip-attributes", required=True, help="a tab-delimited file describing tip attributes at one or more timepoints") parser.add_argument("--delta-months", required=True, nargs="+", type=int, help="number of months between timepoints to be compared") parser.add_argument("--output", help="tab-delimited file of pairwise distances between tips in timepoints separate by the given delta time", required=True) parser.add_argument("--sequence-attribute-name", default="aa_sequence", help="attribute name of sequences to compare") args = parser.parse_args() # Load tip attributes. tips = pd.read_csv(args.tip_attributes, sep="\t", parse_dates=["timepoint"]) # Open output file handle to enable streaming distances to disk instead of # storing them in memory. output_handle = open(args.output, "w") output_handle.write("\t".join(["sample", "other_sample", "distance"]) + "\n") # Calculate pairwise distances between all tips within a timepoint and at # the next timepoint as defined by the given delta months. for timepoint, timepoint_df in tips.groupby("timepoint"): current_tips = [ tuple(values) for values in timepoint_df.loc[:, ["strain", args.sequence_attribute_name]].values.tolist() ] comparison_tips = current_tips for delta_month in args.delta_months: future_timepoint_df = tips[tips["timepoint"] == (timepoint + pd.DateOffset(months=delta_month))] future_tips = [ tuple(values) for values in future_timepoint_df.loc[:, ["strain", args.sequence_attribute_name]].values.tolist() ] comparison_tips = comparison_tips + future_tips comparison_tips = list(set(comparison_tips)) for current_tip, current_tip_sequence in current_tips: current_tip_sequence_array = np.frombuffer(current_tip_sequence.encode(), dtype="S1") for future_tip, future_tip_sequence in comparison_tips: future_tip_sequence_array = np.frombuffer(future_tip_sequence.encode(), dtype="S1") distance = (current_tip_sequence_array != future_tip_sequence_array).sum() output_handle.write("\t".join([current_tip, future_tip, str(distance)]) + "\n") # Close output. output_handle.close() |
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import argparse import numpy as np import pandas as pd import re import sys class BadTransformException(Exception): pass def parse_transform(transform): """Parse a transform string into its corresponding new column name, Python function, and original column name. Return `None` for the Python function if the requested function is not valid. Parameters ---------- transform : str transformation definition string (e.g., "log_lbi=log(lbi)") Returns ------- str, callable, str new column name, transformation function, and original column name >>> parse_transform("log_lbi=log(lbi)") ('log_lbi', <ufunc 'log'>, 'lbi') >>> parse_transform("fake_col=fake(col)") Traceback (most recent call last): ... collect_tables.BadTransformException: the requested function was invalid >>> parse_transform("bad_transform") Traceback (most recent call last): ... collect_tables.BadTransformException: the requested transform was malformed """ match = re.match(r"(?P<new_column>\w+)=(?P<function>\w+)\((?P<column>\w+)\)", transform) if match is None: raise BadTransformException("the requested transform was malformed") new_column, function_string, column = match.groups() function = getattr(np, function_string, None) if function is None: raise BadTransformException("the requested function was invalid") return new_column, function, column if __name__ == '__main__': parser = argparse.ArgumentParser( description="Collect two or more data frame tables", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--tables", nargs="+", required=True, help="tab-delimited files with the same columns to be collected into a single file") parser.add_argument("--transforms", nargs="+", help="a list of new columns to create by transformation of existing columns (e.g., 'log_lbi=log(lbi)')") parser.add_argument("--output", required=True, help="tab-delimited output file collecting the given input tables") args = parser.parse_args() # Concatenate tip attributes across all timepoints. df = pd.concat([pd.read_table(table) for table in args.tables], ignore_index=True) # Apply transformations. if args.transforms: for transform in args.transforms: try: new_column, transform_function, column = parse_transform(transform) df[new_column] = transform_function(df[column]) except BadTransformException as e: print(f"Error: Could not apply transformation '{transform}' because {e}", file=sys.stderr) except Exception as e: print(f"Error: Failed to apply transformation '{transform}' ({e})", file=sys.stderr) df.to_csv(args.output, sep="\t", index=False) |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | import argparse import pandas as pd if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tables", nargs="+", help="tables to concatenate") parser.add_argument("--separator", default="\t", help="separator between columns in the given tables") parser.add_argument("--output", help="concatenated table") args = parser.parse_args() # Concatenate tables. df = pd.concat([ pd.read_csv(table_file, sep=args.separator) for table_file in args.tables ], ignore_index=True, sort=True) df.to_csv(args.output, sep=args.separator, header=True, index=False) |
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | import argparse from augur.reconstruct_sequences import load_alignments from augur.utils import write_json import Bio.Phylo if __name__ == '__main__': parser = argparse.ArgumentParser( description="Convert translation FASTA to a node data JSON", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--tree", required=True, help="Newick file for the tree used to construct the given node data JSONs") parser.add_argument("--alignment", nargs="+", help="sequence(s) to be used, supplied as FASTA files", required=True) parser.add_argument('--gene-names', nargs="+", type=str, help="names of the sequences in the alignment, same order assumed", required=True) parser.add_argument("--output", help="JSON file with translated sequences by node", required=True) parser.add_argument("--include-internal-nodes", action="store_true", help="include data associated with internal nodes in the output JSON") parser.add_argument("--attribute-name", default="aa_sequence", help="name of attribute to store the complete amino acid sequence of each node") args = parser.parse_args() # Load tree. tree = Bio.Phylo.read(args.tree, "newick") # Load sequences. alignments = load_alignments(args.alignment, args.gene_names) # Concatenate translated sequences into a single sequence indexed by sample name. is_node_terminal = {node.name: node.is_terminal() for node in tree.find_clades()} translations = {} for gene in args.gene_names: alignment = alignments[gene] for record in alignment: if is_node_terminal[record.name] or args.include_internal_nodes: # Initialize new samples by name with an empty string. if record.name not in translations: translations[record.name] = {args.attribute_name: ""} # Append the current gene's amino acid sequence to the current # string for this sample. translations[record.name][args.attribute_name] += str(record.seq) # Write out the node annotations. write_json({"nodes": translations}, args.output) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | import argparse from augur.utils import read_node_data, write_json import Bio.Align from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord from collections import Counter import numpy as np import pandas as pd if __name__ == "__main__": parser = argparse.ArgumentParser( description="Calculate consensus sequence for all non-zero frequency strains and the distance of each strain from the resulting consensus.", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--sequences", required=True, help="node data JSON containing sequences to find a consensus for") parser.add_argument("--frequencies", required=True, help="table of strain frequencies at the current timepoint") parser.add_argument("--sequence-attribute", default="aa_sequence", help="attribute in node data JSON containing the sequence data to use") parser.add_argument("--frequency-attribute", default="kde_frequency", help="attribute in frequency table representing the frequency data to use") parser.add_argument("--output", required=True, help="node data JSON with consensus sequence and distances from the consensus per strain") args = parser.parse_args() # Load sequence data from a node data JSON file. node_sequences = read_node_data(args.sequences) # Load frequency data. frequencies = pd.read_csv(args.frequencies, sep="\t") # Select names of strains with non-zero frequencies. strains = set(frequencies.query(f"{args.frequency_attribute} > 0.0")["strain"].values) # Select sequences for strains with non-zero frequencies. sequences = Bio.Align.MultipleSeqAlignment([ SeqRecord( Seq( record[args.sequence_attribute], ), id=strain ) for strain, record in node_sequences["nodes"].items() if strain in strains ]) # Output will store the consensus sequence and the distance of each strain # to the consensus. output = { "nodes": {} } # Calculate the consensus sequence using a majority-rule approach where we # take the most common value in each column. consensus = "".join( Counter(sequences[:, i]).most_common(1)[0][0] for i in range(sequences.get_alignment_length()) ) output["consensus"] = consensus # Calculate the distance of each strain sequence from the consensus. consensus_array = np.frombuffer(consensus.encode(), dtype="S1") for sequence in sequences: sequence_array = np.frombuffer(str(sequence.seq).encode(), dtype="S1") distance = int((consensus_array != sequence_array).sum()) output["nodes"][sequence.id] = { "distance_from_consensus": distance } # Output the results. write_json(output, args.output) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | import argparse import json if __name__ == "__main__": parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--model", required=True, help="JSON for a complete fitness model output") parser.add_argument("--output", required=True, help="JSON for a minimal fitness model (coefficients only)") args = parser.parse_args() with open(args.model, "r") as fh: model = json.load(fh) if "scores" in model: del model["scores"] with open(args.output, "w") as oh: json.dump(model, oh, indent=1) |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | import argparse import Bio import Bio.SeqIO if __name__ == '__main__': parser = argparse.ArgumentParser( description="Extract sample sequences by name", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--sequences", required=True, help="FASTA file of all sample sequences") parser.add_argument("--samples", required=True, help="text file of samples names with one name per line") parser.add_argument("--output", required=True, help="FASTA file of extracted sample sequences") args = parser.parse_args() with open(args.samples) as infile: samples = set([line.strip() for line in infile]) with open(args.output, 'w') as outfile: for seq in Bio.SeqIO.parse(args.sequences, 'fasta'): if seq.name in samples: Bio.SeqIO.write(seq, outfile, 'fasta') |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import argparse import pandas as pd if __name__ == "__main__": parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--metadata", required=True, help="table of metadata to be filtered based on a date column") parser.add_argument("--strains", required=True, help="text file with one strain per line that should be included in the output") parser.add_argument("--output", required=True, help="table of filtered metadata") args = parser.parse_args() metadata = pd.read_table(args.metadata) strains = pd.read_table(args.strains, header=None, names=["strain"]) selected_metadata = strains.merge(metadata, how="left", on="strain") selected_metadata.to_csv(args.output, sep="\t", index=False) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import argparse import Bio.SeqIO import pandas as pd if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--sequences", help="simulated sequences that have already been filtered") parser.add_argument("--metadata", help="original metadata table for simulated sequences") parser.add_argument("--output", help="filtered metadata where only samples present in the given sequences are included") args = parser.parse_args() # Get a list of all samples that passed the sequence filtering step. sequences = Bio.SeqIO.parse(args.sequences, "fasta") sample_ids = [sequence.id for sequence in sequences] # Load all metadata. metadata = pd.read_csv(args.metadata, sep="\t") filtered_metadata = metadata[metadata["strain"].isin(sample_ids)].copy() # Save only the metadata records that have entries in the filtered sequences. filtered_metadata.to_csv(args.output, sep="\t", header=True, index=False) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | import argparse import pandas as pd if __name__ == "__main__": parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--metadata", required=True, help="table of metadata to be filtered based on a date column") parser.add_argument("--date-field", default="date", help="name of date column in the metadata") parser.add_argument("--output", required=True, help="table of filtered metadata") args = parser.parse_args() df = pd.read_csv(args.metadata, sep="\t") # Exclude strains with ambiguous collection dates. df[~df[args.date_field].str.contains("XX")].to_csv(args.output, sep="\t", header=True, index=False) |
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | import argparse from augur.utils import read_metadata, get_numerical_dates from augur.frequencies import TreeKdeFrequencies import Bio import Bio.Phylo import datetime import json import numpy as np import os import sys def get_time_interval_as_floats(time_interval): """ Converts the given datetime interval to start and end floats. Returns: start_date (float): the start of the given time interval end_date (float): the end of the given time interval """ start_date = time_interval[1].year + (time_interval[1].month - 1) / 12.0 end_date = time_interval[0].year + (time_interval[0].month - 1) / 12.0 return start_date, end_date if __name__ == "__main__": parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("tree", help="Newick tree") parser.add_argument("metadata", help="tab-delimited metadata for tips in the given tree including a date field") parser.add_argument("frequencies", help="JSON with frequencies estimated from the given tree and used to estimate the given parameters") parser.add_argument("--narrow-bandwidth", type=float, default=1 / 12.0, help="the bandwidth for the narrow KDE") parser.add_argument("--wide-bandwidth", type=float, default=3 / 12.0, help="the bandwidth for the wide KDE") parser.add_argument("--proportion-wide", type=float, default=0.2, help="the proportion of the wide bandwidth to use in the KDE mixture model") parser.add_argument("--pivot-frequency", type=int, default=1, help="number of months between pivots") parser.add_argument("--start-date", help="the start of the interval to estimate frequencies across") parser.add_argument("--end-date", help="the end of the interval to estimate frequencies across") parser.add_argument("--include-internal-nodes", action="store_true", help="calculate frequencies for internal nodes as well as tips") parser.add_argument("--weights", help="a dictionary of key/value mappings in JSON format used to weight tip frequencies") parser.add_argument("--weights-attribute", help="name of the attribute on each tip whose values map to the given weights dictionary") parser.add_argument("--precision", type=int, default=6, help="number of decimal places to retain in frequency estimates") parser.add_argument("--censored", action="store_true", help="calculate censored frequencies at each pivot") args = parser.parse_args() # Load tree. tree = Bio.Phylo.read(args.tree, "newick") # Load metadata. metadata, columns = read_metadata(args.metadata) dates = get_numerical_dates(metadata, fmt='%Y-%m-%d') # Annotate tree with dates and other metadata. for tip in tree.find_clades(terminal=True): tip.attr = {"num_date": np.mean(dates[tip.name])} # Annotate tips with metadata to enable filtering and weighting of # frequencies by metadata attributes. for key, value in metadata[tip.name].items(): tip.attr[key] = value # Convert start and end dates to floats from time interval format. if args.start_date is not None and args.end_date is not None: # Convert the string time interval to a datetime instance and then to floats. time_interval = [ datetime.datetime.strptime(time, "%Y-%m-%d") for time in (args.end_date, args.start_date) ] start_date, end_date = get_time_interval_as_floats(time_interval) else: start_date = end_date = None # Load weights if they have been provided. if args.weights: with open(args.weights, "r") as fh: weights = json.load(fh) weights_attribute = args.weights_attribute else: weights = None weights_attribute = None # Estimate frequencies. frequencies = TreeKdeFrequencies( sigma_narrow=args.narrow_bandwidth, sigma_wide=args.wide_bandwidth, proportion_wide=args.proportion_wide, pivot_frequency=args.pivot_frequency, start_date=start_date, end_date=end_date, weights=weights, weights_attribute=weights_attribute, include_internal_nodes=args.include_internal_nodes, censored=args.censored ) frequencies.estimate(tree) # Export frequencies to JSON. json_frequencies = frequencies.to_json() # Set precision of frequency estimates. for clade in json_frequencies["data"]["frequencies"]: json_frequencies["data"]["frequencies"][clade] = np.around( np.array( json_frequencies["data"]["frequencies"][clade] ), args.precision ).tolist() with open(args.frequencies, "w") as oh: json.dump(json_frequencies, oh, indent=1, sort_keys=True) |
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | import argparse import Bio.Phylo import json import pandas as pd if __name__ == '__main__': parser = argparse.ArgumentParser( description="Convert frequencies JSON to a data frame", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--tree", required=True, help="Newick file for the tree used to estimate the given frequencies") parser.add_argument("--frequencies", required=True, help="frequencies JSON") parser.add_argument("--method", required=True, choices=["kde", "diffusion"], help="method used to estimate frequencies") parser.add_argument("--annotations", nargs="+", help="additional annotations to add to the output table in the format of 'key=value' pairs") parser.add_argument("--output", required=True, help="tab-delimited file with frequency per node at the last available timepoint") parser.add_argument("--include-internal-nodes", action="store_true", help="include data associated with internal nodes in the output table") parser.add_argument("--minimum-frequency", type=float, default=1e-5, help="minimum frequency to keep below which values will be zeroed and all others renormalized to sum to one") args = parser.parse_args() # Load tree. tree = Bio.Phylo.read(args.tree, "newick") # Load frequencies. with open(args.frequencies, "r") as fh: frequencies_json = json.load(fh) if args.method == "kde": frequencies = frequencies_json["data"]["frequencies"] else: frequencies = { node_name: region_frequencies["global"] for node_name, region_frequencies in frequencies_json.items() if node_name not in ["pivots", "counts", "generated_by"] } # Collect the last frequency for each node keeping only terminal nodes # (tips) unless internal nodes are also requested. frequency_key = "%s_frequency" % args.method records = [ { "strain": node.name, frequency_key: float(frequencies[node.name][-1]), "is_terminal": node.is_terminal() } for node in tree.find_clades() if args.include_internal_nodes or node.is_terminal() ] # Convert frequencies data into a data frame. df = pd.DataFrame(records) # Replace records whose frequency values are below the requested minimum # with zeros and renormalize the remaining records to sum to one. to_zero = df[frequency_key] < args.minimum_frequency not_to_zero = ~to_zero df.loc[to_zero, frequency_key] = 0.0 df.loc[not_to_zero, frequency_key] = df.loc[not_to_zero, frequency_key] / df.loc[not_to_zero, frequency_key].sum() # Add any additional annotations requested by the user in the format of # "key=value" pairs where each key becomes a new column with the given # value. if args.annotations: for annotation in args.annotations: key, value = annotation.split("=") df[key] = value # Save the table. df.to_csv(args.output, sep="\t", float_format="%.6f", index=False, header=True) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | import argparse import pandas as pd if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--titers", required=True, help="JSON of complete titer records to be filtered by passage type") parser.add_argument("--passage-type", required=True, help="type of passage for viruses used in titer assays") parser.add_argument("--output", required=True, help="table of filtered titer records by passage type") args = parser.parse_args() df = pd.read_json(args.titers) passaged = (df["serum_passage_category"] == args.passage_type) tdb_passaged = df["index"].apply(lambda index: isinstance(index, list) and args.passage_type in index) tsv_fields = [ "virus_strain", "serum_strain", "serum_id", "source", "titer", "assay_type" ] titers_df = df.loc[(passaged | tdb_passaged), tsv_fields] titers_df.to_csv(args.output, sep="\t", header=False, index=False) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | import argparse import pandas as pd if __name__ == "__main__": parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--node-data", required=True, help="table of node data from one or more timepoints") parser.add_argument("--kde-frequencies", required=True, help="table of KDE frequencies by strain name and timepoint") parser.add_argument("--diffusion-frequencies", required=True, help="table of diffusion frequencies by strain name and timepoint") parser.add_argument("--preferred-frequency-method", choices=["kde", "diffusion"], help="specify which frequency method should be used for the primary frequency column") parser.add_argument("--output", required=True, help="table of merged node data and frequencies") args = parser.parse_args() node_data = pd.read_table(args.node_data) kde_frequencies = pd.read_table(args.kde_frequencies) diffusion_frequencies = pd.read_table(args.diffusion_frequencies) df = node_data.merge( kde_frequencies, how="inner", on=["strain", "timepoint", "is_terminal"] ).merge( diffusion_frequencies, how="inner", on=["strain", "timepoint", "is_terminal"] ) # Annotate frequency by the preferred method if there isn't already a # frequency column defined. if "frequency" not in df.columns: df["frequency"] = df["%s_frequency" % args.preferred_frequency_method] df = df[df["frequency"] > 0.0].copy() df.to_csv(args.output, sep="\t", index=False, header=True) |
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | import argparse from augur.utils import read_node_data import Bio.Phylo import pandas as pd if __name__ == '__main__': parser = argparse.ArgumentParser( description="Convert node data JSONs to a data frame", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--tree", required=True, help="Newick file for the tree used to construct the given node data JSONs") parser.add_argument("--metadata", help="file with metadata associated with viral sequences, one for each segment") parser.add_argument("--jsons", nargs="+", required=True, help="node data JSON(s) from augur") parser.add_argument("--annotations", nargs="+", help="additional annotations to add to the output table in the format of 'key=value' pairs") parser.add_argument("--excluded-fields", nargs="+", help="names of columns to omit from output table") parser.add_argument("--output", required=True, help="tab-delimited file collecting all given node data") parser.add_argument("--include-internal-nodes", action="store_true", help="include data associated with internal nodes in the output table") args = parser.parse_args() # Load tree. tree = Bio.Phylo.read(args.tree, "newick") # Load metadata for samples. metadata = pd.read_csv(args.metadata, sep="\t") # Load one or more node data JSONs into a single dictionary indexed by node name. node_data = read_node_data(args.jsons) # Convert node data into a data frame. # Data are initially loaded with one column per node. # Transposition converts the table to the expected one row per node format. df = pd.DataFrame(node_data["nodes"]).T.rename_axis("strain").reset_index() # Annotate node data with per sample metadata. df = df.merge(metadata, on="strain", suffixes=["", "_metadata"]) # Remove excluded fields if they are in the data frame. df = df.drop(columns=[field for field in args.excluded_fields if field in df.columns]) # Annotate the tip/internal status of each node using the tree. node_terminal_status_by_name = {node.name: node.is_terminal() for node in tree.find_clades()} df["is_terminal"] = df["strain"].map(node_terminal_status_by_name) # Eliminate internal nodes if they have not been requested. if not args.include_internal_nodes: df = df[df["is_terminal"]].copy() # Add any additional annotations requested by the user in the format of # "key=value" pairs where each key becomes a new column with the given # value. if args.annotations: for annotation in args.annotations: key, value = annotation.split("=") df[key] = value # Save the table. df.to_csv(args.output, sep="\t", index=False, header=True) |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | import argparse from augur.frequency_estimators import TreeKdeFrequencies from augur.reconstruct_sequences import load_alignments from augur.utils import annotate_parents_for_tree, write_json import Bio.Phylo import Bio.SeqIO import hashlib import json import pandas as pd # Magic number of maximum length of SHA hash to keep for each clade. MAX_HASH_LENGTH = 7 if __name__ == "__main__": parser = argparse.ArgumentParser( description="Find clades in a tree by distinct amino acid haplotypes", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--tree", required=True, help="Newick tree to identify clades in") parser.add_argument("--translations", required=True, nargs="+", help="FASTA file(s) of amino acid sequences per node") parser.add_argument("--gene-names", required=True, nargs="+", help="gene names corresponding to translations provided") parser.add_argument("--output", required=True, help="JSON of clade annotations for nodes in the given tree") parser.add_argument("--output-tip-clade-table", help="optional table of all clades per tip in the tree") parser.add_argument("--annotations", nargs="+", help="additional annotations to add to the tip clade output table in the format of 'key=value' pairs") args = parser.parse_args() # Load the tree. tree = Bio.Phylo.read(args.tree, "newick") tree = annotate_parents_for_tree(tree) # Load translations for nodes in the given tree and index them by gene name and node name. translations = load_alignments(args.translations, args.gene_names) translations_by_gene_name = {} for gene in translations: translations_by_gene_name[gene] = {} for seq in translations[gene]: translations_by_gene_name[gene][seq.name] = str(seq.seq) clades = {} for node in tree.find_clades(order="preorder", terminal=False): # Assign the current node a clade id based on the hash of its # full-length amino acid sequence. node_sequence = "".join([translations_by_gene_name[gene][node.name] for gene in args.gene_names]) clades[node.name] = {"clade_membership": hashlib.sha256(node_sequence.encode()).hexdigest()[:MAX_HASH_LENGTH]} # Assign the current node's clade id to all of its terminal children. for child in node.clades: if child.is_terminal(): clades[child.name] = clades[node.name] # Count unique clade groups. distinct_clades = {clade["clade_membership"] for clade in clades.values()} print("Found %i distinct clades" % len(distinct_clades)) # Write out the node annotations. write_json({"nodes": clades}, args.output) # Output the optional tip-to-clade table, if requested. if args.output_tip_clade_table: records = [] for tip in tree.find_clades(terminal=True): # Note the tip's own clade assignment which may be distinct from its # parent's. depth = 0 records.append([tip.name, clades[tip.name]["clade_membership"], depth]) parent = tip.parent depth += 1 while True: records.append([tip.name, clades[parent.name]["clade_membership"], depth]) if parent == tree.root: break parent = parent.parent depth += 1 df = pd.DataFrame(records, columns=["tip", "clade_membership", "depth"]) df = df.drop_duplicates(subset=["tip", "clade_membership"]) # Add any additional annotations requested by the user in the format of # "key=value" pairs where each key becomes a new column with the given # value. if args.annotations: for annotation in args.annotations: key, value = annotation.split("=") df[key] = value df.to_csv(args.output_tip_clade_table, sep="\t", index=False) |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | import argparse from augur.utils import write_json import pandas as pd if __name__ == "__main__": parser = argparse.ArgumentParser( description="Normalize fitness by timepoint frequencies for samples in simulated populations.", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--metadata", required=True, help="file with metadata associated with viral sequences, one for each segment") parser.add_argument("--frequencies-table", required=True, help="frequencies table for the current timepoint") parser.add_argument("--frequency-method", required=True, choices=["kde", "diffusion"], help="method used to estimate frequencies") parser.add_argument("--output", required=True, help="JSON of normalized fitness per sample") args = parser.parse_args() # Load metadata. metadata = pd.read_csv(args.metadata, sep="\t") # Load frequencies. frequencies = pd.read_csv(args.frequencies_table, sep="\t") # Filter samples to those with nonzero frequencies at the current timepoint. nonzero_frequencies = frequencies[frequencies["%s_frequency" % args.frequency_method] > 0].copy() # Merge extent sample frequencies with metadata containing fitnesses. nonzero_metadata = nonzero_frequencies.merge( metadata, on="strain" ) # Normalize fitness by maximum fitness. nonzero_metadata["normalized_fitness"] = nonzero_metadata["fitness"] / nonzero_metadata["fitness"].max() # Prepare dictionary of normalized fitnesses by sample. normalized_fitness = { strain: {"normalized_fitness": fitness} for strain, fitness in nonzero_metadata.loc[:, ["strain", "normalized_fitness"]].values } print("Raw fitness: %.2f +/- %.2f" % (nonzero_metadata["fitness"].mean(), nonzero_metadata["fitness"].std())) print("Normalized fitness: %.2f +/- %.2f" % (nonzero_metadata["normalized_fitness"].mean(), nonzero_metadata["normalized_fitness"].std())) # Save normalized fitness as a node data JSON. write_json({"nodes": normalized_fitness}, args.output) |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | import argparse from augur.distance import read_distance_map, get_distance_between_nodes from augur.frequency_estimators import TreeKdeFrequencies from augur.reconstruct_sequences import load_alignments from augur.utils import annotate_parents_for_tree, read_node_data, write_json import Bio.Phylo from collections import defaultdict import json if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tree", help="Newick tree", required=True) parser.add_argument("--frequencies", help="frequencies JSON", required=True) parser.add_argument("--alignment", nargs="+", help="sequence(s) to be used, supplied as FASTA files", required=True) parser.add_argument('--gene-names', nargs="+", type=str, help="names of the sequences in the alignment, same order assumed", required=True) parser.add_argument("--attribute-name", nargs="+", help="name to store distances associated with the given distance map; multiple attribute names are linked to corresponding positional comparison method and distance map arguments", required=True) parser.add_argument("--map", nargs="+", help="JSON providing the distance map between sites and, optionally, sequences present at those sites; the distance map JSON minimally requires a 'default' field defining a default numeric distance and a 'map' field defining a dictionary of genes and one-based coordinates", required=True) parser.add_argument("--date-annotations", help="JSON of branch lengths and date annotations from augur refine for samples in the given tree; required for comparisons to earliest or latest date", required=True) parser.add_argument("--years-back-to-compare", type=int, help="number of years prior to the current season to search for samples to calculate pairwise comparisons with", required=True) parser.add_argument("--output", help="JSON file with calculated distances stored by node name and attribute name", required=True) args = parser.parse_args() # Load tree and annotate parents. tree = Bio.Phylo.read(args.tree, "newick") tree = annotate_parents_for_tree(tree) # Load frequencies. with open(args.frequencies, "r") as fh: frequencies_json = json.load(fh) frequencies = TreeKdeFrequencies.from_json(frequencies_json) pivots = frequencies.pivots # Identify pivots that belong within our search window for past samples. past_pivot_indices = (pivots < pivots[-1]) & (pivots >= pivots[-1] - args.years_back_to_compare) # Load sequences. alignments = load_alignments(args.alignment, args.gene_names) # Index sequences by node name and gene. sequences_by_node_and_gene = defaultdict(dict) for gene, alignment in alignments.items(): for record in alignment: sequences_by_node_and_gene[record.name][gene] = str(record.seq) # Load date annotations and annotate tree with them. date_annotations = read_node_data(args.date_annotations) for node in tree.find_clades(): node.attr = date_annotations["nodes"][node.name] node.attr["num_date"] = node.attr["numdate"] # Identify samples to compare including those in the current timepoint # (pivot) and those in previous timepoints. current_samples = [] past_samples = [] date_by_sample = {} for tip in tree.find_clades(terminal=True): # Samples with nonzero frequencies in the last timepoint are current # samples. Those with one or more nonzero frequencies in the search # window of the past timepoints are past samples. if frequencies.frequencies[tip.name][-1] > 0: current_samples.append(tip.name) elif (frequencies.frequencies[tip.name][past_pivot_indices] > 0).sum() > 0: past_samples.append(tip.name) date_by_sample[tip.name] = tip.attr["numdate"] print("Expecting %i comparisons" % (len(current_samples) * len(past_samples) * len(args.attribute_name))) distances_by_node = {} distance_map_names = [] comparisons = 0 for attribute, distance_map_file in zip(args.attribute_name, args.map): # Load the given distance map. distance_map = read_distance_map(distance_map_file) distance_map_names.append(distance_map.get("name", distance_map_file)) for current_sample in current_samples: if not current_sample in distances_by_node: distances_by_node[current_sample] = {} if not attribute in distances_by_node[current_sample]: distances_by_node[current_sample][attribute] = {} for past_sample in past_samples: # The past is in the past. comparisons += 1 if date_by_sample[past_sample] < date_by_sample[current_sample]: distances_by_node[current_sample][attribute][past_sample] = get_distance_between_nodes( sequences_by_node_and_gene[past_sample], sequences_by_node_and_gene[current_sample], distance_map ) print("Calculated %i comparisons" % comparisons) # Prepare params for export. params = { "attribute": args.attribute_name, "map_name": distance_map_names, "years_back_to_compare": args.years_back_to_compare } # Export distances to JSON. write_json({"params": params, "nodes": distances_by_node}, args.output) |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import argparse from augur.frequency_estimators import float_to_datestring, timestamp_to_float from augur.utils import annotate_parents_for_tree, read_node_data, write_json import Bio.Phylo import json import numpy as np import pandas as pd def get_titer_distance_between_nodes(tree, past_node, current_node, titer_attr="dTiter"): # Find MRCA of tips from one tip up. Sum the titer attribute of interest # while walking up to the MRCA, to avoid an additional pass later. The loop # below stops when the past node is found in the list of the candidate # MRCA's terminals. This test should always evaluate to true when the MRCA # is the root node, so we should not have to worry about trying to find the # parent of the root. current_node_branch_sum = 0.0 mrca = current_node while past_node.name not in mrca.terminals: current_node_branch_sum += mrca.attr[titer_attr] mrca = mrca.parent # Sum the node weights for the other tip from the bottom up until we reach # the MRCA. The value of the MRCA is intentionally excluded here, as it # would represent the branch leading to the MRCA and would be outside the # path between the two tips. past_node_branch_sum = 0.0 current_node = past_node while current_node != mrca: past_node_branch_sum += current_node.attr[titer_attr] current_node = current_node.parent final_sum = past_node_branch_sum + current_node_branch_sum return final_sum if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tree", help="Newick tree", required=True) parser.add_argument("--frequencies", help="frequencies JSON", required=True) parser.add_argument("--model-attribute-name", help="name of attribute to use from titer model file", default="dTiter") parser.add_argument("--attribute-name", help="name to store distances", required=True) parser.add_argument("--model", help="JSON providing the titer tree model", required=True) parser.add_argument("--date-annotations", help="JSON of branch lengths and date annotations from augur refine for samples in the given tree; required for comparisons to earliest or latest date", required=True) parser.add_argument("--months-back-for-current-samples", type=int, help="number of months prior to the last date with estimated frequencies to include samples as current", required=True) parser.add_argument("--years-back-to-compare", type=int, help="number of years prior to the current season to search for samples to calculate pairwise comparisons with", required=True) parser.add_argument("--max-past-samples", type=int, default=200, help="maximum number of past samples to randomly select for comparison to current samples") parser.add_argument("--min-frequency", type=float, default=0.0, help="minimum frequency to consider a sample alive") parser.add_argument("--output", help="JSON file with calculated distances stored by node name and attribute name", required=True) args = parser.parse_args() # Load tree and annotate parents. tree = Bio.Phylo.read(args.tree, "newick") tree = annotate_parents_for_tree(tree) # Make a single pass through the tree in postorder to store a set of all # terminals descending from each node. This uses more memory, but it allows # faster identification of MRCAs between any pair of tips in the tree and # speeds up pairwise distance calculations by orders of magnitude. for node in tree.find_clades(order="postorder"): node.terminals = set() for child in node.clades: if child.is_terminal(): node.terminals.add(child.name) else: node.terminals.update(child.terminals) # Load frequencies. with open(args.frequencies, "r") as fh: frequencies = json.load(fh) pivots = np.array(frequencies.pop("pivots")) # Identify pivots that belong within our search window for past samples. # First, calculate dates associated with the interval for current samples # based on the number of months back requested. Then, calculate interval for # past samples with an upper bound based on the earliest current samples and # a lower bound based on the years back requested. last_pivot_datetime = pd.to_datetime(float_to_datestring(pivots[-1])) last_current_datetime = last_pivot_datetime - pd.DateOffset(months=args.months_back_for_current_samples) last_past_datetime = last_pivot_datetime - pd.DateOffset(years=args.years_back_to_compare) # Find the pivot indices that correspond to the current and past pivots. current_pivot_indices = np.array([ pd.to_datetime(float_to_datestring(pivot)) > last_current_datetime for pivot in pivots ]) past_pivot_indices = np.array([ ((pd.to_datetime(float_to_datestring(pivot)) >= last_past_datetime) & (pd.to_datetime(float_to_datestring(pivot)) <= last_current_datetime)) for pivot in pivots ]) # Load date and titer model annotations and annotate tree with them. annotations = read_node_data([args.date_annotations, args.model]) for node in tree.find_clades(): node.attr = annotations["nodes"][node.name] node.attr["num_date"] = node.attr["numdate"] # Identify samples to compare including those in the current timepoint # (pivot) and those in previous timepoints. current_samples = [] past_samples = [] date_by_sample = {} tips_by_sample = {} for tip in tree.find_clades(terminal=True): # Samples with nonzero frequencies in the last timepoint are current # samples. Those with one or more nonzero frequencies in the search # window of the past timepoints are past samples. frequencies[tip.name]["frequencies"] = np.array(frequencies[tip.name]["frequencies"]) if (frequencies[tip.name]["frequencies"][current_pivot_indices] > args.min_frequency).sum() > 0: current_samples.append(tip.name) tips_by_sample[tip.name] = tip elif (frequencies[tip.name]["frequencies"][past_pivot_indices] > args.min_frequency).sum() > 0: past_samples.append(tip.name) tips_by_sample[tip.name] = tip date_by_sample[tip.name] = tip.attr["numdate"] print("Expecting %i comparisons for %i current and %i past samples" % (len(current_samples) * len(past_samples), len(current_samples), len(past_samples))) distances_by_node = {} comparisons = 0 for current_sample in current_samples: if not current_sample in distances_by_node: distances_by_node[current_sample] = {} if not args.attribute_name in distances_by_node[current_sample]: distances_by_node[current_sample][args.attribute_name] = {} for past_sample in past_samples: # The past is in the past. if date_by_sample[past_sample] < date_by_sample[current_sample]: distances_by_node[current_sample][args.attribute_name][past_sample] = np.around(get_titer_distance_between_nodes( tree, tips_by_sample[past_sample], tips_by_sample[current_sample], args.model_attribute_name ), 4) comparisons += 1 if comparisons % 10000 == 0: print("Completed", comparisons, "comparisons, with last distance of", distances_by_node[current_sample][args.attribute_name][past_sample], flush=True) print("Calculated %i comparisons" % comparisons) # Prepare params for export. params = { "attribute": args.attribute_name, "years_back_to_compare": args.years_back_to_compare } # Export distances to JSON. write_json({"params": params, "nodes": distances_by_node}, args.output, indent=None) |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | import argparse from augur.utils import get_numerical_dates, read_metadata import numpy as np import pandas as pd from treetime.utils import numeric_date from select_strains import read_strain_list if __name__ == '__main__': parser = argparse.ArgumentParser( description="Partition strains into timepoints", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("metadata", help="tab-delimited metadata with columns for strain and date") parser.add_argument("timepoint", help="date for which strains should be partitioned") parser.add_argument("output", help="text file into which strains should be written for the given timepoint") parser.add_argument("--years-back", type=int, help="Number of years prior to the given timepoint to limit strains to") parser.add_argument("--additional-years-back-for-references", type=int, default=5, help="Additional number of years prior to the given timepoint to allow reference strains") parser.add_argument("--reference-strains", help="text file containing list of reference strains that should be included from the original strains even if they were sampled prior to the minimum date determined by the requested number of years before the given timepoint") args = parser.parse_args() # Convert date string to a datetime instance. timepoint = pd.to_datetime(args.timepoint) numeric_timepoint = np.around(numeric_date(timepoint), 2) # Load metadata with strain names and dates. metadata, columns = read_metadata(args.metadata) # Convert string dates with potential ambiguity (e.g., 2010-05-XX) into # floating point dates. dates = get_numerical_dates(metadata, fmt="%Y-%m-%d") # Setup reference strains. if args.reference_strains: reference_strains = read_strain_list(args.reference_strains) else: reference_strains = [] # If a given number of years back has been requested, determine what the # earliest date to accept for strains is. if args.years_back is not None: earliest_timepoint = timepoint - pd.DateOffset(years=args.years_back) numeric_earliest_timepoint = np.around(numeric_date(earliest_timepoint), 2) # If reference strains are provided, calculate the earliest date to # accept those strains. if len(reference_strains) > 0: earliest_reference_timepoint = earliest_timepoint - pd.DateOffset(years=args.additional_years_back_for_references) numeric_earliest_reference_timepoint = np.around(numeric_date(earliest_reference_timepoint), 2) # Find strains sampled prior to the current timepoint. Strains may have # multiple numerical dates, so we filter on the latest (maximum) observed # date per strain. If a requested number of years back is provided, use the # corresponding earliest dates for non-reference and reference strains to # determine whether they are included in the current timepoint. timepoint_strains = [] for strain, strain_dates in dates.items(): strain_date = np.max(strain_dates) if (strain_date <= numeric_timepoint and ((args.years_back is None) or (strain_date >= numeric_earliest_timepoint) or (strain in reference_strains and strain_date >= numeric_earliest_reference_timepoint))): timepoint_strains.append(strain) timepoint_strains = sorted(timepoint_strains) # Write sorted list of strains to disk. with open(args.output, "w") as oh: for strain in timepoint_strains: oh.write(f"{strain}\n") |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | import argparse from augur.utils import json_to_tree, read_tree import Bio.Phylo import json import matplotlib as mpl mpl.use("Agg") from matplotlib import gridspec import matplotlib.pyplot as plt from matplotlib.collections import LineCollection import numpy as np import pandas as pd import sys def timestamp_to_float(time): """Convert a pandas timestamp to a floating point date. """ return time.year + ((time.month - 1) / 12.0) def plot_tree(tree, figure_name, color_by_trait, initial_branch_width, tip_size, start_date, end_date, include_color_bar): """Plot a BioPython Phylo tree in the BALTIC-style. """ # Plot H3N2 tree in BALTIC style from Bio.Phylo tree. mpl.rcParams['savefig.dpi'] = 120 mpl.rcParams['figure.dpi'] = 100 mpl.rcParams['font.weight']=300 mpl.rcParams['axes.labelweight']=300 mpl.rcParams['font.size']=14 yvalues = [node.yvalue for node in tree.find_clades()] y_span = max(yvalues) y_unit = y_span / float(len(yvalues)) # Setup colors. trait_name = color_by_trait traits = [k.attr[trait_name] for k in tree.find_clades()] norm = mpl.colors.Normalize(min(traits), max(traits)) cmap = mpl.cm.viridis # # Setup the figure grid. # if include_color_bar: fig = plt.figure(figsize=(8, 6), facecolor='w') gs = gridspec.GridSpec(2, 1, height_ratios=[14, 1], width_ratios=[1], hspace=0.1, wspace=0.1) ax = fig.add_subplot(gs[0]) colorbar_ax = fig.add_subplot(gs[1]) else: fig = plt.figure(figsize=(8, 4), facecolor='w') gs = gridspec.GridSpec(1, 1) ax = fig.add_subplot(gs[0]) L=len([k for k in tree.find_clades() if k.is_terminal()]) # Setup arrays for tip and internal node coordinates. tip_circles_x = [] tip_circles_y = [] tip_circles_color = [] tip_circle_sizes = [] node_circles_x = [] node_circles_y = [] node_circles_color = [] node_line_widths = [] node_line_segments = [] node_line_colors = [] branch_line_segments = [] branch_line_widths = [] branch_line_colors = [] branch_line_labels = [] for k in tree.find_clades(): ## iterate over objects in tree x=k.attr["num_date"] ## or from x position determined earlier y=k.yvalue ## get y position from .drawTree that was run earlier, but could be anything else if k.parent is None: xp = None else: xp=k.parent.attr["num_date"] ## get x position of current object's parent if x==None: ## matplotlib won't plot Nones, like root x=0.0 if xp==None: xp=x c = 'k' if trait_name in k.attr: c = cmap(norm(k.attr[trait_name])) branchWidth=2 if k.is_terminal(): ## if leaf... s = tip_size ## tip size can be fixed tip_circle_sizes.append(s) tip_circles_x.append(x) tip_circles_y.append(y) tip_circles_color.append(c) else: ## if node... k_leaves = [child for child in k.find_clades() if child.is_terminal()] # Scale branch widths by the number of tips. branchWidth += initial_branch_width * len(k_leaves) / float(L) if len(k.clades)==1: node_circles_x.append(x) node_circles_y.append(y) node_circles_color.append(c) ax.plot([x,x],[k.clades[-1].yvalue, k.clades[0].yvalue], lw=branchWidth, color=c, ls='-', zorder=9, solid_capstyle='round') branch_line_segments.append([(xp, y), (x, y)]) branch_line_widths.append(branchWidth) branch_line_colors.append(c) branch_lc = LineCollection(branch_line_segments, zorder=9) branch_lc.set_color(branch_line_colors) branch_lc.set_linewidth(branch_line_widths) branch_lc.set_label(branch_line_labels) branch_lc.set_linestyle("-") ax.add_collection(branch_lc) # Add circles for tips and internal nodes. tip_circle_sizes = np.array(tip_circle_sizes) ax.scatter(tip_circles_x, tip_circles_y, s=tip_circle_sizes, facecolor=tip_circles_color, edgecolor='none',zorder=11) ## plot circle for every tip ax.scatter(tip_circles_x, tip_circles_y, s=tip_circle_sizes*2, facecolor='k', edgecolor='none', zorder=10) ## plot black circle underneath ax.scatter(node_circles_x, node_circles_y, facecolor=node_circles_color, s=50, edgecolor='none', zorder=10, lw=2, marker='|') ## mark every node in the tree to highlight that it's a multitype tree #ax.set_ylim(-10, y_span - 300) ax.spines['top'].set_visible(False) ## no axes ax.spines['right'].set_visible(False) ax.spines['left'].set_visible(False) ax.grid(axis='x',ls='-',color='grey') ax.tick_params(axis='y',size=0) ax.set_yticklabels([]) if start_date: # Always add a buffer to the left edge of the plot so data up to the # given end date can be clearly seen. ax.set_xlim(left=timestamp_to_float(pd.to_datetime(start_date)) - 2.0) if end_date: # Always add a buffer of 3 months to the right edge of the plot so data # up to the given end date can be clearly seen. ax.set_xlim(right=timestamp_to_float(pd.to_datetime(end_date)) + 0.25) if include_color_bar: cb1 = mpl.colorbar.ColorbarBase( colorbar_ax, cmap=cmap, norm=norm, orientation='horizontal' ) cb1.set_label(color_by_trait) gs.tight_layout(fig) plt.savefig(figure_name) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("tree", help="auspice tree JSON or Newick tree") parser.add_argument("output", help="plotted tree figure") parser.add_argument("--colorby", help="trait in tree to color by", default="num_date") parser.add_argument("--branch_width", help="initial branch width", type=int, default=10) parser.add_argument("--tip_size", help="tip size", type=int, default=10) parser.add_argument("--start-date", help="earliest date to show on the x-axis") parser.add_argument("--end-date", help="latest date to show on the x-axis") parser.add_argument("--include-color-bar", action="store_true", help="display a color bar for the color by option at the bottom of the plot") args = parser.parse_args() if args.tree.endswith(".json"): with open(args.tree, "r") as json_fh: json_tree = json.load(json_fh) # Convert JSON tree layout to a Biopython Clade instance. tree = json_to_tree(json_tree) # Plot the tree. plot_tree( tree, args.output, args.colorby, args.branch_width, args.tip_size, args.start_date, args.end_date, args.include_color_bar ) else: tree = read_tree(args.tree) tree.ladderize() fig, ax = plt.subplots(1, 1, figsize=(8, 8)) Bio.Phylo.draw(tree, axes=ax, label_func=lambda node: "", show_confidence=False) plt.savefig(args.output) |
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 | import argparse import matplotlib as mpl import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import numpy as np import pandas as pd from scipy.stats import pearsonr, spearmanr import seaborn as sns import statsmodels.api as sm np.random.seed(314159) PLOT_THEME_ATTRIBUTES = { "axes.labelsize": 14, "font.size": 18, "legend.fontsize": 12, "xtick.labelsize": 14, "ytick.labelsize": 14, "figure.figsize": [6.0, 4.0], "savefig.dpi": 200, "figure.dpi": 200, "axes.spines.top": False, "axes.spines.right": False, "text.usetex": False } def matthews_correlation_coefficient(tp, tn, fp, fn): """Return Matthews correlation coefficient for values from a confusion matrix. Implementation is based on the definition from wikipedia: https://en.wikipedia.org/wiki/Matthews_correlation_coefficient """ numerator = (tp * tn) - (fp * fn) denominator = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) if denominator == 0: denominator = 1 return float(numerator) / denominator def get_matthews_correlation_coefficient_for_data_frame(freq_df, return_confusion_matrix=False): """Calculate Matthew's correlation coefficient from a given pandas data frame with columns for initial, observed, and predicted frequencies. """ observed_growth = (freq_df["frequency_final"] > freq_df["frequency"]) predicted_growth = (freq_df["projected_frequency"] > freq_df["frequency"]) true_positives = ((observed_growth) & (predicted_growth)).sum() false_positives= ((~observed_growth) & (predicted_growth)).sum() observed_decline = (freq_df["frequency_final"] < freq_df["frequency"]) predicted_decline = (freq_df["projected_frequency"] < freq_df["frequency"]) true_negatives = ((observed_decline) & (predicted_decline)).sum() false_negatives = ((~observed_decline) & (predicted_decline)).sum() mcc = matthews_correlation_coefficient( true_positives, true_negatives, false_positives, false_negatives ) if return_confusion_matrix: confusion_matrix = { "tp": true_positives, "tn": true_negatives, "fp": false_positives, "fn": false_negatives } return mcc, confusion_matrix else: return mcc if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tip-attributes", required=True, help="tab-delimited file describing tip attributes at all timepoints with standardized predictors and weighted distances to the future") parser.add_argument("--tips-to-clades", required=True, help="tab-delimited file of all clades per tip and timepoint from a single tree that includes all tips in the given tip attributes table") parser.add_argument("--forecasts", required=True, help="table of forecasts for the given tips") parser.add_argument("--model-errors", required=True, help="annotated validation errors for the model used to make the given forecasts") parser.add_argument("--bootstrap-samples", type=int, default=100, help="number of bootstrap samples to generate for confidence intervals around absolute forecast errors") parser.add_argument("--population", help="the population being analyzed (e.g., simulated or natural)") parser.add_argument("--sample", help="sample name for population being analyzed") parser.add_argument("--predictors", help="predictors being analyzed") parser.add_argument("--output", required=True, help="validation figure") parser.add_argument("--output-clades-table", help="table of clade frequencies used in left panels") parser.add_argument("--output-ranks-table", help="table of strain ranks used in right panels") args = parser.parse_args() # Define constants for frequency analyses below. min_clade_frequency = 0.15 precision = 4 pseudofrequency = 0.001 number_of_bootstrap_samples = args.bootstrap_samples sns.set_style("white") mpl.rcParams.update(PLOT_THEME_ATTRIBUTES) # Load validation errors for the model used to produce the given forecasts # table. These errors are used to identify the first validation timepoint. model_errors = pd.read_csv( args.model_errors, sep="\t", parse_dates=["validation_timepoint"] ) first_validation_timepoint = model_errors["validation_timepoint"].min().strftime("%Y-%m-%d") # Load tip attributes to be associated with clades and used to calculate # clade frequencies. tips = pd.read_csv( args.tip_attributes, sep="\t", parse_dates=["timepoint"], usecols=["strain", "timepoint", "frequency", "aa_sequence"] ) tips = tips.query("timepoint >= '%s'" % first_validation_timepoint).copy() distinct_tips_with_sequence = tips.groupby(["timepoint", "aa_sequence"]).first().reset_index() # Load mapping of tips to clades based on a single tree that included all of # the tips in the given tip attributes table. tips_to_clades = pd.read_csv( args.tips_to_clades, sep="\t", usecols=["tip", "clade_membership", "depth"] ) tips_to_clades = tips_to_clades.rename(columns={"tip": "strain"}) # Load forecasts for all tips by the model associated with the given model # errors. First, load only a subset of the forecast information to simplify # downstream data frames. forecasts = pd.read_csv( args.forecasts, sep="\t", parse_dates=["timepoint"], usecols=["timepoint", "strain", "frequency", "projected_frequency"] ) # Next, load the complete forecasts data frame for ranking of estimated and # observed closest strains. full_forecasts = pd.read_csv( args.forecasts, sep="\t", parse_dates=["timepoint", "future_timepoint"] ) full_forecasts = full_forecasts.query("timepoint >= '%s'" % first_validation_timepoint).copy() # Map tip attributes to all corresponding clades. clade_tip_initial_frequencies = tips_to_clades.merge( tips, on=["strain"] ) clade_tip_initial_frequencies["future_timepoint"] = clade_tip_initial_frequencies["timepoint"] + pd.DateOffset(months=12) # Calculate the initial frequency of each clade per timepoint. initial_clade_frequencies = clade_tip_initial_frequencies.groupby([ "timepoint", "future_timepoint", "clade_membership" ])["frequency"].sum().reset_index() # Merge clade frequencies between adjacent years. initial_and_observed_clade_frequencies = initial_clade_frequencies.merge( initial_clade_frequencies, left_on=["future_timepoint", "clade_membership"], right_on=["timepoint", "clade_membership"], suffixes=["", "_final"] ).groupby(["timepoint", "clade_membership", "frequency"])["frequency_final"].sum().reset_index() # Select clades with an initial frequency above the defined threshold. large_clades = initial_and_observed_clade_frequencies.query("frequency > %s" % min_clade_frequency).copy() # Find estimated future frequencies of large clades. clade_tip_estimated_frequencies = tips_to_clades.merge( forecasts, on=["strain"] ) estimated_clade_frequencies = clade_tip_estimated_frequencies.groupby( ["timepoint", "clade_membership"] ).aggregate({"projected_frequency": "sum"}).reset_index() # Annotate initial and observed clade frequencies with the estimated future # values. complete_clade_frequencies = large_clades.merge( estimated_clade_frequencies, on=["timepoint", "clade_membership"], suffixes=["", "_other"] ) # Reduce precision of frequency estimates to a reasonable value and # eliminate entries where the clade frequency did not change between the # initial and final timepoints (these are primarily clades that have already # fixed at 100%). complete_clade_frequencies = np.round(complete_clade_frequencies, 2) complete_clade_frequencies = complete_clade_frequencies.query("frequency != frequency_final").copy() # Calculate accuracy of growth and decline classifications. mcc, confusion_matrix = get_matthews_correlation_coefficient_for_data_frame(complete_clade_frequencies, True) growth_accuracy = confusion_matrix["tp"] / float(confusion_matrix["tp"] + confusion_matrix["fp"]) decline_accuracy = confusion_matrix["tn"] / float(confusion_matrix["tn"] + confusion_matrix["fn"]) # Calculate the observed and estimated log growth rates for all clades. complete_clade_frequencies["log_observed_growth_rate"] = ( np.log10((complete_clade_frequencies["frequency_final"] + pseudofrequency) / (complete_clade_frequencies["frequency"] + pseudofrequency)) ) complete_clade_frequencies["log_estimated_growth_rate"] = ( np.log10((complete_clade_frequencies["projected_frequency"] + pseudofrequency) / (complete_clade_frequencies["frequency"] + pseudofrequency)) ) # Calculate the bounds for the clade growth rate display based on values in # observed and estimated rates. log_lower_limit = complete_clade_frequencies.loc[:, ["log_observed_growth_rate", "log_estimated_growth_rate"]].min().min() - 0.1 log_upper_limit = np.ceil(complete_clade_frequencies.loc[:, ["log_observed_growth_rate", "log_estimated_growth_rate"]].max().max()) + 0.1 # Calculate the Pearson's correlation between observed and estimated log # growth rates. r, p = pearsonr( complete_clade_frequencies["log_observed_growth_rate"], complete_clade_frequencies["log_estimated_growth_rate"] ) # Use observed forecasting errors to inspect the accuracy of one-year # lookaheads based on the initial frequency of each clade. complete_clade_frequencies["clade_error"] = complete_clade_frequencies["frequency_final"] - complete_clade_frequencies["projected_frequency"] complete_clade_frequencies["absolute_clade_error"] = np.abs(complete_clade_frequencies["clade_error"]) # Estimate uncertainty of the mean absolute clade error by initial clade # frequency with LOESS fits to bootstraps from the complete data frame. bootstrap_samples = [] for i in range(number_of_bootstrap_samples): complete_clade_frequencies_sample = complete_clade_frequencies.sample(frac=1.0, replace=True).copy() z = sm.nonparametric.lowess( complete_clade_frequencies_sample["absolute_clade_error"].values * 100, complete_clade_frequencies_sample["frequency"].values * 100 ) # Track both the initial frequency and the LOESS fits for each bootstrap # sample. This ensures that the summary statistics calculated downstream # per initial frequency are based on the correct LOESS values. bootstrap_samples.append( pd.DataFrame({ "initial_frequency": z[:, 0], "loess": z[:, 1]} ) ) bootstrap_df = pd.concat(bootstrap_samples) # Calculate the mean and 95% CIs from bootstraps. bootstrap_summary = bootstrap_df.groupby("initial_frequency")["loess"].agg( lower=lambda group: np.percentile(group, 2.5), mean=np.mean, upper=lambda group: np.percentile(group, 97.5) ).reset_index() initial_frequency = bootstrap_summary["initial_frequency"].values mean_lowess_fit = bootstrap_summary["mean"].values upper_lowess_fit = bootstrap_summary["upper"].values lower_lowess_fit = bootstrap_summary["lower"].values # For each timepoint, calculate the percentile rank of each strain based on # both its observed and estimated distance to the future. sorted_df = full_forecasts.dropna().sort_values( ["timepoint"] ).copy() # Filter sorted records by strains with distinct amino acid sequences. sorted_df = sorted_df.merge( distinct_tips_with_sequence, on=["timepoint", "strain"] ) # First, calculate the rank per strain by observed distance to the future. sorted_df["timepoint_rank"] = sorted_df.groupby("timepoint")["weighted_distance_to_future"].rank(pct=True) # Then, calculate the rank by estimated distance to the future. sorted_df["timepoint_estimated_rank"] = sorted_df.groupby("timepoint")["y"].rank(pct=True) # Calculate the Spearman correlation of ranks, to get a measure of the model # fit. rank_rho, rank_p = spearmanr( sorted_df["timepoint_rank"], sorted_df["timepoint_estimated_rank"] ) # Select the observed rank of the estimated closest strain to the future per # timepoint. best_fitness_rank_by_timepoint_df = sorted_df.sort_values( ["timepoint", "y"], ascending=True ).groupby("timepoint")["timepoint_rank"].first().reset_index() # # Summarize model fit by clade frequencies and strain ranks. # fig = plt.figure(figsize=(10, 10), facecolor='w') gs = gridspec.GridSpec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1], wspace=0.1) ticks = np.array([0, 0.2, 0.4, 0.6, 0.8, 1.0]) # # Top-left: Clade growth rate correlations # clade_ax = fig.add_subplot(gs[0]) clade_ax.plot( complete_clade_frequencies["log_observed_growth_rate"], complete_clade_frequencies["log_estimated_growth_rate"], "o", alpha=0.4 ) clade_ax.axhline(color="#cccccc", zorder=-5) clade_ax.axvline(color="#cccccc", zorder=-5) if p < 0.001: p_value = "$p value$ < 0.001" else: p_value = "$p$ = %.3f" % p clade_ax.text( 0.02, 0.15, "Growth accuracy = %.2f\nDecline accuracy = %.2f\nPearson $R^2$ = %.2f\nN = %s" % ( growth_accuracy, decline_accuracy, r ** 2, complete_clade_frequencies.shape[0] ), fontsize=12, horizontalalignment="left", verticalalignment="center", transform=clade_ax.transAxes ) clade_ax.set_xlabel("Observed $log_{10}$ fold change") clade_ax.set_ylabel("Estimated $log_{10}$ fold change") growth_rate_ticks = np.arange(-6, 4, 1) clade_ax.set_xticks(growth_rate_ticks) clade_ax.set_yticks(growth_rate_ticks) clade_ax.set_xlim(log_lower_limit, log_upper_limit) clade_ax.set_ylim(log_lower_limit, log_upper_limit) clade_ax.set_aspect("equal") # # Top-right: Estimated closest strain to the future ranking # rank_ax = fig.add_subplot(gs[1]) median_best_rank = best_fitness_rank_by_timepoint_df["timepoint_rank"].median() rank_ax.hist(best_fitness_rank_by_timepoint_df["timepoint_rank"], bins=np.arange(0, 1.01, 0.05), label=None) rank_ax.axvline( median_best_rank, color="orange", label="median = %i%%" % round(median_best_rank * 100, 0) ) rank_ax.set_xticks(ticks) rank_ax.set_xticklabels(['{:3.0f}%'.format(x*100) for x in ticks]) rank_ax.set_xlim(0, 1) rank_ax.legend( frameon=False ) rank_ax.set_xlabel("Percentile rank by distance\nfor estimated closest strain") rank_ax.set_ylabel("Number of timepoints") # # Bottom-left: Absolute clade forecast errors with uncertainty. # forecast_error_ax = fig.add_subplot(gs[2]) forecast_error_ax.plot( complete_clade_frequencies["frequency"].values * 100, complete_clade_frequencies["absolute_clade_error"].values * 100, "o", alpha=0.2 ) forecast_error_ax.fill_between( initial_frequency, lower_lowess_fit, upper_lowess_fit, alpha=0.1, color="black" ) forecast_error_ax.plot( initial_frequency, mean_lowess_fit, alpha=0.75, color="black" ) forecast_error_ax.set_xlabel("Initial clade frequency") forecast_error_ax.set_ylabel("Absolute forecast error") forecast_error_ax.set_xticks(ticks * 100) forecast_error_ax.set_yticks(ticks * 100) forecast_error_ax.set_xticklabels(['{:3.0f}%'.format(x * 100) for x in ticks]) forecast_error_ax.set_yticklabels(['{:3.0f}%'.format(x * 100) for x in ticks]) forecast_error_ax.set_aspect("equal") # # Bottom-right: Observed vs. estimated percentile rank for all strains at all timepoints. # all_rank_ax = fig.add_subplot(gs[3]) if rank_p < 0.001: rank_p_value = "$p$ < 0.001" else: rank_p_value = "$p$ = %.3f" % rank_p all_rank_ax.plot( sorted_df["timepoint_rank"], sorted_df["timepoint_estimated_rank"], "o", alpha=0.05 ) all_rank_ax.text( 0.45, 0.05, "Spearman $\\rho^2$ = %.2f" % (rank_rho ** 2,), fontsize=12, horizontalalignment="left", verticalalignment="center", transform=all_rank_ax.transAxes ) all_rank_ax.set_xticks(ticks) all_rank_ax.set_yticks(ticks) all_rank_ax.set_xticklabels(['{:3.0f}%'.format(x * 100) for x in ticks]) all_rank_ax.set_yticklabels(['{:3.0f}%'.format(x * 100) for x in ticks]) all_rank_ax.set_xlabel("Observed percentile rank") all_rank_ax.set_ylabel("Estimated percentile rank") all_rank_ax.set_aspect("equal") # Annotate panel labels. panel_labels_dict = { "weight": "bold", "size": 14 } plt.figtext(0.0, 0.97, "A", **panel_labels_dict) plt.figtext(0.5, 0.97, "B", **panel_labels_dict) plt.figtext(0.0, 0.47, "C", **panel_labels_dict) plt.figtext(0.5, 0.47, "D", **panel_labels_dict) gs.tight_layout(fig) plt.savefig(args.output) timepoints_better_than_20th_percentile = (best_fitness_rank_by_timepoint_df["timepoint_rank"] <= 0.2).sum() total_timepoints = best_fitness_rank_by_timepoint_df.shape[0] print( "Estimated strain was in the top 20th percentile at %s of %s (%s%%) timepoints" % ( timepoints_better_than_20th_percentile, total_timepoints, int(np.round((timepoints_better_than_20th_percentile / float(total_timepoints)) * 100)) ) ) if args.output_clades_table: complete_clade_frequencies = complete_clade_frequencies.rename(columns={ "frequency": "initial_frequency", "frequency_final": "observed_future_frequency", "projected_frequency": "estimated_future_frequency" }) complete_clade_frequencies["population"] = args.population complete_clade_frequencies["predictors"] = args.predictors complete_clade_frequencies["error_type"] = "test" if "test" in args.sample else "validation" complete_clade_frequencies.to_csv( args.output_clades_table, sep="\t", header=True, index=False ) if args.output_ranks_table: sorted_df["observed_distance_to_future"] = sorted_df["weighted_distance_to_future"] sorted_df["estimated_distance_to_future"] = sorted_df["y"] sorted_df["observed_rank"] = sorted_df["timepoint_rank"] sorted_df["estimated_rank"] = sorted_df["timepoint_estimated_rank"] sorted_df["population"] = args.population sorted_df["sample"] = args.sample sorted_df["predictors"] = args.predictors sorted_df["error_type"] = "test" if "test" in args.sample else "validation" sorted_df = np.around(sorted_df, 2) sorted_df.to_csv( args.output_ranks_table, sep="\t", header=True, index=False, columns=[ "population", "error_type", "predictors", "timepoint", "strain", "observed_distance_to_future", "estimated_distance_to_future", "observed_rank", "estimated_rank" ] ) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | import argparse import json if __name__ == "__main__": parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--titers-model", required=True, help="titer model JSON from augur titers tree") parser.add_argument("--output", required=True, help="titer model JSON with renamed fields for FRA data") args = parser.parse_args() with open(args.titers_model, "r") as fh: titers_json = json.load(fh) for sample in titers_json["nodes"].keys(): titers_json["nodes"][sample]["fra_cTiter"] = titers_json["nodes"][sample]["cTiter"] titers_json["nodes"][sample]["fra_dTiter"] = titers_json["nodes"][sample]["dTiter"] del titers_json["nodes"][sample]["cTiter"] del titers_json["nodes"][sample]["dTiter"] with open(args.output, "w") as oh: json.dump(titers_json, oh, indent=1) |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | import argparse import numpy as np import pandas as pd if __name__ == '__main__': parser = argparse.ArgumentParser( description="Standardize predictors", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--tip-attributes", required=True, help="tab-delimited file describing tip attributes at all timepoints") parser.add_argument("--tips-to-clades", required=True, help="tab-delimited file of all clades per tip and timepoint") parser.add_argument("--delta-months", required=True, type=int, help="number of months to project clade frequencies into the future") parser.add_argument("--output", required=True, help="tab-delimited file of clades per timepoint and their corresponding tips and tip frequencies at the given delta time in the future") args = parser.parse_args() delta_time_offset = pd.DateOffset(months=args.delta_months) # Load tip attributes, subsetting to relevant frequency and time information. tips = pd.read_csv(args.tip_attributes, sep="\t", parse_dates=["timepoint"]) tips = tips.loc[:, ["strain", "clade_membership", "timepoint", "frequency"]].copy() # Confirm tip frequencies sum to 1 per timepoint. summed_tip_frequencies = tips.groupby("timepoint")["frequency"].sum() print(summed_tip_frequencies) assert all([ np.isclose(total, 1.0, atol=1e-3) for total in summed_tip_frequencies ]) # Identify distinct clades per timepoint. clades = tips.loc[:, ["timepoint", "clade_membership"]].drop_duplicates().copy() clades = clades.rename(columns={"timepoint": "initial_timepoint"}) # Annotate future timepoint. clades["final_timepoint"] = clades["initial_timepoint"] + delta_time_offset # Load mapping of tips to all possible clades at each timepoint. tips_to_clades = pd.read_csv(args.tips_to_clades, sep="\t", parse_dates=["timepoint"]) tips_to_clades = tips_to_clades.loc[:, ["tip", "clade_membership", "depth", "timepoint"]].copy() # Get all tip-clade combinations by timepoint for the distinct clades. future_tips_by_clades = clades.merge( tips_to_clades, how="inner", left_on=["final_timepoint", "clade_membership"], right_on=["timepoint", "clade_membership"] ) # Drop redundant columns. future_tips_by_clades = future_tips_by_clades.drop( columns=["timepoint"] ) # Get the closest clade to each tip by timepoint. This relies on records # being sorted by depth of clade from tip. future_tips_by_clades = future_tips_by_clades.sort_values(["initial_timepoint", "tip", "depth"]).groupby(["initial_timepoint", "tip"]).first().reset_index() # Get frequencies of future tips associated with current clades. future_clade_frequencies = future_tips_by_clades.merge(tips, how="inner", left_on=["tip", "final_timepoint"], right_on=["strain", "timepoint"], suffixes=["", "_tip"]) future_clade_frequencies = future_clade_frequencies.drop( columns=[ "tip", "depth", "clade_membership_tip", "timepoint" ] ) # Confirm that future frequencies sum to 1. print(future_clade_frequencies.groupby("initial_timepoint")["frequency"].sum()) # Confirm the future frequencies of individual clades. print(future_clade_frequencies.groupby(["initial_timepoint", "clade_membership"])["frequency"].sum()) # Left join original clades table with the future tip frequencies to enable # assessment of all current clades including those without future tips. final_clade_frequencies = clades.merge( future_clade_frequencies, how="left", on=["initial_timepoint", "final_timepoint", "clade_membership"] ) # Fill frequency of clades without any future tips with zeros to enable a # simple groupby in the future to get observed future frequencies of all # clades. final_clade_frequencies["frequency"] = final_clade_frequencies["frequency"].fillna(0.0) # Confirm that future frequencies sum to 1. print(final_clade_frequencies.groupby("initial_timepoint")["frequency"].sum()) # Save clade future tip frequencies by timepoint. final_clade_frequencies.to_csv(args.output, sep="\t", na_rep="N/A", index=False) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 | import argparse, sys, os from augur.utils import read_metadata, get_numerical_dates import Bio import Bio.SeqIO from collections import defaultdict from datetime import datetime, timedelta, date import numpy as np from treetime.utils import numeric_date vpm_dict = { 2: 3, 3: 2, 6: 2, 12: 1, } regions = [ ('africa', "", 1.02), ('europe', "EU", 0.74), ('north_america', "NA", 0.54), ('china', "AS", 1.36), ('south_asia', "AS", 1.45), ('japan_korea', "AS", 0.20), ('oceania', "OC", 0.04), ('south_america', "SA", 0.41), ('southeast_asia', "AS", 0.62), ('west_asia', "AS", 0.75) ] subcats = [r[0] for r in regions] def read_strain_list(fname): """ read strain names from a file assuming there is one strain name per line Parameters: ----------- fname : str file name Returns: -------- strain_list : list strain names """ if os.path.isfile(fname): with open(fname, 'r') as fh: strain_list = [x.strip() for x in fh.readlines() if x[0]!='#'] else: print("ERROR: file %s containing strain list not found"%fname) sys.exit(1) return strain_list def count_titer_measurements(fname): """ read how many titer measurements exist for each virus Parameters: ----------- fname : str file name Returns: -------- titer_count : defaultdict(int) dictionary with titer count for each strain """ titer_count = defaultdict(int) if os.path.isfile(fname): with open(fname, 'r') as fh: for line in fh: titer_count[line.split()[0]] += 1 else: print("ERROR: file %s containing strain list not found"%fname) sys.exit(1) return titer_count def populate_categories(metadata): super_category = lambda x: (x['year'], x['month']) category = lambda x: (x['region'], x['year'], x['month']) virus_by_category = defaultdict(list) virus_by_super_category = defaultdict(list) for v in metadata: virus_by_category[category(metadata[v])].append(v) virus_by_super_category[super_category(metadata[v])].append(v) return virus_by_super_category, virus_by_category def flu_subsampling(metadata, viruses_per_month, time_interval, titer_fname=None): # Filter metadata by date using the given time interval. Using numeric dates # here allows users to define time intervals to the day and filter viruses # at that same level of precision. time_interval_start = round(numeric_date(time_interval[1]), 2) time_interval_end = round(numeric_date(time_interval[0]), 2) metadata = { strain: record for strain, record in metadata.items() if time_interval_start <= record["num_date"] <= time_interval_end } #### DEFINE THE PRIORITY if titer_fname: HI_titer_count = count_titer_measurements(titer_fname) def priority(strain): return HI_titer_count[strain] else: print("No titer counts provided - using random priorities") def priority(strain): return np.random.random() subcat_threshold = int(np.ceil(1.0*viruses_per_month/len(subcats))) virus_by_super_category, virus_by_category = populate_categories(metadata) def threshold_fn(x): #x is the subsampling category, in this case a tuple of (region, year, month) # if there are not enough viruses by super category, take everything if len(virus_by_super_category[x[1:]]) < viruses_per_month: return viruses_per_month # otherwise, sort sub categories by strain count sub_counts = sorted([(r, virus_by_super_category[(r, x[1], x[2])]) for r in subcats], key=lambda y:len(y[1])) # if all (the smallest) subcat has more strains than the threshold, return threshold if len(sub_counts[0][1]) > subcat_threshold: return subcat_threshold strains_selected = 0 tmp_subcat_threshold = subcat_threshold for ri, (r, strains) in enumerate(sub_counts): current_threshold = int(np.ceil(1.0*(viruses_per_month-strains_selected)/(len(subcats)-ri))) if r==x[0]: return current_threshold else: strains_selected += min(len(strains), current_threshold) return subcat_threshold selected_strains = [] for cat, val in virus_by_category.items(): val.sort(key=priority, reverse=True) selected_strains.extend(val[:threshold_fn(cat)]) return selected_strains def determine_time_interval(time_interval, resolution): # determine date range to include strains from if time_interval: # explicitly specified datetime_interval = sorted([datetime.strptime(x, '%Y-%m-%d').date() for x in args.time_interval], reverse=True) else: # derived from resolution arguments (explicit takes precedence) if resolution: years_back = int(resolution[:-1]) else: years_back = 3 datetime_interval = [datetime.today().date(), (datetime.today() - timedelta(days=365.25 * years_back)).date()] return datetime_interval def parse_metadata(segments, metadata_files): metadata = {} for segment, fname in zip(segments, metadata_files): tmp_meta, columns = read_metadata(fname) numerical_dates = get_numerical_dates(tmp_meta, fmt='%Y-%m-%d') for x in tmp_meta: tmp_meta[x]['num_date'] = np.mean(numerical_dates[x]) tmp_meta[x]['year'] = int(tmp_meta[x]['num_date']) # Extract month values starting at January == 1 for comparison with # datetime objects. tmp_meta[x]['month'] = int((tmp_meta[x]['num_date'] % 1) * 12) + 1 metadata[segment] = tmp_meta return metadata def parse_sequences(segments, sequence_files): """Load sequence names into a dictionary of sets indexed by segment. """ sequences = {} for segment, filename in zip(segments, sequence_files): sequence_set = Bio.SeqIO.parse(filename, "fasta") sequences[segment] = set() for seq in sequence_set: sequences[segment].add(seq.name) return sequences if __name__ == '__main__': parser = argparse.ArgumentParser( description="Select strains for downstream analysis", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('-v', '--viruses_per_month', type = int, default=15, help='Subsample x viruses per country per month. Set to 0 to disable subsampling.') parser.add_argument('--sequences', nargs='+', help="FASTA file with viral sequences, one for each segment") parser.add_argument('--metadata', nargs='+', help="file with metadata associated with viral sequences, one for each segment") parser.add_argument('--output', help="name of the file to write selected strains to") parser.add_argument('--verbose', action="store_true", help="turn on verbose reporting") parser.add_argument('-l', '--lineage', choices=['h3n2', 'h1n1pdm', 'vic', 'yam'], default='h3n2', type=str, help="single lineage to include (default: h3n2)") parser.add_argument('-r', '--resolution',default='3y', type = str, help = "single resolution to include (default: 3y)") parser.add_argument('-s', '--segments', default=['ha'], nargs='+', type = str, help = "list of segments to include (default: ha)") parser.add_argument('--sampling', default = 'even', type=str, help='sample evenly over regions (even) (default), or prioritize one region (region name), otherwise sample randomly') parser.add_argument('--time-interval', nargs=2, help="explicit time interval to use -- overrides resolutions" " expects YYYY-MM-DD YYYY-MM-DD") parser.add_argument('--titers', help="a text file titers. this will only read in how many titer measurements are available for a each virus" " and use this count as a priority for inclusion during subsampling.") parser.add_argument('--include', help="a text file containing strains (one per line) that will be included regardless of subsampling") parser.add_argument('--max-include-range', type=float, default=5, help="number of years prior to the lower date limit for reference strain inclusion") parser.add_argument('--exclude', help="a text file containing strains (one per line) that will be excluded") args = parser.parse_args() time_interval = determine_time_interval(args.time_interval, args.resolution) # derive additional lower inclusion date for "force-included strains" lower_reference_cutoff = date(year = time_interval[1].year - args.max_include_range, month=1, day=1) upper_reference_cutoff = time_interval[0] # read strains to exclude excluded_strains = read_strain_list(args.exclude) if args.exclude else [] # read strains to include included_strains = read_strain_list(args.include) if args.include else [] # read in sequence names to determine which sequences already passed upstream filters sequence_names_by_segment = parse_sequences(args.segments, args.sequences) # read in meta data, parse numeric dates metadata = parse_metadata(args.segments, args.metadata) # eliminate all metadata entries that do not have sequences filtered_metadata = {} for segment in metadata: filtered_metadata[segment] = {} for name in metadata[segment]: if name in sequence_names_by_segment[segment]: filtered_metadata[segment][name] = metadata[segment][name] # filter down to strains with sequences for all required segments guide_segment = args.segments[0] strains_with_all_segments = set.intersection(*(set(filtered_metadata[x].keys()) for x in args.segments)) # exclude outlier strains strains_with_all_segments.difference_update(set(excluded_strains)) # subsample by region, month, year selected_strains = flu_subsampling({x:filtered_metadata[guide_segment][x] for x in strains_with_all_segments}, args.viruses_per_month, time_interval, titer_fname=args.titers) # add strains that need to be included for strain in included_strains: if strain in strains_with_all_segments and strain not in selected_strains: # Do not include strains sampled too far in the past or strains # sampled from the future relative to the requested build interval. if (filtered_metadata[guide_segment][strain]['year'] >= lower_reference_cutoff.year and filtered_metadata[guide_segment][strain]['num_date'] <= numeric_date(upper_reference_cutoff)): selected_strains.append(strain) # Confirm that none of the selected strains were sampled outside of the # requested interval. for strain in selected_strains: assert filtered_metadata[guide_segment][strain]['num_date'] <= numeric_date(upper_reference_cutoff) # write the list of selected strains to file with open(args.output, 'w') as ofile: ofile.write('\n'.join(selected_strains)) |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | import argparse import datetime import pandas as pd def float_to_datestring(time): """Convert a floating point date from TreeTime `numeric_date` to a date string """ # Extract the year and remainder from the floating point date. year = int(time) remainder = time - year # Calculate the day of the year (out of 365 + 0.25 for leap years). tm_yday = int(remainder * 365.25) if tm_yday == 0: tm_yday = 1 # Construct a date object from the year and day of the year. date = datetime.datetime.strptime("%s-%s" % (year, tm_yday), "%Y-%j") # Build the date string with zero-padded months and days. date_string = "%s-%.2i-%.2i" % (date.year, date.month, date.day) return date_string if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--metadata", help="metadata for simulated sequences") parser.add_argument("--start-year", default=2000.0, type=float, help="year to start simulated dates from") parser.add_argument("--generations-per-year", default=200.0, type=float, help="number of generations to map to a single yeasr") parser.add_argument("--output", help="metadata with standardized dates and nonzero fitness records") args = parser.parse_args() df = pd.read_csv(args.metadata, sep="\t") df["num_date"] = args.start_year + (df["generation"] / args.generations_per_year) df["date"] = df["num_date"].apply(float_to_datestring) df["year"] = pd.to_datetime(df["date"]).dt.year df["month"] = pd.to_datetime(df["date"]).dt.month # Omit records with a fitness of zero. df[df["fitness"] > 0].to_csv(args.output, header=True, index=False, sep="\t") |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | import argparse import json if __name__ == '__main__': parser = argparse.ArgumentParser( description="Convert titer substitution model to distance map", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--model", required=True, help="JSON from titer substitution model") parser.add_argument("--output", required=True, help="distance map JSON") args = parser.parse_args() # Load titer model. with open(args.model, "r") as fh: model = json.load(fh) # Prepare a distance map for the model. distance_map = { "name": "titer_substitution_model", "default": 0.0, "map": {} } # Convert values like: # "HA1:E173K": 0.4656 # to distance map format. for substitution, weight in model["substitution"].items(): gene, mutation = substitution.split(":") ancestral = mutation[0] derived = mutation[-1] position = mutation[1:-1] if ancestral != "X" and derived != "X": if gene not in distance_map["map"]: distance_map["map"][gene] = {} if position not in distance_map["map"][gene]: distance_map["map"][gene][position] = [] distance_map["map"][gene][position].append({ "from": ancestral, "to": derived, "weight": weight }) # Save the distance map. with open(args.output, "w") as oh: json.dump(distance_map, oh, sort_keys=True, indent=1) |
327 | shell: "echo Environment built" |
341 342 343 344 345 346 | shell: """ python3 scripts/concatenate_tables.py \ --tables {input.errors} \ --output {output.errors} """ |
354 355 356 357 358 359 | shell: """ python3 scripts/concatenate_tables.py \ --tables {input.coefficients} \ --output {output.coefficients} """ |
373 374 375 376 377 378 | shell: """ python3 scripts/collect_tables.py \ --tables {input} \ --output {output.clades} """ |
386 387 388 389 390 391 | shell: """ python3 scripts/collect_tables.py \ --tables {input} \ --output {output.ranks} """ |
397 | shell: "gs -dBATCH -dNOPAUSE -q -sDEVICE=pdfwrite -sOutputFile={output} {input}" |
522 523 524 525 526 527 528 | shell: """ while read original_name new_name do ln manuscript/figures/$original_name manuscript/$new_name done < {input.figure_names} """ |
547 548 549 550 551 552 553 554 | shell: """ cd manuscript pdflatex -draftmode {params.title} bibtex {params.title} pdflatex -draftmode {params.title} pdflatex {params.title} """ |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | import argparse from augur.utils import read_node_data import json from forecast.fitness_predictors import inverse_cross_immunity_amplitude, cross_immunity_cost if __name__ == '__main__': parser = argparse.ArgumentParser( description="Calculate cross-immunity", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--frequencies", required=True, help="JSON of frequencies per sample") parser.add_argument("--distances", required=True, help="JSON of distances between samples") parser.add_argument("--date-annotations", required=True, help="JSON of branch lengths and date annotations from augur refine for samples in the given tree") parser.add_argument("--distance-attributes", nargs="+", required=True, help="names of attributes to use from the given distances JSON") parser.add_argument("--immunity-attributes", nargs="+", required=True, help="names of attributes to use for the calculated cross-immunities") parser.add_argument("--decay-factors", nargs="+", required=True, type=float, help="list of decay factors (d_0) for each given immunity attribute") parser.add_argument("--years-to-wane", type=int, help="number of years after which immunity wanes completely") parser.add_argument("--output", required=True, help="cross-immunities calculated from the given distances and frequencies") args = parser.parse_args() # Load frequencies. with open(args.frequencies, "r") as fh: frequencies = json.load(fh) # Identify maximum frequency per sample. max_frequency_per_sample = { sample: float(max(sample_frequencies["frequencies"])) for sample, sample_frequencies in frequencies.items() if sample not in ["pivots", "generated_by"] and not sample.startswith("count") } current_timepoint = frequencies["pivots"][-1] # Load distances. with open(args.distances, "r") as fh: distances = json.load(fh) distances = distances["nodes"] # Load date annotations and annotate tree with them. date_annotations = read_node_data(args.date_annotations) date_by_node_name = {} for node, annotations in date_annotations["nodes"].items(): date_by_node_name[node] = annotations["numdate"] """ "A/Acre/15093/2010": { "ep": 9, "ne": 8, "rb": 3 }, """ if args.years_to_wane is not None: print("Waning effect with max years of %i" % args.years_to_wane) else: print("No waning effect") # Calculate cross-immunity for distances defined by the given attributes. cross_immunities = {} for sample, sample_distances in distances.items(): for distance_attribute, immunity_attribute, decay_factor in zip(args.distance_attributes, args.immunity_attributes, args.decay_factors): if distance_attribute not in sample_distances: continue if sample not in cross_immunities: cross_immunities[sample] = {} # Calculate cross-immunity cost from all distances to the current # sample. This negative value increases for samples that are # increasingly distant from previous samples. cross_immunity = 0.0 for past_sample, distance in sample_distances[distance_attribute].items(): # Calculate effect of waning immunity. if args.years_to_wane is not None: waning_effect = max(1 - ((current_timepoint - date_by_node_name[past_sample]) / args.years_to_wane), 0) else: waning_effect = 1.0 # Calculate cost of cross-immunity with waning. if waning_effect > 0: cross_immunity += waning_effect * max_frequency_per_sample[past_sample] * cross_immunity_cost( distance, decay_factor ) cross_immunities[sample][immunity_attribute] = -1 * cross_immunity # Export cross-immunities to JSON. with open(args.output, "w") as oh: json.dump({"nodes": cross_immunities}, oh, indent=1, sort_keys=True) |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 | import argparse import csv import cv2 import json import numpy as np import pandas as pd from scipy.optimize import minimize import sys import time from forecast.fitness_model import get_train_validate_timepoints from forecast.metrics import add_pseudocounts_to_frequencies, negative_information_gain from forecast.metrics import mean_absolute_error, sum_of_squared_errors, root_mean_square_error from weighted_distances import get_distances_by_sample_names, get_distance_matrix_by_sample_names MAX_PROJECTED_FREQUENCY = 1e3 FREQUENCY_TOLERANCE = 1e-3 np.random.seed(314159) def sum_of_differences(observed, estimated, y_diff, **kwargs): """ Calculates the sum of squared errors for observed and estimated values. Parameters ---------- observed : numpy.ndarray observed values estimated : numpy.ndarray estimated values y_diff : numpy.ndarray differences between observed and estimated values Returns ------- float : sum of differences between estimated and observed future values """ return np.sum(y_diff) class ExponentialGrowthModel(object): def __init__(self, predictors, delta_time, l1_lambda, cost_function): """Construct an empty exponential growth model instance. Parameters ---------- predictors : list a list of predictors to estimate coefficients for delta_time : float number of years into the future to project frequencies l1_lambda : float hyperparameter to scale L1 regularization penalty for non-zero coefficients cost_function : callable function returning the error to be minimized between observed and estimated values Returns ------- ExponentialGrowthModel """ self.predictors = predictors self.delta_time = delta_time self.l1_lambda = l1_lambda self.cost_function = cost_function def calculate_mean_stds(self, X, predictors): """Calculate mean standard deviations of predictors by timepoints prior to fitting. Parameters ---------- X : pandas.DataFrame standardized tip attributes by timepoint predictors : list names of predictors for which mean standard deviations should be calculated Returns ------- ndarray : mean standard deviation per predictor across all timepoints """ # Note that the pandas standard deviation method ignores missing data # whereas numpy requires the use of specific NaN-aware functions (nanstd). return X.loc[:, ["timepoint"] + predictors].groupby("timepoint").std().mean().values def standardize_predictors(self, predictors, mean_stds, initial_frequencies): """Standardize the values for the given predictors by centering on the mean of each predictor and scaling by the mean standard deviation provided. Parameters ---------- predictors : ndarray matrix of values per sample (rows) and predictor (columns) mean_stds : ndarray mean standard deviations of predictors across all training timepoints initial_frequencies : ndarray initial frequencies of samples corresponding to each row of the given predictors Returns ------- ndarray : standardized predictor values """ means = np.average(predictors, weights=initial_frequencies, axis=0) variances = np.average((predictors - means) ** 2, weights=initial_frequencies, axis=0) stds = np.sqrt(variances) nonzero_stds = np.where(stds)[0] if len(nonzero_stds) == 0: return predictors standardized_predictors = predictors standardized_predictors[:, nonzero_stds] = (predictors[:, nonzero_stds] - means[nonzero_stds]) / stds[nonzero_stds] return standardized_predictors def get_fitnesses(self, coefficients, predictors): """Apply the coefficients to the predictors and sum them to get strain fitnesses. Parameters ---------- coefficients : ndarray or list coefficients for given predictors predictors : ndarray predictor values per sample (n x p matrix for p predictors and n samples) Returns ------- ndarray : fitnesses per sample """ return np.sum(predictors * coefficients, axis=-1) def project_frequencies(self, initial_frequencies, fitnesses, delta_time): """Project the given initial frequencies into the future by the given delta time based on the given fitnesses. Returns the projected frequencies normalized to sum to 1. Parameters ---------- initial_frequencies : ndarray floating point frequencies for all samples in a timepoint fitnesses : ndarray floating point fitnesses for all samples in same order as given frequencies delta_time : float number of years to project into the future Returns ------- ndarray : projected and normalized frequencies """ # Exponentiate the fitnesses and multiply them by strain frequencies. projected_frequencies = initial_frequencies * np.exp(fitnesses * self.delta_time) # Replace infinite values a very large number that can still be summed # across all timepoints. This addresses the case of buffer overflows in # exponentiation which can produce both of these problematic values. projected_frequencies[np.isinf(projected_frequencies)] = MAX_PROJECTED_FREQUENCY # Sum the projected frequencies. total_projected_frequencies = projected_frequencies.sum() # Normalize the projected frequencies. projected_frequencies = projected_frequencies / total_projected_frequencies # Confirm that projected frequencies sum to 1. assert np.isclose(projected_frequencies.sum(), np.ones(1), atol=FREQUENCY_TOLERANCE) # Confirm that all projected frequencies are proper numbers. assert np.isnan(projected_frequencies).sum() == 0 return projected_frequencies def _fit(self, coefficients, X, y, use_l1_penalty=True): """Calculate the error between observed and estimated values for the given parameters and data. Parameters ---------- coefficients : ndarray coefficients for each of the model's predictors X : pandas.DataFrame standardized tip attributes by timepoint y : pandas.DataFrame final clade frequencies at delta time in the future from each timepoint in the given tip attributes table Returns ------- float : error between estimated values using the given coefficients and input data and the observed values """ # Estimate final frequencies. y_hat = self.predict(X, coefficients) # Merge estimated and observed frequencies. The left join enables # tracking of clades that die in the future and are therefore not # observed in the future frequencies data frame. frequencies = y_hat.merge( y, how="left", on=["timepoint", "clade_membership"], suffixes=["_estimated", "_observed"] ) frequencies["frequency_observed"] = frequencies["frequency_observed"].fillna(0.0) # Calculate initial frequencies for use by cost function. initial_frequencies = X.groupby([ "timepoint", "clade_membership" ])["frequency"].sum().reset_index() # Annotate future frequencies with initial frequencies. frequencies = frequencies.merge( initial_frequencies, how="inner", on=["timepoint", "clade_membership"] ) # Calculate the error between the observed and estimated frequencies. error = self.cost_function( frequencies["frequency_observed"], frequencies["frequency_estimated"], initial=frequencies["frequency"] ) if use_l1_penalty: l1_penalty = self.l1_lambda * np.abs(coefficients).sum() else: l1_penalty = 0.0 return error + l1_penalty def fit(self, X, y): """Fit a model to the given input data, producing beta coefficients for each of the model's predictors. Coefficients are stored in the `coef_` attribute, after the pattern of scikit-learn models. Parameters ---------- X : pandas.DataFrame standardized tip attributes by timepoint y : pandas.DataFrame final clade frequencies at delta time in the future from each timepoint in the given tip attributes table Returns ------- float : model training error """ # Calculate mean standard deviations of predictors by timepoints prior # to fitting. self.mean_stds_ = self.calculate_mean_stds(X, self.predictors) # Find coefficients that minimize the model's cost function. if hasattr(self, "coef_"): # Use the previous coefficients +/- a small random offset (+/- 0.05) # to prevent getting stuck in local minima. initial_coefficients = self.coef_ + (0.1 * np.random.random(len(self.predictors)) - 0.05) else: # If no previous coefficients exist, sample random values between -0.5 and 0.5. initial_coefficients = np.random.random(len(self.predictors)) - 0.5 results = minimize( self._fit, initial_coefficients, args=(X, y), method="Nelder-Mead", options={"disp": False} ) self.coef_ = results.x training_error = self.score(X, y) return training_error def predict(self, X, coefficients=None, mean_stds=None): """Calculate the estimate final frequencies of all clades in the given tip attributes data frame using previously calculated beta coefficients. Parameters ---------- X : pandas.DataFrame standardized tip attributes by timepoint coefficients : ndarray optional coefficients to use for each of the model's predictors instead of the model's currently defined coefficients mean_stds : ndarray optional mean standard deviations of predictors across all training timepoints Returns ------- pandas.DataFrame estimated final clade frequencies at delta time in the future for each clade from each timepoint in the given tip attributes table """ # Use model coefficients, if none are provided. if coefficients is None: coefficients = self.coef_ if mean_stds is None: mean_stds = self.mean_stds_ estimated_frequencies = [] for timepoint, timepoint_df in X.groupby("timepoint"): # Select frequencies from timepoint. initial_frequencies = timepoint_df["frequency"].values # Select predictors from the timepoint. predictors = timepoint_df.loc[:, self.predictors].values # Standardize predictors by timepoint centering by means at # timepoint and mean standard deviation provided. standardized_predictors = self.standardize_predictors(predictors, mean_stds, initial_frequencies) # Calculate fitnesses. fitnesses = self.get_fitnesses(coefficients, standardized_predictors) # Project frequencies. projected_frequencies = self.project_frequencies( initial_frequencies, fitnesses, self.delta_time ) # Sum the estimated frequencies by clade. projected_timepoint_df = timepoint_df[["timepoint", "clade_membership"]].copy() projected_timepoint_df["frequency"] = projected_frequencies projected_clade_frequencies = projected_timepoint_df.groupby([ "timepoint", "clade_membership" ])["frequency"].sum().reset_index() estimated_frequencies.append(projected_clade_frequencies) # Collect all estimated frequencies by timepoint. estimated_frequencies = pd.concat(estimated_frequencies) return estimated_frequencies def score(self, X, y): """Calculate model error between the estimated final clade frequencies for the given tip attributes, `X`, and the observed final clade frequencies in `y`. Parameters ---------- X : pandas.DataFrame standardized tip attributes by timepoint y : pandas.DataFrame final clade frequencies at delta time in the future from each timepoint in the given tip attributes table Returns ------- float : model error """ return self._fit(self.coef_, X, y, use_l1_penalty=False) class DistanceExponentialGrowthModel(ExponentialGrowthModel): def __init__(self, predictors, delta_time, l1_lambda, cost_function, distances): super().__init__(predictors, delta_time, l1_lambda, cost_function) self.distances = distances def _fit(self, coefficients, X, y, use_l1_penalty=True, calculate_optimal_distance=False): """Calculate the error between observed and estimated values for the given parameters and data. Parameters ---------- coefficients : ndarray coefficients for each of the model's predictors X : pandas.DataFrame standardized tip attributes by timepoint y : pandas.DataFrame final weighted distances at delta time in the future from each timepoint in the given tip attributes table Returns ------- float : error between estimated values using the given coefficients and input data and the observed values """ # Estimate target values. y_hat = self.predict(X, coefficients) # Calculate EMD for each timepoint in the estimated values and sum that # distance across all timepoints. error = 0.0 count = 0 for timepoint, timepoint_df in y_hat.groupby("timepoint"): samples_a = timepoint_df["strain"] sample_a_initial_frequencies = timepoint_df["frequency"].values.astype(np.float32) sample_a_frequencies = timepoint_df["projected_frequency"].values.astype(np.float32) future_timepoint_df = y[y["timepoint"] == timepoint] assert future_timepoint_df.shape[0] > 0 samples_b = future_timepoint_df["strain"] sample_b_frequencies = future_timepoint_df["frequency"].values.astype(np.float32) distance_matrix = get_distance_matrix_by_sample_names( samples_a, samples_b, self.distances ).astype(np.float32) # Calculate the optimal distance to the future timepoint by mapping # the frequency of each future strain to the closest strain in the # current timepoint. if calculate_optimal_distance: # For each strain in the future timepoint, identify the closest # strain in the current timepoint. This is an array of current # strain indices (one index per future strain). closest_strain_to_future = np.argmin(distance_matrix, axis=0) # Sum the frequencies of the future strains across each closest # strain in the current timepoint. This can and will often # result in a few current strains accuring most of the future # frequencies. estimated_frequencies = np.zeros_like(sample_a_frequencies) for i in range(sample_b_frequencies.shape[0]): estimated_frequencies[closest_strain_to_future[i]] += sample_b_frequencies[i] # Calculate earth mover's distance to the future based on this # optimal (or, at least, greedy) mapping of strains between # timepoints. The resulting EMD value should be the best any # model can hope to perform and establishes a lower bound for # all models. self.optimal_model_emd, _, optimal_model_flow = cv2.EMD( estimated_frequencies, sample_b_frequencies, cv2.DIST_USER, cost=distance_matrix ) # Estimate the distance between the model's estimated future and the # observed future populations. model_emd, _, self.model_flow = cv2.EMD( sample_a_frequencies, sample_b_frequencies, cv2.DIST_USER, cost=distance_matrix ) error += model_emd count += 1 error = error / float(count) if use_l1_penalty: l1_penalty = self.l1_lambda * np.abs(coefficients).sum() else: l1_penalty = 0.0 return error + l1_penalty def _fit_distance(self, coefficients, X, y, use_l1_penalty=True): """Calculate the error between observed and estimated values for the given parameters and data. Parameters ---------- coefficients : ndarray coefficients for each of the model's predictors X : pandas.DataFrame standardized tip attributes by timepoint y : pandas.DataFrame final weighted distances at delta time in the future from each timepoint in the given tip attributes table Returns ------- float : error between estimated values using the given coefficients and input data and the observed values """ # Estimate target values. y_hat = self.predict(X, coefficients) # Calculate weighted distance to the future for each timepoint in the # estimated values and sum that distance across all timepoints. error = 0.0 null_error = 0.0 count = 0 for timepoint, timepoint_df in y_hat.groupby("timepoint"): samples_a = timepoint_df["strain"] sample_a_initial_frequencies = timepoint_df["frequency"].values sample_a_frequencies = timepoint_df["projected_frequency"].values sample_a_weighted_distance_to_future = timepoint_df["weighted_distance_to_future"].values future_timepoint_df = y[y["timepoint"] == timepoint] assert future_timepoint_df.shape[0] > 0 samples_b = future_timepoint_df["strain"] sample_b_frequencies = future_timepoint_df["frequency"].values sample_b_weighted_distance_to_present = future_timepoint_df["weighted_distance_to_present"].values d_t_u = (sample_a_initial_frequencies * sample_a_weighted_distance_to_future).sum() d_u_hat_u = (sample_a_frequencies * sample_a_weighted_distance_to_future).sum() d_u_u = (sample_b_frequencies * sample_b_weighted_distance_to_present).sum() null_error += d_t_u error += (d_u_hat_u - d_u_u) / d_t_u count += 1 null_error = null_error / float(count) error = error / float(count) if use_l1_penalty: l1_penalty = self.l1_lambda * np.abs(coefficients).sum() else: l1_penalty = 0.0 return error + l1_penalty def predict(self, X, coefficients=None, mean_stds=None): """Calculate the estimated final weighted distance between tips at each timepoint and at that timepoint plus delta months in the future. Parameters ---------- X : pandas.DataFrame standardized tip attributes by timepoint coefficients : ndarray optional coefficients to use for each of the model's predictors instead of the model's currently defined coefficients mean_stds : ndarray optional mean standard deviations of predictors across all training timepoints Returns ------- pandas.DataFrame estimated weighted distances at delta time in the future for each tip from each timepoint in the given tip attributes table """ # Use model coefficients, if none are provided. if coefficients is None: coefficients = self.coef_ model_is_fit = True else: model_is_fit = False if mean_stds is None: mean_stds = self.mean_stds_ estimated_targets = [] for timepoint, timepoint_df in X.groupby("timepoint"): # Select frequencies from timepoint. initial_frequencies = timepoint_df["frequency"].values # Select predictors from the timepoint. predictors = timepoint_df.loc[:, self.predictors].values # Standardize predictors by timepoint centering by means at # timepoint and mean standard deviation provided. mean_stds = timepoint_df.loc[:, self.predictors].std().values standardized_predictors = self.standardize_predictors(predictors, mean_stds, initial_frequencies) # Calculate fitnesses. fitnesses = self.get_fitnesses(coefficients, standardized_predictors) # Project frequencies. projected_frequencies = self.project_frequencies( initial_frequencies, fitnesses, self.delta_time ) # Calculate observed distance between current tips and the future # using projected frequencies and weighted distances to the future. columns_to_extract = ["timepoint", "strain", "frequency"] optional_columns = ["weighted_distance_to_present", "weighted_distance_to_future"] for column in optional_columns: if column in timepoint_df.columns: columns_to_extract.append(column) projected_timepoint_df = timepoint_df[columns_to_extract].copy() projected_timepoint_df["fitness"] = fitnesses projected_timepoint_df["projected_frequency"] = projected_frequencies if model_is_fit: # Calculate estimate distance between current tips and future tips # based on projections of current tips. estimated_weighted_distance_to_future = [] for current_tip, current_tip_frequency in projected_timepoint_df.loc[:, ["strain", "frequency"]].values: weighted_distance_to_future = 0.0 for other_tip, other_tip_projected_frequency in projected_timepoint_df.loc[:, ["strain", "projected_frequency"]].values: weighted_distance_to_future += other_tip_projected_frequency * self.distances[current_tip][other_tip] estimated_weighted_distance_to_future.append(weighted_distance_to_future) projected_timepoint_df["y"] = np.array(estimated_weighted_distance_to_future) else: projected_timepoint_df["y"] = np.nan estimated_targets.append(projected_timepoint_df) # Collect all estimated targets by timepoint. estimated_targets = pd.concat(estimated_targets, ignore_index=True) return estimated_targets def cross_validate(model_class, model_kwargs, data, targets, train_validate_timepoints, coefficients=None, group_by="clade_membership", include_attributes=False): """Calculate cross-validation scores for the given data and targets across the given train/validate timepoints. Parameters ---------- model : ExponentialGrowthModel an instance of a model with defined hyperparameters including a list of predictors to use for fitting data : pandas.DataFrame standardized input attributes to use for model fitting targets : pandas.DataFrame observed outputs to fit the model to train_validate_timepoints : list a list of dictionaries of lists indexed by "train" and "validate" keys and containing timepoints to use for model training and validation, respectively coefficients : ndarray an optional array of fixed coefficients for the given model's predictors to use when calculating cross-validation error for specific models (e.g., naive forecasts) group_by : string column of the tip attributes by which they should be grouped to calculate the total number of samples in the model (e.g., group by clade or strain) include_attributes : boolean specifies whether tip attribute data used to train/validate models should be included in the output per training window Returns ------- list a list of dictionaries containing cross-validation results with scores, training and validation results, and beta coefficients per timepoint """ results = [] differences_of_model_and_naive_errors = [] previous_coefficients = None for timepoints in train_validate_timepoints: model = model_class(**model_kwargs) if previous_coefficients is not None: model.coef_ = previous_coefficients # Get training and validation timepoints. training_timepoints = pd.to_datetime(timepoints["train"]) validation_timepoint = pd.to_datetime(timepoints["validate"]) # Get training data by timepoints. training_X = data[data["timepoint"].isin(training_timepoints)].copy() training_y = targets[targets["timepoint"].isin(training_timepoints)].copy() # Fit a model to the training data. if coefficients is None: start_time = time.time() training_error = model.fit(training_X, training_y) end_time = time.time() previous_coefficients = model.coef_ null_training_error = model._fit(np.zeros_like(model.coef_), training_X, training_y) else: start_time = end_time = time.time() model.coef_ = coefficients model.mean_stds_ = model.calculate_mean_stds(training_X, model.predictors) training_error = model.score(training_X, training_y) null_training_error = training_error # Get validation data by timepoints. validation_X = data[data["timepoint"] == validation_timepoint].copy() validation_y = targets[targets["timepoint"] == validation_timepoint].copy() # Calculate the model score for the validation data. validation_error = model.score(validation_X, validation_y) null_validation_error = model._fit(np.zeros_like(model.coef_), validation_X, validation_y, calculate_optimal_distance=True) optimal_validation_error = model.optimal_model_emd differences_of_model_and_naive_errors.append(validation_error - null_validation_error) print( "%s\t%s\t%.2f\t%.2f\t%.2f\t%.2f\t%.2f\t%s\t%.2f\t%.2f" % ( training_timepoints[-1].strftime("%Y-%m"), validation_timepoint.strftime("%Y-%m"), training_error, null_training_error, validation_error, null_validation_error, optimal_validation_error, model.coef_, (np.array(differences_of_model_and_naive_errors) < 0).sum() / float(len(differences_of_model_and_naive_errors)), end_time - start_time ), flush=True ) # Get the estimated frequencies for training and validation sets to export. training_y_hat = model.predict(training_X) validation_y_hat = model.predict(validation_X) # Convert timestamps to a serializable format. for df in [training_X, training_y, training_y_hat, validation_X, validation_y, validation_y_hat]: for column in ["timepoint", "future_timepoint"]: if column in df.columns: df[column] = df[column].dt.strftime("%Y-%m-%d") # Store training results, beta coefficients, and validation results. result = { "predictors": model.predictors, "training_data": { "y": training_y.to_dict(orient="records"), "y_hat": training_y_hat.to_dict(orient="records") }, "training_n": training_X[group_by].unique().shape[0], "training_error": training_error, "coefficients": model.coef_.tolist(), "mean_stds": model.mean_stds_.tolist(), "validation_data": { "y": validation_y.to_dict(orient="records"), "y_hat": validation_y_hat.to_dict(orient="records") }, "validation_n": validation_X[group_by].unique().shape[0], "validation_error": validation_error, "null_validation_error": null_validation_error, "optimal_validation_error": optimal_validation_error, "last_training_timepoint": training_timepoints[-1].strftime("%Y-%m-%d"), "validation_timepoint": validation_timepoint.strftime("%Y-%m-%d") } # Include tip attributes, if requested. if include_attributes: result["training_data"]["X"] = training_X.to_dict(orient="records") result["validation_data"]["X"] = validation_X.to_dict(orient="records") results.append(result) # Return results for all validation timepoints. print("Mean difference between model and naive: %.4f" % (sum(differences_of_model_and_naive_errors) / len(differences_of_model_and_naive_errors)), flush=True) print("Proportion of timepoints when model < naive: %.2f" % ((np.array(differences_of_model_and_naive_errors) < 0).sum() / float(len(differences_of_model_and_naive_errors)))) return results def test(model_class, model_kwargs, data, targets, timepoints, coefficients=None, group_by="clade_membership", include_attributes=False): """Calculate test scores for the given data and targets across the given timepoints. Parameters ---------- model : ExponentialGrowthModel an instance of a model with defined hyperparameters including a list of predictors to use for fitting data : pandas.DataFrame standardized input attributes to use for model fitting targets : pandas.DataFrame observed outputs to test the model with timepoints : list a list of timepoint strings in YYYY-MM-DD format coefficients : ndarray an array of fixed coefficients for the given model's predictors group_by : string column of the tip attributes by which they should be grouped to calculate the total number of samples in the model (e.g., group by clade or strain) include_attributes : boolean specifies whether tip attribute data used to test models should be included in the output per timepoint Returns ------- list a list of dictionaries containing test results with scores per timepoint """ results = [] differences_of_model_and_naive_errors = [] for timepoint in timepoints: model = model_class(**model_kwargs) model.coef_ = coefficients model.mean_stds_ = np.zeros_like(coefficients) # Get training and validation timepoints. test_timepoint = pd.to_datetime(timepoint) # Get test data by timepoints. test_X = data[data["timepoint"] == test_timepoint].copy() test_y = targets[targets["timepoint"] == test_timepoint].copy() # Calculate the model score for the validation data. test_error = model.score(test_X, test_y) null_test_error = model._fit(np.zeros_like(model.coef_), test_X, test_y, calculate_optimal_distance=True) optimal_test_error = model.optimal_model_emd differences_of_model_and_naive_errors.append(test_error - null_test_error) print( "%s\t%.2f\t%.2f\t%.2f\t%s\t%.2f" % ( test_timepoint.strftime("%Y-%m"), test_error, null_test_error, optimal_test_error, model.coef_, (np.array(differences_of_model_and_naive_errors) < 0).sum() / float(len(differences_of_model_and_naive_errors)) ), flush=True ) # Get the estimated frequencies for test sets to export. test_y_hat = model.predict(test_X) # Convert timestamps to a serializable format. for df in [test_X, test_y, test_y_hat]: for column in ["timepoint", "future_timepoint"]: if column in df.columns: df[column] = df[column].dt.strftime("%Y-%m-%d") # Store test results and beta coefficients. result = { "predictors": model.predictors, "coefficients": model.coef_.tolist(), "mean_stds": model.mean_stds_.tolist(), "validation_data": { "y": test_y.to_dict(orient="records"), "y_hat": test_y_hat.to_dict(orient="records") }, "validation_n": test_X[group_by].unique().shape[0], "validation_error": test_error, "null_validation_error": null_test_error, "optimal_validation_error": optimal_test_error, "validation_timepoint": test_timepoint.strftime("%Y-%m-%d") } # Include tip attributes, if requested. if include_attributes: result["validation_data"]["X"] = test_X.to_dict(orient="records") results.append(result) # Return results for all validation timepoints. print("Mean difference between model and naive: %.4f" % (sum(differences_of_model_and_naive_errors) / len(differences_of_model_and_naive_errors)), flush=True) print("Proportion of timepoints when model < naive: %.2f" % ((np.array(differences_of_model_and_naive_errors) < 0).sum() / float(len(differences_of_model_and_naive_errors)))) return results def summarize_scores(scores, include_scores=False): """Summarize model errors across timepoints. Parameters ---------- scores : list a list of cross-validation results including training errors, cross-validation errors, and beta coefficients OR a list of test errors include_scores : boolean specifies whether cross-validation scores should be included in the output per timepoint Returns ------- dict : a dictionary of all cross-validation results plus summary statistics for training, cross-validation, and beta coefficients OR test results """ summary = { "predictors": scores[0]["predictors"] } if include_scores: summary["scores"] = scores validation_errors = [score["validation_error"] for score in scores] summary["cv_error_mean"] = np.mean(validation_errors) summary["cv_error_std"] = np.std(validation_errors) coefficients = np.array([ np.array(score["coefficients"]) for score in scores ]) summary["coefficients_mean"] = coefficients.mean(axis=0).tolist() summary["coefficients_std"] = coefficients.std(axis=0).tolist() mean_stds = np.array([ np.array(score["mean_stds"]) for score in scores ]) summary["mean_stds_mean"] = mean_stds.mean(axis=0).tolist() summary["mean_stds_std"] = mean_stds.std(axis=0).tolist() return summary def get_errors_by_timepoint(scores): """Convert cross-validation errors into a data frame by timepoint and predictors. Parameters ---------- scores : list a list of cross-validation results including training errors, cross-validation errors, and beta coefficients Returns ------- pandas.DataFrame """ predictors = "-".join(scores[0]["predictors"]) errors_by_time = [] for score in scores: errors_by_time.append({ "predictors": predictors, "validation_timepoint": pd.to_datetime(score["validation_timepoint"]), "validation_error": score["validation_error"], "null_validation_error": score["null_validation_error"], "optimal_validation_error": score["optimal_validation_error"], "validation_n": score["validation_n"] }) return pd.DataFrame(errors_by_time) def get_coefficients_by_timepoint(scores): """Convert model coefficients into a data frame by timepoint and predictors. Parameters ---------- scores : list a list of cross-validation results including training errors, cross-validation errors, and beta coefficients Returns ------- pandas.DataFrame """ predictors = "-".join(scores[0]["predictors"]) coefficients_by_time = [] for score in scores: for predictor, coefficient in zip(score["predictors"], score["coefficients"]): coefficients_by_time.append({ "predictors": predictors, "predictor": predictor, "coefficient": coefficient, "validation_timepoint": pd.to_datetime(score["validation_timepoint"]) }) return pd.DataFrame(coefficients_by_time) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tip-attributes", required=True, help="tab-delimited file describing tip attributes at all timepoints with standardized predictors") parser.add_argument("--output", required=True, help="JSON representing the model fit with training and cross-validation results, beta coefficients for predictors, and summary statistics") parser.add_argument("--predictors", nargs="+", help="tip attribute columns to use as predictors of final clade frequencies; optional if a fixed model is provided") parser.add_argument("--delta-months", type=int, help="number of months to project clade frequencies into the future") parser.add_argument("--target", required=True, choices=["clades", "distances"], help="target for models to fit") parser.add_argument("--final-clade-frequencies", help="tab-delimited file of clades per timepoint and their corresponding tips and tip frequencies at the given delta time in the future") parser.add_argument("--distances", help="tab-delimited file of distances between pairs of samples") parser.add_argument("--training-window", type=int, default=4, help="number of years required for model training") parser.add_argument("--l1-lambda", type=float, default=0.0, help="L1 regularization lambda") parser.add_argument("--cost-function", default="sse", choices=["sse", "rmse", "mae", "information_gain", "diffsum"], help="name of the function that returns the error between observed and estimated values") parser.add_argument("--pseudocount", type=float, help="pseudocount numerator to adjust all frequencies by, enabling some information theoretic metrics like information gain") parser.add_argument("--include-attributes", action="store_true", help="include attribute data used to train/validate models in the cross-validation output") parser.add_argument("--include-scores", action="store_true", help="include score data resulting from cross-validation output") parser.add_argument("--errors-by-timepoint", help="optional data frame of cross-validation errors by validation timepoint") parser.add_argument("--coefficients-by-timepoint", help="optional data frame of coefficients by validation timepoint") parser.add_argument("--fixed-model", help="optional model JSON to use as a fixed model for calculation of test error in forecasts") args = parser.parse_args() cost_functions_by_name = { "sse": sum_of_squared_errors, "rmse": root_mean_square_error, "mae": mean_absolute_error, "information_gain": negative_information_gain, "diffsum": sum_of_differences } # Load standardized tip attributes subsetting to tip name, clade, frequency, # and requested predictors. tips = pd.read_csv( args.tip_attributes, sep="\t", parse_dates=["timepoint"] ) if args.target == "clades": # Load final clade tip frequencies. final_clade_tip_frequencies = pd.read_csv( args.final_clade_frequencies, sep="\t", parse_dates=["initial_timepoint", "final_timepoint"] ) # If a pseudocount numerator has been provided, update the given tip # frequencies both from current and future timepoints. if args.pseudocount is not None and args.pseudocount > 0.0: tips = add_pseudocounts_to_frequencies(tips, args.pseudocount) print("Sum of tip frequencies by timepoint: ", tips.groupby("timepoint")["frequency"].sum()) final_clade_tip_frequencies = add_pseudocounts_to_frequencies( final_clade_tip_frequencies, args.pseudocount, timepoint_column="initial_timepoint" ) print("Sum of tip frequencies by timepoint: ", final_clade_tip_frequencies.groupby("initial_timepoint")["frequency"].sum()) # Aggregate final clade frequencies. final_clade_frequencies = final_clade_tip_frequencies.groupby([ "initial_timepoint", "clade_membership" ])["frequency"].sum().reset_index() # Rename initial timepoint column for comparison with tip attribute data. targets = final_clade_frequencies.rename( columns={"initial_timepoint": "timepoint"} ) model_class = ExponentialGrowthModel model_kwargs = {} group_by_attribute = "clade_membership" elif args.target == "distances": # Scale each tip's weighted distance to future populations by one minus # the tip's current frequency. This ensures that lower frequency tips do # not considered closer to the future. tips["y"] = tips["weighted_distance_to_future"] # Get strain frequency per timepoint and subtract delta time from # timepoint to align strain frequencies with the previous timepoint and # make them appropriate as targets for the model. targets = tips.loc[:, ["strain", "timepoint", "frequency", "weighted_distance_to_present", "weighted_distance_to_future", "y"]].copy() targets["future_timepoint"] = targets["timepoint"] model_class = DistanceExponentialGrowthModel with open(args.distances, "r") as fh: print("Read distances", flush=True) reader = csv.DictReader(fh, delimiter="\t") print("Get distances by sample names", flush=True) distances_by_sample_names = get_distances_by_sample_names(reader) print("Data loaded", flush=True) model_kwargs = {"distances": distances_by_sample_names} group_by_attribute = "strain" # Identify all available timepoints from tip attributes. timepoints = tips["timepoint"].dt.strftime("%Y-%m-%d").unique() # If a fixed model is provided, calculate test errors. Otherwise, calculate # cross-validation errors. if args.fixed_model is not None: # Load model details and extract mean coefficients. with open(args.fixed_model, "r") as fh: model_json = json.load(fh) coefficients = np.array(model_json["coefficients_mean"]) delta_months = model_json["delta_months"] delta_time = delta_months / 12.0 l1_lambda = model_json["l1_lambda"] training_window = model_json["training_window"] cost_function_name = model_json["cost_function"] cost_function = cost_functions_by_name[cost_function_name] model_kwargs.update({ "predictors": model_json["predictors"], "delta_time": delta_time, "l1_lambda": l1_lambda, "cost_function": cost_function }) # Find the latest timepoint we can project from based on the given delta # months. latest_timepoint = pd.to_datetime(timepoints[-1]) - pd.DateOffset(months=delta_months) test_timepoints = [ timepoint for timepoint in timepoints if pd.to_datetime(timepoint) <= latest_timepoint ] # Calculate test errors/scores for the given coefficients and data at # the identified test timepoints. targets["timepoint"] = targets["timepoint"] - pd.DateOffset(months=delta_months) scores = test( model_class, model_kwargs, tips, targets, test_timepoints, coefficients, group_by=group_by_attribute, include_attributes=args.include_attributes ) else: # First, confirm that all predictors are defined in the given tip # attributes. if not all([predictor in tips.columns for predictor in args.predictors]): print("ERROR: Not all predictors could be found in the given tip attributes table.", file=sys.stderr) sys.exit(1) # Select the cost function. cost_function_name = args.cost_function cost_function = cost_functions_by_name[cost_function_name] # Identify train/validate splits from timepoints. training_window = args.training_window train_validate_timepoints = get_train_validate_timepoints( timepoints, args.delta_months, training_window ) # For each train/validate split, fit a model to the training data, and # evaluate the model with the validation data, storing the training results, # beta parameters, and validation results. delta_months = args.delta_months delta_time = delta_months / 12.0 l1_lambda = args.l1_lambda model_kwargs.update({ "predictors": args.predictors, "delta_time": delta_time, "l1_lambda": l1_lambda, "cost_function": cost_function }) # If this is a naive model, set the coefficients to zero so cross-validation # can run under naive model conditions. if "naive" in args.predictors: coefficients = np.zeros(len(args.predictors)) else: coefficients = None targets["timepoint"] = targets["timepoint"] - pd.DateOffset(months=delta_months) scores = cross_validate( model_class, model_kwargs, tips, targets, train_validate_timepoints, coefficients, group_by=group_by_attribute, include_attributes=args.include_attributes ) # Summarize model errors including in-sample errors by AIC, out-of-sample # errors by cross-validation, and beta parameters across timepoints. model_results = summarize_scores(scores, args.include_scores) # Annotate parameters used to produce models. model_results["cost_function"] = cost_function_name model_results["l1_lambda"] = l1_lambda model_results["delta_months"] = delta_months model_results["training_window"] = training_window model_results["pseudocount"] = args.pseudocount # Save model fitting hyperparameters, raw results, and summary of results to # JSON. with open(args.output, "w") as fh: json.dump(model_results, fh, indent=1) # Save errors by timepoint, if requested. if args.errors_by_timepoint: errors_by_timepoint_df = get_errors_by_timepoint(scores) errors_by_timepoint_df.to_csv(args.errors_by_timepoint, sep="\t", header=True, index=False) # Save coefficients by timepoint, if requested. if args.coefficients_by_timepoint: coefficients_by_timepoint_df = get_coefficients_by_timepoint(scores) coefficients_by_timepoint_df.to_csv(args.coefficients_by_timepoint, sep="\t", header=True, index=False) |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | import argparse import csv import json import numpy as np import pandas as pd import sys from fit_model import DistanceExponentialGrowthModel from weighted_distances import get_distances_by_sample_names if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tip-attributes", required=True, help="tab-delimited file describing tip attributes at all timepoints with standardized predictors") parser.add_argument("--distances", help="tab-delimited file of distances between pairs of samples") parser.add_argument("--frequencies", help="JSON representing historical frequencies to project from") parser.add_argument("--model", required=True, help="JSON representing the model fit with training and cross-validation results, beta coefficients for predictors, and summary statistics") parser.add_argument("--delta-months", required=True, type=int, nargs="+", help="number of months to project clade frequencies into the future") parser.add_argument("--output-node-data", help="node data JSON of forecasts for the given tips") parser.add_argument("--output-frequencies", help="frequencies JSON extended with forecasts for the given tips") parser.add_argument("--output-table", help="table of forecasts for the given tips") args = parser.parse_args() # Confirm that at least one output file has been specified. outputs = [ args.output_node_data, args.output_frequencies, args.output_table ] outputs_missing =[output is None for output in outputs] if all(outputs_missing): print("ERROR: No output files were specified", file=sys.stderr) sys.exit(1) # Load standardized tip attributes subsetting to tip name, clade, frequency, # and requested predictors. tips = pd.read_csv( args.tip_attributes, sep="\t", parse_dates=["timepoint"] ) # Load distances. with open(args.distances, "r") as fh: reader = csv.DictReader(fh, delimiter="\t") distances_by_sample_names = get_distances_by_sample_names(reader) # Load model details with open(args.model, "r") as fh: model_json = json.load(fh) predictors = model_json["predictors"] cost_function = model_json["cost_function"] l1_lambda = model_json["l1_lambda"] coefficients = np.array(model_json["coefficients_mean"]) mean_stds = np.array(model_json["mean_stds_mean"]) delta_month = args.delta_months[-1] delta_time = delta_month / 12.0 delta_offset = pd.DateOffset(months=delta_month) model = DistanceExponentialGrowthModel( predictors=predictors, delta_time=delta_time, cost_function=cost_function, l1_lambda=l1_lambda, distances=distances_by_sample_names ) model.coef_ = coefficients model.mean_stds_ = mean_stds # collect fitness and projection forecasts_df = model.predict(tips) forecasts_df["weighted_distance_to_future_by_%s" % "-".join(predictors)] = forecasts_df["y"] forecasts_df["future_timepoint"] = forecasts_df["timepoint"] + delta_offset # collect dicts from dataframe strain_to_fitness = {} strain_to_future_timepoint = {} strain_to_projected_frequency = {} strain_to_weighted_distance_to_future = {} for index, row in forecasts_df.iterrows(): strain_to_fitness[row['strain']] = row['fitness'] strain_to_future_timepoint[row['strain']] = row["future_timepoint"].strftime("%Y-%m-%d") strain_to_projected_frequency[row['strain']] = row['projected_frequency'] strain_to_weighted_distance_to_future[row['strain']] = row['y'] # output to file if args.output_node_data: # populate node data node_data = {} strains = list(tips['strain']) for strain in strains: node_data[strain] = { "fitness": strain_to_fitness[strain], "future_timepoint": strain_to_future_timepoint[strain], "projected_frequency": strain_to_projected_frequency[strain], "weighted_distance_to_future": strain_to_weighted_distance_to_future[strain] } with open(args.output_node_data, "w") as jsonfile: json.dump({"nodes": node_data}, jsonfile, indent=1) # load historic frequencies if args.frequencies: with open(args.frequencies, "r") as fh: frequencies = json.load(fh) pivots = frequencies.pop("pivots") projection_pivot = pivots[-1] else: frequencies = None forecasts = [] for delta_month in args.delta_months: delta_time = delta_month / 12.0 delta_offset = pd.DateOffset(months=delta_month) model = DistanceExponentialGrowthModel( predictors=predictors, delta_time=delta_time, cost_function=cost_function, l1_lambda=l1_lambda, distances=distances_by_sample_names ) model.coef_ = coefficients model.mean_stds_ = mean_stds # collect fitness and projection forecasts_df = model.predict(tips) forecasts_df["future_timepoint"] = forecasts_df["timepoint"] + delta_offset # collect dicts from dataframe strain_to_projected_frequency = {} for index, row in forecasts_df.iterrows(): strain_to_projected_frequency[row['strain']] = row['projected_frequency'] if frequencies is not None: # extend frequencies for strain in frequencies.keys(): trajectory = frequencies[strain]['frequencies'] if strain in strain_to_projected_frequency: trajectory.append(strain_to_projected_frequency[strain]) else: trajectory.append(0.0) # extend pivots pivots.append(projection_pivot + delta_time) # Collect forecast data frames, if requested. if args.output_table: forecasts.append(forecasts_df) # reconnect pivots and label projection pivot if frequencies is not None: frequencies['pivots'] = pivots frequencies['projection_pivot'] = projection_pivot # output to file if args.output_frequencies: with open(args.output_frequencies, "w") as jsonfile: json.dump(frequencies, jsonfile, indent=1) # Save forecasts table, if requested. if args.output_table: all_forecasts = pd.concat(forecasts, ignore_index=True) all_forecasts["model"] = "-".join(predictors) all_forecasts.to_csv(args.output_table, sep="\t", index=False, header=True, na_rep="N/A") |
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | import argparse from collections import defaultdict import csv import numpy as np import pandas as pd import sys def get_distances_by_sample_names(distances): """Return a dictionary of distances by pairs of sample names. Parameters ---------- distances : iterator an iterator of dictionaries with keys of distance, sample, and other_sample Returns ------- dict : dictionary of distances by pairs of sample names """ distances_by_sample_names = defaultdict(dict) for record in distances: sample_a = record["sample"] sample_b = record["other_sample"] distance = int(record["distance"]) distances_by_sample_names[sample_a][sample_b] = distance return distances_by_sample_names def get_distance_matrix_by_sample_names(samples_a, samples_b, distances): """Return a matrix of distances between pairs of given sample sets. Parameters ---------- samples_a, samples_b : list names of samples whose pairwise distances should populate the matrix with the first samples in rows and the second samples in columns distances : dict dictionary of distances by pairs of sample names Returns ------- ndarray : matrix of pairwise distances between the given samples >>> samples_a = ["a", "b"] >>> samples_b = ["c", "d"] >>> distances = {"a": {"c": 1, "d": 2}, "b": {"c": 3, "d": 4}} >>> get_distance_matrix_by_sample_names(samples_a, samples_b, distances) array([[1., 2.], [3., 4.]]) >>> get_distance_matrix_by_sample_names(samples_b, samples_a, distances) array([[1., 3.], [2., 4.]]) """ matrix = np.zeros((len(samples_a), len(samples_b))) for i, sample_a in enumerate(samples_a): for j, sample_b in enumerate(samples_b): try: matrix[i, j] = distances[sample_a][sample_b] except KeyError: matrix[i, j] = distances[sample_b][sample_a] return matrix if __name__ == '__main__': parser = argparse.ArgumentParser( description="Annotated weighted distances between viruses", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--tip-attributes", required=True, help="a tab-delimited file describing tip attributes at one or more timepoints") parser.add_argument("--distances", required=True, help="tab-delimited file with pairwise distances between samples") parser.add_argument("--delta-months", required=True, type=int, help="number of months to project clade frequencies into the future") parser.add_argument("--output", required=True, help="tab-delimited output file with mean and standard deviation used to standardize each predictor") args = parser.parse_args() # Load tip attributes. tips = pd.read_csv(args.tip_attributes, sep="\t", parse_dates=["timepoint"]) # Load distances. with open(args.distances, "r") as fh: reader = csv.DictReader(fh, delimiter="\t") # Map distances by sample names. distances_by_sample_names = get_distances_by_sample_names(reader) # Find valid timepoints for calculating distances to the future. timepoints = tips["timepoint"].drop_duplicates() last_timepoint = timepoints.max() - pd.DateOffset(months=args.delta_months) # Calculate weighted distance to the present and future for each sample at a # given timepoint. weighted_distances = [] for timepoint in timepoints: future_timepoint = timepoint + pd.DateOffset(months=args.delta_months) timepoint_tips = tips[tips["timepoint"] == timepoint] future_timepoint_tips = tips[tips["timepoint"] == future_timepoint] for current_tip, current_tip_frequency in timepoint_tips.loc[:, ["strain", "frequency"]].values: # Calculate the distance to the present for all timepoints. weighted_distance_to_present = 0.0 for other_current_tip, other_current_tip_frequency in timepoint_tips.loc[:, ["strain", "frequency"]].values: weighted_distance_to_present += other_current_tip_frequency * distances_by_sample_names[current_tip][other_current_tip] # Calculate the distance to the future only for valid timepoints (those with future information). if timepoint <= last_timepoint: weighted_distance_to_future = 0.0 for future_tip, future_tip_frequency in future_timepoint_tips.loc[:, ["strain", "frequency"]].values: weighted_distance_to_future += future_tip_frequency * distances_by_sample_names[current_tip][future_tip] else: weighted_distance_to_future = np.nan weighted_distances.append({ "timepoint": timepoint, "strain": current_tip, "weighted_distance_to_present": weighted_distance_to_present, "weighted_distance_to_future": weighted_distance_to_future }) weighted_distances = pd.DataFrame(weighted_distances) # Calculate the magnitude of the difference between future and present # distances for each sample. weighted_distances["log2_distance_effect"] = np.log2( weighted_distances["weighted_distance_to_future"] / weighted_distances["weighted_distance_to_present"] ) # Annotate samples with weighted distances. annotated_tips = tips.merge( weighted_distances, how="left", on=["strain", "timepoint"] ) # Save the new data frame. annotated_tips.to_csv(args.output, sep="\t", index=False, na_rep="N/A") |
Support
- Future updates
Related Workflows





