Source code for pg4n.qepparser

from itertools import chain
from typing import Callable, Iterable, List, TypedDict

import psycopg
from psycopg import Connection

from . import util  # used to test relative imports

# TODO: break into variants discriminated by Node Type
# right now, the interface isn't safe to use because it's not clear what fields
# are available for each node type
node = TypedDict(
    "Plan",
    {
        "Node Type": str,
        "Parent Relationship": str,
        "Startup Cost": float,
        "Total Cost": float,
        "Plan Rows": int,
        "Plan Width": int,
        "Actual Startup Time": float,
        "Actual Total Time": float,
        "Actual Rows": int,
        "Actual Loops": int,
        "Plans": List["node"],
        "Index Cond": str,
        "Filter": str,
        "Relation Name": str,
        "Alias": str,
        "Scan Direction": str,
        "Index Name": str,
        "Triggers": list[str],
        "Total Runtime": float,
        "One-Time Filter": str,
    },
)

qep = TypedDict(
    "QEP",
    {
        "Plan": node,
        "Triggers": list[str],
        "Planning Time": float,
        "Execution Time": float,
        "Total Runtime": float,
    },
)


[docs]class QEPNode: """A node in a query execution plan.""" def __init__(self, node_: node): """Create a new QEPNode. :param node_: the node to wrap""" self._node = node_ def __iter__(self) -> Iterable["QEPNode"]: """Iterate over child nodes.""" return map(QEPNode, self._node.get("Plans", [])) def __len__(self) -> int: """Get the number of child nodes.""" return len(self._node["Plans"]) def __getitem__(self, key: int) -> "QEPNode": """Get the child node at the given index. :param key: the index of the child node to get :returns: the child node at the given index """ return QEPNode(self._node["Plans"][key]) def __str__(self): return self._node.__str__() def __repr__(self): return self._node.__repr__() @property def plan(self) -> node: """A dict of the node's properties.""" return self._node @property def plans(self) -> list[node]: """A list of the node's children.""" return self._node.get("Plans", [])
[docs] def find(self, pr: Callable[[node], bool], recursive=False) -> list[node]: """Finds nodes matching the predicate. :param pr: a function that takes a node and returns True if it matches :param recursive: if True, search recursively, otherwise only search this+children :returns: a list of matching nodes """ if recursive: return self.find(pr) + list( chain.from_iterable(x.find(pr, True) for x in iter(self)) ) return list(filter(pr, chain((self._node,), self.plans)))
[docs] def rfind(self, pred: Callable[[node], bool]) -> list[node]: """Finds nodes matching the predicate, recursively. :param pred: a function that takes a node and returns True if it matches :returns: a list of matching nodes """ return self.find(pred, recursive=True)
[docs] def findval(self, key: str, val: object, recursive=False) -> list[node]: """Finds nodes with the given key and value. :param key: the key to search for :param val: the value to search for :param recursive: if True, search recursively, otherwise only search this+children :returns: a list of matching nodes """ return self.find(lambda x: x.get(key) == val, recursive)
[docs] def rfindval(self, key: str, val: object) -> list[node]: """Finds nodes with the given key and value, recursively. :param key: the key to search for :param val: the value to search for :returns: a list of matching nodes """ return self.findval(key, val, recursive=True)
[docs]class QEPAnalysis: """Represents the result of EXPLAIN ANALYZE.""" def __init__(self, qep_: qep): self._qep = qep_ def __str__(self): return self._qep.__str__() def __repr__(self): return self._qep.__repr__() @property def root(self) -> QEPNode: """The root node of the query execution plan.""" return QEPNode(self._qep["Plan"]) @property def plan(self) -> node: """A dict of the root node's properties.""" return self._qep["Plan"] @property def qep(self) -> qep: """A dict of the query execution plan's properties.""" return self._qep
[docs]class QEPParser: """Performs analyses on given queries, returning resultant QEPAnalysis.""" def __init__(self, *args, conn=None, constraint_exclusion=True, **kwargs): self._ref = bool(conn) self._conn: Connection = conn or psycopg.connect(*args, **kwargs) # use constraint_exclusion to avoid unnecessary index scans if constraint_exclusion: with self._conn.cursor() as cur: try: cur.execute("set constraint_exclusion = on;") self._conn.commit() except Exception as e: self._conn.rollback() else: with self._conn.cursor() as cur: try: cur.execute("set constraint_exclusion = off;") self._conn.commit() except Exception as e: self._conn.rollback() def __del__(self): if not self._ref: self._conn.close() def __call__(self, stmt: str, *args, **kwargs) -> QEPAnalysis: """ Executes a query and returns the query execution plan as a dictionary. Parameters: stmt: The query to execute. *args: Positional arguments to pass to cursor.execute(). **kwargs: Keyword arguments to pass to cursor.execute(). Returns: A dictionary representing the query execution plan. """ stmt = ( "explain (format json, analyze, verbose) " + stmt.strip().rstrip(";") + ";" ) try: with self._conn.cursor() as cur: cur.execute(stmt, *args, **kwargs) res = cur.fetchall() self._conn.rollback() if (n := len(res)) != 1: raise ValueError(f"Expected 1 row, got {n}") if (n := len(res[0])) != 1: raise ValueError(f"Expected 1 column, got {n}") if (n := len(res[0][0])) != 1: raise ValueError(f"Expected 1 item in column, got {n}") if (t := type(res[0][0][0])) != dict: raise ValueError(f"Expected dict in column, got {t}") return QEPAnalysis(res[0][0][0]) except psycopg.Error as e: self._conn.rollback()
[docs] def parse(self, stmt: str, *args, **kwargs) -> QEPAnalysis: """Alias for __call__""" return self(stmt, *args, **kwargs)