package de.unikassel.cs.kde.kdd;

import java.util.LinkedList;
import java.util.List;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.NominalMapping;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.LogService;


public class KDDKNearestNeighbors extends AbstractLearner {

	public static final String PARAMETER_K = "k";

	public KDDKNearestNeighbors(OperatorDescription description) {
		super(description);
	}

	public Model learn(ExampleSet exampleSet) throws OperatorException {
		return new KDDKNearestNeighborsModel(exampleSet, getParameterAsInt(PARAMETER_K));
	}

	public boolean supportsCapability(LearnerCapability lc) {
		if (lc == com.rapidminer.operator.learner.LearnerCapability.NUMERICAL_ATTRIBUTES)
			return true;

		if (lc == com.rapidminer.operator.learner.LearnerCapability.POLYNOMINAL_CLASS)
			return true;
		if (lc == com.rapidminer.operator.learner.LearnerCapability.BINOMINAL_CLASS)
			return true;
		if (lc == com.rapidminer.operator.learner.LearnerCapability.NUMERICAL_CLASS)
			return true;
		return false;
	}

	/** Method to get parameters from RapidMiner.
	 * If you need more parameters, change this method.
	 * 
	 * @see com.rapidminer.operator.learner.clustering.clusterer.AbstractClustering#getParameterTypes()
	 */
	public List<ParameterType> getParameterTypes() {

		final List<ParameterType> types = super.getParameterTypes();
		final ParameterType type = new ParameterTypeInt(PARAMETER_K, "The number of clusters which should be detected.", 1, Integer.MAX_VALUE, 2);
		type.setExpert(false);
		types.add(type);
		return types;
	}



	/** Represents a KNN model.
	 * 
	 * @author rja
	 *
	 */
	private static class KDDKNearestNeighborsModel extends PredictionModel {

		private static LogService log = LogService.getGlobal();

		private static final long serialVersionUID = -8808551748509677749L;

		private Example[] neighbors ;  //< k nearest neighbors
		private double[] distances;    //< distances of the k NN
		private int neighborCount;     //< number of neighbors
		private double[] norm;         //< normalization values

		private final int k;
		private final List<Example> samples;


		/** Constructor initializing with training set and number of neighbors k.
		 * 
		 * @param trainingExampleSet
		 * @param k
		 */
		protected KDDKNearestNeighborsModel(final ExampleSet trainingExampleSet, final int k) {
			super(trainingExampleSet);
			this.k = k;
			/*
			 * copying training set 
			 */
			this.samples = new LinkedList<Example>();
			for (final Example example : trainingExampleSet) {
				this.samples.add(example);
			}
		}


		@Override
		public ExampleSet performPrediction(final ExampleSet exampleSet, final Attribute predictedLabel) throws OperatorException {
			final NominalMapping mapping = predictedLabel.getMapping();
			/*
			 * predicting classes for the examples
			 */
			for (final Example example: exampleSet) {
				example.setValue(predictedLabel, mapping.mapIndex(predictClass(example)));
			}
			return exampleSet;
		}



		/**
		 * Predict a class using the k nearest neighbors.
		 * 
		 * @param example the new example to classify
		 */	
		public int predictClass(final Example newExample) {

			calculateNormalization(newExample);				
			findNeighbors(newExample);

			if (neighborCount == 0) {
				log.logWarning("Could not find any neighbors ... ");
				return 0;
			}

			return getVote();
		}

		/**
		 * Let the neighbors vote for a label.
		 * 
		 * @return label of the sum nearest neighbors
		 */
		protected int getVote() {
			/*
			 * FIXME: implement distance weighted voting here
			 */
			return (int) neighbors[0].getLabel();
		}

		/**
		 * Find (at most) k nearest neighbors.
		 * checks everyone and replaces the farest through better candidates
		 *  
		 * @param example 
		 */
		protected void findNeighbors(final Example example) {
			/*
			 * FIXME: implement this method
			 */
		}


		/**
		 * calculate normalization table. 
		 * norm[i] = 1.0/((Attribute Maximum - Attribute Minimum) * dimensions)
		 * @param example
		 */
		protected void calculateNormalization(final Example example)	{
			/*
			 * FIXME: implement this method
			 */
		}

	}
}