RINDTI: Simplifying Drug-Target Interaction Prediction with Protein Residue Interaction Networks
Help improve this workflow!
This workflow has been published but could be further improved with some additional meta data:- Keyword(s) in categories input, output
You can help improve this workflow by suggesting the addition or removal of keywords, suggest changes and report issues, or request to become a maintainer of the Workflow .
This repository aims to simplify the drug-target interaction prediction process which is based on protein residue interaction networks (RINs)
Overview
The repository aims to go from a simple collections of inputs - structures of proteins, interactions data on drugs to a fully-function GNN model
Installation
-
clone the repository with
git clone https://github.com/ilsenatorov/rindti
-
change in the root directory with
cd rindti
-
(Optional) install mamba with
conda install -n base -c conda-forge mamba
-
create the conda environment with
mamba env create -f workflow/envs/main.yaml
(might take some time) -
activate the environment with
conda activate rindti
-
Test the installation with
pytest
Documentation
Check out the documentation to get more information.
Contributing
If you would like to contribute to the repository, please check out the contributing guide .
Code Snippets
9 10 | script: "../scripts/parse_dataset.py" |
22 23 | script: "../scripts/split_data.py" |
33 34 | script: "../scripts/prepare_all.py" |
13 14 | script: "../scripts/prepare_drugs.py" |
39 40 41 42 | shell: """ rinerator {input.pdb} {params.dir}/{wildcards.prot} > {log} 2>&1 """ |
53 54 | script: "../scripts/parse_rinerator.py" |
65 66 | script: "../scripts/distance_based.py" |
74 75 | script: "../scripts/prot_esm.py" |
84 85 | script: "../scripts/pretrain_prot_data.py" |
28 29 | script: "../scripts/create_pymol_scripts.py" |
42 43 | shell: "pymol -k -y -c {input.script} > {log} 2>&1" |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | import os import resource def create_script(protein: str, inp: str, params: dict): """ Create pymol parsing script for a protein according to the params. """ resources = params.resources results = params.results fmt_keywords = {"protein": protein, "resources": resources, "results": results} script = [ "import psico.fullinit", "from glob import glob", 'cmd.load("{inp}")', ] if params.method == "plddt": script.append('cmd.select("result", "b > {threshold}")') fmt_keywords["threshold"] = params.other_params[params.method]["threshold"] else: # template-based script += [ 'lst = glob("{resources}/templates/*.pdb")', 'templates = [x.split("/")[-1].split(".")[0] for x in lst]', "for i in lst:cmd.load(i)", 'scores = {{x : cmd.tmalign("{protein}", x) for x in templates}}', "max_score = max(scores, key=scores.get)", 'cmd.extra_fit("name CA", max_score, "tmalign")', ] fmt_keywords["radius"] = params.other_params[params.method]["radius"] if params.method == "bsite": script.append('cmd.select("result", "br. {protein} within {radius} of organic")') elif params.method == "template": script.append('cmd.select("result", "br. {protein} within {radius} of not {protein} and name CA")') script.append('cmd.save("{parsed_structs_dir}/{protein}.pdb", "result")') fmt_keywords["parsed_structs_dir"] = params.parsed_structs_dir fmt_keywords["structs"] = params.method fmt_keywords["inp"] = inp return "\n".join(script).format(**fmt_keywords) if __name__ == "__main__": for inp, out in zip(snakemake.input, snakemake.output): protein = os.path.basename(out).split(".")[0] with open(out, "w") as file: file.write(create_script(protein, inp, snakemake.params)) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | import torch from encd import encd from utils import onehot_encode node_encoding = encd["prot"]["node"] def encode_residue(residue: str, node_feats: str): """Encode a residue""" residue = residue.lower() if node_feats == "label": if residue not in node_encoding: return node_encoding["unk"] return node_encoding[residue] + 1 elif node_feats == "onehot": return onehot_encode(node_encoding[residue], len(node_encoding)) else: raise ValueError("Unknown node_feats type!") class Residue: """Residue class""" def __init__(self, line: str) -> None: self.name = line[17:20].strip() self.num = int(line[22:26].strip()) self.chainID = line[21].strip() self.x = float(line[30:38].strip()) self.y = float(line[38:46].strip()) self.z = float(line[46:54].strip()) class Structure: """Structure class""" def __init__(self, filename: str, node_feats: str) -> None: self.residues = {} self.parse_file(filename) self.node_feats = node_feats def parse_file(self, filename: str) -> None: """Parse PDB file""" for line in open(filename, "r"): if line.startswith("ATOM") and line[12:16].strip() == "CA": res = Residue(line) self.residues[res.num] = res def get_coords(self) -> torch.Tensor: """Get coordinates of all atoms""" coords = [[res.x, res.y, res.z] for res in self.residues.values()] return torch.tensor(coords) def get_nodes(self) -> torch.Tensor: """Get features of all nodes of a graph""" return torch.tensor([encode_residue(res.name, self.node_feats) for res in self.residues.values()]) def get_edges(self, threshold: float) -> torch.Tensor: """Get edges of a graph using threshold as a cutoff""" coords = self.get_coords() dist = torch.cdist(coords, coords) edges = torch.where(dist < threshold) edges = torch.cat([arr.view(-1, 1) for arr in edges], axis=1) edges = edges[edges[:, 0] != edges[:, 1]] return edges.t() def get_graph(self, threshold: float) -> dict: """Get a graph using threshold as a cutoff""" nodes = self.get_nodes() edges = self.get_edges(threshold) return dict(x=nodes, edge_index=edges) if __name__ == "__main__": import pickle import pandas as pd from joblib import Parallel, delayed from tqdm import tqdm if "snakemake" in globals(): all_structures = snakemake.input.pdbs threshold = snakemake.params.threshold def get_graph(filename: str) -> dict: """Single function to be run in parallel.""" return Structure(filename, snakemake.params.node_feats).get_graph(threshold) data = Parallel(n_jobs=snakemake.threads)(delayed(get_graph)(i) for i in tqdm(all_structures)) df = pd.DataFrame(pd.Series(data, name="data")) df["filename"] = all_structures df["ID"] = df["filename"].apply(lambda x: x.split("/")[-1].split(".")[0]) df.set_index("ID", inplace=True) df.drop("filename", axis=1, inplace=True) df = df.to_pickle(snakemake.output.pickle) else: import os import os.path as osp from jsonargparse import CLI def run(pdb_dir: str, output: str, threads: int = 1, threshold: float = 5, node_feats: str = "label"): """Run the pipeline""" def get_graph(filename: str) -> dict: """Calculate a single graph from a file""" return Structure(filename, node_feats).get_graph(threshold) pdbs = [osp.join(pdb_dir, x) for x in os.listdir(pdb_dir)] data = Parallel(n_jobs=threads)(delayed(get_graph)(i) for i in tqdm(pdbs)) df = pd.DataFrame(pd.Series(data, name="data")) df["filename"] = pdbs df["ID"] = df["filename"].apply(lambda x: x.split("/")[-1].split(".")[0]) df.set_index("ID", inplace=True) df.drop("filename", axis=1, inplace=True) df = df.to_dict("index") df.to_pickle(output) cli = CLI(run) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | import numpy as np import pandas as pd def posneg_filter(inter: pd.DataFrame) -> pd.DataFrame: """Only keep drugs that have at least 1 positive and negative interaction""" pos = inter[inter["Y"] == 1]["Drug_ID"].unique() neg = inter[inter["Y"] == 0]["Drug_ID"].unique() both = set(pos).intersection(set(neg)) inter = inter[inter["Drug_ID"].isin(both)] return inter def sample(inter: pd.DataFrame, how: str = "under") -> pd.DataFrame: """Sample the interactions dataset Args: inter (pd.DataFrame): whole data, has to be binary class how (str, optional): over or undersample. Oversample adds fake negatives, undersample removed extra positives. Defaults to "under". """ if how == "none": return inter total = [] pos = inter[inter["Y"] == 1] neg = inter[inter["Y"] == 0] for prot in inter["Target_ID"].unique(): possample = pos[pos["Target_ID"] == prot] negsample = neg[neg["Target_ID"] == prot] poscount = possample.shape[0] negcount = negsample.shape[0] if poscount == 0: continue if poscount >= negcount: if how == "under": total.append(possample.sample(negcount)) total.append(negsample) elif how == "over": total.append(possample) total.append(negsample) subsample = inter[inter["Target_ID"] != prot].sample(poscount - negcount) subsample["Target_ID"] = prot subsample["Y"] = 0 total.append(subsample) else: raise ValueError("Unknown sampling method!") else: total.append(possample) total.append(negsample.sample(poscount)) return pd.concat(total) if __name__ == "__main__": from pytorch_lightning import seed_everything seed_everything(snakemake.config["seed"]) inter = pd.read_csv(snakemake.input.inter, sep="\t") config = snakemake.config["parse_dataset"] # If duplicates, take median of entries inter = inter.groupby(["Drug_ID", "Target_ID"]).agg("median").reset_index() if config["task"] == "class": inter["Y"] = inter["Y"].apply(lambda x: int(x < config["threshold"])) elif config["task"] == "reg": if config["log"]: inter["Y"] = inter["Y"].apply(np.log10) else: raise ValueError("Unknown task!") if config["filtering"] != "all" and config["sampling"] != "none" and config["task"] == "reg": raise ValueError( "Can't use filtering {filter} with task {task}!".format(filter=config["filtering"], task=config["task"]) ) if config["filtering"] == "posneg": inter = posneg_filter(inter) elif config["filtering"] != "all": raise ValueError("No such type of filtering!") inter = sample(inter, how=config["sampling"]) inter.to_csv(snakemake.output.inter, index=False, sep="\t") |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | import os from typing import Tuple import numpy as np import pandas as pd import torch from encd import encd from utils import onehot_encode class ProteinEncoder: """Performs all the encoding steps for a single sif file.""" def __init__(self, node_feats: str, edge_feats: str): self.node_feats = node_feats self.edge_feats = edge_feats def encode_residue(self, residue: str) -> np.array: """Fully encode residue - one-hot and node_feats Args: residue (str): One-letter residue name Returns: np.array: Concatenated node_feats and one-hot encoding of residue name """ residue = residue.lower() if self.node_feats == "label": return encd["prot"]["node"][residue] + 1 elif self.node_feats == "onehot": return onehot_encode(encd["prot"]["node"][residue], len(encd["prot"]["node"])) else: raise ValueError("Unknown node_feats type!") def parse_sif(self, filename: str) -> Tuple[pd.DataFrame, pd.DataFrame]: """Parse a single sif file Args: filename (str): SIF file location Returns: Tuple[DataFrame, DataFrame]: nodes, edges DataFrames """ nodes = [] edges = [] if not os.path.exists(filename): return None, None with open(filename, "r") as file: for line in file: line = line.strip() splitline = line.split() if len(splitline) != 3: continue node1, edgetype, node2 = splitline node1split = node1.split(":") node2split = node2.split(":") if len(node1split) != 4: continue if len(node2split) != 4: continue chain1, resn1, x1, resaa1 = node1split chain2, resn2, x2, resaa2 = node2split if x1 != "_" or x2 != "_": continue if resaa1.lower() not in encd["prot"]["node"] or resaa2.lower() not in encd["prot"]["node"]: continue resn1 = int(resn1) resn2 = int(resn2) if resn1 == resn2: continue edgesplit = edgetype.split(":") if len(edgesplit) != 2: continue node1 = {"chain": chain1, "resn": resn1, "resaa": resaa1} node2 = {"chain": chain2, "resn": resn2, "resaa": resaa2} edgetype, _ = edgesplit edge1 = { "resn1": resn1, "resn2": resn2, "type": edgetype, } edge2 = { "resn1": resn2, "resn2": resn1, "type": edgetype, } nodes.append(node1) nodes.append(node2) edges.append(edge1) edges.append(edge2) nodes = pd.DataFrame(nodes).drop_duplicates() try: nodes = nodes.sort_values("resn").reset_index(drop=True).reset_index().set_index("resn") except Exception as e: print(nodes) print(filename) print(e) return None, None for node in nodes.index: if (node - 1) in nodes.index: edges.append({"resn1": node, "resn2": node - 1, "type": "pept"}) edges.append({"resn2": node, "resn1": node - 1, "type": "pept"}) edges = pd.DataFrame(edges).drop_duplicates() node_idx = nodes["index"].to_dict() edges["node1"] = edges["resn1"].apply(lambda x: node_idx[x]) edges["node2"] = edges["resn2"].apply(lambda x: node_idx[x]) return nodes, edges def encode_nodes(self, nodes: pd.DataFrame) -> torch.Tensor: """Given dataframe of nodes create node node_feats Args: nodes (pd.DataFrame): nodes dataframe from parse_sif Returns: torch.Tensor: Tensor of node node_feats [n_nodes, *] """ nodes.drop_duplicates(inplace=True) node_attr = [self.encode_residue(x) for x in nodes["resaa"]] node_attr = np.asarray(node_attr) if self.node_feats == "label": node_attr = torch.tensor(node_attr, dtype=torch.long) else: node_attr = torch.tensor(node_attr, dtype=torch.float32) return node_attr def encode_edges(self, edges: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor]: """Given dataframe of edges, create edge index and edge node_feats Args: edges (pd.DataFrame): edges dataframe from parse_sif Returns: Tuple[torch.Tensor, torch.Tensor]: edge index [2,n_edges], edge attributes [n_edges, *] """ if self.edge_feats == "none": edges.drop("type", axis=1, inplace=True) edges.drop_duplicates(inplace=True) edge_index = edges[["node1", "node2"]].astype(int).values edge_index = torch.tensor(edge_index, dtype=torch.long) edge_index = edge_index.t().contiguous() if self.edge_feats == "none": return edge_index, None edge_feats = edges["type"].apply(lambda x: encd["prot"]["edge"][x]) if self.edge_feats == "label": edge_feats = torch.tensor(edge_feats, dtype=torch.long) return edge_index, edge_feats elif self.edge_feats == "onehot": edge_feats = edge_feats.apply(onehot_encode, count=len(encd["prot"]["edge"])) edge_feats = torch.tensor(edge_feats, dtype=torch.float) return edge_index, edge_feats def __call__(self, protein_sif: str) -> dict: """Fully process the protein Args: protein_sif (str): File location for sif file Returns: dict: standard format with x for node node_feats, edge_index for edges etc """ try: nodes, edges = self.parse_sif(protein_sif) if nodes is None: return np.nan node_attr = self.encode_nodes(nodes) edge_index, edge_feats = self.encode_edges(edges) return dict( x=node_attr, edge_index=edge_index, edge_feats=edge_feats, # index_mapping=nodes["index"].to_dict(), ) except Exception as e: print(protein_sif) print(e) return np.nan def extract_name(protein_sif: str) -> str: """Extract the protein name from the sif filename""" return protein_sif.split("/")[-1].split("_")[0] if __name__ == "__main__": if "snakemake" in globals(): prots = pd.Series(list(snakemake.input.rins), name="sif") prots = pd.DataFrame(prots) prots["ID"] = prots["sif"].apply(extract_name) prots.set_index("ID", inplace=True) prot_encoder = ProteinEncoder(snakemake.params.node_feats, snakemake.params.edge_feats) prots["data"] = prots["sif"].apply(prot_encoder) prots.to_pickle(snakemake.output.pickle) else: import argparse from joblib import Parallel, delayed from tqdm import tqdm parser = argparse.ArgumentParser(description="Prepare protein data from rinerator") parser.add_argument("--sifs", nargs="+", required=True, help="Rinerator output folders") parser.add_argument("--output", required=True, help="Output pickle file") parser.add_argument("--node_feats", type=str, default="label") parser.add_argument("--edge_feats", type=str, default="none") parser.add_argument("--threads", type=int, default=1, help="Number of threads to use") args = parser.parse_args() prots = pd.DataFrame(pd.Series(args.sifs, name="sif")) prots["ID"] = prots["sif"].apply(extract_name) prots.set_index("ID", inplace=True) prot_encoder = ProteinEncoder(args.node_feats, args.edge_feats) data = Parallel(n_jobs=args.threads)(delayed(prot_encoder)(i) for i in tqdm(prots["sif"])) prots["data"] = data prots.to_pickle(args.output) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | import pickle from typing import Iterable import pandas as pd from pandas.core.frame import DataFrame from utils import get_config def process(row: pd.Series) -> dict: """Process each interaction.""" split = row["split"] return { "label": row["Y"], "split": split, "prot_id": row["Target_ID"], "drug_id": row["Drug_ID"], } def process_df(df: DataFrame) -> Iterable[dict]: """Apply process() function to each row of the DataFrame""" return [process(row) for (_, row) in df.iterrows()] def del_index_mapping(x: dict) -> dict: """Delete 'index_mapping' entry from the dict""" if "index_mapping" in x: del x["index_mapping"] return x if __name__ == "__main__": interactions = pd.read_csv(snakemake.input.inter, sep="\t") with open(snakemake.input.drugs, "rb") as file: drugs = pickle.load(file) with open(snakemake.input.prots, "rb") as file: prots = pickle.load(file) interactions = interactions[interactions["Target_ID"].isin(prots.index)] interactions = interactions[interactions["Drug_ID"].isin(drugs.index)] prots = prots[prots.index.isin(interactions["Target_ID"].unique())] drugs = drugs[drugs.index.isin(interactions["Drug_ID"].unique())] prot_count = interactions["Target_ID"].value_counts() drug_count = interactions["Drug_ID"].value_counts() prots["data"] = prots.apply(lambda x: {**x["data"], "count": prot_count[x.name]}, axis=1) drugs["data"] = drugs.apply(lambda x: {**x["data"], "count": drug_count[x.name]}, axis=1) full_data = process_df(interactions) snakemake.config["data"] = { "prot": get_config(prots, "prot"), "drug": get_config(drugs, "drug"), } final_data = { "data": full_data, "config": snakemake.config, "prots": prots, "drugs": drugs, } with open(snakemake.output.combined_pickle, "wb") as file: pickle.dump(final_data, file, protocol=-1) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | import numpy as np import pandas as pd import torch from encd import encd from rdkit import Chem from rdkit.Chem import rdmolfiles, rdmolops from torch_geometric.utils import to_undirected from utils import onehot_encode class DrugEncoder: """Drug encoder, goes from SMILES to dictionary of torch data Args: node_feats (str): 'label' or 'onehot' edge_feats (str): 'label' or 'onehot max_num_atoms (int, optional): filter out molecules that are too big. Defaults to 150. """ def __init__(self, node_feats: str, edge_feats: str, max_num_atoms: int = 150): assert node_feats in {"label", "onehot", "glycan", "glycanone", "IUPAC"} assert edge_feats in {"label", "onehot", "none"} self.node_feats = node_feats self.edge_feats = edge_feats self.max_num_atoms = max_num_atoms def encode_node(self, atom_num, atom): """Encode single atom""" if atom_num not in encd["drug"]["node"].keys(): atom_num = "other" if self.node_feats == "glycan": if atom_num in encd["glycan"]: return encd["glycan"][atom_num] + encd["chirality"][atom.GetChiralTag()] else: return encd["glycan"]["other"] + encd["chirality"][atom.GetChiralTag()] label = encd["drug"]["node"][atom_num] if self.node_feats == "onehot": return onehot_encode(label, len(encd["drug"]["node"])) return label + 1 def encode_edge(self, edge): """Encode single edge""" label = encd["drug"]["edge"][edge] if self.edge_feats == "onehot": return onehot_encode(label, len(encd["drug"]["edge"])) elif self.edge_feats == "label": return label else: raise ValueError("This shouldn't be called for edge type none") def __call__(self, smiles: str) -> dict: """Generate drug Data from smiles Args: smiles (str): SMILES Returns: dict: dict with x, edge_index etc or np.nan for bad entries """ if smiles != smiles: # check for nans, i.e. missing smiles strings in dataset return np.nan mol = Chem.MolFromSmiles(smiles) if not mol: # when rdkit fails to read a molecule it returns None return np.nan new_order = rdmolfiles.CanonicalRankAtoms(mol) mol = rdmolops.RenumberAtoms(mol, new_order) edges = [] edge_feats = [] if self.edge_feats != "none" else None for bond in mol.GetBonds(): start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edges.append([start, end]) btype = str(bond.GetBondType()) # If bond type is unknown, remove molecule if btype not in encd["drug"]["edge"].keys(): return np.nan if self.edge_feats != "none": edge_feats.append(self.encode_edge(btype)) if not edges: # If no edges (bonds) were found, remove molecule return np.nan atom_features = [] for atom in mol.GetAtoms(): atom_num = atom.GetAtomicNum() atom_features.append(self.encode_node(atom_num, atom)) if len(atom_features) > self.max_num_atoms: return np.nan if self.node_feats == "label": x = torch.tensor(atom_features, dtype=torch.long) else: x = torch.tensor(atom_features, dtype=torch.float32) edge_index = torch.tensor(edges).t().contiguous() if self.edge_feats == "onehot": edge_feats = torch.tensor(edge_feats, dtype=torch.float32) elif self.edge_feats == "label": edge_feats = torch.tensor(edge_feats, dtype=torch.long) elif self.edge_feats == "none": edge_feats = None else: raise ValueError("Unknown edge encoding!") if self.edge_feats != "none": edge_index, edge_feats = to_undirected(edge_index, edge_feats) else: edge_index = to_undirected(edge_index) return dict(x=x, edge_index=edge_index, edge_feats=edge_feats) if __name__ == "__main__": drug_enc = DrugEncoder(snakemake.params.node_feats, snakemake.params.edge_feats, snakemake.params.max_num_atoms) ligs = pd.read_csv(snakemake.input.lig, sep="\t").set_index("Drug_ID") ligs["data"] = ligs["Drug"].apply(drug_enc) ligs = ligs[ligs["data"].notna()] ligs = ligs.to_pickle(snakemake.output.pickle) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | import pickle import pandas as pd from utils import get_config prot_table = pd.read_csv(snakemake.input.prot_table, sep="\t") prot_data = pd.read_pickle(snakemake.input.prot_data) prot_y = prot_table.set_index("Target_ID")["Y"].to_dict() dims_config = get_config(prot_data, "prot") dims_config["num_classes"] = len(prot_y) snakemake.config["prots"]["data"] = dims_config y_encoder = {v: k for k, v in enumerate(sorted(set(prot_y.values())))} result = [] for k, v in prot_data["data"].items(): v["y"] = y_encoder[prot_y[k]] v["id"] = k result.append(v) with open(snakemake.output.pretrain_prot_data, "wb") as file: pickle.dump( { "data": result, "config": snakemake.config["prots"], "decoder": {v: k for k, v in y_encoder.items()}, }, file, ) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | import os import esm import pandas as pd import torch from extract_esm import create_parser from extract_esm import main as extract_main def generate_esm_python(prot: pd.DataFrame) -> pd.DataFrame: """Return esms.""" model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() batch_converter = alphabet.get_batch_converter() model.eval() # disables dropout for deterministic results prot.set_index("Target_ID", inplace=True) data = [(k, v) for k, v in prot["Target"].to_dict().items()] _, _, batch_tokens = batch_converter(data) with torch.no_grad(): results = model(batch_tokens, repr_layers=[33], return_contacts=True) token_representations = results["representations"][33] sequence_representations = [] for i, (_, seq) in enumerate(data): sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0)) data = [{"x": x} for x in sequence_representations] prot["data"] = data prot = prot.to_dict("index") return prot def generate_esm_script(prot: pd.DataFrame) -> pd.DataFrame: """Create an ESM script for btach processing.""" prot_ids, seqs = list(zip(*[(k, v) for k, v in prot["Target"].to_dict().items()])) os.makedirs("./esms", exist_ok=True) with open("./esms/prots.fasta", "w") as fasta: for prot_id, seq in zip(prot_ids, seqs): fasta.write(f">{prot_id}\n{seq[:1022]}\n") esm_parser = create_parser() esm_args = esm_parser.parse_args( ["esm1b_t33_650M_UR50S", "esms/prots.fasta", "esms/", "--repr_layers", "33", "--include", "mean"] ) extract_main(esm_args) data = [] for prot_id in prot_ids: data.append({"x": torch.load(f"./esms/{prot_id}.pt")["mean_representations"][33].unsqueeze(0)}) # os.rmdir("./esms") prot["data"] = data # prot = prot.to_dict("index") return prot if __name__ == "__main__": import pickle prots = pd.read_csv(snakemake.input.seqs, sep="\t").set_index("Target_ID") prots = generate_esm_script(prots) prots.to_pickle(snakemake.output.pickle) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | import numpy as np import pandas as pd from sklearn.model_selection import train_test_split def split_groups( inter: pd.DataFrame, col_name: str = "Target_ID", bin_size: int = 10, train_frac: float = 0.7, val_frac: float = 0.2, ) -> pd.DataFrame: """Split data by protein (cold-target) Tries to ensure good size of all sets by sorting the prots by number of interactions and performing splits within bins of 10 Args: inter (pd.DataFrame): interaction DataFrame col_name (str): Which column to split on (col_name or 'Drug_ID' usually) bin_size (int, optional): Size of the bins to perform individual splits in. Defaults to 10. train_frac (float, optional): value from 0 to 1, how much of the data goes into train val_frac (float, optional): value from 0 to 1, how much of the data goes into validation Returns: pd.DataFrame: DataFrame with a new 'split' column """ sorted_index = [x for x in inter[col_name].value_counts().index] train_prop = int(bin_size * train_frac) val_prop = int(bin_size * val_frac) train = [] val = [] test = [] for i in range(0, len(sorted_index), bin_size): subset = sorted_index[i : i + bin_size] train_bin = list(np.random.choice(subset, min(len(subset), train_prop), replace=False)) train += train_bin subset = [x for x in subset if x not in train_bin] val_bin = list(np.random.choice(subset, min(len(subset), val_prop), replace=False)) val += val_bin subset = [x for x in subset if x not in val_bin] test += subset train_idx = inter[inter[col_name].isin(train)].index val_idx = inter[inter[col_name].isin(val)].index test_idx = inter[inter[col_name].isin(test)].index inter.loc[train_idx, "split"] = "train" inter.loc[val_idx, "split"] = "val" inter.loc[test_idx, "split"] = "test" return inter def split_random(inter: pd.DataFrame, train_frac: float = 0.7, val_frac: float = 0.2) -> pd.DataFrame: """Split the dataset in a completely random fashion Args: inter (pd.DataFrame): interaction DataFrame train_frac (float, optional): value from 0 to 1, how much of the data goes into train val_frac (float, optional): value from 0 to 1, how much of the data goes into validation Returns: pd.DataFrame: DataFrame with a new 'split' column """ train, valtest = train_test_split(inter, train_size=train_frac) val, test = train_test_split(valtest, train_size=val_frac) train.loc[:, "split"] = "train" val.loc[:, "split"] = "val" test.loc[:, "split"] = "test" inter = pd.concat([train, val, test]) return inter if __name__ == "__main__": from pytorch_lightning import seed_everything seed_everything(snakemake.config["seed"]) inter = pd.read_csv(snakemake.input.inter, sep="\t") fracs = {"train_frac": snakemake.params.train, "val_frac": snakemake.params.val} if snakemake.params.method == "target": inter = split_groups(inter, col_name="Target_ID", **fracs) elif snakemake.params.method == "drug": inter = split_groups(inter, col_name="Drug_ID", **fracs) elif snakemake.params.method == "random": inter = split_random(inter) else: raise NotImplementedError("Unknown split type!") inter.to_csv(snakemake.output.split_data, sep="\t") |
Support
- Future updates
Related Workflows





