#!/usr/bin/env python3 # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Custom mmi2grpc gRPC compiler.""" import sys from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, \ CodeGeneratorResponse def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs) request = CodeGeneratorRequest.FromString(sys.stdin.buffer.read()) def has_type(proto_file, type_name): return any(filter(lambda x: x.name == type_name, proto_file.message_type)) def import_type(imports, type): package = type[1:type.rindex('.')] type_name = type[type.rindex('.')+1:] file = next(filter( lambda x: x.package == package and has_type(x, type_name), request.proto_file)) python_path = file.name.replace('.proto', '').replace('/', '.') as_name = python_path.replace('.', '_dot_') + '__pb2' module_path = python_path[:python_path.rindex('.')] module_name = python_path[python_path.rindex('.')+1:] + '_pb2' imports.add(f'from {module_path} import {module_name} as {as_name}') return f'{as_name}.{type_name}' def generate_service_method(imports, file, service, method): input_mode = 'stream' if method.client_streaming else 'unary' output_mode = 'stream' if method.server_streaming else 'unary' input_type = import_type(imports, method.input_type) output_type = import_type(imports, method.output_type) if input_mode == 'stream': return ( f'def {method.name}(self, iterator, **kwargs):\n' f' return self.channel.{input_mode}_{output_mode}(\n' f" '/{file.package}.{service.name}/{method.name}',\n" f' request_serializer={input_type}.SerializeToString,\n' f' response_deserializer={output_type}.FromString\n' f' )(iterator, **kwargs)' ).split('\n') else: return ( f'def {method.name}(self, wait_for_ready=None, **kwargs):\n' f' return self.channel.{input_mode}_{output_mode}(\n' f" '/{file.package}.{service.name}/{method.name}',\n" f' request_serializer={input_type}.SerializeToString,\n' f' response_deserializer={output_type}.FromString\n' f' )({input_type}(**kwargs), wait_for_ready=wait_for_ready)' ).split('\n') def generate_service(imports, file, service): methods = '\n\n '.join([ '\n '.join( generate_service_method(imports, file, service, method) ) for method in service.method ]) return ( f'class {service.name}:\n' f' def __init__(self, channel):\n' f' self.channel = channel\n' f'\n' f' {methods}\n' ).split('\n') def generate_servicer_method(method): input_mode = 'stream' if method.client_streaming else 'unary' if input_mode == 'stream': return ( f'def {method.name}(self, request_iterator, context):\n' f' context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n' f' context.set_details("Method not implemented!")\n' f' raise NotImplementedError("Method not implemented!")' ).split('\n') else: return ( f'def {method.name}(self, request, context):\n' f' context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n' f' context.set_details("Method not implemented!")\n' f' raise NotImplementedError("Method not implemented!")' ).split('\n') def generate_servicer(service): methods = '\n\n '.join([ '\n '.join( generate_servicer_method(method) ) for method in service.method ]) return ( f'class {service.name}Servicer:\n' f'\n' f' {methods}\n' ).split('\n') def generate_rpc_method_handler(imports, method): input_mode = 'stream' if method.client_streaming else 'unary' output_mode = 'stream' if method.server_streaming else 'unary' input_type = import_type(imports, method.input_type) output_type = import_type(imports, method.output_type) return ( f"'{method.name}': grpc.{input_mode}_{output_mode}_rpc_method_handler(\n" f' servicer.{method.name},\n' f' request_deserializer={input_type}.FromString,\n' f' response_serializer={output_type}.SerializeToString,\n' f' ),\n' ).split('\n') def generate_add_servicer_to_server_method(imports, file, service): method_handlers = ' '.join([ '\n '.join( generate_rpc_method_handler(imports, method) ) for method in service.method ]) return ( f'def add_{service.name}Servicer_to_server(servicer, server):\n' f' rpc_method_handlers = {{\n' f' {method_handlers}\n' f' }}\n' f' generic_handler = grpc.method_handlers_generic_handler(\n' f" '{file.package}.{service.name}', rpc_method_handlers)\n" f' server.add_generic_rpc_handlers((generic_handler,))' ).split('\n') files = [] for file_name in request.file_to_generate: file = next(filter(lambda x: x.name == file_name, request.proto_file)) imports = set(['import grpc']) services = '\n'.join(sum([ generate_service(imports, file, service) for service in file.service ], [])) servicers = '\n'.join(sum([ generate_servicer(service) for service in file.service ], [])) add_servicer_methods = '\n'.join(sum([ generate_add_servicer_to_server_method(imports, file, service) for service in file.service ], [])) files.append(CodeGeneratorResponse.File( name=file_name.replace('.proto', '_grpc.py'), content='\n'.join(imports) + '\n\n' + services + '\n\n' + servicers + '\n\n' + add_servicer_methods + '\n' )) response = CodeGeneratorResponse(file=files) sys.stdout.buffer.write(response.SerializeToString())