#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <strings.h>
#include <err.h>
#include "trellis.h"
#include "probe.h"

#define KMERLEN		4
static int mask2 = 3 << 4;

/* Probability of call given error and reference
 * Entries are P(call | ref, call wrong)
 * call * NCOLOR + ref
 */
/*static double errorProbs[] = {
	0.0000000, 0.3533696, 0.3716659, 0.2749646,
	0.3680965, 0.0000000, 0.3980753, 0.2338282,
	0.3671465, 0.3778348, 0.0000000, 0.2550187, 
	0.5147307, 0.2713846, 0.2138847, 0.0000000};*/

static double errorProbs[] = {
	0.0000000, 0.2448456, 0.2843043, 0.4708502,
	0.2716964, 0.0000000, 0.3940012, 0.3343024,
	0.3154797, 0.3936510, 0.0000000, 0.2908693,
	0.3078265, 0.3049675, 0.3872060, 0.0000000};

SEQ map_decoder( SEQ seq, const base initBase, const bool basespace, const double * probs){
	double *fwd=NULL, *bkd=NULL, *like=NULL;
	if(NULL==seq){ return NULL; }

	SEQ baseseq = new_SEQ(seq->len,true,false);
	if(NULL==baseseq){ goto cleanup;}
	baseseq->name = calloc(strlen(seq->name)+1,sizeof(char));
	if(NULL==baseseq->name){ goto cleanup; }
	strcpy(baseseq->name,seq->name);

	fwd = forwards_trellis(probs,initBase,seq->len,KMERLEN);
	if(NULL==fwd){ goto cleanup; }
	bkd = backwards_trellis(probs,initBase,seq->len,KMERLEN);
	if(NULL==bkd){ goto cleanup; }
	const size_t nkmer = 1<<(2*KMERLEN);

	/* Call base */
	for( int i=0 ; i<seq->len ; i++){
		const size_t offset = i*nkmer;
		double postprob[4] = {0.f};
		for( int j=0 ; j<nkmer ; j++){
			base c = basespace?(j&mask2)>>4:probe1_color[j];
			postprob[c] += fwd[offset+j] * probs[offset+j] * bkd[offset+j];
		}
		base maxb = baseA;
		double sum = postprob[0];
		for(int j=1 ; j<4 ; j++){
			if(postprob[j]>postprob[maxb]){ maxb = j; }
			sum += postprob[j];
		}
		for ( int j=0 ; j<4 ; j++){
			postprob[j] /= sum;
			postprob[j] = log(postprob[j]);
		}
		baseseq->base[i] = maxb;
		baseseq->qual[i] = postprob[maxb];
	}
	free(like);
	free(bkd);
	free(fwd);
	return baseseq;

cleanup:
	free(like);
	free(bkd);
	free(fwd);
	free(baseseq);

	return NULL;
}

/** Calculate likelihood of a given sequence on the trellis
 */
double sequence_likelihood(const SEQ seq, const base initB, const double * probs){
   if(NULL==seq){ return NAN; }
   if(NULL==probs){ return NAN; }

   double loglike=0.;
   const int nkmer = 256;
   const int mask = 255;
   int kmer = initB*64 + seq->base[0]*16 + seq->base[1]*4 + seq->base[2];
   for( int i=0 ; i<(seq->len-3) ; i++){
      loglike += log(probs[i*nkmer+kmer]);
      kmer = ((kmer<<2)&mask) + seq->base[i+3];
   }
   // Pad with A's for remaining calls (ignored).
   for ( int i=seq->len-3 ; i<seq->len ; i++){
      loglike += log(probs[i*nkmer+kmer]);
      kmer = (kmer<<2)&mask;
   }

   return loglike;
}

SEQ viterbi_decoder(SEQ seq, const base initB, const bool basespace, const double * probs){
	double *vit=NULL;
	if(NULL==probs){ return NULL;}
	if(NULL==seq){ return NULL; }

	SEQ baseseq = new_SEQ(seq->len,false,false);
	if(NULL==baseseq){ goto cleanup;}
	baseseq->name = calloc(strlen(seq->name)+1,sizeof(char));
	if(NULL==baseseq->name){ goto cleanup; }
	strcpy(baseseq->name,seq->name);

	vit = viterbi_trellis(probs,basespace?initB:baseA,seq->len,KMERLEN);
	if(NULL==vit){ goto cleanup; }

	// Decode trellis
	const size_t nkmer = 1<<(2*KMERLEN);
	int midx=0;
	double m = 0.;
	const size_t offset = seq->len * nkmer - nkmer;
	for( int i=0 ; i<nkmer ; i++){
		double p = vit[offset+i];
		if(p>m){
			m = p;
			midx = i;
		}
	}
	int kmer = midx;
	for(int pos=seq->len-1 ; pos>0 ; pos--){
		base currbase = basespace?(kmer&mask2)>>4:probe1_color[kmer];
		baseseq->base[pos] = currbase;
		int pkmer = (kmer>>2)&63;
		int midx = -1;
		double mi = -1;
		for( int i=0 ; i<4 ; i++){
			double p = vit[pos*nkmer-nkmer+pkmer+i*64];
			if(p>mi){ mi=p; midx=pkmer+i*64;}
		}
		kmer = midx;
	}
	base currbase = basespace?(kmer&mask2)>>4:probe1_color[kmer];
	baseseq->base[0] = currbase;


	free(vit);
	return baseseq;
cleanup:
	free(vit);
	free(baseseq);
	return NULL;
}

double * construct_trellis_probs(SEQ seq, const bool useEcc){
	if(NULL==seq){ return NULL; }
	if(!has_ecc(seq)){ return NULL; }
	if(!has_qualities(seq)){ return NULL;}

	size_t nkmer = 1 << 8; // = 4^4
	double * trellis_probs = calloc(seq->len*nkmer,sizeof(double));
	if(NULL==trellis_probs){ return NULL; }

	for( int i=0 ; i<seq->len ; i++){
		// Missing data => all probs are one
		double cprobs[4] = {1.,1.,1.,1.};
		double eprobs[4] = {1.,1.,1.,1.};
		if(baseN!=seq->base[i]){
			double p = exp(seq->qual[i]);
			for( int j=0 ; j<4 ; j++){
				cprobs[j] = (1.-p) * errorProbs[seq->base[i]*4+j];
			}
			cprobs[seq->base[i]] = p;
		}
		if(useEcc && baseN!=seq->ebase[i]){
			double p = (baseN!=seq->ebase[i])?exp(seq->equal[i]):1.;
			for( int j=0 ; j<4 ; j++){
				eprobs[j] = (1.-p) * errorProbs[seq->ebase[i]*4+j];;
			}
			eprobs[seq->ebase[i]] = p;
		}

		// kmer probs, assuming independence of probe sets
		for(int kmer=0 ; kmer<nkmer ; kmer++){
			trellis_probs[ i*nkmer + kmer ] = cprobs[probe1_color[kmer]] * eprobs[probe2_color[kmer]];
		}
	}
	return trellis_probs;
}

void viterbi_trellis_sub( double * prev_vit, const double * prob, const size_t nkmer, double * vit ){
	if(NULL==prev_vit){ return; }
	if(NULL==prob){ return; }
	if(NULL==vit){ return; }

	bzero(vit,nkmer*sizeof(double));
	for ( int pref=0 ; pref<(nkmer/4) ; pref++){
		double maxv = -1;
		int midx = 0;
		for( int i=0 ; i<4 ; i++){
			int k = pref + 64 * i;
			double v = prev_vit[k];
			if(v>maxv){ maxv = v; midx = k;}
		}
		int kmer_off = pref<<2;
		for ( int i=0 ; i<4 ; i++){
			vit[kmer_off+i] = maxv*prob[kmer_off+i];
		}
	}
}

double * viterbi_trellis(const double * trell_probs, const base initB, const size_t seqlen, const size_t kmerlen){
	if(NULL==trell_probs){ return NULL; }
	const size_t nkmer = 1<<(2*kmerlen);
	double * vit = calloc(seqlen*nkmer,sizeof(double));
	if(NULL==vit){return NULL;}

	// Initialise - must begin with initB
	const size_t offset = initB <<(2*kmerlen-2);
	for ( int i=0 ; i<(nkmer/4) ; i++){
		vit[offset+i] = trell_probs[offset+i];
	}
	if(baseA==initB){
		for ( int i=0 ; i<nkmer ; i++){
			vit[i] = trell_probs[i];
		}
	}

	// Move along trellis
	for ( int i=1 ; i<seqlen ; i++){
		viterbi_trellis_sub(vit+i*nkmer-nkmer,trell_probs+i*nkmer,nkmer,vit+i*nkmer);
	}

	return vit;
}

void forwards_trellis_sub ( const double * prev_fwd, const double * prev_prob, const size_t nkmer, double * fwd ){
	if(NULL==prev_fwd){ return; }
	if(NULL==prev_prob){ return; }
	if(NULL==fwd){ return;}

	int mask = nkmer-1;
	bzero(fwd,nkmer*sizeof(double));
	for ( int prevK=0 ; prevK<nkmer ; prevK++){
		const double d = prev_fwd[prevK] * prev_prob[prevK];
		int prefK = (prevK<<2)&mask;
		for ( int i=0 ; i<4 ; i++){
			fwd[prefK+i] += d;
		}
	}

	double ma=0.;
	for ( int kmer=0 ; kmer<nkmer ; kmer++){
		if(fwd[kmer]>ma){ ma = fwd[kmer]; }
	}
	for ( int kmer=0 ; kmer<nkmer ; kmer++){
		fwd[kmer] /= ma;
	}

}

double * forwards_trellis ( const double * trell_probs, const base initB, const size_t seqlen, const size_t kmerlen){
	if(NULL==trell_probs){ return NULL; }
	const size_t nkmer = 1<<(2*kmerlen);
	double * fwd = calloc(seqlen*nkmer,sizeof(double));
	if(NULL==fwd){ return NULL; }

	// Initialise - must begin with initB
	const size_t offset = initB <<(2*kmerlen-2);
	for ( int i=0 ; i<(nkmer/4) ; i++){
		fwd[offset+i] = 1.;
	}

	// Move along trellis
	for ( int i=1 ; i<seqlen ; i++){
		forwards_trellis_sub(fwd+i*nkmer-nkmer,trell_probs+i*nkmer-nkmer,nkmer,fwd+i*nkmer);
	}
	return fwd;
}


void backwards_trellis_sub ( const double * prev_bkd, const double * prev_prob, const size_t nkmer, double * bkd ){
	if(NULL==prev_bkd){ return; }
	if(NULL==prev_prob){ return; }
	if(NULL==bkd){ return;}

	int mask = nkmer-1;
	bzero(bkd,nkmer*sizeof(double));
	for ( int nextK=0 ; nextK<nkmer ; nextK++){
		int prefK = (nextK<<2)&mask;
		for ( int i=0 ; i<4 ; i++){
			bkd[nextK] += prev_bkd[prefK+i]*prev_prob[prefK+i];
		}
	}

	double ma=0.;
        for ( int kmer=0 ; kmer<nkmer ; kmer++){
                if(bkd[kmer]>ma){ ma = bkd[kmer]; }
        }
        for ( int kmer=0 ; kmer<nkmer ; kmer++){
                bkd[kmer] /= ma;
        }
}

double * backwards_trellis ( const double * trell_probs, const base initB, const size_t seqlen, const size_t kmerlen){
	if(NULL==trell_probs){ return NULL; }
	const size_t nkmer = 1<<(2*kmerlen);
	double * bkd = calloc(seqlen*nkmer,sizeof(double));
	if(NULL==bkd){ return NULL; }

	// Initialise. All kmers equally likely
	for ( int i=1 ; i<=nkmer ; i++){
		bkd[seqlen*nkmer-i] = 1.f;
	}

	// Move along trellis
	for ( int i=seqlen-1 ; i>0 ; i--){
		backwards_trellis_sub(bkd+i*nkmer,trell_probs+i*nkmer,nkmer,bkd+i*nkmer-nkmer);
	}
	return bkd;
}

double * check_likelihoods(const double * fwd, const double * bwd, const double * trell_probs, const size_t seqlen, const size_t nkmer){
	if(NULL==fwd){ return NULL; }
	if(NULL==bwd){ return NULL; }
	if(NULL==trell_probs){ return NULL; }

	double * like = calloc(seqlen,sizeof(double));
	if(NULL==like){ return NULL; }

	for ( int i=0 ; i<seqlen ; i++){
		like[i] = 0.f;
		for ( int j=0 ; j<nkmer ; j++){
			like[i] += fwd[j] * bwd[j] * trell_probs[j];
		}
		like[i] = log(like[i]);
	}

	return like;
}

