OSDN Git Service

Delete Subversion Tags (Revision, Id)
[stigmata/stigmata.git] / src / main / java / jp / sourceforge / stigmata / birthmarks / comparators / CosineSimilarityBirthmarkComparator.java
1 package jp.sourceforge.stigmata.birthmarks.comparators;
2
3 import java.util.HashMap;
4 import java.util.Map;
5
6 import jp.sourceforge.stigmata.Birthmark;
7 import jp.sourceforge.stigmata.BirthmarkContext;
8 import jp.sourceforge.stigmata.BirthmarkElement;
9 import jp.sourceforge.stigmata.birthmarks.ValueCountable;
10 import jp.sourceforge.stigmata.spi.BirthmarkSpi;
11
12 /**
13  * Comparing birthmarks by cosine similarity algorithm. This class compares
14  * birthmarks which must be implemented
15  * {@link ValueCountable <code>ValueCountable</code>} interface.
16  * 
17  * @author Haruaki Tamada
18  */
19 public class CosineSimilarityBirthmarkComparator extends
20         AbstractBirthmarkComparator{
21
22     public CosineSimilarityBirthmarkComparator(BirthmarkSpi spi){
23         super(spi);
24     }
25
26     @Override
27     public double compare(Birthmark b1, Birthmark b2, BirthmarkContext context){
28         if(!b1.getType().equals(b2.getType())){
29             return Double.NaN;
30         }
31         if(b1.getElementCount() == 0 && b2.getElementCount() == 0){
32             return 1d;
33         }
34         else if(b1.getElementCount() == 0 || b2.getElementCount() == 0){
35             return 0d;
36         }
37
38         Map<String, CountPair> pairs = new HashMap<String, CountPair>();
39         addCount(pairs, b1, true);
40         addCount(pairs, b2, false);
41
42         double norm1 = norm(pairs, true);
43         double norm2 = norm(pairs, false);
44         double product = innerproduct(pairs);
45         double similarity = product / (norm1 * norm2);
46         // System.out.printf("%g / (%g * %g) = %g%n", product, norm1, norm2, similarity);
47
48         // double radian = Math.acos(product / (norm1 * norm2));
49         // double angle = 90 - (180 * radian / Math.PI);
50         // double sim = angle / 90;
51         // System.out.printf("angle: %g (%g�x, %g)%n", radian, angle, sim);
52
53         return similarity;
54     }
55
56     private double innerproduct(Map<String, CountPair> pairs){
57         double sum = 0;
58         for(CountPair pair: pairs.values()){
59             sum += pair.get(true) * pair.get(false);
60         }
61         return sum;
62     }
63
64     private double norm(Map<String, CountPair> pairs, boolean first){
65         double sum = 0;
66         for(CountPair pair: pairs.values()){
67             sum += pair.get(first) * pair.get(first);
68         }
69         return Math.sqrt(sum);
70     }
71
72     private void addCount(Map<String, CountPair> pairs, Birthmark birthmark, boolean first){
73         for(BirthmarkElement element: birthmark){
74             ValueCountable vc = (ValueCountable)element;
75             CountPair cp = pairs.get(vc.getValueName());
76             if(cp == null){
77                 cp = new CountPair();
78                 pairs.put(vc.getValueName(), cp);
79             }
80             cp.set(first, vc.getValueCount());
81         }
82     }
83
84     private class CountPair{
85         private int c1 = 0;
86         private int c2 = 0;
87
88         public int get(boolean first){
89             if(first){
90                 return c1;
91             }
92             else{
93                 return c2;
94             }
95         }
96
97         public void set(boolean first, int count){
98             if(first){
99                 c1 = count;
100             }
101             else{
102                 c2 = count;
103             }
104         }
105     }
106
107     /**
108      * This method is used for debugging.
109      */
110     @SuppressWarnings("unused")
111     private void printAll(Map<String, CountPair> pairs){
112         System.out.println("----------");
113         for(Map.Entry<String, CountPair> entry: pairs.entrySet()){
114             CountPair pair = entry.getValue();
115             System.out.printf("%40s: %5d, %5d%n", entry.getKey(), pair.get(true), pair.get(false));
116         }
117     }
118 }