17db96d56Sopenharmony_ci# pysqlite2/test/hooks.py: tests for various SQLite-specific hooks 27db96d56Sopenharmony_ci# 37db96d56Sopenharmony_ci# Copyright (C) 2006-2007 Gerhard Häring <gh@ghaering.de> 47db96d56Sopenharmony_ci# 57db96d56Sopenharmony_ci# This file is part of pysqlite. 67db96d56Sopenharmony_ci# 77db96d56Sopenharmony_ci# This software is provided 'as-is', without any express or implied 87db96d56Sopenharmony_ci# warranty. In no event will the authors be held liable for any damages 97db96d56Sopenharmony_ci# arising from the use of this software. 107db96d56Sopenharmony_ci# 117db96d56Sopenharmony_ci# Permission is granted to anyone to use this software for any purpose, 127db96d56Sopenharmony_ci# including commercial applications, and to alter it and redistribute it 137db96d56Sopenharmony_ci# freely, subject to the following restrictions: 147db96d56Sopenharmony_ci# 157db96d56Sopenharmony_ci# 1. The origin of this software must not be misrepresented; you must not 167db96d56Sopenharmony_ci# claim that you wrote the original software. If you use this software 177db96d56Sopenharmony_ci# in a product, an acknowledgment in the product documentation would be 187db96d56Sopenharmony_ci# appreciated but is not required. 197db96d56Sopenharmony_ci# 2. Altered source versions must be plainly marked as such, and must not be 207db96d56Sopenharmony_ci# misrepresented as being the original software. 217db96d56Sopenharmony_ci# 3. This notice may not be removed or altered from any source distribution. 227db96d56Sopenharmony_ci 237db96d56Sopenharmony_ciimport contextlib 247db96d56Sopenharmony_ciimport sqlite3 as sqlite 257db96d56Sopenharmony_ciimport unittest 267db96d56Sopenharmony_ci 277db96d56Sopenharmony_cifrom test.support.os_helper import TESTFN, unlink 287db96d56Sopenharmony_ci 297db96d56Sopenharmony_cifrom test.test_sqlite3.test_dbapi import memory_database, cx_limit 307db96d56Sopenharmony_cifrom test.test_sqlite3.test_userfunctions import with_tracebacks 317db96d56Sopenharmony_ci 327db96d56Sopenharmony_ci 337db96d56Sopenharmony_ciclass CollationTests(unittest.TestCase): 347db96d56Sopenharmony_ci def test_create_collation_not_string(self): 357db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 367db96d56Sopenharmony_ci with self.assertRaises(TypeError): 377db96d56Sopenharmony_ci con.create_collation(None, lambda x, y: (x > y) - (x < y)) 387db96d56Sopenharmony_ci 397db96d56Sopenharmony_ci def test_create_collation_not_callable(self): 407db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 417db96d56Sopenharmony_ci with self.assertRaises(TypeError) as cm: 427db96d56Sopenharmony_ci con.create_collation("X", 42) 437db96d56Sopenharmony_ci self.assertEqual(str(cm.exception), 'parameter must be callable') 447db96d56Sopenharmony_ci 457db96d56Sopenharmony_ci def test_create_collation_not_ascii(self): 467db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 477db96d56Sopenharmony_ci con.create_collation("collä", lambda x, y: (x > y) - (x < y)) 487db96d56Sopenharmony_ci 497db96d56Sopenharmony_ci def test_create_collation_bad_upper(self): 507db96d56Sopenharmony_ci class BadUpperStr(str): 517db96d56Sopenharmony_ci def upper(self): 527db96d56Sopenharmony_ci return None 537db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 547db96d56Sopenharmony_ci mycoll = lambda x, y: -((x > y) - (x < y)) 557db96d56Sopenharmony_ci con.create_collation(BadUpperStr("mycoll"), mycoll) 567db96d56Sopenharmony_ci result = con.execute(""" 577db96d56Sopenharmony_ci select x from ( 587db96d56Sopenharmony_ci select 'a' as x 597db96d56Sopenharmony_ci union 607db96d56Sopenharmony_ci select 'b' as x 617db96d56Sopenharmony_ci ) order by x collate mycoll 627db96d56Sopenharmony_ci """).fetchall() 637db96d56Sopenharmony_ci self.assertEqual(result[0][0], 'b') 647db96d56Sopenharmony_ci self.assertEqual(result[1][0], 'a') 657db96d56Sopenharmony_ci 667db96d56Sopenharmony_ci def test_collation_is_used(self): 677db96d56Sopenharmony_ci def mycoll(x, y): 687db96d56Sopenharmony_ci # reverse order 697db96d56Sopenharmony_ci return -((x > y) - (x < y)) 707db96d56Sopenharmony_ci 717db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 727db96d56Sopenharmony_ci con.create_collation("mycoll", mycoll) 737db96d56Sopenharmony_ci sql = """ 747db96d56Sopenharmony_ci select x from ( 757db96d56Sopenharmony_ci select 'a' as x 767db96d56Sopenharmony_ci union 777db96d56Sopenharmony_ci select 'b' as x 787db96d56Sopenharmony_ci union 797db96d56Sopenharmony_ci select 'c' as x 807db96d56Sopenharmony_ci ) order by x collate mycoll 817db96d56Sopenharmony_ci """ 827db96d56Sopenharmony_ci result = con.execute(sql).fetchall() 837db96d56Sopenharmony_ci self.assertEqual(result, [('c',), ('b',), ('a',)], 847db96d56Sopenharmony_ci msg='the expected order was not returned') 857db96d56Sopenharmony_ci 867db96d56Sopenharmony_ci con.create_collation("mycoll", None) 877db96d56Sopenharmony_ci with self.assertRaises(sqlite.OperationalError) as cm: 887db96d56Sopenharmony_ci result = con.execute(sql).fetchall() 897db96d56Sopenharmony_ci self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') 907db96d56Sopenharmony_ci 917db96d56Sopenharmony_ci def test_collation_returns_large_integer(self): 927db96d56Sopenharmony_ci def mycoll(x, y): 937db96d56Sopenharmony_ci # reverse order 947db96d56Sopenharmony_ci return -((x > y) - (x < y)) * 2**32 957db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 967db96d56Sopenharmony_ci con.create_collation("mycoll", mycoll) 977db96d56Sopenharmony_ci sql = """ 987db96d56Sopenharmony_ci select x from ( 997db96d56Sopenharmony_ci select 'a' as x 1007db96d56Sopenharmony_ci union 1017db96d56Sopenharmony_ci select 'b' as x 1027db96d56Sopenharmony_ci union 1037db96d56Sopenharmony_ci select 'c' as x 1047db96d56Sopenharmony_ci ) order by x collate mycoll 1057db96d56Sopenharmony_ci """ 1067db96d56Sopenharmony_ci result = con.execute(sql).fetchall() 1077db96d56Sopenharmony_ci self.assertEqual(result, [('c',), ('b',), ('a',)], 1087db96d56Sopenharmony_ci msg="the expected order was not returned") 1097db96d56Sopenharmony_ci 1107db96d56Sopenharmony_ci def test_collation_register_twice(self): 1117db96d56Sopenharmony_ci """ 1127db96d56Sopenharmony_ci Register two different collation functions under the same name. 1137db96d56Sopenharmony_ci Verify that the last one is actually used. 1147db96d56Sopenharmony_ci """ 1157db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 1167db96d56Sopenharmony_ci con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) 1177db96d56Sopenharmony_ci con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y))) 1187db96d56Sopenharmony_ci result = con.execute(""" 1197db96d56Sopenharmony_ci select x from (select 'a' as x union select 'b' as x) order by x collate mycoll 1207db96d56Sopenharmony_ci """).fetchall() 1217db96d56Sopenharmony_ci self.assertEqual(result[0][0], 'b') 1227db96d56Sopenharmony_ci self.assertEqual(result[1][0], 'a') 1237db96d56Sopenharmony_ci 1247db96d56Sopenharmony_ci def test_deregister_collation(self): 1257db96d56Sopenharmony_ci """ 1267db96d56Sopenharmony_ci Register a collation, then deregister it. Make sure an error is raised if we try 1277db96d56Sopenharmony_ci to use it. 1287db96d56Sopenharmony_ci """ 1297db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 1307db96d56Sopenharmony_ci con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) 1317db96d56Sopenharmony_ci con.create_collation("mycoll", None) 1327db96d56Sopenharmony_ci with self.assertRaises(sqlite.OperationalError) as cm: 1337db96d56Sopenharmony_ci con.execute("select 'a' as x union select 'b' as x order by x collate mycoll") 1347db96d56Sopenharmony_ci self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') 1357db96d56Sopenharmony_ci 1367db96d56Sopenharmony_ciclass ProgressTests(unittest.TestCase): 1377db96d56Sopenharmony_ci def test_progress_handler_used(self): 1387db96d56Sopenharmony_ci """ 1397db96d56Sopenharmony_ci Test that the progress handler is invoked once it is set. 1407db96d56Sopenharmony_ci """ 1417db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 1427db96d56Sopenharmony_ci progress_calls = [] 1437db96d56Sopenharmony_ci def progress(): 1447db96d56Sopenharmony_ci progress_calls.append(None) 1457db96d56Sopenharmony_ci return 0 1467db96d56Sopenharmony_ci con.set_progress_handler(progress, 1) 1477db96d56Sopenharmony_ci con.execute(""" 1487db96d56Sopenharmony_ci create table foo(a, b) 1497db96d56Sopenharmony_ci """) 1507db96d56Sopenharmony_ci self.assertTrue(progress_calls) 1517db96d56Sopenharmony_ci 1527db96d56Sopenharmony_ci def test_opcode_count(self): 1537db96d56Sopenharmony_ci """ 1547db96d56Sopenharmony_ci Test that the opcode argument is respected. 1557db96d56Sopenharmony_ci """ 1567db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 1577db96d56Sopenharmony_ci progress_calls = [] 1587db96d56Sopenharmony_ci def progress(): 1597db96d56Sopenharmony_ci progress_calls.append(None) 1607db96d56Sopenharmony_ci return 0 1617db96d56Sopenharmony_ci con.set_progress_handler(progress, 1) 1627db96d56Sopenharmony_ci curs = con.cursor() 1637db96d56Sopenharmony_ci curs.execute(""" 1647db96d56Sopenharmony_ci create table foo (a, b) 1657db96d56Sopenharmony_ci """) 1667db96d56Sopenharmony_ci first_count = len(progress_calls) 1677db96d56Sopenharmony_ci progress_calls = [] 1687db96d56Sopenharmony_ci con.set_progress_handler(progress, 2) 1697db96d56Sopenharmony_ci curs.execute(""" 1707db96d56Sopenharmony_ci create table bar (a, b) 1717db96d56Sopenharmony_ci """) 1727db96d56Sopenharmony_ci second_count = len(progress_calls) 1737db96d56Sopenharmony_ci self.assertGreaterEqual(first_count, second_count) 1747db96d56Sopenharmony_ci 1757db96d56Sopenharmony_ci def test_cancel_operation(self): 1767db96d56Sopenharmony_ci """ 1777db96d56Sopenharmony_ci Test that returning a non-zero value stops the operation in progress. 1787db96d56Sopenharmony_ci """ 1797db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 1807db96d56Sopenharmony_ci def progress(): 1817db96d56Sopenharmony_ci return 1 1827db96d56Sopenharmony_ci con.set_progress_handler(progress, 1) 1837db96d56Sopenharmony_ci curs = con.cursor() 1847db96d56Sopenharmony_ci self.assertRaises( 1857db96d56Sopenharmony_ci sqlite.OperationalError, 1867db96d56Sopenharmony_ci curs.execute, 1877db96d56Sopenharmony_ci "create table bar (a, b)") 1887db96d56Sopenharmony_ci 1897db96d56Sopenharmony_ci def test_clear_handler(self): 1907db96d56Sopenharmony_ci """ 1917db96d56Sopenharmony_ci Test that setting the progress handler to None clears the previously set handler. 1927db96d56Sopenharmony_ci """ 1937db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 1947db96d56Sopenharmony_ci action = 0 1957db96d56Sopenharmony_ci def progress(): 1967db96d56Sopenharmony_ci nonlocal action 1977db96d56Sopenharmony_ci action = 1 1987db96d56Sopenharmony_ci return 0 1997db96d56Sopenharmony_ci con.set_progress_handler(progress, 1) 2007db96d56Sopenharmony_ci con.set_progress_handler(None, 1) 2017db96d56Sopenharmony_ci con.execute("select 1 union select 2 union select 3").fetchall() 2027db96d56Sopenharmony_ci self.assertEqual(action, 0, "progress handler was not cleared") 2037db96d56Sopenharmony_ci 2047db96d56Sopenharmony_ci @with_tracebacks(ZeroDivisionError, name="bad_progress") 2057db96d56Sopenharmony_ci def test_error_in_progress_handler(self): 2067db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 2077db96d56Sopenharmony_ci def bad_progress(): 2087db96d56Sopenharmony_ci 1 / 0 2097db96d56Sopenharmony_ci con.set_progress_handler(bad_progress, 1) 2107db96d56Sopenharmony_ci with self.assertRaises(sqlite.OperationalError): 2117db96d56Sopenharmony_ci con.execute(""" 2127db96d56Sopenharmony_ci create table foo(a, b) 2137db96d56Sopenharmony_ci """) 2147db96d56Sopenharmony_ci 2157db96d56Sopenharmony_ci @with_tracebacks(ZeroDivisionError, name="bad_progress") 2167db96d56Sopenharmony_ci def test_error_in_progress_handler_result(self): 2177db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 2187db96d56Sopenharmony_ci class BadBool: 2197db96d56Sopenharmony_ci def __bool__(self): 2207db96d56Sopenharmony_ci 1 / 0 2217db96d56Sopenharmony_ci def bad_progress(): 2227db96d56Sopenharmony_ci return BadBool() 2237db96d56Sopenharmony_ci con.set_progress_handler(bad_progress, 1) 2247db96d56Sopenharmony_ci with self.assertRaises(sqlite.OperationalError): 2257db96d56Sopenharmony_ci con.execute(""" 2267db96d56Sopenharmony_ci create table foo(a, b) 2277db96d56Sopenharmony_ci """) 2287db96d56Sopenharmony_ci 2297db96d56Sopenharmony_ci 2307db96d56Sopenharmony_ciclass TraceCallbackTests(unittest.TestCase): 2317db96d56Sopenharmony_ci @contextlib.contextmanager 2327db96d56Sopenharmony_ci def check_stmt_trace(self, cx, expected): 2337db96d56Sopenharmony_ci try: 2347db96d56Sopenharmony_ci traced = [] 2357db96d56Sopenharmony_ci cx.set_trace_callback(lambda stmt: traced.append(stmt)) 2367db96d56Sopenharmony_ci yield 2377db96d56Sopenharmony_ci finally: 2387db96d56Sopenharmony_ci self.assertEqual(traced, expected) 2397db96d56Sopenharmony_ci cx.set_trace_callback(None) 2407db96d56Sopenharmony_ci 2417db96d56Sopenharmony_ci def test_trace_callback_used(self): 2427db96d56Sopenharmony_ci """ 2437db96d56Sopenharmony_ci Test that the trace callback is invoked once it is set. 2447db96d56Sopenharmony_ci """ 2457db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 2467db96d56Sopenharmony_ci traced_statements = [] 2477db96d56Sopenharmony_ci def trace(statement): 2487db96d56Sopenharmony_ci traced_statements.append(statement) 2497db96d56Sopenharmony_ci con.set_trace_callback(trace) 2507db96d56Sopenharmony_ci con.execute("create table foo(a, b)") 2517db96d56Sopenharmony_ci self.assertTrue(traced_statements) 2527db96d56Sopenharmony_ci self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) 2537db96d56Sopenharmony_ci 2547db96d56Sopenharmony_ci def test_clear_trace_callback(self): 2557db96d56Sopenharmony_ci """ 2567db96d56Sopenharmony_ci Test that setting the trace callback to None clears the previously set callback. 2577db96d56Sopenharmony_ci """ 2587db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 2597db96d56Sopenharmony_ci traced_statements = [] 2607db96d56Sopenharmony_ci def trace(statement): 2617db96d56Sopenharmony_ci traced_statements.append(statement) 2627db96d56Sopenharmony_ci con.set_trace_callback(trace) 2637db96d56Sopenharmony_ci con.set_trace_callback(None) 2647db96d56Sopenharmony_ci con.execute("create table foo(a, b)") 2657db96d56Sopenharmony_ci self.assertFalse(traced_statements, "trace callback was not cleared") 2667db96d56Sopenharmony_ci 2677db96d56Sopenharmony_ci def test_unicode_content(self): 2687db96d56Sopenharmony_ci """ 2697db96d56Sopenharmony_ci Test that the statement can contain unicode literals. 2707db96d56Sopenharmony_ci """ 2717db96d56Sopenharmony_ci unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac' 2727db96d56Sopenharmony_ci con = sqlite.connect(":memory:") 2737db96d56Sopenharmony_ci traced_statements = [] 2747db96d56Sopenharmony_ci def trace(statement): 2757db96d56Sopenharmony_ci traced_statements.append(statement) 2767db96d56Sopenharmony_ci con.set_trace_callback(trace) 2777db96d56Sopenharmony_ci con.execute("create table foo(x)") 2787db96d56Sopenharmony_ci con.execute("insert into foo(x) values ('%s')" % unicode_value) 2797db96d56Sopenharmony_ci con.commit() 2807db96d56Sopenharmony_ci self.assertTrue(any(unicode_value in stmt for stmt in traced_statements), 2817db96d56Sopenharmony_ci "Unicode data %s garbled in trace callback: %s" 2827db96d56Sopenharmony_ci % (ascii(unicode_value), ', '.join(map(ascii, traced_statements)))) 2837db96d56Sopenharmony_ci 2847db96d56Sopenharmony_ci def test_trace_callback_content(self): 2857db96d56Sopenharmony_ci # set_trace_callback() shouldn't produce duplicate content (bpo-26187) 2867db96d56Sopenharmony_ci traced_statements = [] 2877db96d56Sopenharmony_ci def trace(statement): 2887db96d56Sopenharmony_ci traced_statements.append(statement) 2897db96d56Sopenharmony_ci 2907db96d56Sopenharmony_ci queries = ["create table foo(x)", 2917db96d56Sopenharmony_ci "insert into foo(x) values(1)"] 2927db96d56Sopenharmony_ci self.addCleanup(unlink, TESTFN) 2937db96d56Sopenharmony_ci con1 = sqlite.connect(TESTFN, isolation_level=None) 2947db96d56Sopenharmony_ci con2 = sqlite.connect(TESTFN) 2957db96d56Sopenharmony_ci try: 2967db96d56Sopenharmony_ci con1.set_trace_callback(trace) 2977db96d56Sopenharmony_ci cur = con1.cursor() 2987db96d56Sopenharmony_ci cur.execute(queries[0]) 2997db96d56Sopenharmony_ci con2.execute("create table bar(x)") 3007db96d56Sopenharmony_ci cur.execute(queries[1]) 3017db96d56Sopenharmony_ci finally: 3027db96d56Sopenharmony_ci con1.close() 3037db96d56Sopenharmony_ci con2.close() 3047db96d56Sopenharmony_ci self.assertEqual(traced_statements, queries) 3057db96d56Sopenharmony_ci 3067db96d56Sopenharmony_ci def test_trace_expanded_sql(self): 3077db96d56Sopenharmony_ci expected = [ 3087db96d56Sopenharmony_ci "create table t(t)", 3097db96d56Sopenharmony_ci "BEGIN ", 3107db96d56Sopenharmony_ci "insert into t values(0)", 3117db96d56Sopenharmony_ci "insert into t values(1)", 3127db96d56Sopenharmony_ci "insert into t values(2)", 3137db96d56Sopenharmony_ci "COMMIT", 3147db96d56Sopenharmony_ci ] 3157db96d56Sopenharmony_ci with memory_database() as cx, self.check_stmt_trace(cx, expected): 3167db96d56Sopenharmony_ci with cx: 3177db96d56Sopenharmony_ci cx.execute("create table t(t)") 3187db96d56Sopenharmony_ci cx.executemany("insert into t values(?)", ((v,) for v in range(3))) 3197db96d56Sopenharmony_ci 3207db96d56Sopenharmony_ci @with_tracebacks( 3217db96d56Sopenharmony_ci sqlite.DataError, 3227db96d56Sopenharmony_ci regex="Expanded SQL string exceeds the maximum string length" 3237db96d56Sopenharmony_ci ) 3247db96d56Sopenharmony_ci def test_trace_too_much_expanded_sql(self): 3257db96d56Sopenharmony_ci # If the expanded string is too large, we'll fall back to the 3267db96d56Sopenharmony_ci # unexpanded SQL statement (for SQLite 3.14.0 and newer). 3277db96d56Sopenharmony_ci # The resulting string length is limited by the runtime limit 3287db96d56Sopenharmony_ci # SQLITE_LIMIT_LENGTH. 3297db96d56Sopenharmony_ci template = "select 1 as a where a=" 3307db96d56Sopenharmony_ci category = sqlite.SQLITE_LIMIT_LENGTH 3317db96d56Sopenharmony_ci with memory_database() as cx, cx_limit(cx, category=category) as lim: 3327db96d56Sopenharmony_ci ok_param = "a" 3337db96d56Sopenharmony_ci bad_param = "a" * lim 3347db96d56Sopenharmony_ci 3357db96d56Sopenharmony_ci unexpanded_query = template + "?" 3367db96d56Sopenharmony_ci expected = [unexpanded_query] 3377db96d56Sopenharmony_ci if sqlite.sqlite_version_info < (3, 14, 0): 3387db96d56Sopenharmony_ci expected = [] 3397db96d56Sopenharmony_ci with self.check_stmt_trace(cx, expected): 3407db96d56Sopenharmony_ci cx.execute(unexpanded_query, (bad_param,)) 3417db96d56Sopenharmony_ci 3427db96d56Sopenharmony_ci expanded_query = f"{template}'{ok_param}'" 3437db96d56Sopenharmony_ci with self.check_stmt_trace(cx, [expanded_query]): 3447db96d56Sopenharmony_ci cx.execute(unexpanded_query, (ok_param,)) 3457db96d56Sopenharmony_ci 3467db96d56Sopenharmony_ci @with_tracebacks(ZeroDivisionError, regex="division by zero") 3477db96d56Sopenharmony_ci def test_trace_bad_handler(self): 3487db96d56Sopenharmony_ci with memory_database() as cx: 3497db96d56Sopenharmony_ci cx.set_trace_callback(lambda stmt: 5/0) 3507db96d56Sopenharmony_ci cx.execute("select 1") 3517db96d56Sopenharmony_ci 3527db96d56Sopenharmony_ci 3537db96d56Sopenharmony_ciif __name__ == "__main__": 3547db96d56Sopenharmony_ci unittest.main() 355