OSDN Git Service

a1a92c8de4b1198cfcfb765f4f53c4a6b0de8108
[stew/Stew4.git] / src / net / argius / stew / ui / window / ResultSetTableModel.java
1 package net.argius.stew.ui.window;
2
3 import static java.awt.EventQueue.invokeLater;
4 import static java.sql.Types.*;
5 import static java.util.Collections.nCopies;
6 import static net.argius.stew.text.TextUtilities.join;
7 import java.math.*;
8 import java.sql.*;
9 import java.util.*;
10 import java.util.Map.Entry;
11 import java.util.concurrent.*;
12 import java.util.regex.*;
13 import javax.swing.table.*;
14 import net.argius.stew.*;
15
16 /**
17  * The TableModel for ResultSetTable.
18  * It mainly provides to synchronize with databases.
19  */
20 final class ResultSetTableModel extends DefaultTableModel {
21
22     static final Logger log = Logger.getLogger(ResultSetTableModel.class);
23
24     private static final long serialVersionUID = -8861356207097438822L;
25     private static final String PTN1 = "\\s*SELECT\\s.+?\\sFROM\\s+([^\\s]+).*";
26
27     private final int[] types;
28     private final String commandString;
29
30     private Connection conn;
31     private Object tableName;
32     private String[] primaryKeys;
33     private boolean updatable;
34     private boolean linkable;
35
36     ResultSetTableModel(ResultSetReference ref) throws SQLException {
37         super(0, getColumnCount(ref));
38         ResultSet rs = ref.getResultSet();
39         ColumnOrder order = ref.getOrder();
40         final String cmd = ref.getCommandString();
41         final boolean orderIsEmpty = order.size() == 0;
42         ResultSetMetaData meta = rs.getMetaData();
43         final int columnCount = getColumnCount();
44         int[] types = new int[columnCount];
45         for (int i = 0; i < columnCount; i++) {
46             final int type;
47             final String name;
48             if (orderIsEmpty) {
49                 type = meta.getColumnType(i + 1);
50                 name = meta.getColumnName(i + 1);
51             } else {
52                 type = meta.getColumnType(order.getOrder(i));
53                 name = order.getName(i);
54             }
55             types[i] = type;
56             @SuppressWarnings({"unchecked", "unused"})
57             Object o = columnIdentifiers.set(i, name);
58         }
59         this.types = types;
60         this.commandString = cmd;
61         try {
62             analyzeForLinking(rs, cmd);
63         } catch (Exception ex) {
64             log.warn(ex);
65         }
66     }
67
68     private static final class UnlinkedRow extends Vector<Object> {
69
70         UnlinkedRow(Vector<?> rowData) {
71             super(rowData);
72         }
73
74         UnlinkedRow(Object[] rowData) {
75             super(rowData.length);
76             for (final Object o : rowData) {
77                 add(o);
78             }
79         }
80
81     }
82
83     private static int getColumnCount(ResultSetReference ref) throws SQLException {
84         final int size = ref.getOrder().size();
85         return (size == 0) ? ref.getResultSet().getMetaData().getColumnCount() : size;
86     }
87
88     @Override
89     public Class<?> getColumnClass(int columnIndex) {
90         switch (types[columnIndex]) {
91             case CHAR:
92             case VARCHAR:
93             case LONGVARCHAR:
94                 return String.class;
95             case BOOLEAN:
96             case BIT:
97                 return Boolean.class;
98             case TINYINT:
99                 return Byte.class;
100             case SMALLINT:
101                 return Short.class;
102             case INTEGER:
103                 return Integer.class;
104             case BIGINT:
105                 return Long.class;
106             case REAL:
107                 return Float.class;
108             case DOUBLE:
109             case FLOAT:
110                 return Double.class;
111             case DECIMAL:
112             case NUMERIC:
113                 return BigDecimal.class;
114             default:
115                 return Object.class;
116         }
117     }
118
119     @Override
120     public boolean isCellEditable(int row, int column) {
121         if (primaryKeys == null || primaryKeys.length == 0) {
122             return false;
123         }
124         return super.isCellEditable(row, column);
125     }
126
127     @Override
128     public void setValueAt(Object newValue, int row, int column) {
129         if (!linkable) {
130             return;
131         }
132         final Object oldValue = getValueAt(row, column);
133         final boolean changed;
134         if (newValue == null) {
135             changed = (newValue != oldValue);
136         } else {
137             changed = !newValue.equals(oldValue);
138         }
139         if (changed) {
140             if (isLinkedRow(row)) {
141                 Object[] keys = columnIdentifiers.toArray();
142                 try {
143                     executeUpdate(getRowData(keys, row), keys[column], newValue);
144                 } catch (Exception ex) {
145                     log.error(ex);
146                     throw new RuntimeException(ex);
147                 }
148             } else {
149                 if (log.isTraceEnabled()) {
150                     log.debug("update unlinked row");
151                 }
152             }
153         } else {
154             if (log.isDebugEnabled()) {
155                 log.debug("skip to update");
156             }
157         }
158         super.setValueAt(newValue, row, column);
159     }
160
161     void addUnlinkedRow(Object[] rowData) {
162         addUnlinkedRow(convertToVector(rowData));
163     }
164
165     void addUnlinkedRow(Vector<?> rowData) {
166         addRow(new UnlinkedRow(rowData));
167     }
168
169     void insertUnlinkedRow(int row, Object[] rowData) {
170         insertUnlinkedRow(row, new UnlinkedRow(rowData));
171     }
172
173     void insertUnlinkedRow(int row, Vector<?> rowData) {
174         insertRow(row, new UnlinkedRow(rowData));
175     }
176
177     /**
178      * Links a row with database.
179      * @param rowIndex
180      * @return true if it successed, false if already linked
181      * @throws SQLException failed to link by SQL error
182      */
183     boolean linkRow(int rowIndex) throws SQLException {
184         if (isLinkedRow(rowIndex)) {
185             return false;
186         }
187         executeInsert(getRowData(columnIdentifiers.toArray(), rowIndex));
188         @SuppressWarnings("unchecked")
189         Vector<Object> rows = getDataVector();
190         rows.set(rowIndex, new Vector<Object>((Vector<?>)rows.get(rowIndex)));
191         return true;
192     }
193
194     /**
195      * Removes a linked row.
196      * @param rowIndex
197      * @return true if it successed, false if already linked
198      * @throws SQLException failed to link by SQL error
199      */
200     boolean removeLinkedRow(int rowIndex) throws SQLException {
201         if (!isLinkedRow(rowIndex)) {
202             return false;
203         }
204         executeDelete(getRowData(columnIdentifiers.toArray(), rowIndex));
205         super.removeRow(rowIndex);
206         return true;
207     }
208
209     private Map<Object, Object> getRowData(Object[] keys, int rowIndex) {
210         Map<Object, Object> rowData = new LinkedHashMap<Object, Object>();
211         for (int columnIndex = 0, n = keys.length; columnIndex < n; columnIndex++) {
212             rowData.put(keys[columnIndex], getValueAt(rowIndex, columnIndex));
213         }
214         return rowData;
215     }
216
217     /**
218      * Sorts this table.
219      * @param columnIndex
220      * @param descending
221      */
222     void sort(final int columnIndex, boolean descending) {
223         final int f = (descending) ? -1 : 1;
224         @SuppressWarnings("unchecked")
225         List<List<Object>> dataVector = getDataVector();
226         Collections.sort(dataVector, new RowComparator(f, columnIndex));
227     }
228
229     private static final class RowComparator implements Comparator<List<Object>> {
230
231         private final int f;
232         private final int columnIndex;
233
234         RowComparator(int f, int columnIndex) {
235             this.f = f;
236             this.columnIndex = columnIndex;
237         }
238
239         @Override
240         public int compare(List<Object> row1, List<Object> row2) {
241             return c(row1, row2) * f;
242         }
243
244         private int c(List<Object> row1, List<Object> row2) {
245             if (row1 == null || row2 == null) {
246                 return row1 == null ? row2 == null ? 0 : -1 : 1;
247             }
248             final Object o1 = row1.get(columnIndex);
249             final Object o2 = row2.get(columnIndex);
250             if (o1 == null || o2 == null) {
251                 return o1 == null ? o2 == null ? 0 : -1 : 1;
252             }
253             if (o1 instanceof Comparable<?> && o1.getClass() == o2.getClass()) {
254                 @SuppressWarnings("unchecked")
255                 Comparable<Object> c1 = (Comparable<Object>)o1;
256                 @SuppressWarnings("unchecked")
257                 Comparable<Object> c2 = (Comparable<Object>)o2;
258                 return c1.compareTo(c2);
259             }
260             return o1.toString().compareTo(o2.toString());
261         }
262
263     }
264
265     /**
266      * Checks whether this table is updatable.
267      * @return
268      */
269     boolean isUpdatable() {
270         return updatable;
271     }
272
273     /**
274      * Checks whether this table is linkable.
275      * @return
276      */
277     boolean isLinkable() {
278         return linkable;
279     }
280
281     /**
282      * Checks whether the specified row is linked.
283      * @param rowIndex
284      * @return
285      */
286     boolean isLinkedRow(int rowIndex) {
287         return !(getDataVector().get(rowIndex) instanceof UnlinkedRow);
288     }
289
290     /**
291      * Checks whether this table has unlinked rows.
292      * @return
293      */
294     boolean hasUnlinkedRows() {
295         for (final Object row : getDataVector()) {
296             if (row instanceof UnlinkedRow) {
297                 return true;
298             }
299         }
300         return false;
301     }
302
303     /**
304      * Checks whether specified connection is same as the connection it has.
305      * @return
306      */
307     boolean isSameConnection(Connection conn) {
308         return conn == this.conn;
309     }
310
311     /**
312      * Returns the command string that creates this.
313      * @return
314      */
315     String getCommandString() {
316         return commandString;
317     }
318
319     private void executeUpdate(Map<Object, Object> keyMap, Object targetKey, Object targetValue) throws SQLException {
320         final String sql = String.format("UPDATE %s SET %s=? WHERE %s",
321                                          tableName,
322                                          quoteIfNeeds(targetKey),
323                                          toKeyPhrase(primaryKeys));
324         List<Object> a = new ArrayList<Object>();
325         a.add(targetValue);
326         for (Object pk : primaryKeys) {
327             a.add(keyMap.get(pk));
328         }
329         executeSql(sql, a.toArray());
330     }
331
332     private void executeInsert(Map<Object, Object> rowData) throws SQLException {
333         final int dataSize = rowData.size();
334         List<Object> keys = new ArrayList<Object>(dataSize);
335         List<Object> values = new ArrayList<Object>(dataSize);
336         for (Entry<?, ?> entry : rowData.entrySet()) {
337             keys.add(quoteIfNeeds(String.valueOf(entry.getKey())));
338             values.add(entry.getValue());
339         }
340         final String sql = String.format("INSERT INTO %s (%s) VALUES (%s)",
341                                          tableName,
342                                          join(",", keys),
343                                          join(",", nCopies(dataSize, "?")));
344         executeSql(sql, values.toArray());
345     }
346
347     private void executeDelete(Map<Object, Object> keyMap) throws SQLException {
348         final String sql = String.format("DELETE FROM %s WHERE %s",
349                                          tableName,
350                                          toKeyPhrase(primaryKeys));
351         List<Object> a = new ArrayList<Object>();
352         for (Object pk : primaryKeys) {
353             a.add(keyMap.get(pk));
354         }
355         executeSql(sql, a.toArray());
356     }
357
358     private void executeSql(final String sql, final Object[] parameters) throws SQLException {
359         if (log.isDebugEnabled()) {
360             log.debug("SQL: " + sql);
361             log.debug("parameters: " + Arrays.asList(parameters));
362         }
363         final CountDownLatch latch = new CountDownLatch(1);
364         final List<SQLException> errors = new ArrayList<SQLException>();
365         final Connection conn = this.conn;
366         final int[] types = this.types;
367         // asynchronous execution
368         class SqlTask implements Runnable {
369             @Override
370             public void run() {
371                 try {
372                     if (conn.isClosed()) {
373                         throw new SQLException(ResourceManager.Default.get("e.not-connect"));
374                     }
375                     final PreparedStatement stmt = conn.prepareStatement(sql);
376                     try {
377                         ValueTransporter transporter = ValueTransporter.getInstance("");
378                         int index = 0;
379                         for (Object o : parameters) {
380                             boolean isNull = false;
381                             if (o == null || String.valueOf(o).length() == 0) {
382                                 if (getColumnClass(index) != String.class) {
383                                     isNull = true;
384                                 }
385                             }
386                             ++index;
387                             if (isNull) {
388                                 stmt.setNull(index, types[index - 1]);
389                             } else {
390                                 transporter.setObject(stmt, index, o);
391                             }
392                         }
393                         final int updatedCount = stmt.executeUpdate();
394                         if (updatedCount != 1) {
395                             throw new SQLException("updated count is not 1, but " + updatedCount);
396                         }
397                     } finally {
398                         stmt.close();
399                     }
400                 } catch (SQLException ex) {
401                     log.error(ex);
402                     errors.add(ex);
403                 } catch (Throwable th) {
404                     log.error(th);
405                     SQLException ex = new SQLException();
406                     ex.initCause(th);
407                     errors.add(ex);
408                 }
409                 latch.countDown();
410             }
411         }
412         DaemonThreadFactory.execute(new SqlTask());
413         try {
414             // waits for a task to stop
415             latch.await(3L, TimeUnit.SECONDS);
416         } catch (InterruptedException ex) {
417             throw new RuntimeException(ex);
418         }
419         if (latch.getCount() != 0) {
420             class SqlTaskErrorHandler implements Runnable {
421                 @Override
422                 public void run() {
423                     try {
424                         latch.await();
425                     } catch (InterruptedException ex) {
426                         log.warn(ex);
427                     }
428                     if (!errors.isEmpty()) {
429                         class ErrorNotifier implements Runnable {
430                             @Override
431                             public void run() {
432                                 WindowOutputProcessor.showErrorDialog(null, errors.get(0));
433                             }
434                         }
435                         invokeLater(new ErrorNotifier());
436                     }
437                 }
438             }
439             DaemonThreadFactory.execute(new SqlTaskErrorHandler());
440         } else if (!errors.isEmpty()) {
441             if (log.isDebugEnabled()) {
442                 for (final Exception ex : errors) {
443                     log.debug("", ex);
444                 }
445             }
446             throw errors.get(0);
447         }
448     }
449
450     private static String toKeyPhrase(Object[] keys) {
451         List<String> a = new ArrayList<String>(keys.length);
452         for (final Object key : keys) {
453             a.add(String.format("%s=?", key));
454         }
455         return join(" AND ", a);
456     }
457
458     private static String quoteIfNeeds(Object o) {
459         final String s = String.valueOf(o);
460         if (s.matches(".*\\W.*")) {
461             return String.format("\"%s\"", s);
462         }
463         return s;
464     }
465
466     private void analyzeForLinking(ResultSet rs, String cmd) throws SQLException {
467         if (rs == null) {
468             return;
469         }
470         Statement stmt = rs.getStatement();
471         if (stmt == null) {
472             return;
473         }
474         Connection conn = stmt.getConnection();
475         if (conn == null) {
476             return;
477         }
478         this.conn = conn;
479         if (conn.isReadOnly()) {
480             return;
481         }
482         final String tableName = findTableName(cmd);
483         if (tableName.length() == 0) {
484             return;
485         }
486         this.tableName = tableName;
487         this.updatable = true;
488         List<String> pkList = findPrimaryKeys(conn, tableName);
489         if (pkList.isEmpty()) {
490             return;
491         }
492         @SuppressWarnings("unchecked")
493         final Collection<Object> columnIdentifiers = this.columnIdentifiers;
494         if (!columnIdentifiers.containsAll(pkList)) {
495             return;
496         }
497         if (findUnion(cmd)) {
498             return;
499         }
500         this.primaryKeys = pkList.toArray(new String[pkList.size()]);
501         this.linkable = true;
502     }
503
504     /**
505      * Finds a table name.
506      * @param cmd command string or SQL
507      * @return table name if it found only a table, or empty string
508      */
509     static String findTableName(String cmd) {
510         if (cmd != null) {
511             StringBuilder buffer = new StringBuilder();
512             Scanner scanner = new Scanner(cmd);
513             try {
514                 while (scanner.hasNextLine()) {
515                     final String line = scanner.nextLine();
516                     buffer.append(line.replaceAll("/\\*.*?\\*/|//.*", ""));
517                     buffer.append(' ');
518                 }
519             } finally {
520                 scanner.close();
521             }
522             Pattern p = Pattern.compile(PTN1, Pattern.CASE_INSENSITIVE);
523             Matcher m = p.matcher(buffer);
524             if (m.matches()) {
525                 String afterFrom = m.group(1);
526                 String[] words = afterFrom.split("\\s");
527                 boolean foundComma = false;
528                 for (int i = 0; i < 2 && i < words.length; i++) {
529                     String word = words[i];
530                     if (word.indexOf(',') >= 0) {
531                         foundComma = true;
532                     }
533                 }
534                 if (!foundComma) {
535                     String word = words[0];
536                     if (word.matches("[A-Za-z0-9_\\.]+")) {
537                         return word;
538                     }
539                 }
540             }
541         }
542         return "";
543     }
544
545     private static List<String> findPrimaryKeys(Connection conn, String tableName) throws SQLException {
546         DatabaseMetaData dbmeta = conn.getMetaData();
547         final String cp0;
548         final String sp0;
549         final String tp0;
550         if (tableName.contains(".")) {
551             String[] splitted = tableName.split("\\.");
552             if (splitted.length >= 3) {
553                 cp0 = splitted[0];
554                 sp0 = splitted[1];
555                 tp0 = splitted[2];
556             } else {
557                 cp0 = null;
558                 sp0 = splitted[0];
559                 tp0 = splitted[1];
560             }
561         } else {
562             cp0 = null;
563             sp0 = dbmeta.getUserName();
564             tp0 = tableName;
565         }
566         final String cp;
567         final String sp;
568         final String tp;
569         if (dbmeta.storesLowerCaseIdentifiers()) {
570             cp = (cp0 == null) ? null : cp0.toLowerCase();
571             sp = (sp0 == null) ? null : sp0.toLowerCase();
572             tp = tp0.toLowerCase();
573         } else if (dbmeta.storesUpperCaseIdentifiers()) {
574             cp = (cp0 == null) ? null : cp0.toUpperCase();
575             sp = (sp0 == null) ? null : sp0.toUpperCase();
576             tp = tp0.toUpperCase();
577         } else {
578             cp = cp0;
579             sp = sp0;
580             tp = tp0;
581         }
582         if (cp == null && sp == null) {
583             return getPrimaryKeys(dbmeta, null, null, tp);
584         }
585         List<String> a = getPrimaryKeys(dbmeta, cp, sp, tp);
586         if (a.isEmpty()) {
587             return getPrimaryKeys(dbmeta, null, null, tp);
588         }
589         return a;
590     }
591
592     private static List<String> getPrimaryKeys(DatabaseMetaData dbmeta,
593                                                String catalog,
594                                                String schema,
595                                                String table) throws SQLException {
596         ResultSet rs = dbmeta.getPrimaryKeys(catalog, schema, table);
597         try {
598             List<String> pkList = new ArrayList<String>();
599             Set<String> schemaSet = new HashSet<String>();
600             while (rs.next()) {
601                 pkList.add(rs.getString(4));
602                 schemaSet.add(rs.getString(2));
603             }
604             if (schemaSet.size() != 1) {
605                 return Collections.emptyList();
606             }
607             return pkList;
608         } finally {
609             rs.close();
610         }
611     }
612
613     private static boolean findUnion(String sql) {
614         String s = sql;
615         if (s.indexOf('\'') >= 0) {
616             if (s.indexOf("\\'") >= 0) {
617                 s = s.replaceAll("\\'", "");
618             }
619             s = s.replaceAll("'[^']+'", "''");
620         }
621         StringTokenizer tokenizer = new StringTokenizer(s);
622         while (tokenizer.hasMoreTokens()) {
623             if (tokenizer.nextToken().equalsIgnoreCase("UNION")) {
624                 return true;
625             }
626         }
627         return false;
628     }
629
630 }