OSDN Git Service

Ensure that typmod decoration on a datatype name is validated in all cases,
[pg-rex/syncrep.git] / src / backend / commands / aggregatecmds.c
index a787f7a..e49a7da 100644 (file)
@@ -4,12 +4,12 @@
  *
  *       Routines for aggregate-manipulation commands
  *
- * Portions Copyright (c) 1996-2003, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1996-2007, PostgreSQL Global Development Group
  * Portions Copyright (c) 1994, Regents of the University of California
  *
  *
  * IDENTIFICATION
- *       $PostgreSQL: pgsql/src/backend/commands/aggregatecmds.c,v 1.19 2004/06/25 21:55:53 tgl Exp $
+ *       $PostgreSQL: pgsql/src/backend/commands/aggregatecmds.c,v 1.44 2007/11/11 19:22:48 tgl Exp $
  *
  * DESCRIPTION
  *       The "DefineFoo" routines take the parse tree and pick out the
 #include "postgres.h"
 
 #include "access/heapam.h"
-#include "catalog/catname.h"
 #include "catalog/dependency.h"
 #include "catalog/indexing.h"
-#include "catalog/namespace.h"
 #include "catalog/pg_aggregate.h"
 #include "catalog/pg_proc.h"
 #include "catalog/pg_type.h"
 
 /*
  *     DefineAggregate
+ *
+ * "oldstyle" signals the old (pre-8.2) style where the aggregate input type
+ * is specified by a BASETYPE element in the parameters.  Otherwise,
+ * "args" defines the input type(s).
  */
 void
-DefineAggregate(List *names, List *parameters)
+DefineAggregate(List *name, List *args, bool oldstyle, List *parameters)
 {
        char       *aggName;
        Oid                     aggNamespace;
        AclResult       aclresult;
        List       *transfuncName = NIL;
        List       *finalfuncName = NIL;
+       List       *sortoperatorName = NIL;
        TypeName   *baseType = NULL;
        TypeName   *transType = NULL;
        char       *initval = NULL;
-       Oid                     baseTypeId;
+       Oid                *aggArgTypes;
+       int                     numArgs;
        Oid                     transTypeId;
        ListCell   *pl;
 
        /* Convert list of names to a name and namespace */
-       aggNamespace = QualifiedNameGetCreationNamespace(names, &aggName);
+       aggNamespace = QualifiedNameGetCreationNamespace(name, &aggName);
 
        /* Check we have creation rights in target namespace */
        aclresult = pg_namespace_aclcheck(aggNamespace, GetUserId(), ACL_CREATE);
@@ -72,8 +76,8 @@ DefineAggregate(List *names, List *parameters)
                DefElem    *defel = (DefElem *) lfirst(pl);
 
                /*
-                * sfunc1, stype1, and initcond1 are accepted as obsolete
-                * spellings for sfunc, stype, initcond.
+                * sfunc1, stype1, and initcond1 are accepted as obsolete spellings
+                * for sfunc, stype, initcond.
                 */
                if (pg_strcasecmp(defel->defname, "sfunc") == 0)
                        transfuncName = defGetQualifiedName(defel);
@@ -81,6 +85,8 @@ DefineAggregate(List *names, List *parameters)
                        transfuncName = defGetQualifiedName(defel);
                else if (pg_strcasecmp(defel->defname, "finalfunc") == 0)
                        finalfuncName = defGetQualifiedName(defel);
+               else if (pg_strcasecmp(defel->defname, "sortop") == 0)
+                       sortoperatorName = defGetQualifiedName(defel);
                else if (pg_strcasecmp(defel->defname, "basetype") == 0)
                        baseType = defGetTypeName(defel);
                else if (pg_strcasecmp(defel->defname, "stype") == 0)
@@ -101,10 +107,6 @@ DefineAggregate(List *names, List *parameters)
        /*
         * make sure we have our required definitions
         */
-       if (baseType == NULL)
-               ereport(ERROR,
-                               (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
-                                errmsg("aggregate basetype must be specified")));
        if (transType == NULL)
                ereport(ERROR,
                                (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
@@ -115,24 +117,67 @@ DefineAggregate(List *names, List *parameters)
                                 errmsg("aggregate sfunc must be specified")));
 
        /*
-        * look up the aggregate's base type (input datatype) and transtype.
-        *
-        * We have historically allowed the command to look like basetype = 'ANY'
-        * so we must do a case-insensitive comparison for the name ANY.  Ugh.
-        *
-        * basetype can be a pseudo-type, but transtype can't, since we need to
-        * be able to store values of the transtype.  However, we can allow
-        * polymorphic transtype in some cases (AggregateCreate will check).
+        * look up the aggregate's input datatype(s).
         */
-       if (pg_strcasecmp(TypeNameToString(baseType), "ANY") == 0)
-               baseTypeId = ANYOID;
+       if (oldstyle)
+       {
+               /*
+                * Old style: use basetype parameter.  This supports aggregates of
+                * zero or one input, with input type ANY meaning zero inputs.
+                *
+                * Historically we allowed the command to look like basetype = 'ANY'
+                * so we must do a case-insensitive comparison for the name ANY. Ugh.
+                */
+               if (baseType == NULL)
+                       ereport(ERROR,
+                                       (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+                                        errmsg("aggregate input type must be specified")));
+
+               if (pg_strcasecmp(TypeNameToString(baseType), "ANY") == 0)
+               {
+                       numArgs = 0;
+                       aggArgTypes = NULL;
+               }
+               else
+               {
+                       numArgs = 1;
+                       aggArgTypes = (Oid *) palloc(sizeof(Oid));
+                       aggArgTypes[0] = typenameTypeId(NULL, baseType, NULL);
+               }
+       }
        else
-               baseTypeId = typenameTypeId(baseType);
+       {
+               /*
+                * New style: args is a list of TypeNames (possibly zero of 'em).
+                */
+               ListCell   *lc;
+               int                     i = 0;
+
+               if (baseType != NULL)
+                       ereport(ERROR,
+                                       (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+                                        errmsg("basetype is redundant with aggregate input type specification")));
+
+               numArgs = list_length(args);
+               aggArgTypes = (Oid *) palloc(sizeof(Oid) * numArgs);
+               foreach(lc, args)
+               {
+                       TypeName   *curTypeName = (TypeName *) lfirst(lc);
+
+                       aggArgTypes[i++] = typenameTypeId(NULL, curTypeName, NULL);
+               }
+       }
 
-       transTypeId = typenameTypeId(transType);
-       if (get_typtype(transTypeId) == 'p' &&
-               transTypeId != ANYARRAYOID &&
-               transTypeId != ANYELEMENTOID)
+       /*
+        * look up the aggregate's transtype.
+        *
+        * transtype can't be a pseudo-type, since we need to be able to store
+        * values of the transtype.  However, we can allow polymorphic transtype
+        * in some cases (AggregateCreate will check).
+        */
+       transTypeId = typenameTypeId(NULL, transType, NULL);
+       if (get_typtype(transTypeId) == TYPTYPE_PSEUDO &&
+               !IsPolymorphicType(transTypeId))
                ereport(ERROR,
                                (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
                                 errmsg("aggregate transition data type cannot be %s",
@@ -143,9 +188,11 @@ DefineAggregate(List *names, List *parameters)
         */
        AggregateCreate(aggName,        /* aggregate name */
                                        aggNamespace,           /* namespace */
+                                       aggArgTypes,    /* input data type(s) */
+                                       numArgs,
                                        transfuncName,          /* step function name */
                                        finalfuncName,          /* final function name */
-                                       baseTypeId, /* type of data being aggregated */
+                                       sortoperatorName,       /* sort operator name */
                                        transTypeId,    /* transition data type */
                                        initval);       /* initial condition */
 }
@@ -156,28 +203,26 @@ DefineAggregate(List *names, List *parameters)
  *             Deletes an aggregate.
  */
 void
-RemoveAggregate(RemoveAggrStmt *stmt)
+RemoveAggregate(RemoveFuncStmt *stmt)
 {
-       List       *aggName = stmt->aggname;
-       TypeName   *aggType = stmt->aggtype;
-       Oid                     basetypeID;
+       List       *aggName = stmt->name;
+       List       *aggArgs = stmt->args;
        Oid                     procOid;
        HeapTuple       tup;
        ObjectAddress object;
 
-       /*
-        * if a basetype is passed in, then attempt to find an aggregate for
-        * that specific type.
-        *
-        * else attempt to find an aggregate with a basetype of ANYOID. This
-        * means that the aggregate is to apply to all basetypes (eg, COUNT).
-        */
-       if (aggType)
-               basetypeID = typenameTypeId(aggType);
-       else
-               basetypeID = ANYOID;
+       /* Look up function and make sure it's an aggregate */
+       procOid = LookupAggNameTypeNames(aggName, aggArgs, stmt->missing_ok);
 
-       procOid = find_aggregate_func(aggName, basetypeID, false);
+       if (!OidIsValid(procOid))
+       {
+               /* we only get here if stmt->missing_ok is true */
+               ereport(NOTICE,
+                               (errmsg("aggregate %s(%s) does not exist, skipping",
+                                               NameListToString(aggName),
+                                               TypeNameListToString(aggArgs))));
+               return;
+       }
 
        /*
         * Find the function tuple, do permissions and validity checks
@@ -190,19 +235,17 @@ RemoveAggregate(RemoveAggrStmt *stmt)
 
        /* Permission check: must own agg or its namespace */
        if (!pg_proc_ownercheck(procOid, GetUserId()) &&
-               !pg_namespace_ownercheck(((Form_pg_proc) GETSTRUCT(tup))->pronamespace,
-                                                                GetUserId()))
+         !pg_namespace_ownercheck(((Form_pg_proc) GETSTRUCT(tup))->pronamespace,
+                                                          GetUserId()))
                aclcheck_error(ACLCHECK_NOT_OWNER, ACL_KIND_PROC,
                                           NameListToString(aggName));
 
-       /* find_aggregate_func already checked it is an aggregate */
-
        ReleaseSysCache(tup);
 
        /*
         * Do the deletion
         */
-       object.classId = RelOid_pg_proc;
+       object.classId = ProcedureRelationId;
        object.objectId = procOid;
        object.objectSubId = 0;
 
@@ -211,9 +254,8 @@ RemoveAggregate(RemoveAggrStmt *stmt)
 
 
 void
-RenameAggregate(List *name, TypeName *basetype, const char *newname)
+RenameAggregate(List *name, List *args, const char *newname)
 {
-       Oid                     basetypeOid;
        Oid                     procOid;
        Oid                     namespaceOid;
        HeapTuple       tup;
@@ -221,20 +263,10 @@ RenameAggregate(List *name, TypeName *basetype, const char *newname)
        Relation        rel;
        AclResult       aclresult;
 
-       /*
-        * if a basetype is passed in, then attempt to find an aggregate for
-        * that specific type; else attempt to find an aggregate with a basetype
-        * of ANYOID. This means that the aggregate applies to all basetypes
-        * (eg, COUNT).
-        */
-       if (basetype)
-               basetypeOid = typenameTypeId(basetype);
-       else
-               basetypeOid = ANYOID;
-
-       rel = heap_openr(ProcedureRelationName, RowExclusiveLock);
+       rel = heap_open(ProcedureRelationId, RowExclusiveLock);
 
-       procOid = find_aggregate_func(name, basetypeOid, false);
+       /* Look up function and make sure it's an aggregate */
+       procOid = LookupAggNameTypeNames(name, args, false);
 
        tup = SearchSysCacheCopy(PROCOID,
                                                         ObjectIdGetDatum(procOid),
@@ -246,27 +278,18 @@ RenameAggregate(List *name, TypeName *basetype, const char *newname)
        namespaceOid = procForm->pronamespace;
 
        /* make sure the new name doesn't exist */
-       if (SearchSysCacheExists(PROCNAMENSP,
+       if (SearchSysCacheExists(PROCNAMEARGSNSP,
                                                         CStringGetDatum(newname),
-                                                        Int16GetDatum(procForm->pronargs),
-                                                        PointerGetDatum(procForm->proargtypes),
-                                                        ObjectIdGetDatum(namespaceOid)))
-       {
-               if (basetypeOid == ANYOID)
-                       ereport(ERROR,
-                                       (errcode(ERRCODE_DUPLICATE_FUNCTION),
-                                errmsg("function %s(*) already exists in schema \"%s\"",
-                                               newname,
+                                                        PointerGetDatum(&procForm->proargtypes),
+                                                        ObjectIdGetDatum(namespaceOid),
+                                                        0))
+               ereport(ERROR,
+                               (errcode(ERRCODE_DUPLICATE_FUNCTION),
+                                errmsg("function %s already exists in schema \"%s\"",
+                                               funcname_signature_string(newname,
+                                                                                                 procForm->pronargs,
+                                                                                          procForm->proargtypes.values),
                                                get_namespace_name(namespaceOid))));
-               else
-                       ereport(ERROR,
-                                       (errcode(ERRCODE_DUPLICATE_FUNCTION),
-                                        errmsg("function %s already exists in schema \"%s\"",
-                                                       funcname_signature_string(newname,
-                                                                                                         procForm->pronargs,
-                                                                                                 procForm->proargtypes),
-                                                       get_namespace_name(namespaceOid))));
-       }
 
        /* must be owner */
        if (!pg_proc_ownercheck(procOid, GetUserId()))
@@ -292,28 +315,18 @@ RenameAggregate(List *name, TypeName *basetype, const char *newname)
  * Change aggregate owner
  */
 void
-AlterAggregateOwner(List *name, TypeName *basetype, AclId newOwnerSysId)
+AlterAggregateOwner(List *name, List *args, Oid newOwnerId)
 {
-       Oid                     basetypeOid;
        Oid                     procOid;
        HeapTuple       tup;
        Form_pg_proc procForm;
        Relation        rel;
+       AclResult       aclresult;
 
-       /*
-        * if a basetype is passed in, then attempt to find an aggregate for
-        * that specific type; else attempt to find an aggregate with a basetype
-        * of ANYOID. This means that the aggregate applies to all basetypes
-        * (eg, COUNT).
-        */
-       if (basetype)
-               basetypeOid = typenameTypeId(basetype);
-       else
-               basetypeOid = ANYOID;
-
-       rel = heap_openr(ProcedureRelationName, RowExclusiveLock);
+       rel = heap_open(ProcedureRelationId, RowExclusiveLock);
 
-       procOid = find_aggregate_func(name, basetypeOid, false);
+       /* Look up function and make sure it's an aggregate */
+       procOid = LookupAggNameTypeNames(name, args, false);
 
        tup = SearchSysCacheCopy(PROCOID,
                                                         ObjectIdGetDatum(procOid),
@@ -322,20 +335,36 @@ AlterAggregateOwner(List *name, TypeName *basetype, AclId newOwnerSysId)
                elog(ERROR, "cache lookup failed for function %u", procOid);
        procForm = (Form_pg_proc) GETSTRUCT(tup);
 
-       /* 
+       /*
         * If the new owner is the same as the existing owner, consider the
         * command to have succeeded.  This is for dump restoration purposes.
         */
-       if (procForm->proowner != newOwnerSysId)
+       if (procForm->proowner != newOwnerId)
        {
-               /* Otherwise, must be superuser to change object ownership */
+               /* Superusers can always do it */
                if (!superuser())
-                       ereport(ERROR,
-                                       (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE),
-                                        errmsg("must be superuser to change owner")));
+               {
+                       /* Otherwise, must be owner of the existing object */
+                       if (!pg_proc_ownercheck(procOid, GetUserId()))
+                               aclcheck_error(ACLCHECK_NOT_OWNER, ACL_KIND_PROC,
+                                                          NameListToString(name));
+
+                       /* Must be able to become new owner */
+                       check_is_member_of_role(GetUserId(), newOwnerId);
+
+                       /* New owner must have CREATE privilege on namespace */
+                       aclresult = pg_namespace_aclcheck(procForm->pronamespace,
+                                                                                         newOwnerId,
+                                                                                         ACL_CREATE);
+                       if (aclresult != ACLCHECK_OK)
+                               aclcheck_error(aclresult, ACL_KIND_NAMESPACE,
+                                                          get_namespace_name(procForm->pronamespace));
+               }
 
-               /* Modify the owner --- okay to scribble on tup because it's a copy */
-               procForm->proowner = newOwnerSysId;
+               /*
+                * Modify the owner --- okay to scribble on tup because it's a copy
+                */
+               procForm->proowner = newOwnerId;
 
                simple_heap_update(rel, &tup->t_self, tup);
                CatalogUpdateIndexes(rel, tup);