diff --git a/pglogical/hooks.py b/pglogical/hooks.py index 3d9fa492..33ae6fe7 100644 --- a/pglogical/hooks.py +++ b/pglogical/hooks.py @@ -10,88 +10,67 @@ except ImportError: sqlparse = None SECTION_NAME = "pglogical" -DDL_KEYWORDS = ("CREATE", "ALTER", "DROP", "TRUNCATE") -QUALIFY_KEYWORDS = DDL_KEYWORDS + ("INHERITS", "FROM", "JOIN") -NO_QUALIFY_KEYWORDS = ("COLUMN",) +DDL_KEYWORDS = set(["CREATE", "ALTER", "DROP", "TRUNCATE"]) +QUALIFY_KEYWORDS = DDL_KEYWORDS | set(["FROM", "INHERITS", "JOIN"]) +NO_QUALIFY_KEYWORDS = set(["AS", "COLUMN", "ON", "RETURNS", "SELECT"]) +TEMPORARY = set(["TEMP", "TEMPORARY"]) -def schema_qualify(parsed_query, temp_tables, schema="public"): +def schema_qualify(tokens, temp_tables, keywords=None, schema="public"): """ - Yield tokens and add a schema to objects if there's none, but record and + Add tokens to add a schema to objects if there's none, but record and exclude temporary tables """ - token_iterator = parsed_query.flatten() + Identifier = sqlparse.sql.Identifier Name = sqlparse.tokens.Name Punctuation = sqlparse.tokens.Punctuation - Symbol = sqlparse.tokens.String.Symbol - is_temp_table = False - for token in token_iterator: - yield token - if token.is_keyword and token.normalized in QUALIFY_KEYWORDS: - # we check if the name coming after {create,drop,alter} object keywords - # is schema qualified, and if not, add the schema we got passed - next_token = False - while True: - try: - next_token = token_iterator.__next__() - except StopIteration: - # this is invalid sql - next_token = False - break - if next_token.is_keyword and next_token.normalized in ( - 'TEMP', 'TEMPORARY' - ): - # don't qualify CREATE TEMP TABLE statements - is_temp_table = True - break - if next_token.is_keyword and next_token.normalized in NO_QUALIFY_KEYWORDS: - yield next_token - next_token = False - break - if not (next_token.is_whitespace or next_token.is_keyword): - break - yield next_token - if is_temp_table: - is_temp_table = False - yield next_token - while True: - try: - next_token = token_iterator.__next__() - except StopIteration: - next_token = False - break - if next_token.ttype in (Name, Symbol): - temp_tables.append(str(next_token)) - yield next_token - next_token = False - break - else: - yield next_token - if not next_token: - continue - if next_token.ttype not in (Name, Symbol): - yield next_token - continue + Token = sqlparse.sql.Token + Statement = sqlparse.sql.Statement + Function = sqlparse.sql.Function + Parenthesis = sqlparse.sql.Parenthesis + keywords = list(keywords or []) - object_name_or_schema = next_token - needs_schema = False - next_token = False - try: - next_token = token_iterator.__next__() - needs_schema = str(next_token) != '.' - except StopIteration: - needs_schema = True + for token in tokens.tokens: + if token.is_keyword: + keywords.append(token.normalized) + continue + elif token.is_whitespace: + continue + elif token.__class__ == Identifier and not token.is_wildcard(): + str_token = str(token) + needs_qualification = "." not in str_token + # qualify tokens that are direct children of a statement as in ALTER TABLE ... + if token.parent.__class__ == Statement: + pass + # or of an expression parsed as function as in CREATE TABLE table + # but not within brackets + if token.parent.__class__ == Function: + needs_qualification &= not token.within(Parenthesis) + elif token.parent.__class__ == Parenthesis: + needs_qualification &= ( + keywords and (keywords[-1] in QUALIFY_KEYWORDS) or False + ) + # but not if the identifier is ie a column name + if set(keywords) & NO_QUALIFY_KEYWORDS: + needs_qualification &= ( + keywords and (keywords[-1] in QUALIFY_KEYWORDS) or False + ) + # and also not if this is a temporary table + if needs_qualification: + if len(keywords) > 1 and keywords[-2] in TEMPORARY: + needs_qualification = False + temp_tables.append(str_token) + elif str_token in temp_tables: + needs_qualification = False + temp_tables.remove(str_token) + if needs_qualification: + token.insert_before(0, Token(Punctuation, ".")) + token.insert_before(0, Token(Name, schema)) + keywords = [] + elif token.is_group: + schema_qualify(token, temp_tables, keywords=keywords, schema=schema) - if needs_schema and str(object_name_or_schema) in temp_tables: - temp_tables.remove(str(object_name_or_schema)) - elif needs_schema: - yield sqlparse.sql.Token(Name, schema) - yield sqlparse.sql.Token(Punctuation, '.') - - yield object_name_or_schema - - if next_token: - yield next_token + return tokens.tokens def post_load(): @@ -128,18 +107,23 @@ def post_load(): temp_tables = getattr(self, "__temp_tables", []) parsed_queries = sqlparse.parse(query) if any( - parsed_query.get_type() in DDL_KEYWORDS - for parsed_query in parsed_queries + parsed_query.get_type() in DDL_KEYWORDS + for parsed_query in parsed_queries ) and not any( - token.is_keyword and token.normalized in - # don't replicate constraints, triggers, indexes - ("CONSTRAINT", "TRIGGER", "INDEX") - for parsed in parsed_queries for token in parsed.tokens[1:] + token.is_keyword and token.normalized in + # don't replicate constraints, triggers, indexes + ("CONSTRAINT", "TRIGGER", "INDEX") + for parsed in parsed_queries + for token in parsed.tokens[1:] ): - qualified_query = ''.join( - ''.join(str(token) for token in schema_qualify( - parsed_query, temp_tables, - )) + qualified_query = "".join( + "".join( + str(token) + for token in schema_qualify( + parsed_query, + temp_tables, + ) + ) for parsed_query in parsed_queries ) mogrified = self.mogrify(qualified_query, params).decode("utf8") diff --git a/pglogical/tests/test_pglogical.py b/pglogical/tests/test_pglogical.py index c8ac7c06..dbdcbc80 100644 --- a/pglogical/tests/test_pglogical.py +++ b/pglogical/tests/test_pglogical.py @@ -58,9 +58,11 @@ class TestPglogical(TransactionCase): ], ) - with self._config(dict(pglogical={"replication_sets": "ddl_sql"})),\ - self.assertLogs("odoo.addons.pglogical") as log,\ - mock.patch("odoo.addons.pglogical.hooks.sqlparse") as mock_sqlparse: + with self._config( + dict(pglogical={"replication_sets": "ddl_sql"}) + ), self.assertLogs("odoo.addons.pglogical") as log, mock.patch( + "odoo.addons.pglogical.hooks.sqlparse" + ) as mock_sqlparse: mock_sqlparse.__bool__.return_value = False post_load() @@ -109,35 +111,38 @@ class TestPglogical(TransactionCase): """Test that schema qualifications are the only changes""" temp_tables = [] for statement in ( - 'create table if not exists testtable', - 'drop table testtable', - 'alter table testtable', - '''create table + "create table if not exists testtable", + "drop table testtable", + "alter table testtable", + """create table testtable - (col1 int, col2 int); select * from testtable''', - 'alter table testschema.test drop column somecol', - ' DROP view if exists testtable', - 'truncate table testtable', - '''CREATE FUNCTION testtable(integer, integer) RETURNS integer + (col1 int, col2 int); select * from testtable""", + "alter table testschema.test drop column somecol", + " DROP view if exists testtable", + "truncate table testtable", + """CREATE FUNCTION testtable(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE - RETURNS NULL ON NULL INPUT''', - 'drop table', - "alter table 'test'", - 'ALTER TABLE "testtable" ADD COLUMN "test_field" double precision', - 'CREATE TEMP TABLE "temptable" (col1 char) INHERITS (ir_translation)', - 'DROP TABLE "temptable"', - 'create view testtable as select col1, col2 from testtable join ' - 'testtable test1 on col3=test1.col4)', + RETURNS NULL ON NULL INPUT""", + "drop table", + "alter table 'test'", + 'ALTER TABLE "testtable" ADD COLUMN "test_field" double precision', + 'CREATE TEMP TABLE "temptable" (col1 char) INHERITS (testtable)', + 'DROP TABLE "temptable"', + "create view testtable as select col1, col2 from testtable join " + "testtable test1 on col3=test1.col4)", + 'CREATE TABLE public."ir_model" (id SERIAL NOT NULL, PRIMARY KEY(id))', ): - qualified_query = ''.join( - ''.join(str(token) for token in schema_qualify(parsed_query, temp_tables)) + qualified_query = "".join( + "".join( + str(token) for token in schema_qualify(parsed_query, temp_tables) + ) for parsed_query in sqlparse.parse(statement) ) self.assertEqual( qualified_query, - statement.replace('testtable', 'public.testtable').replace( + statement.replace("testtable", "public.testtable").replace( '"public.testtable"', 'public."testtable"' - ) + ), )