From 38ae4a696c8195efcd92054a9bee825726edbded Mon Sep 17 00:00:00 2001 From: Holger Brunn Date: Tue, 12 Jul 2022 10:02:11 +0200 Subject: [PATCH] fixup! fixup! fixup! fixup! [IMP] use sqlparse also to determine which ddl to update --- pglogical/hooks.py | 36 ++++++++++++++++++++++++++----- pglogical/tests/test_pglogical.py | 6 ++++-- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/pglogical/hooks.py b/pglogical/hooks.py index cf8f5eb6..9f6ed8c1 100644 --- a/pglogical/hooks.py +++ b/pglogical/hooks.py @@ -10,20 +10,24 @@ except ImportError: sqlparse = None SECTION_NAME = "pglogical" -DDL_KEYWORDS = ("CREATE", "ALTER", "DROP", "TRUNCATE") +DDL_KEYWORDS = ("CREATE", "ALTER", "DROP", "TRUNCATE", "INHERITS") -def schema_qualify(parsed_query, schema="public"): +def schema_qualify(parsed_query, temp_tables, schema="public"): """ - Yield tokens and add a schema to objects if there's none + Yield tokens and add a schema to objects if there's none, but record and + exclude temporary tables """ token_iterator = parsed_query.flatten() Name = sqlparse.tokens.Name Punctuation = sqlparse.tokens.Punctuation Symbol = sqlparse.tokens.String.Symbol is_qualified = False + is_temp_table = False for token in token_iterator: yield token + if token.is_keyword and token.normalized == "INHERITS": + is_qualified = False if not is_qualified and token.is_keyword and token.normalized in DDL_KEYWORDS: # we check if the name coming after {create,drop,alter} object keywords # is schema qualified, and if not, add the schema we got passed @@ -40,10 +44,26 @@ def schema_qualify(parsed_query, schema="public"): ): # don't qualify CREATE TEMP TABLE statements is_qualified = True + is_temp_table = True break if not (next_token.is_whitespace or next_token.is_keyword): break yield next_token + if is_temp_table: + yield next_token + while True: + try: + next_token = token_iterator.__next__() + except StopIteration: + next_token = False + break + if next_token.ttype in (Name, Symbol): + temp_tables.append(str(next_token)) + yield next_token + next_token = False + break + else: + yield next_token if not next_token: continue if next_token.ttype not in (Name, Symbol): @@ -59,7 +79,9 @@ def schema_qualify(parsed_query, schema="public"): except StopIteration: needs_schema = True - if needs_schema: + if needs_schema and str(object_name_or_schema) in temp_tables: + temp_tables.remove(str(object_name_or_schema)) + elif needs_schema: yield sqlparse.sql.Token(Name, schema) yield sqlparse.sql.Token(Punctuation, '.') @@ -102,6 +124,7 @@ def post_load(): """Wrap DDL in pglogical.replicate_ddl_command""" # short circuit statements that must be as fast as possible if query[:6] not in ("SELECT", "UPDATE"): + temp_tables = getattr(self, "__temp_tables", []) parsed_queries = sqlparse.parse(query) if any( parsed_query.get_type() in DDL_KEYWORDS @@ -113,12 +136,15 @@ def post_load(): for parsed in parsed_queries for token in parsed.tokens[1:] ): qualified_query = ''.join( - ''.join(str(token) for token in schema_qualify(parsed_query)) + ''.join(str(token) for token in schema_qualify( + parsed_query, temp_tables, + )) for parsed_query in parsed_queries ) mogrified = self.mogrify(qualified_query, params).decode("utf8") query = "SELECT pglogical.replicate_ddl_command(%s, %s)" params = (mogrified, execute.replication_sets) + setattr(self, "__temp_tables", temp_tables) return execute.origin(self, query, params=params, log_exceptions=log_exceptions) execute.origin = execute_org diff --git a/pglogical/tests/test_pglogical.py b/pglogical/tests/test_pglogical.py index 59a39a7b..170c5976 100644 --- a/pglogical/tests/test_pglogical.py +++ b/pglogical/tests/test_pglogical.py @@ -107,6 +107,7 @@ class TestPglogical(TransactionCase): def test_schema_qualify(self): """Test that schema qualifications are the only changes""" + temp_tables = [] for statement in ( 'create table if not exists testtable', 'drop table testtable', @@ -125,10 +126,11 @@ class TestPglogical(TransactionCase): 'drop table', "alter table 'test'", 'ALTER TABLE "testtable" ADD COLUMN "test_field" double precision', - 'CREATE TEMP TABLE "temptable" (col1 char)', + 'CREATE TEMP TABLE "temptable" (col1 char) INHERITS (ir_translation)', + 'DROP TABLE "temptable"', ): qualified_query = ''.join( - ''.join(str(token) for token in schema_qualify(parsed_query)) + ''.join(str(token) for token in schema_qualify(parsed_query, temp_tables)) for parsed_query in sqlparse.parse(statement) ) self.assertEqual(