#include <mex.h>
#include <math.h>
#include <vector>
using namespace std;
/* 
	function NN = KNN(AdjMatrixElems,DSs,Para)	
*/
typedef vector<int> IntVector;
typedef vector<IntVector> IntArray;
typedef vector<double> DoubleVector;
typedef vector<DoubleVector> DoubleArray;

typedef struct IntDouble_TYP{
	int I;
	double D;
}IntDouble;
typedef vector<IntDouble> IDVector;

typedef struct IntIntDouble_TYP{
	int I[2];
	double D;
}IntIntDouble;
typedef vector<IntIntDouble> IIDVector;

typedef struct Parameter_TYP{
	int K;
	int NumOfVertices;
}Parameter;

void QuickSort(IDVector& data,int left,int right);
void mexFunction(int nlhs, mxArray *plhs[],int nrhs, const mxArray *prhs[])
{
	if(nrhs!=3)
		return;

	IIDVector AMElems;
	IntArray DSs;
	Parameter Para;	

	//read sparse adjacency matrix elements
	int NumOfEntries = mxGetM(prhs[0]);
	double* pData = mxGetPr(prhs[0]);
	for(int i=0;i<NumOfEntries;i++)
	{
		IntIntDouble Iid;
		Iid.I[0] = (int) pData[i]-1;
		Iid.I[1] = (int) pData[NumOfEntries+i]-1;
		Iid.D    = pData[2*NumOfEntries+i];

		if(Iid.I[0]<Iid.I[1])
			AMElems.push_back(Iid);
	}

	int NumOfDS = mxGetM(prhs[1]);
	for(int i=0;i<NumOfDS;i++)
	{
		mxArray* pCell;
		pCell=mxGetCell(prhs[1],i);
		int NumOfElem = mxGetM(pCell);
		pData = mxGetPr(pCell);
		IntVector Iv;
		for(int j=0;j<NumOfElem;j++)
			Iv.push_back((int)pData[j]-1);
		DSs.push_back(Iv);
	}

	pData = mxGetPr(prhs[2]);
	Para.NumOfVertices = (int)pData[0];
	Para.K             = (int)pData[1];
	
	IntArray KNN;
	for(int i=0;i<NumOfDS;i++)
	{
		//compute reward for each vertex
		DoubleVector AffValues;
		IntVector RevMap;
		for(int j=0;j<Para.NumOfVertices;j++)
		{			
			AffValues.push_back(0);
			RevMap.push_back(-1);
		}

		for(int j=0;j<DSs[i].size();j++)
			RevMap[DSs[i][j]] = j;

		
		for(int j=0;j<AMElems.size();j++)
		{
			IntIntDouble Iid = AMElems[j];
			int Index[2];
			Index[0] = RevMap[Iid.I[0]];
			Index[1] = RevMap[Iid.I[1]];
			if(Index[0]>=0)
				AffValues[Iid.I[1]] += Iid.D;
			if(Index[1]>=0)
				AffValues[Iid.I[0]] += Iid.D;
		}

		IDVector Idv;
		for(int j=0;j<Para.NumOfVertices;j++)
		{
			IntDouble Id;
			Id.I = j;
			Id.D = AffValues[j];
			Idv.push_back(Id);
		}

		QuickSort(Idv,0,Idv.size()-1);
		IntVector Iv;
		for(int j=0;j<Para.K;j++)
			Iv.push_back(Idv[j].I);
		KNN.push_back(Iv);
	}

	//output
	nlhs = 1;
	plhs[0] = mxCreateDoubleMatrix(Para.K,NumOfDS,mxREAL);
	pData   = mxGetPr(plhs[0]);
	for(int i=0;i<NumOfDS;i++)
		for(int j=0;j<Para.K;j++)
			pData[i*Para.K+j] = KNN[i][j]+1;
}

void QuickSort(IDVector& data,int left,int right)
{
	int l_hold, r_hold, mid;
	IntDouble pivot;		
    l_hold = left;
    r_hold = right;
    mid=(left+right)/2;	
	pivot = data[left];
	
    while (left < right)
    {
	  while ((data[right].D<pivot.D) && (left < right))
			right--;
      if (left != right)
      {
		  data[left] = data[right];		
          left++;
      }
	  while ((data[left].D>= pivot.D) && (left < right))
          left++;
      if (left != right)
      {
		  data[right] = data[left];		 
          right--;
      }
    }

	data[left] = pivot;	
    
	mid=left;
    left = l_hold;
    right = r_hold;
    if (left < mid)
		QuickSort(data, left, mid-1);
    if (right > mid)
		QuickSort(data, mid+1, right);
}