# Copyright 2015 gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """The Python implementation of the GRPC interoperability test client.""" import os from absl import app from absl.flags import argparse_flags from google import auth as google_auth from google.auth import jwt as google_auth_jwt import grpc from src.proto.grpc.testing import test_pb2_grpc from tests.interop import methods from tests.interop import resources def parse_interop_client_args(argv): parser = argparse_flags.ArgumentParser() parser.add_argument( "--server_host", default="localhost", type=str, help="the host to which to connect", ) parser.add_argument( "--server_port", type=int, required=True, help="the port to which to connect", ) parser.add_argument( "--test_case", default="large_unary", type=str, help="the test case to execute", ) parser.add_argument( "--use_tls", default=False, type=resources.parse_bool, help="require a secure connection", ) parser.add_argument( "--use_alts", default=False, type=resources.parse_bool, help="require an ALTS secure connection", ) parser.add_argument( "--use_test_ca", default=False, type=resources.parse_bool, help="replace platform root CAs with ca.pem", ) parser.add_argument( "--custom_credentials_type", choices=["compute_engine_channel_creds"], default=None, help="use google default credentials", ) parser.add_argument( "--server_host_override", type=str, help="the server host to which to claim to connect", ) parser.add_argument( "--oauth_scope", type=str, help="scope for OAuth tokens" ) parser.add_argument( "--default_service_account", type=str, help="email address of the default service account", ) parser.add_argument( "--grpc_test_use_grpclb_with_child_policy", type=str, help=( "If non-empty, set a static service config on channels created by " + "grpc::CreateTestChannel, that configures the grpclb LB policy " + "with a child policy being the value of this flag (e.g." " round_robin " + "or pick_first)." ), ) return parser.parse_args(argv[1:]) def _create_call_credentials(args): if args.test_case == "oauth2_auth_token": google_credentials, unused_project_id = google_auth.default( scopes=[args.oauth_scope] ) google_credentials.refresh(google_auth.transport.requests.Request()) return grpc.access_token_call_credentials(google_credentials.token) elif args.test_case == "compute_engine_creds": google_credentials, unused_project_id = google_auth.default( scopes=[args.oauth_scope] ) return grpc.metadata_call_credentials( google_auth.transport.grpc.AuthMetadataPlugin( credentials=google_credentials, request=google_auth.transport.requests.Request(), ) ) elif args.test_case == "jwt_token_creds": google_credentials = ( google_auth_jwt.OnDemandCredentials.from_service_account_file( os.environ[google_auth.environment_vars.CREDENTIALS] ) ) return grpc.metadata_call_credentials( google_auth.transport.grpc.AuthMetadataPlugin( credentials=google_credentials, request=None ) ) else: return None def get_secure_channel_parameters(args): call_credentials = _create_call_credentials(args) channel_opts = () if args.grpc_test_use_grpclb_with_child_policy: channel_opts += ( ( "grpc.service_config", '{"loadBalancingConfig": [{"grpclb": {"childPolicy": [{"%s":' " {}}]}}]}" % args.grpc_test_use_grpclb_with_child_policy, ), ) if args.custom_credentials_type is not None: if args.custom_credentials_type == "compute_engine_channel_creds": assert call_credentials is None google_credentials, unused_project_id = google_auth.default( scopes=[args.oauth_scope] ) call_creds = grpc.metadata_call_credentials( google_auth.transport.grpc.AuthMetadataPlugin( credentials=google_credentials, request=google_auth.transport.requests.Request(), ) ) channel_credentials = grpc.compute_engine_channel_credentials( call_creds ) else: raise ValueError( "Unknown credentials type '{}'".format( args.custom_credentials_type ) ) elif args.use_tls: if args.use_test_ca: root_certificates = resources.test_root_certificates() else: root_certificates = None # will load default roots. channel_credentials = grpc.ssl_channel_credentials(root_certificates) if call_credentials is not None: channel_credentials = grpc.composite_channel_credentials( channel_credentials, call_credentials ) if args.server_host_override: channel_opts += ( ( "grpc.ssl_target_name_override", args.server_host_override, ), ) elif args.use_alts: channel_credentials = grpc.alts_channel_credentials() return channel_credentials, channel_opts def _create_channel(args): target = "{}:{}".format(args.server_host, args.server_port) if ( args.use_tls or args.use_alts or args.custom_credentials_type is not None ): channel_credentials, options = get_secure_channel_parameters(args) return grpc.secure_channel(target, channel_credentials, options) else: return grpc.insecure_channel(target) def create_stub(channel, args): if args.test_case == "unimplemented_service": return test_pb2_grpc.UnimplementedServiceStub(channel) else: return test_pb2_grpc.TestServiceStub(channel) def _test_case_from_arg(test_case_arg): for test_case in methods.TestCase: if test_case_arg == test_case.value: return test_case else: raise ValueError('No test case "%s"!' % test_case_arg) def test_interoperability(args): channel = _create_channel(args) stub = create_stub(channel, args) test_case = _test_case_from_arg(args.test_case) test_case.test_interoperability(stub, args) if __name__ == "__main__": app.run(test_interoperability, flags_parser=parse_interop_client_args)