fixup! [IMP] use sqlparse also to determine which ddl to update

This commit is contained in:
Holger Brunn
2022-05-18 07:04:43 +02:00
parent dec01e2bab
commit b76c3ed67b
2 changed files with 102 additions and 4 deletions

View File

@@ -10,6 +10,58 @@ except ImportError:
sqlparse = None sqlparse = None
SECTION_NAME = "pglogical" SECTION_NAME = "pglogical"
DDL_KEYWORDS = ("CREATE", "ALTER", "DROP", "TRUNCATE")
def schema_qualify(parsed_query, schema="public"):
"""
Yield tokens and add a schema to objects if there's none
"""
token_iterator = parsed_query.flatten()
Name = sqlparse.tokens.Name
Punctuation = sqlparse.tokens.Punctuation
is_qualified = False
for token in token_iterator:
yield token
if not is_qualified and token.is_keyword and token.normalized in DDL_KEYWORDS:
# we check if the name coming after {create,drop,alter} object keywords
# is schema qualified, and if not, add the schema we got passed
next_token = False
while True:
try:
next_token = token_iterator.__next__()
except StopIteration:
# this is invalid sql
next_token = False
break
if not (next_token.is_whitespace or next_token.is_keyword):
break
yield next_token
if not next_token:
continue
if next_token.ttype != Name:
yield next_token
continue
object_name_or_schema = next_token
needs_schema = False
next_token = False
try:
next_token = token_iterator.__next__()
needs_schema = str(next_token) != '.'
except StopIteration:
needs_schema = True
if needs_schema:
yield sqlparse.sql.Token(Name, schema)
yield sqlparse.sql.Token(Punctuation, '.')
yield object_name_or_schema
if next_token:
yield next_token
is_qualified = True
def post_load(): def post_load():
@@ -42,10 +94,10 @@ def post_load():
def execute(self, query, params=None, log_exceptions=None): def execute(self, query, params=None, log_exceptions=None):
"""Wrap DDL in pglogical.replicate_ddl_command""" """Wrap DDL in pglogical.replicate_ddl_command"""
# short circuit statements that must be as fast as possible # short circuit statements that must be as fast as possible
if query[:6] != "SELECT": if query[:6] not in ("SELECT", "UPDATE"):
parsed_queries = sqlparse.parse(query) parsed_queries = sqlparse.parse(query)
if any( if any(
parsed_query.get_type() in ("CREATE", "ALTER", "DROP") parsed_query.get_type() in DDL_KEYWORDS
for parsed_query in parsed_queries for parsed_query in parsed_queries
) and not any( ) and not any(
token.is_keyword and token.normalized in token.is_keyword and token.normalized in
@@ -53,7 +105,11 @@ def post_load():
("CONSTRAINT", "TRIGGER", "INDEX") ("CONSTRAINT", "TRIGGER", "INDEX")
for parsed in parsed_queries for token in parsed.tokens[1:] for parsed in parsed_queries for token in parsed.tokens[1:]
): ):
mogrified = self.mogrify(query, params).decode("utf8") qualified_query = ''.join(
''.join(str(token) for token in schema_qualify(parsed_query))
for parsed_query in parsed_queries
)
mogrified = self.mogrify(qualified_query, params).decode("utf8")
query = "SELECT pglogical.replicate_ddl_command(%s, %s)" query = "SELECT pglogical.replicate_ddl_command(%s, %s)"
params = (mogrified, execute.replication_sets) params = (mogrified, execute.replication_sets)
return execute.origin(self, query, params=params, log_exceptions=log_exceptions) return execute.origin(self, query, params=params, log_exceptions=log_exceptions)

View File

@@ -6,7 +6,7 @@ from contextlib import contextmanager
from odoo.sql_db import Cursor from odoo.sql_db import Cursor
from odoo.tests.common import TransactionCase from odoo.tests.common import TransactionCase
from odoo.tools.config import config from odoo.tools.config import config
from ..hooks import post_load from ..hooks import post_load, schema_qualify, sqlparse
class TestPglogical(TransactionCase): class TestPglogical(TransactionCase):
@@ -58,6 +58,20 @@ class TestPglogical(TransactionCase):
], ],
) )
with self._config(dict(pglogical={"replication_sets": "ddl_sql"})),\
self.assertLogs("odoo.addons.pglogical") as log,\
mock.patch("odoo.addons.pglogical.hooks.sqlparse") as mock_sqlparse:
mock_sqlparse.__bool__.return_value = False
post_load()
self.assertEqual(
log.output,
[
"ERROR:odoo.addons.pglogical:"
"DDL replication not supported - sqlparse is not available"
],
)
def test_patching(self): def test_patching(self):
"""Test patching the cursor succeeds""" """Test patching the cursor succeeds"""
with self._config(dict(pglogical=dict(replication_sets="set1,set2"))): with self._config(dict(pglogical=dict(replication_sets="set1,set2"))):
@@ -90,3 +104,31 @@ class TestPglogical(TransactionCase):
) )
finally: finally:
Cursor.execute = getattr(Cursor.execute, "origin", Cursor.execute) Cursor.execute = getattr(Cursor.execute, "origin", Cursor.execute)
def test_schema_qualify(self):
"""Test that schema qualifications are the only changes"""
for statement in (
'create table if not exists testtable',
'drop table testtable',
'alter table testtable',
'''create table
testtable
(col1 int, col2 int); select * from test''',
'alter table testschema.test drop column somecol',
' DROP view if exists testtable',
'truncate table testtable',
'''CREATE FUNCTION testtable(integer, integer) RETURNS integer
AS 'select $1 + $2;'
LANGUAGE SQL
IMMUTABLE
RETURNS NULL ON NULL INPUT''',
'drop table',
"alter table 'test'",
):
qualified_query = ''.join(
''.join(str(token) for token in schema_qualify(parsed_query))
for parsed_query in sqlparse.parse(statement)
)
self.assertEqual(
qualified_query, statement.replace('testtable', 'public.testtable')
)