none
求助(为什么这个KNN算法时好时坏,一下可以运行,一下不可以) RRS feed

  • 问题

  • 头文件:
    
    #ifndef KNN_H
    #define KNN_H
    
    // the number of training points训练集的点数
    #define N 150 
    // the number of patterns
    #define M 3//类的个数
    #define TEST_NUM 100
    #define K 9
    
    typedef struct{
        double a;
        double b; 
    	double c;
    	double d;
        int label; // the class label of the training point类标号
    }TrainingPoint;
    
    typedef struct{
        double value; //the distance from the test point to a training point测试点距训练集的距离
        int num; 
    }Distance;
    
    /*
    * return the class label of the point(x, y), k specifies the number
    * of neighbors selected
    */
    
    //返回选择的K近邻
    
    int knn_classify(double a, double b,double c,double d, int k);
    /*
    * find the indexs of the k nearest neighbors in the training array,
    * and copy the indexs into the array "index"
    */
    
    //将K近邻的类标号放入数组
    
    void find_knn(double a, double b, double c,double d,int index[], int k);
    /*
    * using the labels of the k nearest neighbors to vote 
    * to get the most frequent label
    */
    
    //投票
    
    int vote(int index[], int k);
    /*
    * sort the distances and find the indexes of k nearest neighbors.
    * Here "qsort" function provided in ANSI C was used
    */
    
    //qsort 算法找出K近邻
    
    
    void quick_sort(Distance distance[], int n, int index[], int k);
    /*
    * the function needed in "qsort"; return 0 if dist1 == dist2;
    * return 1 if dist1 > dist2; else return -1
    */
    
    //qsort算法比较大小
    
    
    int compare(const void * dist1, const void * dist2);
    void print();
    
    #endif
    
    

    KNN的实现部分:

    #include "knn.h"
    #include <math.h>
    #include <stdlib.h>
    #include <stdio.h>
    
    extern TrainingPoint training_data [N];//
    extern double test_data[TEST_NUM][4];
    extern int test_label[TEST_NUM];//定义类标号;
    //extern int K;
    
    //用knn算法获得类标号
    int knn_classify(double a, double b,double c,double d, int k) {
        int * index;
        int label = 0;
        int i = 0;
        
        index = (int *)malloc(k * sizeof(int));
        //step 1: 找出最近邻
        find_knn(a, b,c,d, index, k);
        
        //step 2: 类标号投票
        label = vote(index, k);
        
        free(index);
        
        return label;
    }
    
    //从训练集中找出最近邻坐标
    void find_knn(double a, double b, double c,double d,int index[], int k) 
    {
        Distance distance[N];
        int i = 0;
        
        //计算欧式距离
        for (i = 0; i < N; i++) {
            distance[i].value = sqrt((a - training_data[i].a) * (a - training_data[i].a)
                    + (b - training_data[i].b) * (b - training_data[i].b)+(c - training_data[i].c) * (c - training_data[i].c)+(d - training_data[i].d) * (d - training_data[i].d));
            distance[i].num = i;
        }
        
        //排序(qsort algorithm最快速)
        quick_sort(distance, N, index, k);
    }
    
    
    //类标号投票
    int vote(int index[], int k)
    {
        int i = 0;
        int max = 0;
        int cur = 0;
        int labels[M];
        int label = 0;
        
        for(i = 0; i < M; i++) {
            labels[i] = 0;
        }
         
        for(i = 0; i < k; i++) {
            label = training_data[index[i]].label;
            labels[label-1]++;
        }
                
        cur = labels[0];
        for(i = 1; i < M; i++) {
            if(labels[i] > cur) {
                cur = labels[i];
                max = i;
            }
        }
        
        return (max+1);
    }
    
    
    //sort the distance using quick sort algorithm
    
    //qsort排序
    
    void quick_sort(Distance distance[], int n, int index[], int k)
    {
        int i = 0;
        int j = 0;
        int max = 0;
        
        
        qsort(distance, n, sizeof(Distance), compare);
        
        //get the indexes of the k shortest distances
        for(i = 0; i < k; i++) {
            index[i] = distance[i].num;
        }
    }
    
    
    int compare(const void * arg1, const void * arg2)
    {
        Distance* dist1;
        Distance* dist2;
        dist1 = (Distance*) arg1;
        dist2 = (Distance*) arg2;
        if(dist1->value > dist2->value) {
            return 1;
        } else {
            if (dist1->value < dist2->value) {
                return -1;
            } else {
                return 0;
            }
        }
    }
    
    
    



    main.cpp主要负责文件读写

    #include <stdio.h>
    #include <stdlib.h>
    #include <string.h>
    #include <sys/timeb.h>
    #include <time.h>
    #include "knn.h"
    
    
    TrainingPoint training_data [N];
    double test_data [TEST_NUM][4];
    int test_label [TEST_NUM];
    
    //read n training data from 'filename' to buffer
    
    //将训练数据从文件读入buffer
    
    void read_training_data(char * filename, int n);
    //read n test data from 'filename' to buffer
    
    //将测试数据从文件读入Buffer
    void read_test_data(char * filename, int n);
    //write n data in buffer to 'filename'
    
    //将结果从Buffer写入文件
    void write_data(char * filename, int n);
    
    int main(int argc, char *argv[])
    {
      int i = 0;
      printf("read training data!\n");
      
      read_training_data("./aa.txt", N);
       printf("read test data!\n");
      read_test_data("./ab.txt", TEST_NUM);
      
      for(i = 0; i < TEST_NUM; i++) {
          test_label[i] = knn_classify(test_data[i][0], test_data[i][1],test_data[i][2],test_data[i][3], K);
      }
      
      write_data("./test_result.txt", TEST_NUM);
        printf("Press ENTER to continue...\n");
      getchar();
      return 0;
    }
    
    //read training data from file
    void read_training_data(char * filename, int n) 
    {
        FILE * fp;
        int i = 0;
        char * token;
        char line[100];
        if((fp = fopen(filename, "r")) != NULL) {   
            for(i = 0; i < n; i++) {
                fgets(line, 100, fp);
                token = strtok(line, ",");
                training_data[i].a = atof(token) ;
                token = strtok(NULL, ",");
                training_data[i].b = atof(token);
                token = strtok(NULL, ",");
                training_data[i].c = atof(token) ;
    			token = strtok(NULL, ",");
               training_data[i].d = atof(token) ;
    			token = strtok(NULL, ",");
    			training_data[i].label = atoi(token);
            }
            fclose(fp);
        }   
    }
    
    //read test data from file
    void read_test_data(char * filename, int n)
    {
        FILE * fp;
        int i = 0;
        char * token;
        char line[100];
        if((fp = fopen(filename, "r")) != NULL) {   
            for(i = 0; i < n; i++) {
                fgets(line, 100, fp);
                token = strtok(line, ",");
                test_data[i][0] = atof(token) ;
                token = strtok(NULL, ",");
                test_data[i][1] = atof(token);
    			  token = strtok(NULL, ",");
                test_data[i][2] = atof(token);
    			  token = strtok(NULL, ",");
                test_data[i][3] = atof(token);
            }
            fclose(fp);
        }   
    }
    
    //write the classification results into a txt file
    void write_data(char * filename, int n) 
    {
    
    
        FILE * fp;
        int i = 0;
        
        if((fp = fopen(filename, "w")) != NULL) {
            for(i = 0; i < n; i++) {
                fprintf(fp, "%f\t%f\t%f\t%f\t%d\n", test_data[i][0], test_data[i][1], test_data[i][2],test_data[i][3],test_label[i]);
            }
            fclose(fp);
        }
    }
    
    

    为什么会出现这样的错误,谢谢好心人的指点,不胜感激~!!

    2011年12月13日 11:35

答案

  • 应该是你的代码有问题吧。

    我刚才试了,有异常Run-Time Check Failure #2 - Stack around the variable 'labels' was corrupted.

    也有warning,但不是warning C6385,是warning C4996, 如下:

    warning C4996: 'fopen': This function or variable may be unsafe. 

    warning C4996: 'strtok': This function or variable may be unsafe

     

    如果你的代码很健壮的话。应该不会出现性能不稳定的情况。建议多检查代码,然后单步调试。

     

     

    • 已建议为答案 Helen Zhao 2011年12月20日 2:07
    • 已标记为答案 Helen Zhao 2011年12月20日 7:48
    2011年12月15日 6:08

全部回复

  • 其中还有这个warning,请问是什么意思 warning C6385: 无效的数据: 访问“distance”时,“1*0”个字节可读,但可能读取了“32”个字节: Lines: 83, 84, 85, 88, 91, 92
    如何修改,谢谢

    2011年12月13日 11:39
  • 应该是你的代码有问题吧。

    我刚才试了,有异常Run-Time Check Failure #2 - Stack around the variable 'labels' was corrupted.

    也有warning,但不是warning C6385,是warning C4996, 如下:

    warning C4996: 'fopen': This function or variable may be unsafe. 

    warning C4996: 'strtok': This function or variable may be unsafe

     

    如果你的代码很健壮的话。应该不会出现性能不稳定的情况。建议多检查代码,然后单步调试。

     

     

    • 已建议为答案 Helen Zhao 2011年12月20日 2:07
    • 已标记为答案 Helen Zhao 2011年12月20日 7:48
    2011年12月15日 6:08