OSDN Git Service

Ensure that typmod decoration on a datatype name is validated in all cases,
[pg-rex/syncrep.git] / src / backend / commands / aggregatecmds.c
index 552573d..e49a7da 100644 (file)
@@ -4,12 +4,12 @@
  *
  *       Routines for aggregate-manipulation commands
  *
- * Portions Copyright (c) 1996-2002, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1996-2007, PostgreSQL Global Development Group
  * Portions Copyright (c) 1994, Regents of the University of California
  *
  *
  * IDENTIFICATION
- *       $Header: /cvsroot/pgsql/src/backend/commands/aggregatecmds.c,v 1.5 2002/09/04 20:31:14 momjian Exp $
+ *       $PostgreSQL: pgsql/src/backend/commands/aggregatecmds.c,v 1.44 2007/11/11 19:22:48 tgl Exp $
  *
  * DESCRIPTION
  *       The "DefineFoo" routines take the parse tree and pick out the
@@ -23,9 +23,8 @@
 #include "postgres.h"
 
 #include "access/heapam.h"
-#include "catalog/catname.h"
 #include "catalog/dependency.h"
-#include "catalog/namespace.h"
+#include "catalog/indexing.h"
 #include "catalog/pg_aggregate.h"
 #include "catalog/pg_proc.h"
 #include "catalog/pg_type.h"
 
 /*
  *     DefineAggregate
+ *
+ * "oldstyle" signals the old (pre-8.2) style where the aggregate input type
+ * is specified by a BASETYPE element in the parameters.  Otherwise,
+ * "args" defines the input type(s).
  */
 void
-DefineAggregate(List *names, List *parameters)
+DefineAggregate(List *name, List *args, bool oldstyle, List *parameters)
 {
        char       *aggName;
        Oid                     aggNamespace;
        AclResult       aclresult;
        List       *transfuncName = NIL;
        List       *finalfuncName = NIL;
+       List       *sortoperatorName = NIL;
        TypeName   *baseType = NULL;
        TypeName   *transType = NULL;
        char       *initval = NULL;
-       Oid                     baseTypeId;
+       Oid                *aggArgTypes;
+       int                     numArgs;
        Oid                     transTypeId;
-       List       *pl;
+       ListCell   *pl;
 
        /* Convert list of names to a name and namespace */
-       aggNamespace = QualifiedNameGetCreationNamespace(names, &aggName);
+       aggNamespace = QualifiedNameGetCreationNamespace(name, &aggName);
 
        /* Check we have creation rights in target namespace */
        aclresult = pg_namespace_aclcheck(aggNamespace, GetUserId(), ACL_CREATE);
        if (aclresult != ACLCHECK_OK)
-               aclcheck_error(aclresult, get_namespace_name(aggNamespace));
+               aclcheck_error(aclresult, ACL_KIND_NAMESPACE,
+                                          get_namespace_name(aggNamespace));
 
        foreach(pl, parameters)
        {
                DefElem    *defel = (DefElem *) lfirst(pl);
 
                /*
-                * sfunc1, stype1, and initcond1 are accepted as obsolete
-                * spellings for sfunc, stype, initcond.
+                * sfunc1, stype1, and initcond1 are accepted as obsolete spellings
+                * for sfunc, stype, initcond.
                 */
-               if (strcasecmp(defel->defname, "sfunc") == 0)
+               if (pg_strcasecmp(defel->defname, "sfunc") == 0)
                        transfuncName = defGetQualifiedName(defel);
-               else if (strcasecmp(defel->defname, "sfunc1") == 0)
+               else if (pg_strcasecmp(defel->defname, "sfunc1") == 0)
                        transfuncName = defGetQualifiedName(defel);
-               else if (strcasecmp(defel->defname, "finalfunc") == 0)
+               else if (pg_strcasecmp(defel->defname, "finalfunc") == 0)
                        finalfuncName = defGetQualifiedName(defel);
-               else if (strcasecmp(defel->defname, "basetype") == 0)
+               else if (pg_strcasecmp(defel->defname, "sortop") == 0)
+                       sortoperatorName = defGetQualifiedName(defel);
+               else if (pg_strcasecmp(defel->defname, "basetype") == 0)
                        baseType = defGetTypeName(defel);
-               else if (strcasecmp(defel->defname, "stype") == 0)
+               else if (pg_strcasecmp(defel->defname, "stype") == 0)
                        transType = defGetTypeName(defel);
-               else if (strcasecmp(defel->defname, "stype1") == 0)
+               else if (pg_strcasecmp(defel->defname, "stype1") == 0)
                        transType = defGetTypeName(defel);
-               else if (strcasecmp(defel->defname, "initcond") == 0)
+               else if (pg_strcasecmp(defel->defname, "initcond") == 0)
                        initval = defGetString(defel);
-               else if (strcasecmp(defel->defname, "initcond1") == 0)
+               else if (pg_strcasecmp(defel->defname, "initcond1") == 0)
                        initval = defGetString(defel);
                else
-                       elog(WARNING, "DefineAggregate: attribute \"%s\" not recognized",
-                                defel->defname);
+                       ereport(WARNING,
+                                       (errcode(ERRCODE_SYNTAX_ERROR),
+                                        errmsg("aggregate attribute \"%s\" not recognized",
+                                                       defel->defname)));
        }
 
        /*
         * make sure we have our required definitions
         */
-       if (baseType == NULL)
-               elog(ERROR, "Define: \"basetype\" unspecified");
        if (transType == NULL)
-               elog(ERROR, "Define: \"stype\" unspecified");
+               ereport(ERROR,
+                               (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+                                errmsg("aggregate stype must be specified")));
        if (transfuncName == NIL)
-               elog(ERROR, "Define: \"sfunc\" unspecified");
+               ereport(ERROR,
+                               (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+                                errmsg("aggregate sfunc must be specified")));
 
        /*
-        * look up the aggregate's base type (input datatype) and transtype.
-        *
-        * We have historically allowed the command to look like basetype = 'ANY'
-        * so we must do a case-insensitive comparison for the name ANY.  Ugh.
-        *
-        * basetype can be a pseudo-type, but transtype can't, since we need to
-        * be able to store values of the transtype.
+        * look up the aggregate's input datatype(s).
         */
-       if (strcasecmp(TypeNameToString(baseType), "ANY") == 0)
-               baseTypeId = ANYOID;
+       if (oldstyle)
+       {
+               /*
+                * Old style: use basetype parameter.  This supports aggregates of
+                * zero or one input, with input type ANY meaning zero inputs.
+                *
+                * Historically we allowed the command to look like basetype = 'ANY'
+                * so we must do a case-insensitive comparison for the name ANY. Ugh.
+                */
+               if (baseType == NULL)
+                       ereport(ERROR,
+                                       (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+                                        errmsg("aggregate input type must be specified")));
+
+               if (pg_strcasecmp(TypeNameToString(baseType), "ANY") == 0)
+               {
+                       numArgs = 0;
+                       aggArgTypes = NULL;
+               }
+               else
+               {
+                       numArgs = 1;
+                       aggArgTypes = (Oid *) palloc(sizeof(Oid));
+                       aggArgTypes[0] = typenameTypeId(NULL, baseType, NULL);
+               }
+       }
        else
-               baseTypeId = typenameTypeId(baseType);
+       {
+               /*
+                * New style: args is a list of TypeNames (possibly zero of 'em).
+                */
+               ListCell   *lc;
+               int                     i = 0;
 
-       transTypeId = typenameTypeId(transType);
-       if (get_typtype(transTypeId) == 'p')
-               elog(ERROR, "Aggregate transition datatype cannot be %s",
-                        format_type_be(transTypeId));
+               if (baseType != NULL)
+                       ereport(ERROR,
+                                       (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+                                        errmsg("basetype is redundant with aggregate input type specification")));
+
+               numArgs = list_length(args);
+               aggArgTypes = (Oid *) palloc(sizeof(Oid) * numArgs);
+               foreach(lc, args)
+               {
+                       TypeName   *curTypeName = (TypeName *) lfirst(lc);
+
+                       aggArgTypes[i++] = typenameTypeId(NULL, curTypeName, NULL);
+               }
+       }
+
+       /*
+        * look up the aggregate's transtype.
+        *
+        * transtype can't be a pseudo-type, since we need to be able to store
+        * values of the transtype.  However, we can allow polymorphic transtype
+        * in some cases (AggregateCreate will check).
+        */
+       transTypeId = typenameTypeId(NULL, transType, NULL);
+       if (get_typtype(transTypeId) == TYPTYPE_PSEUDO &&
+               !IsPolymorphicType(transTypeId))
+               ereport(ERROR,
+                               (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+                                errmsg("aggregate transition data type cannot be %s",
+                                               format_type_be(transTypeId))));
 
        /*
         * Most of the argument-checking is done inside of AggregateCreate
         */
        AggregateCreate(aggName,        /* aggregate name */
                                        aggNamespace,           /* namespace */
+                                       aggArgTypes,    /* input data type(s) */
+                                       numArgs,
                                        transfuncName,          /* step function name */
                                        finalfuncName,          /* final function name */
-                                       baseTypeId, /* type of data being aggregated */
+                                       sortoperatorName,       /* sort operator name */
                                        transTypeId,    /* transition data type */
                                        initval);       /* initial condition */
 }
@@ -141,28 +203,26 @@ DefineAggregate(List *names, List *parameters)
  *             Deletes an aggregate.
  */
 void
-RemoveAggregate(RemoveAggrStmt *stmt)
+RemoveAggregate(RemoveFuncStmt *stmt)
 {
-       List       *aggName = stmt->aggname;
-       TypeName   *aggType = stmt->aggtype;
-       Oid                     basetypeID;
+       List       *aggName = stmt->name;
+       List       *aggArgs = stmt->args;
        Oid                     procOid;
        HeapTuple       tup;
        ObjectAddress object;
 
-       /*
-        * if a basetype is passed in, then attempt to find an aggregate for
-        * that specific type.
-        *
-        * else attempt to find an aggregate with a basetype of ANYOID. This
-        * means that the aggregate is to apply to all basetypes (eg, COUNT).
-        */
-       if (aggType)
-               basetypeID = typenameTypeId(aggType);
-       else
-               basetypeID = ANYOID;
+       /* Look up function and make sure it's an aggregate */
+       procOid = LookupAggNameTypeNames(aggName, aggArgs, stmt->missing_ok);
 
-       procOid = find_aggregate_func("RemoveAggregate", aggName, basetypeID);
+       if (!OidIsValid(procOid))
+       {
+               /* we only get here if stmt->missing_ok is true */
+               ereport(NOTICE,
+                               (errmsg("aggregate %s(%s) does not exist, skipping",
+                                               NameListToString(aggName),
+                                               TypeNameListToString(aggArgs))));
+               return;
+       }
 
        /*
         * Find the function tuple, do permissions and validity checks
@@ -171,25 +231,145 @@ RemoveAggregate(RemoveAggrStmt *stmt)
                                                 ObjectIdGetDatum(procOid),
                                                 0, 0, 0);
        if (!HeapTupleIsValid(tup)) /* should not happen */
-               elog(ERROR, "RemoveAggregate: couldn't find pg_proc tuple for %s",
-                        NameListToString(aggName));
+               elog(ERROR, "cache lookup failed for function %u", procOid);
 
        /* Permission check: must own agg or its namespace */
        if (!pg_proc_ownercheck(procOid, GetUserId()) &&
-               !pg_namespace_ownercheck(((Form_pg_proc) GETSTRUCT(tup))->pronamespace,
-                                                                GetUserId()))
-               aclcheck_error(ACLCHECK_NOT_OWNER, NameListToString(aggName));
-
-       /* find_aggregate_func already checked it is an aggregate */
+         !pg_namespace_ownercheck(((Form_pg_proc) GETSTRUCT(tup))->pronamespace,
+                                                          GetUserId()))
+               aclcheck_error(ACLCHECK_NOT_OWNER, ACL_KIND_PROC,
+                                          NameListToString(aggName));
 
        ReleaseSysCache(tup);
 
        /*
         * Do the deletion
         */
-       object.classId = RelOid_pg_proc;
+       object.classId = ProcedureRelationId;
        object.objectId = procOid;
        object.objectSubId = 0;
 
        performDeletion(&object, stmt->behavior);
 }
+
+
+void
+RenameAggregate(List *name, List *args, const char *newname)
+{
+       Oid                     procOid;
+       Oid                     namespaceOid;
+       HeapTuple       tup;
+       Form_pg_proc procForm;
+       Relation        rel;
+       AclResult       aclresult;
+
+       rel = heap_open(ProcedureRelationId, RowExclusiveLock);
+
+       /* Look up function and make sure it's an aggregate */
+       procOid = LookupAggNameTypeNames(name, args, false);
+
+       tup = SearchSysCacheCopy(PROCOID,
+                                                        ObjectIdGetDatum(procOid),
+                                                        0, 0, 0);
+       if (!HeapTupleIsValid(tup)) /* should not happen */
+               elog(ERROR, "cache lookup failed for function %u", procOid);
+       procForm = (Form_pg_proc) GETSTRUCT(tup);
+
+       namespaceOid = procForm->pronamespace;
+
+       /* make sure the new name doesn't exist */
+       if (SearchSysCacheExists(PROCNAMEARGSNSP,
+                                                        CStringGetDatum(newname),
+                                                        PointerGetDatum(&procForm->proargtypes),
+                                                        ObjectIdGetDatum(namespaceOid),
+                                                        0))
+               ereport(ERROR,
+                               (errcode(ERRCODE_DUPLICATE_FUNCTION),
+                                errmsg("function %s already exists in schema \"%s\"",
+                                               funcname_signature_string(newname,
+                                                                                                 procForm->pronargs,
+                                                                                          procForm->proargtypes.values),
+                                               get_namespace_name(namespaceOid))));
+
+       /* must be owner */
+       if (!pg_proc_ownercheck(procOid, GetUserId()))
+               aclcheck_error(ACLCHECK_NOT_OWNER, ACL_KIND_PROC,
+                                          NameListToString(name));
+
+       /* must have CREATE privilege on namespace */
+       aclresult = pg_namespace_aclcheck(namespaceOid, GetUserId(), ACL_CREATE);
+       if (aclresult != ACLCHECK_OK)
+               aclcheck_error(aclresult, ACL_KIND_NAMESPACE,
+                                          get_namespace_name(namespaceOid));
+
+       /* rename */
+       namestrcpy(&(((Form_pg_proc) GETSTRUCT(tup))->proname), newname);
+       simple_heap_update(rel, &tup->t_self, tup);
+       CatalogUpdateIndexes(rel, tup);
+
+       heap_close(rel, NoLock);
+       heap_freetuple(tup);
+}
+
+/*
+ * Change aggregate owner
+ */
+void
+AlterAggregateOwner(List *name, List *args, Oid newOwnerId)
+{
+       Oid                     procOid;
+       HeapTuple       tup;
+       Form_pg_proc procForm;
+       Relation        rel;
+       AclResult       aclresult;
+
+       rel = heap_open(ProcedureRelationId, RowExclusiveLock);
+
+       /* Look up function and make sure it's an aggregate */
+       procOid = LookupAggNameTypeNames(name, args, false);
+
+       tup = SearchSysCacheCopy(PROCOID,
+                                                        ObjectIdGetDatum(procOid),
+                                                        0, 0, 0);
+       if (!HeapTupleIsValid(tup)) /* should not happen */
+               elog(ERROR, "cache lookup failed for function %u", procOid);
+       procForm = (Form_pg_proc) GETSTRUCT(tup);
+
+       /*
+        * If the new owner is the same as the existing owner, consider the
+        * command to have succeeded.  This is for dump restoration purposes.
+        */
+       if (procForm->proowner != newOwnerId)
+       {
+               /* Superusers can always do it */
+               if (!superuser())
+               {
+                       /* Otherwise, must be owner of the existing object */
+                       if (!pg_proc_ownercheck(procOid, GetUserId()))
+                               aclcheck_error(ACLCHECK_NOT_OWNER, ACL_KIND_PROC,
+                                                          NameListToString(name));
+
+                       /* Must be able to become new owner */
+                       check_is_member_of_role(GetUserId(), newOwnerId);
+
+                       /* New owner must have CREATE privilege on namespace */
+                       aclresult = pg_namespace_aclcheck(procForm->pronamespace,
+                                                                                         newOwnerId,
+                                                                                         ACL_CREATE);
+                       if (aclresult != ACLCHECK_OK)
+                               aclcheck_error(aclresult, ACL_KIND_NAMESPACE,
+                                                          get_namespace_name(procForm->pronamespace));
+               }
+
+               /*
+                * Modify the owner --- okay to scribble on tup because it's a copy
+                */
+               procForm->proowner = newOwnerId;
+
+               simple_heap_update(rel, &tup->t_self, tup);
+               CatalogUpdateIndexes(rel, tup);
+       }
+
+       heap_close(rel, NoLock);
+       heap_freetuple(tup);
+}