1#!/usr/bin/env python3
2# vim: set expandtab shiftwidth=4:
3# -*- Mode: python; coding: utf-8; indent-tabs-mode: nil -*- */
4#
5# Copyright © 2018 Red Hat, Inc.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a
8# copy of this software and associated documentation files (the "Software"),
9# to deal in the Software without restriction, including without limitation
10# the rights to use, copy, modify, merge, publish, distribute, sublicense,
11# and/or sell copies of the Software, and to permit persons to whom the
12# Software is furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice (including the next
15# paragraph) shall be included in all copies or substantial portions of the
16# Software.
17#
18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
21# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
24# DEALINGS IN THE SOFTWARE.
25
26import os
27import sys
28import time
29import math
30import multiprocessing
31import argparse
32from pathlib import Path
33
34try:
35    import libevdev
36    import yaml
37    import pyudev
38except ModuleNotFoundError as e:
39    print("Error: {}".format(e), file=sys.stderr)
40    print(
41        "One or more python modules are missing. Please install those "
42        "modules and re-run this tool."
43    )
44    sys.exit(1)
45
46
47SUPPORTED_FILE_VERSION = 1
48
49
50def error(msg, **kwargs):
51    print(msg, **kwargs, file=sys.stderr)
52
53
54class YamlException(Exception):
55    pass
56
57
58def fetch(yaml, key):
59    """Helper function to avoid confusing a YAML error with a
60    normal KeyError bug"""
61    try:
62        return yaml[key]
63    except KeyError:
64        raise YamlException("Failed to get '{}' from recording.".format(key))
65
66
67def check_udev_properties(yaml_data, uinput):
68    """
69    Compare the properties our new uinput device has with the ones from the
70    recording and ring the alarm bell if one of them is off.
71    """
72    yaml_udev_section = fetch(yaml_data, "udev")
73    yaml_udev_props = fetch(yaml_udev_section, "properties")
74    yaml_props = {
75        k: v for (k, v) in [prop.split("=", maxsplit=1) for prop in yaml_udev_props]
76    }
77    try:
78        # We don't assign this one to virtual devices
79        del yaml_props["LIBINPUT_DEVICE_GROUP"]
80    except KeyError:
81        pass
82
83    # give udev some time to catch up
84    time.sleep(0.2)
85    context = pyudev.Context()
86    udev_device = pyudev.Devices.from_device_file(context, uinput.devnode)
87    for name, value in udev_device.properties.items():
88        if name in yaml_props:
89            if yaml_props[name] != value:
90                error(
91                    f"Warning: udev property mismatch: recording has {name}={yaml_props[name]}, device has {name}={value}"
92                )
93            del yaml_props[name]
94        else:
95            # The list of properties we add to the recording, see libinput-record.c
96            prefixes = (
97                "ID_INPUT",
98                "LIBINPUT",
99                "EVDEV_ABS",
100                "MOUSE_DPI",
101                "POINTINGSTICK_",
102            )
103            for prefix in prefixes:
104                if name.startswith(prefix):
105                    error(f"Warning: unexpected property: {name}={value}")
106
107    # the ones we found above were removed from the dict
108    for name, value in yaml_props.items():
109        error(f"Warning: device is missing recorded udev property: {name}={value}")
110
111
112def create(device):
113    evdev = fetch(device, "evdev")
114
115    d = libevdev.Device()
116    d.name = fetch(evdev, "name")
117
118    ids = fetch(evdev, "id")
119    if len(ids) != 4:
120        raise YamlException("Invalid ID format: {}".format(ids))
121    d.id = dict(zip(["bustype", "vendor", "product", "version"], ids))
122
123    codes = fetch(evdev, "codes")
124    for evtype, evcodes in codes.items():
125        for code in evcodes:
126            data = None
127            if evtype == libevdev.EV_ABS.value:
128                values = fetch(evdev, "absinfo")[code]
129                absinfo = libevdev.InputAbsInfo(
130                    minimum=values[0],
131                    maximum=values[1],
132                    fuzz=values[2],
133                    flat=values[3],
134                    resolution=values[4],
135                )
136                data = absinfo
137            elif evtype == libevdev.EV_REP.value:
138                if code == libevdev.EV_REP.REP_DELAY.value:
139                    data = 500
140                elif code == libevdev.EV_REP.REP_PERIOD.value:
141                    data = 20
142            d.enable(libevdev.evbit(evtype, code), data=data)
143
144    properties = fetch(evdev, "properties")
145    for prop in properties:
146        d.enable(libevdev.propbit(prop))
147
148    uinput = d.create_uinput_device()
149
150    check_udev_properties(device, uinput)
151
152    return uinput
153
154
155def print_events(devnode, indent, evs):
156    devnode = os.path.basename(devnode)
157    for e in evs:
158        print(
159            "{}: {}{:06d}.{:06d} {} / {:<20s} {:4d}".format(
160                devnode,
161                " " * (indent * 8),
162                e.sec,
163                e.usec,
164                e.type.name,
165                e.code.name,
166                e.value,
167            )
168        )
169
170
171def collect_events(frame):
172    evs = []
173    events_skipped = False
174    for (sec, usec, evtype, evcode, value) in frame:
175        if evtype == libevdev.EV_KEY.value and value == 2:  # key repeat
176            events_skipped = True
177            continue
178
179        e = libevdev.InputEvent(
180            libevdev.evbit(evtype, evcode), value=value, sec=sec, usec=usec
181        )
182        evs.append(e)
183
184    # If we skipped some events and now all we have left is the
185    # SYN_REPORTs, we drop the SYN_REPORTs as well.
186    if events_skipped and all(e for e in evs if e.matches(libevdev.EV_SYN.SYN_REPORT)):
187        return []
188    else:
189        return evs
190
191
192def replay(device, verbose):
193    events = fetch(device, "events")
194    if events is None:
195        return
196    uinput = device["__uinput"]
197
198    # The first event may have a nonzero offset but we want to replay
199    # immediately regardless. When replaying multiple devices, the first
200    # offset is the offset from the first event on any device.
201    offset = time.time() - device["__first_event_offset"]
202
203    if offset < 0:
204        error("WARNING: event time offset is in the future, refusing to replay")
205        return
206
207    # each 'evdev' set contains one SYN_REPORT so we only need to check for
208    # the time offset once per event
209    for event in events:
210        try:
211            evdev = fetch(event, "evdev")
212        except YamlException:
213            continue
214
215        evs = collect_events(evdev)
216        if not evs:
217            continue
218
219        evtime = evs[0].sec + evs[0].usec / 1e6 + offset
220        now = time.time()
221
222        if evtime - now > 150 / 1e6:  # 150 µs error margin
223            time.sleep(evtime - now - 150 / 1e6)
224
225        uinput.send_events(evs)
226        if verbose:
227            print_events(uinput.devnode, device["__index"], evs)
228
229
230def first_timestamp(device):
231    events = fetch(device, "events")
232    for e in events or []:
233        try:
234            evdev = fetch(e, "evdev")
235            (sec, usec, *_) = evdev[0]
236            return sec + usec / 1.0e6
237        except YamlException:
238            pass
239
240    return None
241
242
243def wrap(func, *args):
244    try:
245        func(*args)
246    except KeyboardInterrupt:
247        pass
248
249
250def loop(args, recording):
251    devices = fetch(recording, "devices")
252
253    first_timestamps = tuple(
254        filter(lambda x: x is not None, [first_timestamp(d) for d in devices])
255    )
256    # All devices need to start replaying at the same time, so let's find
257    # the very first event and offset everything by that timestamp.
258    toffset = min(first_timestamps or [math.inf])
259
260    for idx, d in enumerate(devices):
261        uinput = create(d)
262        print("{}: {}".format(uinput.devnode, uinput.name))
263        d["__uinput"] = uinput  # cheaper to hide it in the dict then work around it
264        d["__index"] = idx
265        d["__first_event_offset"] = toffset
266
267    if not first_timestamps:
268        input("No events in recording. Hit enter to quit")
269        return
270
271    while True:
272        if args.replay_after >= 0:
273            time.sleep(args.replay_after)
274        else:
275            input("Hit enter to start replaying")
276
277        processes = []
278        for d in devices:
279            p = multiprocessing.Process(target=wrap, args=(replay, d, args.verbose))
280            processes.append(p)
281
282        for p in processes:
283            p.start()
284
285        for p in processes:
286            p.join()
287
288        del processes
289
290        if args.once:
291            break
292
293
294def create_device_quirk(device):
295    try:
296        quirks = fetch(device, "quirks")
297        if not quirks:
298            return None
299    except YamlException:
300        return None
301    # Where the device has a quirk, we match on name, vendor and product.
302    # That's the best match we can assemble here from the info we have.
303    evdev = fetch(device, "evdev")
304    name = fetch(evdev, "name")
305    id = fetch(evdev, "id")
306    quirk = (
307        "[libinput-replay {name}]\n"
308        "MatchName={name}\n"
309        "MatchVendor=0x{id[1]:04X}\n"
310        "MatchProduct=0x{id[2]:04X}\n"
311    ).format(name=name, id=id)
312    quirk += "\n".join(quirks)
313    return quirk
314
315
316def setup_quirks(recording):
317    devices = fetch(recording, "devices")
318    overrides = None
319    quirks = []
320    for d in devices:
321        if "quirks" in d:
322            quirk = create_device_quirk(d)
323            if quirk:
324                quirks.append(quirk)
325    if not quirks:
326        return None
327
328    overrides = Path("/etc/libinput/local-overrides.quirks")
329    if overrides.exists():
330        print(
331            "{} exists, please move it out of the way first".format(overrides),
332            file=sys.stderr,
333        )
334        sys.exit(1)
335
336    overrides.parent.mkdir(exist_ok=True)
337    with overrides.open("w+") as fd:
338        fd.write("# This file was generated by libinput replay\n")
339        fd.write("# Unless libinput replay is running right now, remove this file.\n")
340        fd.write("\n\n".join(quirks))
341
342    return overrides
343
344
345def check_file(recording):
346    version = fetch(recording, "version")
347    if version != SUPPORTED_FILE_VERSION:
348        raise YamlException(
349            "Invalid file format: {}, expected {}".format(
350                version, SUPPORTED_FILE_VERSION
351            )
352        )
353
354    ndevices = fetch(recording, "ndevices")
355    devices = fetch(recording, "devices")
356    if ndevices != len(devices):
357        error(
358            "WARNING: truncated file, expected {} devices, got {}".format(
359                ndevices, len(devices)
360            )
361        )
362
363
364def main():
365    parser = argparse.ArgumentParser(description="Replay a device recording")
366    parser.add_argument(
367        "recording",
368        metavar="recorded-file.yaml",
369        type=str,
370        help="Path to device recording",
371    )
372    parser.add_argument(
373        "--replay-after",
374        type=int,
375        default=-1,
376        help="Automatically replay once after N seconds",
377    )
378    parser.add_argument(
379        "--once",
380        action="store_true",
381        default=False,
382        help="Stop and exit after one replay",
383    )
384    parser.add_argument("--verbose", action="store_true")
385    args = parser.parse_args()
386
387    quirks_file = None
388
389    try:
390        with open(args.recording) as f:
391            y = yaml.safe_load(f)
392            check_file(y)
393            quirks_file = setup_quirks(y)
394            loop(args, y)
395    except KeyboardInterrupt:
396        pass
397    except (PermissionError, OSError) as e:
398        error("Error: failed to open device: {}".format(e))
399    except YamlException as e:
400        error("Error: failed to parse recording: {}".format(e))
401    finally:
402        if quirks_file:
403            quirks_file.unlink()
404
405
406if __name__ == "__main__":
407    main()
408