1#!/usr/bin/env python3
2# coding=utf-8
3
4#
5# Copyright (c) 2022 Huawei Device Co., Ltd.
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18
19import enum
20import time
21
22try:
23    from PIL import Image
24    from PIL import ImageDraw
25except ImportError:
26    pass
27from devicetest.utils.time_util import TS
28
29
30class ImgLoc(enum.Enum):
31    '''
32    图片所在的枚举类
33    1左上 (North_West)
34    2中上 (North)
35    3右上 (Nort_East)
36    4左中 (West)
37    5全图 (All)
38    6右中 (East)
39    7左下 (South_West)
40    8中下 (South)
41    9右下(South_East)
42    '''
43
44    # 在大图的左上内
45    North_West = 1
46    # 代表子图部分全部位于大图的水平中线以上
47    North = 2
48    # 在大图的右上内
49    North_East = 3
50    # 代表子图部分全部位于大图的垂直中线左
51    West = 4
52    # 全图查找,当小图跟大图的水平中线, 且与锤子中线都有交集,就只能用该参数了
53    All = 5
54    # 代表子图部分全部位于大图的垂直中线右
55    East = 6
56    # 在大图的左下内
57    South_West = 7
58    # 代表子图部分全部位于大图的水平中线以下
59    South = 8
60    # 在大图的右下内
61    South_East = 9
62
63
64class ImgUtils:
65    @staticmethod
66    def img2arr(arr_img, rect=None, convert=True):
67        '''
68        @将位图流转化为二维二值数组
69        @param arr_img: instance of Image
70        '''
71        if convert and arr_img.mode != 'L':
72            arr_img = arr_img.convert('L')
73
74        width, height = arr_img.size
75        pix = arr_img.load()
76        if rect:
77            rect_l, rect_t, rect_r, rect_b = rect
78            result_list = []
79            for pix_h in range(height):
80                if not rect_t <= pix_h <= rect_b + 1:
81                    continue
82                temp_list = []
83                for pix_w in range(width):
84                    if not rect_l <= pix_w <= rect_r + 1:
85                        continue
86                    temp_list.append(pix[pix_w, pix_h])
87                result_list.append(temp_list)
88            return result_list
89
90        result_list = []
91        for pix_h in range(height):
92            temp_list = []
93            for pix_w in range(width):
94                temp_list.append(pix[pix_w, pix_h])
95            result_list.append(temp_list)
96
97        return result_list
98
99    @staticmethod
100    def get_rect(rect_width, rect_height, location):
101        '''
102        根据方位对象,获取图片的方位
103        '''
104        rect = (0, 0, rect_width, rect_height)
105        if location == ImgLoc.East.value:
106            rect = (int(rect_width >> 1), 0, rect_width, rect_height)
107        elif location == ImgLoc.South.value:
108            rect = (0, int(rect_height >> 1), rect_width, rect_height)
109        elif location == ImgLoc.West.value:
110            rect = (0, 0, int(rect_width >> 1), rect_height)
111        elif location == ImgLoc.North.value:
112            rect = (0, 0, rect_width, int(rect_height >> 1))
113
114        elif location == ImgLoc.North_East.value:
115            rect = (int(rect_width >> 1), 0,
116                    rect_width, int(rect_height >> 1))
117        elif location == ImgLoc.South_East.value:
118            rect = (int(rect_width >> 1),
119                    int(rect_height >> 1), rect_width, rect_height)
120        elif location == ImgLoc.North_West.value:
121            rect = (0, 0, int(rect_width >> 1), int(rect_height >> 1))
122        elif location == ImgLoc.South_West.value:
123            rect = (0, int(rect_height >> 1),
124                    int(rect_width >> 1), rect_height)
125
126        return rect
127
128    @staticmethod
129    def quick_find(fp1, fp2, similar=1, density=None,
130                   rect=None, loc=ImgLoc.All.value, debug=False):
131        '''
132        快速查找图片,指定的similar越大,速度越快
133        1. 如果similar不等于1,则使用density来加快查找速度 density(x,y)
134        表示对比的时候 每个横坐标上只对比 (width / 2^x)个点
135        每个纵坐标上只对比 height / 2^y个点,相当于只对比原来的
136        (width + height) / 2^(x+y) 个点
137        @param fp1: 大图片的绝对路径
138        @param fp2: 小图片的绝对路径
139        @param similar: 对比的相似度;如 0.7, 0.9, 1
140        @param density: 小图中对比的密度: (2,3) 表示每行对比 width >> 2个点;
141        没列对比 height >> 3个点
142        @param rect: 大图中指定区域内查找 (left,top,right,bottom)
143        @param loc: 小图在大图中的什么部位,是一个枚举对象,ImgLoc,注意需要
144        加.value;如: ImgLoc.North.value 或者 直接输入 1 -9 的数字也行
145        @param debug: 是否打印debug信息
146        '''
147        if debug:
148            TS.start()
149        _m1 = Image.open(fp1)
150        _m2 = Image.open(fp2)
151
152        m1_w, m1_h = _m1.size
153        if not rect:
154            rect = ImgUtils.get_rect(m1_w, m1_h, loc)
155
156        data1 = ImgUtils.img2arr(_m1.crop(rect) if rect else _m1)
157        data2 = ImgUtils.img2arr(_m2)
158        if debug:
159            TS.stop("before find_arr")
160        return ImgUtils.find_arr(data1, data2, similar, density, rect, debug)
161
162    @staticmethod
163    def find_arr(im1, im2, similar=1, density=None, rect=None, debug=False):
164        '''
165        在大图中查找小图
166        注意:如果density值为None,则系统自动设置,保证特征点在9 - 16个左右
167        (即 3 * 3 或 4 * 4之间)
168        @param im1 大图的二维数组
169        @param im2 小图的二维数组
170        @param similar 相似度
171        @param density (x,y)  x: 可以控制小图横坐标查找的点数
172        im2Width >> x 个点数
173        @param rect 在指定的区域中查找图片 (若指定,则可以大大节省时间)
174        (leftX,topY,rihgtX,bottomY)
175        @return (rect,similar) rect:找到的图片位置; similar:相似度
176        '''
177        if debug:
178            TS.start()
179
180        m2_width = len(im2[0])
181        m2_height = len(im2)
182        arr_width = len(im1[0]) - m2_width + 1
183        arr_height = len(im1) - m2_height + 1
184
185        denx, deny = 0, 0
186        if not density:
187            denx, deny = ImgUtils.get_density(m2_width, m2_height)
188        else:
189            denx, deny = density
190        den_yy = int(m2_height >> deny)
191        den_xx = int(m2_width >> denx)
192
193        total = den_yy * den_xx
194        if total == 0:
195            total = 1
196        max_fail_num = (1 - similar) * total
197        if debug:
198            print("denXX: %i; denYY: %i; total: %i" % (
199                den_xx, den_yy, total))
200            print("maxFailNum %i" % max_fail_num)
201        starttime = time.time()
202        endtime = starttime + 5.0 * 60.0
203        for arr_h in range(arr_height):
204            for arr_w in range(arr_width):
205                # 对图片对比设置超时限制
206                if time.time() <= endtime:
207                    # 1. 对比当前位置的图片是否符合要求
208                    fail_num = 0
209                    found = True
210                    for _yy in range(den_yy):
211                        for _xx in range(den_xx):
212                            x_den = _xx << denx
213                            y_den = _yy << deny
214                            m2_val = im2[y_den][x_den]
215                            m1_val = im1[arr_h + y_den][x_den + arr_w]
216                            if m1_val != m2_val:
217                                fail_num += 1
218                                if max_fail_num <= fail_num:
219                                    found = False
220                                    break
221                        if not found:
222                            break
223                    if found:
224                        if debug:
225                            TS.stop("find_arr")
226                        if rect:
227                            # @UnusedVariable
228                            rect_l, rect_t, rect_r, rect_b = rect
229                            return (1 - fail_num / total), (
230                                arr_w + rect_l, arr_h + rect_t, arr_w +
231                                m2_width + rect_l, arr_h + m2_height + rect_t)
232                        return (1 - fail_num / total), (
233                            arr_w, arr_h, arr_w + m2_width, arr_h + m2_height)
234                else:
235                    return None, None
236        if debug:
237            TS.stop("find_arr")
238        return None, None
239
240    @staticmethod
241    def img_filter(filter_img, *filters):
242        last_img = filter_img
243        for _filter in filters:
244            last_img = last_img.filter(_filter)
245        return last_img
246
247    @staticmethod
248    def get_density(width, height, maxWNum=4, maxHNum=4):
249        denx, deny = 0, 0
250        while width > maxWNum:
251            denx += 1
252            width = int(width >> 1)
253        while height > maxHNum:
254            deny += 1
255            height = int(height >> 1)
256        return denx, deny
257