Analysis pipeline for PEANUT performance and sensitivity

The following Snakemake pipeline conducts the complete performance and sensitivity analysis of PEANUT and its competitors as presented in our publication. The corresponding Snakefile can be downloaded here.

# vim: set syntax=python expandtab!:
import csv
from functools import partial
from collections import defaultdict, Counter
import matplotlib
matplotlib.use("agg")
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid.axislines import Subplot
import pysam


__author__ = "Johannes Köster"
__license__ = "MIT"


"""
PEANUT analysis workflow.
The workflow expects the homo sapiens reference as "ref/genome.fa" together with
BWA and Bowtie 2, cushaw2 (with prefix genome.cushaw2), cushaw3 (with prefix genome.cushaw3) and nvbwt (with prefix genome.nvbwt) indexes in the same directory.
This can be achieved via symlinks.
For unbiased performance estimates, it should be ensured that all files are accessible
from the same hard disk and do not reside on a network share.
Further, the workflow should not be executed multithreaded.
"""


############## definitions ################

# occupancy
BLOCKSIZES = list(range(32, 577, 32))

# sensitivity
RABEMA_SCORES = list(range(60, 101, 5))
RABEMA_ERRORS = list(range(6)) + list(range(10, 21, 5)) #+ list(range(30, 51, 10))

SENSITIVITY_TESTS = expand("rabema/evaluation/sim.{readlength}.{score}.{error}.txt", readlength=[100], score=RABEMA_SCORES, error=RABEMA_ERRORS)

# performance
PEANUT_TESTS = expand("peanut.align.{type}", type="strata all".split())
PEANUT_TESTS_RECALL = expand("peanut.align.{type}", type="strata_noxa all_noxa".split())
PERFORMANCE_MAPPERS = "bowtie2.best cushaw2-gpu bowtie2.all razers3 ngm cushaw3 bwamem nvbowtie mrfast".split()
# bowtie2.all does not terminate due to swapping
PERFORMANCE_MAPPERS.remove("bowtie2.all")
# nvbowtie is optimized for Tesla GPUs with larger memory than our test system and fails with bad_alloc
PERFORMANCE_MAPPERS.remove("nvbowtie")
# ngm exits with an error for the large dataset, razers3 badalloc, mrfast crashes
PERFORMANCE_MAPPERS_LARGE_BLACKLIST = "ngm".split()
# peanut all not needed, razers3 throwing error (TODO find better parameters for razers3)
PERFORMANCE_VS_RECALL_BLACKLIST = "razers3 peanut.align.all_noxa".split()


PERFORMANCE_FASTQ = {
    "ERR281333.5000000": expand("sampled/ERR281333_{read}.5000000.fastq.gz", read="1 2".split()),
    "ERR281333_1.5000000": ["sampled/ERR281333_1.5000000.fastq.gz"],
    "simulated.5000000": ["sampled/simulated.5000000.fastq.gz"],
    "simulated.1000": ["sampled/simulated.1000.fastq.gz"],
    "ERR091787.25000000": expand("sampled/ERR091787_{read}.25000000.fastq.gz", read="1 2".split())
}

def get_performance_fastq(wildcards, gzip=True):
    fastqs = PERFORMANCE_FASTQ[wildcards.dataset]
    if not gzip:
        fastqs = ["".join(os.path.splitext(f)[:-1]) for f in fastqs]
    return fastqs

PERFORMANCE_RECALL = expand(
    "performance/evaluation/simulated.5000000.{mapper}.{run}.{benchmark_type}.rabema_report_tsv".split(),
    mapper=PERFORMANCE_MAPPERS + PEANUT_TESTS_RECALL,
    benchmark_type="recall precision".split(),
    run=0)
PERFORMANCE_RUNTIME = expand("performance/{dataset}.{mapper}.{run}.time".split(),
    dataset="ERR281333_1.5000000 simulated.5000000 ERR281333.5000000".split(),
    mapper=PERFORMANCE_MAPPERS + PEANUT_TESTS,
    run=0)
PERFORMANCE_RUNTIME_LARGE = expand("performance/{dataset}.{mapper}.{run}.time".split(),
    dataset="ERR091787.25000000".split(),
    mapper=[mapper for mapper in PERFORMANCE_MAPPERS if mapper not in PERFORMANCE_MAPPERS_LARGE_BLACKLIST] + PEANUT_TESTS,
    run=0)
PERFORMANCE_PEANUT = expand("performance/{dataset}.{mapper}.{run}.time".split(),
    dataset="ERR281333_1.5000000 simulated.5000000 ERR281333.5000000".split(),
    mapper=PEANUT_TESTS,
    run=0)
RUN_TIME_DIST = expand(
    "plots/{dataset}.run_time_dist.pdf",
    dataset="ERR281333_1.5000000 simulated.5000000 ERR281333.5000000 ERR091787.25000000".split()
)


shell.prefix("set -o pipefail; ")

##################### targets #####################


rule all:
    input:
        "plots/occupancy.pdf",
        "plots/benchmark_best-mappers_recall.pdf",
        "plots/benchmark_best-mappers_precision.pdf",
        "plots/benchmark_all-mappers_all.pdf",
        "plots/index_size.pdf",
        "plots/mapq_fpr.pdf",
        RUN_TIME_DIST,
        PERFORMANCE_RECALL,
        PERFORMANCE_RUNTIME,
        PERFORMANCE_RUNTIME_LARGE,
        SENSITIVITY_TESTS


rule all_performance_peanut:
    input:
        PERFORMANCE_PEANUT


rule all_performance:
    input:
        PERFORMANCE_RUNTIME


rule all_run_time_dist:
    input:
        RUN_TIME_DIST


##################### general #####################


rule index_repeatcount:
    input:
        "ref/genome.fa"
    output:
        "stats/index/genome.{minrepeat}.csv", "index/genome.{minrepeat}.hdf5"
    resources: gpu=1
    shell:
        "peanut index {input} --min-repeat-count {wildcards.minrepeat} --stats {output}"


rule index_default:
    input:
        "{prefix}.fasta"
    output:
        "{prefix}.hdf5"
    shell:
        "peanut index {input} {output}"


rule sample_fastq:
    input: "reads/{dataset}.fastq.gz"
    output: "sampled/{dataset}.{reads}.fastq.gz"
    params:
        lines=lambda wildcards: str(4 * int(wildcards.reads))
    shell:
        "set +o pipefail; zcat {input} | head -n {params.lines} | gzip > {output}"


rule sam_to_bam:
    input:  "{prefix}.sam"
    output: "{prefix}.bam"
    shell:  "samtools view -Sb {input} > {output}"


rule sort_bam:
    input:  "{prefix}.bam"
    output: "{prefix}.{sorttype,(sorted|namesorted)}.bam"
    params: flags=lambda wildcards: "-n" if wildcards.sorttype == "namesorted" else ""
    shell:  "samtools sort {params.flags} {input} {wildcards.prefix}.{wildcards.sorttype}"


rule index_bam:
    input:
        "{prefix}.sorted.bam"
    output:
        "{prefix}.sorted.bam.bai"
    shell:
        "samtools index {input}"


rule unzip_fastq:
    input:  "{prefix}.fastq.gz"
    output: "{prefix}.fastq"
    shell:
        "gzip -d -c {input} > {output}"


##################### occupancy ##################

ALL_KERNELS = set("create_queries_index create_queries_occ_count create_queries_occ popcount_index filter_reference create_candidates validate_hits".split())
REPRESENTATIVE_KERNELS = {"create_queries_index": "index construction", "filter_reference": "filtration", "validate_hits": "validation"}
PROFILE = "COMPUTE_PROFILE=1 COMPUTE_PROFILE_CSV=1 COMPUTE_PROFILE_CONFIG=profile.config"


rule profile_peanut:
    input:
        "sampled/{reads}.fastq.gz", "index/genome.2500.hdf5"
    output:
        profile="profile/{reads}.{blocksize}.profile.txt"
    resources: gpu=1
    threads: 8
    shell:
        "{PROFILE} CL_LOG_ERRORS=stdout COMPUTE_PROFILE_LOG={output.profile} "
        "peanut map --threads {threads} "
        "--blocksize-filtration {wildcards.blocksize} "
        "--blocksize-validation {wildcards.blocksize} "
        "{input} > /dev/null"


rule extract_occupancy:
    input:
        expand("profile/1000000.{blocksize}.profile.txt", blocksize=range(32, 513, 32))
    output:
        csv="profile/occupancy.csv"
    run:
        kernels = ALL_KERNELS

        occupancy = defaultdict(dict)
        for f in input:
            with open(f) as f:
                reader = csv.reader(f, delimiter=",")
                for l in reader:
                    if not len(l) > 10:
                        continue
                    try:
                        method = l[0]
                        blocksize = int(l[5]) * int(l[6]) * int(l[7])
                        occ = float(l[10])
                        if method in kernels:
                            occupancy[method][blocksize] = occ
                    except ValueError:
                        continue

        with open(output.csv, "w") as out:
            blocksizes = sorted(set(b for occ in occupancy.values() for b in occ))
            print("method", *blocksizes, sep="\t", file=out)
            for method, occ in occupancy.items():
                print(method, *[occ[b] if b in occ else "-" for b in blocksizes], sep="\t", file=out)


rule plot_occupancy:
    input:
        csv="profile/occupancy.csv"
    output:
        pdf="plots/occupancy.pdf"
    run:
        figure()
        styles = "- -- :".split()
        with open(input.csv) as f:
            reader = csv.reader(f, delimiter="\t")
            header = next(reader)
            x = [int(b) for b in header[1:]]
            i = 0
            for l in reader:
                kernel = l[0]
                occupancy = l[1:]
                if kernel not in REPRESENTATIVE_KERNELS:
                    continue
                x_ = [b for b, occ in zip(x, occupancy) if occ != "-"]
                y = [float(occ) for occ in occupancy if occ != "-"]
                plt.plot(x_, y, styles[i], label=REPRESENTATIVE_KERNELS[kernel])
                i += 1
        plt.xlabel("block size")
        plt.ylabel("occupancy")
        plt.legend(loc="lower right", handlelength=2.5)
        plt.ylim([0, 1])
        plt.savefig(output.pdf)


################## sensitivity ##################


rule rabema_download_data:
    output: 
        "rabema/data/saccharomyces/genome.fasta",
        "rabema/data/saccharomyces/reads_454/SRR000853.10k.fastq"
    params:
        archive="rabema-data.tar.bz2"
    shell:  
        """
        wget -O {params.archive} http://www.seqan.de/wp-content/plugins/download-monitor/download.php?id=27
        tar -xf {params.archive}
        mv rabema-data rabema
        rm {params.archive}
        """


rule rabema_create_fastq_simulated:
    input:
        "rabema/data/saccharomyces/genome.fasta"
    output:
        "rabema/sim.{readlength,\d+}.fastq"
    shell:
        """
        mason illumina -N 10000 --no-N -n {wildcards.readlength} -sq -o {output} {input}
        rm {output}.sam
        """


rule razers3_create_sam:
    input:
        "rabema/data/saccharomyces/genome.fasta",
        "rabema/sim.{readlength}.fastq"
    output:
        "rabema/gold.{readlength}.{error,\d+}.pre.sam"
    params:
        identity=lambda wildcards: str(100 - int(wildcards.error))
    threads: 6
    shell:
        "razers3 --dont-shrink-alignments --verbose "
        "--thread-count {threads} "
        "--recognition-rate 100 "
        "--percent-identity {params.identity} "
        "--max-hits 10000000 "
        "--output {output} "
        "{input}"


rule rabema_prepare_sam:
    input:
        "rabema/gold.{readlength}.{error}.pre.sam"
    output:
        "rabema/gold.{readlength,\d+}.{error,\d+}.sam"
    shell:
        "rabema_prepare_sam -i {input} -o {output}"


rule rabema_build_gold_standard:
    input:
        ref="rabema/data/saccharomyces/genome.fasta",
        bam="rabema/gold.{readlength}.{error}.sorted.bam"
    output: "rabema/gold.{readlength}.{error}.gsi"
    shell:
        "rabema_build_gold_standard "
        "--max-error {wildcards.error} "
        "--distance-metric edit "
        "--out-gsi {output} "
        "--reference {input.ref} "
        "--in-bam {input.bam}"


rule rabema_peanut:
    input:
        ref="rabema/data/saccharomyces/genome.hdf5",
        fastq="rabema/sim.{readlength,\d+}.fastq"
    output:
        "rabema/peanut.{readlength,\d+}.{score,\d+}.bam"
    resources: gpu=1
    shell:
        "peanut map --no-xa --max-hits 1000 --semiglobal --strata all --gap-open-penalty 2 --gap-extend-penalty 1 --percent-identity {wildcards.score} {input.ref} {input.fastq} | samtools view -Sbh - > {output}"



rule debug_peanut:
    input:
        ref="rabema/data/saccharomyces/genome.hdf5",
        fastq="rabema/debug.fastq"
    output:
        aligned="rabema/debug.bam"
    resources: gpu=1
    shell:
        """
        peanut map --max-hits 1000 --semiglobal --gap-open-penalty 2 --gap-extend-penalty 1 --percent-identity 60 {input.fastq} {input.ref} | samtools view -Shb - {output}
        """


rule rabema_evaluate_simulated:
    input:
        ref="rabema/data/saccharomyces/genome.fasta",
        peanut="rabema/peanut.{readlength}.{score}.namesorted.bam",
        gold="rabema/gold.{readlength}.{error}.gsi"
    output:
        "rabema/evaluation/sim.{readlength}.{score}.{error}.txt"
    log:
        "logs/rabema_evaluate.{readlength}.{score}.{error}.log"
    shell:
        "rabema_evaluate --reference {input.ref} "
        "--show-missed-intervals "
        "--max-error {wildcards.error} "
        "--distance-metric edit "
        "--benchmark-category all "
        "--in-gsi {input.gold} "
        "--in-bam {input.peanut} > {output} 2> {log}"


######################### performance ######################


rule performance_dataset_real:
    output:
        "reads/ERR281333_{read}.fastq.gz"
    shell:
        "wget -O {output} ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR281/ERR281333/ERR281333_{wildcards.read}.fastq.gz"


rule performance_dataset_real2:
    # download an illumina platinum genome from an african male
    output:
        "reads/ERR091787_{read}.fastq.gz"
    shell:
        "wget -O {output} ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR091/ERR091787/ERR091787_{wildcards.read}.fastq.gz"


rule performance_dataset_simulated:
    input:
        "ref/genome.fa"
    output:
        "sampled/simulated.{count}.fastq.gz", "sampled/simulated.{count}.fastq.sam"
    params:
        fastq="sampled/simulated.{count}.fastq"
    shell:
        """
        mason illumina -N {wildcards.count} -n 100 -sq -o {params.fastq} {input}
        gzip {params.fastq}
        """


rule performance_gold_standard_oracle:
    input:
        bam="sampled/simulated.{count}.fastq.sorted.bam",
        ref="ref/genome.fa"
    output:
        "performance/simulated.{count,\d+}.gsi"
    shell:
        "rabema_build_gold_standard --oracle-mode -o {output} -r {input.ref} -b {input.bam}"


rule performance_gold_standard_map_all:
    input:
        "ref/genome.fa",
        "sampled/simulated.{count}.fastq"
    output:
        "sampled/simulated.{count}.gold_all.pre.sam"
    threads: 8
    shell:
        "razers3 --dont-shrink-alignments "
        "--parallel-window-size 20000 "
        "--thread-count {threads} "
        "--recognition-rate 100 "
        "--percent-identity 85 "
        "--max-hits 1000000 "
        "--output {output} "
        "{input}"


rule performance_rabema_prepare_bam:
    input:
        "sampled/simulated.{count}.gold_all.pre.sam"
    output:
        "sampled/simulated.{count}.gold_all.sam"
    shell:
        "rabema_prepare_sam -i {input} -o {output}"


rule performance_gold_standard_all:
    input:
        bam="sampled/simulated.{count}.gold_all.sorted.bam",
        ref="ref/genome.fa"
    output:
        "performance/simulated.{count}.all.gsi"
    shell:
        "rabema_build_gold_standard "
        "--max-error 15 "
        "--distance-metric edit "
        "--out-gsi {output} "
        "--reference {input.ref} "
        "--in-bam {input.bam}"


rule performance_rabema:
    input:
        ref="ref/genome.fa",
        gsi="performance/simulated.{count}.gsi",
        bam="performance/simulated.{count}.{mapper}.{run}.namesorted.bam"
    output:
        "performance/evaluation/simulated.{count,[0-9]+}.{mapper}.{run,\d+}.{type,(precision|recall)}.rabema_report_tsv"
    params:
        lambda wildcards: "--only-unique-reads" if wildcards.type == "precision" else ""
    log:
        "logs/performance/rabema_evaluate.simulated.{count}.{mapper}.{run}.{type}.log"
    shell:
        "rabema_evaluate --show-missed-intervals --show-additional-hits "
        "--show-invalid-hits --oracle-mode "
        "{params} "
        "-r {input.ref} -g {input.gsi} -b {input.bam} "
        "--out-tsv {output} 2> {log}"


rule performance_rabema_all:
    input:
        ref="ref/genome.fa",
        gsi="performance/simulated.{count}.all.gsi",
        bam="performance/simulated.{count}.{mapper}.{run}.namesorted.bam"
    output:
        "performance/evaluation/simulated.{count,[0-9]+}.{mapper}.{run,\d+}.{type,(all|any-best|all-best)}.rabema_report_tsv"
    log:
        "logs/performance/rabema_evaluate.simulated.{count}.{mapper}.{run}.{type}.log"
    shell:
        "rabema_evaluate --show-missed-intervals --show-additional-hits "
        "--show-invalid-hits "
        "--max-error 15 "
        "--distance-metric edit "
        "--benchmark-category {wildcards.type} "
        "-r {input.ref} -g {input.gsi} -b {input.bam} "
        "--out-tsv {output} 2> {log}"


RUN_PARAMS = {
    "peanut": ["", "--percent-identity 90", "--percent-identity 95"],
    "bwa_mem": ["", "-r 2.5", "-r 3.5"],
    "bowtie2": ["", "--very-sensitive", "--very-fast"],
    "razers3": ["", "--percent-identity 80", "--percent-identity 90"],
    "ngm": ["", "--min-identity 0.8", "--min-identity 0.95"],
    "cushaw2_gpu": ["", "-min_id 0.8", "-min_id 0.95"],
    "cushaw3": ["", "-min_id 0.8", "-min_id 0.95"]
}


def get_run_params(mapper):
    def _get(wildcards):
        return RUN_PARAMS[mapper][int(wildcards.run)]
    return _get


PEANUT_TYPE2PARAMS = dict(
    all="--strata all",
    strata="--strata 1",
    strata_noxa="--strata 1 --no-xa",
    all_noxa="--strata all --no-xa",
    noalign="--strata 1 --no-alignments")


rule performance_peanut:
    resources: benchmark=1, gpu=1
    input:
        fastq=get_performance_fastq, index="index/genome.2500.hdf5"
    output:
        sam="performance/{dataset}.peanut.align.{type,(all|strata|strata_noxa|all_noxa|noalign)}.{run,\d+}.sam",
        stats="performance/{dataset}.peanut.align.{type}.{run,\d+}.stats",
        time="performance/{dataset}.peanut.align.{type}.{run,\d+}.time"
    params:
        type=lambda wildcards: PEANUT_TYPE2PARAMS[wildcards.type],
        run=get_run_params("peanut")
    log:
        "logs/{dataset}.peanut.align.{type}.{run}.log"
    shell:
        "rm -f {output.time}; "
        "for i in 1 2 3; do "
        "(time peanut --stats {output.stats} map --threads 8 --query-buffer 1000000 {params.type} "
        " {input.index} {input.fastq} --semiglobal {params.run} "
        " > {output.sam} 2> {log}) 2>> {output.time}; "
        "done"


rule performance_bwa_mem:
    resources: benchmark=1
    input:
        "ref/genome.fa",
        get_performance_fastq
    output:
        sam="performance/{dataset}.bwamem.{run,\d+}.sam",
        time="performance/{dataset}.bwamem.{run,\d+}.time"
    params:
        run=get_run_params("bwa_mem")
    log:
        "logs/{dataset}.bwamem.{run}.log"
    shell:
        "rm -f {output.time}; "
        "for i in 1 2 3; "
        "do (time bwa mem {params.run} -t 8 {input} > {output.sam} 2> {log}) 2>> {output.time}; "
        "done"


rule performance_razers3:
    resources: benchmark=1
    input:
        "ref/genome.fa",
        partial(get_performance_fastq, gzip=False)
    output:
        sam="performance/{dataset}.razers3.{run,\d+}.sam",
        time="performance/{dataset}.razers3.{run,\d+}.time"
    params:
        run=get_run_params("razers3")
    log:
        "logs/{dataset}.razers3.{run}.log"
    shell:
        "rm -f {output.time}; "
        "for i in 1 2 3; "
        "do (time razers3 {params.run} -tc 8 {input} --output {output.sam} 2> {log}) 2>> {output.time}; "
        "done"


rule performance_razers3_partitioned:
    resources: benchmark=1
    input:
        ref="ref/genome.fa",
        fastq=expand("sampled/ERR091787_{read}.25000000.fastq", read="1 2".split())
    output:
        sam="performance/ERR091787.25000000.razers3.{run,\d+}.sam",
        time="performance/ERR091787.25000000.razers3.{run,\d+}.time"
    params:
        run=get_run_params("razers3"),
        lines=str(int(25000000 * 4 / 2)),
        args = " ".join(map("{:02d}".format, range(2)))
    log:
        "logs/ERR091787.25000000.razers3.{run}.log"
    run:
        seq_args = " ".join("{}.{{}}.fastq".format(f) for f in input.fastq)
        shell(
        """
        for f in {input.fastq}
        do
            rm -f $f.*
            split -a 2 -d -l {params.lines} $f $f.
            rename 's/$/.fastq/' $f.*
        done
        rm -f {output.time}
        rm -f {log}
        for i in 1 2 3
        do
            (time (
                for chunk in {params.args}
                do
                    echo "chunk $chunk"
                    razers3 {params.run} -tc 8 {input.ref} {input.fastq[0]}.$chunk.fastq {input.fastq[1]}.$chunk.fastq --output {output.sam} &>> {log}
                done
            )) 2>> {output.time}
        done
        """
        )

ruleorder: performance_razers3_partitioned > performance_razers3


rule performance_ngm:
    resources: benchmark=1
    input:
        ref="ref/genome.fa",
        fastq=get_performance_fastq
    output:
        sam="performance/{dataset}.ngm.{run,\d+}.sam",
        time="performance/{dataset}.ngm.{run,\d+}.time"
    params:
        run=get_run_params("ngm")
    log:
        "logs/{dataset}.ngm.{run}.log"
    run:
        if len(input.fastq) > 1:
            fastqs = "-1 {} -2 {}".format(*input.fastq)
        else:
            fastqs = "-q {}".format(input.fastq)
        shell("rm -f {output.time}; "
        "for i in 1 2 3; "
        "do (time ngm {params.run} --gpu -t 8 {fastqs} -r {input.ref} -o {output.sam} &> {log}) 2>> {output.time}; "
        "done")


rule performance_cushaw3:
    resources: benchmark=1
    input:
        ref="ref/genome.fa",
        fastq=get_performance_fastq
    output:
        sam="performance/{dataset}.cushaw3.{run,\d+}.sam",
        time="performance/{dataset}.cushaw3.{run,\d+}.time"
    params:
        run=get_run_params("cushaw3")
    log:
        "logs/{dataset}.cushaw3.{run}.log"
    run:
        inputparam = "-q" if len(input.fastq) > 1 else "-f"
        shell(
            "rm -f {output.time}; "
            "for i in 1 2 3; "
            "do (time cushaw3 align {params.run} -t 8 -r ref/genome.cushaw3 {inputparam} {input.fastq} -o {output.sam} &> {log}) 2>> {output.time}; "
            "done")


rule performance_cushaw2_gpu:
    resources: benchmark=1
    input:
        ref="ref/genome.fa",
        fastq=get_performance_fastq
    output:
        sam="performance/{dataset}.cushaw2-gpu.{run,\d+}.sam",
        time="performance/{dataset}.cushaw2-gpu.{run,\d+}.time"
    params:
        run=get_run_params("cushaw2_gpu")
    log:
        "logs/{dataset}.cushaw2-gpu.{run}.log"
    run:
        inputparam = "-q" if len(input.fastq) > 1 else "-f"
        shell(
            "rm -f {output.time}; "
            "for i in 1 2 3; "
            "do (time cushaw2-gpu {params.run} -t 8 {inputparam} {input.fastq} -r ref/genome.cushaw2 -o {output.sam} &> {log}) 2>> {output.time}; "
            "done")


rule performance_bowtie2:
    resources: benchmark=1
    input:
        "ref/genome.1.bt2",
        fastq=get_performance_fastq
    output:
        sam="performance/{dataset}.bowtie2.{type,(best|all)}.{run,\d+}.sam",
        time="performance/{dataset}.bowtie2.{type,(best|all)}.{run,\d+}.time"
    params:
        prefix="ref/genome",
        all=lambda wildcards: "--all" if wildcards.type == "all" else "",
        run=get_run_params("bowtie2")
    log:
        "logs/{dataset}.bowtie2.{run}.log"
    run:
        fastq = "-U {}".format(input.fastq) if len(input.fastq) == 1 else "-1 {} -2 {}".format(*input.fastq)
        shell("rm -f {output.time}; for i in 1 2 3; do (time bowtie2 {params.run} --threads 8 {params.all} {params.prefix} {fastq} > {output.sam} 2> {log}) 2>> {output.time}; done")


rule performance_nvbowtie:
    resources: benchmark=1
    input:
        "ref/genome.fa",
        fastq=get_performance_fastq
    output:
        sam="performance/{dataset}.nvbowtie.{run,\d+}.sam",
        time="performance/{dataset}.nvbowtie.{run,\d+}.time"
    params:
        prefix="ref/genome.nvbwt"
    log:
        "logs/{dataset}.nvbowtie.{run}.log"
    shell:
        "rm -f {output.time}; "
        "for i in 1 2 3; "
        "do (time nvBowtie --file-ref {params.prefix} {input.fastq} {output.sam} &> {log}) 2>> {output.time}; "
        "done"


rule performance_mrfast:
    resources: benchmark=1
    input:
        ref="ref/genome.fa",
        fastq=partial(get_performance_fastq, gzip=False)
    output:
        sam="performance/{dataset}.mrfast.{run,\d+}.sam",
        time="performance/{dataset}.mrfast.{run,\d+}.time"
    log:
        "logs/{dataset}.mrfast.{run}.log"
    run:
        _, reads = wildcards.dataset.split(".")
        chunks = 1000 if int(reads) > 5000000 else 100
        args = " ".join(map("{:04d}".format, range(chunks)))
        seq_args = "--seq {}.{{}}".format(input.fastq) if len(input.fastq) == 1 else "--pe --min 100 --max 500 --seq1 {}.{{}} --seq2 {}.{{}}".format(*input.fastq)
        shell("""
        rm -f {output.sam}.*
        for f in {input.fastq}
        do
            flines=`wc -l $f | cut -f1 -d' '`
            lines=$(($flines / {chunks}))
            rm -f $f.*
            split -a 4 -d -l $lines $f $f.
        done

        rm -f {output.time}
        for i in 1 2 3
        do
            (time parallel -j 8 -i sh -c 'mrfast --search {input.ref} {seq_args} -o {output.sam}.{{}}' -- {args} ) 2>> {output.time}
        done
        samtools view -SH {output.sam}.0000 > {output.sam}
        for f in {output.sam}.*
        do
            # ignore errors, SAM will be incomplete but we are only interested in run time
            samtools view -S $f >> {output.sam} || true
        done
        """)


####################### PLOT_BENCHMARK_RESULTS #################################

BENCHMARK_BEST_MAPPERS = [
    mapper
    for mapper in ["peanut.align.strata_noxa"] + PERFORMANCE_MAPPERS
    if mapper != "razers3" and mapper != "mrfast"  # only compare best mappers here
]
BENCHMARK_ALL_MAPPERS = "peanut.align.all_noxa razers3 mrfast".split()

MAPPER_LABELS = {
    "peanut": "PEANUT",
    "bwamem": "BWA-MEM",
    "ngm": "NextGenMap",
    "bowtie2": "Bowtie 2",
    "cushaw3": "CUSHAW3",
    "cushaw2-gpu": "CUSHAW2-GPU",
    "razers3": "RazerS 3",
    "mrfast": "MrFast"
}


def plot_benchmarks_input(wildcards):
    pattern = "performance/evaluation/simulated.{reads}.{mapper}.0.{type}.rabema_report_tsv"
    if wildcards.mapper_type == "best-mappers":
        return expand(pattern, mapper=BENCHMARK_BEST_MAPPERS, type=wildcards.type, reads=5000000)
    return expand(pattern, mapper=BENCHMARK_ALL_MAPPERS, type=wildcards.type, reads=1000)


rule plot_benchmarks:
    input:
        benchmarks=plot_benchmarks_input
    output:
        "plots/benchmark_{mapper_type}_{type}.pdf"
    run:
        mappers = BENCHMARK_BEST_MAPPERS if wildcards.mapper_type == "best-mappers" else BENCHMARK_ALL_MAPPERS

        def parse_rabema(f):
            d = np.genfromtxt(f, dtype=None, names="error_rate num_max num_found percent_found norm_max norm_found percent_norm_found".split())
            p = np.cumsum(d["norm_max"])
            tp = np.cumsum(d["norm_found"])
            rate = tp / p * 100
            return rate, d["percent_norm_found"]

        error_rate = np.arange(16)
        figure()
        styles = "- -- : k- k-- k:".split()
        for i, (mapper, benchmark_file) in enumerate(zip(mappers, input.benchmarks)):
            cumulative_results, results = parse_rabema(benchmark_file)
            results = cumulative_results
            l = results.size
            results = np.resize(results, error_rate.size)
            results[l:] = results[l-1]

            plt.plot(error_rate, results, styles[i], label=MAPPER_LABELS[mapper.split(".")[0]])
        plt.xlabel("maximum edit distance")
        plt.ylabel((wildcards.type if wildcards.type != "all" else "sensitivity") + " [%]")
        plt.legend(loc="lower left", ncol=2 if wildcards.type != "all" else 1, handlelength=2.5)
        ymin = 80 if wildcards.mapper_type == "best-mappers" else 60
        plt.ylim([ymin, 100])
        plt.xlim([0,15])

        plt.savefig(output[0])



##################### MAPQ ###############################

rule extract_mapq_fpr:
    input:
        peanut="performance/simulated.5000000.peanut.align.all_noxa.0.namesorted.bam",
        gold="sampled/simulated.5000000.fastq.sorted.bam",
        idx="sampled/simulated.5000000.fastq.sorted.bam.bai"
    output:
        pdf="performance/evaluation/mapq_fpr.txt"
    run:
        import pysam
        total_reads = 5000000
        with pysam.Samfile(input.peanut, "rb") as peanut, pysam.Samfile(input.gold, "rb") as gold:
            FP = Counter()
            P = Counter()
            n = 0
            last = None
            for i, hit in enumerate(peanut):
                if not i % 1000:
                    print("processing hit", i)
                if not hit.is_unmapped:
                    ref = peanut.getrname(hit.tid)
                    is_fp = not any(
                        a.qname == hit.qname
                        for a in gold.fetch(
                            reference=ref, start=hit.pos, end=hit.pos + 1
                        )
                    )
                    if is_fp:
                        FP[hit.mapq] += 1
                    P[hit.mapq] += 1
                if i > total_reads:
                    break
        mapqs = sorted(P)
        p = [P[m] for m in mapqs]
        fp = [FP[m] for m in mapqs]
        np.savetxt(output[0], [mapqs, p, fp])

rule plot_mapq_fpr:
    input:
        "performance/evaluation/mapq_fpr.txt"
    output:
        pdf="plots/mapq_fpr.pdf"
    run:
        d = np.loadtxt(input[0])
        mapqs = d[0]
        p = d[1]
        fp = d[2]

        figure()
        plt.plot(mapqs, fp / p, "-", label="measured FPR")
        plt.plot(mapqs, [10 ** (-q / 10) for q in mapqs], "--", label="expected FPR")
        plt.legend(loc="upper right", handlelength=2.5)
        plt.xlabel("mapping quality")
        plt.ylim([-0.01, 1])
        plt.savefig(output.pdf)


####################### index size ##########################

rule plot_index_size:
    output:
        "plots/index_size.pdf"
    run:
        figure()
        epsilon = np.linspace(0, 1)
        ratio = 1 + epsilon / (16 * (1 + epsilon))
        line, = plt.semilogx(epsilon, ratio, "-")

        calc_ratio = lambda k: (k / 16 + 2) / (k + 1)
        K = np.linspace(1, 100)
        ratio = calc_ratio(K)
        plt.semilogx(K, ratio, "-")
        plt.xlabel("K")
        plt.ylabel("index size ratio")
        #plt.xlim([1, 100])

        break_x = 16 / 15
        break_y = calc_ratio(break_x)
        
        plt.plot([break_x] * 2, [0, break_y], "k--")

        plt.savefig(output[0])


###################### run time distribution ##################

rule plot_run_time_dist:
    input:
        "performance/{dataset}.peanut.align.strata.0.stats"
    output:
        "plots/{dataset}.run_time_dist.pdf"
    run:
        durations = dict()
        with open(input[0]) as f:
            reader = csv.reader(f, delimiter="\t")
            for l in reader:
                if l[0].endswith("duration"):
                    durations[l[0].split("_")[0]] = sum(map(float, l[1:]))
        items = "index filtration validation postprocessing writing".split()
        durations = [durations[i] for i in items]
        plt.figure(figsize=(3,2))
        cmap = plt.cm.Blues_r
        colors = cmap(np.linspace(0., 1., len(items) * 2))
        # fix index label
        items[0] = "indexing"
        patches, texts, autotexts = plt.pie(durations, labels=items, autopct="%1.1f%%", colors=colors)
        for p in patches:
            p.set_edgecolor("w")
        for t in autotexts:
            t.set_color("w")
        plt.gca().set_aspect("equal")
        plt.savefig(output[0], bbox_inches="tight")


def figure(figsize=None, right=False, top=False):
    fig = plt.figure(figsize=figsize)
    ax = Subplot(fig, 111)
    ax.axis["right"].set_visible(False)
    ax.axis["top"].set_visible(False)
    fig.add_subplot(ax)