Source code for gavo.adql.nodes

"""
Node classes and factories used in ADQL tree processing.
"""

#c Copyright 2008-2023, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL.  See the
#c COPYING file in the source distribution.


import fnmatch
import functools
import itertools
import re
import weakref
from functools import reduce

from gavo import stc
from gavo import utils
from gavo.adql import common
from gavo.adql import fieldinfo
from gavo.adql import fieldinfos
from gavo.adql import grammar
from gavo.stc import tapstc
from gavo.utils import parsetricks


################ Various helpers

[docs]class ReplaceNode(utils.ExecutiveAction): """can be raised by code in the constructor of an ADQLNode to replace itself. It is constructed with the (single) ADQLNode that should stand in its stead. This is intended as a special service for ufuncs that want to insert complex, annotatable expressions. We also use that in certain situations when casting. """ def __init__(self, replacingNode): self.replacingNode = replacingNode
[docs]def symbolAction(*symbols): """is a decorator to mark functions as being a parseAction for symbol. This is evaluated by getADQLGrammar below. Be careful not to alter global state in such a handler. """ def deco(func): for symbol in symbols: if hasattr(func, "parseActionFor"): # plan for double decoration (so don't worry about no coverage) func.parseActionFor.append(symbol) else: func.parseActionFor = [symbol] func.fromParseResult = func return func return deco
[docs]def getType(arg): """returns the type of an ADQL node or the value of str if arg is a string. """ if isinstance(arg, str): return str else: return arg.type
[docs]def flatten(arg): """returns the SQL serialized representation of arg. """ if isinstance(arg, str): return arg elif isinstance(arg, (int, float)): return repr(arg) elif isinstance(arg, parsetricks.ParseResults): return " ".join(flatten(c) for c in arg) # elif arg is None: import pdb;pdb.Pdb(nosigint=True).set_trace() else: return arg.flatten()
[docs]def autocollapse(nodeBuilder, children): """inhibts the construction via nodeBuilder if children consists of a single ADQLNode. This function will automatically be inserted into the the constructor chain if the node defines an attribute collapsible=True. """ if len(children)==1 and isinstance(children[0], ADQLNode): return children[0] return nodeBuilder.fromParseResult(children)
[docs]def collectUserData(infoChildren): userData, tainted = (), False for c in infoChildren: userData = userData+c.fieldInfo.userData tainted = tainted or c.fieldInfo.tainted return userData, tainted
[docs]def flattenKWs(obj, *fmtTuples): """returns a string built from the obj according to format tuples. A format tuple is consists of a literal string, and an attribute name. If the corresponding attribute is non-None, the plain string and the flattened attribute value are inserted into the result string, otherwise both are ignored. Nonexisting attributes are taken to have None values. To allow unconditional literals, the attribute name can be None. The corresponding literal is always inserted. All contributions are separated by single blanks. This is a helper method for flatten methods of parsed-out elements. """ res = [] for literal, attName in fmtTuples: if attName is None: res.append(literal) else: if getattr(obj, attName, None) is not None: if literal: res.append(literal) res.append(flatten(getattr(obj, attName))) return " ".join(res)
[docs]def cleanNamespace(ns): """removes all names starting with an underscore from the dict ns. This is intended for _getInitKWs methods. ns is changed in place *and* returned for convenience """ return dict((k,v) for k,v in ns.items() if not k.startswith("_") and k!="cls")
[docs]def getChildrenOfType(nodeSeq, type): """returns a list of children of type typ in the sequence nodeSeq. """ return [c for c in nodeSeq if getType(c)==type]
[docs]def getChildrenOfClass(nodeSeq, cls): return [c for c in nodeSeq if isinstance(c, cls)]
[docs]class BOMB_OUT(object): pass
def _uniquify(matches, default, exArgs): # helper method for getChildOfX -- see there if len(matches)==0: if default is not BOMB_OUT: return default raise common.NoChild(*exArgs) if len(matches)!=1: raise common.MoreThanOneChild(*exArgs) return matches[0]
[docs]def getChildOfType(nodeSeq, type, default=BOMB_OUT): """returns the unique node of type in nodeSeq. If there is no such node in nodeSeq or more than one, a NoChild or MoreThanOneChild exception is raised, Instead of raising NoChild, default is returned if given. """ return _uniquify(getChildrenOfType(nodeSeq, type), default, (type, nodeSeq))
[docs]def getChildOfClass(nodeSeq, cls, default=BOMB_OUT): """returns the unique node of class in nodeSeq. See getChildOfType. """ return _uniquify(getChildrenOfClass(nodeSeq, cls), default, (cls, nodeSeq))
[docs]def parseArgs(parseResult): """returns a sequence of ADQL nodes suitable as function arguments from parseResult. This is for cleaning up _parseResults["args"], i.e. stuff from the Args symbol decorator in grammar. """ args = [] for _arg in parseResult: # _arg is either another ParseResult, an ADQL identifier, or an ADQLNode if isinstance(_arg, (ADQLNode, str, utils.QuotedName)): args.append(_arg) else: args.append(autocollapse(GenericValueExpression, _arg)) return tuple(args)
[docs]def getStringLiteral(node, description="Argument"): """ensures that node only contains a constant string literal and returns its value if so. The function raises an adql.Error mentioning description otherwise. """ if node.type!='characterStringLiteral': raise common.Error( "%s must be a constant string literal"%description) return node.value
######################### Misc helpers related to simple query planning def _getDescendants(args): """returns the nodes in the sequence args and all their descendants. This is a helper function for when you have to analyse what's contributing to complex terms. """ descendants = list(args) for arg in args: if hasattr(arg, "iterTree"): descendants.extend(c[1] for c in arg.iterTree()) return descendants
[docs]def iterFieldInfos(args): """returns fieldInfo objects found within the children of the node list args. """ for desc in _getDescendants(args): if getattr(desc, "fieldInfo", None) is not None: yield desc.fieldInfo
def _isConstant(args): """returns true if no columnReference-s are found below the node list args. """ for desc in _getDescendants(args): if getattr(desc, "type", None)=="columnReference": return False return True def _estimateTableSize(args): """returns an estimate for the size of a table mentioned in the node list args. Actually, we wait for the first column in fieldInfo userdata that has a reference to a table that knows its nrows. If that comes, that's our estimate. If it doesn't come, we return None. """ for fi in iterFieldInfos(args): if fi.ignoreTableStats: return for ud in fi.userData: sizeEst = getattr(ud.parent, "nrows", None) if sizeEst is not None: return sizeEst return None def _sortLargeFirst(arg1, arg2): """returns arga, argb such that arga deals with the larger table if we can figure that out. This is for distance; postgres in general only uses an index for them if the point stands alone (rather than in the circle). So, it normally pays to have the larger table first in our expressions (which are point op geom where applicable). This will also swap constant arguments second (so, into the circle). """ if _isConstant([arg1]): return arg2, arg1 if _isConstant([arg2]): return arg1, arg2 size1, size2 = _estimateTableSize([arg1]), _estimateTableSize([arg2]) if size1 is None: if size2 is None: # we know nothing; don't change anything to keep the user in control return arg1, arg2 else: # we assume all large tables are nrows-annotated, so presumably # arg1 isn't large. So, swap. return arg2, arg1 else: if size2 is None: # see one comment up return arg1, arg2 else: if size1>size2: return arg1, arg2 else: return arg2, arg1 ######################### Generic Node definitions
[docs]class ADQLNode(utils.AutoNode): """A node within an ADQL parse tree. ADQL nodes may be parsed out; in that case, they have individual attributes and are craftily flattened in special methods. We do this for nodes that are morphed. Other nodes basically just have a children attribute, and their flattening is just a concatenation for their flattened children. This is convenient as long as they are not morphed. To derive actual classes, define - the _a_<name> class attributes you need, - the type (a nonterminal from the ADQL grammar) - bindings if the class handles more than one symbol (in which case type is ignored) - a class method _getInitKWs(cls, parseResult); see below. - a method flatten() -> string if you define a parsed ADQLNode. - a method _polish() that is called just before the constructor is done and can be used to create more attributes. There is no need to call _polish of superclasses. The _getInitKWs methods must return a dictionary mapping constructor argument names to values. You do not need to manually call superclass _getInitKWs, since the fromParseResult classmethod figures out all _getInitKWs in the inheritance tree itself. It calls all of them in the normal MRO and updates the argument dictionary in reverse order. The fromParseResult class method additionally filters out all names starting with an underscore; this is to allow easy returning of locals(). """ type = None
[docs] @classmethod def fromParseResult(cls, parseResult): initArgs = {} for superclass in reversed(cls.mro()): if hasattr(superclass, "_getInitKWs"): initArgs.update(superclass._getInitKWs(parseResult)) try: return cls(**cleanNamespace(initArgs)) except TypeError: raise common.BadKeywords("%s, %s"%(cls, cleanNamespace(initArgs))) except ReplaceNode as rn: return rn.replacingNode
def _setupNode(self): for cls in reversed(self.__class__.mro()): if hasattr(cls, "_polish"): cls._polish(self) self._setupNodeNext(ADQLNode) def __repr__(self): return "<ADQL Node %s>"%(self.type)
[docs] def flatten(self): """returns a string representation of the text content of the tree. This default implementation will only work if you returned all parsed elements as children. This, in turn, is something you only want to do if you are sure that the node is question will not be morphed. Otherwise, override it to create an SQL fragment out of the parsed attributes. """ return " ".join(flatten(c) for c in self.children)
[docs] def asTree(self): res = [] for name, val in self.iterChildren(): if isinstance(val, ADQLNode): res.append(val.asTree()) return self._treeRepr()+tuple(res)
def _treeRepr(self): return (self.type,)
[docs] def iterTree(self): for name, val in self.iterChildren(): if isinstance(val, ADQLNode): for item in val.iterTree(): yield item yield name, val
[docs]class TransparentMixin(object): """a mixin just pulling through the children and serializing them. """ _a_children = () @classmethod def _getInitKWs(cls, _parseResult): return {"children": list(_parseResult)}
[docs]class FieldInfoedNode(ADQLNode): """An ADQL node that carries a FieldInfo. This is true for basically everything in the tree below a derived column. This class is the basis for column annotation. You'll usually have to override addFieldInfo. The default implementation just looks in its immediate children for anything having a fieldInfo, and if there's exactly one such child, it adopts that fieldInfo as its own, not changing anything. FieldInfoedNode, when change()d, keep their field info. This is usually what you want when morphing, but sometimes you might need adjustments. """ fieldInfo = None def _getInfoChildren(self): return [c for c in self.iterNodeChildren() if hasattr(c, "fieldInfo")]
[docs] def addFieldInfo(self, context): infoChildren = self._getInfoChildren() if len(infoChildren)==1: self.fieldInfo = infoChildren[0].fieldInfo else: if len(infoChildren): msg = "More than one" else: msg = "No" raise common.Error("%s child with fieldInfo with" " no behaviour defined in %s, children %s"%( msg, self.__class__.__name__, list(self.iterChildren())))
[docs] def change(self, **kwargs): other = ADQLNode.change(self, **kwargs) other.fieldInfo = self.fieldInfo return other
[docs]class FunctionNode(FieldInfoedNode): """An ADQLNodes having a function name and arguments. The rules having this as action must use the Arg "decorator" in grammar.py around their arguments and must have a string-valued result "fName". FunctionNodes have attributes args (unflattened arguments), and funName (a string containing the function name, all upper case). """ _a_args = () _a_funName = None @classmethod def _getInitKWs(cls, _parseResult): try: args = parseArgs(_parseResult["args"]) #noflake: locals returned except KeyError: # Zero-Arg function pass funName = _parseResult["fName"].upper() #noflake: locals returned return locals()
[docs] def flatten(self): return "%s(%s)"%(self.funName, ", ".join(flatten(a) for a in self.args))
[docs]class ColumnBearingNode(ADQLNode): """A Node types defining selectable columns. These are tables, subqueries, etc. This class is the basis for the annotation of tables and subqueries. Their getFieldInfo(name)->fi method gives annotation.FieldInfos objects for their columns, None for unknown columns. These keep their fieldInfos on a change() """ fieldInfos = None originalTable = None
[docs] def getFieldInfo(self, name): if self.fieldInfos: return self.fieldInfos.getFieldInfo(name)
[docs] def getAllNames(self): # pragma: no cover """yields all relation names mentioned in this node. """ raise TypeError("Override getAllNames for ColumnBearingNodes.")
[docs] def change(self, **kwargs): other = ADQLNode.change(self, **kwargs) other.fieldInfos = self.fieldInfos return other
############# Toplevel query language node types (for query analysis)
[docs]class TableName(ADQLNode): type = "tableName" _a_cat = None _a_schema = None _a_name = None def __eq__(self, other): if hasattr(other, "qName"): return self.qName.lower()==other.qName.lower() try: return self.qName.lower()==other.lower() except AttributeError: # other has no lower, so it's neither a string nor a table name; # thus, fall through to non-equal case pass return False def __ne__(self, other): return not self==other def __bool__(self): return bool(self.name) def __str__(self): return "TableName(%s)"%self.qName def _polish(self): # Implementation detail: We map tap_upload to temporary tables # here; therefore, we can just nil out anything called tap_upload. # If we need more flexibility, this probably is the place to implement # the mapping. if self.schema and self.schema.lower()=="tap_upload": self.schema = None self.qName = ".".join(flatten(n) for n in (self.cat, self.schema, self.name) if n) @classmethod def _getInitKWs(cls, _parseResult): _parts = _parseResult[::2] cat, schema, name = [None]*(3-len(_parts))+_parts return locals()
[docs] def flatten(self): return self.qName
[docs] def lower(self): """returns self's qualified name in lower case. """ return self.qName.lower()
@staticmethod def _normalizePart(part): if isinstance(part, utils.QuotedName): return part.name else: return part.lower()
[docs] def getNormalized(self): """returns self's qualified name lowercased for regular identifiers, in original capitalisation otherwise. """ return ".".join(self._normalizePart(p) for p in [self.cat, self.schema, self.name] if p is not None)
[docs]class PlainTableRef(ColumnBearingNode): """A reference to a simple table. The tableName is the name this table can be referenced as from within SQL, originalName is the name within the database; they are equal unless a correlationSpecification has been given. """ type = "possiblyAliasedTable" _a_tableName = None # a TableName instance _a_originalTable = None # a TableName instance _a_sampling = None @classmethod def _getInitKWs(cls, _parseResult): if _parseResult.get("alias"): tableName = TableName(name=_parseResult.get("alias")) originalTable = _parseResult.get("tableName") else: tableName = getChildOfType(_parseResult, "tableName") originalTable = tableName #noflake: locals returned if _parseResult.get("tablesample"): sampling = float(_parseResult.get("tablesample")[2]) return locals()
[docs] def addFieldInfos(self, context): self.fieldInfos = fieldinfos.TableFieldInfos.makeForNode(self, context)
def _polish(self): self.qName = flatten(self.tableName)
[docs] def flatten(self): origTable = flatten(self.originalTable) if origTable!=self.qName: literal = "%s AS %s"%(origTable, flatten(self.tableName)) else: literal = self.qName if self.sampling: # TODO: Postgres dependency; this should be in morphpg literal = "%s TABLESAMPLE SYSTEM (%s)"%(literal, self.sampling) return literal
[docs] def getAllNames(self): yield self.tableName.qName
[docs] def getAllTables(self): yield self
[docs] def makeUpId(self): # for suggestAName n = self.tableName.name if isinstance(n, utils.QuotedName): return "_"+re.sub("[^A-Za-z0-9_]", "", n.name) else: return n
[docs]class DerivedTable(ColumnBearingNode): type = "derivedTable" _a_query = None _a_tableName = None
[docs] def getFieldInfo(self, name): return self.query.getFieldInfo(name)
def _get_fieldInfos(self): return self.query.fieldInfos def _set_fieldInfos(self, val): self.query.fieldInfos = val fieldInfos = property(_get_fieldInfos, _set_fieldInfos) @classmethod def _getInitKWs(cls, _parseResult): tmp = {'tableName': TableName(name=str(_parseResult.get("alias"))), 'query': getChildOfClass(_parseResult, SelectExpression), } return tmp
[docs] def flatten(self): return "(%s) AS %s"%(flatten(self.query), flatten(self.tableName))
[docs] def getAllNames(self): yield self.tableName.qName
[docs] def getAllTables(self): yield self
[docs] def makeUpId(self): # for suggestAName n = self.tableName.name if isinstance(n, utils.QuotedName): return "_"+re.sub("[^A-Za-z0-9_]", "", n.name) else: return n
[docs]class SetGeneratingFunction(ColumnBearingNode, TransparentMixin): """a function that can stand instead of a table. For starters, we only do generate_series here. Let's see where this leads. """ type = "setGeneratingFunction" _a_functionName = None _a_args = None _a_name = None # name is both the name of the column and the "table" # here. This will come from a correlationSpec where # available. It's generate_series otherwise. @classmethod def _getInitKWs(cls, _parseResult): functionName = _parseResult[0] # TODO: We really should allow more than two arguments here args = [_parseResult[2], _parseResult[4]] name = _parseResult.get("alias") if name is None: name = functionName return locals() def _polish(self): self.tableName = self.name
[docs] def getFieldInfo(self, name): return self.fieldInfos.getFieldInfo(name)
[docs] def getAllTables(self): yield self
[docs] def addFieldInfos(self, context): # TODO: Infer types from argument types self.fieldInfos = fieldinfos.FieldInfos(self, context) self.fieldInfos.addColumn(self.name, fieldinfo.FieldInfo("integer", None, None, sqlName=self.name))
[docs] def getAllNames(self): yield self.name
[docs] def makeUpId(self): return self.name
[docs]class JoinSpecification(ADQLNode, TransparentMixin): """A join specification ("ON" or "USING"). """ type = "joinSpecification" _a_children = () _a_predicate = None _a_usingColumns = () @classmethod def _getInitKWs(cls, _parseResult): predicate = _parseResult[0].upper() if predicate=="USING": usingColumns = [ #noflake: locals returned n for n in _parseResult["columnNames"] if n!=','] children = list(_parseResult) #noflake: locals returned return locals()
[docs]class JoinOperator(ADQLNode, TransparentMixin): """the complete join operator (including all LEFT, RIGHT, ",", and whatever). """ type = "joinOperator"
[docs] def isCrossJoin(self): return self.children[0] in (',', 'CROSS')
[docs]class JoinedTable(ColumnBearingNode): """A joined table. These aren't made directly by the parser since parsing a join into a binary structure is very hard using pyparsing. Instead, there's the helper function makeJoinedTableTree handling the joinedTable symbol that manually creates a binary tree. """ type = None originalTable = None tableName = TableName() qName = None _a_leftOperand = None _a_operator = None _a_rightOperand = None _a_joinSpecification = None @classmethod def _getInitKWs(cls, _parseResult): leftOperand = _parseResult[0] #noflake: locals returned operator = _parseResult[1] #noflake: locals returned rightOperand = _parseResult[2] #noflake: locals returned if len(_parseResult)>3: joinSpecification = _parseResult[3] #noflake: locals returned return locals()
[docs] def flatten(self): js = "" if self.joinSpecification is not None: js = flatten(self.joinSpecification) return "%s %s %s %s"%( self.leftOperand.flatten(), self.operator.flatten(), self.rightOperand.flatten(), js)
[docs] def addFieldInfos(self, context): self.fieldInfos = fieldinfos.TableFieldInfos.makeForNode(self, context)
def _polish(self): self.joinedTables = [self.leftOperand, self.rightOperand]
[docs] def getAllNames(self): """iterates over all fully qualified table names mentioned in this (possibly joined) table reference. """ for t in self.joinedTables: yield t.tableName.qName
[docs] def getTableForName(self, name): return self.fieldInfos.locateTable(name)
[docs] def makeUpId(self): # for suggestAName return "_".join(t.makeUpId() for t in self.joinedTables)
[docs] def getJoinType(self): """returns a keyword indicating how result rows are formed in this join. This can be NATURAL (all common columns are folded into one), USING (check the joinSpecification what columns are folded), CROSS (no columns are folded). """ if self.operator.isCrossJoin(): if self.joinSpecification is not None: raise common.Error("Cannot use cross join with a join predicate.") return "CROSS" if self.joinSpecification is not None: if self.joinSpecification.predicate=="USING": return "USING" if self.joinSpecification.predicate=="ON": return "CROSS" return "NATURAL"
[docs] def getAllTables(self): """returns all actual tables and subqueries (not sub-joins) within this join. """ res = [] def collect(node): if hasattr(node.leftOperand, "leftOperand"): collect(node.leftOperand) else: res.append(node.leftOperand) if hasattr(node.rightOperand, "leftOperand"): collect(node.rightOperand) else: res.append(node.rightOperand) collect(self) return res
[docs]class SubJoin(ADQLNode): """A sub join (JoinedTable surrounded by parens). The parse result is just the parens and a joinedTable; we need to camouflage as that joinedTable. """ type = "subJoin" _a_joinedTable = None @classmethod def _getInitKWs(cls, _parseResult): return {"joinedTable": _parseResult[1]}
[docs] def flatten(self): return "("+self.joinedTable.flatten()+")"
def __getattr__(self, attName): return getattr(self.joinedTable, attName)
[docs]@symbolAction("joinedTable") def makeBinaryJoinTree(children): """takes the parse result for a join and generates a binary tree of JoinedTable nodes from it. It's much easier to do this in a separate step than to force a non-left-recursive grammar to spit out the right parse tree in the first place. """ children = list(children) while len(children)>1: if len(children)>3 and isinstance(children[3], JoinSpecification): exprLen = 4 else: exprLen = 3 args = children[:exprLen] children[:exprLen] = [JoinedTable.fromParseResult(args)] return children[0]
[docs]class TransparentNode(ADQLNode, TransparentMixin): """An abstract base for Nodes that don't parse out anything. """ type = None
[docs]class WhereClause(TransparentNode): type = "whereClause"
[docs]class Grouping(TransparentNode): type = "groupByClause"
[docs]class Having(TransparentNode): type = "havingClause"
[docs]class OrderBy(TransparentNode): type = "sortSpecification"
[docs]class OffsetSpec(ADQLNode): type = "offsetSpec" _a_offset = None @classmethod def _getInitKWs(cls, _parseResult): return {"offset": int(_parseResult[1])}
[docs] def flatten(self): if self.offset is not None: # for morphpg, this never happens because _PGQS deals with it # (and sets self.offset to None). return "OFFSET %d"%self.offset return ""
[docs]class SelectQuery(ColumnBearingNode): type = "selectQuery" _a_setQuantifier = None _a_setLimit = None _a_selectList = None _a_fromClause = None _a_whereClause = None _a_groupby = None _a_having = None _a_orderBy = None def _polish(self): self.query = weakref.proxy(self) @classmethod def _getInitKWs(cls, _parseResult): res = {} for name in ["setQuantifier", "setLimit", "fromClause", "whereClause", "groupby", "having", "orderBy"]: res[name] = _parseResult.get(name) res["selectList"] = getChildOfType(_parseResult, "selectList") return res def _iterSelectList(self): for f in self.selectList.selectFields: if isinstance(f, DerivedColumn): yield f elif isinstance(f, QualifiedStar): for sf in self.fromClause.getFieldsForTable(f.sourceTable): yield sf else: raise common.Error("Unexpected %s in select list"%getType(f))
[docs] def getSelectFields(self): if self.selectList.allFieldsQuery: return self.fromClause.getAllFields() else: return self._iterSelectList()
[docs] def addFieldInfos(self, context): self.fieldInfos = fieldinfos.QueryFieldInfos.makeForNode(self, context)
[docs] def resolveField(self, fieldName): return self.fromClause.resolveField(fieldName)
[docs] def getAllNames(self): return self.fromClause.getAllNames()
[docs] def flatten(self): return flattenKWs(self, ("SELECT", None), ("", "setQuantifier"), ("TOP", "setLimit"), ("", "selectList"), ("", "fromClause"), ("", "whereClause"), ("", "groupby"), ("", "having"), ("", "orderBy"))
[docs] def suggestAName(self): """returns a string that may or may not be a nice name for a table resulting from this query. Whatever is being returned here, it's a regular SQL identifier. """ try: sources = [tableRef.makeUpId() for tableRef in self.fromClause.getAllTables()] if sources: return "_".join(sources) else: return "query_result" except: # should not happen, but we don't want to bomb from here import traceback;traceback.print_exc() return "weird_table_report_this"
[docs] def getContributingNames(self): """returns a set of table names mentioned below this node. """ names = set() for name, val in self.iterTree(): if isinstance(val, TableName): names.add(val.flatten()) return names
[docs]class SetOperationNode(ColumnBearingNode, TransparentMixin): """A node containing a set expression. This is UNION, INTERSECT, or EXCEPT. In all cases, we need to check all contributing sub-expressions have compatible degree. For now, in violation of SQL1992, we require identical names on all operands -- sql92 in 7.10 says [if column names are unequal], the <column name> of the i-th column of TR is implementation-dependent and different from the <column name> of any column, other than itself, of any table referenced by any <table reference> contained in the SQL-statement. Yikes. These collapse to keep things simple in the typical case. """ def _assertFieldInfosCompatible(self): """errors out if operands have incompatible signatures. For convenience, if all are compatible, the common signature (ie, fieldInfos) is returned. """ fieldInfos = None for child in self.children: # Skip WithQueries -- they're not part of set operations. if hasattr(child, "fieldInfos") and not isinstance(child, WithQuery): if fieldInfos is None: fieldInfos = child.fieldInfos.copy(self) else: fieldInfos.assertIsCompatible(child.fieldInfos) for (l1, mycol), (l2, theircol) in zip(fieldInfos, child.fieldInfos): if mycol.ucd != theircol.ucd: mycol.ucd = "" if mycol.unit != theircol.unit: mycol.unit = "" return fieldInfos
[docs] def addFieldInfos(self, context): self.fieldInfos = self._assertFieldInfosCompatible()
[docs] def getAllNames(self): for index, child in enumerate(self.children): if hasattr(child, "getAllNames"): for name in child.getAllNames(): yield name elif hasattr(child, "suggestAName"): yield child.suggestAName() else: # pragma: no cover assert False, "no name"
[docs] def getSelectClauses(self): for child in self.children: for sc in getattr(child, "getSelectClauses", lambda: [])(): yield sc if hasattr(child, "setLimit"): yield child
[docs]class SetTerm(SetOperationNode): type = "querySetTerm" collapsible = True
[docs]class SetExpression(SetOperationNode): type = "querySetExpression" collapsible = True # we hand the various introspection items through to our first child; # if one day we want to make sure the relations of all children are # reconcilable, this probably should happen in a _polish method. @property def fromClause(self): return self.children[0].fromClause
[docs] def getContributingNames(self): res = [] for child in self.children: if hasattr(child, "getContributingNames"): res.extend(child.getContributingNames()) return res
[docs]class WithQuery(SetOperationNode): """A query from a with clause. This essentially does everything a table does. """ type = "withQuery"
[docs] def addFieldInfos(self, context): # overridden because we need to discard statistics here. We # have code that looks at table sizes (look for "optimize"), # and it looks at columns' parents to do that. These stats # will be totally off in CTEs. To avoid dis-optimisations, # we note it's all off here. SetOperationNode.addFieldInfos(self, context) for _, fi in self.fieldInfos: fi.ignoreTableStats = True
def _polish(self): self.name = self.children[0] for c in self.children: # this should be a selectExpression, but this we want to be sure # we don't fail when morphers replace the main query node # (as the pg morpher does) if hasattr(c, "setLimit"): self.select = c break else: # pragma: no cover raise NotImplementedError("WithQuery without select?")
[docs]class SelectExpression(SetOperationNode): """A complete query excluding CTEs. The main ugly thing here is the set limit; the querySpecification has max of the limits of the children, if existing, otherwise to None. Other than that, we hand through attribute access to our first child. If there is a set expression on the top level, this will have a complex structure; the first-child thing still ought to work since after annotation we'll have errored out if set operator arguments aren't reasonably congurent. """ type = "selectExpression" _a_setLimit = None _a_offset= None
[docs] def getSelectClauses(self): for child in self.children: for sc in getattr(child, "getSelectClauses", lambda: [])(): yield sc if hasattr(child, "setLimit"): yield child
def _polish(self): if self.setLimit is None: limits = [selectClause.setLimit for selectClause in self.getSelectClauses()] limits = [int(s) for s in limits if s] if limits: self.setLimit = max(limits) for child in self.children: if isinstance(child, OffsetSpec) and child.offset is not None: self.offset = child.offset child.offset = None def __getattr__(self, attrName): return getattr(self.children[0], attrName)
[docs]class QuerySpecification(TransparentNode): """The toplevel query objects including CTEs. Apart from any CTEs, that's just a SelectExpression (which is always the last child), and we hand through essentially all attribute access to it. """ type = "querySpecification" def _polish(self): self.withTables = [] for child in self.children: if isinstance(child, WithQuery): self.withTables.append(child) def _setSetLimit(self, val): self.children[-1].setLimit = val def _getSetLimit(self): return self.children[-1].setLimit setLimit = property(_getSetLimit, _setSetLimit) def __getattr__(self, attrName): # our last child always is a querySpecification return getattr(self.children[-1], attrName)
class _BaseColumnReference(FieldInfoedNode): # normal column references will be handled by the dispatchColumnReference # function below; there we look at the columnReference type to figure # out whether something is constant, though, and so multiple # non-terminals share this. # That's why we're unbinding everything here and leave it to # deriving classes to set up bindings. type = "columnReference" bindings = [] _a_refName = None # if given, a TableName instance _a_name = None def _polish(self): if not self.refName: self.refName = None self.colName = ".".join( flatten(p) for p in (self.refName, self.name) if p) @classmethod def _getInitKWs(cls, _parseResult): names = [_c for _c in _parseResult if _c!="."] names = [None]*(4-len(names))+names refName = TableName(cat=names[0], schema=names[1], name=names[2]) if not refName: refName = None return { "name": names[-1], "refName": refName} def addFieldInfo(self, context): self.fieldInfo = context.getFieldInfo(self.name, self.refName) srcColumn = None if self.fieldInfo.userData: srcColumn = self.fieldInfo.userData[0] if hasattr(srcColumn, "originalName"): # This is a column from a VOTable upload we have renamed to avoid # clashes with postgres-reserved column names. Update the name # so the "bad" name doesn't apprear in the serialised query. if not isinstance(self.name, utils.QuotedName): self.name = srcColumn.name self._polish() def flatten(self): if self.fieldInfo and self.fieldInfo.sqlName: return ".".join( flatten(p) for p in (self.refName, self.fieldInfo.sqlName) if p) return self.colName def _treeRepr(self): return (self.type, self.name)
[docs]class ColumnReference(_BaseColumnReference): # No bindings: constructed by dispatchColumnReference pass
[docs]class GeometryValue(_BaseColumnReference): bindings = ["geometryValue"]
[docs]class ColumnReferenceByUCD(_BaseColumnReference): # these are tricky: As, when parsing, we don't know where the columns # might come from, we have to later figure out where to get our metadata # from. bindings = ["columnReferenceByUCD"] _a_ucdWanted = None @classmethod def _getInitKWs(cls, _parseResult): return { "ucdWanted": _parseResult[2].value, "name": utils.Undefined, "refName": utils.Undefined}
[docs] def addFieldInfo(self, context): # I've not really thought about where these might turn up. # Hence, I just heuristically walk up the ancestor stack # until I find a from clause. TODO: think about if that's valid. for ancestor in reversed(context.ancestors): if hasattr(ancestor, "fromClause"): break else: raise common.Error("UCDCOL outside of query specification with FROM") for field in ancestor.fromClause.getAllFields(): if fnmatch.fnmatch(field.fieldInfo.ucd, self.ucdWanted): self.fieldInfo = field.fieldInfo.change() self.name = self.colName = field.name self.refName = None break else: raise utils.NotFoundError(self.ucdWanted, "column matching ucd", "from clause") self.fieldInfo.properties["src-expression"] = flatten(self)
[docs]@symbolAction("columnReference") def dispatchColumnReference(parseResult): # this dispatch is there so ColumnReference is not bothered # by the by-UCD hack in the normal case. It should go if we # punt UCDCOL, and the columnReference binding should then go # back to ColumnReference if len(parseResult)==1 and isinstance(parseResult[0], ColumnReferenceByUCD): return parseResult[0] else: return ColumnReference.fromParseResult(parseResult)
[docs]class FromClause(ADQLNode): type = "fromClause" _a_tableReference = () _a_tables = () @classmethod def _getInitKWs(cls, parseResult): parseResult = list(parseResult) if len(parseResult)==1: tableReference = parseResult[0] else: # it's a cross join; to save repeating the logic, we'll # just build an artificial join as the table reference tableReference = reduce(lambda left, right: JoinedTable( leftOperand=left, operator=JoinOperator(children=[","]), rightOperand=right), parseResult) return { "tableReference": tableReference, "tables": parseResult}
[docs] def flatten(self): return "FROM %s"%(' , '.join(t.flatten() for t in self.tables))
[docs] def getAllNames(self): """returns the names of all tables taking part in this from clause. """ return self.tableReference.getAllNames()
[docs] def resolveField(self, name): return self.tableReference.getFieldInfo(name)
def _makeColumnReference(self, sourceTableName, colPair): """returns a ColumnReference object for a name, colInfo pair from a table's fieldInfos. """ cr = ColumnReference(name=colPair[0], refName=sourceTableName) cr.fieldInfo = colPair[1] return cr
[docs] def getAllFields(self): """returns all fields from all tables in this FROM. These will be qualified names. Columns taking part in joins are resolved here. This will only work for annotated tables. """ res = [] commonColumns = common.computeCommonColumns(self.tableReference) commonColumnsMade = set() for table in self.getAllTables(): for label, fi in table.fieldInfos.seq: if label in commonColumns: if label not in commonColumnsMade: res.append(self._makeColumnReference( None, (label, fi))) commonColumnsMade.add(label) else: res.append(self._makeColumnReference( table.tableName, (label, fi))) return res
[docs] def getFieldsForTable(self, srcTableName): """returns the fields in srcTable. srcTableName is a TableName. """ if fieldinfos.tableNamesMatch(self.tableReference, srcTableName): table = self.tableReference else: table = self.tableReference.fieldInfos.locateTable(srcTableName) return [self._makeColumnReference(table.tableName, ci) for ci in table.fieldInfos.seq]
[docs] def getAllTables(self): return self.tableReference.getAllTables()
[docs]class DerivedColumn(FieldInfoedNode): """A column within a select list. """ type = "derivedColumn" _a_expr = None _a_alias = None _a_tainted = True def _polish(self): if getType(self.expr)=="columnReference": self.tainted = False @property def name(self): # todo: be a bit more careful here to come up with meaningful # names (users don't like the funny names). Also: do # we make sure somewhere we're getting unique names? if self.alias is not None: return self.alias elif hasattr(self.expr, "name"): return self.expr.name else: return utils.intToFunnyWord(id(self)) @classmethod def _getInitKWs(cls, _parseResult): expr = _parseResult["expr"] #noflake: locals returned alias = _parseResult.get("alias") #noflake: locals returned return locals()
[docs] def flatten(self): return flattenKWs(self, ("", "expr"), ("AS", "alias"))
def _treeRepr(self): return (self.type, self.name)
[docs] def addFieldInfo(self, context): FieldInfoedNode.addFieldInfo(self, context) if self.fieldInfo and self.alias: self.fieldInfo = self.fieldInfo.change() if hasattr(self.alias, "name"): # a QName, presumably self.fieldInfo.properties["src-expression"] = self.alias.name else: self.fieldInfo.properties["src-expression"] = self.alias
[docs]class QualifiedStar(ADQLNode): type = "qualifiedStar" _a_sourceTable = None # A TableName for the column source @classmethod def _getInitKWs(cls, _parseResult): parts = _parseResult[:-2:2] # kill dots and star cat, schema, name = [None]*(3-len(parts))+parts return {"sourceTable": TableName(cat=cat, schema=schema, name=name)}
[docs] def flatten(self): return "%s.*"%flatten(self.sourceTable)
[docs]class SelectList(ADQLNode): type = "selectList" _a_selectFields = () _a_allFieldsQuery = False @classmethod def _getInitKWs(cls, _parseResult): allFieldsQuery = _parseResult.get("starSel", False) if allFieldsQuery: # Will be filled in by query, we don't have the from clause here. selectFields = None #noflake: locals returned else: selectFields = list(itertools.chain(#noflake: locals returned *_parseResult.get("fieldSel"))) return locals()
[docs] def flatten(self): if self.allFieldsQuery: return self.allFieldsQuery else: return ", ".join(flatten(sf) for sf in self.selectFields)
######## all expression parts we need to consider when inferring units and such
[docs]class Comparison(ADQLNode): """is required when we want to morph the braindead contains(...)=1 into a true boolean function call. """ type = "comparisonPredicate" _a_op1 = None _a_opr = None _a_op2 = None @classmethod def _getInitKWs(cls, _parseResult): op1, opr, op2 = _parseResult #noflake: locals returned return locals()
[docs] def flatten(self): return "%s %s %s"%(flatten(self.op1), self.opr, flatten(self.op2))
def _guessNumericType(literal): """returns a guess for a type suitable to hold a numeric value given in literal. I don't want to pull through the literal symbol that matched from grammar in all cases. Thus, at times I simply guess the type (and yes, I'm aware that -32768 still is a smallint). """ try: val = int(literal) if abs(val)<32767: type = "smallint" elif abs(val)<2147483648: type = "integer" else: type = "bigint" except ValueError: if literal.lower()=="null": type = None else: # I'm too lazy to spell out all sorts of float. type = "double precision" return type
[docs]class Factor(FieldInfoedNode, TransparentMixin): """a factor within an SQL expression. factors may have only one (direct) child with a field info and copy this. They can have no child with a field info, in which case they're simply numeric (about the weakest assumption: They're doubles). """ type = "factor" collapsible = True
[docs] def addFieldInfo(self, context): infoChildren = self._getInfoChildren() if infoChildren: assert len(infoChildren)==1 self.fieldInfo = infoChildren[0].fieldInfo else: self.fieldInfo = fieldinfo.FieldInfo( _guessNumericType("".join(self.children)), "", "") self.fieldInfo.properties["src-expression"] = flatten(self)
[docs]class ArrayReference(FieldInfoedNode, TransparentMixin): type = "arrayReference" collapsible = False
[docs] def addFieldInfo(self, context): infoChild = self.children[0] childInfo = infoChild.fieldInfo if childInfo.type is None: raise common.Error("Cannot subscript a typeless thing in %s"%( self.flatten())) lastSubscript = re.search("\[[0-9]*\]$", childInfo.type) if lastSubscript is None: raise common.Error("Cannot subscript a non-array in %s"%( self.flatten())) self.fieldInfo = fieldinfo.FieldInfo( childInfo.type[:lastSubscript.start()], childInfo.unit, childInfo.ucd, childInfo.userData, tainted=True # array might actually have semantics ) self.fieldInfo.properties["src-expression"] = flatten(self)
[docs]class CombiningFINode(FieldInfoedNode):
[docs] def addFieldInfo(self, context): infoChildren = self._getInfoChildren() if not infoChildren: if len(self.children)==1: # probably a naked numeric literal in the grammar, e.g., # in mathFunction self.fieldInfo = fieldinfo.FieldInfo( _guessNumericType(self.children[0]), "", "") else: raise common.Error("Oops -- did not expect '%s' when annotating %s"%( "".join(self.children), self)) elif len(infoChildren)==1: self.fieldInfo = infoChildren[0].fieldInfo else: self.fieldInfo = self._combineFieldInfos() if self.fieldInfo: self.fieldInfo.properties["src-expression"] = flatten(self)
[docs]class Term(CombiningFINode, TransparentMixin): type = "term" collapsible = True def _combineFieldInfos(self): # These are either multiplication or division toDo = self.children[:] opd1 = toDo.pop(0) fi1 = opd1.fieldInfo while toDo: opr = toDo.pop(0) fi1 = fieldinfo.FieldInfo.fromMulExpression(opr, fi1, toDo.pop(0).fieldInfo) return fi1
[docs]class NumericValueExpression(CombiningFINode, TransparentMixin): type = "numericValueExpression" collapsible = True def _combineFieldInfos(self): # These are either addition or subtraction toDo = self.children[:] fi1 = toDo.pop(0).fieldInfo while toDo: opr = toDo.pop(0) fi1 = fieldinfo.FieldInfo.fromAddExpression( opr, fi1, toDo.pop(0).fieldInfo) return fi1
[docs]class StringValueExpression(FieldInfoedNode, TransparentMixin): type = "stringValueExpression" collapsible = True
[docs] def addFieldInfo(self, context): # This is concatenation; we treat is as if we'd be adding numbers infoChildren = self._getInfoChildren() if infoChildren: fi1 = infoChildren.pop(0).fieldInfo if fi1.type=="unicode": baseType = "unicode" else: baseType = "text" while infoChildren: if infoChildren[0].fieldInfo.type=="unicode": baseType = "unicode" fi1 = fieldinfo.FieldInfo.fromAddExpression( "+", fi1, infoChildren.pop(0).fieldInfo, forceType=baseType) self.fieldInfo = fi1 else: self.fieldInfo = fieldinfo.FieldInfo( "text", "", "") self.fieldInfo.properties["src-expression"] = flatten(self)
[docs]class GenericValueExpression(CombiningFINode, TransparentMixin): """A container for value expressions that we don't want to look at closer. It is returned by the makeValueExpression factory below to collect stray children. """ type = "genericValueExpression" collapsible = True def _combineFieldInfos(self): # we don't really know what these children are. Let's just give up # unless all child fieldInfos are more or less equal (which of course # is a wild guess). childUnits, childUCDs = set(), set() infoChildren = self._getInfoChildren() for c in infoChildren: childUnits.add(c.fieldInfo.unit) childUCDs.add(c.fieldInfo.ucd) if len(childUnits)==1 and len(childUCDs)==1: # let's taint the first info and be done with it return infoChildren[0].fieldInfo.change(tainted=True) else: # if all else fails: let's hope someone can make a string from it return fieldinfo.FieldInfo("text", "", "")
[docs]@symbolAction("valueExpression") def makeValueExpression(children): if len(children)!=1: res = GenericValueExpression.fromParseResult(children) res.type = "valueExpression" return res else: return children[0]
[docs]class SetFunction(TransparentMixin, FieldInfoedNode): """An aggregate function. These typically amend the ucd by a word from the stat family and copy over the unit. There are exceptions, however, see table in class def. """ type = "setFunctionSpecification" funcDefs = { 'AVG': ('{u};stat.mean', None, "double precision"), 'STDDEV': ('stat.stdev;{u}', None, "double precision"), 'MAX': ('stat.max;{u}', None, None), 'MIN': ('stat.min;{u}', None, None), 'SUM': (None, None, None), 'COUNT': ('meta.number;{u}', '', "integer"),}
[docs] def addFieldInfo(self, context): funcName = self.children[0].upper() ucdPref, newUnit, newType = self.funcDefs[funcName] # try to find out about our child infoChildren = self._getInfoChildren() if infoChildren: assert len(infoChildren)==1 fi = infoChildren[0].fieldInfo else: fi = fieldinfo.FieldInfo("double precision", "", "") if ucdPref is None: # ucd of a sum is the ucd of the summands? ucd = fi.ucd elif fi.ucd: ucd = ucdPref.format(u=fi.ucd) else: # no UCD given; if we're count, we're meta.number, otherwise we # don't know if funcName=="COUNT": ucd = "meta.number" else: ucd = None # most of these keep the unit of what they're working on if newUnit is None: newUnit = fi.unit # most of these keep the type of what they're working on if newType is None: newType = fi.type self.fieldInfo = fieldinfo.FieldInfo( newType, unit=newUnit, ucd=ucd, userData=fi.userData, tainted=fi.tainted) self.fieldInfo.properties["src-expression"] = flatten(self)
[docs]class NumericValueFunction(FunctionNode): """A numeric function. This is really a mixed bag. We work through handlers here. See table in class def. Unknown functions result in dimlesses. """ type = "numericValueFunction" collapsible = True # if it's a real function call, it has at least # a name, parens and an argument and thus won't be collapsed. funcDefs = { "ACOS": ('rad', '', None), "ASIN": ('rad', '', None), "ATAN": ('rad', '', None), "ATAN2": ('rad', '', None), "PI": ('', '', None), "RAND": ('', '', None), "EXP": ('', '', None), "LOG": ('', '', None), "LOG10": ('', '', None), "SQRT": ('', '', None), "SQUARE": ('', '', None), "POWER": ('', '', None), "ABS": (None, None, "keepMeta"), "CEILING": (None, None, "keepMeta"), "FLOOR": (None, None, "keepMeta"), "ROUND": (None, None, "keepMeta"), "TRUNCATE": (None, None, "keepMeta"), "DEGREES": ('deg', None, "keepMeta"), "RADIANS": ('rad', None, "keepMeta"), # bitwise operators: hopeless } def _handle_keepMeta(self, infoChildren): fi = infoChildren[0].fieldInfo return fi.unit, fi.ucd
[docs] def addFieldInfo(self, context): infoChildren = self._getInfoChildren() unit, ucd = '', '' overrideUnit, overrideUCD, handlerName = self.funcDefs.get( self.funName, ('', '', None)) if handlerName: unit, ucd = getattr(self, "_handle_"+handlerName)(infoChildren) if overrideUnit: unit = overrideUnit if overrideUCD: ucd = overrideUCD self.fieldInfo = fieldinfo.FieldInfo("double precision", unit, ucd, *collectUserData(infoChildren)) self.fieldInfo.tainted = True self.fieldInfo.properties["src-expression"] = flatten(self)
[docs]class ScalarArrayFunction(FunctionNode): """All of these are functions somehow aggregating array elements. For those that we know we try to infer units and UCDs. """ type = "scalarArrayFunction" funcDefs = { 'ARR_MIN': ('stat.min;{u}', None, 1), 'ARR_MAX': ('stat.max;{u}', None, 1), 'ARR_AVG': ('{u};stat.mean', None, 1), 'ARR_STDDEV': ('stat.stdev;{u}', None, 1), 'ARR_SUM': ('{u};arith.sum', None, 1), 'ARR_COUNT': ('meta.number;{u}', "", 1), 'ARR_DOT': (None, None, 2), }
[docs] def addFieldInfo(self, context): ucdchange, unitchange, nPar = self.funcDefs.get( self.funName, (None, None, 1)) infoChildren = self._getInfoChildren() if len(infoChildren)!=nPar: raise common.Error( "Grammar vs. annotation function arity mismatch!") # we're being very liberal with upcasting to double precision # here since that may help fend off integer overflows. type = "double precision" if nPar==1: baseFI = infoChildren[0].fieldInfo mat = fieldinfo.isArray(baseFI.type) if not mat: raise common.Error( "Scalar array function called on non-array?") if self.funName=="ARR_COUNT": type = "integer" elif nPar==2: baseFI = fieldinfo.FieldInfo(type, "", "", userData=infoChildren) else: assert False ucd, unit = "", baseFI.unit if ucdchange is not None and baseFI.ucd: ucd = ucdchange.format(u=baseFI.ucd) if unitchange is not None: unit = unitchange self.fieldInfo = fieldinfo.FieldInfo(type, unit, ucd, *collectUserData(infoChildren))
[docs]class ArrayMapFunction(FunctionNode): """The arr_map extension. arg 1 is the expression; any x in it will be the filled with the array elements. arg 2 is the array to take the elements from. """ type = "arrayMapFunction"
[docs] def addFieldInfo(self, context): self.fieldInfo = fieldinfo.FieldInfo("double precision[]", "", "")
def _polish(self): # we need to pre-annotate all column references to "x" (our unbound # variable) in our first argument with a literal "x" to keep # our annotation engine from looking at them def preAnnotate(node): for child in node.iterNodeChildren(): if isinstance(child, ColumnReference) and child.name=="x": child.fieldInfo = fieldinfo.FieldInfo( "double precision", "", "") preAnnotate(child) preAnnotate(self.args[0])
[docs]class StringValueFunction(FunctionNode): type = "stringValueFunction"
[docs] def addFieldInfo(self, context): self.fieldInfo = fieldinfo.FieldInfo("text", "", "", userData=collectUserData(self._getInfoChildren())[0]) self.fieldInfo.properties["src-expression"] = flatten(self)
[docs]class TimestampFunction(FunctionNode): type = "timestampFunction"
[docs] def addFieldInfo(self, context): subordinates = self._getInfoChildren() if subordinates: ucd, stc = subordinates[0].fieldInfo.ucd, subordinates[0].fieldInfo.stc else: ucd, stc = None, None userData, tainted = collectUserData(subordinates) self.fieldInfo = fieldinfo.FieldInfo("timestamp", "", ucd=ucd, stc=stc, userData=userData, tainted=tainted)
[docs]class InUnitFunction(FieldInfoedNode): type = "inUnitFunction" _a_expr = None _a_unit = None conversionFactor = None @classmethod def _getInitKWs(cls, _parseResult): return { 'expr': _parseResult[2], 'unit': _parseResult[4].value, }
[docs] def addFieldInfo(self, context): try: from gavo.base import computeConversionFactor, IncompatibleUnits, BadUnit except ImportError: # pragma: no cover raise utils.ReportableError("in_unit only available with gavo.base" " installed") try: self.conversionFactor = computeConversionFactor( self.expr.fieldInfo.unit, self.unit) self.fieldInfo = self.expr.fieldInfo.change(unit=self.unit) except IncompatibleUnits as msg: raise common.Error("in_unit error: %s"%msg) except BadUnit as msg: raise common.Error("Bad unit passed to in_unit: %s"%msg)
[docs] def flatten(self): if self.conversionFactor is None: # pragma: no cover raise common.Error("in_unit can only be flattened in annotated" " trees") if isinstance(self.expr, ColumnReference): exprPat = "%s" else: exprPat = "(%s)" return "(%s * %.16g)"%(exprPat%flatten(self.expr), self.conversionFactor)
[docs] def change(self, **kwargs): copy = FieldInfoedNode.change(self, **kwargs) copy.conversionFactor = self.conversionFactor return copy
[docs]class CharacterStringLiteral(FieldInfoedNode): """according to the current grammar, these are always sequences of quoted strings. """ type = "characterStringLiteral" bindings = ["characterStringLiteral", "generalLiteral"] _a_value = None @classmethod def _getInitKWs(cls, _parseResult): value = "".join(_c[1:-1] for _c in _parseResult) #noflake: locals returned return locals()
[docs] def flatten(self): return "'%s'"%self.value
[docs] def addFieldInfo(self, context): self.fieldInfo = fieldinfo.FieldInfo("text", "", "")
[docs]class CastSpecification(FieldInfoedNode, TransparentMixin): type = "castSpecification" _a_value = None _a_newType = None @classmethod def _getInitKWs(cls, _parseResult): value = _parseResult["value"] newType = _parseResult["newType"].lower() if newType.startswith("char ("): newType = "text" elif newType.startswith("national char"): newType = "unicode" return locals() geometryConstructors = { "POINT": "Point", "CIRCLE": "Circle", "POLYGON": "Polygon"} def _polish(self): # the casts to geometries need to be turned into constructor calls. # We could theoretically do this while morphing, but then annotation # would get a lot more complicated. newNodeName = self.geometryConstructors.get( self.newType.upper(), None) if newNodeName is None: return # we only have names in geometryConstructors because we're defined # lexically above the geometries nodeClass = globals()[newNodeName] raise ReplaceNode(nodeClass.fromCastArgument(self.value))
[docs] def addFieldInfo(self, context): # We copy units and UCDs from the subordinate value (if it's there; # NULLs have nothing, of course). That has the somewhat unfortunate # effect that we may be declaring units on strings. Ah well. if hasattr(self.value, "fieldInfo"): self.fieldInfo = self.value.fieldInfo.change( type=self.newType, tainted=True) else: self.fieldInfo = fieldinfo.FieldInfo(self.newType, "", "")
[docs] def flatten(self): if self.children[-2].lower().endswith("char ( * )"): # postgres, surprisingly, can't cast to char(*), so we just hack # it to text (the endswith is for national char) self.children[-2] = "TEXT" return FieldInfoedNode.flatten(self)
[docs]class CoalesceExpression(FieldInfoedNode, TransparentMixin): type = "coalesceExpression" # The trouble as to inference here is that of course anything # might happen per-row, as people could stick in whatever. # But we assume well-meaning people and thus just take the annotation # from the first argument (which is supposed to be the "normal case").
[docs] def addFieldInfo(self, context): subordinates = self._getInfoChildren() if subordinates: self.fieldInfo = subordinates[0].fieldInfo.change(tainted=True) else: # there's probably only nulls here; make it so we produce # the most robust NULL VOTable has. self.fieldInfo = fieldinfo.FieldInfo('real', None, None)
[docs]class CaseExpression(FieldInfoedNode, TransparentMixin): type = "searchedCase" bindings = ["searchedCase", "simpleCase"]
[docs] def addFieldInfo(self, context): results, nextIsResult = [], False for c in self.children: if c in {"THEN", "ELSE"}: nextIsResult = 1 else: if nextIsResult: results.append(c) nextIsResult = False for res in results: # we try to find a case that has proper user data (i.e., a column) # attached; anything else will probably be empty info. if res.fieldInfo and res.fieldInfo.userData: self.fieldInfo = res.fieldInfo.change(tainted=True) break else: # We probably only have literals. Get the type from the first # and run. Should we make sure that all cases have matching types? for res in results: if res.fieldInfo and res.fieldInfo.type: self.fieldInfo = res.fieldInfo.change(tainted=True) break else: # we cannot venture any guess what we'll return. Let's avoid # wild guesses for now and just fail. raise common.Error("CASE statement seems to have no results.")
###################### Geometry and stuff that needs morphing into real SQL
[docs]class CoosysMixin(object): """is a mixin that works cooSys into FieldInfos for ADQL geometries. """ _a_cooSys = None @classmethod def _getInitKWs(cls, _parseResult): refFrame = _parseResult.get("coordSys", "") if isinstance(refFrame, ColumnReference): # pragma: no cover raise NotImplementedError("References frames must not be column" " references.") return {"cooSys": refFrame}
[docs]class GeometryNode(CoosysMixin, FieldInfoedNode): """Nodes for geometry constructors. In ADQL 2.1, most of these became polymorphous. For instance, circles can be constructed with a point as the first (or second, if a coosys is present) argument; that point can also be a column reference. Also, these will always get morphed in some way (as the database certainly doesn't understand ADQL geometries). So, we're trying to give the morphers a fair chance of not getting confused despite the wild variety of argument forms and types. stcArgs is a list of symbolic names that *might* contain stc (or similar) information. Some of the actual attributes will be None. Flatten is only there for debugging; it'll return invalid SQL. OrigArgs is not for client consumption; clients must go through the symbolic names. If you want your geometry to support casts, give it a fromCastArgs(args) method that will, in general, probably replace the node with a GeometryCast node. You'll then have to teach GeometryCast how to annotate your geometry and add the type name to the geometryConstructors dictionary in CastSpecification. """ _a_origArgs = None
[docs] def flatten(self): return "%s%s"%(self.type.upper(), "".join(flatten(arg) for arg in self.origArgs))
@classmethod def _getInitKWs(cls, _parseResult): return {"origArgs": list(_parseResult[1:])}
[docs] def addFieldInfo(self, context): fis = [attr.fieldInfo for attr in (getattr(self, arg) for arg in self.stcArgs if getattr(self, arg)) if attr and attr.fieldInfo] childUserData, childUnits = [], [] thisSystem = tapstc.getSTCForTAP(self.cooSys) # get reference frame from first child if not given in node and # one is defined there. if thisSystem.astroSystem.spaceFrame.refFrame is None: if fis and fis[0].stc: thisSystem = fis[0].stc ignoreTableStats = False for index, fi in enumerate(fis): childUserData.extend(fi.userData) childUnits.append(fi.unit) ignoreTableStats = ignoreTableStats or fi.ignoreTableStats if not context.policy.match(fi.stc, thisSystem): context.errors.append("When constructing %s: Argument %d has" " incompatible STC"%(self.type, index+1)) self.fieldInfo = fieldinfo.FieldInfo( type=self.sqlType, ucd="", unit="", userData=tuple(childUserData), stc=thisSystem, ignoreTableStats=ignoreTableStats) self.fieldInfo.properties["xtype"] = self.xtype
[docs]class GeometryCast(GeometryNode): """A cast to a geometry type. For these, we defer to functions built into the database. """ _a_argument = None _a_cast_function = None resultingTypes = { "cast_to_point": "spoint", "cast_to_circle": "scircle", "cast_to_polygon": "spoly",}
[docs] def addFieldInfo(self, context): self.fieldInfo = fieldinfo.FieldInfo( type=self.resultingTypes[self.cast_function], unit='', ucd='', stc=None)
[docs] def flatten(self): return "{}({})".format( self.cast_function, self.argument.flatten())
[docs]class Point(GeometryNode): type = "point" _a_x = _a_y = None xtype = "point" sqlType = "spoint" stcArgs = ("x", "y")
[docs] def flatten(self): return "%s(%s)"%(self.type.upper(), ", ".join(flatten(arg) for arg in [self.x, self.y]))
@classmethod def _getInitKWs(cls, _parseResult): x, y = parseArgs(_parseResult["args"]) #noflake: locals returned return locals()
[docs] @classmethod def fromCastArgument(self, arg): raise ReplaceNode(GeometryCast( argument=arg, cast_function="cast_to_point"))
[docs]class Circle(GeometryNode): """A circle parsed from ADQL. There are two ways a circle is specified: either with (x, y, radius) or as (center, radius). In the second case, center is an spoint-valued column reference. Cases with a point-valued literal are turned into the first variant during parsing. """ type = "circle" _a_radius = None _a_center = None stcArgs = ("center", "radius") xtype = "circle" sqlType = "scircle" @classmethod def _getInitKWs(cls, _parseResult): args = parseArgs(_parseResult["args"]) res = {a: None for a in cls.stcArgs} if len(args)==2: res["center"], res["radius"] = args[0], args[1] elif len(args)==3: res["center"] = Point(cooSys=_parseResult.get("coordSys", ""), x=args[0], y=args[1]) res["radius"] = args[2] else: # pragma: no cover assert False, "Grammar let through invalid args to Circle" return res
[docs] @classmethod def fromCastArgument(self, arg): raise ReplaceNode(GeometryCast( argument=arg, cast_function="cast_to_circle"))
[docs]class MOC(GeometryNode): """a MOC in an ADQL syntax tree. This can be constructed from an ASCII-MOC string or from an order and a geometry value expression. """ type = "moc" _a_literal = None _a_order = None _a_geometry = None stcArgs = () xtype = "moc" sqlType = "smoc" @classmethod def _getInitKWs(cls, _parseResult): _args = parseArgs(_parseResult["args"]) if len(_args)==1: literal = _args[0] elif len(_args)==2: order, geometry = _args[0], _args[1] else: raise common.Error("MOC() takes either one literal or order, geo") return locals()
[docs] def flatten(self): # there's no point morphing this; when people put this into db # engines, they can just as well use the ADQL signature. if self.literal is None: return "smoc(%s, %s)"%(flatten(self.order), flatten(self.geometry)) else: return "smoc(%s)"%flatten(self.literal)
[docs]class Box(GeometryNode): type = "box" _a_x = _a_y = _a_width = _a_height = None stcArgs = ("x", "y", "width", "height") xtype = "polygon" sqlType = "sbox" @classmethod def _getInitKWs(cls, _parseResult): x, y, width, height = parseArgs( #noflake: locals returned _parseResult["args"]) return locals()
[docs]class PolygonCoos(FieldInfoedNode): """a base class for the various argument forms of polygons. We want to tell them apart to let the grammar tell the tree builder what it thinks the arguments were. Polygon may have to reconsider this when it learns the types of its arguments, but we don't want to discard the information coming from the grammar. """ _a_args = None @classmethod def _getInitKWs(cls, _parseResult): return {"args": parseArgs(_parseResult["args"])}
[docs] def addFieldInfo(self, context): # these fieldInfos are never used because Polygon doesn't ask us. pass
[docs] def flatten(self): return ", ".join(flatten(a) for a in self.args)
[docs]class PolygonSplitCooArgs(PolygonCoos): type = "polygonSplitCooArgs"
[docs]class PolygonPointCooArgs(PolygonCoos): type = "polygonPointCooArgs"
[docs]class Polygon(GeometryNode): type = "polygon" _a_coos = None _a_points = None stcArgs = ("coos", "points") xtype = "polygon" sqlType = "spoly" @classmethod def _getInitKWs(cls, _parseResult): # XXX TODO: The grammar will parse even-numbered arguments >=6 into # splitCooArgs. We can't fix that here as we don't have reliable # type information at this point. Fix coos/points confusion # in addFieldInfo, I'd say arg = parseArgs(_parseResult["args"])[0] if arg.type=="polygonPointCooArgs": # geometry-typed arguments res = {"points": tuple(parseArgs(arg.args))} # See if they're all literal points, which which case we fall # back to the split args for item in res["points"]: if item.type!="point": return res # all points: mutate args to let us fall through to the split coup # case arg.type = "polygonSplitCooArgs" newArgs = [] for item in res["points"]: newArgs.extend([item.x, item.y]) arg.args = newArgs if arg.type=="polygonSplitCooArgs": # turn numeric expressions into pairs coos, toDo = [], list(arg.args) while toDo: coos.append(tuple(toDo[:2])) del toDo[:2] res = {"coos": coos} else: # pragma: no cover assert False, "Invalid arguments to polygon" return res
[docs] @classmethod def fromCastArgument(self, arg): raise ReplaceNode(GeometryCast( argument=arg, cast_function="cast_to_polygon"))
[docs] def addFieldInfo(self, name): if self.points is not None: systemSource = self.points elif self.coos is not None: systemSource = (c[0] for c in self.coos) else: # pragma: no cover assert False if self.cooSys and self.cooSys!="UNKNOWN": thisSystem = tapstc.getSTCForTAP(self.cooSys) for geo in systemSource: if geo.fieldInfo.stc and geo.fieldInfo.stc.astroSystem.spaceFrame.refFrame: thisSystem = geo.fieldInfo.stc break else: thisSystem = tapstc.getSTCForTAP("UNKNOWN") userData, tainted = collectUserData( self.points or [c[0] for c in self.coos]+[c[1] for c in self.coos]) self.fieldInfo = fieldinfo.FieldInfo( type=self.sqlType, unit="deg", ucd="phys.angArea", userData=userData, tainted=tainted, stc=thisSystem)
_regionMakers = []
[docs]def registerRegionMaker(fun): """adds a region maker to the region resolution chain. region makers are functions taking the argument to REGION and trying to do something with it. They should return either some kind of FieldInfoedNode that will then replace the REGION or None, in which case the next function will be tried. As a convention, region specifiers here should always start with an identifier (like simbad, siapBbox, etc, basically [A-Za-z]+). The rest is up to the region maker, but whitespace should separate this rest from the identifier. The entire region functionality will probably disappear with TAP 1.1. Don't do anything with it any more. Use ufuncs instead. """ _regionMakers.append(fun)
[docs]@symbolAction("region") def makeRegion(children): if len(children)!=4 or not isinstance(children[2], CharacterStringLiteral): raise common.RegionError("Invalid argument to REGION: '%s'."% "".join(flatten(c) for c in children[2:-1]), hint="Here, regions must be simple strings; concatenations or" " non-constant parts are forbidden. Use ADQL geometry expressions" " instead.") arg = children[2].value for r in _regionMakers: res = r(arg) if res is not None: return res raise common.RegionError("Invalid argument to REGION: '%s'."% arg, hint="None of the region parsers known to this service could" " make anything of your string. While STC-S should in general" " be comprehendable to TAP services, it's probably better to" " use ADQL geometry functions.")
[docs]class STCSRegion(FieldInfoedNode): bindings = [] # we're constructed by makeSTCSRegion, not by the parser type = "stcsRegion" xtype = "adql:REGION" _a_tapstcObj = None # from tapstc -- STCSRegion or a utils.pgshere object def _polish(self): self.cooSys = self.tapstcObj.cooSys
[docs] def addFieldInfo(self, context): # XXX TODO: take type and unit from tapstcObj self.fieldInfo = fieldinfo.FieldInfo("spoly", unit="deg", ucd=None, stc=tapstc.getSTCForTAP(self.cooSys))
[docs] def flatten(self): # pragma: no cover raise common.FlattenError("STCRegion objectcs cannot be flattened, they" " must be morphed.")
[docs]def makeSTCSRegion(spec): try: return STCSRegion(stc.parseSimpleSTCS(spec)) except stc.STCSParseError: #Not a valid STC spec, try next region parser return None
registerRegionMaker(makeSTCSRegion)
[docs]class Centroid(FunctionNode): type = "centroid"
[docs] def addFieldInfo(self, context): self.fieldInfo = fieldinfo.FieldInfo(type="spoint", unit="", ucd="", userData=collectUserData(self._getInfoChildren())[0])
[docs]class Distance(FunctionNode): type = "distanceFunction"
[docs] def addFieldInfo(self, context): self.fieldInfo = fieldinfo.FieldInfo(type="double precision", unit="deg", ucd="pos.angDistance", userData=collectUserData(self._getInfoChildren())[0])
[docs] def optimize(self, stack): assert len(self.args)==2, "unexpected arguments in distance" self.args = list(self.args) self.args[0], self.args[1] = _sortLargeFirst(self.args[0], self.args[1])
@classmethod def _getInitKWs(cls, _parseResult): args = parseArgs(_parseResult["args"]) if len(args)==4: # always normalise to (point, point) args = [ Point(cooSys="", x=args[0], y=args[1]), Point(cooSys="", x=args[2], y=args[3])] return locals()
[docs]class PredicateGeometryFunction(FunctionNode): type = "predicateGeometryFunction" _pgFieldInfo = fieldinfo.FieldInfo("integer", "", "")
[docs] def optimize(self, stack): if len(self.args)!=2: assert False, "Grammar let through bad arguments to pgf" self.args = list(self.args) # by ADQL, an INTERSECTS with a point has to become a CONTAINS if self.funName=="INTERSECTS": ltype = getattr(self.args[0].fieldInfo, "type", None) rtype = getattr(self.args[1].fieldInfo, "type", None) if ltype=='spoint': self.funName = "CONTAINS" elif rtype=='spoint': self.funName = "CONTAINS" self.args[0], self.args[1] = self.args[1], self.args[0] leftInd, rightInd = 0, 1 # optimise the common case of contains(point, circle); both # q3c and pgsphere won't use an index (properly) if the sequence # of the arguments is "wrong". if (self.args[leftInd].type=="point" and self.args[rightInd].type=="circle"): if _isConstant([self.args[leftInd]]): self.args[leftInd], self.args[rightInd].center = \ self.args[rightInd].center, self.args[leftInd] else: self.args[leftInd], self.args[rightInd].center = _sortLargeFirst( self.args[leftInd], self.args[rightInd].center) # in case we swapped, coosys meta might be out of whack, so # fix that: self.args[rightInd].cooSys = self.args[rightInd].center.cooSys
[docs] def addFieldInfo(self, context): # swallow all upstream info, it really doesn't help here self.fieldInfo = self._pgFieldInfo
[docs] def flatten(self): return "%s(%s)"%(self.funName, ", ".join(flatten(a) for a in self.args))
[docs]class PointFunction(FunctionNode): type = "pointFunction" def _makeCoordsysFieldInfo(self): return fieldinfo.FieldInfo("text", unit="", ucd="meta.ref;pos.frame") def _makeCoordFieldInfo(self): # this should pull in the metadata from the 1st or 2nd component # of the argument. However, given the way geometries are constructed # in ADQL, what comes back here is in degrees in the frame of the # child always. We're a bit pickier with the user data -- if there's # exactly two user data fields in the child, we assume the child # has been built from individual columns, and we try to retrieve the # one pulled out. childFieldInfo = self.args[0].fieldInfo if len(childFieldInfo.userData)==2: userData = (childFieldInfo.userData[int(self.funName[-1])-1],) else: userData = childFieldInfo.userData return fieldinfo.FieldInfo("double precision", ucd=None, unit="deg", userData=userData)
[docs] def addFieldInfo(self, context): if self.funName=="COORDSYS": makeFieldInfo = self._makeCoordsysFieldInfo else: # it's coordN makeFieldInfo = self._makeCoordFieldInfo self.fieldInfo = makeFieldInfo()
[docs]class Area(FunctionNode): type = "area"
[docs] def addFieldInfo(self, context): self.fieldInfo = fieldinfo.FieldInfo(type="double precision", unit="deg**2", ucd="phys.angSize", userData=collectUserData(self._getInfoChildren())[0])
_ADDITIONAL_NODES = []
[docs]def registerNode(node): """registers a node class or a symbolAction from a module other than node. This is a bit of magic -- some module can call this to register a node class that is then bound to some parse action as if it were in nodes. I'd expect this to be messy in the presence of chaotic imports (when classes are not necessarily singletons and a single module can be imported more than once. For now, I ignore this potential bomb. """ # adding a node probably changes the grammar bindings, so we # need to clear the cached grammar getTreeBuildingGrammar.cache_clear() # if a module registering a node is re-loaded, we have to remove the # node it previously registered, or we'll get a symbol clash. # It's not totally straightforward to figure out what the previous # version was; for now, let's look at type, but given the various # ways nodes might use to bind, that's potentially not enough. for existing in _ADDITIONAL_NODES: if existing.type==node.type: _ADDITIONAL_NODES.remove(existing) break _ADDITIONAL_NODES.append(node)
[docs]def getNodeClasses(): """returns a list of node classes (and standalone parse actions) available for tree building. This is what needs to be passed to adql.getADQLGrammarCopy to get a proper parse tree out of the parser. """ res = [] for item in itertools.chain(globals().values(), _ADDITIONAL_NODES): if isinstance(item, type) and issubclass(item, ADQLNode): res.append(item) if hasattr(item, "parseActionFor"): res.append(item) return res
[docs]@functools.lru_cache(1) def getTreeBuildingGrammar(): return grammar.getADQLGrammarCopy(getNodeClasses())