1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Copyright (c) 2024 Huawei Device Co., Ltd.
5Licensed under the Apache License, Version 2.0 (the "License");
6you may not use this file except in compliance with the License.
7You may obtain a copy of the License at
8
9    http://www.apache.org/licenses/LICENSE-2.0
10
11Unless required by applicable law or agreed to in writing, software
12distributed under the License is distributed on an "AS IS" BASIS,
13WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14See the License for the specific language governing permissions and
15limitations under the License.
16
17Description: Responsible for websocket communication.
18"""
19
20import asyncio
21import json
22import logging
23
24import websockets.protocol
25from websockets import connect, ConnectionClosed
26
27from aw.fport import Fport
28
29
30class WebSocket(object):
31    def __init__(self, connect_server_port, debugger_server_port):
32        self.connect_server_port = connect_server_port
33        self.debugger_server_port = debugger_server_port
34        self.debugger_server_connection_threshold = 3
35
36        self.to_send_msg_queue_for_connect_server = None
37        self.received_msg_queue_for_connect_server = None
38
39        self.to_send_msg_queues = {}  # key: instance_id, value: to_send_msg_queue
40        self.received_msg_queues = {}  # key: instance_id, value: received_msg_queue
41        self.debugger_server_instance = None
42
43    @staticmethod
44    async def recv_msg_of_debugger_server(instance_id, queue):
45        message = await queue.get()
46        queue.task_done()
47        logging.info(f'[<==] Instance {instance_id} receive message: {message}')
48        return message
49
50    @staticmethod
51    async def send_msg_to_debugger_server(instance_id, queue, message):
52        await queue.put(message)
53        logging.info(f'[==>] Instance {instance_id} send message: {message}')
54        return True
55
56    @staticmethod
57    async def _sender(client, send_queue):
58        assert client.state == websockets.protocol.OPEN, logging.error(f'Client state of _sender is: {client.state}')
59        while True:
60            send_message = await send_queue.get()
61            send_queue.task_done()
62            if send_message == 'close':
63                await client.close(reason='close')
64                return
65            await client.send(json.dumps(send_message))
66
67    @staticmethod
68    async def _receiver(client, received_queue):
69        assert client.state == websockets.protocol.OPEN, logging.error(f'Client state of _receiver is: {client.state}')
70        while True:
71            try:
72                response = await client.recv()
73                await received_queue.put(response)
74            except ConnectionClosed:
75                logging.info('Debugger server connection closed')
76                return
77
78    async def get_instance(self):
79        instance_id = await self.debugger_server_instance.get()
80        self.debugger_server_instance.task_done()
81        return instance_id
82
83    async def recv_msg_of_connect_server(self):
84        message = await self.received_msg_queue_for_connect_server.get()
85        self.received_msg_queue_for_connect_server.task_done()
86        return message
87
88    async def send_msg_to_connect_server(self, message):
89        await self.to_send_msg_queue_for_connect_server.put(message)
90        logging.info(f'[==>] Connect server send message: {message}')
91        return True
92
93    async def main_task(self, taskpool, procedure, pid):
94        # the async queue must be initialized in task
95        self.to_send_msg_queue_for_connect_server = asyncio.Queue()
96        self.received_msg_queue_for_connect_server = asyncio.Queue()
97        self.debugger_server_instance = asyncio.Queue(maxsize=1)
98
99        connect_server_client = await self._connect_connect_server()
100        taskpool.submit(self._sender(connect_server_client, self.to_send_msg_queue_for_connect_server))
101        taskpool.submit(self._receiver_of_connect_server(connect_server_client,
102                                                         self.received_msg_queue_for_connect_server,
103                                                         taskpool, pid))
104        taskpool.submit(procedure(self))
105
106    def _connect_connect_server(self):
107        client = connect(f'ws://localhost:{self.connect_server_port}',
108                         open_timeout=10,
109                         ping_interval=None)
110        return client
111
112    def _connect_debugger_server(self):
113        client = connect(f'ws://localhost:{self.debugger_server_port}',
114                         open_timeout=6,
115                         ping_interval=None)
116        return client
117
118    async def _receiver_of_connect_server(self, client, receive_queue, taskpool, pid):
119        assert client.state == websockets.protocol.OPEN, \
120            logging.error(f'Client state of _receiver_of_connect_server is: {client.state}')
121        num_debugger_server_client = 0
122        while True:
123            try:
124                response = await client.recv()
125                await receive_queue.put(response)
126                logging.info(f'[<==] Connect server receive message: {response}')
127                response = json.loads(response)
128
129                # The debugger server client is only responsible for adding and removing instances
130                if (response['type'] == 'addInstance' and
131                        num_debugger_server_client < self.debugger_server_connection_threshold):
132                    instance_id = response['instanceId']
133
134                    port = Fport.fport_debugger_server(self.debugger_server_port, pid, instance_id)
135                    assert port > 0, logging.error('Failed to fport debugger server for 3 times, '
136                                                   'the port is very likely occupied')
137                    self.debugger_server_port = port
138                    debugger_server_client = await self._connect_debugger_server()
139                    logging.info(f'InstanceId: {instance_id}, port: {self.debugger_server_port}, '
140                                 f'debugger server connected')
141                    self.debugger_server_port += 1
142
143                    to_send_msg_queue = asyncio.Queue()
144                    received_msg_queue = asyncio.Queue()
145                    self.to_send_msg_queues[instance_id] = to_send_msg_queue
146                    self.received_msg_queues[instance_id] = received_msg_queue
147                    taskpool.submit(coroutine=self._sender(debugger_server_client, to_send_msg_queue))
148                    taskpool.submit(coroutine=self._receiver(debugger_server_client, received_msg_queue))
149
150                    await self._store_instance(instance_id)
151                    num_debugger_server_client += 1
152
153                elif response['type'] == 'destroyInstance':
154                    instance_id = response['instanceId']
155                    to_send_msg_queue = self.to_send_msg_queues[instance_id]
156                    await self.send_msg_to_debugger_server(instance_id, to_send_msg_queue, 'close')
157                    num_debugger_server_client -= 1
158
159            except ConnectionClosed:
160                logging.info('Connect server connection closed')
161                return
162
163    async def _store_instance(self, instance_id):
164        await self.debugger_server_instance.put(instance_id)
165        return True
166