391 lines
16 KiB
Python
Executable File
391 lines
16 KiB
Python
Executable File
#!/usr/bin/python3
|
|
#
|
|
# Copyright (c) 2019 Collabora, Ltd.
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Author(s): Ryan Pavlik <ryan.pavlik@collabora.com>
|
|
#
|
|
# Purpose: This script checks some "business logic" in the XML registry.
|
|
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
from check_spec_links import VulkanEntityDatabase as OrigEntityDatabase
|
|
from reg import Registry
|
|
from spec_tools.consistency_tools import XMLChecker
|
|
from spec_tools.util import findNamedElem, getElemName, getElemType
|
|
from vkconventions import VulkanConventions as APIConventions
|
|
|
|
# These are extensions which do not follow the usual naming conventions,
|
|
# specifying the alternate convention they follow
|
|
EXTENSION_ENUM_NAME_SPELLING_CHANGE = {
|
|
'VK_EXT_swapchain_colorspace': 'VK_EXT_SWAPCHAIN_COLOR_SPACE',
|
|
}
|
|
|
|
# These are extensions whose names *look* like they end in version numbers,
|
|
# but don't
|
|
EXTENSION_NAME_VERSION_EXCEPTIONS = (
|
|
'VK_AMD_gpu_shader_int16',
|
|
'VK_EXT_index_type_uint8',
|
|
'VK_EXT_shader_image_atomic_int64',
|
|
'VK_EXT_video_decode_h264',
|
|
'VK_EXT_video_decode_h265',
|
|
'VK_EXT_video_encode_h264',
|
|
'VK_EXT_video_encode_h265',
|
|
'VK_KHR_external_fence_win32',
|
|
'VK_KHR_external_memory_win32',
|
|
'VK_KHR_external_semaphore_win32',
|
|
'VK_KHR_shader_atomic_int64',
|
|
'VK_KHR_shader_float16_int8',
|
|
'VK_KHR_spirv_1_4',
|
|
'VK_NV_external_memory_win32',
|
|
'VK_RESERVED_do_not_use_146',
|
|
'VK_RESERVED_do_not_use_94',
|
|
)
|
|
|
|
# Exceptions to pointer parameter naming rules
|
|
# Keyed by (entity name, type, name).
|
|
CHECK_PARAM_POINTER_NAME_EXCEPTIONS = {
|
|
('vkGetDrmDisplayEXT', 'VkDisplayKHR', 'display') : None,
|
|
}
|
|
|
|
# Exceptions to pNext member requiring an optional attribute
|
|
CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS = (
|
|
'VkVideoEncodeInfoKHR',
|
|
)
|
|
|
|
def get_extension_commands(reg):
|
|
extension_cmds = set()
|
|
for ext in reg.extensions:
|
|
for cmd in ext.findall("./require/command[@name]"):
|
|
extension_cmds.add(cmd.get("name"))
|
|
return extension_cmds
|
|
|
|
|
|
def get_enum_value_names(reg, enum_type):
|
|
names = set()
|
|
result_elem = reg.groupdict[enum_type].elem
|
|
for val in result_elem.findall("./enum[@name]"):
|
|
names.add(val.get("name"))
|
|
return names
|
|
|
|
|
|
# Regular expression matching an extension name ending in a (possible) version number
|
|
EXTNAME_RE = re.compile(r'(?P<base>(\w+[A-Za-z]))(?P<version>\d+)')
|
|
|
|
DESTROY_PREFIX = "vkDestroy"
|
|
TYPEENUM = "VkStructureType"
|
|
|
|
SPECIFICATION_DIR = Path(__file__).parent.parent
|
|
REVISION_RE = re.compile(r' *[*] Revision (?P<num>[1-9][0-9]*),.*')
|
|
|
|
|
|
def get_extension_source(extname):
|
|
fn = '{}.txt'.format(extname)
|
|
return str(SPECIFICATION_DIR / 'appendices' / fn)
|
|
|
|
|
|
class EntityDatabase(OrigEntityDatabase):
|
|
|
|
# Override base class method to not exclude 'disabled' extensions
|
|
def getExclusionSet(self):
|
|
"""Return a set of "support=" attribute strings that should not be included in the database.
|
|
|
|
Called only during construction."""
|
|
|
|
return set(())
|
|
|
|
def makeRegistry(self):
|
|
try:
|
|
import lxml.etree as etree
|
|
HAS_LXML = True
|
|
except ImportError:
|
|
HAS_LXML = False
|
|
if not HAS_LXML:
|
|
return super().makeRegistry()
|
|
|
|
registryFile = str(SPECIFICATION_DIR / 'xml/vk.xml')
|
|
registry = Registry()
|
|
registry.filename = registryFile
|
|
registry.loadElementTree(etree.parse(registryFile))
|
|
return registry
|
|
|
|
|
|
class Checker(XMLChecker):
|
|
def __init__(self):
|
|
manual_types_to_codes = {
|
|
# These are hard-coded "manual" return codes:
|
|
# the codes of the value (string, list, or tuple)
|
|
# are available for a command if-and-only-if
|
|
# the key type is passed as an input.
|
|
"VkFormat": "VK_ERROR_FORMAT_NOT_SUPPORTED"
|
|
}
|
|
forward_only = {
|
|
# Like the above, but these are only valid in the
|
|
# "type implies return code" direction
|
|
}
|
|
reverse_only = {
|
|
# like the above, but these are only valid in the
|
|
# "return code implies type or its descendant" direction
|
|
# "XrDuration": "XR_TIMEOUT_EXPIRED"
|
|
}
|
|
# Some return codes are related in that only one of a set
|
|
# may be returned by a command
|
|
# (eg. XR_ERROR_SESSION_RUNNING and XR_ERROR_SESSION_NOT_RUNNING)
|
|
self.exclusive_return_code_sets = tuple(
|
|
# set(("XR_ERROR_SESSION_NOT_RUNNING", "XR_ERROR_SESSION_RUNNING")),
|
|
)
|
|
# Map of extension number -> [ list of extension names ]
|
|
self.extension_number_reservations = {
|
|
}
|
|
|
|
# This is used to report collisions.
|
|
conventions = APIConventions()
|
|
db = EntityDatabase()
|
|
|
|
self.extension_cmds = get_extension_commands(db.registry)
|
|
self.return_codes = get_enum_value_names(db.registry, 'VkResult')
|
|
self.structure_types = get_enum_value_names(db.registry, TYPEENUM)
|
|
|
|
# Dict of entity name to a list of messages to suppress. (Exclude any context data and "Warning:"/"Error:")
|
|
# Keys are entity names, values are tuples or lists of message text to suppress.
|
|
suppressions = {}
|
|
|
|
# Initialize superclass
|
|
super().__init__(entity_db=db, conventions=conventions,
|
|
manual_types_to_codes=manual_types_to_codes,
|
|
forward_only_types_to_codes=forward_only,
|
|
reverse_only_types_to_codes=reverse_only,
|
|
suppressions=suppressions)
|
|
|
|
def check_command_return_codes_basic(self, name, info,
|
|
successcodes, errorcodes):
|
|
"""Check a command's return codes for consistency.
|
|
|
|
Called on every command."""
|
|
# Check that all extension commands can return the code associated
|
|
# with trying to use an extension that wasn't enabled.
|
|
# if name in self.extension_cmds and UNSUPPORTED not in errorcodes:
|
|
# self.record_error("Missing expected return code",
|
|
# UNSUPPORTED,
|
|
# "implied due to being an extension command")
|
|
|
|
codes = successcodes.union(errorcodes)
|
|
|
|
# Check that all return codes are recognized.
|
|
unrecognized = codes - self.return_codes
|
|
if unrecognized:
|
|
self.record_error("Unrecognized return code(s):",
|
|
unrecognized)
|
|
|
|
elem = info.elem
|
|
params = [(getElemName(elt), elt) for elt in elem.findall('param')]
|
|
|
|
def is_count_output(name, elt):
|
|
# Must end with Count or Size,
|
|
# not be const,
|
|
# and be a pointer (detected by naming convention)
|
|
return (name.endswith('Count') or name.endswith('Size')) \
|
|
and (elt.tail is None or 'const' not in elt.tail) \
|
|
and (name.startswith('p'))
|
|
|
|
countParams = [elt
|
|
for name, elt in params
|
|
if is_count_output(name, elt)]
|
|
if countParams:
|
|
assert(len(countParams) == 1)
|
|
if 'VK_INCOMPLETE' not in successcodes:
|
|
self.record_error(
|
|
"Apparent enumeration of an array without VK_INCOMPLETE in successcodes.")
|
|
|
|
elif 'VK_INCOMPLETE' in successcodes:
|
|
self.record_error(
|
|
"VK_INCOMPLETE in successcodes of command that is apparently not an array enumeration.")
|
|
|
|
def check_param(self, param):
|
|
"""Check a member of a struct or a param of a function.
|
|
|
|
Called from check_params."""
|
|
super().check_param(param)
|
|
|
|
if not self.is_api_type(param):
|
|
return
|
|
|
|
param_text = "".join(param.itertext())
|
|
param_name = getElemName(param)
|
|
|
|
# Make sure the number of leading "p" matches the pointer count.
|
|
pointercount = param.find('type').tail
|
|
if pointercount:
|
|
pointercount = pointercount.count('*')
|
|
if pointercount:
|
|
prefix = 'p' * pointercount
|
|
if not param_name.startswith(prefix):
|
|
param_type = param.find('type').text
|
|
message = "Apparently incorrect pointer-related name prefix for {} - expected it to start with '{}'".format(
|
|
param_text, prefix)
|
|
if (self.entity, param_type, param_name) in CHECK_PARAM_POINTER_NAME_EXCEPTIONS:
|
|
self.record_warning('(Allowed exception)', message, elem=param)
|
|
else:
|
|
self.record_error(message, elem=param)
|
|
|
|
# Make sure pNext members have optional="true" attributes
|
|
if param_name == self.conventions.nextpointer_member_name:
|
|
optional = param.get('optional')
|
|
if optional is None or optional != 'true':
|
|
message = '{}.pNext member is missing \'optional="true"\' attribute'.format(self.entity)
|
|
if self.entity in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS:
|
|
self.record_warning('(Allowed exception)', message, elem=param)
|
|
else:
|
|
self.record_error(message, elem=param)
|
|
|
|
def check_type(self, name, info, category):
|
|
"""Check a type's XML data for consistency.
|
|
|
|
Called from check."""
|
|
|
|
elem = info.elem
|
|
type_elts = [elt
|
|
for elt in elem.findall("member")
|
|
if getElemType(elt) == TYPEENUM]
|
|
if category == 'struct' and type_elts:
|
|
if len(type_elts) > 1:
|
|
self.record_error(
|
|
"Have more than one member of type", TYPEENUM)
|
|
else:
|
|
type_elt = type_elts[0]
|
|
val = type_elt.get('values')
|
|
if val and val not in self.structure_types:
|
|
self.record_error("Unknown structure type constant", val)
|
|
|
|
# Check the pointer chain member, if present.
|
|
next_name = self.conventions.nextpointer_member_name
|
|
next_member = findNamedElem(info.elem.findall('member'), next_name)
|
|
if next_member is not None:
|
|
# Ensure that the 'optional' attribute is set to 'true'
|
|
optional = next_member.get('optional')
|
|
if optional is None or optional != 'true':
|
|
message = '{}.{} member is missing \'optional="true"\' attribute'.format(name, next_name)
|
|
if name in CHECK_MEMBER_PNEXT_OPTIONAL_EXCEPTIONS:
|
|
self.record_warning('(Allowed exception)', message)
|
|
else:
|
|
self.record_error(message)
|
|
|
|
elif category == "bitmask":
|
|
if 'Flags' in name:
|
|
expected_require = name.replace('Flags', 'FlagBits')
|
|
require = info.elem.get('require')
|
|
if require is not None and expected_require != require:
|
|
self.record_error("Unexpected require attribute value:",
|
|
"got", require,
|
|
"but expected", expected_require)
|
|
super().check_type(name, info, category)
|
|
|
|
def check_extension(self, name, info):
|
|
"""Check an extension's XML data for consistency.
|
|
|
|
Called from check."""
|
|
elem = info.elem
|
|
enums = elem.findall('./require/enum[@name]')
|
|
|
|
# Look for other extensions using that number
|
|
# Keep track of this extension number reservation
|
|
ext_number = elem.get('number')
|
|
if ext_number in self.extension_number_reservations:
|
|
conflicts = self.extension_number_reservations[ext_number]
|
|
self.record_error('Extension number {} has more than one reservation: {}, {}'.format(
|
|
ext_number, name, ', '.join(conflicts)))
|
|
self.extension_number_reservations[ext_number].append(name)
|
|
else:
|
|
self.extension_number_reservations[ext_number] = [ name ]
|
|
|
|
# If extension name is not on the exception list and matches the
|
|
# versioned-extension pattern, map the extension name to the version
|
|
# name with the version as a separate word. Otherwise just map it to
|
|
# the upper-case version of the extension name.
|
|
|
|
matches = EXTNAME_RE.fullmatch(name)
|
|
ext_versioned_name = False
|
|
if name in EXTENSION_ENUM_NAME_SPELLING_CHANGE:
|
|
ext_enum_name = EXTENSION_ENUM_NAME_SPELLING_CHANGE.get(name)
|
|
elif matches is None or name in EXTENSION_NAME_VERSION_EXCEPTIONS:
|
|
# This is the usual case, either a name that doesn't look
|
|
# versioned, or one that does but is on the exception list.
|
|
ext_enum_name = name.upper()
|
|
else:
|
|
# This is a versioned extension name.
|
|
# Treat the version number as a separate word.
|
|
base = matches.group('base')
|
|
version = matches.group('version')
|
|
ext_enum_name = base.upper() + '_' + version
|
|
# Keep track of this case
|
|
ext_versioned_name = True
|
|
|
|
# Look for the expected SPEC_VERSION token name
|
|
version_name = "{}_SPEC_VERSION".format(ext_enum_name)
|
|
version_elem = findNamedElem(enums, version_name)
|
|
|
|
if version_elem is None:
|
|
# Did not find a SPEC_VERSION enum matching the extension name
|
|
if ext_versioned_name:
|
|
suffix = '\n\
|
|
Make sure that trailing version numbers in extension names are treated\n\
|
|
as separate words in extension enumerant names. If this is an extension\n\
|
|
whose name ends in a number which is not a version, such as "...h264"\n\
|
|
or "...int16", add it to EXTENSION_NAME_VERSION_EXCEPTIONS in\n\
|
|
scripts/xml_consistency.py.'
|
|
else:
|
|
suffix = ''
|
|
self.record_error('Missing version enum {}{}'.format(version_name, suffix))
|
|
elif info.elem.get('supported') == self.conventions.xml_api_name:
|
|
# Skip unsupported / disabled extensions for these checks
|
|
|
|
fn = get_extension_source(name)
|
|
revisions = []
|
|
with open(fn, 'r', encoding='utf-8') as fp:
|
|
for line in fp:
|
|
line = line.rstrip()
|
|
match = REVISION_RE.match(line)
|
|
if match:
|
|
revisions.append(int(match.group('num')))
|
|
ver_from_xml = version_elem.get('value')
|
|
if revisions:
|
|
ver_from_text = str(max(revisions))
|
|
if ver_from_xml != ver_from_text:
|
|
self.record_error("Version enum mismatch: spec text indicates", ver_from_text,
|
|
"but XML says", ver_from_xml)
|
|
else:
|
|
if ver_from_xml == '1':
|
|
self.record_warning(
|
|
"Cannot find version history in spec text - make sure it has lines starting exactly like '* Revision 1, ....'",
|
|
filename=fn)
|
|
else:
|
|
self.record_warning("Cannot find version history in spec text, but XML reports a non-1 version number", ver_from_xml,
|
|
" - make sure the spec text has lines starting exactly like '* Revision 1, ....'",
|
|
filename=fn)
|
|
|
|
name_define = "{}_EXTENSION_NAME".format(ext_enum_name)
|
|
name_elem = findNamedElem(enums, name_define)
|
|
if name_elem is None:
|
|
self.record_error("Missing name enum", name_define)
|
|
else:
|
|
# Note: etree handles the XML entities here and turns " back into "
|
|
expected_name = '"{}"'.format(name)
|
|
name_val = name_elem.get('value')
|
|
if name_val != expected_name:
|
|
self.record_error("Incorrect name enum: expected", expected_name,
|
|
"got", name_val)
|
|
|
|
super().check_extension(name, elem)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
ckr = Checker()
|
|
ckr.check()
|
|
|
|
if ckr.fail:
|
|
sys.exit(1)
|