Source code for gavo.helpers.testtricks

"""
Helper functions and classes for unit tests and similar.

Whatever is useful to unit tests from here should be imported into
testhelpers, too.  Unit test modules should not be forced to import
this.
"""

#c Copyright 2008-2023, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL.  See the
#c COPYING file in the source distribution.


import contextlib
import gzip
import os
import re
import tempfile

from lxml import etree

from gavo import base
from gavo import utils
from gavo.formal.testing import ( #noflake: exported names
	FakeRequest, assertHasStrings)
from gavo.utils import stanxml


def _nukeNamespaces(xmlString):
	"""removes namespace declarations from xmlString (which must be bytes).

	This is for writing more compact tests and should of course not be
	used outside of tests; in particular, you could easily fool the
	mechanism to wreck your XML.

	This always returns bytes.
	"""
	nsCleaner = re.compile(rb'^(</?)(?:[a-z0-9]+:)')
	return re.sub(b"(?s)<[^>]*>",
		lambda mat: nsCleaner.sub(rb"\1", mat.group()),
		re.sub(b'xmlns="[^"]*"', b"", xmlString))


class _WrappedEtree:
	"""a wrapper adding a few methods to an lxml etree.

	This is done as a wrapper because you can't monkeypatch _Element.

	See getXMLTree for what this about; it's essentially an implementation
	detail of that function.
	"""
	def __init__(self, tree):
		self._tree = tree
	
	def __getattr__(self, name):
		return getattr(self._tree, name)

	def __getitem__(self, index):
		return self._tree[index]

	def uniqueXpath(self, path):
		res = self.xpath(path)
		assert len(res)==1, "Xpath %s gave %d matches"%(path, len(res))
		return res[0]

	def getById(self, id):
		return self.uniqueXpath("//*[@id='%s']"%id)
	
	def getByID(self, id):
		return self.uniqueXpath("//*[@ID='%s']"%id)

	def asString(self):
		return etree.tostring(self._tree).decode("utf-8")


[docs]def getXMLTree(xmlString, debug=False): """returns an ``libxml2`` etree-like object for ``xmlString``, where, for convenience, all namespaces on elements are nuked. This will only accept strings. The libxml2 etree lets you do xpath searching using the ``xpath`` method. Nuking namespaces is of course not a good idea in general, so you might want to think again before you use this in production code. To facilitate writing tests, in addition to lxml.etree methods the returned object also has the following methods: * uniqueXpath(xpath), returning a single match if it's there and raises an assertion error otherwise. * getById(id), returning the unique element with id and raising an assertion error if that doesn't exist. * getByID(id), as getById, by for VOTable-style ID. * asString(), returning a string representation of the tree """ tree = etree.fromstring(_nukeNamespaces(xmlString)) if debug: etree.dump(tree) return _WrappedEtree(tree)
[docs]class XSDResolver(etree.Resolver): """A resolver for external entities only returning in-tree files. """ def __init__(self): self.basePath = "schemata"
[docs] def getPathForName(self, name): xsdName = name.split("/")[-1] return base.getPathForDistFile( os.path.join(self.basePath, xsdName))
[docs] def resolve(self, url, pubid, context): try: # resolve namespace URIs, too try: url = stanxml.NSRegistry.getSchemaForNS(url) except base.NotFoundError: # it's not a (known) namespace URI, try on pass path = self.getPathForName(url) res = self.resolve_filename(path, context) if res is not None: return res except: pass # fall through to error message base.ui.notifyError("Did not find local file for schema %s --" " this will fall back to network resources and thus probably" " be slow"%url)
RESOLVER = XSDResolver() XSD_PARSER = etree.XMLParser() XSD_PARSER.resolvers.add(RESOLVER)
[docs]@contextlib.contextmanager def MyParser(): if etree.get_default_parser is XSD_PARSER: yield else: etree.set_default_parser(XSD_PARSER) try: yield finally: etree.set_default_parser()
[docs]class QNamer(object): """A hack that generates QNames through getattr. Construct with the desired namespace. """ def __init__(self, ns): self.ns = ns def __getattr__(self, name): return etree.QName(self.ns, name.strip("_"))
XS = QNamer("http://www.w3.org/2001/XMLSchema") VO_SCHEMATA = [ "Characterisation.xsd", "Colstats.xsd", # remove once it's in VODataService "ConeSearch.xsd", "DaCHS.xsd", "DataModel.xsd", "DocRegExt.xsd", "eudat-core.xsd", "oai_dc.xsd", "OAI-PMH.xsd", "RegistryInterface.xsd", "SIA.xsd", "SLAP.xsd", "SSA.xsd", "StandardsRegExt.xsd", "stc.xsd", "stc-v1.20.xsd", "coords-v1.20.xsd", "region-v1.20.xsd", "TAPRegExt.xsd", "UWS.xsd", "VODataService.xsd", "VOEvent.xsd", "VOEventRegExt.xsd", "VORegistry.xsd", "VOResource.xsd", "VOSIAvailability.xsd", "VOSICapabilities.xsd", "VOSITables.xsd", "VOTable-1.1.xsd", "VOTable-1.2.xsd", "VOTable.xsd", "mivot.xsd", "vo-dml.xsd", "xlink.xsd", "XMLSchema.xsd", "xml.xsd",]
[docs]def getJointValidator(schemaPaths): """returns an lxml validator containing the schemas in schemaPaths. schemaPaths must be actual file paths, absolute or trunk/schema-relative. """ with MyParser(): subordinates = [] for fName in schemaPaths: fPath = RESOLVER.getPathForName(fName) root = etree.parse(fPath).getroot() subordinates.append(( "http://vo.ari.uni-heidelberg.de/docs/schemata/"+fName, root.get("targetNamespace"))) root = etree.Element( XS.schema, attrib={"targetNamespace": "urn:combiner"}) for schemaLocation, tns in subordinates: etree.SubElement(root, XS.import_, attrib={ "namespace": tns, "schemaLocation": schemaLocation}) doc = etree.ElementTree(root) return etree.XMLSchema(doc)
[docs]def getDefaultValidator(extraSchemata=[]): """returns a validator that knows the schemata typically useful within the VO. This will currently only work if DaCHS is installed from an SVN checkout with setup.py develop. What's returned has a method assertValid(et) that raises an exception if the elementtree et is not valid. You can simply call it to get back True for valid and False for invalid. """ return getJointValidator(VO_SCHEMATA+extraSchemata)
def _makeLXMLValidator(): """returns an lxml-based schema validating function for the VO XSDs This is not happening at import time as it is time-consuming, and the DaCHS server probably doesn't even validate anything. This is used below to build getXSDErrorsLXML. """ VALIDATOR = getDefaultValidator() def getErrors(data, leaveOffending=False): """returns error messages for the XSD validation of the string in data. """ try: with MyParser(): if hasattr(data, "xpath"): # we believe it's already parsed stuff tree = data else: tree = etree.fromstring(data) if VALIDATOR.validate(tree): return None else: if leaveOffending: if hasattr(data, "xpath"): data = etree.tostring(data, encoding="utf-8") with open("badDocument.xml", "wb") as of: of.write(data) return str(VALIDATOR.error_log) except Exception as msg: return str(msg) return getErrors
[docs]def getXSDErrorsLXML(data, leaveOffending=False): """returns error messages for the XSD validation of the string in data. """ if not hasattr(getXSDErrorsLXML, "validate"): getXSDErrorsLXML.validate = _makeLXMLValidator() return getXSDErrorsLXML.validate(data, leaveOffending)
getXSDErrors = getXSDErrorsLXML
[docs]class XSDTestMixin(object): """provides a assertValidates method doing XSD validation. assertValidates raises an assertion error with the validator's messages on an error. You can optionally pass a leaveOffending argument to make the method store the offending document in badDocument.xml. """
[docs] def assertValidates(self, xmlSource, leaveOffending=False): messages = getXSDErrors(xmlSource, leaveOffending) if messages: raise AssertionError(messages)
[docs] def assertWellformed(self, xmlSource): try: etree.fromstring(xmlSource) except Exception as msg: raise AssertionError("XML not well-formed (%s)"%msg)
[docs]def getMemDiffer(ofClass=base.Structure): """returns a function to call that returns a list of new DaCHS structures since this was called. If you watch everything, things get hairy because of course the state of this function (for instance) also creates references. Hence, pass ofClass to choose what the function will track. This will call a gc.collect itself (and wouldn't make sense without that) """ import gc gc.collect() seen_ids = set() for ob in gc.get_objects(): try: if isinstance(ob, ofClass): seen_ids.add(id(ob)) except ReferenceError: # object is already essentially gone, don't worry about it. pass del ob def getNewObjects(): gc.collect() newObjects = [] for ob in gc.get_objects(): try: if id(ob) not in seen_ids and isinstance(ob, ofClass): newObjects.append(ob) except ReferenceError: # again, don't worry about disappearing objects pass return newObjects return getNewObjects
[docs]def getUnreferenced(items): """returns a list of elements in items that do not have a reference from any other in items. """ import gc itemids = set(id(i) for i in items) unreferenced = [] for i in items: intrefs = set(id(r) for r in gc.get_referrers(i)) & itemids if not intrefs: unreferenced.append(i) return unreferenced
[docs]def debugReferenceChain(ob): """a sort-of-interactive way to investigate where ob is referenced. * d -- enter pdb (look at ob, perhaps at nob) * u -- follow * x -- continue execution """ import gc while True: print("Current object: ", repr(ob)) refs = gc.get_referrers(ob) if not refs: print("Not referenced -- exiting") break while refs: nob = refs.pop() print(len(refs), utils.makeEllipsis(repr(nob))) res = input("?") if res=="d": import pdb;pdb.Pdb(nosigint=True).set_trace() elif res=="x": return elif res=="u": ob = nob break elif res=="?": print("d, x, u, <empty>") elif not refs: print("Referrers exhausted, warping")
NEWIDS = set()
[docs]def memdebug(watchClass=base.Structure): """a debug method to track memory usage after some code has run. This is typically run from ArchiveService.getChild, since request processing should be idempotent wrt memory after initial caching. This is for editing in place by DaCHS plumbers; accordingly, you're not supposed to make sense of this. """ import gc print(">>>>>> total managed:", len(gc.get_objects())) if hasattr(base, "getNewStructs"): ns = base.getNewStructs() print(">>>>>> new objects:", len(ns)) if len(ns)<11000: ur = getUnreferenced(ns) print(">>>>>> new externally referenced:", len(ur)) del ur print([ob for ob in ns if isinstance(ob, watchClass)]) if True: try: debugReferenceChain( [ob for ob in ns if isinstance(ob, watchClass)][0]) except IndexError: pass base.getNewStructs = getMemDiffer(ofClass=watchClass)
[docs]@contextlib.contextmanager def testFile(name, content, writeGz=False, inDir=base.getConfig("tempDir"), timestamp=None): """a context manager that creates a file name with content in inDir. The full path name is returned. content can be bytes or str; in the latter case, it's utf-8 encoded before writing. With writeGz=True, content is gzipped on the fly (don't do this if the data already is gzipped). You can pass in name=None to get a temporary file name if you don't care about the name. inDir will be created as a side effect if it doesn't exist but (right now, at least), not be removed. """ if not os.path.isdir(inDir): os.makedirs(inDir) if name is None: handle, destName = tempfile.mkstemp(dir=inDir) os.close(handle) else: destName = os.path.join(inDir, name) if writeGz: f = gzip.GzipFile(destName, mode="wb") else: f = open(destName, "wb") f.write(utils.bytify(content)) f.close() if timestamp: os.utime(destName, times=(timestamp, timestamp)) try: yield destName finally: try: os.unlink(destName) except os.error: pass
[docs]@contextlib.contextmanager def collectedEvents(*kinds): """a context manager collecting event arguments for a while. The yielded thing is a list that contains tuples of event name and the event arguments. """ collected = [] def makeHandler(evType): def handler(*args): collected.append((evType,)+args) return handler handlers = [(kind, makeHandler(kind)) for kind in kinds] for kind, handler in handlers: base.ui.subscribe(kind, handler) try: yield collected finally: for kind, handler in handlers: base.ui.unsubscribe(kind, handler)