"""
Infrastructure for obtaining metadata about on tables and
columns in the data center.
This has the command line interface for dachs info, and the annotation
machinery is also used by dachs limits; the common functionality should
probably move to rsc at some point (cf. rsc.dbtable.annotateDBTable).
The core here is annotateDBTable.  This will gather various pieces
of table metadata and fetch column metadata trhough _annotateColumns.  That,
in turn, constructs a query fetching the metadata from the database.  Since
writing this query is a bit involved, it is done in terms of a sequence
of AnnotationMakers.  These combine SQL making (through OutputFields)
and then pulling the results out of the database result (in their
annotate methods).  The end result is that keys in the columns'
annotations dictionaries are added.
"""
#c Copyright 2008-2023, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL.  See the
#c COPYING file in the source distribution.
import datetime
import math
import time
from gavo import api
from gavo import base
from gavo import stc
from gavo import svcs
from gavo import utils
from gavo.protocols import scs
[docs]def getDefaultSamplePercent(tableSize):
	"""returns a hopefully sensible value for samplePercent depending
	on the tableSize.
	This is based on the gut feeling that up to 1e6 rows, we can just scan all,
	whereas for 1e7 50% is fine, and then: 1e8: 20%, 1e9: 10%, 1e10: 5%.
	I *think* there might be a theory for that.
	"""
	if not tableSize:
		return 100
	sf = (math.log10(tableSize)-6)
	if sf<=0:
		return 100
	return max(int(100*(1/(1+sf**1.5))), 1) 
[docs]class AnnotationMaker(object):
	"""A class for producing column annotations.
	
	An annotation simply is a dictionary with some well-known keys.  They
	are generated from DB queries.  It is this class' responsibility
	to collect the DB query result columns pertaining to a column and
	produce the annotation dictionary from them.
	To make this happen, it is constructed with the column; then, for
	each property queried, addPropertyKey is called.  Finally, addAnnotation
	is called with the DB result row (see annotateDBTable) to actually
	make and attach the dictionary.
	"""
	def __init__(self, column):
		self.column = column
		if not hasattr(self.column, "annotations"):
			self.column.annotations = {}
		self.propDests = {}
[docs]	def doesWork(self):
		"""returns a true value if the annotator will contribute a query.
		"""
		return self.propDests 
[docs]	def getOutputFieldFor(self, propName, propFunc, nameMaker, extractor=None):
		"""returns an OutputField that will generate a propName annotation
		from the propFunc function.
		propFunc for now has a %(name)s where the column name must be
		inserted.
		nameMaker is something like a base.VOTNameMaker.
		extractor can be a callable receiving the result of propFunc
		and the annotation dictionary; it the must modify the annotation
		dictionary to reflect the result (the default is to just
		add the result under propName).
		"""
		destCol = nameMaker.makeName(propName+"_"+self.column.key)
		self.propDests[destCol] = extractor or propName
		ofield = api.makeStruct(svcs.OutputField,
			name=destCol,
			select=propFunc%{"name": self.column.name},
			type=self.column.type)
		return ofield 
[docs]	def annotate(self, resultRow):
		"""builds an annotation of the column from resultRow.
		resultRow is a dictionary containing values for all keys registered
		through addPropertyKey.
		If the column already has an annotation, only the new keys will be
		overwritten.
		"""
		for srcKey, destKey in self.propDests.items():
			if resultRow[srcKey] is None:
				continue
			if isinstance(destKey, str):
				self.column.annotations[destKey] = resultRow[srcKey]
			else:
				destKey(self.column.annotations, resultRow[srcKey])  
[docs]def getAnnotators(td):
	"""returns a pair of output fields and annotators to gather
	column statistcs for td.
	The annotators are AnnotationMaker instances that will add the
	relevant annotations to td's columns.
	The rules applying are given in annotateDBTable.
	"""
	outputFields, annotators = [], []
	nameMaker = base.VOTNameMaker()
	for col in td:
		if col.getProperty("statistics", None)=="no":
			continue
		annotator = AnnotationMaker(col)
		if col.type in base.ORDERED_TYPES or col.type.startswith("char"):
			outputFields.append(annotator.getOutputFieldFor("max_value",
				"MAX(%(name)s)", nameMaker))
			outputFields.append(annotator.getOutputFieldFor("min_value",
				"MIN(%(name)s)", nameMaker))
		
		if col.type in base.NUMERIC_TYPES:
			outputFields.append(annotator.getOutputFieldFor("percentiles",
				"percentile_cont(ARRAY[0.03, 0.5, 0.97]) WITHIN GROUP"
					" (ORDER BY %(name)s)", nameMaker,
				extractor=lambda annotations, res:
					annotations.update(dict(zip(
						["percentile03", "median", "percentile97"], res)))))
		if col.getProperty("statistics", None)=="enumerate":
			assert col.type=="text"
			outputFields.append(annotator.getOutputFieldFor("discrete_values",
				"(SELECT jsonb_object_agg(COALESCE(val, 'NULL'), ct) FROM ("
					" SELECT %(name)s AS val, count(*) AS ct"
					" FROM {} GROUP BY %(name)s) AS {})".format(
						td.getQName(),
						"subquery_"+col.name),
					nameMaker,
					extractor=_normalizeDist))
		outputFields.append(annotator.getOutputFieldFor("fill_factor",
			"AVG(CASE WHEN %(name)s IS NULL THEN 0 ELSE 1 END)",
			nameMaker))
		if annotator.doesWork():
			annotators.append(annotator)
	return outputFields, annotators 
def _normalizeDist(annotations, dist):
	"""stores dist normalized to sum(dist.values())=1 in
	annotations[discrete_values]
	This is probably only useful as the extractor for discrete_values
	output fields.
	"""
	normalizer = sum(dist.values(), 0)
	if not normalizer:
		# no metadata detected, which might be benign.
		return
	annotations["discrete_values"] = dict((key, val/normalizer)
		for key, val in dist.items())
def _estimateStatRuntime(rowsInspected):
	"""returns some text guessing at how long a stat run will take based
	on now many rows will have to be inspected.
	This assumes not too unreasonable hardware of ~2020; adjust secondsScale
	over time
	"""
	secondsScale = 5e5
	if rowsInspected<secondsScale:
		return "seconds"
	
	elif rowsInspected<secondsScale*60:
		return "minutes"
	elif rowsInspected<secondsScale*3600:
		return "hours"
	
	elif rowsInspected<secondsScale*3600*24:
		return "days"
	else:
		return "too much time"
def _noteAndDelay(fullTime):
	"""a delaying iterator also doing the UI of _annotateColumns.
	"""
	startTime = time.time()
	while True:
		elapsed = str(int(time.time()-startTime))+"s of "+fullTime
		api.ui.notifyProgress(elapsed)
		time.sleep(1)
		yield
def _annotateColumns(conn, td, samplePercent):
	"""Adds statistics annotations on td.
	This usually takes a long time, and thus this will produce progress
	messages to the UI.
	The changes the column objects in td.
	"""
	dbtable = api.TableForDef(td, connection=conn)
	if samplePercent==100:
		samplePercent = None
	outputFields, annotators = getAnnotators(td)
	resultTableDef, query, pars = dbtable.getQuery(
		outputFields, "", samplePercent=samplePercent)
	with base.NonBlockingQuery(conn, query, pars) as query:
		try:
			for _ in _noteAndDelay(
					_estimateStatRuntime(td.nrows*(samplePercent or 100)/100.)):
				if query.result is not None:
					break
		except KeyboardInterrupt:
			query.abort()
			raise base.ReportableError("Aborted statistics gathering."
				"  You might consider using --sample-percent")
	resultRow = dict((f.name, v)
		for f, v in zip(outputFields, query.result[0]))
	for annotator in annotators:
		annotator.annotate(resultRow)
[docs]def annotateDBTable(td,
		samplePercent=None,
		acquireColumnMeta=True):
	"""returns the TableDef td with domain annotations for its columns.
	td must be an existing on-Disk table.  If acquireColumnMeta is
	False, only the size of the table is being estimated.
	samplePercent uses TABLESAMPLE SYSTEM to only look at about this percentage
	of the rows (which doesn't work for views).
	The annotations come in a dictionary-valued attribute annotations on
	the column object.  The keys in there correspond to column names
	from //dc_tables.
	This will not attempt to annotate columns that already have
	min, max, or options in their values.
	This will only look at columns that have appropriate types.
	"""
	api.ui.notifyProcessStarts(f"Getting stats for {td.getQName()}")
	try:
		api.ui.notifyProgress("Estimate nrows")
		with base.getTableConn() as conn:
			try:
				td.nrows = estimateTableSize(td.getQName(), conn)
			except base.DBError:
				raise base.ui.logOldExc(
					api.ReportableError(f"Table {td.getQName()} cannot be queried.",
						hint="This could be because it is an in-memory table.  Add"
						" onDisk='True' to make tables reside in the database in that"
						" case.  Or run dachs imp to import it if it just hasn't been"
						" created."))
		if samplePercent is None:
			samplePercent = getDefaultSamplePercent(td.nrows)
		if acquireColumnMeta:
			_annotateColumns(conn, td, samplePercent)
	finally:
		api.ui.notifyProcessEnded() 
[docs]def getSCSCoverageQuery(td, order):
	"""returns a database query for getting a MOC for a table suitable
	for cone search.
	This will return None if no such query can be built.
	"""
	try:
		raCol, decCol = scs.getConeColumns(td)
	except (base.StructureError, base.NotFoundError):
		return None
	fragments = [
		"SELECT smoc('%d/' || string_agg(format('%%%%s', hpx), ','))"%order,
		"FROM (",
		"  SELECT DISTINCT healpix_nest(%d,"
			" spoint(RADIANS(%s), RADIANS(%s))) AS hpx "%(
				order, str(raCol.name), str(decCol.name)),
		"FROM %s"%td.getQName(),
		"WHERE %s IS NOT NULL AND %s IS NOT NULL"%(
			str(raCol.name), str(decCol.name)),
		"GROUP BY hpx",
		") as q"]
	return "\n".join(fragments) 
[docs]def getSSAPCoverageQuery(td, order):
	"""returns a database query for getting a MOC for a table using
	one of our standard SSAP mixins.
	This will return None if no such query can be built.
	"""
	mixinsHandled = ["//ssap#hcd", "//ssap#mixc", "//ssap#view"]
	for mixin in mixinsHandled:
		if td.mixesIn(mixin):
			break
	else:
		return None
	fragments = [
		"SELECT SUM(SMOC(%d,"%order,
		"  SCIRCLE(ssa_location, RADIANS(COALESCE(ssa_aperture, 1/3600.)))))",
		"FROM %s WHERE ssa_location IS NOT NULL"%td.getQName()]
	return "\n".join(fragments) 
[docs]def getSIAPCoverageQuery(td, order):
	"""returns a database query for getting a MOC for a table using
	//siap#pgs (i.e., SIAv1)
	This will return None if no such query can be built.
	For SIAv2, no such thing is available yet, the obscore querier
	below should work; however, we don't really have standalone SIAv2
	resources in DaCHS yet.
	"""
	if not td.mixesIn("//siap#pgs"):
		return None
	fragments = [
		"SELECT SUM(SMOC(%d, coverage))"%order,
		"FROM %s WHERE coverage IS NOT NULL"%td.getQName()]
	return "\n".join(fragments) 
[docs]def getObscoreCoverageQuery(td, order):
	"""returns a database query for getting a MOC for tables with obscore
	columns
	This will return None if no such query can be built.
	"""
	geoSources = []
	if "s_region" in td:
		sRegionType = td.getColumnByName("s_region").type
		if sRegionType=="spoly":
			geoSources.append("SMOC(%d, s_region),"%order)
		elif sRegionType=="smoc":
			geoSources.append("smoc_degrade(%d, s_region),"%order)
		else:
			base.ui.notifyWarning("Table has unsupported s_region type %s"
				" when determining coverage."%sRegionType)
	
	if "s_ra" in td and "s_dec" in td and "s_fov" in td:
		geoSources.append(
			"smoc_disc(%s, RADIANS(s_ra), RADIANS(s_dec), RADIANS(s_fov)),"%order)
	if not geoSources:
		return None
	fragments = [
		"SELECT SUM(coverage)",
		"FROM (SELECT",
		"  COALESCE(",
		"    "+(" ".join(geoSources)),
		"    NULL) AS coverage",
		"  FROM %s"%td.getQName(),
		"  ) AS q"]
	return "\n".join(fragments) 
[docs]def getMOCQuery(td, order):
	"""returns a MOC-generating query for a tableDef with standard
	columns.
	(this is a helper for getMOCForStdTable)
	"""
	for generatingFunction in [
			getSIAPCoverageQuery,
			getSSAPCoverageQuery,
			getSCSCoverageQuery,
			getObscoreCoverageQuery,
		]:
		mocQuery = generatingFunction(td, order)
		if mocQuery is not None:
			return mocQuery
	else:
		raise base.ReportableError("Table %s does not have columns DaCHS knows"
			" how to get a coverage from."%td.getFullId()) 
[docs]def getMOCForStdTable(td, order=6):
	"""returns a MOC for a tableDef with one of the standard protocol mixins.
	The function knows about SCS and SSAP for now; protocols are tested
	for in this order.
	"""
	with base.getTableConn() as conn:
		if not base.UnmanagedQuerier(conn).getTableType(td.getQName()):
			return None
		moc = list(conn.query(getMOCQuery(td, order)))[0][0]
	return moc 
def _getTimeTransformer(col):
	"""returns a function turning values in col to MJDs.
	This is very much a matter of heuristics; we build upon what's happening
	in utils.serializers.
	"""
	if col.type in ["timestamp", "date"]:
		return lambda val: stc.dateTimeToMJD(val)
		
	elif stc.isMJD(col):
		return utils.identity
	elif col.unit=="yr" or col.unit=="a":
		return lambda val: stc.dateTimeToMJD(stc.jYearToDateTime(val))
	elif col.unit=="d":
		return lambda val: val-stc.JD_MJD
	elif col.unit=="s":
		return lambda val: stc.dateTimeToMJD(datetime.utcfromtimestamp(val))
	
	else:
		raise NotImplementedError("Cannot figure out how to get an MJD"
			" from column %s"%col.name)
def _min(s):
	return "MIN(%s)"%s
def _max(s):
	return "MAX(%s)"%s
[docs]def getTimeLimitsExprs(td):
	"""returns the names of columns hopefully containing minimal and
	maximal time coverage of each row of a table defined by td.
	As required by getScalarLimits, this will also return a function
	that (hopefully) turns the detected columns to julian years,
	This tries a couple of known criteria for columns containing times
	in some order, and the first one matching wins.
	This will raise a NotFoundError if none of our heuristics work.
	"""
	# obscore and friends
	try:
		return (_min(td.getColumnByName("t_min").name),
			_max(td.getColumnByName("t_max").name),
			_getTimeTransformer(td.getColumnByName("t_min")))
	except base.NotFoundError:
		pass
	
	# SSAP
	try:
		col = td.columns.getColumnByUtype(
			"ssa:Char.TimeAxis.Coverage.Location.Value"
			)
		return _min(col.name), _max(col.name), utils.identity
	except base.NotFoundError:
		pass
	
	# our SIAP mixins
	try:
		col = td.getColumnByName("dateObs"
			)
		return _min(col.name), _max(col.name), utils.identity
	except base.NotFoundError:
		pass
	
	# Any table with appropriate, sufficiently unique UCDs
	try:
		col = td.getColumnByUCD("time.start")
		# we assume time.start and time.end work the same way and one
		# transformer is enough.
		return (_min(col.name),
			_max(td.getColumnByUCD("time.end").name)
			), _getTimeTransformer(col)
	except base.StructureError:
		pass
	
	for obsUCD in ["time.epoch", "time.epoch;obs"]:
		try:
			for col in td.getColumnsByUCD(obsUCD):
				if not col.isScalar():
					raise ValueError("Cannot determine scalar coverage from arrays")
				return _min(col.name), _max(col.name), _getTimeTransformer(col)
		except (ValueError, base.StructureError):
			pass
	
	raise base.NotFoundError("temporal coverage", "Columns to figure out",
		"table "+td.getFullId()) 
[docs]def getSpectralLimitsExprs(td):
	"""returns the name of columns hopefully containing minimal and
	maximal spectral coverage.
	As transformer function, we currently return the identity, as we're
	only using IVOA standard columns anyway.  Based on unit and ucd,
	we could pretty certainly do better.
	If this doesn't find any, it raise a NotFoundError.
	"""
	# obscore and friends
	try:
		return (_max(td.getColumnByName("em_max").name),
			_min(td.getColumnByName("em_min").name),
			getSpectralTransformer("m"))
	except base.NotFoundError:
		pass
	
	# SSAP
	try:
		return (_max(td.getColumnByName("ssa_specend").name),
			_min(td.getColumnByName("ssa_specstart").name),
			getSpectralTransformer("m"))
	except base.NotFoundError:
		pass
	# SIAv1
	try:
		return (_max(td.getColumnByName("bandpassHi").name),
			_min(td.getColumnByName("bandpassLo").name),
			getSpectralTransformer("m"))
	except base.NotFoundError:
		pass
	raise base.NotFoundError("spectral coverage", "Columns to figure out",
		"table "+td.getFullId()) 
[docs]def iterScalarLimits(td, columnsDeterminer):
	"""yields Internal instances for time or spectral coverage.
	ColumnsDeterminer is a function td -> (mincol, maxcol, transformer) expected
	to raise a NotFoundError if no appropriate columns can be found.  This is
	either getTimeLimitsExprs or getSpectralLimitsExprs at this
	point.  transformer here is a function val -> val turning what's coming
	back from the database to what's expected by the coverage machinery
	(e.g., MJD -> jYear).
	It's conceivable that at some point we'll give multiple intervals,
	and hence this is an iterator (that can very well yield nothing for
	a variety of reasons).
	"""
	try:
		minExpr, maxExpr, transformer = columnsDeterminer(td)
	except base.NotFoundError:
		return
	query = "SELECT %s, %s FROM %s"%(
		str(minExpr), str(maxExpr), td.getQName())
	with base.getTableConn() as conn:
		if not base.UnmanagedQuerier(conn).getTableType(td.getQName()):
			return
		for res in conn.query(query):
			try:
				if res[0] is not None and res[1] is not None:
					yield [transformer(res[0]), transformer(res[1])]
			except ZeroDivisionError:
				raise base.ReportableError(f"Invalid limits in {td}: {res}") 
def _format_val(val):
	"""returns a string representation of val for inclusion in our
	info table.
	"""
	return utils.makeEllipsis(utils.safe_str(val), 30)
def _format_percent(val):
	"""formats the float val as a percentage.
	"""
	return "%.0f%%"%(val*100)
_PROP_SEQ = (
	("min_value", _format_val),
	("median", _format_val),
	("max_value", _format_val),
	("fill_factor", _format_percent))
[docs]def printTableInfo(td, samplePercent=None):
	"""tries to obtain various information on the properties of the
	database table described by td.
	"""
	annotateDBTable(td, samplePercent=samplePercent)
	propTable = [("col",)+tuple(p[0] for p in _PROP_SEQ)]
	for col in td:
		row = [col.name]
		for prop, formatter in _PROP_SEQ:
			if prop in col.annotations:
				row.append(formatter(col.annotations[prop]))
			else:
				row.append("-")
		propTable.append(tuple(row))
	print(utils.formatSimpleTable(propTable)) 
[docs]def estimateTableSize(tableName, connection):
	"""returns an estimate for the size of the table tableName.
	This is precise for tables postgres guesses are small (currently, 1e6
	rows); for larger tables, we round up postgres' estimate.
	The function will return None for non-existing tables.
	"""
	try:
		estsize = base.UnmanagedQuerier(connection).getRowEstimate(tableName)
	except KeyError:
		return None
	if estsize<1e6:
		# for small tables, we can be precise; but this will also kick in
		# for views (which have estsize=0), where this may run forever.
		# I don't think this can be helped.
		res = list(connection.query("SELECT COUNT(*) FROM %s"%tableName))
		if not res:
			return None
		return res[0][0]
	
	else:
		return utils.roundO2M(estsize) 
[docs]def parseCmdline():
	from argparse import ArgumentParser
	parser = ArgumentParser(
		description="Displays various stats about the table referred to in"
			" the argument.")
	parser.add_argument("tableId", help="Table id (of the form rdId#tableId)")
	parser.add_argument("-m", "--moc-order", type=int, default=None,
		dest="mocOrder",
		help="Also print a MOC giving the coverage at MOC_ORDER (use MOC_ORDER=6"
			" for about 1 deg resolution).",
		metavar="MOC_ORDER")
	parser.add_argument("-s", "--sample-percent", type=float, default=None,
		dest="samplePercent", metavar="P",
		help="Only look at P percent of the table to determine min/max/mean.")
	return parser.parse_args() 
[docs]def main():
	args = parseCmdline()
	td = api.getReferencedElement(args.tableId, api.TableDef)
	printTableInfo(td, args.samplePercent)
	if args.mocOrder:
		print("Computing MOC at order %s..."%args.mocOrder)
		print(getMOCForStdTable(td, args.mocOrder))