OSDN Git Service

18bf1835c92403f25e5ae439d9e340fdae6acb73
[stigmata/stigmata.git] / src / main / java / jp / sourceforge / stigmata / birthmarks / comparators / DPMatchingBirthmarkComparator.java
1 package jp.sourceforge.stigmata.birthmarks.comparators;
2
3 /*
4  * $Id$
5  */
6
7 import jp.sourceforge.stigmata.Birthmark;
8 import jp.sourceforge.stigmata.BirthmarkContext;
9 import jp.sourceforge.stigmata.BirthmarkElement;
10 import jp.sourceforge.stigmata.spi.BirthmarkSpi;
11
12 /**
13  * calculate similarities between two birthmarks by DP matching algorithm.
14  *
15  * @author Haruaki TAMADA
16  */
17 public class DPMatchingBirthmarkComparator extends AbstractBirthmarkComparator{
18     private int mismatchPenalty = 5;
19     private int shiftPenalty = 1;
20
21     public DPMatchingBirthmarkComparator(BirthmarkSpi spi){
22         super(spi);
23     }
24
25     public int getMismatchPenalty(){
26         return mismatchPenalty;
27     }
28
29     public void setMismatchPenalty(int mismatchPenalty){
30         this.mismatchPenalty = mismatchPenalty;
31     }
32
33     public int getShiftPenalty(){
34         return shiftPenalty;
35     }
36
37     public void setShiftPenalty(int shiftPenalty){
38         this.shiftPenalty = shiftPenalty;
39     }
40
41     @Override
42     public double compare(Birthmark b1, Birthmark b2, BirthmarkContext context){
43         if(!b1.getType().equals(b2.getType())){
44             return Double.NaN;
45         }
46
47         BirthmarkElement[] element1 = b1.getElements();
48         BirthmarkElement[] element2 = b2.getElements();
49         if(element1.length > 0 && element2.length > 0){
50             int[][] cost = createCostMatrix(element1, element2);
51
52             int max = (element1.length + element2.length) * (getMismatchPenalty() + getShiftPenalty());
53             int distance = cost[element1.length - 1][element2.length - 1];
54
55             return (double)(max - distance) / max;
56         }
57         else if(element1.length == 0 && element2.length == 0){
58             return 1d;
59         }
60         else{
61             return 0d;
62         }
63     }
64
65     @Override
66     public int getCompareCount(Birthmark b1, Birthmark b2){
67         return b1.getElementCount() + b2.getElementCount();
68     }
69
70     private int[][] createCostMatrix(BirthmarkElement[] targetX, BirthmarkElement[] targetY){
71         int[][] mismatches = getMismatchMatrix(targetX, targetY);
72         int[][] cost = new int[targetX.length][targetY.length];
73
74         cost[0][0] = mismatches[0][0] * getMismatchPenalty();
75
76         for(int i = 1; i < targetX.length; i++){
77             cost[i][0] = cost[i - 1][0] + getShiftPenalty() + mismatches[i][0] * getMismatchPenalty();
78         }
79         for(int i = 1; i < targetY.length; i++){
80             cost[0][i] = cost[0][i - 1] + getShiftPenalty() + mismatches[0][i] * getMismatchPenalty();
81         }
82         for(int i = 1; i < targetX.length; i++){
83             for(int j = 1; j < targetY.length; j++){
84                 int crossCost      = cost[i - 1][j - 1] + mismatches[i][j] * getMismatchPenalty();
85                 int horizontalCost = cost[i - 1][j    ] + mismatches[i][j] * getMismatchPenalty() + getShiftPenalty();
86                 int verticalCost   = cost[i    ][j - 1] + mismatches[i][j] * getMismatchPenalty() + getShiftPenalty();
87
88                 if(crossCost <= horizontalCost && crossCost <= verticalCost){
89                     cost[i][j] = crossCost;
90                 }
91                 else if(horizontalCost <= verticalCost){
92                     cost[i][j] = horizontalCost;
93                 }
94                 else{
95                     cost[i][j] = verticalCost;
96                 }
97             }
98         }
99         return cost;
100     }
101
102     private int[][] getMismatchMatrix(BirthmarkElement[] targetX, BirthmarkElement[] targetY){
103         int[][] mismatches = new int[targetX.length][targetY.length];
104
105         for(int i = 0; i < mismatches.length; i++){
106             for(int j = 0; j < mismatches[i].length; j++){
107                 if(targetX[i] == null){
108                     if(targetY[j] == null)            mismatches[i][j] = 0;
109                     else                              mismatches[i][j] = 1;
110                 }
111                 else{
112                     if(targetX[i].equals(targetY[j])) mismatches[i][j] = 0;
113                     else                              mismatches[i][j] = 1;
114                 }
115             }
116         }
117         return mismatches;
118     }
119 }