OSDN Git Service

OSSCが作成したPlan制御ツール。
[pghintplan/pg_hint_plan.git] / pg_hint_plan.c
1 /*-------------------------------------------------------------------------
2  *
3  * pg_hint_plan.c
4  *              Track statement execution in current/last transaction.
5  *
6  * Copyright (c) 2011, PostgreSQL Global Development Group
7  *
8  * IDENTIFICATION
9  *        contrib/pg_hint_plan/pg_hint_plan.c
10  *
11  *-------------------------------------------------------------------------
12  */
13 #include "postgres.h"
14 #include "fmgr.h"
15 #include "utils/elog.h"
16 #include "utils/builtins.h"
17 #include "utils/memutils.h"
18 #include "optimizer/cost.h"
19
20 #ifdef PG_MODULE_MAGIC
21 PG_MODULE_MAGIC;
22 #endif
23
24 #define HASH_ENTRIES 201
25
26 enum
27 {
28         ENABLE_SEQSCAN          = 0x01,
29         ENABLE_INDEXSCAN        = 0x02,
30         ENABLE_BITMAPSCAN       = 0x04,
31         ENABLE_TIDSCAN          = 0x08,
32         ENABLE_NESTLOOP         = 0x10,
33         ENABLE_MERGEJOIN        = 0x20,
34         ENABLE_HASHJOIN         = 0x40
35 } TYPE_BITS;
36
37 typedef struct tidlist
38 {
39         int nrels;
40         Oid *oids;
41 } TidList;
42
43 typedef struct hash_entry
44 {
45         TidList tidlist;
46         unsigned char enforce_mask;
47         struct hash_entry *next;
48 } HashEntry;
49
50 static HashEntry *hashent[HASH_ENTRIES];
51 static bool (*org_cost_hook)(CostHookType type, PlannerInfo *root, Path *path1, Path *path2);
52 static bool print_log = false;
53 static bool tweak_enabled = true;
54
55 /* Module callbacks */
56 void            _PG_init(void);
57 void            _PG_fini(void);
58 Datum           pg_add_hint(PG_FUNCTION_ARGS);
59 Datum           pg_clear_hint(PG_FUNCTION_ARGS);
60 Datum       pg_dump_hint(PG_FUNCTION_ARGS);
61 Datum           pg_enable_hint(bool arg, bool *isnull);
62 Datum           pg_enable_log(bool arg, bool *isnull);
63
64 static char *rels_str(PlannerInfo *root, Path *path);
65 static void dump_rels(char *label, PlannerInfo *root, Path *path, bool found, bool enabled);
66 static void dump_joinrels(char *label, PlannerInfo *root, Path *inpath, Path *outpath, bool found, bool enabled);
67 static bool my_cost_hook(CostHookType type, PlannerInfo *root, Path *path1, Path *path2);
68 static void free_hashent(HashEntry *head);
69 static unsigned int calc_hash(TidList *tidlist);
70 static HashEntry *search_ent(TidList *tidlist);
71
72 PG_FUNCTION_INFO_V1(pg_add_hint);
73 PG_FUNCTION_INFO_V1(pg_clear_hint);
74
75 /*
76  * Module load callbacks
77  */
78 void
79 _PG_init(void)
80 {
81         int i;
82
83         org_cost_hook = cost_hook;
84         cost_hook = my_cost_hook;
85         
86         for (i = 0 ; i < HASH_ENTRIES ; i++)
87                 hashent[i] = NULL;
88 }
89
90 /*
91  * Module unload callback
92  */
93 void
94 _PG_fini(void)
95 {
96         int i;
97
98         cost_hook = org_cost_hook;
99
100         for (i = 0 ; i < HASH_ENTRIES ; i++)
101         {
102                 free_hashent(hashent[i]);
103                 hashent[i] = NULL;
104
105         }
106 }
107
108 char *rels_str(PlannerInfo *root, Path *path)
109 {
110         char buf[4096];                                                         
111         int relid;
112         int first = 1;
113         Bitmapset *tmpbms;
114
115         if (path->pathtype == T_Invalid) return strdup("");
116
117         tmpbms = bms_copy(path->parent->relids);
118
119         buf[0] = 0;
120         while ((relid = bms_first_member(tmpbms)) >= 0)
121         {
122                 char idbuf[8];
123                 snprintf(idbuf, sizeof(idbuf), first ? "%d" : ", %d",
124                                  root->simple_rte_array[relid]->relid);
125                 if (strlen(buf) + strlen(idbuf) < sizeof(buf))
126                         strcat(buf, idbuf);
127                 first = 0;
128         }
129
130         return strdup(buf);
131 }
132
133 static int oidsortcmp(const void *a, const void *b)
134 {
135         const Oid oida = *((const Oid *)a);
136         const Oid oidb = *((const Oid *)b);
137
138         return oida - oidb;
139 }
140
141 static TidList *maketidlist(PlannerInfo *root, Path *path1, Path *path2)
142 {
143         int relid;
144         Path *paths[2] = {path1, path2};
145         int i;
146         int j = 0;
147         int nrels = 0;
148         TidList *ret = (TidList *)malloc(sizeof(TidList));
149
150         for (i = 0 ; i < 2 ; i++)
151         {
152                 if (paths[i] != NULL)
153                         nrels += bms_num_members(paths[i]->parent->relids);
154         }
155
156         ret->nrels = nrels;
157         ret->oids = (Oid *)malloc(nrels * sizeof(Oid));
158
159         for (i = 0 ; i < 2 ; i++)
160         {
161                 Bitmapset *tmpbms;
162
163                 if (paths[i] == NULL) continue;
164
165                 tmpbms= bms_copy(paths[i]->parent->relids);
166
167                 while ((relid = bms_first_member(tmpbms)) >= 0)
168                         ret->oids[j++] = root->simple_rte_array[relid]->relid;
169         }
170
171         if (nrels > 1)
172                 qsort(ret->oids, nrels, sizeof(Oid), oidsortcmp);
173
174         return ret;
175 }
176
177 static void free_tidlist(TidList *tidlist)
178 {
179         if (tidlist)
180         {
181                 if (tidlist->oids)
182                         free(tidlist->oids);
183                 free(tidlist);
184         }
185 }
186
187 int n = 0;
188 static void dump_rels(char *label, PlannerInfo *root, Path *path, bool found, bool enabled)
189 {
190         char *relsstr;
191
192         if (!print_log) return;
193         relsstr = rels_str(root, path);
194         ereport(LOG, (errmsg_internal("%04d: %s for relation %s (%s, %s)\n",
195                                                                   n++, label, relsstr,
196                                                                   found ? "found" : "not found",
197                                                                   enabled ? "enabled" : "disabled")));
198         free(relsstr);
199 }
200
201 void dump_joinrels(char *label, PlannerInfo *root, Path *inpath, Path *outpath,
202                                    bool found, bool enabled)
203 {
204         char *irelstr, *orelstr;
205
206         if (!print_log) return;
207         irelstr = rels_str(root, inpath);
208         orelstr = rels_str(root, outpath);
209
210         ereport(LOG, (errmsg_internal("%04d: %s for relation ((%s),(%s)) (%s, %s)\n",
211                                                                   n++, label, irelstr, orelstr,
212                                                                   found ? "found" : "not found",
213                                                                   enabled ? "enabled" : "disabled")));
214         free(irelstr);
215         free(orelstr);
216 }
217
218
219 bool my_cost_hook(CostHookType type, PlannerInfo *root, Path *path1, Path *path2)
220 {
221         TidList *tidlist;
222         HashEntry *ent;
223         bool ret = false;
224
225         if (!tweak_enabled)
226         {
227                 switch (type)
228                 {
229                         case COSTHOOK_seqscan:
230                                 return enable_seqscan;
231                         case COSTHOOK_indexscan:
232                                 return enable_indexscan;
233                         case COSTHOOK_bitmapscan:
234                                 return enable_bitmapscan;
235                         case COSTHOOK_tidscan:
236                                 return enable_tidscan;
237                         case COSTHOOK_nestloop:
238                                 return enable_nestloop;
239                         case COSTHOOK_mergejoin:
240                                 return enable_mergejoin;
241                         case COSTHOOK_hashjoin:
242                                 return enable_hashjoin;
243                         default:
244                                 ereport(LOG, (errmsg_internal("Unknown cost type")));
245                                 break;
246                 }
247         }
248         switch (type)
249         {
250                 case COSTHOOK_seqscan:
251                         tidlist = maketidlist(root, path1, path2);
252                         ent = search_ent(tidlist);
253                         free_tidlist(tidlist);
254                         ret = (ent ? (ent->enforce_mask & ENABLE_SEQSCAN) :
255                                    enable_seqscan);
256                         dump_rels("cost_seqscan", root, path1, ent != NULL, ret);
257                         return ret;
258                 case COSTHOOK_indexscan:
259                         tidlist = maketidlist(root, path1, path2);
260                         ent = search_ent(tidlist);
261                         free_tidlist(tidlist);
262                         ret = (ent ? (ent->enforce_mask & ENABLE_INDEXSCAN) :
263                                    enable_indexscan);
264                         dump_rels("cost_indexscan", root, path1, ent != NULL, ret);
265                         return ret;
266                 case COSTHOOK_bitmapscan:
267                         if (path1->pathtype != T_BitmapHeapScan)
268                         {
269                                 ent = NULL;
270                                 ret = enable_bitmapscan;
271                         }
272                         else
273                         {
274                                 tidlist = maketidlist(root, path1, path2);
275                                 ent = search_ent(tidlist);
276                                 free_tidlist(tidlist);
277                                 ret = (ent ? (ent->enforce_mask & ENABLE_BITMAPSCAN) :
278                                            enable_bitmapscan);
279                         }
280                         dump_rels("cost_bitmapscan", root, path1, ent != NULL, ret);
281
282                         return ret;
283                 case COSTHOOK_tidscan:
284                         tidlist = maketidlist(root, path1, path2);
285                         ent = search_ent(tidlist);
286                         free_tidlist(tidlist);
287                         ret = (ent ? (ent->enforce_mask & ENABLE_TIDSCAN) :
288                                    enable_tidscan);
289                         dump_rels("cost_tidscan", root, path1, ent != NULL, ret);
290                         return ret;
291                 case COSTHOOK_nestloop:
292                         tidlist = maketidlist(root, path1, path2);
293                         ent = search_ent(tidlist);
294                         free_tidlist(tidlist);
295                         ret = (ent ? (ent->enforce_mask & ENABLE_NESTLOOP) :
296                                    enable_nestloop);
297                         dump_joinrels("cost_nestloop", root, path1, path2,
298                                                   ent != NULL, ret);
299                         return ret;
300                 case COSTHOOK_mergejoin:
301                         tidlist = maketidlist(root, path1, path2);
302                         ent = search_ent(tidlist);
303                         free_tidlist(tidlist);
304                         ret = (ent ? (ent->enforce_mask & ENABLE_MERGEJOIN) :
305                                    enable_mergejoin);
306                         dump_joinrels("cost_mergejoin", root, path1, path2,
307                                                   ent != NULL, ret);
308                         return ret;
309                 case COSTHOOK_hashjoin:
310                         tidlist = maketidlist(root, path1, path2);
311                         ent = search_ent(tidlist);
312                         free_tidlist(tidlist);
313                         ret = (ent ? (ent->enforce_mask & ENABLE_HASHJOIN) :
314                                    enable_hashjoin);
315                         dump_joinrels("cost_hashjoin", root, path1, path2,
316                                                   ent != NULL, ret);
317                         return ret;
318                 default:
319                         ereport(LOG, (errmsg_internal("Unknown cost type")));
320                         break;
321         }
322         
323         return true;
324 }
325
326 static void free_hashent(HashEntry *head)
327 {
328         HashEntry *next = head;
329
330         while (next)
331         {
332                 HashEntry *last = next;
333                 if (next->tidlist.oids != NULL) free(next->tidlist.oids);
334                 next = next->next;
335                 free(last);
336         }
337 }
338
339 static HashEntry *parse_tidlist(char **str)
340 {
341         char tidstr[8];
342         char *p0;
343         Oid tid[20]; /* ^^; */
344         int ntids = 0;
345         int i, len;
346         HashEntry *ret;
347
348         while (isdigit(**str) && ntids < 20)
349         {
350                 p0 = *str;
351                 while (isdigit(**str)) (*str)++;
352                 len = *str - p0;
353                 if (len >= 8) return NULL;
354                 strncpy(tidstr, p0, len);
355                 tidstr[len] = 0;
356                 
357                 /* Tis 0 is valid? I don't know :-p */
358                 if ((tid[ntids++] = atoi(tidstr)) == 0) return NULL;
359
360                 if (**str == ',') (*str)++;
361         }
362
363         if (ntids > 1)
364                 qsort(tid, ntids, sizeof(Oid), oidsortcmp);
365         ret = (HashEntry*)malloc(sizeof(HashEntry));
366         ret->next = NULL;
367         ret->enforce_mask = 0;
368         ret->tidlist.nrels = ntids;
369         ret->tidlist.oids = (Oid *)malloc(ntids * sizeof(Oid));
370         for (i = 0 ; i < ntids ; i++)
371                 ret->tidlist.oids[i] = tid[i];
372         return ret;     
373 }
374
375 static int parse_phrase(HashEntry **head, char **str)
376 {
377         char *cmds[]    = {"seq", "index", "nest", "merge", "hash", NULL};
378         unsigned char masks[] = {ENABLE_SEQSCAN, ENABLE_INDEXSCAN|ENABLE_BITMAPSCAN,
379                                                    ENABLE_NESTLOOP, ENABLE_MERGEJOIN, ENABLE_HASHJOIN};
380         char req[12];
381         int cmd;
382         HashEntry *ent = NULL;
383         char *p0;
384         int len;
385
386         p0 = *str;
387         while (isalpha(**str)) (*str)++;
388         len = *str - p0;
389         if (**str != '(' || len >= 12) return 0;
390         strncpy(req, p0, len);
391         req[len] = 0;
392         for (cmd = 0 ; cmds[cmd] && strcmp(cmds[cmd], req) ; cmd++);
393         if (cmds[cmd] == NULL) return 0;
394         (*str)++;
395         if ((ent = parse_tidlist(str)) == NULL) return 0;
396         if (*(*str)++ != ')') return 0;
397         if (**str != 0 && **str != ';') return 0;
398         if (**str == ';') (*str)++;
399         ent->enforce_mask = masks[cmd];
400         ent->next = NULL;
401         *head = ent;
402
403         return 1;
404 }
405
406
407 static HashEntry* parse_top(char* str)
408 {
409         HashEntry *head = NULL;
410         HashEntry *ent = NULL;
411
412         if (!parse_phrase(&head, &str))
413         {
414                 free_hashent(head);
415                 return NULL;
416         }
417         ent = head;
418
419         while (*str)
420         {
421                 if (!parse_phrase(&ent->next, &str))
422                 {
423                         free_hashent(head);
424                         return NULL;
425                 }
426                 ent = ent->next;
427         }
428
429         return head;
430 }
431
432 static bool ent_matches(TidList *key, HashEntry *ent2)
433 {
434         int i;
435
436         if (key->nrels != ent2->tidlist.nrels)
437                 return 0;
438
439         for (i = 0 ; i < key->nrels ; i++)
440                 if (key->oids[i] != ent2->tidlist.oids[i])
441                         return 0;
442
443         return 1;
444 }
445
446 static unsigned int calc_hash(TidList *tidlist)
447 {
448         unsigned int hash = 0;
449         int i = 0;
450         
451         for (i = 0 ; i < tidlist->nrels ; i++)
452         {
453                 int j = 0;
454                 for (j = 0 ; j < sizeof(Oid) ; j++)
455                         hash = hash * 2 + ((tidlist->oids[i] >> (j * 8)) & 0xff);
456         }
457
458         return hash % HASH_ENTRIES;
459 ;
460 }
461
462 static HashEntry *search_ent(TidList *tidlist)
463 {
464         HashEntry *ent;
465         if (tidlist == NULL) return NULL;
466
467         ent = hashent[calc_hash(tidlist)];
468         while(ent)
469         {
470                 if (ent_matches(tidlist, ent))
471                         return ent;
472                 ent = ent->next;
473         }
474
475         return NULL;
476 }
477
478 Datum
479 pg_add_hint(PG_FUNCTION_ARGS)
480 {
481         HashEntry *ret = NULL;
482         char *str = NULL;
483         int i = 0;
484
485         if (PG_NARGS() < 1)
486                 ereport(ERROR, (errmsg_internal("No argument")));
487
488         str = text_to_cstring(PG_GETARG_TEXT_PP(0));
489
490         ret = parse_top(str);
491
492         if (ret == NULL)
493                 ereport(ERROR, (errmsg_internal("Parse Error")));
494
495         while (ret)
496         {
497                 HashEntry *etmp = NULL;
498                 HashEntry *next = NULL;
499                 int hash = calc_hash(&ret->tidlist);
500                 while (hashent[hash] && ent_matches(&ret->tidlist, hashent[hash]))
501                 {
502                         etmp = hashent[hash]->next;
503                         hashent[hash]->next = NULL;
504                         free_hashent(hashent[hash]);
505                         hashent[hash] = etmp;
506
507                 }
508                 etmp = hashent[hash];
509                 while (etmp && etmp->next)
510                 {
511                         if (ent_matches(&ret->tidlist, etmp->next))
512                         {
513                                 HashEntry *etmp2 = etmp->next->next;
514                                 etmp->next->next = NULL;
515                                 free_hashent(etmp->next);
516                                 etmp->next = etmp2;
517                         } else
518                                 etmp = etmp->next;
519                 }
520
521                 i++;
522                 next = ret->next;
523                 ret->next = hashent[hash];
524                 hashent[hash] = ret;
525                 ret = next;
526         }
527         PG_RETURN_INT32(i);
528 }
529
530 Datum
531 pg_clear_hint(PG_FUNCTION_ARGS)
532 {
533         int i;
534         int n = 0;
535
536         for (i = 0 ; i < HASH_ENTRIES ; i++)
537         {
538                 free_hashent(hashent[i]);
539                 hashent[i] = NULL;
540                 n++;
541
542         }
543         PG_RETURN_INT32(n);
544 }
545
546 Datum
547 pg_enable_hint(bool arg, bool *isnull)
548 {
549         tweak_enabled = arg;
550         PG_RETURN_INT32(0);
551 }
552
553 Datum
554 pg_enable_log(bool arg, bool *isnull)
555 {
556         print_log = arg;
557         PG_RETURN_INT32(0);
558 }
559
560 static int putsbuf(char **p, char *bottom, char *str)
561 {
562         while (*p < bottom && *str)
563         {
564                 *(*p)++ = *str++;
565         }
566
567         return (*str == 0);
568 }
569
570 static void dump_ent(HashEntry *ent, char **p, char *bottom)
571 {
572         static char typesigs[] = "SIBTNMH";
573         char sigs[sizeof(typesigs)];
574         int i;
575
576         if (!putsbuf(p, bottom, "[(")) return;
577         for (i = 0 ; i < ent->tidlist.nrels ; i++)
578         {
579                 if (i && !putsbuf(p, bottom, ", ")) return;
580                 if (*p >= bottom) return;
581                 *p += snprintf(*p, bottom - *p, "%d", ent->tidlist.oids[i]);
582         }
583         if (!putsbuf(p, bottom, "), ")) return;
584         strcpy(sigs, typesigs);
585         for (i = 0 ; i < 7 ; i++) /* Magic number here! */
586         {
587                 if(((1<<i) & ent->enforce_mask) == 0)
588                         sigs[i] += 'a' - 'A';
589         }
590         if (!putsbuf(p, bottom, sigs)) return;
591         if (!putsbuf(p, bottom, "]")) return;
592 }
593
594 Datum
595 pg_dump_hint(PG_FUNCTION_ARGS)
596 {
597         char buf[16384]; /* ^^; */
598         char *bottom = buf + sizeof(buf);
599         char *p = buf;
600         int i;
601         int first = 1;
602
603         memset(buf, 0, sizeof(buf));
604         for (i = 0 ; i < HASH_ENTRIES ; i++)
605         {
606                 if (hashent[i])
607                 {
608                         HashEntry *ent = hashent[i];
609                         while (ent)
610                         {
611                                 if (first)
612                                         first = 0;
613                                 else
614                                         putsbuf(&p, bottom, ", ");
615                                 
616                                 dump_ent(ent, &p, bottom);
617                                 ent = ent->next;
618                         }
619                 }
620         }
621         if (p >= bottom) p--;
622         *p = 0;
623         
624         PG_RETURN_CSTRING(cstring_to_text(buf));
625 }