fixup! fixup! [IMP] use sqlparse also to determine which ddl to update

This commit is contained in:
Holger Brunn
2022-07-13 00:31:26 +02:00
parent b99468dec9
commit ba91374b7e
2 changed files with 96 additions and 107 deletions

View File

@@ -10,88 +10,67 @@ except ImportError:
sqlparse = None
SECTION_NAME = "pglogical"
DDL_KEYWORDS = ("CREATE", "ALTER", "DROP", "TRUNCATE")
QUALIFY_KEYWORDS = DDL_KEYWORDS + ("INHERITS", "FROM", "JOIN")
NO_QUALIFY_KEYWORDS = ("COLUMN",)
DDL_KEYWORDS = set(["CREATE", "ALTER", "DROP", "TRUNCATE"])
QUALIFY_KEYWORDS = DDL_KEYWORDS | set(["FROM", "INHERITS", "JOIN"])
NO_QUALIFY_KEYWORDS = set(["AS", "COLUMN", "ON", "RETURNS", "SELECT"])
TEMPORARY = set(["TEMP", "TEMPORARY"])
def schema_qualify(parsed_query, temp_tables, schema="public"):
def schema_qualify(tokens, temp_tables, keywords=None, schema="public"):
"""
Yield tokens and add a schema to objects if there's none, but record and
Add tokens to add a schema to objects if there's none, but record and
exclude temporary tables
"""
token_iterator = parsed_query.flatten()
Identifier = sqlparse.sql.Identifier
Name = sqlparse.tokens.Name
Punctuation = sqlparse.tokens.Punctuation
Symbol = sqlparse.tokens.String.Symbol
is_temp_table = False
for token in token_iterator:
yield token
if token.is_keyword and token.normalized in QUALIFY_KEYWORDS:
# we check if the name coming after {create,drop,alter} object keywords
# is schema qualified, and if not, add the schema we got passed
next_token = False
while True:
try:
next_token = token_iterator.__next__()
except StopIteration:
# this is invalid sql
next_token = False
break
if next_token.is_keyword and next_token.normalized in (
'TEMP', 'TEMPORARY'
):
# don't qualify CREATE TEMP TABLE statements
is_temp_table = True
break
if next_token.is_keyword and next_token.normalized in NO_QUALIFY_KEYWORDS:
yield next_token
next_token = False
break
if not (next_token.is_whitespace or next_token.is_keyword):
break
yield next_token
if is_temp_table:
is_temp_table = False
yield next_token
while True:
try:
next_token = token_iterator.__next__()
except StopIteration:
next_token = False
break
if next_token.ttype in (Name, Symbol):
temp_tables.append(str(next_token))
yield next_token
next_token = False
break
else:
yield next_token
if not next_token:
continue
if next_token.ttype not in (Name, Symbol):
yield next_token
continue
Token = sqlparse.sql.Token
Statement = sqlparse.sql.Statement
Function = sqlparse.sql.Function
Parenthesis = sqlparse.sql.Parenthesis
keywords = list(keywords or [])
object_name_or_schema = next_token
needs_schema = False
next_token = False
try:
next_token = token_iterator.__next__()
needs_schema = str(next_token) != '.'
except StopIteration:
needs_schema = True
for token in tokens.tokens:
if token.is_keyword:
keywords.append(token.normalized)
continue
elif token.is_whitespace:
continue
elif token.__class__ == Identifier and not token.is_wildcard():
str_token = str(token)
needs_qualification = "." not in str_token
# qualify tokens that are direct children of a statement as in ALTER TABLE ...
if token.parent.__class__ == Statement:
pass
# or of an expression parsed as function as in CREATE TABLE table
# but not within brackets
if token.parent.__class__ == Function:
needs_qualification &= not token.within(Parenthesis)
elif token.parent.__class__ == Parenthesis:
needs_qualification &= (
keywords and (keywords[-1] in QUALIFY_KEYWORDS) or False
)
# but not if the identifier is ie a column name
if set(keywords) & NO_QUALIFY_KEYWORDS:
needs_qualification &= (
keywords and (keywords[-1] in QUALIFY_KEYWORDS) or False
)
# and also not if this is a temporary table
if needs_qualification:
if len(keywords) > 1 and keywords[-2] in TEMPORARY:
needs_qualification = False
temp_tables.append(str_token)
elif str_token in temp_tables:
needs_qualification = False
temp_tables.remove(str_token)
if needs_qualification:
token.insert_before(0, Token(Punctuation, "."))
token.insert_before(0, Token(Name, schema))
keywords = []
elif token.is_group:
schema_qualify(token, temp_tables, keywords=keywords, schema=schema)
if needs_schema and str(object_name_or_schema) in temp_tables:
temp_tables.remove(str(object_name_or_schema))
elif needs_schema:
yield sqlparse.sql.Token(Name, schema)
yield sqlparse.sql.Token(Punctuation, '.')
yield object_name_or_schema
if next_token:
yield next_token
return tokens.tokens
def post_load():
@@ -128,18 +107,23 @@ def post_load():
temp_tables = getattr(self, "__temp_tables", [])
parsed_queries = sqlparse.parse(query)
if any(
parsed_query.get_type() in DDL_KEYWORDS
for parsed_query in parsed_queries
parsed_query.get_type() in DDL_KEYWORDS
for parsed_query in parsed_queries
) and not any(
token.is_keyword and token.normalized in
# don't replicate constraints, triggers, indexes
("CONSTRAINT", "TRIGGER", "INDEX")
for parsed in parsed_queries for token in parsed.tokens[1:]
token.is_keyword and token.normalized in
# don't replicate constraints, triggers, indexes
("CONSTRAINT", "TRIGGER", "INDEX")
for parsed in parsed_queries
for token in parsed.tokens[1:]
):
qualified_query = ''.join(
''.join(str(token) for token in schema_qualify(
parsed_query, temp_tables,
))
qualified_query = "".join(
"".join(
str(token)
for token in schema_qualify(
parsed_query,
temp_tables,
)
)
for parsed_query in parsed_queries
)
mogrified = self.mogrify(qualified_query, params).decode("utf8")