您好,登錄后才能下訂單哦!
Knn算法的核心思想是如果一個樣本在特征空間中的K個最相鄰的樣本中的大多數屬于某一個類別,則該樣本也屬于這個類別,并具有這個類別上樣本的特性。該方法在確定分類決策上只依據最鄰近的一個或者幾個樣本的類別來決定待分樣本所屬的類別。Knn方法在類別決策時,只與極少量的相鄰樣本有關。由于Knn方法主要靠周圍有限的鄰近的樣本,而不是靠判別類域的方法來確定所屬類別的,因此對于類域的交叉或重疊較多的待分樣本集來說,Knn方法較其他方法更為合適。
Knn算法流程如下:
1. 計算當前測試數據與訓練數據中的每條數據的距離
2. 圈定距離最近的K個訓練對象,作為測試對象的近鄰
3. 計算這K個訓練對象中出現最多的那個類別,并將這個類別作為當前測試數據的類別
以上流程是Knn的大致流程,按照這個流程實現的MR效率并不高,可以在這之上進行優化。在這里只寫,跟著這個流程走的MR實現過程。
Mapper的設計:
由于測試數據相比于訓練數據來說,會小很多,因此將測試數據用Java API讀取,放到內存中。所以,在setup中需要對測試數據進行初始化。在map中,計算當前測試數據與每條訓練數據的距離,Mapper的值類型為:<Object, Text, IntWritable,MyWritable>。map輸出鍵類型為IntWritable,存放當前測試數據的下標,輸出值類型為MyWritable,這是自定義值類型,其中存放的是距離以及與測試數據比較的訓練數據的類別。
public class KnnMapper extends Mapper<Object, Text, IntWritable,MyWritable> { Logger log = LoggerFactory.getLogger(KnnMapper.class); private List<float[]> testData; @Override protected void setup(Context context) throws IOException, InterruptedException { // TODO Auto-generated method stub Configuration conf= context.getConfiguration(); conf.set("fs.defaultFS", "master:8020"); String testPath= conf.get("TestFilePath"); Path testDataPath= new Path(testPath); FileSystem fs = FileSystem.get(conf); this.testData = readTestData(fs,testDataPath); } @Override protected void map(Object key, Text value, Context context) throws IOException, InterruptedException { // TODO Auto-generated method stub String[] line = value.toString().split(","); float[] trainData = new float[line.length-1]; for(int i=0;i<trainData.length;i++){ trainData[i] = Float.valueOf(line[i]); log.info("訓練數據:"+line[i]+"類別:"+line[line.length-1]); } for(int i=0; i< this.testData.size();i++){ float[] testI = this.testData.get(i); float distance = Outh(testI, trainData); log.info("距離:"+distance); context.write(new IntWritable(i), new MyWritable(distance, line[line.length-1])); } } private List<float[]> readTestData(FileSystem fs,Path Path) throws IOException { //補充代碼完整 FSDataInputStream data = fs.open(Path); BufferedReader bf = new BufferedReader(new InputStreamReader(data)); String line = ""; List<float[]> list = new ArrayList<>(); while ((line = bf.readLine()) != null) { String[] items = line.split(","); float[] item = new float[items.length]; for(int i=0;i<items.length;i++){ item[i] = Float.valueOf(items[i]); } list.add(item); } return list; } // 計算歐式距離 private static float Outh(float[] testData, float[] inData) { float distance =0.0f; for(int i=0;i<testData.length;i++){ distance += (testData[i]-inData[i])*(testData[i]-inData[i]); } distance = (float)Math.sqrt(distance); return distance; } }
自定義值類型MyWritable如下:
public class MyWritable implements Writable{ private float distance; private String label; public MyWritable() { // TODO Auto-generated constructor stub } public MyWritable(float distance, String label){ this.distance = distance; this.label = label; } @Override public String toString() { // TODO Auto-generated method stub return this.distance+","+this.label; } @Override public void write(DataOutput out) throws IOException { // TODO Auto-generated method stub out.writeFloat(distance); out.writeUTF(label); } @Override public void readFields(DataInput in) throws IOException { // TODO Auto-generated method stub this.distance = in.readFloat(); this.label = in.readUTF(); } public float getDistance() { return distance; } public void setDistance(float distance) { this.distance = distance; } public String getLabel() { return label; } public void setLabel(String label) { this.label = label; } }
在Reducer端中,需要初始化參數K,也就是圈定距離最近的K個對象的K值。在reduce中需要對距離按照從小到大的距離排序,然后選取前K條數據,再計算這K條數據中,出現次數最多的那個類別并將這個類別與測試數據的下標相對應并以K,V的形式輸出到HDFS上。
public class KnnReducer extends Reducer<IntWritable, MyWritable, IntWritable, Text> { private int K; @Override protected void setup(Context context) throws IOException, InterruptedException { // TODO Auto-generated method stub this.K = context.getConfiguration().getInt("K", 5); } @Override /*** * key => 0 * values =>([1,lable1],[2,lable2],[3,label2],[2.5,lable2]) */ protected void reduce(IntWritable key, Iterable<MyWritable> values, Context context) throws IOException, InterruptedException { // TODO Auto-generated method stub MyWritable[] mywrit = new MyWritable[K]; for(int i=0;i<K;i++){ mywrit[i] = new MyWritable(Float.MAX_VALUE, "-1"); } // 找出距離最小的前k個 for (MyWritable m : values) { float distance = m.getDistance(); String label = m.getLabel(); for(MyWritable m1: mywrit){ if (distance < m1.getDistance()){ m1.setDistance(distance); m1.setLabel(label); } } } // 找出前k個中,出現次數最多的類別 String[] testClass = new String[K]; for(int i=0;i<K;i++){ testClass[i] = mywrit[i].getLabel(); } String countMost = mostEle(testClass); context.write(key, new Text(countMost)); } public static String mostEle(String[] strArray) { HashMap<String, Integer> map = new HashMap<>(); for (int i = 0; i < strArray.length; i++) { String str = strArray[i]; if (map.containsKey(str)) { int tmp = map.get(str); map.put(str, tmp+1); }else{ map.put(str, 1); } } // 得到hashmap中值最大的鍵,也就是出現次數最多的類別 Collection<Integer> count = map.values(); int maxCount = Collections.max(count); String maxString = ""; for(Map.Entry<String, Integer> entry: map.entrySet()){ if (maxCount == entry.getValue()) { maxString = entry.getKey(); } } return maxString; } }
最后輸出結果如下:
以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支持億速云。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。