diff options
Diffstat (limited to 'grpc/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py')
-rw-r--r-- | grpc/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py | 61 |
1 files changed, 36 insertions, 25 deletions
diff --git a/grpc/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py b/grpc/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py index 3bf4b261..b4e6b187 100644 --- a/grpc/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py +++ b/grpc/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py @@ -95,22 +95,25 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper): return server_socket return None - def find_channels_for_target(self, target: str) -> Iterator[Channel]: - return (channel for channel in self.list_channels() + def find_channels_for_target(self, target: str, + **kwargs) -> Iterator[Channel]: + return (channel for channel in self.list_channels(**kwargs) if channel.data.target == target) - def find_server_listening_on_port(self, port: int) -> Optional[Server]: - for server in self.list_servers(): + def find_server_listening_on_port(self, port: int, + **kwargs) -> Optional[Server]: + for server in self.list_servers(**kwargs): listen_socket_ref: SocketRef for listen_socket_ref in server.listen_socket: - listen_socket = self.get_socket(listen_socket_ref.socket_id) + listen_socket = self.get_socket(listen_socket_ref.socket_id, + **kwargs) listen_address: Address = listen_socket.local if (self.is_sock_tcpip_address(listen_address) and listen_address.tcpip_address.port == port): return server return None - def list_channels(self) -> Iterator[Channel]: + def list_channels(self, **kwargs) -> Iterator[Channel]: """ Iterate over all pages of all root channels. @@ -125,12 +128,13 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper): start += 1 response = self.call_unary_with_deadline( rpc='GetTopChannels', - req=_GetTopChannelsRequest(start_channel_id=start)) + req=_GetTopChannelsRequest(start_channel_id=start), + **kwargs) for channel in response.channel: start = max(start, channel.ref.channel_id) yield channel - def list_servers(self) -> Iterator[Server]: + def list_servers(self, **kwargs) -> Iterator[Server]: """Iterate over all pages of all servers that exist in the process.""" start: int = -1 response: Optional[_GetServersResponse] = None @@ -139,12 +143,14 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper): # value by adding 1 to the highest seen result ID. start += 1 response = self.call_unary_with_deadline( - rpc='GetServers', req=_GetServersRequest(start_server_id=start)) + rpc='GetServers', + req=_GetServersRequest(start_server_id=start), + **kwargs) for server in response.server: start = max(start, server.ref.server_id) yield server - def list_server_sockets(self, server: Server) -> Iterator[Socket]: + def list_server_sockets(self, server: Server, **kwargs) -> Iterator[Socket]: """List all server sockets that exist in server process. Iterating over the results will resolve additional pages automatically. @@ -158,39 +164,44 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper): response = self.call_unary_with_deadline( rpc='GetServerSockets', req=_GetServerSocketsRequest(server_id=server.ref.server_id, - start_socket_id=start)) + start_socket_id=start), + **kwargs) socket_ref: SocketRef for socket_ref in response.socket_ref: start = max(start, socket_ref.socket_id) # Yield actual socket - yield self.get_socket(socket_ref.socket_id) + yield self.get_socket(socket_ref.socket_id, **kwargs) - def list_channel_sockets(self, channel: Channel) -> Iterator[Socket]: + def list_channel_sockets(self, channel: Channel, + **kwargs) -> Iterator[Socket]: """List all sockets of all subchannels of a given channel.""" - for subchannel in self.list_channel_subchannels(channel): - yield from self.list_subchannels_sockets(subchannel) + for subchannel in self.list_channel_subchannels(channel, **kwargs): + yield from self.list_subchannels_sockets(subchannel, **kwargs) - def list_channel_subchannels(self, - channel: Channel) -> Iterator[Subchannel]: + def list_channel_subchannels(self, channel: Channel, + **kwargs) -> Iterator[Subchannel]: """List all subchannels of a given channel.""" for subchannel_ref in channel.subchannel_ref: - yield self.get_subchannel(subchannel_ref.subchannel_id) + yield self.get_subchannel(subchannel_ref.subchannel_id, **kwargs) - def list_subchannels_sockets(self, - subchannel: Subchannel) -> Iterator[Socket]: + def list_subchannels_sockets(self, subchannel: Subchannel, + **kwargs) -> Iterator[Socket]: """List all sockets of a given subchannel.""" for socket_ref in subchannel.socket_ref: - yield self.get_socket(socket_ref.socket_id) + yield self.get_socket(socket_ref.socket_id, **kwargs) - def get_subchannel(self, subchannel_id) -> Subchannel: + def get_subchannel(self, subchannel_id, **kwargs) -> Subchannel: """Return a single Subchannel, otherwise raises RpcError.""" response: _GetSubchannelResponse = self.call_unary_with_deadline( rpc='GetSubchannel', - req=_GetSubchannelRequest(subchannel_id=subchannel_id)) + req=_GetSubchannelRequest(subchannel_id=subchannel_id), + **kwargs) return response.subchannel - def get_socket(self, socket_id) -> Socket: + def get_socket(self, socket_id, **kwargs) -> Socket: """Return a single Socket, otherwise raises RpcError.""" response: _GetSocketResponse = self.call_unary_with_deadline( - rpc='GetSocket', req=_GetSocketRequest(socket_id=socket_id)) + rpc='GetSocket', + req=_GetSocketRequest(socket_id=socket_id), + **kwargs) return response.socket |