Source code for gavo.formal.testing

"""
Helpers for testing code using gavo.formal
"""

#c Copyright 2008-2023, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL.  See the
#c COPYING file in the source distribution.


import io
import math

from twisted.internet  import defer
from twisted.python import failure
from twisted.python import urlpath
from twisted.trial.unittest import TestCase as TrialTest
from twisted.web import http
from twisted.web import resource
from twisted.web import server
from twisted.web.test.requesthelper import DummyRequest


[docs]def debug(arg): import pdb; pdb.set_trace() return arg
[docs]def bytify(s): """returns s utf-8 encoded if it is a string, unmodified otherwise. """ if isinstance(s, str): return s.encode("utf-8") return s
[docs]def debytify(b): """returns b utf-8 decoded if it is bytes, unmodified otherwise. """ if isinstance(b, bytes): return b.decode("utf-8") return b
[docs]def assertHasStrings(content, strings, inverse=False): """asserts that all strings in the list strings are in content. If inverse is True, it asserts the strings are not in content. For generality, both content and strings will be bytified if they're not bytes already, and only then compared. When the assertion fails, the bytified content will be dumped to a file remote.data in the current directory. """ content = bytify(content) try: for s in strings: if inverse: assert bytify(s) not in content, f"'{s}' in remote.data" else: assert bytify(s) in content, f"'{s}' not in remote.data" except AssertionError: with open("remote.data", "wb") as f: f.write(content) raise
[docs]def raiseException(failure): failure.raiseException()
[docs]def bytify_seq(s): """returns a list from s with any elements bytified. s must be a sequence or None. Oh, for convenience we also accept plain bytes and strings. """ if s is None: return None if isinstance(s, (bytes, str)): s = [s] if not isinstance(s, (list, tuple)): raise Exception("bytify_seq really wants a sequence, not %s"%repr(s)) return [bytify(item) for item in s]
[docs]class FakeFile: """a fake file upload. Construct this with a file name and a payload, both of which can be bytes or str (which will be utf-8-encoded). """ def __init__(self, file_name, payload): self.file_name = bytify(file_name) self.file_object = io.BytesIO(bytify(payload))
[docs]class FakeRequest(DummyRequest): """A request for test purposes. We furnish t.w's DummyRequest with some extra facilities to let us be a bit lazy in having rather macro tests. Also, stock twisted DummyRequest produces an endless loop with push producers (which is what we have), so we fix that, too. You can pass in args as a str -> str mapping; the strings will be encoded as utf-8 so request.args is bytes -> [bytes]. For convenience, we'll turn single values to lists. For uploads, you can pass (single) args with FakeFile-valued arguments. """ method = b"GET" session = None startedWriting = 0 # some code tests for a live connection using client client = True def __init__(self, uri=b'', headers=None, args=None, avatar=None, currentSegments=None, cookies=None, user="", password="", isSecure=False): uri = bytify(uri) if uri.startswith(b"/"): uri = uri[1:] postpath = [] if uri: postpath = uri.split(b"/") DummyRequest.__init__(self, uri) self.files, self.args = {}, {} if args: for k,v in args.items(): if isinstance(v, FakeFile): self.files[debytify(k)] = [v] else: self.args[bytify(k)] = bytify_seq(v) else: self.args = {} self.uri = uri self.postpath = postpath self.code = 200 self.user, self.password = user, password self.deferred = defer.Deferred() self.accumulator = b"" self.prepath = [] self.finished = False self.secure = False self.channel = 1 # must be non-None for custom hangup detection self.lastModified = None
[docs] def setHeader(self, name, value): # overridden because t.w.t doesn't overwrite self.responseHeaders.setRawHeaders(name, [value])
[docs] def write(self, data): if not self.startedWriting: if self.lastModified is not None: self.responseHeaders.setRawHeaders( b"last-modified", [http.datetimeToString(self.lastModified)]) self.startedWriting = True self.accumulator += bytify(data)
[docs] def notifyFinish(self): return self.deferred
[docs] def prePathURL(self): return 'http://%s/%s'%(self.getHost(), '/'.join(self.prepath))
[docs] def setLastModified(self, when): # copied from twisted.web.server.Request when = int(math.ceil(when)) if (not self.lastModified) or (self.lastModified < when): self.lastModified = when modifiedSince = self.getHeader(b"if-modified-since") if modifiedSince: firstPart = modifiedSince.split(b";", 1)[0] try: modifiedSince = http.stringToDatetime(firstPart) except ValueError: return None if modifiedSince >= self.lastModified: self.setResponseCode(http.NOT_MODIFIED) return http.CACHED return None
[docs] def finish(self): self.finished = True self.deferred.callback((self.accumulator, self))
[docs] def finishCallback(self, arg): if isinstance(arg, failure.Failure): arg.raiseException() if not self.finished: self.finish()
[docs] def setHost(self, host, port): self.host = host
[docs] def getHost(self): return self.host
[docs] def setResponseCode(self, code): self.code = code
[docs] def URLPath(self): return urlpath.URLPath.fromString(self.path.decode("utf-8"))
@property def path(self): return self.uri
[docs] def isSecure(self): return self.secure
[docs] def getLocationValue(self): """returns a location header if this requests redirects, and raises an AssertionError otherwise. """ if not self.code or self.code//100!=3: raise AssertionError("Trying to get a redirection target for" " request with status %s"%self.code) return self.getResponseHeader("location")
[docs] def processWithRoot(self, page): """runs this request on page. This is probably a bad idea all around, and we should just be using trial. But since sync tests are quite a bit more convenient, here this is. Of course, it only works if resource effectively renders sync (or has a renderSync method). """ rsc = resource.getChildForRequest(page, self) res = getattr(rsc, "renderSync", rsc.render)(self) if res: if isinstance(res, int) and res==server.NOT_DONE_YET: # this will only work if the thing is actually sync. # see servicetest._syncvosi for an inspration there. # But in that case, accumulator will have it all. pass else: return res return self.accumulator
[docs] def registerProducer(self, producer, isPush): self.producer = producer if not isPush: DummyRequest.registerProducer( self, producer, isPush)
[docs] def unregisterProducer(self): # stop twisted pull producers, too self.go = 0 self.channel = None del self.producer
[docs] def getResponseHeader(self, headerName): return self.responseHeaders.getRawHeaders(headerName, [None])[0]
[docs] def addUpload(self, name, content): self.files.setdefault(name, []).append( FakeFile(name, content))
def _doRender(page, request): result = page.render(request) if isinstance(result, int) and result==server.NOT_DONE_YET: # the thing is set up in a way that eventually some deferred # will fire and complete return request.deferred elif isinstance(result, bytes): request.write(result) request.finish() return request.deferred else: raise Exception("Unsupported render result: %s"%repr(result)) def _buildRequest( method, path, args, moreHeaders=None, requestClass=None): if requestClass is None: requestClass = FakeRequest req = requestClass(path, args=args) req.headers = {} if moreHeaders: for k, v in moreHeaders.items(): req.requestHeaders.setRawHeaders(k, [v]) req.method = bytify(method) return req
[docs]def runQuery(page, method, path, args, moreHeaders=None, requestMogrifier=None, requestClass=None, produceErrorDocument=None): """runs a query on a page. The query should look like it's coming from localhost. The thing returns a deferred firing a pair of the result (a string) and the request (from which you can glean headers and such). errorHandler must be a callable accepting a failure and the request if you want to exercise your error handling, too. If you don't pass it in, exceptions during request handling will be re-raised. """ req = _buildRequest( method, path, args, moreHeaders=moreHeaders, requestClass=requestClass) if requestMogrifier is not None: requestMogrifier(req) try: rsc = resource.getChildForRequest(page, req) return _doRender(rsc, req) except Exception as ex: if produceErrorDocument: produceErrorDocument(failure.Failure(ex), req) return req.deferred raise
[docs]class RenderTest(TrialTest): """a base class for tests of twisted web resources. """ renderer = None # Override with the resource to be tested. errorHandler = None # override with a runQuery produceErrorDocument runQuery = staticmethod(runQuery)
[docs] def assertStringsIn(self, result, strings, inverse=False, customTest=None): # this wraps testhelpers.assertHasStrings to work better with # twisted results; in particular, we need to return the result. content = result[0] assertHasStrings(content, strings, inverse) try: if customTest is not None: customTest(content) except AssertionError: with open("remote.data", "wb") as f: f.write(content) raise return result
[docs] def assertResultHasStrings(self, method, path, args, strings, rm=None, inverse=False, customTest=None): return self.runQuery( self.renderer, method, path, args, requestMogrifier=rm, produceErrorDocument=self.errorHandler ).addCallback(self.assertStringsIn, strings, inverse=inverse, customTest=customTest)
[docs] def assertGETHasStrings(self, path, args, strings, rm=None, customTest=None): return self.assertResultHasStrings("GET", path, args, strings, rm=rm, customTest=customTest)
[docs] def assertGETLacksStrings(self, path, args, strings, rm=None): return self.assertResultHasStrings("GET", path, args, strings, rm=rm, inverse=True)
[docs] def assertPOSTHasStrings(self, path, args, strings, rm=None): return self.assertResultHasStrings("POST", path, args, strings, rm=rm)
[docs] def assertStatus(self, path, status, args={}, rm=None): def check(res): self.assertEqual(res[1].code, status) return res return self.runQuery( self.renderer, "GET", path, args, requestMogrifier=rm, produceErrorDocument=self.errorHandler ).addCallback(check)
[docs] def assertGETRaises(self, path, args, exc, alsoCheck=None): def cb(res): raise AssertionError("%s not raised (returned %s instead)"%( exc, res)) def eb(flr): flr.trap(exc) if alsoCheck is not None: alsoCheck(flr) return self.runQuery(self.renderer, "GET", path, args, produceErrorDocument=self.errorHandler ).addCallback(cb ).addErrback(eb)
[docs] def assertGETIsValid(self, path, args={}): return self.runQuery(self.renderer, "GET", path, args, produceErrorDocument=self.errorHandler ).addCallback(self.assertResponseIsValid)