# vim: set syntax=python expandtab!:
import csv
from functools import partial
from collections import defaultdict, Counter
import matplotlib
matplotlib.use("agg")
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid.axislines import Subplot
import pysam


__author__ = "Johannes Köster"
__license__ = "MIT"


"""
PEANUT analysis workflow.
The workflow expects the homo sapiens reference as "ref/genome.fa" together with
BWA and Bowtie 2, cushaw2 (with prefix genome.cushaw2), cushaw3 (with prefix genome.cushaw3) and nvbwt (with prefix genome.nvbwt) indexes in the same directory.
This can be achieved via symlinks.
For unbiased performance estimates, it should be ensured that all files are accessible
from the same hard disk and do not reside on a network share.
Further, the workflow should not be executed multithreaded.
"""


############## definitions ################

# occupancy
BLOCKSIZES = list(range(32, 577, 32))

# sensitivity
RABEMA_SCORES = list(range(60, 101, 5))
RABEMA_ERRORS = list(range(6)) + list(range(10, 21, 5)) #+ list(range(30, 51, 10))

SENSITIVITY_TESTS = expand("rabema/evaluation/sim.{readlength}.{score}.{error}.txt", readlength=[100], score=RABEMA_SCORES, error=RABEMA_ERRORS)

# performance
PEANUT_TESTS = expand("peanut.align.{type}", type="strata all".split())
PEANUT_TESTS_RECALL = expand("peanut.align.{type}", type="strata_noxa all_noxa".split())
PERFORMANCE_MAPPERS = "bowtie2.best cushaw2-gpu bowtie2.all razers3 ngm cushaw3 bwamem nvbowtie mrfast".split()
# bowtie2.all does not terminate due to swapping
PERFORMANCE_MAPPERS.remove("bowtie2.all")
# nvbowtie is optimized for Tesla GPUs with larger memory than our test system and fails with bad_alloc
PERFORMANCE_MAPPERS.remove("nvbowtie")
# ngm exits with an error for the large dataset, razers3 badalloc, mrfast crashes
PERFORMANCE_MAPPERS_LARGE_BLACKLIST = "ngm".split()
# peanut all not needed, razers3 throwing error (TODO find better parameters for razers3)
PERFORMANCE_VS_RECALL_BLACKLIST = "razers3 peanut.align.all_noxa".split()


PERFORMANCE_FASTQ = {
	"ERR281333.5000000": expand("sampled/ERR281333_{read}.5000000.fastq.gz", read="1 2".split()),
	"ERR281333_1.5000000": ["sampled/ERR281333_1.5000000.fastq.gz"],
	"simulated.5000000": ["sampled/simulated.5000000.fastq.gz"],
	"simulated.1000": ["sampled/simulated.1000.fastq.gz"],
	"ERR091787.25000000": expand("sampled/ERR091787_{read}.25000000.fastq.gz", read="1 2".split())
}

def get_performance_fastq(wildcards, gzip=True):
	fastqs = PERFORMANCE_FASTQ[wildcards.dataset]
	if not gzip:
		fastqs = ["".join(os.path.splitext(f)[:-1]) for f in fastqs]
	return fastqs

PERFORMANCE_RECALL = expand(
	"performance/evaluation/simulated.5000000.{mapper}.{run}.{benchmark_type}.rabema_report_tsv".split(),
	mapper=PERFORMANCE_MAPPERS + PEANUT_TESTS_RECALL,
	benchmark_type="recall precision".split(),
	run=0)
PERFORMANCE_RUNTIME = expand("performance/{dataset}.{mapper}.{run}.time".split(),
	dataset="ERR281333_1.5000000 simulated.5000000 ERR281333.5000000".split(),
	mapper=PERFORMANCE_MAPPERS + PEANUT_TESTS,
	run=0)
PERFORMANCE_RUNTIME_LARGE = expand("performance/{dataset}.{mapper}.{run}.time".split(),
	dataset="ERR091787.25000000".split(),
	mapper=[mapper for mapper in PERFORMANCE_MAPPERS if mapper not in PERFORMANCE_MAPPERS_LARGE_BLACKLIST] + PEANUT_TESTS,
	run=0)
PERFORMANCE_PEANUT = expand("performance/{dataset}.{mapper}.{run}.time".split(),
	dataset="ERR281333_1.5000000 simulated.5000000 ERR281333.5000000".split(),
	mapper=PEANUT_TESTS,
	run=0)
RUN_TIME_DIST = expand(
	"plots/{dataset}.run_time_dist.pdf",
	dataset="ERR281333_1.5000000 simulated.5000000 ERR281333.5000000 ERR091787.25000000".split()
)


shell.prefix("set -o pipefail; ")

##################### targets #####################


rule all:
	input:
		"plots/occupancy.pdf",
		"plots/benchmark_best-mappers_recall.pdf",
		"plots/benchmark_best-mappers_precision.pdf",
		"plots/benchmark_all-mappers_all.pdf",
		"plots/index_size.pdf",
		"plots/mapq_fpr.pdf",
		RUN_TIME_DIST,
		PERFORMANCE_RECALL,
		PERFORMANCE_RUNTIME,
		PERFORMANCE_RUNTIME_LARGE,
		SENSITIVITY_TESTS


rule all_performance_peanut:
	input:
		PERFORMANCE_PEANUT


rule all_performance:
	input:
		PERFORMANCE_RUNTIME


rule all_run_time_dist:
	input:
		RUN_TIME_DIST


##################### general #####################


rule index_repeatcount:
	input:
		"ref/genome.fa"
	output:
		"stats/index/genome.{minrepeat}.csv", "index/genome.{minrepeat}.hdf5"
	resources: gpu=1
	shell:
		"peanut index {input} --min-repeat-count {wildcards.minrepeat} --stats {output}"


rule index_default:
	input:
		"{prefix}.fasta"
	output:
		"{prefix}.hdf5"
	shell:
		"peanut index {input} {output}"


rule sample_fastq:
	input: "reads/{dataset}.fastq.gz"
	output: "sampled/{dataset}.{reads}.fastq.gz"
	params:
		lines=lambda wildcards: str(4 * int(wildcards.reads))
	shell:
		"set +o pipefail; zcat {input} | head -n {params.lines} | gzip > {output}"


rule sam_to_bam:
	input:  "{prefix}.sam"
	output: "{prefix}.bam"
	shell:  "samtools view -Sb {input} > {output}"


rule sort_bam:
	input:  "{prefix}.bam"
	output: "{prefix}.{sorttype,(sorted|namesorted)}.bam"
	params: flags=lambda wildcards: "-n" if wildcards.sorttype == "namesorted" else ""
	shell:  "samtools sort {params.flags} {input} {wildcards.prefix}.{wildcards.sorttype}"


rule index_bam:
	input:
		"{prefix}.sorted.bam"
	output:
		"{prefix}.sorted.bam.bai"
	shell:
		"samtools index {input}"


rule unzip_fastq:
	input:  "{prefix}.fastq.gz"
	output: "{prefix}.fastq"
	shell:
		"gzip -d -c {input} > {output}"


##################### occupancy ##################

ALL_KERNELS = set("create_queries_index create_queries_occ_count create_queries_occ popcount_index filter_reference create_candidates validate_hits".split())
REPRESENTATIVE_KERNELS = {"create_queries_index": "index construction", "filter_reference": "filtration", "validate_hits": "validation"}
PROFILE = "COMPUTE_PROFILE=1 COMPUTE_PROFILE_CSV=1 COMPUTE_PROFILE_CONFIG=profile.config"


rule profile_peanut:
	input:
		"sampled/{reads}.fastq.gz", "index/genome.2500.hdf5"
	output:
		profile="profile/{reads}.{blocksize}.profile.txt"
	resources: gpu=1
	threads: 8
	shell:
		"{PROFILE} CL_LOG_ERRORS=stdout COMPUTE_PROFILE_LOG={output.profile} "
		"peanut map --threads {threads} "
		"--blocksize-filtration {wildcards.blocksize} "
		"--blocksize-validation {wildcards.blocksize} "
		"{input} > /dev/null"


rule extract_occupancy:
	input:
		expand("profile/1000000.{blocksize}.profile.txt", blocksize=range(32, 513, 32))
	output:
		csv="profile/occupancy.csv"
	run:
		kernels = ALL_KERNELS

		occupancy = defaultdict(dict)
		for f in input:
			with open(f) as f:
				reader = csv.reader(f, delimiter=",")
				for l in reader:
					if not len(l) > 10:
						continue
					try:
						method = l[0]
						blocksize = int(l[5]) * int(l[6]) * int(l[7])
						occ = float(l[10])
						if method in kernels:
							occupancy[method][blocksize] = occ
					except ValueError:
						continue

		with open(output.csv, "w") as out:
			blocksizes = sorted(set(b for occ in occupancy.values() for b in occ))
			print("method", *blocksizes, sep="\t", file=out)
			for method, occ in occupancy.items():
				print(method, *[occ[b] if b in occ else "-" for b in blocksizes], sep="\t", file=out)


rule plot_occupancy:
	input:
		csv="profile/occupancy.csv"
	output:
		pdf="plots/occupancy.pdf"
	run:
		figure()
		styles = "- -- :".split()
		with open(input.csv) as f:
			reader = csv.reader(f, delimiter="\t")
			header = next(reader)
			x = [int(b) for b in header[1:]]
			i = 0
			for l in reader:
				kernel = l[0]
				occupancy = l[1:]
				if kernel not in REPRESENTATIVE_KERNELS:
					continue
				x_ = [b for b, occ in zip(x, occupancy) if occ != "-"]
				y = [float(occ) for occ in occupancy if occ != "-"]
				plt.plot(x_, y, styles[i], label=REPRESENTATIVE_KERNELS[kernel])
				i += 1
		plt.xlabel("block size")
		plt.ylabel("occupancy")
		plt.legend(loc="lower right", handlelength=2.5)
		plt.ylim([0, 1])
		plt.savefig(output.pdf)


################## sensitivity ##################


rule rabema_download_data:
	output: 
		"rabema/data/saccharomyces/genome.fasta",
		"rabema/data/saccharomyces/reads_454/SRR000853.10k.fastq"
	params:
		archive="rabema-data.tar.bz2"
	shell:  
		"""
		wget -O {params.archive} http://www.seqan.de/wp-content/plugins/download-monitor/download.php?id=27
		tar -xf {params.archive}
		mv rabema-data rabema
		rm {params.archive}
		"""


rule rabema_create_fastq_simulated:
	input:
		"rabema/data/saccharomyces/genome.fasta"
	output:
		"rabema/sim.{readlength,\d+}.fastq"
	shell:
		"""
		mason illumina -N 10000 --no-N -n {wildcards.readlength} -sq -o {output} {input}
		rm {output}.sam
		"""


rule razers3_create_sam:
	input:
		"rabema/data/saccharomyces/genome.fasta",
		"rabema/sim.{readlength}.fastq"
	output:
		"rabema/gold.{readlength}.{error,\d+}.pre.sam"
	params:
		identity=lambda wildcards: str(100 - int(wildcards.error))
	threads: 6
	shell:
		"razers3 --dont-shrink-alignments --verbose "
		"--thread-count {threads} "
		"--recognition-rate 100 "
		"--percent-identity {params.identity} "
		"--max-hits 10000000 "
		"--output {output} "
		"{input}"


rule rabema_prepare_sam:
	input:
		"rabema/gold.{readlength}.{error}.pre.sam"
	output:
		"rabema/gold.{readlength,\d+}.{error,\d+}.sam"
	shell:
		"rabema_prepare_sam -i {input} -o {output}"


rule rabema_build_gold_standard:
	input:
		ref="rabema/data/saccharomyces/genome.fasta",
		bam="rabema/gold.{readlength}.{error}.sorted.bam"
	output: "rabema/gold.{readlength}.{error}.gsi"
	shell:
		"rabema_build_gold_standard "
		"--max-error {wildcards.error} "
		"--distance-metric edit "
		"--out-gsi {output} "
		"--reference {input.ref} "
		"--in-bam {input.bam}"


rule rabema_peanut:
	input:
		ref="rabema/data/saccharomyces/genome.hdf5",
		fastq="rabema/sim.{readlength,\d+}.fastq"
	output:
		"rabema/peanut.{readlength,\d+}.{score,\d+}.bam"
	resources: gpu=1
	shell:
		"peanut map --no-xa --max-hits 1000 --semiglobal --strata all --gap-open-penalty 2 --gap-extend-penalty 1 --percent-identity {wildcards.score} {input.ref} {input.fastq} | samtools view -Sbh - > {output}"



rule debug_peanut:
	input:
		ref="rabema/data/saccharomyces/genome.hdf5",
		fastq="rabema/debug.fastq"
	output:
		aligned="rabema/debug.bam"
	resources: gpu=1
	shell:
		"""
		peanut map --max-hits 1000 --semiglobal --gap-open-penalty 2 --gap-extend-penalty 1 --percent-identity 60 {input.fastq} {input.ref} | samtools view -Shb - {output}
		"""


rule rabema_evaluate_simulated:
	input:
		ref="rabema/data/saccharomyces/genome.fasta",
		peanut="rabema/peanut.{readlength}.{score}.namesorted.bam",
		gold="rabema/gold.{readlength}.{error}.gsi"
	output:
		"rabema/evaluation/sim.{readlength}.{score}.{error}.txt"
	log:
		"logs/rabema_evaluate.{readlength}.{score}.{error}.log"
	shell:
		"rabema_evaluate --reference {input.ref} "
		"--show-missed-intervals "
		"--max-error {wildcards.error} "
		"--distance-metric edit "
		"--benchmark-category all "
		"--in-gsi {input.gold} "
		"--in-bam {input.peanut} > {output} 2> {log}"


######################### performance ######################


rule performance_dataset_real:
	output:
		"reads/ERR281333_{read}.fastq.gz"
	shell:
		"wget -O {output} ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR281/ERR281333/ERR281333_{wildcards.read}.fastq.gz"


rule performance_dataset_real2:
	# download an illumina platinum genome from an african male
	output:
		"reads/ERR091787_{read}.fastq.gz"
	shell:
		"wget -O {output} ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR091/ERR091787/ERR091787_{wildcards.read}.fastq.gz"


rule performance_dataset_simulated:
	input:
		"ref/genome.fa"
	output:
		"sampled/simulated.{count}.fastq.gz", "sampled/simulated.{count}.fastq.sam"
	params:
		fastq="sampled/simulated.{count}.fastq"
	shell:
		"""
		mason illumina -N {wildcards.count} -n 100 -sq -o {params.fastq} {input}
		gzip {params.fastq}
		"""


rule performance_gold_standard_oracle:
	input:
		bam="sampled/simulated.{count}.fastq.sorted.bam",
		ref="ref/genome.fa"
	output:
		"performance/simulated.{count,\d+}.gsi"
	shell:
		"rabema_build_gold_standard --oracle-mode -o {output} -r {input.ref} -b {input.bam}"


rule performance_gold_standard_map_all:
	input:
		"ref/genome.fa",
		"sampled/simulated.{count}.fastq"
	output:
		"sampled/simulated.{count}.gold_all.pre.sam"
	threads: 8
	shell:
		"razers3 --dont-shrink-alignments "
		"--parallel-window-size 20000 "
		"--thread-count {threads} "
		"--recognition-rate 100 "
		"--percent-identity 85 "
		"--max-hits 1000000 "
		"--output {output} "
		"{input}"


rule performance_rabema_prepare_bam:
	input:
		"sampled/simulated.{count}.gold_all.pre.sam"
	output:
		"sampled/simulated.{count}.gold_all.sam"
	shell:
		"rabema_prepare_sam -i {input} -o {output}"


rule performance_gold_standard_all:
	input:
		bam="sampled/simulated.{count}.gold_all.sorted.bam",
		ref="ref/genome.fa"
	output:
		"performance/simulated.{count}.all.gsi"
	shell:
		"rabema_build_gold_standard "
		"--max-error 15 "
		"--distance-metric edit "
		"--out-gsi {output} "
		"--reference {input.ref} "
		"--in-bam {input.bam}"


rule performance_rabema:
	input:
		ref="ref/genome.fa",
		gsi="performance/simulated.{count}.gsi",
		bam="performance/simulated.{count}.{mapper}.{run}.namesorted.bam"
	output:
		"performance/evaluation/simulated.{count,[0-9]+}.{mapper}.{run,\d+}.{type,(precision|recall)}.rabema_report_tsv"
	params:
		lambda wildcards: "--only-unique-reads" if wildcards.type == "precision" else ""
	log:
		"logs/performance/rabema_evaluate.simulated.{count}.{mapper}.{run}.{type}.log"
	shell:
		"rabema_evaluate --show-missed-intervals --show-additional-hits "
		"--show-invalid-hits --oracle-mode "
		"{params} "
		"-r {input.ref} -g {input.gsi} -b {input.bam} "
		"--out-tsv {output} 2> {log}"


rule performance_rabema_all:
	input:
		ref="ref/genome.fa",
		gsi="performance/simulated.{count}.all.gsi",
		bam="performance/simulated.{count}.{mapper}.{run}.namesorted.bam"
	output:
		"performance/evaluation/simulated.{count,[0-9]+}.{mapper}.{run,\d+}.{type,(all|any-best|all-best)}.rabema_report_tsv"
	log:
		"logs/performance/rabema_evaluate.simulated.{count}.{mapper}.{run}.{type}.log"
	shell:
		"rabema_evaluate --show-missed-intervals --show-additional-hits "
		"--show-invalid-hits "
		"--max-error 15 "
		"--distance-metric edit "
		"--benchmark-category {wildcards.type} "
		"-r {input.ref} -g {input.gsi} -b {input.bam} "
		"--out-tsv {output} 2> {log}"


RUN_PARAMS = {
	"peanut": ["", "--percent-identity 90", "--percent-identity 95"],
	"bwa_mem": ["", "-r 2.5", "-r 3.5"],
	"bowtie2": ["", "--very-sensitive", "--very-fast"],
	"razers3": ["", "--percent-identity 80", "--percent-identity 90"],
	"ngm": ["", "--min-identity 0.8", "--min-identity 0.95"],
	"cushaw2_gpu": ["", "-min_id 0.8", "-min_id 0.95"],
	"cushaw3": ["", "-min_id 0.8", "-min_id 0.95"]
}


def get_run_params(mapper):
	def _get(wildcards):
		return RUN_PARAMS[mapper][int(wildcards.run)]
	return _get


PEANUT_TYPE2PARAMS = dict(
	all="--strata all",
	strata="--strata 1",
	strata_noxa="--strata 1 --no-xa",
	all_noxa="--strata all --no-xa",
	noalign="--strata 1 --no-alignments")


rule performance_peanut:
	resources: benchmark=1, gpu=1
	input:
		fastq=get_performance_fastq, index="index/genome.2500.hdf5"
	output:
		sam="performance/{dataset}.peanut.align.{type,(all|strata|strata_noxa|all_noxa|noalign)}.{run,\d+}.sam",
		stats="performance/{dataset}.peanut.align.{type}.{run,\d+}.stats",
		time="performance/{dataset}.peanut.align.{type}.{run,\d+}.time"
	params:
		type=lambda wildcards: PEANUT_TYPE2PARAMS[wildcards.type],
		run=get_run_params("peanut")
	log:
		"logs/{dataset}.peanut.align.{type}.{run}.log"
	shell:
		"rm -f {output.time}; "
		"for i in 1 2 3; do "
		"(time peanut --stats {output.stats} map --threads 8 --query-buffer 1000000 {params.type} "
		" {input.index} {input.fastq} --semiglobal {params.run} "
		" > {output.sam} 2> {log}) 2>> {output.time}; "
		"done"


rule performance_bwa_mem:
	resources: benchmark=1
	input:
		"ref/genome.fa",
		get_performance_fastq
	output:
		sam="performance/{dataset}.bwamem.{run,\d+}.sam",
		time="performance/{dataset}.bwamem.{run,\d+}.time"
	params:
		run=get_run_params("bwa_mem")
	log:
		"logs/{dataset}.bwamem.{run}.log"
	shell:
		"rm -f {output.time}; "
		"for i in 1 2 3; "
		"do (time bwa mem {params.run} -t 8 {input} > {output.sam} 2> {log}) 2>> {output.time}; "
		"done"


rule performance_razers3:
	resources: benchmark=1
	input:
		"ref/genome.fa",
		partial(get_performance_fastq, gzip=False)
	output:
		sam="performance/{dataset}.razers3.{run,\d+}.sam",
		time="performance/{dataset}.razers3.{run,\d+}.time"
	params:
		run=get_run_params("razers3")
	log:
		"logs/{dataset}.razers3.{run}.log"
	shell:
		"rm -f {output.time}; "
		"for i in 1 2 3; "
		"do (time razers3 {params.run} -tc 8 {input} --output {output.sam} 2> {log}) 2>> {output.time}; "
		"done"


rule performance_razers3_partitioned:
	resources: benchmark=1
	input:
		ref="ref/genome.fa",
		fastq=expand("sampled/ERR091787_{read}.25000000.fastq", read="1 2".split())
	output:
		sam="performance/ERR091787.25000000.razers3.{run,\d+}.sam",
		time="performance/ERR091787.25000000.razers3.{run,\d+}.time"
	params:
		run=get_run_params("razers3"),
		lines=str(int(25000000 * 4 / 2)),
		args = " ".join(map("{:02d}".format, range(2)))
	log:
		"logs/ERR091787.25000000.razers3.{run}.log"
	run:
		seq_args = " ".join("{}.{{}}.fastq".format(f) for f in input.fastq)
		shell(
		"""
		for f in {input.fastq}
		do
			rm -f $f.*
			split -a 2 -d -l {params.lines} $f $f.
			rename 's/$/.fastq/' $f.*
		done
		rm -f {output.time}
		rm -f {log}
		for i in 1 2 3
		do
			(time (
				for chunk in {params.args}
				do
					echo "chunk $chunk"
					razers3 {params.run} -tc 8 {input.ref} {input.fastq[0]}.$chunk.fastq {input.fastq[1]}.$chunk.fastq --output {output.sam} &>> {log}
				done
			)) 2>> {output.time}
		done
		"""
		)

ruleorder: performance_razers3_partitioned > performance_razers3


rule performance_ngm:
	resources: benchmark=1
	input:
		ref="ref/genome.fa",
		fastq=get_performance_fastq
	output:
		sam="performance/{dataset}.ngm.{run,\d+}.sam",
		time="performance/{dataset}.ngm.{run,\d+}.time"
	params:
		run=get_run_params("ngm")
	log:
		"logs/{dataset}.ngm.{run}.log"
	run:
		if len(input.fastq) > 1:
			fastqs = "-1 {} -2 {}".format(*input.fastq)
		else:
			fastqs = "-q {}".format(input.fastq)
		shell("rm -f {output.time}; "
		"for i in 1 2 3; "
		"do (time ngm {params.run} --gpu -t 8 {fastqs} -r {input.ref} -o {output.sam} &> {log}) 2>> {output.time}; "
		"done")


rule performance_cushaw3:
	resources: benchmark=1
	input:
		ref="ref/genome.fa",
		fastq=get_performance_fastq
	output:
		sam="performance/{dataset}.cushaw3.{run,\d+}.sam",
		time="performance/{dataset}.cushaw3.{run,\d+}.time"
	params:
		run=get_run_params("cushaw3")
	log:
		"logs/{dataset}.cushaw3.{run}.log"
	run:
		inputparam = "-q" if len(input.fastq) > 1 else "-f"
		shell(
			"rm -f {output.time}; "
			"for i in 1 2 3; "
			"do (time cushaw3 align {params.run} -t 8 -r ref/genome.cushaw3 {inputparam} {input.fastq} -o {output.sam} &> {log}) 2>> {output.time}; "
			"done")


rule performance_cushaw2_gpu:
	resources: benchmark=1
	input:
		ref="ref/genome.fa",
		fastq=get_performance_fastq
	output:
		sam="performance/{dataset}.cushaw2-gpu.{run,\d+}.sam",
		time="performance/{dataset}.cushaw2-gpu.{run,\d+}.time"
	params:
		run=get_run_params("cushaw2_gpu")
	log:
		"logs/{dataset}.cushaw2-gpu.{run}.log"
	run:
		inputparam = "-q" if len(input.fastq) > 1 else "-f"
		shell(
			"rm -f {output.time}; "
			"for i in 1 2 3; "
			"do (time cushaw2-gpu {params.run} -t 8 {inputparam} {input.fastq} -r ref/genome.cushaw2 -o {output.sam} &> {log}) 2>> {output.time}; "
			"done")


rule performance_bowtie2:
	resources: benchmark=1
	input:
		"ref/genome.1.bt2",
		fastq=get_performance_fastq
	output:
		sam="performance/{dataset}.bowtie2.{type,(best|all)}.{run,\d+}.sam",
		time="performance/{dataset}.bowtie2.{type,(best|all)}.{run,\d+}.time"
	params:
		prefix="ref/genome",
		all=lambda wildcards: "--all" if wildcards.type == "all" else "",
		run=get_run_params("bowtie2")
	log:
		"logs/{dataset}.bowtie2.{run}.log"
	run:
		fastq = "-U {}".format(input.fastq) if len(input.fastq) == 1 else "-1 {} -2 {}".format(*input.fastq)
		shell("rm -f {output.time}; for i in 1 2 3; do (time bowtie2 {params.run} --threads 8 {params.all} {params.prefix} {fastq} > {output.sam} 2> {log}) 2>> {output.time}; done")


rule performance_nvbowtie:
	resources: benchmark=1
	input:
		"ref/genome.fa",
		fastq=get_performance_fastq
	output:
		sam="performance/{dataset}.nvbowtie.{run,\d+}.sam",
		time="performance/{dataset}.nvbowtie.{run,\d+}.time"
	params:
		prefix="ref/genome.nvbwt"
	log:
		"logs/{dataset}.nvbowtie.{run}.log"
	shell:
		"rm -f {output.time}; "
		"for i in 1 2 3; "
		"do (time nvBowtie --file-ref {params.prefix} {input.fastq} {output.sam} &> {log}) 2>> {output.time}; "
		"done"


rule performance_mrfast:
	resources: benchmark=1
	input:
		ref="ref/genome.fa",
		fastq=partial(get_performance_fastq, gzip=False)
	output:
		sam="performance/{dataset}.mrfast.{run,\d+}.sam",
		time="performance/{dataset}.mrfast.{run,\d+}.time"
	log:
		"logs/{dataset}.mrfast.{run}.log"
	run:
		_, reads = wildcards.dataset.split(".")
		chunks = 1000 if int(reads) > 5000000 else 100
		args = " ".join(map("{:04d}".format, range(chunks)))
		seq_args = "--seq {}.{{}}".format(input.fastq) if len(input.fastq) == 1 else "--pe --min 100 --max 500 --seq1 {}.{{}} --seq2 {}.{{}}".format(*input.fastq)
		shell("""
		rm -f {output.sam}.*
		for f in {input.fastq}
		do
			flines=`wc -l $f | cut -f1 -d' '`
			lines=$(($flines / {chunks}))
			rm -f $f.*
			split -a 4 -d -l $lines $f $f.
		done

		rm -f {output.time}
		for i in 1 2 3
		do
			(time parallel -j 8 -i sh -c 'mrfast --search {input.ref} {seq_args} -o {output.sam}.{{}}' -- {args} ) 2>> {output.time}
		done
		samtools view -SH {output.sam}.0000 > {output.sam}
		for f in {output.sam}.*
		do
			# ignore errors, SAM will be incomplete but we are only interested in run time
			samtools view -S $f >> {output.sam} || true
		done
		""")


####################### PLOT_BENCHMARK_RESULTS #################################

BENCHMARK_BEST_MAPPERS = [
	mapper
	for mapper in ["peanut.align.strata_noxa"] + PERFORMANCE_MAPPERS
	if mapper != "razers3" and mapper != "mrfast"  # only compare best mappers here
]
BENCHMARK_ALL_MAPPERS = "peanut.align.all_noxa razers3 mrfast".split()

MAPPER_LABELS = {
	"peanut": "PEANUT",
	"bwamem": "BWA-MEM",
	"ngm": "NextGenMap",
	"bowtie2": "Bowtie 2",
	"cushaw3": "CUSHAW3",
	"cushaw2-gpu": "CUSHAW2-GPU",
	"razers3": "RazerS 3",
	"mrfast": "MrFast"
}


def plot_benchmarks_input(wildcards):
	pattern = "performance/evaluation/simulated.{reads}.{mapper}.0.{type}.rabema_report_tsv"
	if wildcards.mapper_type == "best-mappers":
		return expand(pattern, mapper=BENCHMARK_BEST_MAPPERS, type=wildcards.type, reads=5000000)
	return expand(pattern, mapper=BENCHMARK_ALL_MAPPERS, type=wildcards.type, reads=1000)


rule plot_benchmarks:
	input:
		benchmarks=plot_benchmarks_input
	output:
		"plots/benchmark_{mapper_type}_{type}.pdf"
	run:
		mappers = BENCHMARK_BEST_MAPPERS if wildcards.mapper_type == "best-mappers" else BENCHMARK_ALL_MAPPERS

		def parse_rabema(f):
			d = np.genfromtxt(f, dtype=None, names="error_rate num_max num_found percent_found norm_max norm_found percent_norm_found".split())
			p = np.cumsum(d["norm_max"])
			tp = np.cumsum(d["norm_found"])
			rate = tp / p * 100
			return rate, d["percent_norm_found"]

		error_rate = np.arange(16)
		figure()
		styles = "- -- : k- k-- k:".split()
		for i, (mapper, benchmark_file) in enumerate(zip(mappers, input.benchmarks)):
			cumulative_results, results = parse_rabema(benchmark_file)
			results = cumulative_results
			l = results.size
			results = np.resize(results, error_rate.size)
			results[l:] = results[l-1]

			plt.plot(error_rate, results, styles[i], label=MAPPER_LABELS[mapper.split(".")[0]])
		plt.xlabel("maximum edit distance")
		plt.ylabel((wildcards.type if wildcards.type != "all" else "sensitivity") + " [%]")
		plt.legend(loc="lower left", ncol=2 if wildcards.type != "all" else 1, handlelength=2.5)
		ymin = 80 if wildcards.mapper_type == "best-mappers" else 60
		plt.ylim([ymin, 100])
		plt.xlim([0,15])

		plt.savefig(output[0])



##################### MAPQ ###############################

rule extract_mapq_fpr:
	input:
		peanut="performance/simulated.5000000.peanut.align.all_noxa.0.namesorted.bam",
		gold="sampled/simulated.5000000.fastq.sorted.bam",
		idx="sampled/simulated.5000000.fastq.sorted.bam.bai"
	output:
		pdf="performance/evaluation/mapq_fpr.txt"
	run:
		import pysam
		total_reads = 5000000
		with pysam.Samfile(input.peanut, "rb") as peanut, pysam.Samfile(input.gold, "rb") as gold:
			FP = Counter()
			P = Counter()
			n = 0
			last = None
			for i, hit in enumerate(peanut):
				if not i % 1000:
					print("processing hit", i)
				if not hit.is_unmapped:
					ref = peanut.getrname(hit.tid)
					is_fp = not any(
						a.qname == hit.qname
						for a in gold.fetch(
							reference=ref, start=hit.pos, end=hit.pos + 1
						)
					)
					if is_fp:
						FP[hit.mapq] += 1
					P[hit.mapq] += 1
				if i > total_reads:
					break
		mapqs = sorted(P)
		p = [P[m] for m in mapqs]
		fp = [FP[m] for m in mapqs]
		np.savetxt(output[0], [mapqs, p, fp])

rule plot_mapq_fpr:
	input:
		"performance/evaluation/mapq_fpr.txt"
	output:
		pdf="plots/mapq_fpr.pdf"
	run:
		d = np.loadtxt(input[0])
		mapqs = d[0]
		p = d[1]
		fp = d[2]

		figure()
		plt.plot(mapqs, fp / p, "-", label="measured FPR")
		plt.plot(mapqs, [10 ** (-q / 10) for q in mapqs], "--", label="expected FPR")
		plt.legend(loc="upper right", handlelength=2.5)
		plt.xlabel("mapping quality")
		plt.ylim([-0.01, 1])
		plt.savefig(output.pdf)


####################### index size ##########################

rule plot_index_size:
	output:
		"plots/index_size.pdf"
	run:
		figure()
		epsilon = np.linspace(0, 1)
		ratio = 1 + epsilon / (16 * (1 + epsilon))
		line, = plt.semilogx(epsilon, ratio, "-")

		calc_ratio = lambda k: (k / 16 + 2) / (k + 1)
		K = np.linspace(1, 100)
		ratio = calc_ratio(K)
		plt.semilogx(K, ratio, "-")
		plt.xlabel("K")
		plt.ylabel("index size ratio")
		#plt.xlim([1, 100])

		break_x = 16 / 15
		break_y = calc_ratio(break_x)
		
		plt.plot([break_x] * 2, [0, break_y], "k--")

		plt.savefig(output[0])


###################### run time distribution ##################

rule plot_run_time_dist:
	input:
		"performance/{dataset}.peanut.align.strata.0.stats"
	output:
		"plots/{dataset}.run_time_dist.pdf"
	run:
		durations = dict()
		with open(input[0]) as f:
			reader = csv.reader(f, delimiter="\t")
			for l in reader:
				if l[0].endswith("duration"):
					durations[l[0].split("_")[0]] = sum(map(float, l[1:]))
		items = "index filtration validation postprocessing writing".split()
		durations = [durations[i] for i in items]
		plt.figure(figsize=(3,2))
		cmap = plt.cm.Blues_r
		colors = cmap(np.linspace(0., 1., len(items) * 2))
		# fix index label
		items[0] = "indexing"
		patches, texts, autotexts = plt.pie(durations, labels=items, autopct="%1.1f%%", colors=colors)
		for p in patches:
			p.set_edgecolor("w")
		for t in autotexts:
			t.set_color("w")
		plt.gca().set_aspect("equal")
		plt.savefig(output[0], bbox_inches="tight")


def figure(figsize=None, right=False, top=False):
    fig = plt.figure(figsize=figsize)
    ax = Subplot(fig, 111)
    ax.axis["right"].set_visible(False)
    ax.axis["top"].set_visible(False)
    fig.add_subplot(ax)
