mirror of
https://github.com/OCA/server-backend.git
synced 2025-02-18 09:52:42 +02:00
fixup! fixup! fixup! fixup! [IMP] use sqlparse also to determine which ddl to update
This commit is contained in:
@@ -10,20 +10,24 @@ except ImportError:
|
||||
sqlparse = None
|
||||
|
||||
SECTION_NAME = "pglogical"
|
||||
DDL_KEYWORDS = ("CREATE", "ALTER", "DROP", "TRUNCATE")
|
||||
DDL_KEYWORDS = ("CREATE", "ALTER", "DROP", "TRUNCATE", "INHERITS")
|
||||
|
||||
|
||||
def schema_qualify(parsed_query, schema="public"):
|
||||
def schema_qualify(parsed_query, temp_tables, schema="public"):
|
||||
"""
|
||||
Yield tokens and add a schema to objects if there's none
|
||||
Yield tokens and add a schema to objects if there's none, but record and
|
||||
exclude temporary tables
|
||||
"""
|
||||
token_iterator = parsed_query.flatten()
|
||||
Name = sqlparse.tokens.Name
|
||||
Punctuation = sqlparse.tokens.Punctuation
|
||||
Symbol = sqlparse.tokens.String.Symbol
|
||||
is_qualified = False
|
||||
is_temp_table = False
|
||||
for token in token_iterator:
|
||||
yield token
|
||||
if token.is_keyword and token.normalized == "INHERITS":
|
||||
is_qualified = False
|
||||
if not is_qualified and token.is_keyword and token.normalized in DDL_KEYWORDS:
|
||||
# we check if the name coming after {create,drop,alter} object keywords
|
||||
# is schema qualified, and if not, add the schema we got passed
|
||||
@@ -40,10 +44,26 @@ def schema_qualify(parsed_query, schema="public"):
|
||||
):
|
||||
# don't qualify CREATE TEMP TABLE statements
|
||||
is_qualified = True
|
||||
is_temp_table = True
|
||||
break
|
||||
if not (next_token.is_whitespace or next_token.is_keyword):
|
||||
break
|
||||
yield next_token
|
||||
if is_temp_table:
|
||||
yield next_token
|
||||
while True:
|
||||
try:
|
||||
next_token = token_iterator.__next__()
|
||||
except StopIteration:
|
||||
next_token = False
|
||||
break
|
||||
if next_token.ttype in (Name, Symbol):
|
||||
temp_tables.append(str(next_token))
|
||||
yield next_token
|
||||
next_token = False
|
||||
break
|
||||
else:
|
||||
yield next_token
|
||||
if not next_token:
|
||||
continue
|
||||
if next_token.ttype not in (Name, Symbol):
|
||||
@@ -59,7 +79,9 @@ def schema_qualify(parsed_query, schema="public"):
|
||||
except StopIteration:
|
||||
needs_schema = True
|
||||
|
||||
if needs_schema:
|
||||
if needs_schema and str(object_name_or_schema) in temp_tables:
|
||||
temp_tables.remove(str(object_name_or_schema))
|
||||
elif needs_schema:
|
||||
yield sqlparse.sql.Token(Name, schema)
|
||||
yield sqlparse.sql.Token(Punctuation, '.')
|
||||
|
||||
@@ -102,6 +124,7 @@ def post_load():
|
||||
"""Wrap DDL in pglogical.replicate_ddl_command"""
|
||||
# short circuit statements that must be as fast as possible
|
||||
if query[:6] not in ("SELECT", "UPDATE"):
|
||||
temp_tables = getattr(self, "__temp_tables", [])
|
||||
parsed_queries = sqlparse.parse(query)
|
||||
if any(
|
||||
parsed_query.get_type() in DDL_KEYWORDS
|
||||
@@ -113,12 +136,15 @@ def post_load():
|
||||
for parsed in parsed_queries for token in parsed.tokens[1:]
|
||||
):
|
||||
qualified_query = ''.join(
|
||||
''.join(str(token) for token in schema_qualify(parsed_query))
|
||||
''.join(str(token) for token in schema_qualify(
|
||||
parsed_query, temp_tables,
|
||||
))
|
||||
for parsed_query in parsed_queries
|
||||
)
|
||||
mogrified = self.mogrify(qualified_query, params).decode("utf8")
|
||||
query = "SELECT pglogical.replicate_ddl_command(%s, %s)"
|
||||
params = (mogrified, execute.replication_sets)
|
||||
setattr(self, "__temp_tables", temp_tables)
|
||||
return execute.origin(self, query, params=params, log_exceptions=log_exceptions)
|
||||
|
||||
execute.origin = execute_org
|
||||
|
||||
@@ -107,6 +107,7 @@ class TestPglogical(TransactionCase):
|
||||
|
||||
def test_schema_qualify(self):
|
||||
"""Test that schema qualifications are the only changes"""
|
||||
temp_tables = []
|
||||
for statement in (
|
||||
'create table if not exists testtable',
|
||||
'drop table testtable',
|
||||
@@ -125,10 +126,11 @@ class TestPglogical(TransactionCase):
|
||||
'drop table',
|
||||
"alter table 'test'",
|
||||
'ALTER TABLE "testtable" ADD COLUMN "test_field" double precision',
|
||||
'CREATE TEMP TABLE "temptable" (col1 char)',
|
||||
'CREATE TEMP TABLE "temptable" (col1 char) INHERITS (ir_translation)',
|
||||
'DROP TABLE "temptable"',
|
||||
):
|
||||
qualified_query = ''.join(
|
||||
''.join(str(token) for token in schema_qualify(parsed_query))
|
||||
''.join(str(token) for token in schema_qualify(parsed_query, temp_tables))
|
||||
for parsed_query in sqlparse.parse(statement)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
||||
Reference in New Issue
Block a user