From b76c3ed67bcc9823545ff702c695ce9af86e0911 Mon Sep 17 00:00:00 2001 From: Holger Brunn Date: Wed, 18 May 2022 07:04:43 +0200 Subject: [PATCH] fixup! [IMP] use sqlparse also to determine which ddl to update --- pglogical/hooks.py | 62 +++++++++++++++++++++++++++++-- pglogical/tests/test_pglogical.py | 44 +++++++++++++++++++++- 2 files changed, 102 insertions(+), 4 deletions(-) diff --git a/pglogical/hooks.py b/pglogical/hooks.py index 8180b70a..8910feba 100644 --- a/pglogical/hooks.py +++ b/pglogical/hooks.py @@ -10,6 +10,58 @@ except ImportError: sqlparse = None SECTION_NAME = "pglogical" +DDL_KEYWORDS = ("CREATE", "ALTER", "DROP", "TRUNCATE") + + +def schema_qualify(parsed_query, schema="public"): + """ + Yield tokens and add a schema to objects if there's none + """ + token_iterator = parsed_query.flatten() + Name = sqlparse.tokens.Name + Punctuation = sqlparse.tokens.Punctuation + is_qualified = False + for token in token_iterator: + yield token + if not is_qualified and token.is_keyword and token.normalized in DDL_KEYWORDS: + # we check if the name coming after {create,drop,alter} object keywords + # is schema qualified, and if not, add the schema we got passed + next_token = False + while True: + try: + next_token = token_iterator.__next__() + except StopIteration: + # this is invalid sql + next_token = False + break + if not (next_token.is_whitespace or next_token.is_keyword): + break + yield next_token + if not next_token: + continue + if next_token.ttype != Name: + yield next_token + continue + + object_name_or_schema = next_token + needs_schema = False + next_token = False + try: + next_token = token_iterator.__next__() + needs_schema = str(next_token) != '.' + except StopIteration: + needs_schema = True + + if needs_schema: + yield sqlparse.sql.Token(Name, schema) + yield sqlparse.sql.Token(Punctuation, '.') + + yield object_name_or_schema + + if next_token: + yield next_token + + is_qualified = True def post_load(): @@ -42,10 +94,10 @@ def post_load(): def execute(self, query, params=None, log_exceptions=None): """Wrap DDL in pglogical.replicate_ddl_command""" # short circuit statements that must be as fast as possible - if query[:6] != "SELECT": + if query[:6] not in ("SELECT", "UPDATE"): parsed_queries = sqlparse.parse(query) if any( - parsed_query.get_type() in ("CREATE", "ALTER", "DROP") + parsed_query.get_type() in DDL_KEYWORDS for parsed_query in parsed_queries ) and not any( token.is_keyword and token.normalized in @@ -53,7 +105,11 @@ def post_load(): ("CONSTRAINT", "TRIGGER", "INDEX") for parsed in parsed_queries for token in parsed.tokens[1:] ): - mogrified = self.mogrify(query, params).decode("utf8") + qualified_query = ''.join( + ''.join(str(token) for token in schema_qualify(parsed_query)) + for parsed_query in parsed_queries + ) + mogrified = self.mogrify(qualified_query, params).decode("utf8") query = "SELECT pglogical.replicate_ddl_command(%s, %s)" params = (mogrified, execute.replication_sets) return execute.origin(self, query, params=params, log_exceptions=log_exceptions) diff --git a/pglogical/tests/test_pglogical.py b/pglogical/tests/test_pglogical.py index 4ab9d75a..f3e0f8f6 100644 --- a/pglogical/tests/test_pglogical.py +++ b/pglogical/tests/test_pglogical.py @@ -6,7 +6,7 @@ from contextlib import contextmanager from odoo.sql_db import Cursor from odoo.tests.common import TransactionCase from odoo.tools.config import config -from ..hooks import post_load +from ..hooks import post_load, schema_qualify, sqlparse class TestPglogical(TransactionCase): @@ -58,6 +58,20 @@ class TestPglogical(TransactionCase): ], ) + with self._config(dict(pglogical={"replication_sets": "ddl_sql"})),\ + self.assertLogs("odoo.addons.pglogical") as log,\ + mock.patch("odoo.addons.pglogical.hooks.sqlparse") as mock_sqlparse: + mock_sqlparse.__bool__.return_value = False + post_load() + + self.assertEqual( + log.output, + [ + "ERROR:odoo.addons.pglogical:" + "DDL replication not supported - sqlparse is not available" + ], + ) + def test_patching(self): """Test patching the cursor succeeds""" with self._config(dict(pglogical=dict(replication_sets="set1,set2"))): @@ -90,3 +104,31 @@ class TestPglogical(TransactionCase): ) finally: Cursor.execute = getattr(Cursor.execute, "origin", Cursor.execute) + + def test_schema_qualify(self): + """Test that schema qualifications are the only changes""" + for statement in ( + 'create table if not exists testtable', + 'drop table testtable', + 'alter table testtable', + '''create table + testtable + (col1 int, col2 int); select * from test''', + 'alter table testschema.test drop column somecol', + ' DROP view if exists testtable', + 'truncate table testtable', + '''CREATE FUNCTION testtable(integer, integer) RETURNS integer + AS 'select $1 + $2;' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT''', + 'drop table', + "alter table 'test'", + ): + qualified_query = ''.join( + ''.join(str(token) for token in schema_qualify(parsed_query)) + for parsed_query in sqlparse.parse(statement) + ) + self.assertEqual( + qualified_query, statement.replace('testtable', 'public.testtable') + )