1e509ee18Sopenharmony_ci#!/usr/bin/env python3
2e509ee18Sopenharmony_ci# -*- coding: utf-8 -*-
3e509ee18Sopenharmony_ci"""
4e509ee18Sopenharmony_ciCopyright (c) 2024 Huawei Device Co., Ltd.
5e509ee18Sopenharmony_ciLicensed under the Apache License, Version 2.0 (the "License");
6e509ee18Sopenharmony_ciyou may not use this file except in compliance with the License.
7e509ee18Sopenharmony_ciYou may obtain a copy of the License at
8e509ee18Sopenharmony_ci
9e509ee18Sopenharmony_ci    http://www.apache.org/licenses/LICENSE-2.0
10e509ee18Sopenharmony_ci
11e509ee18Sopenharmony_ciUnless required by applicable law or agreed to in writing, software
12e509ee18Sopenharmony_cidistributed under the License is distributed on an "AS IS" BASIS,
13e509ee18Sopenharmony_ciWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14e509ee18Sopenharmony_ciSee the License for the specific language governing permissions and
15e509ee18Sopenharmony_cilimitations under the License.
16e509ee18Sopenharmony_ci
17e509ee18Sopenharmony_ciDescription: Responsible for websocket communication.
18e509ee18Sopenharmony_ci"""
19e509ee18Sopenharmony_ci
20e509ee18Sopenharmony_ciimport asyncio
21e509ee18Sopenharmony_ciimport json
22e509ee18Sopenharmony_ciimport logging
23e509ee18Sopenharmony_ci
24e509ee18Sopenharmony_ciimport websockets.protocol
25e509ee18Sopenharmony_cifrom websockets import connect, ConnectionClosed
26e509ee18Sopenharmony_ci
27e509ee18Sopenharmony_cifrom aw.fport import Fport
28e509ee18Sopenharmony_ci
29e509ee18Sopenharmony_ci
30e509ee18Sopenharmony_ciclass WebSocket(object):
31e509ee18Sopenharmony_ci    def __init__(self, connect_server_port, debugger_server_port):
32e509ee18Sopenharmony_ci        self.connect_server_port = connect_server_port
33e509ee18Sopenharmony_ci        self.debugger_server_port = debugger_server_port
34e509ee18Sopenharmony_ci        self.debugger_server_connection_threshold = 3
35e509ee18Sopenharmony_ci
36e509ee18Sopenharmony_ci        self.to_send_msg_queue_for_connect_server = None
37e509ee18Sopenharmony_ci        self.received_msg_queue_for_connect_server = None
38e509ee18Sopenharmony_ci
39e509ee18Sopenharmony_ci        self.to_send_msg_queues = {}  # key: instance_id, value: to_send_msg_queue
40e509ee18Sopenharmony_ci        self.received_msg_queues = {}  # key: instance_id, value: received_msg_queue
41e509ee18Sopenharmony_ci        self.debugger_server_instance = None
42e509ee18Sopenharmony_ci
43e509ee18Sopenharmony_ci    @staticmethod
44e509ee18Sopenharmony_ci    async def recv_msg_of_debugger_server(instance_id, queue):
45e509ee18Sopenharmony_ci        message = await queue.get()
46e509ee18Sopenharmony_ci        queue.task_done()
47e509ee18Sopenharmony_ci        logging.info(f'[<==] Instance {instance_id} receive message: {message}')
48e509ee18Sopenharmony_ci        return message
49e509ee18Sopenharmony_ci
50e509ee18Sopenharmony_ci    @staticmethod
51e509ee18Sopenharmony_ci    async def send_msg_to_debugger_server(instance_id, queue, message):
52e509ee18Sopenharmony_ci        await queue.put(message)
53e509ee18Sopenharmony_ci        logging.info(f'[==>] Instance {instance_id} send message: {message}')
54e509ee18Sopenharmony_ci        return True
55e509ee18Sopenharmony_ci
56e509ee18Sopenharmony_ci    @staticmethod
57e509ee18Sopenharmony_ci    async def _sender(client, send_queue):
58e509ee18Sopenharmony_ci        assert client.state == websockets.protocol.OPEN, logging.error(f'Client state of _sender is: {client.state}')
59e509ee18Sopenharmony_ci        while True:
60e509ee18Sopenharmony_ci            send_message = await send_queue.get()
61e509ee18Sopenharmony_ci            send_queue.task_done()
62e509ee18Sopenharmony_ci            if send_message == 'close':
63e509ee18Sopenharmony_ci                await client.close(reason='close')
64e509ee18Sopenharmony_ci                return
65e509ee18Sopenharmony_ci            await client.send(json.dumps(send_message))
66e509ee18Sopenharmony_ci
67e509ee18Sopenharmony_ci    @staticmethod
68e509ee18Sopenharmony_ci    async def _receiver(client, received_queue):
69e509ee18Sopenharmony_ci        assert client.state == websockets.protocol.OPEN, logging.error(f'Client state of _receiver is: {client.state}')
70e509ee18Sopenharmony_ci        while True:
71e509ee18Sopenharmony_ci            try:
72e509ee18Sopenharmony_ci                response = await client.recv()
73e509ee18Sopenharmony_ci                await received_queue.put(response)
74e509ee18Sopenharmony_ci            except ConnectionClosed:
75e509ee18Sopenharmony_ci                logging.info('Debugger server connection closed')
76e509ee18Sopenharmony_ci                return
77e509ee18Sopenharmony_ci
78e509ee18Sopenharmony_ci    async def get_instance(self):
79e509ee18Sopenharmony_ci        instance_id = await self.debugger_server_instance.get()
80e509ee18Sopenharmony_ci        self.debugger_server_instance.task_done()
81e509ee18Sopenharmony_ci        return instance_id
82e509ee18Sopenharmony_ci
83e509ee18Sopenharmony_ci    async def recv_msg_of_connect_server(self):
84e509ee18Sopenharmony_ci        message = await self.received_msg_queue_for_connect_server.get()
85e509ee18Sopenharmony_ci        self.received_msg_queue_for_connect_server.task_done()
86e509ee18Sopenharmony_ci        return message
87e509ee18Sopenharmony_ci
88e509ee18Sopenharmony_ci    async def send_msg_to_connect_server(self, message):
89e509ee18Sopenharmony_ci        await self.to_send_msg_queue_for_connect_server.put(message)
90e509ee18Sopenharmony_ci        logging.info(f'[==>] Connect server send message: {message}')
91e509ee18Sopenharmony_ci        return True
92e509ee18Sopenharmony_ci
93e509ee18Sopenharmony_ci    async def main_task(self, taskpool, procedure, pid):
94e509ee18Sopenharmony_ci        # the async queue must be initialized in task
95e509ee18Sopenharmony_ci        self.to_send_msg_queue_for_connect_server = asyncio.Queue()
96e509ee18Sopenharmony_ci        self.received_msg_queue_for_connect_server = asyncio.Queue()
97e509ee18Sopenharmony_ci        self.debugger_server_instance = asyncio.Queue(maxsize=1)
98e509ee18Sopenharmony_ci
99e509ee18Sopenharmony_ci        connect_server_client = await self._connect_connect_server()
100e509ee18Sopenharmony_ci        taskpool.submit(self._sender(connect_server_client, self.to_send_msg_queue_for_connect_server))
101e509ee18Sopenharmony_ci        taskpool.submit(self._receiver_of_connect_server(connect_server_client,
102e509ee18Sopenharmony_ci                                                         self.received_msg_queue_for_connect_server,
103e509ee18Sopenharmony_ci                                                         taskpool, pid))
104e509ee18Sopenharmony_ci        taskpool.submit(procedure(self))
105e509ee18Sopenharmony_ci
106e509ee18Sopenharmony_ci    def _connect_connect_server(self):
107e509ee18Sopenharmony_ci        client = connect(f'ws://localhost:{self.connect_server_port}',
108e509ee18Sopenharmony_ci                         open_timeout=10,
109e509ee18Sopenharmony_ci                         ping_interval=None)
110e509ee18Sopenharmony_ci        return client
111e509ee18Sopenharmony_ci
112e509ee18Sopenharmony_ci    def _connect_debugger_server(self):
113e509ee18Sopenharmony_ci        client = connect(f'ws://localhost:{self.debugger_server_port}',
114e509ee18Sopenharmony_ci                         open_timeout=6,
115e509ee18Sopenharmony_ci                         ping_interval=None)
116e509ee18Sopenharmony_ci        return client
117e509ee18Sopenharmony_ci
118e509ee18Sopenharmony_ci    async def _receiver_of_connect_server(self, client, receive_queue, taskpool, pid):
119e509ee18Sopenharmony_ci        assert client.state == websockets.protocol.OPEN, \
120e509ee18Sopenharmony_ci            logging.error(f'Client state of _receiver_of_connect_server is: {client.state}')
121e509ee18Sopenharmony_ci        num_debugger_server_client = 0
122e509ee18Sopenharmony_ci        while True:
123e509ee18Sopenharmony_ci            try:
124e509ee18Sopenharmony_ci                response = await client.recv()
125e509ee18Sopenharmony_ci                await receive_queue.put(response)
126e509ee18Sopenharmony_ci                logging.info(f'[<==] Connect server receive message: {response}')
127e509ee18Sopenharmony_ci                response = json.loads(response)
128e509ee18Sopenharmony_ci
129e509ee18Sopenharmony_ci                # The debugger server client is only responsible for adding and removing instances
130e509ee18Sopenharmony_ci                if (response['type'] == 'addInstance' and
131e509ee18Sopenharmony_ci                        num_debugger_server_client < self.debugger_server_connection_threshold):
132e509ee18Sopenharmony_ci                    instance_id = response['instanceId']
133e509ee18Sopenharmony_ci
134e509ee18Sopenharmony_ci                    port = Fport.fport_debugger_server(self.debugger_server_port, pid, instance_id)
135e509ee18Sopenharmony_ci                    assert port > 0, logging.error('Failed to fport debugger server for 3 times, '
136e509ee18Sopenharmony_ci                                                   'the port is very likely occupied')
137e509ee18Sopenharmony_ci                    self.debugger_server_port = port
138e509ee18Sopenharmony_ci                    debugger_server_client = await self._connect_debugger_server()
139e509ee18Sopenharmony_ci                    logging.info(f'InstanceId: {instance_id}, port: {self.debugger_server_port}, '
140e509ee18Sopenharmony_ci                                 f'debugger server connected')
141e509ee18Sopenharmony_ci                    self.debugger_server_port += 1
142e509ee18Sopenharmony_ci
143e509ee18Sopenharmony_ci                    to_send_msg_queue = asyncio.Queue()
144e509ee18Sopenharmony_ci                    received_msg_queue = asyncio.Queue()
145e509ee18Sopenharmony_ci                    self.to_send_msg_queues[instance_id] = to_send_msg_queue
146e509ee18Sopenharmony_ci                    self.received_msg_queues[instance_id] = received_msg_queue
147e509ee18Sopenharmony_ci                    taskpool.submit(coroutine=self._sender(debugger_server_client, to_send_msg_queue))
148e509ee18Sopenharmony_ci                    taskpool.submit(coroutine=self._receiver(debugger_server_client, received_msg_queue))
149e509ee18Sopenharmony_ci
150e509ee18Sopenharmony_ci                    await self._store_instance(instance_id)
151e509ee18Sopenharmony_ci                    num_debugger_server_client += 1
152e509ee18Sopenharmony_ci
153e509ee18Sopenharmony_ci                elif response['type'] == 'destroyInstance':
154e509ee18Sopenharmony_ci                    instance_id = response['instanceId']
155e509ee18Sopenharmony_ci                    to_send_msg_queue = self.to_send_msg_queues[instance_id]
156e509ee18Sopenharmony_ci                    await self.send_msg_to_debugger_server(instance_id, to_send_msg_queue, 'close')
157e509ee18Sopenharmony_ci                    num_debugger_server_client -= 1
158e509ee18Sopenharmony_ci
159e509ee18Sopenharmony_ci            except ConnectionClosed:
160e509ee18Sopenharmony_ci                logging.info('Connect server connection closed')
161e509ee18Sopenharmony_ci                return
162e509ee18Sopenharmony_ci
163e509ee18Sopenharmony_ci    async def _store_instance(self, instance_id):
164e509ee18Sopenharmony_ci        await self.debugger_server_instance.put(instance_id)
165e509ee18Sopenharmony_ci        return True
166