From 2c63245a14bb5be41cd75c4a4df223e19b3cb233 Mon Sep 17 00:00:00 2001
From: Peter Alexander <peter@blackhillock.co.uk>
Date: Sun, 12 Mar 2017 22:50:14 +0000
Subject: [PATCH] Use k-mean clustering for grouping yields

---
 GAMS/IntExtOpt.gms                            |   4 +-
 debug_config.properties                       |   5 +-
 src/ac/ed/lurg/ModelConfig.java               |   5 +-
 src/ac/ed/lurg/ModelMain.java                 |  34 ++++-
 src/ac/ed/lurg/country/CountryAgent.java      |  59 +++++++-
 .../country/gams/GamsRasterOptimiser.java     | 127 +++++-------------
 .../lurg/country/gams/GamsRasterOutput.java   |   9 +-
 .../ed/lurg/country/gams/GamsRasterTest.java  |   2 +-
 .../ed/lurg/utils/cluster/CentriodPoint.java  |  14 +-
 src/ac/ed/lurg/utils/cluster/Cluster.java     |  30 ++---
 .../lurg/utils/cluster/ClusteringPoint.java   |   6 +-
 src/ac/ed/lurg/utils/cluster/KMeans.java      |  92 +++++++------
 src/ac/ed/lurg/utils/cluster/KMeansTest.java  |  10 +-
 src/ac/ed/lurg/yield/YieldClusterPoint.java   |  75 +++++++++++
 14 files changed, 279 insertions(+), 193 deletions(-)
 create mode 100644 src/ac/ed/lurg/yield/YieldClusterPoint.java

diff --git a/GAMS/IntExtOpt.gms b/GAMS/IntExtOpt.gms
index a825c564..9043c460 100644
--- a/GAMS/IntExtOpt.gms
+++ b/GAMS/IntExtOpt.gms
@@ -175,7 +175,7 @@ $gdxin
  PASTURE_IMPORT_CONSTRAINT .. importAmount('pasture') =E= 0;
  PASTURE_EXPORT_CONSTRAINT ..  exportAmount('pasture') =E= 0;
   
- IRRIGATION_CONSTRAINT(location) .. irrigConstraint(location) =G= sum(crop, irrigMaxRate(crop, location) * irrigI(crop, location) * area(crop, location)) / suitableLandArea(location);
+ IRRIGATION_CONSTRAINT(location) .. irrigConstraint(location) * suitableLandArea(location) =G= sum(crop, irrigMaxRate(crop, location) * irrigI(crop, location) * area(crop, location));
  
  AGRI_LAND_EXPANSION_CALC(location) .. agriLandExpansion(location) =G= sum(crop, area(crop, location) - previousArea(crop, location)); 
  
@@ -211,6 +211,8 @@ $gdxin
  otherIntensity.L(crop, location) = previousOtherIntensity(crop, location);
  area.L(crop, location) = previousArea(crop, location)
  
+ display suitableLandArea;
+ 
  SOLVE LAND_USE USING NLP MINIMIZING total_cost;
  
  display previousFertIntensity;
diff --git a/debug_config.properties b/debug_config.properties
index 91afd906..e6da74b9 100644
--- a/debug_config.properties
+++ b/debug_config.properties
@@ -6,7 +6,7 @@ YIELD_DIR=/Users/peteralexander/Documents/LURG/LPJ/IPSL/LPJG_PLUM_expt1.2_rcp45_
 #CHANGE_YIELD_DATA_YEAR=true
 #CHANGE_DEMAND_YEAR=false
 #DEBUG_LIMIT_COUNTRIES=true
-#DEBUG_COUNTRY_NAME=China
+#DEBUG_COUNTRY_NAME=United Republic of Tanzania
 #MAX_IMPORT_CHANGE=0.0
 #MARKET_ADJ_PRICE=false
 
@@ -19,9 +19,6 @@ IS_CALIBRATION_RUN = false
 #TECHNOLOGY_CHANGE_ANNUAL_RATE=0.005
 #TRADE_BARRIER_FACTOR_DEFAULT=0.3
 
-#NUM_CEREAL_CATEGORIES = 10
-#NUM_PASTURE_CATEGORIES = 3
-
 END_TIMESTEP=18
 TIMESTEP_SIZE=5
 
diff --git a/src/ac/ed/lurg/ModelConfig.java b/src/ac/ed/lurg/ModelConfig.java
index db287052..ea6252de 100644
--- a/src/ac/ed/lurg/ModelConfig.java
+++ b/src/ac/ed/lurg/ModelConfig.java
@@ -214,11 +214,10 @@ public class ModelConfig {
 	public static final boolean PROTECTED_AREAS_ENABLED = getBooleanProperty("PROTECTED_AREAS_ENABLED", true);
 	public static final double MIN_NATURAL_RATE = getDoubleProperty("MIN_NATURAL_RATE", 0.10);
 
-	public static final int NUM_CEREAL_CATEGORIES = getIntProperty("NUM_CEREAL_CATEGORIES", 15);
-	public static final int NUM_PASTURE_CATEGORIES = getIntProperty("NUM_PASTURE_CATEGORIES", 4);
-
 	public static final boolean DEBUG_LIMIT_COUNTRIES = getBooleanProperty("DEBUG_LIMIT_COUNTRIES", false);
 	public static final String DEBUG_COUNTRY_NAME = getProperty("DEBUG_COUNTRY_NAME", "United States of America");
 	public static final double PASTURE_MAX_IRRIGATION_RATE = getDoubleProperty("DEFAULT_MAX_IRRIGATION_RATE", 50.0); // shouldn't need this but some areas crops don't have a value, but was causing them to be selected
 	public static final int LPJG_TIMESTEP_SIZE = 5;
+	
+	public static final int NUM_YIELD_CLUSTERS = 400;
 }
diff --git a/src/ac/ed/lurg/ModelMain.java b/src/ac/ed/lurg/ModelMain.java
index a1af2620..19abc477 100644
--- a/src/ac/ed/lurg/ModelMain.java
+++ b/src/ac/ed/lurg/ModelMain.java
@@ -50,6 +50,7 @@ import ac.sac.raster.IntegerRasterItem;
 import ac.sac.raster.InterpolatingRasterSet;
 import ac.sac.raster.RasterHeaderDetails;
 import ac.sac.raster.RasterKey;
+import ac.sac.raster.RasterOutputer;
 import ac.sac.raster.RasterSet;
 
 public class ModelMain {
@@ -65,7 +66,9 @@ public class ModelMain {
 	private Map<CropType, Double> prevStockLevel;
 	private RasterSet<LandUseItem> prevLandUseRaster;
 	private RasterSet<IrrigationItem> currentIrrigationData;
-
+	private RasterSet<LandUseItem> globalLandUseRaster;
+	private RasterSet<IntegerRasterItem> globalLocationIdRaster;
+	
 	public static void main(String[] args)  {
 		ModelMain theModel = new ModelMain();
 		theModel.setup();
@@ -84,6 +87,7 @@ public class ModelMain {
 				
 		countryBoundaryRaster = getCountryBoundaryRaster();
 		countryAgents = createCountryAgents(compositeCountryManager.getAll());
+		globalLandUseRaster = new RasterSet<LandUseItem>(desiredProjection);
 
 		// in first timestep we don't have this info, but ok as constrained to import/export specified amount, values based on http://www.indexmundi.com/commodities/ for Jun 2010
 		prevWorldPrices = new HashMap<CropType, GlobalPrice>();  
@@ -125,8 +129,7 @@ public class ModelMain {
 		CropToDoubleMap totalImportCommodities = new CropToDoubleMap();
 		CropToDoubleMap totalExportCommodities = new CropToDoubleMap();
 
-		RasterSet<LandUseItem> globalLandUseRaster = new RasterSet<LandUseItem>(desiredProjection);
-		RasterSet<IntegerRasterItem> globalLocationIdRaster = new RasterSet<IntegerRasterItem>(desiredProjection);
+		globalLocationIdRaster = new RasterSet<IntegerRasterItem>(desiredProjection);
 
 		for (CountryAgent ca : countryAgents) {
 
@@ -159,14 +162,16 @@ public class ModelMain {
 						FileSystems.getDefault().getPath("/Users/peteralexander/Documents/R_Workspace/UNPLUM/temp/GamsTmp/" + timestep.getTimestep() + ".gdx")
 						, StandardCopyOption.REPLACE_EXISTING);
 				} catch (IOException e) {
-					// TODO Auto-generated catch block
 					LogWriter.printException(e);
 				}
 			} */
 
 			// update global rasters
 			globalLandUseRaster.putAll(result.getLandUses());
-			globalLocationIdRaster.putAll(result.getLocationIdRaster());
+			
+			// if first timestep get the clustering info, which doesn't change through time
+			if (timestep.isInitialTimestep())
+				globalLocationIdRaster.putAll(ca.getYieldClusters());
 
 			// Get values for world input costs
 			Map<CropType, CropUsageData> cropUsage = result.getCropUsageData();
@@ -205,7 +210,7 @@ public class ModelMain {
 
 
 		// output results
-		outputTimestepResults(timestep, globalLandUseRaster, globalLocationIdRaster, yieldSurfaces);
+		outputTimestepResults(timestep, globalLandUseRaster, yieldSurfaces);
 
 		// keep last to allow interpolation
 		prevLandUseRaster = globalLandUseRaster;
@@ -304,7 +309,7 @@ public class ModelMain {
 	}
 
 
-	private void outputTimestepResults(Timestep timestep, RasterSet<LandUseItem> landUseRaster, RasterSet<IntegerRasterItem> locationIdRaster, YieldRaster yieldSurfaces) {
+	private void outputTimestepResults(Timestep timestep, RasterSet<LandUseItem> landUseRaster, YieldRaster yieldSurfaces) {
 
 		writeLandCoverFile(timestep, landUseRaster);
 		writeGlobalMarketFile(timestep);
@@ -348,9 +353,24 @@ public class ModelMain {
 		// don't really need this a LPJ outputs have same data, although in a slightly different format
 //		outputLandCover(timestep.getYear(), landUseRaster, LandCoverType.CROPLAND);
 //		outputLandCover(timestep.getYear(), landUseRaster, LandCoverType.PASTURE); 
+		
+		if (timestep.isInitialTimestep())
+			outputClusters(globalLocationIdRaster);
 	}
 
+	private void outputClusters(RasterSet<IntegerRasterItem> landUseRaster) {
+	new RasterOutputer<IntegerRasterItem>(landUseRaster, "clusters") {
+		@Override
+		public Double getValue(RasterKey location) {
+			IntegerRasterItem item = results.get(location);
+			if (item == null)
+				return null;
 
+			return (double)item.getInt();
+		}
+	}.writeOutput();
+}
+	
 /*	private void outputLandCover(int year, RasterSet<LandUseItem> landUseRaster, final LandCoverType lcType) {
 		new RasterOutputer<LandUseItem>(landUseRaster, lcType.getName() + "Area" + year) {
 			@Override
diff --git a/src/ac/ed/lurg/country/CountryAgent.java b/src/ac/ed/lurg/country/CountryAgent.java
index 550ba05f..f9cf2d5b 100644
--- a/src/ac/ed/lurg/country/CountryAgent.java
+++ b/src/ac/ed/lurg/country/CountryAgent.java
@@ -1,7 +1,11 @@
 package ac.ed.lurg.country;
 
+import java.util.Collection;
 import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 
 import ac.ed.lurg.ModelConfig;
 import ac.ed.lurg.Timestep;
@@ -16,7 +20,13 @@ import ac.ed.lurg.landuse.LandUseItem;
 import ac.ed.lurg.types.CommodityType;
 import ac.ed.lurg.types.CropType;
 import ac.ed.lurg.utils.LogWriter;
+import ac.ed.lurg.utils.cluster.Cluster;
+import ac.ed.lurg.utils.cluster.KMeans;
+import ac.ed.lurg.yield.YieldClusterPoint;
 import ac.ed.lurg.yield.YieldRaster;
+import ac.ed.lurg.yield.YieldResponsesItem;
+import ac.sac.raster.IntegerRasterItem;
+import ac.sac.raster.RasterKey;
 import ac.sac.raster.RasterSet;
 
 public class CountryAgent {
@@ -26,10 +36,10 @@ public class CountryAgent {
 	
 	private GamsRasterOutput previousGamsRasterOutput;
 	private Timestep currentTimestep;
-	private YieldRaster countryYieldSurfaces;
 	private Map<CommodityType, Double> currentProjectedDemand;
 	private Map<CropType, CountryPrice> currentCountryPrices;
 	private Map<CropType, Double> tradeBarriers;
+	private RasterSet<IntegerRasterItem> yieldClusters;
 	
 	public CountryAgent(DemandManager demandManager,CompositeCountry country, RasterSet<LandUseItem> cropAreaRaster,
 			Map<CropType, CropUsageData> cropUsageData, Map<CropType, Double> tradeBarriers) {
@@ -45,10 +55,46 @@ public class CountryAgent {
 	public CompositeCountry getCountry() {
 		return country;
 	}
+	
+	public RasterSet<IntegerRasterItem> getYieldClusters() {
+		return yieldClusters;
+	}
+	
+	private RasterSet<IntegerRasterItem> calcYieldClusters(YieldRaster countryYieldSurfaces) {
+		
+		LogWriter.println("calcYieldClusters: for " + ModelConfig.NUM_YIELD_CLUSTERS + " clusters");	
+		
+		// create collection of ClusteringPoints from countryYieldSurfaces, these have the RasterKey and data for yield (or access to them)
+		Collection<YieldClusterPoint> clusteringPoints = new HashSet<YieldClusterPoint>();
+		for (Entry<RasterKey, YieldResponsesItem> entry : countryYieldSurfaces.entrySet()) {
+			YieldResponsesItem yieldresp = entry.getValue();
+			if (yieldresp != null)
+				clusteringPoints.add(new YieldClusterPoint(entry.getKey(), yieldresp));
+		}
+				
+		// do the clustering
+		KMeans<String, YieldClusterPoint> kmeans = new KMeans<String, YieldClusterPoint>(clusteringPoints, ModelConfig.NUM_YIELD_CLUSTERS);
+		kmeans.calculateClusters(100, 0.1);
+		kmeans.printClusters();
 
+		// reformat output
+		List<Cluster<String, YieldClusterPoint>> yieldClusters = kmeans.getClusters();
+		RasterSet<IntegerRasterItem> mapping = new RasterSet<IntegerRasterItem>(countryYieldSurfaces.getHeaderDetails());
+		
+		int id = 1;
+		for (Cluster<String, YieldClusterPoint> c : yieldClusters) {
+			for (YieldClusterPoint p : c.getPoints())
+				mapping.put(p.getRasterKey(), new IntegerRasterItem(id));
+			
+			if (c.getPoints().size()>0)
+				id++;
+		}
+
+		return mapping;
+	}
+	
 	public GamsRasterOutput determineProduction(Timestep timestep, YieldRaster countryYieldSurfaces, RasterSet<IrrigationItem> irrigData, Map<CropType, GlobalPrice> worldPrices) {
 		currentTimestep = timestep;
-		this.countryYieldSurfaces = countryYieldSurfaces;
 		
 		// get projected demand
 		currentProjectedDemand = demandManager.getDemand(country, timestep.getYear());
@@ -62,9 +108,12 @@ public class CountryAgent {
 			LogWriter.printlnError("No yield values for " + country + " so skipping it");
 		}
 		else {
+			if (yieldClusters==null)
+				yieldClusters = calcYieldClusters(countryYieldSurfaces);  // this should only be on the first timestep
+			
 			// optimize areas and intensity 
-			GamsRasterInput input = getGamsRasterInput(irrigData);
-			GamsRasterOptimiser opti = new GamsRasterOptimiser(input);
+			GamsRasterInput input = getGamsRasterInput(irrigData, countryYieldSurfaces);
+			GamsRasterOptimiser opti = new GamsRasterOptimiser(input, yieldClusters);
 			LogWriter.println("Running " + country.getName() + ", currentTimestep " + currentTimestep);
 			
 			GamsRasterOutput result = opti.run();
@@ -79,7 +128,7 @@ public class CountryAgent {
 		return currentProjectedDemand;
 	}
 
-	private GamsRasterInput getGamsRasterInput(RasterSet<IrrigationItem> irrigData) {
+	private GamsRasterInput getGamsRasterInput(RasterSet<IrrigationItem> irrigData, YieldRaster countryYieldSurfaces) {
 		double allowedImportChange;
 
 		if (currentTimestep.isInitialTimestep()) {  // initialisation time-step
diff --git a/src/ac/ed/lurg/country/gams/GamsRasterOptimiser.java b/src/ac/ed/lurg/country/gams/GamsRasterOptimiser.java
index 558a5a42..76ca0852 100644
--- a/src/ac/ed/lurg/country/gams/GamsRasterOptimiser.java
+++ b/src/ac/ed/lurg/country/gams/GamsRasterOptimiser.java
@@ -1,14 +1,10 @@
 package ac.ed.lurg.country.gams;
 
-import java.util.ArrayList;
-import java.util.Collections;
 import java.util.HashSet;
-import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
 
-import ac.ed.lurg.ModelConfig;
 import ac.ed.lurg.landuse.Intensity;
 import ac.ed.lurg.landuse.IrrigationItem;
 import ac.ed.lurg.landuse.LandUseItem;
@@ -27,10 +23,11 @@ public class GamsRasterOptimiser {
 	public static final boolean DEBUG = false;
 
 	private GamsRasterInput rasterInputData;
-	private LazyTreeMap<Integer, Set<RasterKey>> mapping;
+	private RasterSet<IntegerRasterItem> mapping;
 
-	public GamsRasterOptimiser(GamsRasterInput rasterInputData) {
+	public GamsRasterOptimiser(GamsRasterInput rasterInputData, RasterSet<IntegerRasterItem> clusterMapping) {
 		this.rasterInputData = rasterInputData;
+		this.mapping = clusterMapping;
 	}
 
 	public GamsRasterOutput run() {
@@ -52,14 +49,7 @@ public class GamsRasterOptimiser {
 	private GamsRasterOutput convertToRaster(GamsLocationInput gamsInput, GamsLocationOutput gamsOutput) {		
 		RasterSet<LandUseItem> newIntensityRaster = allocAreas(gamsInput.getPreviousLandUse(), gamsOutput);
 
-		RasterSet<IntegerRasterItem> locationIdRaster = new RasterSet<IntegerRasterItem>(rasterInputData.getPreviousLandUses().getHeaderDetails());
-		for (Entry<Integer, Set<RasterKey>> entry : mapping.entrySet()) {
-			Integer locId = entry.getKey();
-			for (RasterKey key : entry.getValue())
-				locationIdRaster.put(key, new IntegerRasterItem(locId));
-		}
-
-		return new GamsRasterOutput(gamsOutput.getStatus(), newIntensityRaster, gamsOutput.getCommoditiesData(), gamsOutput.getCropAdjustments(), locationIdRaster);
+		return new GamsRasterOutput(gamsOutput.getStatus(), newIntensityRaster, gamsOutput.getCommoditiesData(), gamsOutput.getCropAdjustments());
 	}
 
 	private RasterSet<LandUseItem> createWithSameLandCovers(RasterSet<LandUseItem> toCopy) {
@@ -96,8 +86,13 @@ public class GamsRasterOptimiser {
 			Integer locId = entry.getKey();
 			LandUseItem newLandUseAggItem = entry.getValue();
 			LandUseItem prevLandUseAggItem = prevAreasAgg.get(locId);
-			Set<RasterKey> keys = mapping.get(locId);
-
+			
+			Set<RasterKey> keys = new HashSet<RasterKey>();
+			for (Entry<RasterKey, IntegerRasterItem> mapEntry : mapping.entrySet()) {
+				if (locId == mapEntry.getValue().getInt())
+					keys.add(mapEntry.getKey());
+			}
+					
 			if (DEBUG) 
 				checkedTotalAreas(newLandUseRaster.createSubsetForKeys(keys), locId + " before");
 
@@ -145,9 +140,11 @@ public class GamsRasterOptimiser {
 
 			for (RasterKey key : keys) {
 				LandUseItem newLandUseItem = newLandUseRaster.get(key);
-				for (CropType crop : CropType.values()) {
-					newLandUseItem.setCropFraction(crop, newLandUseAggItem.getCropFraction(crop));
-					newLandUseItem.setIntensity(crop, newLandUseAggItem.getIntensity(crop)); // intensities constant over single aggregated land category
+				if (newLandUseItem != null) {
+					for (CropType crop : CropType.values()) {
+						newLandUseItem.setCropFraction(crop, newLandUseAggItem.getCropFraction(crop));
+						newLandUseItem.setIntensity(crop, newLandUseAggItem.getIntensity(crop)); // intensities constant over single aggregated land category
+					}
 				}
 			}
 		}
@@ -174,11 +171,13 @@ public class GamsRasterOptimiser {
 			//if (DEBUG) LogWriter.println("  Processing raster key " + key);
 			LandUseItem newLandUseItem = newLandUseRaster.get(key);
 
-			double shortfall = newLandUseItem.moveAreas(toType, fromType, avgChange);
-			if (shortfall == 0)
-				keysWithSpace.add(key);
-			else
-				totalShortfall += shortfall;
+			if (newLandUseItem!=null) {
+				double shortfall = newLandUseItem.moveAreas(toType, fromType, avgChange);
+				if (shortfall == 0)
+					keysWithSpace.add(key);
+				else
+					totalShortfall += shortfall;
+			}
 		}
 
 		if (totalShortfall > 0 & keysWithSpace.size() > 0) {  // more to allocate and some free areas to allocate into
@@ -241,16 +240,7 @@ public class GamsRasterOptimiser {
 				logErrorWithCoord("Inconsistency F only:" + yresp.getYieldFertOnly(crop) + ", I only" + yresp.getYieldIrrigOnly(crop) + ", max " + yresp.getYieldMax(crop) + " at ", key, yieldRaster);
 			}
 		}
-
-		int numCerealCats = ModelConfig.NUM_CEREAL_CATEGORIES;
-		int numPastureCats = ModelConfig.NUM_PASTURE_CATEGORIES;
-
-		int thisShouldLookAtCropsOtherThanJustWheat; // need to consider other crops, and perhaps other yieldTypes as well - particularly fert/irrig responses
-		List<Double> wheatlDivisions = getDivisions(yieldRaster, CropType.WHEAT, numCerealCats); 
-		List<Double> pastureDivisions = getDivisions(yieldRaster, CropType.PASTURE, numPastureCats);
-
-		if (DEBUG) LogWriter.println("Making " + numCerealCats * numPastureCats + " categories");
-
+		
 		LazyTreeMap<Integer, YieldResponsesItem> aggregatedYields = new LazyTreeMap<Integer, YieldResponsesItem>() { 
 			protected YieldResponsesItem createValue() { return new YieldResponsesItem(); }
 		};
@@ -260,9 +250,7 @@ public class GamsRasterOptimiser {
 		LazyTreeMap<Integer, IrrigationItem> aggregatedIrrigCosts = new LazyTreeMap<Integer, IrrigationItem>() { 
 			protected IrrigationItem createValue() { return new IrrigationItem(); }
 		};
-		mapping = new LazyTreeMap<Integer, Set<RasterKey>>() { 
-			protected Set<RasterKey> createValue() { return new HashSet<RasterKey>(); }
-		};
+
 
 		int countFound = 0, countMissing = 0;
 
@@ -286,14 +274,11 @@ public class GamsRasterOptimiser {
 
 				IrrigationItem irrigItem = irrigRaster.get(key);
 
-				int cerealCat = findCategory(wheatlDivisions, yresp.getYieldNone(CropType.WHEAT) + yresp.getYieldMax(CropType.WHEAT));
-				int pastureCat = findCategory(pastureDivisions, yresp.getYieldNone(CropType.PASTURE) + yresp.getYieldMax(CropType.PASTURE));
-				Integer id = cerealCat + pastureCat * numCerealCats;
+				int clusterId = mapping.get(key).getInt();
 
-				YieldResponsesItem aggYResp = aggregatedYields.lazyGet(id);
-				LandUseItem aggLandUse = aggregatedAreas.lazyGet(id);
-				IrrigationItem aggIrig = aggregatedIrrigCosts.lazyGet(id);
-				mapping.lazyGet(id).add(key); 
+				YieldResponsesItem aggYResp = aggregatedYields.lazyGet(clusterId);
+				LandUseItem aggLandUse = aggregatedAreas.lazyGet(clusterId);
+				IrrigationItem aggIrig = aggregatedIrrigCosts.lazyGet(clusterId);
 
 				// Irrigation cost
 				double suitableAreaThisTime  = landUseItem.getSuitableLand();
@@ -348,8 +333,8 @@ public class GamsRasterOptimiser {
 
 		LogWriter.println("YieldResponsesItem: " + rasterInputData.getCountryInput().getCountry() + ", countFound=" + countFound + ", countMissing=" + countMissing);
 
-		for (Map.Entry<Integer, Set<RasterKey>> e : mapping.entrySet()) {
-			LogWriter.println(e.getKey() + " zone has " + e.getValue().size() + " raster areas");
+	//	for (Map.Entry<Integer, Set<RasterKey>> e : mapping.entrySet()) {
+		//	LogWriter.println(e.getKey() + " zone has " + e.getValue().size() + " raster areas");
 
 			/*	CropType[] cs = {CropType.WHEAT, CropType.MAIZE};
 			for (CropType c : cs) {
@@ -361,7 +346,7 @@ public class GamsRasterOptimiser {
 				}
 				LogWriter.println("");
 			} */
-		}
+	//	}
 		
 		double baseCropland = LandUseItem.getTotalLandCover(aggregatedAreas.values(), LandCoverType.CROPLAND);
 		double basePasture = LandUseItem.getTotalLandCover(aggregatedAreas.values(), LandCoverType.PASTURE);
@@ -387,52 +372,4 @@ public class GamsRasterOptimiser {
 		
 		return (aggV*aggArea + newV*newArea) / (aggArea + newArea);
 	}
-
-	private int findCategory(List<Double> divisions, double yield) {
-		int category;
-		int numDivisions = divisions.size();
-
-		for (category = 0; category<numDivisions; category++) {
-			if (yield < divisions.get(category)) {
-				break;
-			}
-		}
-		return category;
-	}
-
-	private List<Double> getDivisions(YieldRaster yieldRaster, CropType crop, int numCategories) {
-		List<Double> yieldValues = new ArrayList<Double>();
-
-		for (YieldResponsesItem yresp : yieldRaster.values()) {
-			if (yresp == null) {
-				if (DEBUG) LogWriter.println("GamsRasterOptimiser: Can't get value for crop " + crop);
-			}
-			else {
-				double d = yresp.getYieldNone(crop) + yresp.getYieldMax(crop) ;
-				//LogWriter.println("GamsRasterOptimiser: Got value for crop " + crop + " of " + d);
-
-				if (Double.isNaN(d) || d == 0.0) {
-					if (DEBUG) LogWriter.println("GamsRasterOptimiser: Got NaN or zero value for crop " + crop);
-				}
-				else {
-					yieldValues.add(d);
-				}
-			}
-		}
-
-		if (yieldValues.size() == 0) {
-			throw new RuntimeException("No yield values for country, crop = " + crop);
-		}
-
-		Collections.sort(yieldValues);
-
-		List<Double> divisions = new ArrayList<Double>();
-
-		for (int i=1; i < numCategories; i++) {
-			double d = yieldValues.get(yieldValues.size()*i/numCategories);
-			divisions.add(d);
-		}	
-
-		return divisions;
-	}
 }
diff --git a/src/ac/ed/lurg/country/gams/GamsRasterOutput.java b/src/ac/ed/lurg/country/gams/GamsRasterOutput.java
index 76bb49d6..5034f748 100644
--- a/src/ac/ed/lurg/country/gams/GamsRasterOutput.java
+++ b/src/ac/ed/lurg/country/gams/GamsRasterOutput.java
@@ -5,7 +5,6 @@ import java.util.Map;
 import ac.ed.lurg.landuse.CropUsageData;
 import ac.ed.lurg.landuse.LandUseItem;
 import ac.ed.lurg.types.CropType;
-import ac.sac.raster.IntegerRasterItem;
 import ac.sac.raster.RasterSet;
 
 public class GamsRasterOutput {
@@ -13,7 +12,6 @@ public class GamsRasterOutput {
 	private int status;
 	
 	private RasterSet<LandUseItem> landUses;
-	private RasterSet<IntegerRasterItem> locationIdRaster;
 	
 	private Map<CropType, CropUsageData> cropUsageData;
 	private Map<CropType, Double> cropAdjustments;
@@ -25,11 +23,10 @@ public class GamsRasterOutput {
 	}
 
 	public GamsRasterOutput(int status, RasterSet<LandUseItem> intensityRaster, Map<CropType, CropUsageData> cropUsageData, 
-			Map<CropType, Double> cropAdjustments, RasterSet<IntegerRasterItem> locationIdRaster) {
+			Map<CropType, Double> cropAdjustments) {
 		this(intensityRaster, cropUsageData);
 		this.status = status;
 		this.cropAdjustments = cropAdjustments;
-		this.locationIdRaster = locationIdRaster;
 	}
 	
 	public int getStatus() {
@@ -39,10 +36,6 @@ public class GamsRasterOutput {
 	public RasterSet<LandUseItem> getLandUses() {
 		return landUses;
 	}
-		
-	public RasterSet<IntegerRasterItem> getLocationIdRaster() {
-		return locationIdRaster;
-	}
 
 	public Map<CropType, CropUsageData> getCropUsageData() {
 		return cropUsageData;
diff --git a/src/ac/ed/lurg/country/gams/GamsRasterTest.java b/src/ac/ed/lurg/country/gams/GamsRasterTest.java
index 00ef4aef..39992302 100644
--- a/src/ac/ed/lurg/country/gams/GamsRasterTest.java
+++ b/src/ac/ed/lurg/country/gams/GamsRasterTest.java
@@ -21,7 +21,7 @@ public class GamsRasterTest extends GamsLocationTest {
 		GamsCountryInput countryLevelInputs = new GamsCountryInput(new CompositeCountry("Test"), getProjectedDemand(), getCountryPrices(), null);
 		GamsRasterInput input = new GamsRasterInput(new Timestep(0), getYieldRaster(), getPreviousAreaRaster(), getIrrigationCost(), countryLevelInputs);
 		
-		GamsRasterOptimiser opti = new GamsRasterOptimiser(input);		
+		GamsRasterOptimiser opti = new GamsRasterOptimiser(input, null);		
 		GamsRasterOutput output = opti.run();
 		LogWriter.println(output.toString());
 	}
diff --git a/src/ac/ed/lurg/utils/cluster/CentriodPoint.java b/src/ac/ed/lurg/utils/cluster/CentriodPoint.java
index 22a1e122..57142743 100644
--- a/src/ac/ed/lurg/utils/cluster/CentriodPoint.java
+++ b/src/ac/ed/lurg/utils/cluster/CentriodPoint.java
@@ -3,25 +3,27 @@ package ac.ed.lurg.utils.cluster;
 import java.util.Collection;
 import java.util.Map;
 
-public class CentriodPoint implements ClusteringPoint {
+public class CentriodPoint<K> implements ClusteringPoint<K> {
 
-	private Map<String, Double> valueMap;
+	private Map<K, Double> valueMap;
 
-	public CentriodPoint(Map<String, Double> valueMap) {
+	public CentriodPoint(Map<K, Double> valueMap) {
 		this.valueMap = valueMap;
 	}
 
-	public double getClusteringValue(String key)  {
+	@Override
+    public double getClusteringValue(K key) {
 		return valueMap.get(key);
 	}
 
-	public Collection<String> getAllClusteringKeys() {
+	@Override
+	public Collection<K> getAllClusteringKeys() {
 		return valueMap.keySet();
 	}
 
 	public String toString() {
 		StringBuffer sb = new StringBuffer();
-		for (Map.Entry<String, Double> e : valueMap.entrySet()) {
+		for (Map.Entry<K, Double> e : valueMap.entrySet()) {
 			sb.append(e.getKey() + "=" + e.getValue() + " ");
 		}
 		return sb.toString();
diff --git a/src/ac/ed/lurg/utils/cluster/Cluster.java b/src/ac/ed/lurg/utils/cluster/Cluster.java
index 91293201..d96f3640 100644
--- a/src/ac/ed/lurg/utils/cluster/Cluster.java
+++ b/src/ac/ed/lurg/utils/cluster/Cluster.java
@@ -5,24 +5,24 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
 
-public class Cluster {
+public class Cluster<K, P extends ClusteringPoint<K>> {
 
-	private Collection<ClusteringPoint> points = new HashSet<ClusteringPoint>();
-	private ClusteringPoint centroid;
+	private Collection<P> points = new HashSet<P>();
+	private CentriodPoint<K> centroid;
 
-	public Cluster(ClusteringPoint centroid) {
+	public Cluster(CentriodPoint<K> centroid) {
 		this.centroid = centroid;
 	}
 
-	public Collection<ClusteringPoint> getPoints() {
+	public Collection<P> getPoints() {
 		return points;
 	}
 
-	protected void addPoint(ClusteringPoint point) {
+	protected void addPoint(P point) {
 		points.add(point);
 	}
 
-	public ClusteringPoint getCentroid() {
+	public CentriodPoint<K> getCentroid() {
 		return centroid;
 	}
 	
@@ -30,10 +30,10 @@ public class Cluster {
 		int n_points = points.size();
 		if(n_points == 0) return 0;
 
-		Map<String, Double> centroidValues = new HashMap<String, Double>();
+		Map<K, Double> centroidValues = new HashMap<K, Double>();
 
-		for(ClusteringPoint p : points) {
-	    	for (String key : p.getAllClusteringKeys()) {
+		for(ClusteringPoint<K> p : points) {
+	    	for (K key : p.getAllClusteringKeys()) {
 				Double soFar = centroidValues.get(key);
 				Double pointVal = p.getClusteringValue(key);
 				Double updated = soFar==null ? pointVal : pointVal + soFar;
@@ -41,10 +41,10 @@ public class Cluster {
 			}
 		}
 
-		for (Map.Entry<String, Double> e : centroidValues.entrySet())
+		for (Map.Entry<K, Double> e : centroidValues.entrySet())
 			centroidValues.put(e.getKey(), e.getValue()/n_points);
 
-		ClusteringPoint updatedCentroid = new CentriodPoint(centroidValues);
+		CentriodPoint<K> updatedCentroid = new CentriodPoint<K>(centroidValues);
 		double distanceMoved = distanceFromCentroid(updatedCentroid);
 		
 		centroid = updatedCentroid;
@@ -52,10 +52,10 @@ public class Cluster {
 	}
 	
     //Calculates the distance between two points.
-    protected double distanceFromCentroid(ClusteringPoint p) {
+    protected double distanceFromCentroid(ClusteringPoint<K> p) {
     	double squaredTotal=0;
-    	for (String key : centroid.getAllClusteringKeys()) {
-       		squaredTotal += Math.pow(p.getClusteringValue(key)-centroid.getClusteringValue(key), 2);
+    	for (K key : centroid.getAllClusteringKeys()) {
+       		squaredTotal += Math.pow(centroid.getClusteringValue(key)-p.getClusteringValue(key), 2);
     	}
     	
         return Math.sqrt(squaredTotal);
diff --git a/src/ac/ed/lurg/utils/cluster/ClusteringPoint.java b/src/ac/ed/lurg/utils/cluster/ClusteringPoint.java
index ca300134..62378bd0 100644
--- a/src/ac/ed/lurg/utils/cluster/ClusteringPoint.java
+++ b/src/ac/ed/lurg/utils/cluster/ClusteringPoint.java
@@ -3,9 +3,9 @@ package ac.ed.lurg.utils.cluster;
 import java.util.Collection;
 
 /** Interface that give information used in clustering */
-public interface ClusteringPoint {
+public interface ClusteringPoint<K> {
  
-    public double getClusteringValue(String key);
-	public Collection<String> getAllClusteringKeys();
+    public double getClusteringValue(K key);
+	public Collection<K> getAllClusteringKeys();
 
 }
\ No newline at end of file
diff --git a/src/ac/ed/lurg/utils/cluster/KMeans.java b/src/ac/ed/lurg/utils/cluster/KMeans.java
index e49a5230..e05c803b 100644
--- a/src/ac/ed/lurg/utils/cluster/KMeans.java
+++ b/src/ac/ed/lurg/utils/cluster/KMeans.java
@@ -1,27 +1,28 @@
 package ac.ed.lurg.utils.cluster;
 
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
-import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Random;
 
 import ac.ed.lurg.utils.LogWriter;
  
-public class KMeans {
+public class KMeans<K, P extends ClusteringPoint<K>> {
      
-    private Collection<ClusteringPoint> points;
-    private Collection<Cluster> clusters;
+    private Collection<P> points;
+    private List<Cluster<K,P>> clusters;
     
-    public KMeans(Collection<ClusteringPoint> points, int numClusters) {
+    private Map<K, Double> minValueMap = new HashMap<K, Double>();
+    private Map<K, Double> maxValueMap = new HashMap<K, Double>();
+
+    public KMeans(Collection<P> points, int numClusters) {
     	this.points = points;
-    	this.clusters = new HashSet<Cluster>();   
-    	
-        Map<String, Double> minValueMap = new HashMap<String, Double>();
-        Map<String, Double> maxValueMap = new HashMap<String, Double>();
+    	this.clusters = new ArrayList<Cluster<K,P>>();   
 
-        for (ClusteringPoint p : points) {
-	    	for (String key : p.getAllClusteringKeys()) {
+        for (ClusteringPoint<K> p : points) {
+	    	for (K key : p.getAllClusteringKeys()) {
 	    		double currentVal = p.getClusteringValue(key);
 	    		Double minV = minValueMap.get(key);
 	    		Double maxV = maxValueMap.get(key);
@@ -36,40 +37,51 @@ public class KMeans {
 	    	}
         }
     	
-		Random r = new Random();
-		double min, max, rand, v;
 		//Create Clusters, with random centroids
 		for (int i = 0; i < numClusters; i++) {
 			
-			Map<String, Double> randomCentroid = new HashMap<String, Double>();
-
-	    	for (Map.Entry<String, Double> e : minValueMap.entrySet()) {
-	    		min = e.getValue();
-	    		max = maxValueMap.get(e.getKey());
-	    		rand = r.nextDouble();
-	    		v = min + (max - min) * rand;
-	    		randomCentroid.put(e.getKey(), v);
-	    	}
-	    	
-	    	ClusteringPoint p = new CentriodPoint(randomCentroid);
-	    	LogWriter.println("Creating cluster at " + p);
-			Cluster cluster = new Cluster(p);
+			CentriodPoint<K> p = generateRandomCentriod();
+	 //   	LogWriter.println("Creating cluster at " + p);
+			Cluster<K,P> cluster = new Cluster<K,P>(p);
 			clusters.add(cluster);
 		}
     }
+
+	private CentriodPoint<K> generateRandomCentriod() {
+		double min, max, rand, v;
+		Random r = new Random();
+
+		Map<K, Double> randomCentroid = new HashMap<K, Double>();
+
+		for (Map.Entry<K, Double> e : minValueMap.entrySet()) {
+			min = e.getValue();
+			max = maxValueMap.get(e.getKey());
+			rand = r.nextDouble();
+			v = min + (max - min) * rand;
+			randomCentroid.put(e.getKey(), v);
+		}
+		
+		return new CentriodPoint<K>(randomCentroid);
+	}
     
-	protected void printClusters() {
+	public void printClusters() {
 		int i=0;
-		for (Cluster c : clusters) {
+		int clustersWithPoints = 0;
+		for (Cluster<K,P> c : clusters) {
 			i++;
 			LogWriter.println("[Cluster: " + i+"]");
 			LogWriter.println("[Centroid: " + c.getCentroid() + "]");
-			LogWriter.println("[Points:");
-			for(ClusteringPoint p : c.getPoints())
+			LogWriter.println("[Points: (" + c.getPoints().size() + " points)");
+			for(P p : c.getPoints())
 				LogWriter.println(p.toString());
 			
+			if (c.getPoints().size() > 0)
+				clustersWithPoints++;
+			
 			LogWriter.println("]\n");
 		}
+		
+		LogWriter.println(clusters.size() + " clusters, of which " + clustersWithPoints + " have points");
 	}
 
     public double calculateClusters(int maxIterations, double tolerance) {
@@ -84,13 +96,13 @@ public class KMeans {
         	assignCluster();
         	
             //Calculate new centroids, and total distance between new and old centroids
-            for(Cluster c : clusters)
+            for(Cluster<K,P> c : clusters)
             	distance += c.calculateCentroid();
             
-        	LogWriter.println("#################");
-        	LogWriter.println("Iteration: " + iteration);
-        	LogWriter.println("Centroid distances: " + distance);
-        	printClusters();
+       // 	LogWriter.println("#################");
+       // 	LogWriter.println("Iteration: " + iteration);
+       // 	LogWriter.println("Centroid distances: " + distance);
+       // 	printClusters();
    	
         	if(distance <= tolerance || iteration > maxIterations) {
         		LogWriter.println("Finishing calculateClusters: Iteration " + iteration + ", centroid distances: " + distance);
@@ -100,12 +112,12 @@ public class KMeans {
         }
     }
     
-    public Collection<Cluster> getClusters() {
+    public List<Cluster<K,P>> getClusters() {
     	return clusters;
     }
     
     private void clearPointsFromAllClusters() {
-    	for(Cluster cluster : clusters) {
+    	for(Cluster<K,P> cluster : clusters) {
     		cluster.clearPoints();
     	}
     }
@@ -113,12 +125,12 @@ public class KMeans {
     private void assignCluster() {
         clearPointsFromAllClusters();
        
-        for(ClusteringPoint point : points) {
+        for(P point : points) {
         	double min = Double.MAX_VALUE;
-            Cluster cluster = null;                 
+            Cluster<K,P> cluster = null;                 
             double distance = 0.0; 
       	
-            for(Cluster c : clusters) {
+            for(Cluster<K,P> c : clusters) {
                 distance = c.distanceFromCentroid(point);
                 if(distance < min){
                     min = distance;
diff --git a/src/ac/ed/lurg/utils/cluster/KMeansTest.java b/src/ac/ed/lurg/utils/cluster/KMeansTest.java
index cbbd1cef..4ff77516 100644
--- a/src/ac/ed/lurg/utils/cluster/KMeansTest.java
+++ b/src/ac/ed/lurg/utils/cluster/KMeansTest.java
@@ -24,13 +24,13 @@ public class KMeansTest {
 	}
 	
 	public void doIt() {
-		KMeans kmeans = new KMeans(createRandomPoints(MIN_COORDINATE,MAX_COORDINATE,NUM_POINTS), NUM_CLUSTERS);
+		KMeans<String, ClusteringPoint<String>> kmeans = new KMeans<String, ClusteringPoint<String>>(createRandomPoints(MIN_COORDINATE,MAX_COORDINATE,NUM_POINTS), NUM_CLUSTERS);
 		kmeans.printClusters();
 		kmeans.calculateClusters(MAX_ITERATIONS, 0.6);
 	}
 	
     //Creates random point
-    protected static ClusteringPoint createRandomPoint(int min, int max) {
+    protected static ClusteringPoint<String> createRandomPoint(int min, int max) {
     	Random r = new Random();
     	double x = min + (max - min) * r.nextDouble();
     	double y = min + (max - min) * r.nextDouble();
@@ -41,12 +41,12 @@ public class KMeansTest {
     	values.put("Y", y);
     	values.put("Z", z);
 
-    	ClusteringPoint p = new CentriodPoint(values);
+    	ClusteringPoint<String> p = new CentriodPoint<String>(values);
     	return p;
     }
     
-    protected static List<ClusteringPoint> createRandomPoints(int min, int max, int number) {
-    	List<ClusteringPoint> points = new ArrayList<ClusteringPoint>(number);
+    protected static List<ClusteringPoint<String>> createRandomPoints(int min, int max, int number) {
+    	List<ClusteringPoint<String>> points = new ArrayList<ClusteringPoint<String>>(number);
     	for(int i = 0; i < number; i++) {
     		points.add(createRandomPoint(min,max));
     	}
diff --git a/src/ac/ed/lurg/yield/YieldClusterPoint.java b/src/ac/ed/lurg/yield/YieldClusterPoint.java
new file mode 100644
index 00000000..4b965f71
--- /dev/null
+++ b/src/ac/ed/lurg/yield/YieldClusterPoint.java
@@ -0,0 +1,75 @@
+package ac.ed.lurg.yield;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import ac.ed.lurg.types.CropType;
+import ac.ed.lurg.utils.LogWriter;
+import ac.ed.lurg.utils.cluster.ClusteringPoint;
+import ac.sac.raster.RasterKey;
+
+public class YieldClusterPoint implements ClusteringPoint<String> {
+
+	private final static String PASTURE = "pas";
+	private final static String WHEAT_MIN = "wheatMin";
+	private final static String WHEAT_MAX = "wheatMax";
+	private final static String RICE_MIN = "riceMin";
+	private final static String MAIZE_MIN = "maizeMin";
+	private final static String MAIZE_MAX = "maizeMax";
+	private final static String ROOTS_MIN = "rootsMin";
+
+	private RasterKey rasterKey;
+	private double wheatMin;
+	private double wheatMax;
+	private double riceMin;
+	private double maizeMin;
+	private double maizeMax;
+	private double rootsMin;
+	private double pasture;
+
+	public YieldClusterPoint(RasterKey rasterKey, YieldResponsesItem yields) {
+		this.rasterKey = rasterKey;
+		
+		// not sure if we be better to get a reference to the YieldResponsesItem, rather than caching these values?
+		this.wheatMin = yields.getYieldNone(CropType.WHEAT);
+		this.riceMin = yields.getYieldNone(CropType.RICE);
+		this.maizeMin = yields.getYieldNone(CropType.MAIZE);
+		this.rootsMin = yields.getYieldNone(CropType.STARCHY_ROOTS);
+		this.pasture = yields.getYieldNone(CropType.PASTURE);
+		this.wheatMax = yields.getYieldMax(CropType.WHEAT);
+		this.maizeMax = yields.getYieldMax(CropType.MAIZE);
+	}
+
+	public RasterKey getRasterKey() {
+		return rasterKey;
+	}
+
+	@Override
+	public double getClusteringValue(String key) {
+
+		switch (key) {
+			case PASTURE:  return pasture;
+			case WHEAT_MIN:  return wheatMin;
+			case WHEAT_MAX:  return wheatMax;
+			case RICE_MIN:  return riceMin;
+			case MAIZE_MIN:  return maizeMin;
+			case MAIZE_MAX:  return maizeMax;
+			case ROOTS_MIN:  return rootsMin;
+		}
+		LogWriter.printlnError("YieldClusterPoint.getClusteringValue: got unknown value " + key);
+		return Double.NaN;
+	}
+
+	@Override
+	public Collection<String> getAllClusteringKeys() {
+		return Arrays.asList(PASTURE, WHEAT_MIN, WHEAT_MAX, RICE_MIN, MAIZE_MIN, MAIZE_MAX, ROOTS_MIN);
+	}
+	
+	public String toString() {
+		StringBuffer sb = new StringBuffer(rasterKey.toString() + ": ");
+		for (String i : getAllClusteringKeys())
+			sb.append(i + "=" + getClusteringValue(i) + " ");
+		
+		return sb.toString();
+	}
+}
-- 
GitLab