#!/usr/bin/python3 # # Copyright (c) 2019 Collabora, Ltd. # # SPDX-License-Identifier: Apache-2.0 # # Author(s): Ryan Pavlik # # Purpose: This script checks some "business logic" in the XML registry. import re import sys from pathlib import Path from check_spec_links import VulkanEntityDatabase as OrigEntityDatabase from reg import Registry from spec_tools.consistency_tools import XMLChecker from spec_tools.util import findNamedElem, getElemName, getElemType from vkconventions import VulkanConventions as APIConventions # These are extensions which do not follow the usual naming conventions, # specifying the alternate convention they follow EXTENSION_ENUM_NAME_SPELLING_CHANGE = { 'VK_EXT_swapchain_colorspace': 'VK_EXT_SWAPCHAIN_COLOR_SPACE', } # These are extensions whose names *look* like they end in version numbers, # but don't EXTENSION_NAME_VERSION_EXCEPTIONS = ( 'VK_AMD_gpu_shader_int16', 'VK_EXT_index_type_uint8', 'VK_EXT_shader_image_atomic_int64', 'VK_EXT_video_decode_h264', 'VK_EXT_video_decode_h265', 'VK_EXT_video_encode_h264', 'VK_EXT_video_encode_h265', 'VK_KHR_external_fence_win32', 'VK_KHR_external_memory_win32', 'VK_KHR_external_semaphore_win32', 'VK_KHR_shader_atomic_int64', 'VK_KHR_shader_float16_int8', 'VK_KHR_spirv_1_4', 'VK_NV_external_memory_win32', 'VK_RESERVED_do_not_use_146', 'VK_RESERVED_do_not_use_94', ) # Exceptions to pointer parameter naming rules # Keyed by (entity name, type, name). CHECK_PARAM_POINTER_NAME_EXCEPTIONS = { ('vkGetDrmDisplayEXT', 'VkDisplayKHR', 'display') : None, } # Exceptions to pNext member requiring an optional attribute CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS = ( 'VkVideoEncodeInfoKHR', ) def get_extension_commands(reg): extension_cmds = set() for ext in reg.extensions: for cmd in ext.findall("./require/command[@name]"): extension_cmds.add(cmd.get("name")) return extension_cmds def get_enum_value_names(reg, enum_type): names = set() result_elem = reg.groupdict[enum_type].elem for val in result_elem.findall("./enum[@name]"): names.add(val.get("name")) return names # Regular expression matching an extension name ending in a (possible) version number EXTNAME_RE = re.compile(r'(?P(\w+[A-Za-z]))(?P\d+)') DESTROY_PREFIX = "vkDestroy" TYPEENUM = "VkStructureType" SPECIFICATION_DIR = Path(__file__).parent.parent REVISION_RE = re.compile(r' *[*] Revision (?P[1-9][0-9]*),.*') def get_extension_source(extname): fn = '{}.txt'.format(extname) return str(SPECIFICATION_DIR / 'appendices' / fn) class EntityDatabase(OrigEntityDatabase): # Override base class method to not exclude 'disabled' extensions def getExclusionSet(self): """Return a set of "support=" attribute strings that should not be included in the database. Called only during construction.""" return set(()) def makeRegistry(self): try: import lxml.etree as etree HAS_LXML = True except ImportError: HAS_LXML = False if not HAS_LXML: return super().makeRegistry() registryFile = str(SPECIFICATION_DIR / 'xml/vk.xml') registry = Registry() registry.filename = registryFile registry.loadElementTree(etree.parse(registryFile)) return registry class Checker(XMLChecker): def __init__(self): manual_types_to_codes = { # These are hard-coded "manual" return codes: # the codes of the value (string, list, or tuple) # are available for a command if-and-only-if # the key type is passed as an input. "VkFormat": "VK_ERROR_FORMAT_NOT_SUPPORTED" } forward_only = { # Like the above, but these are only valid in the # "type implies return code" direction } reverse_only = { # like the above, but these are only valid in the # "return code implies type or its descendant" direction # "XrDuration": "XR_TIMEOUT_EXPIRED" } # Some return codes are related in that only one of a set # may be returned by a command # (eg. XR_ERROR_SESSION_RUNNING and XR_ERROR_SESSION_NOT_RUNNING) self.exclusive_return_code_sets = tuple( # set(("XR_ERROR_SESSION_NOT_RUNNING", "XR_ERROR_SESSION_RUNNING")), ) # Map of extension number -> [ list of extension names ] self.extension_number_reservations = { } # This is used to report collisions. conventions = APIConventions() db = EntityDatabase() self.extension_cmds = get_extension_commands(db.registry) self.return_codes = get_enum_value_names(db.registry, 'VkResult') self.structure_types = get_enum_value_names(db.registry, TYPEENUM) # Dict of entity name to a list of messages to suppress. (Exclude any context data and "Warning:"/"Error:") # Keys are entity names, values are tuples or lists of message text to suppress. suppressions = {} # Initialize superclass super().__init__(entity_db=db, conventions=conventions, manual_types_to_codes=manual_types_to_codes, forward_only_types_to_codes=forward_only, reverse_only_types_to_codes=reverse_only, suppressions=suppressions) def check_command_return_codes_basic(self, name, info, successcodes, errorcodes): """Check a command's return codes for consistency. Called on every command.""" # Check that all extension commands can return the code associated # with trying to use an extension that wasn't enabled. # if name in self.extension_cmds and UNSUPPORTED not in errorcodes: # self.record_error("Missing expected return code", # UNSUPPORTED, # "implied due to being an extension command") codes = successcodes.union(errorcodes) # Check that all return codes are recognized. unrecognized = codes - self.return_codes if unrecognized: self.record_error("Unrecognized return code(s):", unrecognized) elem = info.elem params = [(getElemName(elt), elt) for elt in elem.findall('param')] def is_count_output(name, elt): # Must end with Count or Size, # not be const, # and be a pointer (detected by naming convention) return (name.endswith('Count') or name.endswith('Size')) \ and (elt.tail is None or 'const' not in elt.tail) \ and (name.startswith('p')) countParams = [elt for name, elt in params if is_count_output(name, elt)] if countParams: assert(len(countParams) == 1) if 'VK_INCOMPLETE' not in successcodes: self.record_error( "Apparent enumeration of an array without VK_INCOMPLETE in successcodes.") elif 'VK_INCOMPLETE' in successcodes: self.record_error( "VK_INCOMPLETE in successcodes of command that is apparently not an array enumeration.") def check_param(self, param): """Check a member of a struct or a param of a function. Called from check_params.""" super().check_param(param) if not self.is_api_type(param): return param_text = "".join(param.itertext()) param_name = getElemName(param) # Make sure the number of leading "p" matches the pointer count. pointercount = param.find('type').tail if pointercount: pointercount = pointercount.count('*') if pointercount: prefix = 'p' * pointercount if not param_name.startswith(prefix): param_type = param.find('type').text message = "Apparently incorrect pointer-related name prefix for {} - expected it to start with '{}'".format( param_text, prefix) if (self.entity, param_type, param_name) in CHECK_PARAM_POINTER_NAME_EXCEPTIONS: self.record_warning('(Allowed exception)', message, elem=param) else: self.record_error(message, elem=param) # Make sure pNext members have optional="true" attributes if param_name == self.conventions.nextpointer_member_name: optional = param.get('optional') if optional is None or optional != 'true': message = '{}.pNext member is missing \'optional="true"\' attribute'.format(self.entity) if self.entity in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS: self.record_warning('(Allowed exception)', message, elem=param) else: self.record_error(message, elem=param) def check_type(self, name, info, category): """Check a type's XML data for consistency. Called from check.""" elem = info.elem type_elts = [elt for elt in elem.findall("member") if getElemType(elt) == TYPEENUM] if category == 'struct' and type_elts: if len(type_elts) > 1: self.record_error( "Have more than one member of type", TYPEENUM) else: type_elt = type_elts[0] val = type_elt.get('values') if val and val not in self.structure_types: self.record_error("Unknown structure type constant", val) # Check the pointer chain member, if present. next_name = self.conventions.nextpointer_member_name next_member = findNamedElem(info.elem.findall('member'), next_name) if next_member is not None: # Ensure that the 'optional' attribute is set to 'true' optional = next_member.get('optional') if optional is None or optional != 'true': message = '{}.{} member is missing \'optional="true"\' attribute'.format(name, next_name) if name in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS: self.record_warning('(Allowed exception)', message) else: self.record_error(message) elif category == "bitmask": if 'Flags' in name: expected_require = name.replace('Flags', 'FlagBits') require = info.elem.get('require') if require is not None and expected_require != require: self.record_error("Unexpected require attribute value:", "got", require, "but expected", expected_require) super().check_type(name, info, category) def check_extension(self, name, info): """Check an extension's XML data for consistency. Called from check.""" elem = info.elem enums = elem.findall('./require/enum[@name]') # Look for other extensions using that number # Keep track of this extension number reservation ext_number = elem.get('number') if ext_number in self.extension_number_reservations: conflicts = self.extension_number_reservations[ext_number] self.record_error('Extension number {} has more than one reservation: {}, {}'.format( ext_number, name, ', '.join(conflicts))) self.extension_number_reservations[ext_number].append(name) else: self.extension_number_reservations[ext_number] = [ name ] # If extension name is not on the exception list and matches the # versioned-extension pattern, map the extension name to the version # name with the version as a separate word. Otherwise just map it to # the upper-case version of the extension name. matches = EXTNAME_RE.fullmatch(name) ext_versioned_name = False if name in EXTENSION_ENUM_NAME_SPELLING_CHANGE: ext_enum_name = EXTENSION_ENUM_NAME_SPELLING_CHANGE.get(name) elif matches is None or name in EXTENSION_NAME_VERSION_EXCEPTIONS: # This is the usual case, either a name that doesn't look # versioned, or one that does but is on the exception list. ext_enum_name = name.upper() else: # This is a versioned extension name. # Treat the version number as a separate word. base = matches.group('base') version = matches.group('version') ext_enum_name = base.upper() + '_' + version # Keep track of this case ext_versioned_name = True # Look for the expected SPEC_VERSION token name version_name = "{}_SPEC_VERSION".format(ext_enum_name) version_elem = findNamedElem(enums, version_name) if version_elem is None: # Did not find a SPEC_VERSION enum matching the extension name if ext_versioned_name: suffix = '\n\ Make sure that trailing version numbers in extension names are treated\n\ as separate words in extension enumerant names. If this is an extension\n\ whose name ends in a number which is not a version, such as "...h264"\n\ or "...int16", add it to EXTENSION_NAME_VERSION_EXCEPTIONS in\n\ scripts/xml_consistency.py.' else: suffix = '' self.record_error('Missing version enum {}{}'.format(version_name, suffix)) elif info.elem.get('supported') == self.conventions.xml_api_name: # Skip unsupported / disabled extensions for these checks fn = get_extension_source(name) revisions = [] with open(fn, 'r', encoding='utf-8') as fp: for line in fp: line = line.rstrip() match = REVISION_RE.match(line) if match: revisions.append(int(match.group('num'))) ver_from_xml = version_elem.get('value') if revisions: ver_from_text = str(max(revisions)) if ver_from_xml != ver_from_text: self.record_error("Version enum mismatch: spec text indicates", ver_from_text, "but XML says", ver_from_xml) else: if ver_from_xml == '1': self.record_warning( "Cannot find version history in spec text - make sure it has lines starting exactly like '* Revision 1, ....'", filename=fn) else: self.record_warning("Cannot find version history in spec text, but XML reports a non-1 version number", ver_from_xml, " - make sure the spec text has lines starting exactly like '* Revision 1, ....'", filename=fn) name_define = "{}_EXTENSION_NAME".format(ext_enum_name) name_elem = findNamedElem(enums, name_define) if name_elem is None: self.record_error("Missing name enum", name_define) else: # Note: etree handles the XML entities here and turns " back into " expected_name = '"{}"'.format(name) name_val = name_elem.get('value') if name_val != expected_name: self.record_error("Incorrect name enum: expected", expected_name, "got", name_val) super().check_extension(name, elem) if __name__ == "__main__": ckr = Checker() ckr.check() if ckr.fail: sys.exit(1)