/*
 * File:        cric.cpp 
 * Author:      Bram Kuijvenhoven (bkuijvenhoven@student.tudelft.nl)
 * Date:        2005/04/11 [yyyy/mm/dd]
 * Description: Competing Risks Interval Censoring solver
 */

#include <iostream>
#include "cric.h"

#define SQR(x) ((x)*(x))

/* reading/writing TCRICSamples
 */

TCRICSample *ReadCRICSamples(istream &is, int &sampleSize, int &numCauses) {

	is >> sampleSize;
	is >> numCauses;
	TCRICSample *samples = new TCRICSample[sampleSize];
	for (int i = 0; i < sampleSize; i++) {
		is >> samples[i].t1 >> samples[i].t2 >> samples[i].k1 >> samples[i].k2;
	}
	return samples;

}

void WriteCRICSamples(ostream &os, TCRICSample *samples, int sampleSize, int numCauses) {

	os << sampleSize << endl;
	os << numCauses << endl;
	for (int i = 0; i < sampleSize; i++) {
		os << " " << setw(3) << samples[i].t1
		   << " " << setw(3) << samples[i].t2
		   << " " << setw(3) << samples[i].k1
		   << " " << setw(3) << samples[i].k2 << endl;
	}

}

typedef struct {
	double t;      // time
	int    sample; // index in sample array
	int    which;  // 0: refers to a t1; 1: refers to a t2
	int    unique; // index in fTime array (unique times)
} TSampleTime;

typedef struct {
	int k1, k2; // k1 and k2 from original sample in format of TCRICSample (so 1-based)
	int u1, u2; // indices in fTime array (unique times)
} TProcessedSample;

/* Helper output routines
 */
void WriteProcessedSamples(ostream &os, TProcessedSample *processedSamples, int processedSampleSize) {

	os << "  u1  u2  k1  k2" << endl;
	for (int i = 0; i < processedSampleSize; i++) {
		os << " " << setw(3) << processedSamples[i].u1
		   << " " << setw(3) << processedSamples[i].u2
		   << " " << setw(3) << processedSamples[i].k1
		   << " " << setw(3) << processedSamples[i].k2 << endl; 
	}

}

/* Compare procs
 */
int CompareSampleTime(const void *a, const void *b) {
	if (((TSampleTime *) a)->t < ((TSampleTime *) b)->t) return -1;
	if (((TSampleTime *) a)->t > ((TSampleTime *) b)->t) return 1;
	return 0;
}

int CompareProcessedSample(const void *a, const void *b) {
	// first sort on u1
	if (((TProcessedSample *) a)->u1 < ((TProcessedSample *) b)->u1) return -1;
	if (((TProcessedSample *) a)->u1 > ((TProcessedSample *) b)->u1) return 1;
	// then on u2
	if (((TProcessedSample *) a)->u2 < ((TProcessedSample *) b)->u2) return -1;
	if (((TProcessedSample *) a)->u2 > ((TProcessedSample *) b)->u2) return 1;
	// then on k2
	if (((TProcessedSample *) a)->k2 < ((TProcessedSample *) b)->k2) return -1;
	if (((TProcessedSample *) a)->k2 > ((TProcessedSample *) b)->k2) return 1;
	// equal
	return 0;
}


void CCRICSolver::SetSample(TCRICSample *sample, int sampleSize, int numCauses) {

	/*
	 * This routine will convert the sample data to the fTime, fNi, fNki, fNkij
	 *
	 * - all t1/t2 values occuring in the sample data will be put in the time array
	 * - the time array is sorted
	 * - the fTime array is constructed, which stores unique observation times
	 * - the processedSample array is filled
	 * - the processedSample array is sorted
	 * - the fNiSize, fNkiSize, fNkijSize variables/arrays are filled
	 * - the fNi, fNki, fNkij arrays are filled
	 * - the fDistIndex and fDistSize arrays are calculated
	 */

	cerr << "CCRICsolver::SetSample(..) entered" << endl;
	
	int i;
	int k;
	int t;

	TSampleTime *time; 
	int          timeSize = 0;

	TProcessedSample *processedSample;
	int               processedSampleSize = 0;

	// set fNumDists
	fNumDists = numCauses;

	cerr << "Determining timeSize" << endl;

	// determine timeSize
	for (i = 0; i < sampleSize; i++) {
		if (sample[i].k1 >= 0) timeSize++;
		if (sample[i].k2 >= 0) timeSize++;
	}

	cerr << "Filling time array" << endl;

	// fill time array
	time = new TSampleTime[timeSize];
	t = 0;
	for (i = 0; i < sampleSize; i++) {
		if (sample[i].k1 >= 0) { time[t].t = sample[i].t1; time[t].which = 0; time[t].sample = i; t++; }
		if (sample[i].k2 >= 0) { time[t].t = sample[i].t2; time[t].which = 1; time[t].sample = i; t++; }
	}

	cerr << "Sorting time array" << endl;
	
	// sort time array
	qsort(time, timeSize, sizeof(TSampleTime), CompareSampleTime);

	cerr << "Determining fTimeSize" << endl;
	
	// determine fTimeSize (number of unique times)
	double lastTime = 0.0; // note: actually no init is required here
	fTimeSize = 0;
	for (i = 0; i < timeSize; i++) {
		if (time[i].t != lastTime || i == 0) {
			lastTime = time[i].t;
			fTimeSize++;
		}
	}

	cerr << "Filling fTime array" << endl;
	
	// fill fTime array (unique times)
	fTime = new double[fTimeSize];
	t = 0;
	for (i = 0; i < timeSize; i++) {
		if (time[i].t != lastTime || i == 0) {
			lastTime = time[i].t;
			fTime[t] = lastTime;
			t++;
		}
		time[i].unique = t-1;
	}

	cerr << "Filling processedSample array" << endl;

	processedSampleSize = sampleSize;
	processedSample = new TProcessedSample[processedSampleSize];
	// set k1, k2 members
	for (i = 0; i < sampleSize; i++) {
		processedSample[i].k1 = sample[i].k1;
		processedSample[i].k2 = sample[i].k2;
		processedSample[i].u1 = -1; // for sanity intialize u1 and u2 to invalid value
		processedSample[i].u2 = -1;
	}
	// set u1, u2 members
	for (i = 0; i < timeSize; i++) {
		if (time[i].which) { // time[i].which == 1, indicating a t2
			processedSample[time[i].sample].u2 = time[i].unique;
		} else { // time[i].which == 0, indicating a t1
			processedSample[time[i].sample].u1 = time[i].unique;
		}
	}

	cerr << "Sorting processedSample array" << endl;

	qsort(processedSample, processedSampleSize, sizeof(TProcessedSample), CompareProcessedSample);

	//WriteProcessedSamples(cerr, processedSample, processedSampleSize);

	cerr << "Determining fNiSize, fNkiSize, fNkijSize" << endl;

	// determine fNiSize, fNkiSize, fNkijSize
	fNiSize = 0;
	fNkiSize  = new int[numCauses];
	fNkijSize = new int[numCauses];
	for (k = 0; k < numCauses; k++) {
		fNkiSize[k]  = 0;
		fNkijSize[k] = 0;
	}
	int lastU1 = 0, lastU2 = 0, lastK2 = 0;
	for (i = 0; i < processedSampleSize; i++) {
		if (processedSample[i].k2 != lastK2 || processedSample[i].u2 != lastU2 || processedSample[i].u1 != lastU1 || i == 0) {
			// new unique sample
			lastU1 = processedSample[i].u1;
			lastU2 = processedSample[i].u2;
			lastK2 = processedSample[i].k2;
			// update Size info
			if (lastU1 == -1) {
				fNkiSize[lastK2-1]++;
			} else if (lastU2 == -1) {
				fNiSize++;
			} else {
				fNkijSize[lastK2-1]++;
			}
		}
	}

	cerr << "Allocating fNi, fNki, fNkij arrays" << endl;
	
	// allocate fNi, fNki, fNkij arrays
	fNi   = new TNi[fNiSize];
	fNki  = new TNi*[numCauses];
	fNkij = new TNij*[numCauses];
	for (k = 0; k < numCauses; k++) {
		fNki[k]  = new TNi[fNkiSize[k]];
		fNkij[k] = new TNij[fNkijSize[k]];
	}

	cerr << "Filling fNi, fNki, fNkij arrays" << endl;

	// fill fNi, fNki, fNkij arrays
	fNiSize = 0;
	fNkiSize  = new int[numCauses];
	fNkijSize = new int[numCauses];
	for (k = 0; k < numCauses; k++) {
		fNkiSize[k]  = 0;
		fNkijSize[k] = 0;
	}
	for (i = 0; i < processedSampleSize; i++) {
		if (processedSample[i].k2 != lastK2 || processedSample[i].u2 != lastU2 || processedSample[i].u1 != lastU1 || i == 0) {
			// new unique sample
			lastU1 = processedSample[i].u1;
			lastU2 = processedSample[i].u2;
			lastK2 = processedSample[i].k2;
			// update Size and add entry
			if (lastU1 == -1) {
				fNki[lastK2-1][fNkiSize[lastK2-1]].i = lastU2;
				fNki[lastK2-1][fNkiSize[lastK2-1]].n = 1;
				fNkiSize[lastK2-1]++;
			} else if (lastU2 == -1) {
				fNi[fNiSize].i = lastU1;
				fNi[fNiSize].n = 1;
				fNiSize++;
			} else {
				fNkij[lastK2-1][fNkijSize[lastK2-1]].i = lastU1;
				fNkij[lastK2-1][fNkijSize[lastK2-1]].j = lastU2;
				fNkij[lastK2-1][fNkijSize[lastK2-1]].n = 1;
				fNkijSize[lastK2-1]++;
			}
		} else {
			// repeating the same sample as previous iteration
			if (lastU1 == -1) {
				fNki[lastK2-1][fNkiSize[lastK2-1]-1].n++;
			} else if (lastU2 == -1) {
				fNi[fNiSize-1].n++;
			} else {
				fNkij[lastK2-1][fNkijSize[lastK2-1]-1].n++;
			}
		}
	}

	// calculate fNeedLagrangian
	fNeedLagrangian = fNi[fNiSize-1].i < fTimeSize-1;
	fLambda = sampleSize;
	fEpsilon = fAccuracy*sampleSize;

	cerr << "fNeedLagrangian: " << fNeedLagrangian << endl;

	cerr << "Calculating fDistIndex, fDistSize" << endl;
	
	// calculate fDistIndex, fDistSize
	fDistIndex = new int *[numCauses];
	fDistSize = new int[numCauses];
	for (k = 0; k < numCauses; k++) {
		// alloc and init fDistIndex[k]
		fDistIndex[k] = new int[fTimeSize];
		for (i = 0; i < fTimeSize; i++) {
			fDistIndex[k][i] = -1;
		}
		// set entries in fDistIndex[k] that need to become an index to 0
		for (i = 0; i < fNiSize; i++) {
			fDistIndex[k][fNi[i].i] = 0;
		}
		for (i = 0; i < fNkiSize[k]; i++) {
			fDistIndex[k][fNki[k][i].i] = 0;
		}
		for (i = 0; i < fNkijSize[k]; i++) {
			fDistIndex[k][fNkij[k][i].i] = 0;
			fDistIndex[k][fNkij[k][i].j] = 0;
		}
		// calc fDistSize while enumerating the entries in fDistIndex that are 0
		fDistSize[k] = 0;
		for (i = 0; i < fTimeSize; i++) {
			if (fDistIndex[k][i] == 0) {
				fDistIndex[k][i] = fDistSize[k];
				fDistSize[k]++;
			}
		}
	}

	cerr << "Writing tables to fOutput" << endl;

	// write tables to fOutput: K, T, distIndex
	*fOutput << "K: " << fNumDists << endl;
	*fOutput << "T: " << endl;
	for (i = 0; i < fTimeSize; i++)
		*fOutput << setw(6) << fTime[i] << endl;
	for (k = 0; k < fNumDists; k++) {
		*fOutput << "distIndex{" << (k+1) << "}:" << endl;
		for (i = 0; i < fTimeSize; i++)
			*fOutput << setw(6) << (fDistIndex[k][i]+1) << endl;
	}
	// FullOutput tables
	if (fFullOutput) {
		*fOutput << "needLagrangian: " << fNeedLagrangian << endl;
		*fOutput << "Ni::" << endl;
		*fOutput << setw(6) << "i" << setw(6) << "n" << endl;
		for (i = 0; i < fNiSize; i++)
			*fOutput << " " << setw(5) << (fNi[i].i+1) << " " << setw(5) << fNi[i].n << endl;
		for (k = 0; k < fNumDists; k++) {
			*fOutput << "Nki(" << (k+1) << ")::" << endl;
			*fOutput << setw(6) << "i" << setw(6) << "n" << endl;
			for (i = 0; i < fNkiSize[k]; i++)
				*fOutput << " " << setw(5) << (fNki[k][i].i+1) << " " << setw(5) << fNki[k][i].n << endl;
			*fOutput << "Nkij(" << (k+1) << ")::" << endl;
			*fOutput << setw(6) << "i" << setw(6) << "j" << setw(6) << "n" << endl;
			for (i = 0; i < fNkijSize[k]; i++)
				*fOutput << " " << setw(5) << (fNkij[k][i].i+1) << " " << setw(5) << (fNkij[k][i].j+1) << " " << setw(5) << fNkij[k][i].n << endl;
		}
	}

	cerr << "Freeing local arrays" << endl;

	// free local arrays
	delete time;
	delete processedSample;

	cerr << "CCRICsolver::SetSample(..) left" << endl;
}

double CCRICSolver::Phi(double **dist) {
	int k, i; // counter variables
	double phi = 0.0; // result variable
	for (k = 0; k < fNumDists; k++) {
		// Nki terms
		for (i = 0; i < fNkiSize[k]; i++) {
			TNi *Ni = &fNki[k][i];
			phi -= Ni->n*log(dist[k][fDistIndex[k][Ni->i]]);
		}
		// Nkij terms
		for (i = 0; i < fNkijSize[k]; i++) {
			TNij *Nij = &fNkij[k][i];
			phi -= Nij->n*log(dist[k][fDistIndex[k][Nij->j]]-dist[k][fDistIndex[k][Nij->i]]);
		}
	}
	// Ni terms
	double sumStart;
	if (fNeedLagrangian) {
		// calc F_+p, and store it in sumStart
		sumStart = 0.0;
		for (k = 0; k < fNumDists; k++) {
			if (fDistSize[k] > 0)
			  sumStart += dist[k][fDistSize[k]-1]; // take the last one in dist[k] as F_kp
		}
	} else {
		// take sumStart to be 1
		sumStart = 1.0;
	}
	// now, each time substract F_+i from sumStart
	for (i = 0; i < fNiSize; i++) {
		TNi *Ni = &fNi[i];
		// calc 1 - F_+i
		double sum = sumStart;
		for (k = 0; k < fNumDists; k++) {
			sum -= dist[k][fDistIndex[k][Ni->i]];
		}
		phi -= Ni->n*log(sum);
	}
	if (fNeedLagrangian)
		phi += fLambda*sumStart; // sumStart is F_+p here
	return phi;
}

void CCRICSolver::GradPhi(double **dist, double **grad) {
	int k, i; // counter variables
	// initialise grad to all zeros
	for (k = 0; k < fNumDists; k++)
		for (i = 0; i < fDistSize[k]; i++)
			grad[k][i] = 0.0;
	// add all terms to the correct places in grad
	for (k = 0; k < fNumDists; k++) {
		// Nki terms
		for (i = 0; i < fNkiSize[k]; i++) {
			TNi *Ni = &fNki[k][i];
			int distI = fDistIndex[k][Ni->i];
			grad[k][distI] -= Ni->n/dist[k][distI];
		}
		// Nkij terms
		for (i = 0; i < fNkijSize[k]; i++) {
			TNij *Nij = &fNkij[k][i];
			int distI = fDistIndex[k][Nij->i];
			int distJ = fDistIndex[k][Nij->j];
			double delta = Nij->n/(dist[k][distJ]-dist[k][distI]);
			grad[k][distI] += delta;
			grad[k][distJ] -= delta;
		}
	}
	// Ni terms
	double sumStart;
	if (fNeedLagrangian) {
		// calc F_+p, and store it in sumStart
		sumStart = 0.0;
		for (k = 0; k < fNumDists; k++) {
			if (fDistSize[k] > 0)
			  sumStart += dist[k][fDistSize[k]-1]; // take the last one in dist[k] as F_kp
		}
	} else {
		sumStart = 1.0;
	}
	// calculate the terms N_i/(sumStart - F_+i), add them to the right places and sum them
	double termSum = fLambda;
	for (i = 0; i < fNiSize; i++) {
		TNi *Ni = &fNi[i];
		// calculate the term N_i/(sumStart - F_+i)
		double term = sumStart;
		for (k = 0; k < fNumDists; k++) {
			term -= dist[k][fDistIndex[k][Ni->i]];
		}
		term = Ni->n/term;
		if (fNeedLagrangian)
			termSum -= term;
		// add the term N_i/(sumStart - F_+i) to the right places in grad
		for (k = 0; k < fNumDists; k++)
			grad[k][fDistIndex[k][Ni->i]] += term;
	}
	if (fNeedLagrangian) {
		// add termSum to all last subdist elements
		for (k = 0; k < fNumDists; k++)
			grad[k][fDistSize[k]-1] += termSum;
	}
}

void CCRICSolver::HessianDiagPhi(double **dist, double **hdiag) {
	int k, i; // counter variables
	// initialise grad to all zeros
	for (k = 0; k < fNumDists; k++)
		for (i = 0; i < fDistSize[k]; i++)
			hdiag[k][i] = 0.0;
	// add all terms to the correct places in grad
	for (k = 0; k < fNumDists; k++) {
		// Nki terms
		for (i = 0; i < fNkiSize[k]; i++) {
			TNi *Ni = &fNki[k][i];
			int distI = fDistIndex[k][Ni->i];
			hdiag[k][distI] += Ni->n/SQR(dist[k][distI]);
		}
		// Nkij terms
		for (i = 0; i < fNkijSize[k]; i++) {
			TNij *Nij = &fNkij[k][i];
			int distI = fDistIndex[k][Nij->i];
			int distJ = fDistIndex[k][Nij->j];
			double delta = Nij->n/SQR(dist[k][distJ]-dist[k][distI]);
			hdiag[k][distI] += delta;
			hdiag[k][distJ] += delta;
		}
	}
	// Ni terms
	double sumStart;
	if (fNeedLagrangian) {
		// calc F_+p, and store it in sumStart
		sumStart = 0.0;
		for (k = 0; k < fNumDists; k++) {
			if (fDistSize[k] > 0)
			  sumStart += dist[k][fDistSize[k]-1]; // take the last one in dist[k] as F_kp
		}
	} else {
		sumStart = 1.0;
	}
	// calculate the terms N_i/(sumStart - F_+i), add them to the right places and sum them
	double termSum = 0.0;
	for (i = 0; i < fNiSize; i++) {
		TNi *Ni = &fNi[i];
		// calculate the term N_i/(sumStart - F_+i)
		double term = sumStart;
		for (k = 0; k < fNumDists; k++) {
			term -= dist[k][fDistIndex[k][Ni->i]];
		}
		term = Ni->n/SQR(term);
		if (fNeedLagrangian)
			termSum += term;
		// add the term N_i/(sumStart - F_+i) to the right places in grad
		for (k = 0; k < fNumDists; k++)
			hdiag[k][fDistIndex[k][Ni->i]] += term;
	}
	if (fNeedLagrangian) {
		// add termSum to all last subdist elements
		for (k = 0; k < fNumDists; k++)
			hdiag[k][fDistSize[k]-1] += termSum;
	}
}

void CCRICSolver::InitialEstimate(double **dist) {
	// Our initial estimate is a jump in every point of 1/(m+K), where m is the number of elements in the uniqueness set
	int k, i; // counter variables
	int m = fNumDists; // Note: we store K + the number of elements in the uniqueness set in our variable m
	for (k = 0; k < fNumDists; k++)
		m += fDistSize[k];
	// initial estimate
	for (k = 0; k < fNumDists; k++) {
		for (i = 0; i < fDistSize[k]; i++)
			dist[k][i] = ((double) (i+1))/m; // double cast is required to make this a floating point division
	}
}
