Source code for gavo.base.sqlsupport

# -*- encoding: utf-8 -*-
"""
Basic support for communicating with the database server.

This is currently very postgres specific.  If we really wanted to
support some other database, this would need massive refactoring.
"""

#c Copyright 2008-2023, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL.  See the
#c COPYING file in the source distribution.


import contextlib
import functools
import os
import random
import re
import threading
import warnings
import weakref

import numpy

from gavo import utils
from gavo.base import config

debug = "GAVO_SQL_DEBUG" in os.environ

import psycopg2
import psycopg2.extensions
import psycopg2.pool

from psycopg2.extras import DictCursor #noflake: exported name

[docs]class Error(utils.Error): pass
NUMERIC_TYPES = frozenset(["smallint", "integer", "bigint", "real", "double precision"]) ORDERED_TYPES = frozenset(["timestamp", "text", "unicode"]) | NUMERIC_TYPES _PG_TIME_UNITS = { "ms": 0.0001, "s": 1., "": 1., "min": 60., "h": 3600., "d": 86400.,}
[docs]class SqlSetAdapter(object): """is an adapter that formats python sequences as SQL sets. -- as opposed to psycopg2's apparent default of building arrays out of them. """ def __init__(self, seq): self._seq = seq
[docs] def prepare(self, conn): pass
[docs] def getquoted(self): qobjs = [] for o in self._seq: if isinstance(o, str): qobjs.append(psycopg2.extensions.adapt(str(o)).getquoted()) else: qobjs.append(psycopg2.extensions.adapt(o).getquoted()) return b'(%s)'%(b", ".join(qobjs))
__str__ = getquoted
[docs]class SqlArrayAdapter(object): """An adapter that formats python lists as SQL arrays This makes, in the shameful tradition of VOTable, empty arrays equal to NULL. """ def __init__(self, seq): self._seq = seq for item in seq: if item is not None: self.itemType = type(item) break else: self.itemType = None
[docs] def prepare(self, conn): pass
def _addCastIfNecessary(self, serializedList): """adds a typecast to serializedList if it needs one. This is when all entries in serializedList are NULL; so, we're fine anyway if the first element is non-NULL; if it's not, we try to guess. serializedList is changed in place, the method returns nothing. """ if not serializedList or serializedList[0]!=b"NULL": return if isinstance(self._seq, utils.floatlist): serializedList[0] = b"NULL::REAL" elif isinstance(self._seq, utils.intlist): serializedList[0] = b"NULL::INTEGER"
[docs] def getquoted(self): if len(self._seq)==0: return b'NULL' if self.itemType and issubclass(self.itemType, str): # I need to be a bit verbose here because psycopg's default # encoding still is latin-1, and it seems there's no better # way to force it to utf-8 than this: qobjs = [] for o in self._seq: item = psycopg2.extensions.adapt(o) item.encoding = "utf-8" qobjs.append(item.getquoted()) else: qobjs = [psycopg2.extensions.adapt(o).getquoted() for o in self._seq] self._addCastIfNecessary(qobjs) return b'ARRAY[ %s ]'%(b", ".join(qobjs))
__str__ = getquoted
[docs]class FloatableAdapter(object): """An adapter for things that do "float", in particular numpy.float* """ def __init__(self, val): self.val = float(val)
[docs] def prepare(self, conn): pass
[docs] def getquoted(self): if self.val!=self.val: return b"'nan'::real" else: return repr(self.val).encode("ascii")
__str__ = getquoted
[docs]class IntableAdapter(object): """An adapter for things that do "int", in particular numpy.int* """ def __init__(self, val): self.val = int(val)
[docs] def prepare(self, conn): pass
[docs] def getquoted(self): return str(self.val).encode("ascii")
__str__ = getquoted
[docs]class NULLAdapter(object): """An adapter for things that should end up as NULL in the DB. """ def __init__(self, val): # val doesn't matter, we're making it NULL anyway pass
[docs] def prepare(self, conn): pass
[docs] def getquoted(self): return b"NULL"
__str__ = getquoted
psycopg2.extensions.register_adapter(list, SqlArrayAdapter) psycopg2.extensions.register_adapter(numpy.ndarray, SqlArrayAdapter) psycopg2.extensions.register_adapter(tuple, SqlSetAdapter) psycopg2.extensions.register_adapter(set, SqlSetAdapter) psycopg2.extensions.register_adapter(frozenset, SqlSetAdapter) for numpyType, adapter in [ ("float32", FloatableAdapter), ("float64", FloatableAdapter), ("float96", FloatableAdapter), ("int8", IntableAdapter), ("int16", IntableAdapter), ("int32", IntableAdapter), ("int64", IntableAdapter),]: try: psycopg2.extensions.register_adapter( getattr(numpy, numpyType), adapter) except AttributeError: # pragma: no cover # types not present on the python end we don't need to adapt pass # Override psycopg2's mapping of numeric to decimal, because our # serialisers (votable, fits, json) don't really work with decimal. psycopg2.extensions.register_type( psycopg2.extensions.new_type( psycopg2.extensions.DECIMAL.values, "numeric_float", lambda value, cursor: float(value) if value is not None else None)) from gavo.utils import pyfits psycopg2.extensions.register_adapter(pyfits.Undefined, NULLAdapter) from psycopg2 import (OperationalError, #noflake: exported names DatabaseError, IntegrityError, ProgrammingError, InterfaceError, DataError, InternalError) from psycopg2.extensions import QueryCanceledError #noflake: exported name from psycopg2 import Error as DBError
[docs]class DebugCursor(psycopg2.extensions.cursor): # pragma: no cover
[docs] def execute(self, sql, args=None): print("Executing %s %s"%(id(self.connection), sql)) psycopg2.extensions.cursor.execute(self, sql, args) print("Finished %s %s"%(id(self.connection), self.query.decode("utf-8"))) return self.rowcount
[docs] def executemany(self, sql, args=[]): print("Executing many", sql) print(("%d args, first one:\n%s"%(len(args), args[0]))) res = psycopg2.extensions.cursor.executemany(self, sql, args) print("Finished many", self.query.decode("utf-8")) return res
[docs]class GAVOConnection(psycopg2.extensions.connection): """A psycopg2 connection with some additional methods. This derivation is also done so we can attach the getDBConnection arguments to the connection; it is used when recovering from a database restart. """ # extensionFunctions is filled senseEnvironment (and contains names of # postgres extension functions that might modify our behaviour) extensionFunctions = []
[docs] @classmethod def senseEnvironment(cls, conn): """configures us depending on what is in the database. The argument needs to be a connection to the database we will connect to. In practice, _initPsycopg calls this once during DaCHS startup. """ cls.extensionFunctions = frozenset(r[0] for r in conn.query( "SELECT proname FROM pg_proc WHERE" " proname in ('epoch_prop', 'q3c_ang2ipix'," " 'smoc_union', 'healpix_nest')"))
[docs] def getParameter(self, key, cursor=None): """returns the value of the postgres parameter key. This returns unprocessed values, probably almost always as strings. Caveat emptor. The main purpose of this function is to help the parameters connection manager, so users shouldn't really mess with it. """ cursor = cursor or self.cursor() if not re.match("[A-Za-z_]+", key): raise ValueError("Invalid settings key: %s"%key) cursor.execute("SHOW %s"%key) return list(cursor)[0][0]
[docs] def configure(self, settings, cursor=None): """sets a number of postgres connection parameters. settings is a list of (parameter, value) pairs, where value must be a python value that psycopg2 understands and that works for the parameter in question. This returns a settings-list that restores the previous values when passed to configure() """ cursor = cursor or self.cursor() resetTo = [] for key, _ in settings: resetTo.append((key, self.getParameter(key, cursor))) for key, val in settings: cursor.execute("SET %s=%%(val)s"%key, {"val": val}) return resetTo
[docs] @contextlib.contextmanager def parameters(self, settings, cursor=None): """executes a block with a certain set of parameters on a connection, resetting them to their original value again afterwards. Of course, this only works as expected if you're not sharing your connections to widely. This rolls back the connection by itself on database errors; we couldn't reset the parameters otherwise. """ cursor = cursor or self.cursor() resetTo = self.configure(settings, cursor) try: yield except Exception as ex: try: if isinstance(ex, psycopg2.Error): self.rollback() self.configure(resetTo) except psycopg2.Error: # we believe the connection was already closed and don't bother pass raise self.configure(resetTo, cursor)
[docs] def queryToDicts(self, query, args={}, timeout=None, caseFixer=None): """iterates over dictionary rows for query. This is a thin wrapper around query(yieldDicts=True) provided for convenience and backwards compatibility. """ return self.query(query, args, timeout, True, caseFixer)
[docs] def query(self, query, args={}, timeout=None, yieldDicts=False, caseFixer=None): """iterates over result tuples for query. This is mainly for ad-hoc queries needing little metadata. You can pass yieldDicts=True to get dictionaries instead of tuples. The dictionary keys are determined by what the database says the column titles are; thus, it's usually lower-cased variants of what's in the select-list. To fix this, you can pass in a caseFixer dict that gives a properly cased version of lowercase names. Timeout is in seconds. Warning: this is an iterator, so unless you iterate over the result, the query will not get executed. Hence, for non-select statements you will generally have to use conn.execute. """ cursor = self.cursor() params = [] if timeout is not None: params.append(("statement_timeout", "%s ms"%int(timeout*1000))) try: with self.parameters(params, cursor): cursor.execute(query, args) if yieldDicts: keys = [cd[0] for cd in cursor.description] if caseFixer: keys = [caseFixer.get(key, key) for key in keys] for row in cursor: yield dict(list(zip(keys, row))) else: for row in cursor: yield row finally: cursor.close()
[docs] def execute(self, query, args={}): """executes query in a cursor. This returns the rowcount of the cursor used. """ cursor = self.cursor() try: cursor.execute(query, args) return cursor.rowcount finally: cursor.close()
[docs] @contextlib.contextmanager def savepoint(self): """sets up a section protected by a savepoint that will be released after use. If an exception happens in the controlled section, the connection will be rolled back to the savepoint. """ savepointName = "auto_%s"%(random.randint(0, 2147483647)) self.execute("SAVEPOINT %s"%savepointName) try: yield except: self.execute("ROLLBACK TO SAVEPOINT %s"%savepointName) raise finally: self.execute("RELEASE SAVEPOINT %s"%savepointName)
[docs]class DebugConnection(GAVOConnection): # pragma: no cover
[docs] def cursor(self, *args, **kwargs): kwargs["cursor_factory"] = DebugCursor return psycopg2.extensions.connection.cursor(self, *args, **kwargs)
[docs] def commit(self): print("Commit %s"%id(self)) return GAVOConnection.commit(self)
[docs] def rollback(self): print("Rollback %s"%id(self)) return GAVOConnection.rollback(self)
[docs] def getPID(self): cursor = self.cursor() cursor.execute("SELECT pg_backend_pid()") pid = list(cursor)[0][0] cursor.close() return pid
[docs]def getDBConnection(profile, debug=debug, autocommitted=False): """returns an enhanced database connection through profile. You will typically rather use the context managers for the standard profiles (``getTableConnection`` and friends). Use this function if you want to keep your connection out of connection pools or if you want to use non-standard profiles. profile will usually be a string naming a profile defined in ``GAVO_ROOT/etc``. """ if isinstance(profile, str): profile = config.getDBProfile(profile) if debug: # pragma: no cover conn = psycopg2.connect(connection_factory=DebugConnection, **profile.getArgs()) print("NEW CONN using %s (%s)"%(profile.name, conn.getPID()), id(conn)) def closer(): print("CONNECTION CLOSE", id(conn)) return DebugConnection.close(conn) conn.close = closer else: try: conn = psycopg2.connect(connection_factory=GAVOConnection, **profile.getArgs()) except OperationalError as msg: raise utils.ReportableError("Cannot connect to the database server." " The database library reported:\n\n%s"%str(msg), hint="This usually means you must adapt either the access profiles" " in $GAVO_DIR/etc or your database config (in particular," " pg_hba.conf).") if autocommitted: conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) conn.set_client_encoding("UTF8") conn._getDBConnectionArgs = { "profile": profile, "debug": debug, "autocommitted": autocommitted} return conn
def _parseTableName(tableName, schema=None): """returns schema, unqualified table name for the arguments. schema=None selects the default schema (public for postgresql). If tableName is qualified (i.e. schema.table), the schema given in the name overrides the schema argument. We do not support delimited identifiers for tables in DaCHS. Hence, this will raise a ValueError if anything that wouldn't work as an SQL regular identifier (except we don't filter for reserved words yet, which is an implementation detail that might change). """ parts = tableName.split(".") if len(parts)>2: raise ValueError("%s is not a SQL regular identifier"%repr(tableName)) for p in parts: if not utils.identifierPattern.match(p): raise ValueError("%s is not a SQL regular identifier"%repr(tableName)) if len(parts)==1: name = parts[0] else: schema, name = parts if schema is None: schema = "public" return schema.lower(), name.lower() def _parseBannerString(bannerString): """returns digits from a postgres server banner. This hardcodes the response given by postgres 8 and raises a ValueError if the expected format is not found. """ mat = re.match(r"PostgreSQL ([\d.]*)", bannerString) if not mat: raise ValueError("Cannot make out the Postgres server version from %s"% repr(bannerString)) return tuple(int(s) for s in mat.group(1).split("."))
[docs]def getPgVersion(digits=2): """returns the version number of the postgres server executing untrusted (ADQL) queries. This is relatively expensive, as it will actually ask the server. """ with getUntrustedConn() as conn: bannerString = list(conn.query("SELECT version()"))[0][0] return _parseBannerString(bannerString)
[docs]class PostgresQueryMixin(object): """is a mixin containing various useful queries that are postgres specific. This mixin expects a parent that mixes is QuerierMixin (that, for now, also mixes in PostgresQueryMixin, so you won't need to mix this in). """
[docs] def getPrimaryIndexName(self, tableName): """returns the name of the index corresponding to the primary key on (the unqualified) tableName. """ return ("%s_pkey"%tableName).lower()
[docs] def schemaExists(self, schema): """returns True if the named schema exists in the database. """ matches = list(self.connection.query("SELECT nspname FROM" " pg_namespace WHERE LOWER(nspname)=%(schemaName)s", { 'schemaName': schema.lower(), })) return len(matches)!=0
[docs] def hasIndex(self, tableName, indexName, schema=None): """returns True if table tablename has and index called indexName. See _parseTableName on the meaning of the arguments. """ schema, tableName = _parseTableName(tableName, schema) res = list(self.connection.query("SELECT indexname FROM" " pg_indexes WHERE schemaname=lower(%(schema)s) AND" " tablename=lower(%(tableName)s) AND" " indexname=lower(%(indexName)s)", locals())) return len(list(res))>0
def _getColIndices(self, relOID, colNames): """returns a sorted tuple of column indices of colNames in the relation relOID. This really is a helper for foreignKeyExists. """ colNames = set(n.lower() for n in colNames) res = [r[0] for r in self.connection.query("SELECT attnum FROM pg_attribute WHERE" " attrelid=%(relOID)s and attname IN %(colNames)s", locals())] res.sort() return res
[docs] def getForeignKeyName(self, srcTableName, destTableName, srcColNames, destColNames, schema=None): """returns True if there's a foreign key constraint on srcTable's srcColNames using destTableName's destColNames. Warning: names in XColNames that are not column names in the respective tables are ignored. This raises a ValueError if the foreign keys do not exist. """ try: srcOID = self.getOIDForTable(srcTableName, schema) srcColInds = self._getColIndices( #noflake: used in locals() srcOID, srcColNames) destOID = self.getOIDForTable(destTableName, schema) destColInds = self._getColIndices( #noflake: used in locals() destOID, destColNames) except Error: # Some of the items related probably don't exist return False res = list(self.connection.query("""SELECT conname FROM pg_constraint WHERE contype='f' AND conrelid=%(srcOID)s AND confrelid=%(destOID)s AND conkey=%(srcColInds)s::SMALLINT[] AND confkey=%(destColInds)s::SMALLINT[]""", locals())) if len(res)==1: return res[0][0] else: raise ValueError("Non-existing or ambiguous foreign key")
[docs] def foreignKeyExists(self, srcTableName, destTableName, srcColNames, destColNames, schema=None): try: _ = self.getForeignKeyName( #noflake: ignored value srcTableName, destTableName, srcColNames, destColNames, schema) return True except ValueError: return False
@functools.lru_cache() def _resolveTypeCode(self, oid): """returns a textual description for a type oid as returned by cursor.description. These descriptions are *not* DDL-ready. There's the *** postgres specific *** """ res = list(self.connection.query( "select typname from pg_type where oid=%(oid)s", {"oid": oid})) return res[0][0]
[docs] def getColumnsFromDB(self, tableName): """returns a sequence of (name, type) pairs of the columns this table has in the database. If the table is not on disk, this will raise a NotFoundError. *** psycopg2 specific *** """ # _parseTableName bombs out on non-regular identifiers, hence # foiling a possible SQL injection _parseTableName(tableName) cursor = self.connection.cursor() try: cursor.execute("select * from %s limit 0"%tableName) return [(col.name, self._resolveTypeCode(col.type_code)) for col in cursor.description] finally: cursor.close()
[docs] def getRowEstimate(self, tableName): """returns the size of the table in rows as estimated by the query planner. This will raise a KeyError with tableName if the table isn't known to postgres. """ res = list(self.connection.query( "SELECT reltuples FROM pg_class WHERE oid = %(tableName)s::regclass", locals())) # this is guaranteed to return something because of the ::regclass # cast that will fail for non-existing tables. return int(res[0][0])
[docs] def roleExists(self, role): """returns True if there role is known to the database. """ matches = list(self.connection.query( "SELECT usesysid FROM pg_user WHERE usename=%(role)s", locals())) return len(matches)!=0
[docs] def getOIDForTable(self, tableName, schema=None): """returns the current oid of tableName. tableName may be schema qualified. If it is not, public is assumed. """ schema, tableName = _parseTableName(tableName, schema) return list(self.connection.query( "SELECT %(tableName)s::regclass::bigint", {"tableName": f"{schema}.{tableName}"}))[0][0]
[docs] def getTableType(self, tableName, schema=None): """returns the type of the relation relationName. If relationName does not exist, None is returned. Otherwise, it's what is in the information schema for the table, which for postgres currently is one of BASE TABLE, VIEW, FOREIGN TABLE, MATERIALIZED VIEW, or LOCAL TEMPORARY. The DaCHS-idiomatic way to see if a relation exists is getTableType() is not None. You can pass in schema-qualified relation names, or the relation name and the schema separately. *** postgres specific *** """ schema, tableName = _parseTableName(tableName, schema) res = list( self.connection.query("""SELECT table_name, table_type FROM information_schema.tables WHERE ( table_schema=%(schemaName)s OR table_type='LOCAL TEMPORARY') AND table_name=%(tableName)s""", { 'tableName': tableName.lower(), 'schemaName': schema.lower()})) if not res: # materialised views are not yet in information_schema.tables, # so we try again with a special postgres case. if list(self.connection.query( "select table_name from information_schema.tables" " where table_name='pg_matviews'")): res = list( self.connection.query( """SELECT matviewname, 'MATERIALIZED VIEW' AS table_type FROM pg_matviews WHERE schemaname=%(schemaName)s AND matviewname=%(tableName)s""", { 'tableName': tableName.lower(), 'schemaName': schema.lower()})) if not res: return None assert len(res)==1 return res[0][1]
[docs] def dropTable(self, tableName, cascade=False): """drops a table or view named by tableName. This does not raise an error if no such relation exists. *** postgres specific *** """ tableType = self.getTableType(tableName) if tableType is None: return dropQualification = { "VIEW": "VIEW", "MATERIALIZED VIEW": "MATERIALIZED VIEW", "FOREIGN TABLE": "FOREIGN TABLE", "BASE TABLE": "TABLE", "LOCAL TEMPORARY": "TABLE"}[tableType] self.connection.execute("DROP %s %s %s"%( dropQualification, tableName, "CASCADE" if cascade else ""))
[docs] def getSchemaPrivileges(self, schema): """returns (owner, readRoles, allRoles) for schema's ACL. """ res = list(self.connection.query("SELECT nspacl FROM pg_namespace WHERE" " nspname=%(schema)s", locals())) return self.parsePGACL(res[0][0])
[docs] def getTablePrivileges(self, schema, tableName): """returns (owner, readRoles, allRoles) for the relation tableName and the schema. *** postgres specific *** """ res = list(self.connection.query("SELECT relacl FROM pg_class WHERE" " lower(relname)=lower(%(tableName)s) AND" " relnamespace=(SELECT oid FROM pg_namespace WHERE nspname=%(schema)s)", locals())) try: return self.parsePGACL(res[0][0]) except IndexError: # Table doesn't exist, so no privileges return {}
_privTable = { "arwdRx": "ALL", "arwdDxt": "ALL", "arwdRxt": "ALL", "arwdxt": "ALL", "r": "SELECT", "UC": "ALL", "U": "USAGE", }
[docs] def parsePGACL(self, acl): """returns a dict roleName->acl for acl in postgres' ACL serialization. """ if acl is None: return {} res = [] for acs in re.match("{(.*)}", acl).group(1).split(","): if acs!='': # empty ACLs don't match the RE, so catch them here role, privs, granter = re.match("([^=]*)=([^/]*)/(.*)", acs).groups() res.append((role, self._privTable.get(privs, "READ"))) return dict(res)
[docs] def getACLFromRes(self, thingWithPrivileges): """returns a dict of (role, ACL) as it is defined in thingWithPrivileges. thingWithPrivileges is something mixing in rscdef.common.PrivilegesMixin. (or has readProfiles and allProfiles attributes containing sequences of profile names). """ res = [] if hasattr(thingWithPrivileges, "schema"): # it's an RD readRight = "USAGE" else: readRight = "SELECT" for profile in thingWithPrivileges.readProfiles: res.append((config.getDBProfile(profile).roleName, readRight)) for profile in thingWithPrivileges.allProfiles: res.append((config.getDBProfile(profile).roleName, "ALL")) return dict(res)
[docs]class StandardQueryMixin(object): """is a mixin containing various useful queries that should work against all SQL systems. This mixin expects a parent that mixes is QuerierMixin (that, for now, also mixes in StandardQueryMixin, so you won't need to mix this in). The parent also needs to mix in something like PostgresQueryMixin (I might want to define an interface there once I'd like to support other databases). """
[docs] def setSchemaPrivileges(self, rd): """sets the privileges defined on rd to its schema. This function will never touch the public schema. """ schema = rd.schema.lower() if schema=="public": return self._updatePrivileges("SCHEMA %s"%schema, self.getSchemaPrivileges(schema), self.getACLFromRes(rd))
[docs] def setTablePrivileges(self, tableDef): """sets the privileges defined in tableDef for that table through querier. """ self._updatePrivileges(tableDef.getQName(), self.getTablePrivileges(tableDef.rd.schema, tableDef.id), self.getACLFromRes(tableDef))
def _updatePrivileges(self, objectName, foundPrivs, shouldPrivs): """is a helper for set[Table|Schema]Privileges. Requests for granting privileges not known to the database are ignored, but a log entry is generated. """ for role in set(foundPrivs)-set(shouldPrivs): if role: self.connection.execute("REVOKE ALL PRIVILEGES ON %s FROM %s"%( objectName, role)) for role in set(shouldPrivs)-set(foundPrivs): if role: if self.roleExists(role): self.connection.execute( "GRANT %s ON %s TO %s"%(shouldPrivs[role], objectName, role)) else: utils.sendUIEvent("Warning", "Request to grant privileges to non-existing" " database user %s dropped"%role) for role in set(shouldPrivs)&set(foundPrivs): if role: if shouldPrivs[role]!=foundPrivs[role]: self.connection.execute("REVOKE ALL PRIVILEGES ON %s FROM %s"%( objectName, role)) self.connection.execute("GRANT %s ON %s TO %s"%(shouldPrivs[role], objectName, role))
[docs]class QuerierMixin(PostgresQueryMixin, StandardQueryMixin): """is a mixin for "queriers", i.e., objects that maintain a db connection. The mixin assumes an attribute connection from the parent. """ defaultProfile = None # _reconnecting is used in query _reconnecting = False
[docs] def enableAutocommit(self): self.connection.set_isolation_level( psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
def _queryReconnecting(self, query, data, timeout): """helps query in case of disconnections. """ self.connection = getDBConnection( **self.connection._getDBConnectionArgs) self._reconnecting = True res = self.connection.query(query, data, timeout) self._reconnection = False return res
[docs] def query(self, query, data={}, timeout=None): """wraps conn.query adding logic to re-establish lost connections. Don't use this method any more in new code. It contains wicked logic to tell DDL statements (that run without anyone pulling the results) from actual selects. That's a bad API. Also note that the timeout is ignored for DDL statements. We'll drop this some time in 2023. Use either connection.query or connection.execute in new code. """ warnings.warn("You are using querier.query (or perhaps table.query)." " This has terrible semantics; use querier.connection.query" " for statements returning rows and .execute for DDL statements.", category=FutureWarning) if self.connection is None: raise utils.ReportableError( "SimpleQuerier connection is None.", hint="This usually is because an AdhocQuerier's query method" " was used outside of a with block.") try: if query[:5].lower() in ["selec", "with "]: return self.connection.query(query, data, timeout) else: # it's DDL that we execute directly, ignoring the timeout self.connection.execute(query, data) except DBError as ex: if isinstance(ex, OperationalError) and self.connection.fileno()==-1: if not self._reconnecting: return self._queryReconnecting(query, data, timeout) raise
[docs] def queryToDicts(self, *args, **kwargs): """wraps conn.queryToDicts for backwards compatilitiy. """ return self.connection.queryToDicts(*args, **kwargs)
[docs]class UnmanagedQuerier(QuerierMixin): """A simple interface to querying the database through a connection managed by someone else. This is typically used as in:: with base.getTableConn() as conn: q = UnmanagedQuerier(conn) ... This contains numerous methods abstracting DB functionality a bit. Documented ones include: * schemaExists(schema) * getColumnsFromDB(tableName) * getTableType(tableName) -- this will return None for non-existing tables, which is DaCHS' official way to determine table existence. """ def __init__(self, connection): self.connection = connection
[docs]class AdhocQuerier(QuerierMixin): """A simple interface to querying the database through pooled connections. These are constructed using the connection getters (getTableConn (default), getAdminConn) and then serve as context managers, handing back the connection as you exit the controlled block. Since they operate through pooled connections, no transaction management takes place. These are typically for read-only things. You can use the query method and everything that's in the QuerierMixin. """ def __init__(self, connectionManager=None): if connectionManager is None: self.connectionManager = getTableConn else: self.connectionManager = connectionManager self.connection = None def __enter__(self): self._cm = self.connectionManager() self.connection = self._cm.__enter__() return self def __exit__(self, *args): self.connection = None return self._cm.__exit__(*args)
[docs]class NonBlockingQuery: """a query run in a pseudo-nonblocking way. While psycopg2 can do proper async, that doesn't play well with about everything else DaCHS is doing so far. So, here's a quick way to allow long-running queries that users can still interrupt. The ugly secret is that it's based on threads. This should not be used within the server. We might want to port the async taprunner (which runs outside of the server) to using this, though. To use it, construct it with conn, query and perhaps args and use it as a context manager. Wait for its result attribute to become non-None; this will then be either a list of result rows or an Exception (which will also be raised when exiting the context manager). To abort a running query, call abort(). """ def __init__(self, conn, query, args={}): self.conn, self.query, self.args = conn, query, args self.backendPID = list(self.conn.query("SELECT pg_backend_pid()"))[0] # will be set only from the thread self.result = None def __enter__(self): self.thread = threading.Thread(target=self._runQuery) self.thread.setDaemon(True) self.thread.start() return self def __exit__(self, *excInfo): self.cleanup(1) if excInfo==(None, None, None) and isinstance(self.result, Exception): # this probably is a memory leak, which is one of the reasons # this shouldn't be used in the server without more thought raise self.result return False def _runQuery(self): try: self.result = list(self.conn.query(self.query, self.args)) except QueryCanceledError: # assume this happened on user request pass except Exception as ex: self.result = ex
[docs] def abort(self): """aborts the current query and reaps the thread. """ self.conn.cancel() self.cleanup(1)
[docs] def cleanup(self, timeout=None): """tries to reap the thread (i.e., join it). If the thread hasn't terminated within timeout seconds, a sqlsupport.Error is raised. """ self.thread.join(timeout=timeout) if self.thread.is_alive(): raise Error("Could not join NonBlockingQuery")
[docs]def setDBMeta(conn, key, value): """adds/overwrites (key, value) in the dc.metastore table within conn. conn must be an admin connection; this does not commit. key must be a string, value something unicodeable. """ conn.execute( "INSERT INTO dc.metastore (key, value) VALUES (%(key)s, %(value)s)", { 'key': key, 'value': str(value)})
[docs]def getDBMeta(key): """returns the value for key from within dc.metastore. This always returns a unicode string. Type conversions are the client's business. If no value exists, this raises a KeyError. """ with getTableConn() as conn: res = list(conn.query("SELECT value FROM dc.metastore WHERE" " key=%(key)s", {"key": key})) if not res: raise KeyError(key) return res[0][0]
[docs]class CustomConnectionPool(psycopg2.pool.ThreadedConnectionPool): """A threaded connection pool that returns connections made via profileName. """ # we keep weak references to pools we've created so we can invalidate # them all on a server restart to avoid having stale connections # around. knownPools = [] def __init__(self, minconn, maxconn, profileName, autocommitted=True): # make sure no additional arguments come in, since we don't # support them. self.profileName = profileName self.autocommitted = autocommitted self.stale = False psycopg2.pool.ThreadedConnectionPool.__init__( self, minconn, maxconn) self.knownPools.append(weakref.ref(self))
[docs] @classmethod def serverRestarted(cls): utils.sendUIEvent("Warning", "Suspecting a database restart." " Discarding old connection pools, asking to create new ones.") for pool in cls.knownPools: try: pool().stale = True except AttributeError: # already gone pass # we risk a race condition here; this is used rarely enough that this # shouldn't matter. cls.knownPools = []
def _connect(self, key=None): """creates a new connection with our selected profile and assigns it to key if not None. This is an implementation detail of psycopg2's connection pools. """ conn = getDBConnection(self.profileName) if self.autocommitted: try: conn.set_session(autocommit=True, readonly=True) except ProgrammingError: utils.sendUIEvent("Warning", "Uncommitted transaction escaped; please" " investigate and fix") conn.commit() if key is not None: self._used[key] = conn self._rused[id(conn)] = key else: self._pool.append(conn) return conn
def _cleanupAfterDBError(ex, conn, pool, poolLock): """removes conn from pool after an error occurred. This is a helper for getConnFromPool below. """ if isinstance(ex, OperationalError) and ex.pgcode is None: # this is probably a db server restart. Invalidate all connections # immediately. with poolLock: if pool: pool[0].serverRestarted() # Make sure the connection is closed; something bad happened # in it, so we don't want to re-use it try: pool[0].putconn(conn, close=True) except InterfaceError: # Connection already closed pass except Exception as msg: utils.sendUIEvent("Error", "Disaster: %s while force-closing connection"%msg) def _makeConnectionManager(profileName, autocommitted=True, singleton=False): """returns a context manager for a connection pool for profileName connections. With singleton=True, only one connection will be created rather than [db]poolSize ones. """ pool = [] poolLock = threading.Lock() def makePool(): if singleton: minConn = 1 else: minConn = config.get("db", "poolSize") with poolLock: pool.append(CustomConnectionPool( minConn, # I don't think there's any point in maxConn at all the # way psycopg pools are done right now, so all I care about # here is that it won't get in our way. 200, profileName, autocommitted)) def getConnFromPool(): # we delay pool creation since these functions are built during # sqlsupport import. We probably don't have profiles ready # at that point. if not pool: makePool() if pool[0].stale: pool[0].closeall() pool.pop() makePool() conn = pool[0].getconn() try: yield conn except Exception as ex: # controlled block bombed out, do error handling _cleanupAfterDBError(ex, conn, pool, poolLock) raise else: # no exception raised, commit if not autocommitted if not autocommitted: conn.commit() try: pool[0].putconn(conn, close=conn.closed) except InterfaceError: # Connection already closed pass return contextlib.contextmanager(getConnFromPool) getUntrustedConn = _makeConnectionManager("untrustedquery") getTableConn = _makeConnectionManager("trustedquery") getAdminConn = _makeConnectionManager("admin", singleton=True) getWritableUntrustedConn = _makeConnectionManager("untrustedquery", autocommitted=False, singleton=True) getWritableTableConn = _makeConnectionManager("trustedquery", autocommitted=False, singleton=True) getWritableAdminConn = _makeConnectionManager("admin", autocommitted=False, singleton=True)
[docs]def initPsycopg(): """does any DaCHS-specific database setup necessary. This is executed on sqlsupport import unless we are in initdachs (or setting up the testbed); see the foot of this module for how this is done. This needs to call the GAVOConnection.senseEnvironment. """ conn = psycopg2.connect(connection_factory=GAVOConnection, **config.getDBProfile("feed").getArgs()) try: try: from gavo.utils import pgsphere pgsphere.preparePgSphere(conn) except: # prama: no cover warnings.warn("pgsphere missing -- ADQL, pg-SIAP, and SSA will not work") GAVOConnection.senseEnvironment(conn) finally: conn.close()
if "GAVO_INIT_RUNNING" not in os.environ: initPsycopg()