1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3""" 4Copyright (c) 2024 Huawei Device Co., Ltd. 5Licensed under the Apache License, Version 2.0 (the "License"); 6you may not use this file except in compliance with the License. 7You may obtain a copy of the License at 8 9 http://www.apache.org/licenses/LICENSE-2.0 10 11Unless required by applicable law or agreed to in writing, software 12distributed under the License is distributed on an "AS IS" BASIS, 13WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14See the License for the specific language governing permissions and 15limitations under the License. 16 17Description: Responsible for websocket communication. 18""" 19 20import asyncio 21import json 22import logging 23 24import websockets.protocol 25from websockets import connect, ConnectionClosed 26 27from aw.fport import Fport 28 29 30class WebSocket(object): 31 def __init__(self, connect_server_port, debugger_server_port): 32 self.connect_server_port = connect_server_port 33 self.debugger_server_port = debugger_server_port 34 self.debugger_server_connection_threshold = 3 35 36 self.to_send_msg_queue_for_connect_server = None 37 self.received_msg_queue_for_connect_server = None 38 39 self.to_send_msg_queues = {} # key: instance_id, value: to_send_msg_queue 40 self.received_msg_queues = {} # key: instance_id, value: received_msg_queue 41 self.debugger_server_instance = None 42 43 @staticmethod 44 async def recv_msg_of_debugger_server(instance_id, queue): 45 message = await queue.get() 46 queue.task_done() 47 logging.info(f'[<==] Instance {instance_id} receive message: {message}') 48 return message 49 50 @staticmethod 51 async def send_msg_to_debugger_server(instance_id, queue, message): 52 await queue.put(message) 53 logging.info(f'[==>] Instance {instance_id} send message: {message}') 54 return True 55 56 @staticmethod 57 async def _sender(client, send_queue): 58 assert client.state == websockets.protocol.OPEN, logging.error(f'Client state of _sender is: {client.state}') 59 while True: 60 send_message = await send_queue.get() 61 send_queue.task_done() 62 if send_message == 'close': 63 await client.close(reason='close') 64 return 65 await client.send(json.dumps(send_message)) 66 67 @staticmethod 68 async def _receiver(client, received_queue): 69 assert client.state == websockets.protocol.OPEN, logging.error(f'Client state of _receiver is: {client.state}') 70 while True: 71 try: 72 response = await client.recv() 73 await received_queue.put(response) 74 except ConnectionClosed: 75 logging.info('Debugger server connection closed') 76 return 77 78 async def get_instance(self): 79 instance_id = await self.debugger_server_instance.get() 80 self.debugger_server_instance.task_done() 81 return instance_id 82 83 async def recv_msg_of_connect_server(self): 84 message = await self.received_msg_queue_for_connect_server.get() 85 self.received_msg_queue_for_connect_server.task_done() 86 return message 87 88 async def send_msg_to_connect_server(self, message): 89 await self.to_send_msg_queue_for_connect_server.put(message) 90 logging.info(f'[==>] Connect server send message: {message}') 91 return True 92 93 async def main_task(self, taskpool, procedure, pid): 94 # the async queue must be initialized in task 95 self.to_send_msg_queue_for_connect_server = asyncio.Queue() 96 self.received_msg_queue_for_connect_server = asyncio.Queue() 97 self.debugger_server_instance = asyncio.Queue(maxsize=1) 98 99 connect_server_client = await self._connect_connect_server() 100 taskpool.submit(self._sender(connect_server_client, self.to_send_msg_queue_for_connect_server)) 101 taskpool.submit(self._receiver_of_connect_server(connect_server_client, 102 self.received_msg_queue_for_connect_server, 103 taskpool, pid)) 104 taskpool.submit(procedure(self)) 105 106 def _connect_connect_server(self): 107 client = connect(f'ws://localhost:{self.connect_server_port}', 108 open_timeout=10, 109 ping_interval=None) 110 return client 111 112 def _connect_debugger_server(self): 113 client = connect(f'ws://localhost:{self.debugger_server_port}', 114 open_timeout=6, 115 ping_interval=None) 116 return client 117 118 async def _receiver_of_connect_server(self, client, receive_queue, taskpool, pid): 119 assert client.state == websockets.protocol.OPEN, \ 120 logging.error(f'Client state of _receiver_of_connect_server is: {client.state}') 121 num_debugger_server_client = 0 122 while True: 123 try: 124 response = await client.recv() 125 await receive_queue.put(response) 126 logging.info(f'[<==] Connect server receive message: {response}') 127 response = json.loads(response) 128 129 # The debugger server client is only responsible for adding and removing instances 130 if (response['type'] == 'addInstance' and 131 num_debugger_server_client < self.debugger_server_connection_threshold): 132 instance_id = response['instanceId'] 133 134 port = Fport.fport_debugger_server(self.debugger_server_port, pid, instance_id) 135 assert port > 0, logging.error('Failed to fport debugger server for 3 times, ' 136 'the port is very likely occupied') 137 self.debugger_server_port = port 138 debugger_server_client = await self._connect_debugger_server() 139 logging.info(f'InstanceId: {instance_id}, port: {self.debugger_server_port}, ' 140 f'debugger server connected') 141 self.debugger_server_port += 1 142 143 to_send_msg_queue = asyncio.Queue() 144 received_msg_queue = asyncio.Queue() 145 self.to_send_msg_queues[instance_id] = to_send_msg_queue 146 self.received_msg_queues[instance_id] = received_msg_queue 147 taskpool.submit(coroutine=self._sender(debugger_server_client, to_send_msg_queue)) 148 taskpool.submit(coroutine=self._receiver(debugger_server_client, received_msg_queue)) 149 150 await self._store_instance(instance_id) 151 num_debugger_server_client += 1 152 153 elif response['type'] == 'destroyInstance': 154 instance_id = response['instanceId'] 155 to_send_msg_queue = self.to_send_msg_queues[instance_id] 156 await self.send_msg_to_debugger_server(instance_id, to_send_msg_queue, 'close') 157 num_debugger_server_client -= 1 158 159 except ConnectionClosed: 160 logging.info('Connect server connection closed') 161 return 162 163 async def _store_instance(self, instance_id): 164 await self.debugger_server_instance.put(instance_id) 165 return True 166