# Copyright (C) 2022 The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from perfetto.trace_uri_resolver.util import parse_trace_uri from perfetto.trace_uri_resolver.resolver import _args_dict_from_uri from perfetto.trace_uri_resolver.resolver import TraceUriResolver from perfetto.trace_uri_resolver.registry import ResolverRegistry class SimpleResolver(TraceUriResolver): PREFIX = 'simple' def __init__(self, foo=None, bar=None): self.foo = foo self.bar = bar def foo_gen(self): yield self.foo.encode() if self.foo else b'' def bar_gen(self): yield self.bar.encode() if self.bar else b'' def resolve(self): return [ TraceUriResolver.Result(self.foo_gen()), TraceUriResolver.Result( self.bar_gen(), metadata={ 'foo': self.foo, 'bar': self.bar }) ] class RecursiveResolver(SimpleResolver): PREFIX = 'recursive' def __init__(self, foo=None, bar=None): super().__init__(foo=foo, bar=bar) def resolve(self): return [ TraceUriResolver.Result(self.foo_gen()), TraceUriResolver.Result( self.bar_gen(), metadata={ 'foo': 'foo', 'bar': 'bar' }), TraceUriResolver.Result(f'simple:foo={self.foo};bar={self.bar}'), TraceUriResolver.Result(SimpleResolver(foo=self.foo, bar=self.bar)), ] class TestResolver(unittest.TestCase): def test_simple_resolve(self): registry = ResolverRegistry([SimpleResolver]) res = registry.resolve('simple:foo=x;bar=y') self.assertEqual(len(res), 2) (foo_res, bar_res) = res self._check_resolver_result(foo_res, bar_res) (foo_res, bar_res) = registry.resolve(['simple:foo=x;bar=y']) self._check_resolver_result(foo_res, bar_res) resolver = SimpleResolver(foo='x', bar='y') (foo_res, bar_res) = registry.resolve(resolver) self._check_resolver_result(foo_res, bar_res) (foo_res, bar_res) = registry.resolve([resolver]) self._check_resolver_result(foo_res, bar_res) (foo_a, bar_b, foo_x, bar_y) = registry.resolve(['simple:foo=a;bar=b', resolver]) self._check_resolver_result(foo_a, bar_b, foo='a', bar='b') self._check_resolver_result(foo_x, bar_y) def test_simple_resolve_missing_arg(self): registry = ResolverRegistry([SimpleResolver]) (foo_res, bar_res) = registry.resolve('simple:foo=x') self._check_resolver_result(foo_res, bar_res, bar=None) (foo_res, bar_res) = registry.resolve('simple:bar=y') self._check_resolver_result(foo_res, bar_res, foo=None) (foo_res, bar_res) = registry.resolve('simple:') self._check_resolver_result(foo_res, bar_res, foo=None, bar=None) def test_recursive_resolve(self): registry = ResolverRegistry([SimpleResolver]) registry.register(RecursiveResolver) res = registry.resolve('recursive:foo=x;bar=y') self.assertEqual(len(res), 6) (non_rec_foo, non_rec_bar, rec_foo_str, rec_bar_str, rec_foo_obj, rec_bar_obj) = res self._check_resolver_result( non_rec_foo, non_rec_bar, foo_metadata='foo', bar_metadata='bar') self._check_resolver_result(rec_foo_str, rec_bar_str) self._check_resolver_result(rec_foo_obj, rec_bar_obj) def test_parse_trace_uri(self): self.assertEqual(parse_trace_uri('/foo/bar'), (None, '/foo/bar')) self.assertEqual(parse_trace_uri('foo/bar'), (None, 'foo/bar')) self.assertEqual(parse_trace_uri('/foo/b:ar'), (None, '/foo/b:ar')) self.assertEqual(parse_trace_uri('./foo/b:ar'), (None, './foo/b:ar')) self.assertEqual(parse_trace_uri('foo/b:ar'), ('foo/b', 'ar')) def test_args_dict_from_uri(self): self.assertEqual(_args_dict_from_uri('foo:'), {}) self.assertEqual(_args_dict_from_uri('foo:bar=baz'), { 'bar': 'baz', }) self.assertEqual( _args_dict_from_uri('foo:key=v1,v2'), {'key': ['v1', 'v2']}) self.assertEqual( _args_dict_from_uri('foo:bar=baz;key=v1,v2'), { 'bar': 'baz', 'key': ['v1', 'v2'] }) def _check_resolver_result(self, foo_res, bar_res, foo='x', bar='y', foo_metadata=None, bar_metadata=None): self.assertEqual( tuple(foo_res.generator), (foo.encode() if foo else ''.encode(),)) self.assertEqual( tuple(bar_res.generator), (bar.encode() if bar else ''.encode(),)) self.assertEqual( bar_res.metadata, { 'foo': foo_metadata if foo_metadata else foo, 'bar': bar_metadata if bar_metadata else bar })