Source code for gavo.base.sqlmunge

"""
Helpers for building SQL expressions.

Some of this code is concerned with SQL factories.  These are functions
with the signature::

	func(field, val, outPars) -> fragment

outPars is a dictionary that is used to transmit literal values into SQL.
The result must be an SQL boolean expression for embedding into a WHERE clause
(use None to signal no constraint).  Field is the field for which the
expression is being generated.

The factories currently are never called when val is a sequence; there's
special hard-coded behaviour for that in getSQLFactory.

To enter values in outPars, use getSQLKey.  Its docstring contains
an example that shows how that would look like.
"""

#c Copyright 2008-2025, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL.  See the
#c COPYING file in the source distribution.

import datetime

from gavo import utils
from gavo.stc import mjdToDateTime

from gavo.utils.dachstypes import (
	Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING)

if TYPE_CHECKING:
	from gavo.svcs import InputKey


plusInfinity = float("Inf")
minusInfinity = float("-Inf")


[docs]def joinOperatorExpr(operator: str, operands: List[str]) -> Optional[str]: """filters empty operands and joins the rest using operator. The function returns an expression string or None for the empty expression. """ operands = [_f for _f in operands if _f] if not operands: return None elif len(operands)==1: return operands[0] else: return operator.join([" (%s) "%op for op in operands]).strip()
[docs]def getSQLKey(key: str, value: Any, sqlPars: Dict[str, Any]) -> str: """adds value to sqlPars and returns a key for inclusion in a SQL query. This function is used to build parameter dictionaries for SQL queries, avoiding overwriting parameters with accidental name clashes. key usually a string matching the identifier pattern or a QuotedName (the latter are going to be horribly mogrified) As an extra service, if value is a list, it is turned into a set (rather than the default, which would be an array). We don't believe there's a great need to match against arrays. If you must match against arrays, use numpy arrays. >>> sqlPars = {} >>> getSQLKey("foo", 13, sqlPars) 'foo0' >>> getSQLKey("foo", 14, sqlPars) 'foo1' >>> getSQLKey("foo", 13, sqlPars) 'foo0' >>> sqlPars["foo0"], sqlPars["foo1"]; sqlPars = {} (13, 14) >>> "WHERE foo<%%(%s)s OR foo>%%(%s)s"%(getSQLKey("foo", 1, sqlPars), ... getSQLKey("foo", 15, sqlPars)) 'WHERE foo<%(foo0)s OR foo>%(foo1)s' >>> getSQLKey(utils.QuotedName("-x-"), "x", sqlPars) 'id2dx2d0' """ if isinstance(key, utils.QuotedName): key = key.makeIdentifier() if isinstance(value, list): value = frozenset(value) ct = 0 while True: dataKey = "%s%d"%(key, ct) if dataKey not in sqlPars or sqlPars[dataKey]==value: break ct += 1 sqlPars[dataKey] = value return dataKey
_REGISTRED_SQL_FACTORIES = {}
[docs]def registerSQLFactory( type: str, factory: Callable[["InputKey", Any, Dict[str, Any]], Optional[str]] ) -> None: """registers factory as an SQL factory for the type type (a string). A SQL factory turns expression of a special type (e.g., vexpr-float for VizieR-like float searches) into literal SQL. They receive an rscdef.OutputField, the expression, and the dictionary of SQL query parameters that they will in general amend (avoid doing value serialisation yourself; you will produce SQL injection surface). """ _REGISTRED_SQL_FACTORIES[type] = factory
def _getSQLForSequence( field: "InputKey", val: List[Any], sqlPars: Dict[str, Any]) -> str: if len(val)==0 or (len(val)==1 and val[0] is None): return "" return "%s IN %%(%s)s"%(field.name, getSQLKey(field.name, set(val), sqlPars)) _FloatLike = Union[float, datetime.date] def _convertIfFinite( val: _FloatLike, converter: Callable[[_FloatLike], _FloatLike]) -> _FloatLike: """returns converter(val) if val is a finite float. """ if minusInfinity<val<plusInfinity: # type: ignore # date comparison does not hurt here return converter(val) return val def _getSQLForInterval( field: "InputKey", val: List[_FloatLike], sqlPars: Dict[str, Any]) -> Optional[str]: """returns SQL for DALI intervals. This presumes that val is a 2-array of numbers and will return an empty condition otherwise. """ if len(val)!=2: return "" if field.hasProperty("database-column-is-date"): val = [_convertIfFinite(v, mjdToDateTime) for v in val] if val[1]==plusInfinity: return "%s > %%(%s)s"%(field.name, getSQLKey(field.name, val[0], sqlPars)) elif val[0]==minusInfinity: return "%s < %%(%s)s"%(field.name, getSQLKey(field.name, val[1], sqlPars)) else: return "%s BETWEEN %%(%s)s AND %%(%s)s"%(field.name, getSQLKey(field.name, val[0], sqlPars), getSQLKey(field.name, val[1], sqlPars)) def _getSQLForSimple( field: "InputKey", val: str, sqlPars: Dict[str, Any]) -> str: return "%s=%%(%s)s"%(field.name, getSQLKey(field.name, val, sqlPars)) def _getSQLFactory(field: "InputKey", value: str ) -> Callable[["InputKey", Any, Dict[str, Any]], Optional[str]]: """returns an SQL factory for matching field's values against value. """ if field.xtype=="interval": return _getSQLForInterval elif isinstance(value, (list, tuple)): return _getSQLForSequence elif field.type in _REGISTRED_SQL_FACTORIES: return _REGISTRED_SQL_FACTORIES[field.type] else: return _getSQLForSimple
[docs]def getSQLForField( field: "InputKey", inPars: Dict[str, Any], sqlPars: Dict[str, Any]) -> Optional[str]: """returns an SQL fragment for a column-like thing. This will be empty if no input in inPars is present. If it is, (a) new key(s) will be left in sqlPars. getSQLForField defines the default behaviour; in DBCore condDescs, it can be overridden using phrase makers. inPars is supposed to be "typed"; we do not catch general parse errors here. """ val = inPars.get(field.name) if val is None: return None # identifying 1-sequences with scalars probably was a bad idea, # but it's hard to get out of that now. if isinstance(val, (list, tuple)) and len(val)==1: val = val[0] # type: ignore # protected polymorphism factory = _getSQLFactory(field, val) return factory(field, val, sqlPars)
def _test(): import doctest doctest.testmod() if __name__=="__main__": _test()