Source code for pg4n.cmp_domain_checker

import re
from dataclasses import dataclass
from typing import Optional

import sqlglot.expressions as exp

from .errfmt import ErrorFormatter
from .sqlparser import Column, SqlParser

VT100_UNDERLINE = "\x1b[4m"
VT100_RESET = "\x1b[0m"


[docs]@dataclass(frozen=True) class CmpContext: expression: exp.Predicate a: Column b: Column
[docs]class CmpDomainChecker: def __init__(self, parsed_sql: exp.Expression, columns: list[Column]): self.parsed_sql: str = parsed_sql self.columns: list[Column] = columns self.suspicious_cmp_contexts: list[CmpContext] = [] self.warning_msg: Optional[str] = None
[docs] def check(self) -> Optional[str]: """ Does analysis for suspicous comparisons between different domains. e.g., comparing columns off type VARCHAR(20) and VARCHAR(50) Returns a warning message if something was found, otherwise None. """ self._detect_suspicious_cmps(self.parsed_sql, self.columns) return self.warning_msg
def _are_from_compatible_domains(self, a: Column, b: Column) -> bool: """ Tests whether the both types a and b are representable by their datatype. """ # First we must extract the base type without possible trailing # precission or digit counts (which can be retrieved from the # Column.type directly). type_start_matcher = re.compile(r"^([a-zA-Z]+).*$") match_a = type_start_matcher.match(a.type.name) match_b = type_start_matcher.match(b.type.name) if not (match_a and match_b): return True elif len(match_a.groups()) != 1 or len(match_b.groups()) != 1: return True type_start_a = match_a.group(1) type_start_b = match_b.group(1) # We don't evaluate types of different names # TODO: Consider evaluating similar types e.g. VARCHAR and CHAR if type_start_a != type_start_b: return True if a.type.digits is None and b.type.digits is None: return True elif (a.type.digits is None and b.type.digits is not None) or ( a.type.digits is not None and b.type.digits is None ): return False if a.type.digits != b.type.digits: return False else: if a.type.precision is None and b.type.precision is None: return True elif (a.type.precision is None and b.type.precision is not None) or ( a.type.precision is not None and b.type.precision is None ): return False return True def _detect_suspicious_cmp(self, cmp: exp.Predicate, columns: list[Column]): """ Detects whether 'cmp' has comparison between columns from different domains (e.g. a: VARCHAR(10) < b: VARCHAR(50)). Only works for comparisons between 2 variables. In other words, if any operand in the comparison is literal this function returns False. Ignores any casts used for the operands. """ # TODO: Take casts into account. # TODO: Consider comparisons with literals as well. # TODO: Investigate if postgresql already errors/warns of these kind of # comparisons # FIXME: Stop assuming columns with the same name coming from possibly # different tables have the same datatype column_expressions = cmp.find_all(exp.Column) cmp_column_names = [] cmp_columns = [] for column_expression in column_expressions: cmp_column_name = SqlParser.get_column_name_from_column_expression( column_expression ) cmp_column_names.append(cmp_column_name) for cmp_column_name in cmp_column_names: tmp = _find_column(cmp_column_name, columns) if tmp is not None: cmp_columns.append(tmp) # This the comparisons has atleast 1 literal if len(cmp_columns) < 2: return False if not self._are_from_compatible_domains(cmp_columns[0], cmp_columns[1]): cmp_context = CmpContext(cmp, cmp_columns[0], cmp_columns[1]) self.suspicious_cmp_contexts.append(cmp_context) def _detect_suspicious_cmps( self, select_statement: exp.Select, columns: list[Column] ): predicates = SqlParser.find_where_predicates(select_statement) # This filters predicates we are not interested in such as IN or EXISTS binary_predicates = list( filter(lambda x: isinstance(x, exp.Binary), predicates) ) for binary_predicate in binary_predicates: self._detect_suspicious_cmp(binary_predicate, columns) if len(self.suspicious_cmp_contexts) == 0: return for i, suspicious_cmp_context in enumerate(self.suspicious_cmp_contexts): if self.warning_msg is None: self.warning_msg = "" whole_statement = str(select_statement) domain1 = suspicious_cmp_context.a.type.name domain2 = suspicious_cmp_context.b.type.name warning = f"Comparison between different domains ({domain1}, {domain2})" # TODO: Develop more principled way of matching warning_name with # the options expected in the configuration files. warning_name = type(self).__name__.rstrip("Checker") cmp_exp = suspicious_cmp_context.expression # It does not matter which column's ancestor Where expression we # find because both necessarily have the same. containing_where = str(cmp_exp.find_ancestor(exp.Where)) containing_where_start_offset = whole_statement.find(containing_where) cmp_start_offset = containing_where.find(str(cmp_exp)) cmp_end_offset = cmp_start_offset + len(str(cmp_exp)) total_start_offset = containing_where_start_offset + cmp_start_offset total_end_offset = containing_where_start_offset + cmp_end_offset underlined_query = ( whole_statement[:total_start_offset] + VT100_UNDERLINE + whole_statement[total_start_offset:total_end_offset] + VT100_RESET + whole_statement[total_end_offset : len(whole_statement)] ) formatter = ErrorFormatter(warning, warning_name, underlined_query) self.warning_msg = formatter.format() if i != len(self.suspicious_cmp_contexts) - 1: self.warning_msg += "\n"
def _find_column(column_name: str, columns: list[Column]) -> Optional[Column]: """ Returns a column matching column name. """ # TODO: Just use dict inside an object for column in columns: if column_name == column.name: return column return None