OSDN Git Service

make it possible to call WindowOutputProcessor.showErrorDialog from non
[stew/Stew4.git] / src / net / argius / stew / ui / window / ResultSetTableModel.java
1 package net.argius.stew.ui.window;
2
3 import static java.sql.Types.*;
4 import static java.util.Collections.nCopies;
5 import static net.argius.stew.text.TextUtilities.join;
6 import java.math.*;
7 import java.sql.*;
8 import java.util.*;
9 import java.util.Map.Entry;
10 import java.util.concurrent.*;
11 import java.util.regex.*;
12 import javax.swing.table.*;
13 import net.argius.stew.*;
14
15 /**
16  * The TableModel for ResultSetTable.
17  * It mainly provides to synchronize with databases.
18  */
19 final class ResultSetTableModel extends DefaultTableModel {
20
21     static final Logger log = Logger.getLogger(ResultSetTableModel.class);
22
23     private static final long serialVersionUID = -8861356207097438822L;
24     private static final String PTN1 = "\\s*SELECT\\s.+?\\sFROM\\s+([^\\s]+).*";
25
26     private final int[] types;
27     private final String commandString;
28
29     private Connection conn;
30     private Object tableName;
31     private String[] primaryKeys;
32     private boolean updatable;
33     private boolean linkable;
34
35     ResultSetTableModel(ResultSetReference ref) throws SQLException {
36         super(0, getColumnCount(ref));
37         ResultSet rs = ref.getResultSet();
38         ColumnOrder order = ref.getOrder();
39         final String cmd = ref.getCommandString();
40         final boolean orderIsEmpty = order.size() == 0;
41         ResultSetMetaData meta = rs.getMetaData();
42         final int columnCount = getColumnCount();
43         int[] types = new int[columnCount];
44         for (int i = 0; i < columnCount; i++) {
45             final int type;
46             final String name;
47             if (orderIsEmpty) {
48                 type = meta.getColumnType(i + 1);
49                 name = meta.getColumnName(i + 1);
50             } else {
51                 type = meta.getColumnType(order.getOrder(i));
52                 name = order.getName(i);
53             }
54             types[i] = type;
55             @SuppressWarnings({"unchecked", "unused"})
56             Object o = columnIdentifiers.set(i, name);
57         }
58         this.types = types;
59         this.commandString = cmd;
60         try {
61             analyzeForLinking(rs, cmd);
62         } catch (Exception ex) {
63             log.warn(ex);
64         }
65     }
66
67     private static final class UnlinkedRow extends Vector<Object> {
68
69         UnlinkedRow(Vector<?> rowData) {
70             super(rowData);
71         }
72
73         UnlinkedRow(Object[] rowData) {
74             super(rowData.length);
75             for (final Object o : rowData) {
76                 add(o);
77             }
78         }
79
80     }
81
82     private static int getColumnCount(ResultSetReference ref) throws SQLException {
83         final int size = ref.getOrder().size();
84         return (size == 0) ? ref.getResultSet().getMetaData().getColumnCount() : size;
85     }
86
87     @Override
88     public Class<?> getColumnClass(int columnIndex) {
89         switch (types[columnIndex]) {
90             case CHAR:
91             case VARCHAR:
92             case LONGVARCHAR:
93                 return String.class;
94             case BOOLEAN:
95             case BIT:
96                 return Boolean.class;
97             case TINYINT:
98                 return Byte.class;
99             case SMALLINT:
100                 return Short.class;
101             case INTEGER:
102                 return Integer.class;
103             case BIGINT:
104                 return Long.class;
105             case REAL:
106                 return Float.class;
107             case DOUBLE:
108             case FLOAT:
109                 return Double.class;
110             case DECIMAL:
111             case NUMERIC:
112                 return BigDecimal.class;
113             default:
114                 return Object.class;
115         }
116     }
117
118     @Override
119     public boolean isCellEditable(int row, int column) {
120         if (primaryKeys == null || primaryKeys.length == 0) {
121             return false;
122         }
123         return super.isCellEditable(row, column);
124     }
125
126     @Override
127     public void setValueAt(Object newValue, int row, int column) {
128         if (!linkable) {
129             return;
130         }
131         final Object oldValue = getValueAt(row, column);
132         final boolean changed;
133         if (newValue == null) {
134             changed = (newValue != oldValue);
135         } else {
136             changed = !newValue.equals(oldValue);
137         }
138         if (changed) {
139             if (isLinkedRow(row)) {
140                 Object[] keys = columnIdentifiers.toArray();
141                 try {
142                     executeUpdate(getRowData(keys, row), keys[column], newValue);
143                 } catch (Exception ex) {
144                     log.error(ex);
145                     throw new RuntimeException(ex);
146                 }
147             } else {
148                 if (log.isTraceEnabled()) {
149                     log.debug("update unlinked row");
150                 }
151             }
152         } else {
153             if (log.isDebugEnabled()) {
154                 log.debug("skip to update");
155             }
156         }
157         super.setValueAt(newValue, row, column);
158     }
159
160     void addUnlinkedRow(Object[] rowData) {
161         addUnlinkedRow(convertToVector(rowData));
162     }
163
164     void addUnlinkedRow(Vector<?> rowData) {
165         addRow(new UnlinkedRow(rowData));
166     }
167
168     void insertUnlinkedRow(int row, Object[] rowData) {
169         insertUnlinkedRow(row, new UnlinkedRow(rowData));
170     }
171
172     void insertUnlinkedRow(int row, Vector<?> rowData) {
173         insertRow(row, new UnlinkedRow(rowData));
174     }
175
176     /**
177      * Links a row with database.
178      * @param rowIndex
179      * @return true if it successed, false if already linked
180      * @throws SQLException failed to link by SQL error
181      */
182     boolean linkRow(int rowIndex) throws SQLException {
183         if (isLinkedRow(rowIndex)) {
184             return false;
185         }
186         executeInsert(getRowData(columnIdentifiers.toArray(), rowIndex));
187         @SuppressWarnings("unchecked")
188         Vector<Object> rows = getDataVector();
189         rows.set(rowIndex, new Vector<Object>((Vector<?>)rows.get(rowIndex)));
190         return true;
191     }
192
193     /**
194      * Removes a linked row.
195      * @param rowIndex
196      * @return true if it successed, false if already linked
197      * @throws SQLException failed to link by SQL error
198      */
199     boolean removeLinkedRow(int rowIndex) throws SQLException {
200         if (!isLinkedRow(rowIndex)) {
201             return false;
202         }
203         executeDelete(getRowData(columnIdentifiers.toArray(), rowIndex));
204         super.removeRow(rowIndex);
205         return true;
206     }
207
208     private Map<Object, Object> getRowData(Object[] keys, int rowIndex) {
209         Map<Object, Object> rowData = new LinkedHashMap<Object, Object>();
210         for (int columnIndex = 0, n = keys.length; columnIndex < n; columnIndex++) {
211             rowData.put(keys[columnIndex], getValueAt(rowIndex, columnIndex));
212         }
213         return rowData;
214     }
215
216     /**
217      * Sorts this table.
218      * @param columnIndex
219      * @param descending
220      */
221     void sort(final int columnIndex, boolean descending) {
222         final int f = (descending) ? -1 : 1;
223         @SuppressWarnings("unchecked")
224         List<List<Object>> dataVector = getDataVector();
225         Collections.sort(dataVector, new RowComparator(f, columnIndex));
226     }
227
228     private static final class RowComparator implements Comparator<List<Object>> {
229
230         private final int f;
231         private final int columnIndex;
232
233         RowComparator(int f, int columnIndex) {
234             this.f = f;
235             this.columnIndex = columnIndex;
236         }
237
238         @Override
239         public int compare(List<Object> row1, List<Object> row2) {
240             return c(row1, row2) * f;
241         }
242
243         private int c(List<Object> row1, List<Object> row2) {
244             if (row1 == null || row2 == null) {
245                 return row1 == null ? row2 == null ? 0 : -1 : 1;
246             }
247             final Object o1 = row1.get(columnIndex);
248             final Object o2 = row2.get(columnIndex);
249             if (o1 == null || o2 == null) {
250                 return o1 == null ? o2 == null ? 0 : -1 : 1;
251             }
252             if (o1 instanceof Comparable<?> && o1.getClass() == o2.getClass()) {
253                 @SuppressWarnings("unchecked")
254                 Comparable<Object> c1 = (Comparable<Object>)o1;
255                 @SuppressWarnings("unchecked")
256                 Comparable<Object> c2 = (Comparable<Object>)o2;
257                 return c1.compareTo(c2);
258             }
259             return o1.toString().compareTo(o2.toString());
260         }
261
262     }
263
264     /**
265      * Checks whether this table is updatable.
266      * @return
267      */
268     boolean isUpdatable() {
269         return updatable;
270     }
271
272     /**
273      * Checks whether this table is linkable.
274      * @return
275      */
276     boolean isLinkable() {
277         return linkable;
278     }
279
280     /**
281      * Checks whether the specified row is linked.
282      * @param rowIndex
283      * @return
284      */
285     boolean isLinkedRow(int rowIndex) {
286         return !(getDataVector().get(rowIndex) instanceof UnlinkedRow);
287     }
288
289     /**
290      * Checks whether this table has unlinked rows.
291      * @return
292      */
293     boolean hasUnlinkedRows() {
294         for (final Object row : getDataVector()) {
295             if (row instanceof UnlinkedRow) {
296                 return true;
297             }
298         }
299         return false;
300     }
301
302     /**
303      * Checks whether specified connection is same as the connection it has.
304      * @return
305      */
306     boolean isSameConnection(Connection conn) {
307         return conn == this.conn;
308     }
309
310     /**
311      * Returns the command string that creates this.
312      * @return
313      */
314     String getCommandString() {
315         return commandString;
316     }
317
318     private void executeUpdate(Map<Object, Object> keyMap, Object targetKey, Object targetValue) throws SQLException {
319         final String sql = String.format("UPDATE %s SET %s=? WHERE %s",
320                                          tableName,
321                                          quoteIfNeeds(targetKey),
322                                          toKeyPhrase(primaryKeys));
323         List<Object> a = new ArrayList<Object>();
324         a.add(targetValue);
325         for (Object pk : primaryKeys) {
326             a.add(keyMap.get(pk));
327         }
328         executeSql(sql, a.toArray());
329     }
330
331     private void executeInsert(Map<Object, Object> rowData) throws SQLException {
332         final int dataSize = rowData.size();
333         List<Object> keys = new ArrayList<Object>(dataSize);
334         List<Object> values = new ArrayList<Object>(dataSize);
335         for (Entry<?, ?> entry : rowData.entrySet()) {
336             keys.add(quoteIfNeeds(String.valueOf(entry.getKey())));
337             values.add(entry.getValue());
338         }
339         final String sql = String.format("INSERT INTO %s (%s) VALUES (%s)",
340                                          tableName,
341                                          join(",", keys),
342                                          join(",", nCopies(dataSize, "?")));
343         executeSql(sql, values.toArray());
344     }
345
346     private void executeDelete(Map<Object, Object> keyMap) throws SQLException {
347         final String sql = String.format("DELETE FROM %s WHERE %s",
348                                          tableName,
349                                          toKeyPhrase(primaryKeys));
350         List<Object> a = new ArrayList<Object>();
351         for (Object pk : primaryKeys) {
352             a.add(keyMap.get(pk));
353         }
354         executeSql(sql, a.toArray());
355     }
356
357     private void executeSql(final String sql, final Object[] parameters) throws SQLException {
358         if (log.isDebugEnabled()) {
359             log.debug("SQL: " + sql);
360             log.debug("parameters: " + Arrays.asList(parameters));
361         }
362         final CountDownLatch latch = new CountDownLatch(1);
363         final List<SQLException> errors = new ArrayList<SQLException>();
364         final Connection conn = this.conn;
365         final int[] types = this.types;
366         // asynchronous execution
367         class SqlTask implements Runnable {
368             @Override
369             public void run() {
370                 try {
371                     if (conn.isClosed()) {
372                         throw new SQLException(ResourceManager.Default.get("e.not-connect"));
373                     }
374                     final PreparedStatement stmt = conn.prepareStatement(sql);
375                     try {
376                         ValueTransporter transporter = ValueTransporter.getInstance("");
377                         int index = 0;
378                         for (Object o : parameters) {
379                             boolean isNull = false;
380                             if (o == null || String.valueOf(o).length() == 0) {
381                                 if (getColumnClass(index) != String.class) {
382                                     isNull = true;
383                                 }
384                             }
385                             ++index;
386                             if (isNull) {
387                                 stmt.setNull(index, types[index - 1]);
388                             } else {
389                                 transporter.setObject(stmt, index, o);
390                             }
391                         }
392                         final int updatedCount = stmt.executeUpdate();
393                         if (updatedCount != 1) {
394                             throw new SQLException("updated count is not 1, but " + updatedCount);
395                         }
396                     } finally {
397                         stmt.close();
398                     }
399                 } catch (SQLException ex) {
400                     log.error(ex);
401                     errors.add(ex);
402                 } catch (Throwable th) {
403                     log.error(th);
404                     SQLException ex = new SQLException();
405                     ex.initCause(th);
406                     errors.add(ex);
407                 }
408                 latch.countDown();
409             }
410         }
411         DaemonThreadFactory.execute(new SqlTask());
412         try {
413             // waits for a task to stop
414             latch.await(3L, TimeUnit.SECONDS);
415         } catch (InterruptedException ex) {
416             throw new RuntimeException(ex);
417         }
418         if (latch.getCount() != 0) {
419             class SqlTaskErrorHandler implements Runnable {
420                 @Override
421                 public void run() {
422                     try {
423                         latch.await();
424                     } catch (InterruptedException ex) {
425                         log.warn(ex);
426                     }
427                     if (!errors.isEmpty()) {
428                         WindowOutputProcessor.showErrorDialog(null, errors.get(0));
429                     }
430                 }
431             }
432             DaemonThreadFactory.execute(new SqlTaskErrorHandler());
433         } else if (!errors.isEmpty()) {
434             if (log.isDebugEnabled()) {
435                 for (final Exception ex : errors) {
436                     log.debug("", ex);
437                 }
438             }
439             throw errors.get(0);
440         }
441     }
442
443     private static String toKeyPhrase(Object[] keys) {
444         List<String> a = new ArrayList<String>(keys.length);
445         for (final Object key : keys) {
446             a.add(String.format("%s=?", key));
447         }
448         return join(" AND ", a);
449     }
450
451     private static String quoteIfNeeds(Object o) {
452         final String s = String.valueOf(o);
453         if (s.matches(".*\\W.*")) {
454             return String.format("\"%s\"", s);
455         }
456         return s;
457     }
458
459     private void analyzeForLinking(ResultSet rs, String cmd) throws SQLException {
460         if (rs == null) {
461             return;
462         }
463         Statement stmt = rs.getStatement();
464         if (stmt == null) {
465             return;
466         }
467         Connection conn = stmt.getConnection();
468         if (conn == null) {
469             return;
470         }
471         this.conn = conn;
472         if (conn.isReadOnly()) {
473             return;
474         }
475         final String tableName = findTableName(cmd);
476         if (tableName.length() == 0) {
477             return;
478         }
479         this.tableName = tableName;
480         this.updatable = true;
481         List<String> pkList = findPrimaryKeys(conn, tableName);
482         if (pkList.isEmpty()) {
483             return;
484         }
485         @SuppressWarnings("unchecked")
486         final Collection<Object> columnIdentifiers = this.columnIdentifiers;
487         if (!columnIdentifiers.containsAll(pkList)) {
488             return;
489         }
490         if (findUnion(cmd)) {
491             return;
492         }
493         this.primaryKeys = pkList.toArray(new String[pkList.size()]);
494         this.linkable = true;
495     }
496
497     /**
498      * Finds a table name.
499      * @param cmd command string or SQL
500      * @return table name if it found only a table, or empty string
501      */
502     static String findTableName(String cmd) {
503         if (cmd != null) {
504             StringBuilder buffer = new StringBuilder();
505             Scanner scanner = new Scanner(cmd);
506             try {
507                 while (scanner.hasNextLine()) {
508                     final String line = scanner.nextLine();
509                     buffer.append(line.replaceAll("/\\*.*?\\*/|//.*", ""));
510                     buffer.append(' ');
511                 }
512             } finally {
513                 scanner.close();
514             }
515             Pattern p = Pattern.compile(PTN1, Pattern.CASE_INSENSITIVE);
516             Matcher m = p.matcher(buffer);
517             if (m.matches()) {
518                 String afterFrom = m.group(1);
519                 String[] words = afterFrom.split("\\s");
520                 boolean foundComma = false;
521                 for (int i = 0; i < 2 && i < words.length; i++) {
522                     String word = words[i];
523                     if (word.indexOf(',') >= 0) {
524                         foundComma = true;
525                     }
526                 }
527                 if (!foundComma) {
528                     String word = words[0];
529                     if (word.matches("[A-Za-z0-9_\\.]+")) {
530                         return word;
531                     }
532                 }
533             }
534         }
535         return "";
536     }
537
538     private static List<String> findPrimaryKeys(Connection conn, String tableName) throws SQLException {
539         DatabaseMetaData dbmeta = conn.getMetaData();
540         final String cp0;
541         final String sp0;
542         final String tp0;
543         if (tableName.contains(".")) {
544             String[] splitted = tableName.split("\\.");
545             if (splitted.length >= 3) {
546                 cp0 = splitted[0];
547                 sp0 = splitted[1];
548                 tp0 = splitted[2];
549             } else {
550                 cp0 = null;
551                 sp0 = splitted[0];
552                 tp0 = splitted[1];
553             }
554         } else {
555             cp0 = null;
556             sp0 = dbmeta.getUserName();
557             tp0 = tableName;
558         }
559         final String cp;
560         final String sp;
561         final String tp;
562         if (dbmeta.storesLowerCaseIdentifiers()) {
563             cp = (cp0 == null) ? null : cp0.toLowerCase();
564             sp = (sp0 == null) ? null : sp0.toLowerCase();
565             tp = tp0.toLowerCase();
566         } else if (dbmeta.storesUpperCaseIdentifiers()) {
567             cp = (cp0 == null) ? null : cp0.toUpperCase();
568             sp = (sp0 == null) ? null : sp0.toUpperCase();
569             tp = tp0.toUpperCase();
570         } else {
571             cp = cp0;
572             sp = sp0;
573             tp = tp0;
574         }
575         if (cp == null && sp == null) {
576             return getPrimaryKeys(dbmeta, null, null, tp);
577         }
578         List<String> a = getPrimaryKeys(dbmeta, cp, sp, tp);
579         if (a.isEmpty()) {
580             return getPrimaryKeys(dbmeta, null, null, tp);
581         }
582         return a;
583     }
584
585     private static List<String> getPrimaryKeys(DatabaseMetaData dbmeta,
586                                                String catalog,
587                                                String schema,
588                                                String table) throws SQLException {
589         ResultSet rs = dbmeta.getPrimaryKeys(catalog, schema, table);
590         try {
591             List<String> pkList = new ArrayList<String>();
592             Set<String> schemaSet = new HashSet<String>();
593             while (rs.next()) {
594                 pkList.add(rs.getString(4));
595                 schemaSet.add(rs.getString(2));
596             }
597             if (schemaSet.size() != 1) {
598                 return Collections.emptyList();
599             }
600             return pkList;
601         } finally {
602             rs.close();
603         }
604     }
605
606     private static boolean findUnion(String sql) {
607         String s = sql;
608         if (s.indexOf('\'') >= 0) {
609             if (s.indexOf("\\'") >= 0) {
610                 s = s.replaceAll("\\'", "");
611             }
612             s = s.replaceAll("'[^']+'", "''");
613         }
614         StringTokenizer tokenizer = new StringTokenizer(s);
615         while (tokenizer.hasMoreTokens()) {
616             if (tokenizer.nextToken().equalsIgnoreCase("UNION")) {
617                 return true;
618             }
619         }
620         return false;
621     }
622
623 }