fixup! fixup! [IMP] use sqlparse also to determine which ddl to update

This commit is contained in:
Holger Brunn
2022-07-13 00:31:26 +02:00
parent b99468dec9
commit ba91374b7e
2 changed files with 96 additions and 107 deletions

View File

@@ -10,88 +10,67 @@ except ImportError:
sqlparse = None sqlparse = None
SECTION_NAME = "pglogical" SECTION_NAME = "pglogical"
DDL_KEYWORDS = ("CREATE", "ALTER", "DROP", "TRUNCATE") DDL_KEYWORDS = set(["CREATE", "ALTER", "DROP", "TRUNCATE"])
QUALIFY_KEYWORDS = DDL_KEYWORDS + ("INHERITS", "FROM", "JOIN") QUALIFY_KEYWORDS = DDL_KEYWORDS | set(["FROM", "INHERITS", "JOIN"])
NO_QUALIFY_KEYWORDS = ("COLUMN",) NO_QUALIFY_KEYWORDS = set(["AS", "COLUMN", "ON", "RETURNS", "SELECT"])
TEMPORARY = set(["TEMP", "TEMPORARY"])
def schema_qualify(parsed_query, temp_tables, schema="public"): def schema_qualify(tokens, temp_tables, keywords=None, schema="public"):
""" """
Yield tokens and add a schema to objects if there's none, but record and Add tokens to add a schema to objects if there's none, but record and
exclude temporary tables exclude temporary tables
""" """
token_iterator = parsed_query.flatten() Identifier = sqlparse.sql.Identifier
Name = sqlparse.tokens.Name Name = sqlparse.tokens.Name
Punctuation = sqlparse.tokens.Punctuation Punctuation = sqlparse.tokens.Punctuation
Symbol = sqlparse.tokens.String.Symbol Token = sqlparse.sql.Token
is_temp_table = False Statement = sqlparse.sql.Statement
for token in token_iterator: Function = sqlparse.sql.Function
yield token Parenthesis = sqlparse.sql.Parenthesis
if token.is_keyword and token.normalized in QUALIFY_KEYWORDS: keywords = list(keywords or [])
# we check if the name coming after {create,drop,alter} object keywords
# is schema qualified, and if not, add the schema we got passed for token in tokens.tokens:
next_token = False if token.is_keyword:
while True: keywords.append(token.normalized)
try:
next_token = token_iterator.__next__()
except StopIteration:
# this is invalid sql
next_token = False
break
if next_token.is_keyword and next_token.normalized in (
'TEMP', 'TEMPORARY'
):
# don't qualify CREATE TEMP TABLE statements
is_temp_table = True
break
if next_token.is_keyword and next_token.normalized in NO_QUALIFY_KEYWORDS:
yield next_token
next_token = False
break
if not (next_token.is_whitespace or next_token.is_keyword):
break
yield next_token
if is_temp_table:
is_temp_table = False
yield next_token
while True:
try:
next_token = token_iterator.__next__()
except StopIteration:
next_token = False
break
if next_token.ttype in (Name, Symbol):
temp_tables.append(str(next_token))
yield next_token
next_token = False
break
else:
yield next_token
if not next_token:
continue continue
if next_token.ttype not in (Name, Symbol): elif token.is_whitespace:
yield next_token
continue continue
elif token.__class__ == Identifier and not token.is_wildcard():
str_token = str(token)
needs_qualification = "." not in str_token
# qualify tokens that are direct children of a statement as in ALTER TABLE ...
if token.parent.__class__ == Statement:
pass
# or of an expression parsed as function as in CREATE TABLE table
# but not within brackets
if token.parent.__class__ == Function:
needs_qualification &= not token.within(Parenthesis)
elif token.parent.__class__ == Parenthesis:
needs_qualification &= (
keywords and (keywords[-1] in QUALIFY_KEYWORDS) or False
)
# but not if the identifier is ie a column name
if set(keywords) & NO_QUALIFY_KEYWORDS:
needs_qualification &= (
keywords and (keywords[-1] in QUALIFY_KEYWORDS) or False
)
# and also not if this is a temporary table
if needs_qualification:
if len(keywords) > 1 and keywords[-2] in TEMPORARY:
needs_qualification = False
temp_tables.append(str_token)
elif str_token in temp_tables:
needs_qualification = False
temp_tables.remove(str_token)
if needs_qualification:
token.insert_before(0, Token(Punctuation, "."))
token.insert_before(0, Token(Name, schema))
keywords = []
elif token.is_group:
schema_qualify(token, temp_tables, keywords=keywords, schema=schema)
object_name_or_schema = next_token return tokens.tokens
needs_schema = False
next_token = False
try:
next_token = token_iterator.__next__()
needs_schema = str(next_token) != '.'
except StopIteration:
needs_schema = True
if needs_schema and str(object_name_or_schema) in temp_tables:
temp_tables.remove(str(object_name_or_schema))
elif needs_schema:
yield sqlparse.sql.Token(Name, schema)
yield sqlparse.sql.Token(Punctuation, '.')
yield object_name_or_schema
if next_token:
yield next_token
def post_load(): def post_load():
@@ -134,12 +113,17 @@ def post_load():
token.is_keyword and token.normalized in token.is_keyword and token.normalized in
# don't replicate constraints, triggers, indexes # don't replicate constraints, triggers, indexes
("CONSTRAINT", "TRIGGER", "INDEX") ("CONSTRAINT", "TRIGGER", "INDEX")
for parsed in parsed_queries for token in parsed.tokens[1:] for parsed in parsed_queries
for token in parsed.tokens[1:]
): ):
qualified_query = ''.join( qualified_query = "".join(
''.join(str(token) for token in schema_qualify( "".join(
parsed_query, temp_tables, str(token)
)) for token in schema_qualify(
parsed_query,
temp_tables,
)
)
for parsed_query in parsed_queries for parsed_query in parsed_queries
) )
mogrified = self.mogrify(qualified_query, params).decode("utf8") mogrified = self.mogrify(qualified_query, params).decode("utf8")

View File

@@ -58,9 +58,11 @@ class TestPglogical(TransactionCase):
], ],
) )
with self._config(dict(pglogical={"replication_sets": "ddl_sql"})),\ with self._config(
self.assertLogs("odoo.addons.pglogical") as log,\ dict(pglogical={"replication_sets": "ddl_sql"})
mock.patch("odoo.addons.pglogical.hooks.sqlparse") as mock_sqlparse: ), self.assertLogs("odoo.addons.pglogical") as log, mock.patch(
"odoo.addons.pglogical.hooks.sqlparse"
) as mock_sqlparse:
mock_sqlparse.__bool__.return_value = False mock_sqlparse.__bool__.return_value = False
post_load() post_load()
@@ -109,35 +111,38 @@ class TestPglogical(TransactionCase):
"""Test that schema qualifications are the only changes""" """Test that schema qualifications are the only changes"""
temp_tables = [] temp_tables = []
for statement in ( for statement in (
'create table if not exists testtable', "create table if not exists testtable",
'drop table testtable', "drop table testtable",
'alter table testtable', "alter table testtable",
'''create table """create table
testtable testtable
(col1 int, col2 int); select * from testtable''', (col1 int, col2 int); select * from testtable""",
'alter table testschema.test drop column somecol', "alter table testschema.test drop column somecol",
' DROP view if exists testtable', " DROP view if exists testtable",
'truncate table testtable', "truncate table testtable",
'''CREATE FUNCTION testtable(integer, integer) RETURNS integer """CREATE FUNCTION testtable(integer, integer) RETURNS integer
AS 'select $1 + $2;' AS 'select $1 + $2;'
LANGUAGE SQL LANGUAGE SQL
IMMUTABLE IMMUTABLE
RETURNS NULL ON NULL INPUT''', RETURNS NULL ON NULL INPUT""",
'drop table', "drop table",
"alter table 'test'", "alter table 'test'",
'ALTER TABLE "testtable" ADD COLUMN "test_field" double precision', 'ALTER TABLE "testtable" ADD COLUMN "test_field" double precision',
'CREATE TEMP TABLE "temptable" (col1 char) INHERITS (ir_translation)', 'CREATE TEMP TABLE "temptable" (col1 char) INHERITS (testtable)',
'DROP TABLE "temptable"', 'DROP TABLE "temptable"',
'create view testtable as select col1, col2 from testtable join ' "create view testtable as select col1, col2 from testtable join "
'testtable test1 on col3=test1.col4)', "testtable test1 on col3=test1.col4)",
'CREATE TABLE public."ir_model" (id SERIAL NOT NULL, PRIMARY KEY(id))',
): ):
qualified_query = ''.join( qualified_query = "".join(
''.join(str(token) for token in schema_qualify(parsed_query, temp_tables)) "".join(
str(token) for token in schema_qualify(parsed_query, temp_tables)
)
for parsed_query in sqlparse.parse(statement) for parsed_query in sqlparse.parse(statement)
) )
self.assertEqual( self.assertEqual(
qualified_query, qualified_query,
statement.replace('testtable', 'public.testtable').replace( statement.replace("testtable", "public.testtable").replace(
'"public.testtable"', 'public."testtable"' '"public.testtable"', 'public."testtable"'
) ),
) )