xref: /third_party/markupsafe/__init__.py (revision c44ef7f9)
1import functools
2import string
3import sys
4import typing as t
5
6if t.TYPE_CHECKING:
7    import typing_extensions as te
8
9    class HasHTML(te.Protocol):
10        def __html__(self) -> str:
11            pass
12
13    _P = te.ParamSpec("_P")
14
15
16__version__ = "2.1.5"
17
18
19def _simple_escaping_wrapper(func: "t.Callable[_P, str]") -> "t.Callable[_P, Markup]":
20    @functools.wraps(func)
21    def wrapped(self: "Markup", *args: "_P.args", **kwargs: "_P.kwargs") -> "Markup":
22        arg_list = _escape_argspec(list(args), enumerate(args), self.escape)
23        _escape_argspec(kwargs, kwargs.items(), self.escape)
24        return self.__class__(func(self, *arg_list, **kwargs))  # type: ignore[arg-type]
25
26    return wrapped  # type: ignore[return-value]
27
28
29class Markup(str):
30    """A string that is ready to be safely inserted into an HTML or XML
31    document, either because it was escaped or because it was marked
32    safe.
33
34    Passing an object to the constructor converts it to text and wraps
35    it to mark it safe without escaping. To escape the text, use the
36    :meth:`escape` class method instead.
37
38    >>> Markup("Hello, <em>World</em>!")
39    Markup('Hello, <em>World</em>!')
40    >>> Markup(42)
41    Markup('42')
42    >>> Markup.escape("Hello, <em>World</em>!")
43    Markup('Hello &lt;em&gt;World&lt;/em&gt;!')
44
45    This implements the ``__html__()`` interface that some frameworks
46    use. Passing an object that implements ``__html__()`` will wrap the
47    output of that method, marking it safe.
48
49    >>> class Foo:
50    ...     def __html__(self):
51    ...         return '<a href="/foo">foo</a>'
52    ...
53    >>> Markup(Foo())
54    Markup('<a href="/foo">foo</a>')
55
56    This is a subclass of :class:`str`. It has the same methods, but
57    escapes their arguments and returns a ``Markup`` instance.
58
59    >>> Markup("<em>%s</em>") % ("foo & bar",)
60    Markup('<em>foo &amp; bar</em>')
61    >>> Markup("<em>Hello</em> ") + "<foo>"
62    Markup('<em>Hello</em> &lt;foo&gt;')
63    """
64
65    __slots__ = ()
66
67    def __new__(
68        cls, base: t.Any = "", encoding: t.Optional[str] = None, errors: str = "strict"
69    ) -> "te.Self":
70        if hasattr(base, "__html__"):
71            base = base.__html__()
72
73        if encoding is None:
74            return super().__new__(cls, base)
75
76        return super().__new__(cls, base, encoding, errors)
77
78    def __html__(self) -> "te.Self":
79        return self
80
81    def __add__(self, other: t.Union[str, "HasHTML"]) -> "te.Self":
82        if isinstance(other, str) or hasattr(other, "__html__"):
83            return self.__class__(super().__add__(self.escape(other)))
84
85        return NotImplemented
86
87    def __radd__(self, other: t.Union[str, "HasHTML"]) -> "te.Self":
88        if isinstance(other, str) or hasattr(other, "__html__"):
89            return self.escape(other).__add__(self)
90
91        return NotImplemented
92
93    def __mul__(self, num: "te.SupportsIndex") -> "te.Self":
94        if isinstance(num, int):
95            return self.__class__(super().__mul__(num))
96
97        return NotImplemented
98
99    __rmul__ = __mul__
100
101    def __mod__(self, arg: t.Any) -> "te.Self":
102        if isinstance(arg, tuple):
103            # a tuple of arguments, each wrapped
104            arg = tuple(_MarkupEscapeHelper(x, self.escape) for x in arg)
105        elif hasattr(type(arg), "__getitem__") and not isinstance(arg, str):
106            # a mapping of arguments, wrapped
107            arg = _MarkupEscapeHelper(arg, self.escape)
108        else:
109            # a single argument, wrapped with the helper and a tuple
110            arg = (_MarkupEscapeHelper(arg, self.escape),)
111
112        return self.__class__(super().__mod__(arg))
113
114    def __repr__(self) -> str:
115        return f"{self.__class__.__name__}({super().__repr__()})"
116
117    def join(self, seq: t.Iterable[t.Union[str, "HasHTML"]]) -> "te.Self":
118        return self.__class__(super().join(map(self.escape, seq)))
119
120    join.__doc__ = str.join.__doc__
121
122    def split(  # type: ignore[override]
123        self, sep: t.Optional[str] = None, maxsplit: int = -1
124    ) -> t.List["te.Self"]:
125        return [self.__class__(v) for v in super().split(sep, maxsplit)]
126
127    split.__doc__ = str.split.__doc__
128
129    def rsplit(  # type: ignore[override]
130        self, sep: t.Optional[str] = None, maxsplit: int = -1
131    ) -> t.List["te.Self"]:
132        return [self.__class__(v) for v in super().rsplit(sep, maxsplit)]
133
134    rsplit.__doc__ = str.rsplit.__doc__
135
136    def splitlines(  # type: ignore[override]
137        self, keepends: bool = False
138    ) -> t.List["te.Self"]:
139        return [self.__class__(v) for v in super().splitlines(keepends)]
140
141    splitlines.__doc__ = str.splitlines.__doc__
142
143    def unescape(self) -> str:
144        """Convert escaped markup back into a text string. This replaces
145        HTML entities with the characters they represent.
146
147        >>> Markup("Main &raquo; <em>About</em>").unescape()
148        'Main » <em>About</em>'
149        """
150        from html import unescape
151
152        return unescape(str(self))
153
154    def striptags(self) -> str:
155        """:meth:`unescape` the markup, remove tags, and normalize
156        whitespace to single spaces.
157
158        >>> Markup("Main &raquo;\t<em>About</em>").striptags()
159        'Main » About'
160        """
161        value = str(self)
162
163        # Look for comments then tags separately. Otherwise, a comment that
164        # contains a tag would end early, leaving some of the comment behind.
165
166        while True:
167            # keep finding comment start marks
168            start = value.find("<!--")
169
170            if start == -1:
171                break
172
173            # find a comment end mark beyond the start, otherwise stop
174            end = value.find("-->", start)
175
176            if end == -1:
177                break
178
179            value = f"{value[:start]}{value[end + 3:]}"
180
181        # remove tags using the same method
182        while True:
183            start = value.find("<")
184
185            if start == -1:
186                break
187
188            end = value.find(">", start)
189
190            if end == -1:
191                break
192
193            value = f"{value[:start]}{value[end + 1:]}"
194
195        # collapse spaces
196        value = " ".join(value.split())
197        return self.__class__(value).unescape()
198
199    @classmethod
200    def escape(cls, s: t.Any) -> "te.Self":
201        """Escape a string. Calls :func:`escape` and ensures that for
202        subclasses the correct type is returned.
203        """
204        rv = escape(s)
205
206        if rv.__class__ is not cls:
207            return cls(rv)
208
209        return rv  # type: ignore[return-value]
210
211    __getitem__ = _simple_escaping_wrapper(str.__getitem__)
212    capitalize = _simple_escaping_wrapper(str.capitalize)
213    title = _simple_escaping_wrapper(str.title)
214    lower = _simple_escaping_wrapper(str.lower)
215    upper = _simple_escaping_wrapper(str.upper)
216    replace = _simple_escaping_wrapper(str.replace)
217    ljust = _simple_escaping_wrapper(str.ljust)
218    rjust = _simple_escaping_wrapper(str.rjust)
219    lstrip = _simple_escaping_wrapper(str.lstrip)
220    rstrip = _simple_escaping_wrapper(str.rstrip)
221    center = _simple_escaping_wrapper(str.center)
222    strip = _simple_escaping_wrapper(str.strip)
223    translate = _simple_escaping_wrapper(str.translate)
224    expandtabs = _simple_escaping_wrapper(str.expandtabs)
225    swapcase = _simple_escaping_wrapper(str.swapcase)
226    zfill = _simple_escaping_wrapper(str.zfill)
227    casefold = _simple_escaping_wrapper(str.casefold)
228
229    if sys.version_info >= (3, 9):
230        removeprefix = _simple_escaping_wrapper(str.removeprefix)
231        removesuffix = _simple_escaping_wrapper(str.removesuffix)
232
233    def partition(self, sep: str) -> t.Tuple["te.Self", "te.Self", "te.Self"]:
234        l, s, r = super().partition(self.escape(sep))
235        cls = self.__class__
236        return cls(l), cls(s), cls(r)
237
238    def rpartition(self, sep: str) -> t.Tuple["te.Self", "te.Self", "te.Self"]:
239        l, s, r = super().rpartition(self.escape(sep))
240        cls = self.__class__
241        return cls(l), cls(s), cls(r)
242
243    def format(self, *args: t.Any, **kwargs: t.Any) -> "te.Self":
244        formatter = EscapeFormatter(self.escape)
245        return self.__class__(formatter.vformat(self, args, kwargs))
246
247    def format_map(  # type: ignore[override]
248        self, map: t.Mapping[str, t.Any]
249    ) -> "te.Self":
250        formatter = EscapeFormatter(self.escape)
251        return self.__class__(formatter.vformat(self, (), map))
252
253    def __html_format__(self, format_spec: str) -> "te.Self":
254        if format_spec:
255            raise ValueError("Unsupported format specification for Markup.")
256
257        return self
258
259
260class EscapeFormatter(string.Formatter):
261    __slots__ = ("escape",)
262
263    def __init__(self, escape: t.Callable[[t.Any], Markup]) -> None:
264        self.escape = escape
265        super().__init__()
266
267    def format_field(self, value: t.Any, format_spec: str) -> str:
268        if hasattr(value, "__html_format__"):
269            rv = value.__html_format__(format_spec)
270        elif hasattr(value, "__html__"):
271            if format_spec:
272                raise ValueError(
273                    f"Format specifier {format_spec} given, but {type(value)} does not"
274                    " define __html_format__. A class that defines __html__ must define"
275                    " __html_format__ to work with format specifiers."
276                )
277            rv = value.__html__()
278        else:
279            # We need to make sure the format spec is str here as
280            # otherwise the wrong callback methods are invoked.
281            rv = string.Formatter.format_field(self, value, str(format_spec))
282        return str(self.escape(rv))
283
284
285_ListOrDict = t.TypeVar("_ListOrDict", list, dict)
286
287
288def _escape_argspec(
289    obj: _ListOrDict, iterable: t.Iterable[t.Any], escape: t.Callable[[t.Any], Markup]
290) -> _ListOrDict:
291    """Helper for various string-wrapped functions."""
292    for key, value in iterable:
293        if isinstance(value, str) or hasattr(value, "__html__"):
294            obj[key] = escape(value)
295
296    return obj
297
298
299class _MarkupEscapeHelper:
300    """Helper for :meth:`Markup.__mod__`."""
301
302    __slots__ = ("obj", "escape")
303
304    def __init__(self, obj: t.Any, escape: t.Callable[[t.Any], Markup]) -> None:
305        self.obj = obj
306        self.escape = escape
307
308    def __getitem__(self, item: t.Any) -> "te.Self":
309        return self.__class__(self.obj[item], self.escape)
310
311    def __str__(self) -> str:
312        return str(self.escape(self.obj))
313
314    def __repr__(self) -> str:
315        return str(self.escape(repr(self.obj)))
316
317    def __int__(self) -> int:
318        return int(self.obj)
319
320    def __float__(self) -> float:
321        return float(self.obj)
322
323
324# circular import
325try:
326    from ._speedups import escape as escape
327    from ._speedups import escape_silent as escape_silent
328    from ._speedups import soft_str as soft_str
329except ImportError:
330    from ._native import escape as escape
331    from ._native import escape_silent as escape_silent  # noqa: F401
332    from ._native import soft_str as soft_str  # noqa: F401
333