fixup! fixup! fixup! fixup! [IMP] use sqlparse also to determine which ddl to update

This commit is contained in:
Holger Brunn
2022-07-12 10:02:11 +02:00
parent fbe5bc5cb4
commit 38ae4a696c
2 changed files with 35 additions and 7 deletions

View File

@@ -10,20 +10,24 @@ except ImportError:
sqlparse = None sqlparse = None
SECTION_NAME = "pglogical" SECTION_NAME = "pglogical"
DDL_KEYWORDS = ("CREATE", "ALTER", "DROP", "TRUNCATE") DDL_KEYWORDS = ("CREATE", "ALTER", "DROP", "TRUNCATE", "INHERITS")
def schema_qualify(parsed_query, schema="public"): def schema_qualify(parsed_query, temp_tables, schema="public"):
""" """
Yield tokens and add a schema to objects if there's none Yield tokens and add a schema to objects if there's none, but record and
exclude temporary tables
""" """
token_iterator = parsed_query.flatten() token_iterator = parsed_query.flatten()
Name = sqlparse.tokens.Name Name = sqlparse.tokens.Name
Punctuation = sqlparse.tokens.Punctuation Punctuation = sqlparse.tokens.Punctuation
Symbol = sqlparse.tokens.String.Symbol Symbol = sqlparse.tokens.String.Symbol
is_qualified = False is_qualified = False
is_temp_table = False
for token in token_iterator: for token in token_iterator:
yield token yield token
if token.is_keyword and token.normalized == "INHERITS":
is_qualified = False
if not is_qualified and token.is_keyword and token.normalized in DDL_KEYWORDS: if not is_qualified and token.is_keyword and token.normalized in DDL_KEYWORDS:
# we check if the name coming after {create,drop,alter} object keywords # we check if the name coming after {create,drop,alter} object keywords
# is schema qualified, and if not, add the schema we got passed # is schema qualified, and if not, add the schema we got passed
@@ -40,10 +44,26 @@ def schema_qualify(parsed_query, schema="public"):
): ):
# don't qualify CREATE TEMP TABLE statements # don't qualify CREATE TEMP TABLE statements
is_qualified = True is_qualified = True
is_temp_table = True
break break
if not (next_token.is_whitespace or next_token.is_keyword): if not (next_token.is_whitespace or next_token.is_keyword):
break break
yield next_token yield next_token
if is_temp_table:
yield next_token
while True:
try:
next_token = token_iterator.__next__()
except StopIteration:
next_token = False
break
if next_token.ttype in (Name, Symbol):
temp_tables.append(str(next_token))
yield next_token
next_token = False
break
else:
yield next_token
if not next_token: if not next_token:
continue continue
if next_token.ttype not in (Name, Symbol): if next_token.ttype not in (Name, Symbol):
@@ -59,7 +79,9 @@ def schema_qualify(parsed_query, schema="public"):
except StopIteration: except StopIteration:
needs_schema = True needs_schema = True
if needs_schema: if needs_schema and str(object_name_or_schema) in temp_tables:
temp_tables.remove(str(object_name_or_schema))
elif needs_schema:
yield sqlparse.sql.Token(Name, schema) yield sqlparse.sql.Token(Name, schema)
yield sqlparse.sql.Token(Punctuation, '.') yield sqlparse.sql.Token(Punctuation, '.')
@@ -102,6 +124,7 @@ def post_load():
"""Wrap DDL in pglogical.replicate_ddl_command""" """Wrap DDL in pglogical.replicate_ddl_command"""
# short circuit statements that must be as fast as possible # short circuit statements that must be as fast as possible
if query[:6] not in ("SELECT", "UPDATE"): if query[:6] not in ("SELECT", "UPDATE"):
temp_tables = getattr(self, "__temp_tables", [])
parsed_queries = sqlparse.parse(query) parsed_queries = sqlparse.parse(query)
if any( if any(
parsed_query.get_type() in DDL_KEYWORDS parsed_query.get_type() in DDL_KEYWORDS
@@ -113,12 +136,15 @@ def post_load():
for parsed in parsed_queries for token in parsed.tokens[1:] for parsed in parsed_queries for token in parsed.tokens[1:]
): ):
qualified_query = ''.join( qualified_query = ''.join(
''.join(str(token) for token in schema_qualify(parsed_query)) ''.join(str(token) for token in schema_qualify(
parsed_query, temp_tables,
))
for parsed_query in parsed_queries for parsed_query in parsed_queries
) )
mogrified = self.mogrify(qualified_query, params).decode("utf8") mogrified = self.mogrify(qualified_query, params).decode("utf8")
query = "SELECT pglogical.replicate_ddl_command(%s, %s)" query = "SELECT pglogical.replicate_ddl_command(%s, %s)"
params = (mogrified, execute.replication_sets) params = (mogrified, execute.replication_sets)
setattr(self, "__temp_tables", temp_tables)
return execute.origin(self, query, params=params, log_exceptions=log_exceptions) return execute.origin(self, query, params=params, log_exceptions=log_exceptions)
execute.origin = execute_org execute.origin = execute_org

View File

@@ -107,6 +107,7 @@ class TestPglogical(TransactionCase):
def test_schema_qualify(self): def test_schema_qualify(self):
"""Test that schema qualifications are the only changes""" """Test that schema qualifications are the only changes"""
temp_tables = []
for statement in ( for statement in (
'create table if not exists testtable', 'create table if not exists testtable',
'drop table testtable', 'drop table testtable',
@@ -125,10 +126,11 @@ class TestPglogical(TransactionCase):
'drop table', 'drop table',
"alter table 'test'", "alter table 'test'",
'ALTER TABLE "testtable" ADD COLUMN "test_field" double precision', 'ALTER TABLE "testtable" ADD COLUMN "test_field" double precision',
'CREATE TEMP TABLE "temptable" (col1 char)', 'CREATE TEMP TABLE "temptable" (col1 char) INHERITS (ir_translation)',
'DROP TABLE "temptable"',
): ):
qualified_query = ''.join( qualified_query = ''.join(
''.join(str(token) for token in schema_qualify(parsed_query)) ''.join(str(token) for token in schema_qualify(parsed_query, temp_tables))
for parsed_query in sqlparse.parse(statement) for parsed_query in sqlparse.parse(statement)
) )
self.assertEqual( self.assertEqual(