OSDN Git Service

オーバーライドしているメソッド全てに @Override アノテーションを追加した.
[stigmata/stigmata-core.git] / src / main / java / jp / sourceforge / stigmata / birthmarks / comparators / DPMatchingBirthmarkComparator.java
1 package jp.sourceforge.stigmata.birthmarks.comparators;
2
3 /*
4  * $Id$
5  */
6
7 import jp.sourceforge.stigmata.Birthmark;
8 import jp.sourceforge.stigmata.BirthmarkContext;
9 import jp.sourceforge.stigmata.BirthmarkElement;
10 import jp.sourceforge.stigmata.spi.BirthmarkSpi;
11
12 /**
13  * calculate similarities between two birthmarks by DP matching algorithm.
14  *
15  * @author Haruaki TAMADA
16  * @version $Revision$ 
17  */
18 public class DPMatchingBirthmarkComparator extends AbstractBirthmarkComparator{
19     private int mismatchPenalty = 5;
20     private int shiftPenalty = 1;
21
22     public DPMatchingBirthmarkComparator(BirthmarkSpi spi){
23         super(spi);
24     }
25
26     public int getMismatchPenalty(){
27         return mismatchPenalty;
28     }
29
30     public void setMismatchPenalty(int mismatchPenalty){
31         this.mismatchPenalty = mismatchPenalty;
32     }
33
34     public int getShiftPenalty(){
35         return shiftPenalty;
36     }
37
38     public void setShiftPenalty(int shiftPenalty){
39         this.shiftPenalty = shiftPenalty;
40     }
41
42     @Override
43     public double compare(Birthmark b1, Birthmark b2, BirthmarkContext context){
44         if(!b1.getType().equals(b2.getType())){
45             return Double.NaN;
46         }
47
48         BirthmarkElement[] element1 = b1.getElements();
49         BirthmarkElement[] element2 = b2.getElements();
50         if(element1.length > 0 && element2.length > 0){
51             int[][] cost = createCostMatrix(element1, element2);
52
53             int max = (element1.length + element2.length) * (getMismatchPenalty() + getShiftPenalty());
54             int distance = cost[element1.length - 1][element2.length - 1];
55
56             return (double)(max - distance) / max;
57         }
58         else if(element1.length == 0 && element2.length == 0){
59             return 1d;
60         }
61         else{
62             return 0d;
63         }
64     }
65
66     @Override
67     public int getCompareCount(Birthmark b1, Birthmark b2){
68         return b1.getElementCount() + b2.getElementCount();
69     }
70
71     private int[][] createCostMatrix(BirthmarkElement[] targetX, BirthmarkElement[] targetY){
72         int[][] mismatches = getMismatchMatrix(targetX, targetY);
73         int[][] cost = new int[targetX.length][targetY.length];
74
75         cost[0][0] = mismatches[0][0] * getMismatchPenalty();
76
77         for(int i = 1; i < targetX.length; i++){
78             cost[i][0] = cost[i - 1][0] + getShiftPenalty() + mismatches[i][0] * getMismatchPenalty();
79         }
80         for(int i = 1; i < targetY.length; i++){
81             cost[0][i] = cost[0][i - 1] + getShiftPenalty() + mismatches[0][i] * getMismatchPenalty();
82         }
83         for(int i = 1; i < targetX.length; i++){
84             for(int j = 1; j < targetY.length; j++){
85                 int crossCost      = cost[i - 1][j - 1] + mismatches[i][j] * getMismatchPenalty();
86                 int horizontalCost = cost[i - 1][j    ] + mismatches[i][j] * getMismatchPenalty() + getShiftPenalty();
87                 int verticalCost   = cost[i    ][j - 1] + mismatches[i][j] * getMismatchPenalty() + getShiftPenalty();
88
89                 if(crossCost <= horizontalCost && crossCost <= verticalCost){
90                     cost[i][j] = crossCost;
91                 }
92                 else if(horizontalCost <= verticalCost){
93                     cost[i][j] = horizontalCost;
94                 }
95                 else{
96                     cost[i][j] = verticalCost;
97                 }
98             }
99         }
100         return cost;
101     }
102
103     private int[][] getMismatchMatrix(BirthmarkElement[] targetX, BirthmarkElement[] targetY){
104         int[][] mismatches = new int[targetX.length][targetY.length];
105
106         for(int i = 0; i < mismatches.length; i++){
107             for(int j = 0; j < mismatches[i].length; j++){
108                 if(targetX[i] == null){
109                     if(targetY[j] == null)            mismatches[i][j] = 0;
110                     else                              mismatches[i][j] = 1;
111                 }
112                 else{
113                     if(targetX[i].equals(targetY[j])) mismatches[i][j] = 0;
114                     else                              mismatches[i][j] = 1;
115                 }
116             }
117         }
118         return mismatches;
119     }
120 }