diff --git a/README.md b/README.md index 09dfae0..59d345d 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,109 @@ For a local socket connection, exclude the "host" and "port" portions:: e = create_engine("ibm_db_sa://user:pass@/database") ``` +## Logging in ibm_db_sa +The `ibm_db_sa` module provides built-in logging support to help debug SQLAlchemy dialect operations, connection handling, reflection, and SQL execution. +Logging can be enabled using the `ibmdbsa_log` query parameter in the database connection URL. + +### Logging Configuration +Use the `ibmdbsa_log` parameter at the end of the connection URL. +Possible values: +``` +True -> Enable logging to the console +"filename" -> Enable logging to a file (file is overwritten on each run) +False -> Disable logging +``` + +--- +### Examples +#### 1. Enable Logging to a File (Current Directory) +Logs will be written to a file in the current working directory. +```python +from sqlalchemy import create_engine +engine = create_engine( + "ibm_db_sa://userID:Password@host:port/database?ibmdbsa_log=log_ibmdbsa.txt" +) +``` +This will create the file: +``` +log_ibmdbsa.txt +``` +--- +#### 2. Enable Logging to a File (Specific Directory) +You can provide an absolute file path. +#### Windows example +```python +from sqlalchemy import create_engine +engine = create_engine( + "ibm_db_sa://userID:Password@host:port/database?ibmdbsa_log=C:\\Users\\Logs\\log_ibmdbsa.txt" +) +``` +#### Linux / macOS example +```python +from sqlalchemy import create_engine +engine = create_engine( + "ibm_db_sa://userID:Password@host:port/database?ibmdbsa_log=/var/log/log_ibmdbsa.txt" +) +``` +--- +#### 3. Enable Console Logging +Logs will be printed to the terminal or console. +```python +from sqlalchemy import create_engine +engine = create_engine( + "ibm_db_sa://userID:Password@host:port/database?ibmdbsa_log=True" +) +``` +--- +### Using SQLAlchemy URL Object +Logging can also be configured using SQLAlchemy's `URL` object. +#### Log to File +```python +from sqlalchemy import create_engine +from sqlalchemy.engine import URL +url_object = URL.create( + drivername="ibm_db_sa", + username="userID", + password="Password", + host="host", + port=port, + database="database", + query={"ibmdbsa_log": "log_ibmdbsa.txt"}, +) +engine = create_engine(url_object) +``` +--- +#### Log to Console +```python +from sqlalchemy import create_engine +from sqlalchemy.engine import URL +url_object = URL.create( + drivername="ibm_db_sa", + username="userID", + password="Password", + host="host", + port=port, + database="database", + query={"ibmdbsa_log": "True"}, +) +engine = create_engine(url_object) +``` +--- +### Notes +- Logging configuration is automatically detected when the SQLAlchemy engine is created. +- The `ibmdbsa_log` parameter is removed internally before the connection parameters are passed to the DBAPI driver. +- If logging is not specified, logging remains disabled by default. +--- +#### Typical Use Cases +Logging can help diagnose: +- Connection issues +- SQL execution problems +- Reflection metadata queries +- Dialect initialization +- Performance troubleshooting + + + Supported Databases ------------------- - IBM DB2 Database for Linux/Unix/Windows versions 11.5 onwards diff --git a/ibm_db_sa/base.py b/ibm_db_sa/base.py index e96640c..7ac89fb 100644 --- a/ibm_db_sa/base.py +++ b/ibm_db_sa/base.py @@ -19,6 +19,7 @@ """Support for IBM DB2 database """ +import sys import sqlalchemy import datetime, re from sqlalchemy import types as sa_types @@ -29,6 +30,9 @@ from sqlalchemy.sql import compiler from sqlalchemy.sql import operators from sqlalchemy.engine import default +from sqlalchemy import event +from sqlalchemy.engine import Engine +from .logger import logger, log_entry_exit from sqlalchemy import __version__ as SA_VERSION_STR from . import reflection as ibm_reflection @@ -285,162 +289,268 @@ class XML(sa_types.Text): class DB2TypeCompiler(compiler.GenericTypeCompiler): - - def visit_TIMESTAMP(self, type_, **kw): - return "TIMESTAMP" - - def visit_DATE(self, type_, **kw): - return "DATE" - - def visit_TIME(self, type_, **kw): - return "TIME" - - def visit_DATETIME(self, type_, **kw): - return self.visit_TIMESTAMP(type_, **kw) - - def visit_SMALLINT(self, type_, **kw): - return "SMALLINT" - - def visit_BOOLEAN(self, type_, **kw): - return "BOOLEAN" - - def visit_INT(self, type_, **kw): - return "INT" - - def visit_BIGINT(self, type_, **kw): - return "BIGINT" - - def visit_FLOAT(self, type_, **kw): - return "FLOAT" if type_.precision is None else \ - "FLOAT(%(precision)s)" % {'precision': type_.precision} - - def visit_DOUBLE(self, type_, **kw): - return "DOUBLE" - - def visit_XML(self, type_, **kw): - return "XML" - - def visit_CLOB(self, type_, **kw): - return "CLOB" - - def visit_BLOB(self, type_, **kw): - return "BLOB(1M)" if type_.length in (None, 0) else \ - "BLOB(%(length)s)" % {'length': type_.length} - - def visit_DBCLOB(self, type_, **kw): - return "DBCLOB(1M)" if type_.length in (None, 0) else \ - "DBCLOB(%(length)s)" % {'length': type_.length} - - def visit_VARCHAR(self, type_, **kw): - return "VARCHAR(%(length)s)" % {'length': type_.length} - - def visit_LONGVARCHAR(self, type_, **kw): - return "LONG VARCHAR" - - def visit_VARGRAPHIC(self, type_, **kw): - return "VARGRAPHIC(%(length)s)" % {'length': type_.length} - - def visit_LONGVARGRAPHIC(self, type_, **kw): - return "LONG VARGRAPHIC" - - def visit_CHAR(self, type_, **kw): - return "CHAR" if type_.length in (None, 0) else \ - "CHAR(%(length)s)" % {'length': type_.length} - - def visit_GRAPHIC(self, type_, **kw): - return "GRAPHIC" if type_.length in (None, 0) else \ - "GRAPHIC(%(length)s)" % {'length': type_.length} - - def visit_DECIMAL(self, type_, **kw): - if not type_.precision: - return "DECIMAL(31, 0)" - elif not type_.scale: - return "DECIMAL(%(precision)s, 0)" % {'precision': type_.precision} - else: - return "DECIMAL(%(precision)s, %(scale)s)" % { - 'precision': type_.precision, 'scale': type_.scale} - - def visit_numeric(self, type_, **kw): - return self.visit_DECIMAL(type_, **kw) - - def visit_datetime(self, type_, **kw): - return self.visit_TIMESTAMP(type_, **kw) - - def visit_date(self, type_, **kw): - return self.visit_DATE(type_, **kw) - - def visit_time(self, type_, **kw): - return self.visit_TIME(type_, **kw) - - def visit_integer(self, type_, **kw): - return self.visit_INT(type_, **kw) - - def visit_boolean(self, type_, **kw): - return self.visit_BOOLEAN(type_, **kw) - - def visit_float(self, type_, **kw): - return self.visit_FLOAT(type_, **kw) - - def visit_unicode(self, type_, **kw): - check_server = getattr(DB2Dialect, 'serverType') - return (self.visit_VARGRAPHIC(type_, **kw) + " CCSID 1200") \ - if check_server == "DB2" else self.visit_VARGRAPHIC(type_, **kw) - - def visit_unicode_text(self, type_, **kw): - return self.visit_LONGVARGRAPHIC(type_, **kw) - - def visit_string(self, type_, **kw): - return self.visit_VARCHAR(type_, **kw) - - def visit_TEXT(self, type_, **kw): - return self.visit_CLOB(type_, **kw) - - def visit_large_binary(self, type_, **kw): - return self.visit_BLOB(type_, **kw) + @log_entry_exit + def visit_TIMESTAMP(self, type_, **kw): + sql = "TIMESTAMP" + logger.debug(f"Type rendering -> TIMESTAMP -> {sql}") + return sql + + @log_entry_exit + def visit_DATE(self, type_, **kw): + sql = "DATE" + logger.debug(f"Type rendering -> DATE -> {sql}") + return sql + + @log_entry_exit + def visit_TIME(self, type_, **kw): + sql = "TIME" + logger.debug(f"Type rendering -> TIME -> {sql}") + return sql + + def visit_DATETIME(self, type_, **kw): + logger.debug("Redirecting DATETIME to TIMESTAMP") + return self.visit_TIMESTAMP(type_, **kw) + + def visit_SMALLINT(self, type_, **kw): + sql = "SMALLINT" + logger.debug(f"Type rendering -> SMALLINT -> {sql}") + return sql + + def visit_BOOLEAN(self, type_, **kw): + sql = "BOOLEAN" + logger.debug(f"Type rendering -> BOOLEAN -> {sql}") + return sql + + def visit_INT(self, type_, **kw): + sql = "INT" + logger.debug(f"Type rendering -> INT -> {sql}") + return sql + + def visit_BIGINT(self, type_, **kw): + sql = "BIGINT" + logger.debug(f"Type rendering -> BIGINT -> {sql}") + return sql + + def visit_FLOAT(self, type_, **kw): + precision = type_.precision + if precision is None: + sql = "FLOAT" + else: + sql = f"FLOAT({precision})" + logger.debug(f"Type rendering -> FLOAT -> precision={precision}, sql={sql}") + return sql + + def visit_DOUBLE(self, type_, **kw): + sql = "DOUBLE" + logger.debug(f"Type rendering -> DOUBLE -> {sql}") + return sql + + def visit_XML(self, type_, **kw): + sql = "XML" + logger.debug(f"Type rendering -> XML -> {sql}") + return sql + + def visit_CLOB(self, type_, **kw): + sql = "CLOB" + logger.debug(f"Type rendering -> CLOB -> {sql}") + return sql + + def visit_BLOB(self, type_, **kw): + length = type_.length + sql = "BLOB(1M)" if length in (None, 0) else f"BLOB({length})" + logger.debug(f"Type rendering -> BLOB -> length={length}, sql={sql}") + return sql + + def visit_DBCLOB(self, type_, **kw): + length = type_.length + sql = "DBCLOB(1M)" if length in (None, 0) else f"DBCLOB({length})" + logger.debug(f"Type rendering -> DBCLOB -> length={length}, sql={sql}") + return sql + + def visit_VARCHAR(self, type_, **kw): + length = type_.length + sql = f"VARCHAR({length})" + logger.debug(f"Type rendering -> VARCHAR -> length={length}, sql={sql}") + return sql + + def visit_LONGVARCHAR(self, type_, **kw): + sql = "LONG VARCHAR" + logger.debug(f"Type rendering -> LONG VARCHAR -> {sql}") + return sql + + def visit_VARGRAPHIC(self, type_, **kw): + length = type_.length + sql = f"VARGRAPHIC({length})" + logger.debug(f"Type rendering -> VARGRAPHIC -> length={length}, sql={sql}") + return sql + + def visit_LONGVARGRAPHIC(self, type_, **kw): + sql = "LONG VARGRAPHIC" + logger.debug(f"Type rendering -> LONG VARGRAPHIC -> {sql}") + return sql + + def visit_CHAR(self, type_, **kw): + length = type_.length + sql = "CHAR" if length in (None, 0) else f"CHAR({length})" + logger.debug(f"Type rendering -> CHAR -> length={length}, sql={sql}") + return sql + + def visit_GRAPHIC(self, type_, **kw): + length = type_.length + sql = "GRAPHIC" if length in (None, 0) else f"GRAPHIC({length})" + logger.debug(f"Type rendering -> GRAPHIC -> length={length}, sql={sql}") + return sql + + @log_entry_exit + def visit_DECIMAL(self, type_, **kw): + precision = type_.precision + scale = type_.scale + if not precision: + sql = "DECIMAL(31, 0)" + elif not scale: + sql = f"DECIMAL({precision}, 0)" + else: + sql = f"DECIMAL({precision}, {scale})" + logger.debug( + f"Type rendering -> DECIMAL -> " + f"precision={precision}, scale={scale}, sql={sql}" + ) + return sql + + def visit_numeric(self, type_, **kw): + logger.debug("Redirecting numeric to DECIMAL") + return self.visit_DECIMAL(type_, **kw) + + def visit_datetime(self, type_, **kw): + logger.debug("Redirecting datetime to TIMESTAMP") + return self.visit_TIMESTAMP(type_, **kw) + + def visit_date(self, type_, **kw): + logger.debug("Redirecting date to DATE") + return self.visit_DATE(type_, **kw) + + def visit_time(self, type_, **kw): + logger.debug("Redirecting time to TIME") + return self.visit_TIME(type_, **kw) + + def visit_integer(self, type_, **kw): + logger.debug("Redirecting integer to INT") + return self.visit_INT(type_, **kw) + + def visit_boolean(self, type_, **kw): + logger.debug("Redirecting boolean to BOOLEAN") + return self.visit_BOOLEAN(type_, **kw) + + def visit_float(self, type_, **kw): + logger.debug("Redirecting float to FLOAT") + return self.visit_FLOAT(type_, **kw) + + def visit_unicode(self, type_, **kw): + check_server = getattr(DB2Dialect, "serverType") + base_sql = self.visit_VARGRAPHIC(type_, **kw) + if check_server == "DB2": + sql = base_sql + " CCSID 1200" + else: + sql = base_sql + logger.debug( + f"Type rendering -> UNICODE -> " + f"serverType={check_server}, sql={sql}" + ) + return sql + + def visit_unicode_text(self, type_, **kw): + logger.debug("Redirecting unicode_text to LONGVARGRAPHIC") + return self.visit_LONGVARGRAPHIC(type_, **kw) + + def visit_string(self, type_, **kw): + logger.debug("Redirecting string to VARCHAR") + return self.visit_VARCHAR(type_, **kw) + + def visit_TEXT(self, type_, **kw): + logger.debug("Redirecting TEXT to CLOB") + return self.visit_CLOB(type_, **kw) + + def visit_large_binary(self, type_, **kw): + logger.debug("Redirecting large_binary to BLOB") + return self.visit_BLOB(type_, **kw) class DB2Compiler(compiler.SQLCompiler): if SA_VERSION_MM < (0, 9): + @log_entry_exit def visit_false(self, expr, **kw): - return '0' + logger.debug("Rendering FALSE literal as 0") + return "0" + @log_entry_exit def visit_true(self, expr, **kw): - return '1' + logger.debug("Rendering TRUE literal as 1") + return "1" + @log_entry_exit def get_cte_preamble(self, recursive): + logger.debug(f"Generating CTE preamble -> recursive={recursive}") return "WITH" + @log_entry_exit def visit_now_func(self, fn, **kw): + logger.debug("Rendering NOW function as CURRENT_TIMESTAMP") return "CURRENT_TIMESTAMP" + @log_entry_exit def for_update_clause(self, select, **kw): - if select.for_update is True: - return ' WITH RS USE AND KEEP UPDATE LOCKS' - elif select.for_update == 'read': - return ' WITH RS USE AND KEEP SHARE LOCKS' + for_update = select.for_update + logger.debug(f"Processing FOR UPDATE clause -> value={for_update}") + if for_update is True: + clause = " WITH RS USE AND KEEP UPDATE LOCKS" + elif for_update == "read": + clause = " WITH RS USE AND KEEP SHARE LOCKS" else: - return '' + clause = "" + logger.debug(f"Generated FOR UPDATE clause -> {clause}") + return clause + @log_entry_exit def visit_mod_binary(self, binary, operator, **kw): - return "mod(%s, %s)" % (self.process(binary.left), - self.process(binary.right)) + left_expr = binary.left + right_expr = binary.right + left = self.process(left_expr) + right = self.process(right_expr) + sql = f"mod({left}, {right})" + logger.debug( + f"Rendering MOD binary -> left={left}, right={right}, sql={sql}" + ) + return sql def literalBindsFlagFrom_kw(self, kw=None): """Return True if literal_binds is requested in compile kwargs.""" if not kw or not isinstance(kw, dict): + logger.debug("literal_binds check -> False (invalid kw)") return False - if kw.get("literal_binds"): - return True - ck = kw.get("compile_kwargs") - if isinstance(ck, dict) and ck.get("literal_binds"): + literal_binds = kw.get("literal_binds") + if literal_binds: + logger.debug("literal_binds detected at top level") return True + compile_kwargs = kw.get("compile_kwargs") + if isinstance(compile_kwargs, dict): + if compile_kwargs.get("literal_binds"): + logger.debug("literal_binds detected in compile_kwargs") + return True + logger.debug("literal_binds not requested") return False + @log_entry_exit def limit_clause(self, select, **kw): text = "" limit_clause = select._limit_clause offset_clause = select._offset_clause literal_binds = self.literalBindsFlagFrom_kw(kw) + logger.debug( + f"Processing LIMIT/OFFSET -> " + f"limit={limit_clause}, " + f"offset={offset_clause}, " + f"literal_binds={literal_binds}" + ) def _render_clause(clause): if clause is None: @@ -459,16 +569,18 @@ def _render_clause(clause): pass try: if isinstance(clause, BindParameter): - val = getattr(clause, "value", None) - if val is not None: - if isinstance(val, str): - return f"'{val}'" - return str(val) + value = getattr(clause, "value", None) + if value is not None: + if isinstance(value, str): + return f"'{value}'" + return str(value) except Exception: pass try: return self.process(clause, **kw) except Exception as e: + logger.error("Failed to render LIMIT/OFFSET clause") + logger.exception("Stack trace in limit_clause") raise exc.CompileError( "dialect 'ibm_db_sa' cannot render LIMIT/OFFSET for this clause; " "ensure the clause is a simple integer or is processable by the compiler." @@ -476,97 +588,174 @@ def _render_clause(clause): limit_text = _render_clause(limit_clause) if limit_text is not None: - text += " LIMIT %s" % limit_text + text += f" LIMIT {limit_text}" + logger.debug(f"Applied LIMIT -> {limit_text}") offset_text = _render_clause(offset_clause) if offset_text is not None: - text += " OFFSET %s" % offset_text + text += f" OFFSET {offset_text}" + logger.debug(f"Applied OFFSET -> {offset_text}") + logger.debug(f"Generated LIMIT/OFFSET clause -> {text}") return text + @log_entry_exit def visit_select(self, select, **kw): - sql_ori = compiler.SQLCompiler.visit_select(self, select, **kw) - if ("LIMIT" in sql_ori.upper()) or ("FETCH FIRST" in sql_ori.upper()): - return sql_ori - limit_clause_obj = select._limit_clause - offset_clause_obj = select._offset_clause - if limit_clause_obj is not None: - limit_offset_clause = self.limit_clause(select, **kw) - if limit_offset_clause: - return sql_ori + limit_offset_clause - if offset_clause_obj is not None: - __rownum = 'Z.__ROWNUM' - sql_work = re.sub(r'FETCH FIRST \d+ ROWS ONLY', '', sql_ori, flags=re.IGNORECASE).strip() - sql_work = re.sub(r'\s+OFFSET\s+(?:\d+|__\[POSTCOMPILE_[^\]]+\]|:[A-Za-z0-9_]+|\?)\s*$', '', sql_work, - flags=re.IGNORECASE) - sql_split = re.split(r"[\s+]FROM ", sql_work, 1) - if len(sql_split) < 2: + try: + sql_ori = compiler.SQLCompiler.visit_select(self, select, **kw) + sql_upper = sql_ori.upper() + logger.debug("Processing SELECT compilation.") + if ("LIMIT" in sql_upper) or ("FETCH FIRST" in sql_upper): + logger.debug("LIMIT/FETCH already present. Returning original SQL.") + logger.debug(f"Final SELECT SQL -> {sql_ori}") return sql_ori - sql_sec = " \nFROM %s " % (sql_split[1]) - dummyVal = "Z.__db2_" - sql_pri = "" - sql_sel = "SELECT " - if select._distinct: - sql_sel = "SELECT DISTINCT " - sql_select_token = sql_split[0].split(",") - i = 0 - while i < len(sql_select_token): - if sql_select_token[i].count("TIMESTAMP(DATE(SUBSTR(CHAR(") == 1: + limit_clause_obj = select._limit_clause + offset_clause_obj = select._offset_clause + logger.debug( + f"SELECT limit/offset detection -> " + f"limit={limit_clause_obj}, offset={offset_clause_obj}" + ) + if limit_clause_obj is not None: + limit_offset_clause = self.limit_clause(select, **kw) + if limit_offset_clause: + final_sql = sql_ori + limit_offset_clause + logger.debug("Applying simple LIMIT/OFFSET clause.") + logger.debug(f"Final SELECT SQL -> {final_sql}") + return final_sql + if offset_clause_obj is not None: + logger.debug("Applying DB2 ROW_NUMBER based OFFSET rewrite.") + __rownum = "Z.__ROWNUM" + sql_work = re.sub( + r"FETCH FIRST \d+ ROWS ONLY", + "", + sql_ori, + flags=re.IGNORECASE, + ).strip() + sql_work = re.sub( + r"\s+OFFSET\s+(?:\d+|__\[POSTCOMPILE_[^\]]+\]|:[A-Za-z0-9_]+|\?)\s*$", + "", + sql_work, + flags=re.IGNORECASE, + ) + sql_split = re.split(r"[\s+]FROM ", sql_work, 1) + if len(sql_split) < 2: + logger.debug("Unable to split SELECT for OFFSET rewrite.") + logger.debug(f"Final SELECT SQL -> {sql_ori}") + return sql_ori + sql_sec = f" \nFROM {sql_split[1]} " + dummyVal = "Z.__db2_" + sql_pri = "" + sql_sel = "SELECT DISTINCT " if select._distinct else "SELECT " + sql_select_token = sql_split[0].split(",") + i = 0 + while i < len(sql_select_token): + token = sql_select_token[i] + if token.count("TIMESTAMP(DATE(SUBSTR(CHAR(") == 1: + sql_sel = f'{sql_sel} "{dummyVal}{i + 1}",' + sql_pri = ( + f'{sql_pri} {sql_select_token[i]},' + f'{sql_select_token[i + 1]},' + f'{sql_select_token[i + 2]},' + f'{sql_select_token[i + 3]} AS "{dummyVal}{i + 1}",' + ) + i += 4 + continue + if token.count(" AS ") == 1: + temp_col_alias = token.split(" AS ") + sql_pri = f"{sql_pri} {token}," + sql_sel = f"{sql_sel} {temp_col_alias[1]}," + i += 1 + continue + sql_pri = f'{sql_pri} {token} AS "{dummyVal}{i + 1}",' sql_sel = f'{sql_sel} "{dummyVal}{i + 1}",' - sql_pri = f'{sql_pri} {sql_select_token[i]},{sql_select_token[i + 1]},{sql_select_token[i + 2]},{sql_select_token[i + 3]} AS "{dummyVal}{i + 1}",' - i += 4 - continue - if sql_select_token[i].count(" AS ") == 1: - temp_col_alias = sql_select_token[i].split(" AS ") - sql_pri = f'{sql_pri} {sql_select_token[i]},' - sql_sel = f'{sql_sel} {temp_col_alias[1]},' i += 1 - continue - sql_pri = f'{sql_pri} {sql_select_token[i]} AS "{dummyVal}{i + 1}",' - sql_sel = f'{sql_sel} "{dummyVal}{i + 1}",' - i += 1 - sql_pri = sql_pri.rstrip(",") - sql_pri = f"{sql_pri}{sql_sec}" - sql_sel = sql_sel.rstrip(",") - sql = f'{sql_sel}, ( ROW_NUMBER() OVER() ) AS "{__rownum}" FROM ( {sql_pri} ) AS M' - sql = f'{sql_sel} FROM ( {sql} ) Z WHERE' - - def _process_clause_text(clause): - if clause is None: - return None - if select._simple_int_clause(clause): - return self.process(clause.render_literal_execute(), **kw) - else: + sql_pri = sql_pri.rstrip(",") + sql_pri = f"{sql_pri}{sql_sec}" + sql_sel = sql_sel.rstrip(",") + sql = ( + f'{sql_sel}, ( ROW_NUMBER() OVER() ) AS "{__rownum}" ' + f"FROM ( {sql_pri} ) AS M" + ) + sql = f'{sql_sel} FROM ( {sql} ) Z WHERE' + + def _process_clause_text(clause): + if clause is None: + return None + if select._simple_int_clause(clause): + return self.process( + clause.render_literal_execute(), **kw + ) return self.process(clause, **kw) - offset_text = _process_clause_text(offset_clause_obj) - limit_text = _process_clause_text(limit_clause_obj) - if offset_text is not None: - sql = f'{sql} "{__rownum}" > {offset_text}' - if offset_text is not None and limit_text is not None: - sql = f'{sql} AND ' - if limit_text is not None: + offset_text = _process_clause_text(offset_clause_obj) + limit_text = _process_clause_text(limit_clause_obj) if offset_text is not None: - sql = f'{sql} "{__rownum}" <= ({offset_text} + {limit_text})' - else: - sql = f'{sql} "{__rownum}" <= {limit_text}' - return f"( {sql} )" - return sql_ori + sql = f'{sql} "{__rownum}" > {offset_text}' + if offset_text is not None and limit_text is not None: + sql = f"{sql} AND " + if limit_text is not None: + if offset_text is not None: + sql = ( + f'{sql} "{__rownum}" <= ' + f"({offset_text} + {limit_text})" + ) + else: + sql = f'{sql} "{__rownum}" <= {limit_text}' + final_sql = f"( {sql} )" + logger.debug("Generated ROW_NUMBER based pagination SQL.") + logger.debug(f"Final SELECT SQL -> {final_sql}") + return final_sql + logger.debug("Returning original SELECT SQL.") + logger.debug(f"Final SELECT SQL -> {sql_ori}") + return sql_ori + except Exception as e: + logger.error(f"Error compiling SELECT statement: {e}") + logger.exception("Stack trace in visit_select") + raise + @log_entry_exit def visit_sequence(self, sequence, **kw): - if sequence.schema: - return "NEXT VALUE FOR %s.%s" % (sequence.schema, sequence.name) - return "NEXT VALUE FOR %s" % sequence.name - + try: + schema = sequence.schema + name = sequence.name + logger.debug(f"Rendering sequence -> schema={schema}, name={name}") + if schema: + sql = f"NEXT VALUE FOR {schema}.{name}" + else: + sql = f"NEXT VALUE FOR {name}" + logger.debug(f"Generated sequence SQL -> {sql}") + return sql + except Exception as e: + logger.error(f"Error rendering sequence: {e}") + logger.exception("Stack trace in visit_sequence") + raise + + @log_entry_exit def default_from(self): # DB2 uses SYSIBM.SYSDUMMY1 table for row count + logger.debug("Rendering default FROM clause (SYSIBM.SYSDUMMY1)") return " FROM SYSIBM.SYSDUMMY1" + @log_entry_exit def visit_function(self, func, result_map=None, **kwargs): - if func.name.upper() == "AVG": - return "AVG(DOUBLE(%s))" % (self.function_argspec(func, **kwargs)) - elif func.name.upper() == "CHAR_LENGTH": - return "CHAR_LENGTH(%s, %s)" % (self.function_argspec(func, **kwargs), 'OCTETS') - else: - return compiler.SQLCompiler.visit_function(self, func, **kwargs) + try: + func_name = func.name.upper() + logger.debug(f"Rendering function -> name={func_name}") + if func_name == "AVG": + args = self.function_argspec(func, **kwargs) + sql = f"AVG(DOUBLE({args}))" + logger.debug(f"Rewritten AVG function -> {sql}") + return sql + elif func_name == "CHAR_LENGTH": + args = self.function_argspec(func, **kwargs) + sql = f"CHAR_LENGTH({args}, OCTETS)" + logger.debug(f"Rewritten CHAR_LENGTH function -> {sql}") + return sql + sql = compiler.SQLCompiler.visit_function(self, func, **kwargs) + logger.debug(f"Default function rendering -> {sql}") + return sql + except Exception as e: + logger.error(f"Error rendering function {func}: {e}") + logger.exception("Stack trace in visit_function") + raise # TODO: this is wrong but need to know what DB2 is expecting here # if func.name.upper() == "LENGTH": @@ -574,245 +763,532 @@ def visit_function(self, func, result_map=None, **kwargs): # else: # return compiler.SQLCompiler.visit_function(self, func, **kwargs) + @log_entry_exit def visit_cast(self, cast, **kw): - type_ = cast.typeclause.type - - if SA_VERSION_MM >= (2, 0): - valid_types = ( - CHAR, VARCHAR, CLOB, String, Text, Unicode, UnicodeText, - BLOB, LargeBinary, VARBINARY, - SMALLINT, SmallInteger, - INTEGER, Integer, - BIGINT, BigInteger, - DECIMAL, NUMERIC, Float, REAL, DOUBLE, Numeric, - DATE, Date, TIME, Time, TIMESTAMP, DateTime, - BOOLEAN, Boolean, - NullType - ) - else: - valid_types = ( - CHAR, VARCHAR, CLOB, String, Text, Unicode, UnicodeText, - BLOB, LargeBinary, VARBINARY, - SMALLINT, SmallInteger, - INTEGER, Integer, - BIGINT, BigInteger, - DECIMAL, NUMERIC, Float, REAL, Numeric, - DATE, Date, TIME, Time, TIMESTAMP, DateTime, - BOOLEAN, Boolean, - NullType - ) - - if isinstance(type_, valid_types): - return super(DB2Compiler, self).visit_cast(cast, **kw) - else: + try: + type_ = cast.typeclause.type + logger.debug(f"Rendering CAST -> type={type_}") + if SA_VERSION_MM >= (2, 0): + valid_types = ( + CHAR, VARCHAR, CLOB, String, Text, Unicode, UnicodeText, + BLOB, LargeBinary, VARBINARY, + SMALLINT, SmallInteger, + INTEGER, Integer, + BIGINT, BigInteger, + DECIMAL, NUMERIC, Float, REAL, DOUBLE, Numeric, + DATE, Date, TIME, Time, TIMESTAMP, DateTime, + BOOLEAN, Boolean, + NullType + ) + else: + valid_types = ( + CHAR, VARCHAR, CLOB, String, Text, Unicode, UnicodeText, + BLOB, LargeBinary, VARBINARY, + SMALLINT, SmallInteger, + INTEGER, Integer, + BIGINT, BigInteger, + DECIMAL, NUMERIC, Float, REAL, Numeric, + DATE, Date, TIME, Time, TIMESTAMP, DateTime, + BOOLEAN, Boolean, + NullType + ) + if isinstance(type_, valid_types): + sql = super(DB2Compiler, self).visit_cast(cast, **kw) + logger.debug(f"Standard CAST rendering -> {sql}") + return sql + logger.debug("Unsupported CAST type, processing clause only.") return self.process(cast.clause) + except Exception as e: + logger.error(f"Error rendering CAST: {e}") + logger.exception("Stack trace in visit_cast") + raise def get_select_precolumns(self, select, **kwargs): - if isinstance(select._distinct, str): - return select._distinct.upper() + " " - elif select._distinct: - return "DISTINCT " + distinct_value = select._distinct + if isinstance(distinct_value, str): + result = distinct_value.upper() + " " + elif distinct_value: + result = "DISTINCT " else: - return "" + result = "" + logger.debug(f"SELECT precolumns -> {result.strip()}") + return result + @log_entry_exit def visit_join(self, join, asfrom=False, **kwargs): - join_type = " INNER JOIN " - if join.full: - join_type = " FULL OUTER JOIN " - elif join.isouter: - join_type = " LEFT OUTER JOIN " - - return ''.join( - (self.process(join.left, asfrom=True, **kwargs), - join_type, - self.process(join.right, asfrom=True, **kwargs), - " ON ", - self.process(join.onclause, **kwargs))) - + try: + join_type = " INNER JOIN " + if join.full: + join_type = " FULL OUTER JOIN " + elif join.isouter: + join_type = " LEFT OUTER JOIN " + logger.debug( + f"Rendering JOIN -> type={join_type.strip()}, " + f"left={join.left}, right={join.right}" + ) + sql = "".join( + ( + self.process(join.left, asfrom=True, **kwargs), + join_type, + self.process(join.right, asfrom=True, **kwargs), + " ON ", + self.process(join.onclause, **kwargs), + ) + ) + logger.debug(f"Generated JOIN SQL -> {sql}") + return sql + except Exception as e: + logger.error(f"Error rendering JOIN: {e}") + logger.exception("Stack trace in visit_join") + raise + + @log_entry_exit def visit_savepoint(self, savepoint_stmt): - return "SAVEPOINT %(sid)s ON ROLLBACK RETAIN CURSORS" % {'sid': self.preparer.format_savepoint(savepoint_stmt)} + sid = self.preparer.format_savepoint(savepoint_stmt) + sql = f"SAVEPOINT {sid} ON ROLLBACK RETAIN CURSORS" + logger.debug(f"Generated SAVEPOINT SQL -> {sql}") + return sql + @log_entry_exit def visit_rollback_to_savepoint(self, savepoint_stmt): - return 'ROLLBACK TO SAVEPOINT %(sid)s' % {'sid': self.preparer.format_savepoint(savepoint_stmt)} + sid = self.preparer.format_savepoint(savepoint_stmt) + sql = f"ROLLBACK TO SAVEPOINT {sid}" + logger.debug(f"Generated ROLLBACK TO SAVEPOINT SQL -> {sql}") + return sql + @log_entry_exit def visit_release_savepoint(self, savepoint_stmt): - return 'RELEASE TO SAVEPOINT %(sid)s' % {'sid': self.preparer.format_savepoint(savepoint_stmt)} + sid = self.preparer.format_savepoint(savepoint_stmt) + sql = f"RELEASE TO SAVEPOINT {sid}" + logger.debug(f"Generated RELEASE SAVEPOINT SQL -> {sql}") + return sql + @log_entry_exit def visit_unary(self, unary, **kw): - if (unary.operator == operators.exists) and kw.get('within_columns_clause', False): - usql = super(DB2Compiler, self).visit_unary(unary, **kw) - usql = "CASE WHEN " + usql + " THEN 1 ELSE 0 END" - return usql - else: - return super(DB2Compiler, self).visit_unary(unary, **kw) + try: + operator_ = unary.operator + within_columns = kw.get("within_columns_clause", False) + logger.debug( + f"Rendering UNARY -> operator={operator_}, " + f"within_columns_clause={within_columns}" + ) + if operator_ == operators.exists and within_columns: + usql = super(DB2Compiler, self).visit_unary(unary, **kw) + sql = f"CASE WHEN {usql} THEN 1 ELSE 0 END" + logger.debug(f"Rewritten EXISTS unary -> {sql}") + return sql + sql = super(DB2Compiler, self).visit_unary(unary, **kw) + logger.debug(f"Standard unary rendering -> {sql}") + return sql + except Exception as e: + logger.error(f"Error rendering unary expression: {e}") + logger.exception("Stack trace in visit_unary") + raise class DB2DDLCompiler(compiler.DDLCompiler): @staticmethod + @log_entry_exit def get_server_version_info(dialect): """Returns the DB2 server major and minor version as a list of ints.""" - return [int(ver_token) for ver_token in dialect.dbms_ver.split('.')[0:2]] \ - if hasattr(dialect, 'dbms_ver') else [] + try: + if hasattr(dialect, 'dbms_ver') and dialect.dbms_ver: + version_tokens = dialect.dbms_ver.split('.')[0:2] + version_info = [int(ver_token) for ver_token in version_tokens] + logger.debug( + f"Parsed server version -> raw={dialect.dbms_ver}, parsed={version_info}" + ) + return version_info + logger.warning("Dialect has no dbms_ver attribute or version is empty.") + return [] + except Exception as e: + logger.error(f"Failed to parse server version: {e}") + logger.exception("Stack trace in get_server_version_info") + raise @classmethod + @log_entry_exit def _is_nullable_unique_constraint_supported(cls, dialect): - """Checks to see if the DB2 version is at least 10.5. - This is needed for checking if unique constraints with null columns are supported. """ - dbms_name = getattr(dialect, 'dbms_name', None) - if hasattr(dialect, 'dbms_name'): - if not (dbms_name is None) and (dbms_name.find('DB2/') != -1): - return cls.get_server_version_info(dialect) >= [10, 5] - else: + Checks to see if DB2 version is at least 10.5. + Required to determine if unique constraints with nullable columns are supported. + """ + try: + dbms_name = getattr(dialect, 'dbms_name', None) + logger.debug(f"Checking nullable unique constraint support -> dbms_name={dbms_name}") + if not dbms_name: + logger.warning("DBMS name not available for constraint capability check.") + return False + if 'DB2/' in dbms_name: + version_info = cls.get_server_version_info(dialect) + supported = version_info >= [10, 5] + logger.info( + f"Nullable unique constraint support -> " + f"version={version_info}, supported={supported}" + ) + return supported + logger.debug("DBMS is not DB2 LUW. Nullable unique constraint not supported.") return False + except Exception as e: + logger.error(f"Error checking nullable unique constraint support: {e}") + logger.exception("Stack trace in _is_nullable_unique_constraint_supported") + raise + @log_entry_exit def get_column_specification(self, column, **kw): - col_spec = [self.preparer.format_column(column), - self.dialect.type_compiler.process(column.type, type_expression=column)] - - # column-options: "NOT NULL" - if not column.nullable or column.primary_key: - col_spec.append('NOT NULL') - - # default-clause: - default = self.get_column_default_string(column) - if default is not None: - col_spec.extend(['WITH DEFAULT', default]) - - if column is column.table._autoincrement_column: - col_spec.extend(['GENERATED BY DEFAULT', - 'AS IDENTITY', - '(START WITH 1)']) - column_spec = ' '.join(col_spec) - return column_spec - + try: + column_name = column.name + column_type = column.type + logger.debug( + f"Generating column specification -> " + f"name={column_name}, type={column_type}, " + f"nullable={column.nullable}, primary_key={column.primary_key}" + ) + col_spec = [ + self.preparer.format_column(column), + self.dialect.type_compiler.process(column.type, type_expression=column)] + # NOT NULL handling + if not column.nullable or column.primary_key: + col_spec.append('NOT NULL') + logger.debug("Applied NOT NULL constraint.") + # DEFAULT handling + default = self.get_column_default_string(column) + if default is not None: + col_spec.extend(['WITH DEFAULT', default]) + logger.debug(f"Applied default clause -> {default}") + # AUTOINCREMENT handling + auto_column = column.table._autoincrement_column + if column is auto_column: + logger.debug("Column is autoincrement column. Applying IDENTITY clause.") + col_spec.extend([ + 'GENERATED BY DEFAULT', + 'AS IDENTITY', + '(START WITH 1)' + ]) + column_spec = ' '.join(col_spec) + logger.debug(f"Final column specification generated -> {column_spec}") + return column_spec + except Exception as e: + logger.error(f"Error generating column specification: {e}") + logger.exception("Stack trace in get_column_specification") + raise + + @log_entry_exit def define_constraint_cascades(self, constraint): - text = "" - if constraint.ondelete is not None: - text += " ON DELETE %s" % constraint.ondelete - - if constraint.onupdate is not None: - util.warn( - "DB2 does not support UPDATE CASCADE for foreign keys.") - - return text - + try: + constraint_name = getattr(constraint, "name", None) + ondelete = constraint.ondelete + onupdate = constraint.onupdate + logger.debug( + f"Defining constraint cascades -> " + f"name={constraint_name}, " + f"ondelete={ondelete}, " + f"onupdate={onupdate}" + ) + text = "" + if ondelete is not None: + text += f" ON DELETE {ondelete}" + logger.debug(f"Applied ON DELETE clause -> {ondelete}") + if onupdate is not None: + logger.warning( + "DB2 does not support UPDATE CASCADE for foreign keys." + ) + util.warn( + "DB2 does not support UPDATE CASCADE for foreign keys." + ) + logger.debug(f"Cascade definition result -> {text}") + return text + except Exception as e: + logger.error(f"Error defining constraint cascades: {e}") + logger.exception("Stack trace in define_constraint_cascades") + raise + + @log_entry_exit def visit_drop_constraint(self, drop, **kw): - constraint = drop.element - if isinstance(constraint, sa_schema.ForeignKeyConstraint): - qual = "FOREIGN KEY " - const = self.preparer.format_constraint(constraint) - elif isinstance(constraint, sa_schema.PrimaryKeyConstraint): - qual = "PRIMARY KEY " - const = "" - elif isinstance(constraint, sa_schema.UniqueConstraint): - qual = "UNIQUE " - if self._is_nullable_unique_constraint_supported(self.dialect): - for column in constraint: - if column.nullable: - constraint.uConstraint_as_index = True - if getattr(constraint, 'uConstraint_as_index', None): - qual = "INDEX " - const = self.preparer.format_constraint(constraint) - else: - qual = "" - const = self.preparer.format_constraint(constraint) - - return ("DROP %s%s" % (qual, const)) if \ - hasattr(constraint, 'uConstraint_as_index') and constraint.uConstraint_as_index else \ - ("ALTER TABLE %s DROP %s%s" % (self.preparer.format_table(constraint.table), qual, const)) - - def create_table_constraints(self, table, **kw): - if self._is_nullable_unique_constraint_supported(self.dialect): - for constraint in table._sorted_constraints: - if isinstance(constraint, sa_schema.UniqueConstraint): + try: + constraint = drop.element + constraint_name = getattr(constraint, "name", None) + constraint_table = getattr(constraint, "table", None) + constraint_type = type(constraint).__name__ + logger.debug( + f"Processing DROP constraint -> " + f"type={constraint_type}, " + f"name={constraint_name}" + ) + if isinstance(constraint, sa_schema.ForeignKeyConstraint): + qual = "FOREIGN KEY " + const = self.preparer.format_constraint(constraint) + elif isinstance(constraint, sa_schema.PrimaryKeyConstraint): + qual = "PRIMARY KEY " + const = "" + elif isinstance(constraint, sa_schema.UniqueConstraint): + qual = "UNIQUE " + nullable_supported = self._is_nullable_unique_constraint_supported( + self.dialect + ) + if nullable_supported: for column in constraint: - if column.nullable: - constraint.use_alter = True + column_nullable = column.nullable + if column_nullable: constraint.uConstraint_as_index = True - break - if getattr(constraint, 'uConstraint_as_index', None): - if not constraint.name: - index_name = "%s_%s_%s" % ('ukey', self.preparer.format_table(constraint.table), - '_'.join(column.name for column in constraint)) - else: - index_name = constraint.name - index = sa_schema.Index(index_name, *(column for column in constraint)) - index.unique = True - index.uConstraint_as_index = True - result = super(DB2DDLCompiler, self).create_table_constraints(table, **kw) - return result - + logger.debug( + "Nullable column detected in UNIQUE constraint. " + "Marking as INDEX." + ) + if getattr(constraint, "uConstraint_as_index", None): + qual = "INDEX " + const = self.preparer.format_constraint(constraint) + else: + qual = "" + const = self.preparer.format_constraint(constraint) + use_index = getattr(constraint, "uConstraint_as_index", None) + if use_index: + drop_sql = f"DROP {qual}{const}" + else: + table_name = self.preparer.format_table(constraint_table) + drop_sql = f"ALTER TABLE {table_name} DROP {qual}{const}" + logger.debug(f"Generated DROP SQL -> {drop_sql}") + return drop_sql + except Exception as e: + logger.error(f"Error generating DROP constraint SQL: {e}") + logger.exception("Stack trace in visit_drop_constraint") + raise + + @log_entry_exit + def create_table_constraints(self, table, **kw): + try: + table_name = table.name + logger.debug(f"Processing CREATE TABLE constraints -> table={table_name}") + nullable_supported = self._is_nullable_unique_constraint_supported( + self.dialect + ) + if nullable_supported: + for constraint in table._sorted_constraints: + if isinstance(constraint, sa_schema.UniqueConstraint): + constraint_name = constraint.name + logger.debug(f"Evaluating UniqueConstraint -> name={constraint_name}") + for column in constraint: + column_name = column.name + column_nullable = column.nullable + if column_nullable: + constraint.use_alter = True + constraint.uConstraint_as_index = True + logger.debug( + f"Nullable column detected -> {column_name}. " + "Converting UNIQUE constraint to INDEX." + ) + break + use_index = getattr(constraint, "uConstraint_as_index", None) + if use_index: + if not constraint_name: + index_name = "%s_%s_%s" % ( + "ukey", + self.preparer.format_table(constraint.table), + "_".join(col.name for col in constraint), + ) + else: + index_name = constraint_name + logger.debug( + f"Creating index for nullable UNIQUE constraint -> " + f"index_name={index_name}" + ) + index = sa_schema.Index(index_name,*(col for col in constraint)) + index.unique = True + index.uConstraint_as_index = True + result = super(DB2DDLCompiler, self).create_table_constraints(table, **kw) + logger.debug(f"Final CREATE TABLE constraints SQL fragment -> {result}") + return result + except Exception as e: + logger.error(f"Error processing create_table_constraints: {e}") + logger.exception("Stack trace in create_table_constraints") + raise + + @log_entry_exit def visit_create_index(self, create, include_schema=True, include_table_schema=True, **kw): - if SA_VERSION_MM < (0, 8): - sql = super(DB2DDLCompiler, self).visit_create_index(create, **kw) - else: - sql = super(DB2DDLCompiler, self).visit_create_index(create, include_schema, include_table_schema, **kw) - if getattr(create.element, 'uConstraint_as_index', None): - sql += ' EXCLUDE NULL KEYS' - return sql - + try: + element = create.element + index_name = getattr(element, "name", None) + is_unique = getattr(element, "unique", None) + use_index = getattr(element, "uConstraint_as_index", None) + logger.debug( + f"Processing CREATE INDEX -> " + f"name={index_name}, " + f"unique={is_unique}, " + f"uConstraint_as_index={use_index}" + ) + if SA_VERSION_MM < (0, 8): + sql = super(DB2DDLCompiler, self).visit_create_index(create, **kw) + else: + sql = super(DB2DDLCompiler, self).visit_create_index(create,include_schema, include_table_schema, **kw) + if use_index: + sql += " EXCLUDE NULL KEYS" + logger.debug("Applied EXCLUDE NULL KEYS for nullable unique constraint index.") + logger.debug(f"Generated CREATE INDEX SQL -> {sql}") + return sql + except Exception as e: + logger.error(f"Error generating CREATE INDEX SQL: {e}") + logger.exception("Stack trace in visit_create_index") + raise + + @log_entry_exit def visit_add_constraint(self, create, **kw): - if self._is_nullable_unique_constraint_supported(self.dialect): - if isinstance(create.element, sa_schema.UniqueConstraint): - for column in create.element: - if column.nullable: - create.element.uConstraint_as_index = True + try: + element = create.element + constraint_type = type(element).__name__ + constraint_name = getattr(element, "name", None) + logger.debug( + f"Processing ADD CONSTRAINT -> " + f"type={constraint_type}, " + f"name={constraint_name}" + ) + nullable_supported = self._is_nullable_unique_constraint_supported(self.dialect) + if nullable_supported and isinstance(element, sa_schema.UniqueConstraint): + for column in element: + column_name = column.name + column_nullable = column.nullable + if column_nullable: + element.uConstraint_as_index = True + logger.debug( + f"Nullable column detected -> {column_name}. " + "Converting UNIQUE constraint to INDEX." + ) break - if getattr(create.element, 'uConstraint_as_index', None): - if not create.element.name: - index_name = "%s_%s_%s" % ('uk_index', self.preparer.format_table(create.element.table), - '_'.join(column.name for column in create.element)) + use_index = getattr(element, "uConstraint_as_index", None) + if use_index: + table = element.table + table_name = self.preparer.format_table(table) + if not constraint_name: + index_name = "%s_%s_%s" % ( + "uk_index", + table_name, + "_".join(col.name for col in element), + ) else: - index_name = create.element.name - index = sa_schema.Index(index_name, *(column for column in create.element)) + index_name = constraint_name + logger.debug( + f"Creating index for nullable UNIQUE constraint -> " + f"index_name={index_name}" + ) + index = sa_schema.Index(index_name,*(col for col in element)) index.unique = True index.uConstraint_as_index = True sql = self.visit_create_index(sa_schema.CreateIndex(index)) + logger.debug(f"Generated SQL via index conversion -> {sql}") return sql - sql = super(DB2DDLCompiler, self).visit_add_constraint(create) - return sql + sql = super(DB2DDLCompiler, self).visit_add_constraint(create) + logger.debug(f"Generated ADD CONSTRAINT SQL -> {sql}") + return sql + except Exception as e: + logger.error(f"Error generating ADD CONSTRAINT SQL: {e}") + logger.exception("Stack trace in visit_add_constraint") + raise class DB2IdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = RESERVED_WORDS - illegal_initial_characters = set(range(0, 10)).union(["_", "$"]) + reserved_words = RESERVED_WORDS + illegal_initial_characters = set(range(0, 10)).union(["_", "$"]) + def __init__(self, dialect): + logger.debug("Initializing DB2IdentifierPreparer") + super(DB2IdentifierPreparer, self).__init__(dialect) + logger.debug( + f"IdentifierPreparer configuration -> " + f"reserved_words_count={len(self.reserved_words)}, " + f"illegal_initial_characters={self.illegal_initial_characters}" + ) class _SelectLastRowIDMixin(object): - _select_lastrowid = False - _lastrowid = None - - def get_lastrowid(self): - return self._lastrowid - - def pre_exec(self): - if self.isinsert: - tbl = self.compiled.statement.table - seq_column = tbl._autoincrement_column - insert_has_sequence = seq_column is not None - - self._select_lastrowid = insert_has_sequence and \ - not self.compiled.returning and \ - not self.compiled.inline - - def post_exec(self): - conn = self.root_connection - if self._select_lastrowid: - conn._cursor_execute(self.cursor, - "SELECT IDENTITY_VAL_LOCAL() FROM SYSIBM.SYSDUMMY1", - (), self) - row = self.cursor.fetchall()[0] - if row[0] is not None: - self._lastrowid = int(row[0]) + _select_lastrowid = False + _lastrowid = None + + @log_entry_exit + def get_lastrowid(self): + lastrowid = self._lastrowid + logger.debug(f"Returning lastrowid -> {lastrowid}") + return lastrowid + + @log_entry_exit + def pre_exec(self): + try: + is_insert = self.isinsert + compiled = self.compiled + logger.debug(f"pre_exec invoked -> isinsert={is_insert}") + if not is_insert: + logger.debug("Statement is not INSERT. Skipping identity logic.") + return + statement = compiled.statement + table = statement.table + seq_column = table._autoincrement_column + insert_has_sequence = seq_column is not None + returning_enabled = compiled.returning + inline_insert = compiled.inline + logger.debug( + f"Insert detected -> " + f"table={table.name}, " + f"autoincrement_column={getattr(seq_column, 'name', None)}, " + f"returning={returning_enabled}, " + f"inline={inline_insert}" + ) + select_lastrowid = ( + insert_has_sequence + and not returning_enabled + and not inline_insert + ) + self._select_lastrowid = select_lastrowid + logger.debug(f"Will fetch identity after insert -> {select_lastrowid}") + except Exception as e: + logger.error(f"Error during pre_exec: {e}") + logger.exception("Stack trace in pre_exec") + raise + + @log_entry_exit + def post_exec(self): + try: + select_lastrowid = self._select_lastrowid + if not select_lastrowid: + logger.debug("post_exec skipped identity fetch (not required)") + return + logger.debug("Fetching IDENTITY_VAL_LOCAL() after insert") + conn = self.root_connection + cursor = self.cursor + identity_sql = "SELECT IDENTITY_VAL_LOCAL() FROM SYSIBM.SYSDUMMY1" + logger.debug(f"Executing identity SQL -> {identity_sql}") + conn._cursor_execute( + cursor, + identity_sql, + (), + self + ) + row = cursor.fetchall()[0] + identity_value = row[0] + if identity_value is not None: + lastrowid = int(identity_value) + self._lastrowid = lastrowid + logger.info(f"Identity value retrieved successfully -> {lastrowid}") + else: + logger.warning("IDENTITY_VAL_LOCAL() returned NULL") + except Exception as e: + logger.error(f"Error during post_exec identity fetch: {e}") + logger.exception("Stack trace in post_exec") + raise class DB2ExecutionContext(_SelectLastRowIDMixin, default.DefaultExecutionContext): + @log_entry_exit def fire_sequence(self, seq, type_): - return self._execute_scalar("SELECT NEXTVAL FOR " + - self.dialect.identifier_preparer.format_sequence(seq) + - " FROM SYSIBM.SYSDUMMY1", type_) + sequence_name = str(seq) + try: + formatted_seq = self.dialect.identifier_preparer.format_sequence(seq) + sql = ("SELECT NEXTVAL FOR " + formatted_seq + " FROM SYSIBM.SYSDUMMY1") + logger.debug(f"Firing sequence -> name={sequence_name}, Generated SQL={sql}") + result = self._execute_scalar(sql, type_) + logger.info(f"Sequence value generated -> name={sequence_name}, value={result}") + return result + except Exception as e: + logger.error(f"Sequence execution failed -> name={sequence_name}, error={e}") + logger.exception("Stack trace for sequence execution failure") + raise class DB2Dialect(default.DefaultDialect): @@ -860,6 +1336,7 @@ class DB2Dialect(default.DefaultDialect): serverType = '' def __init__(self, **kw): + logger.debug("Creating DB2Dialect instance") super(DB2Dialect, self).__init__(**kw) self._reflector = self._reflector_cls(self) self.dbms_ver = None @@ -867,79 +1344,164 @@ def __init__(self, **kw): # reflection: these all defer to an BaseDB2Reflector # object which selects between DB2 and AS/400 schemas + @log_entry_exit def initialize(self, connection): - self.dbms_ver = getattr(connection.connection, 'dbms_ver', None) - self.dbms_name = getattr(connection.connection, 'dbms_name', None) - DB2Dialect.serverType = self.dbms_name - super(DB2Dialect, self).initialize(connection) - # check server type logic here - _reflector_cls = self._reflector_cls - if self.dbms_name == 'AS': - _reflector_cls = ibm_reflection.AS400Reflector - elif self.dbms_name == "DB2": - _reflector_cls = ibm_reflection.OS390Reflector - elif(self.dbms_name is None): - pass - elif "DB2/" in self.dbms_name: - _reflector_cls = ibm_reflection.DB2Reflector - elif "IDS/" in self.dbms_name: - _reflector_cls = ibm_reflection.DB2Reflector - elif self.dbms_name.startswith("DSN"): - _reflector_cls = ibm_reflection.OS390Reflector - self._reflector = _reflector_cls(self) - + logger.info("Initializing DB2Dialect") + try: + self.dbms_ver = getattr(connection.connection, 'dbms_ver', None) + self.dbms_name = getattr(connection.connection, 'dbms_name', None) + if not self.dbms_name: + logger.warning("DBMS name not detected from connection") + else: + logger.info( + f"Connected to DB Server -> name={self.dbms_name}, version={self.dbms_ver}" + ) + DB2Dialect.serverType = self.dbms_name + super(DB2Dialect, self).initialize(connection) + logger.debug( + f"SQLAlchemy version branch -> SA_VERSION_MM={SA_VERSION_MM}, " + f"returns_unicode_strings={self.returns_unicode_strings}" + ) + selected_reflector = self._reflector_cls + if self.dbms_name == 'AS': + selected_reflector = ibm_reflection.AS400Reflector + elif self.dbms_name == "DB2": + selected_reflector = ibm_reflection.OS390Reflector + elif self.dbms_name and "DB2/" in self.dbms_name: + selected_reflector = ibm_reflection.DB2Reflector + elif self.dbms_name and "IDS/" in self.dbms_name: + selected_reflector = ibm_reflection.DB2Reflector + elif self.dbms_name and self.dbms_name.startswith("DSN"): + selected_reflector = ibm_reflection.OS390Reflector + self._reflector = selected_reflector(self) + logger.info(f"Reflector selected -> {selected_reflector.__name__}") + except Exception as e: + logger.critical(f"Dialect initialization failed: {e}") + raise + + @log_entry_exit def get_columns(self, connection, table_name, schema=None, **kw): - return self._reflector.get_columns(connection, table_name, schema=schema, **kw) + logger.debug(f"Fetching columns -> table={table_name}, schema={schema}") + columns = self._reflector.get_columns(connection, table_name, schema=schema, **kw) + if not columns: + logger.warning(f"No columns found -> table={table_name}") + else: + logger.debug(f"Columns fetched -> count={len(columns)}") + return columns + @log_entry_exit def get_pk_constraint(self, connection, table_name, schema=None, **kw): - return self._reflector.get_pk_constraint(connection, table_name, schema=schema, **kw) + logger.debug(f"Fetching PK -> table={table_name}, schema={schema}") + pk = self._reflector.get_pk_constraint(connection, table_name, schema=schema, **kw) + if not pk or not pk.get("constrained_columns"): + logger.warning(f"No primary key found -> table={table_name}") + else: + logger.debug(f"PK columns -> {pk.get('constrained_columns')}") + return pk + @log_entry_exit def get_foreign_keys(self, connection, table_name, schema=None, **kw): - return self._reflector.get_foreign_keys(connection, table_name, schema=schema, **kw) + logger.debug(f"Fetching foreign keys -> table={table_name}, schema={schema}") + fks = self._reflector.get_foreign_keys(connection, table_name, schema=schema, **kw) + logger.debug(f"Foreign keys fetched -> count={len(fks)}") + return fks + @log_entry_exit def get_table_names(self, connection, schema=None, **kw): - return self._reflector.get_table_names(connection, schema=schema, **kw) + logger.debug(f"Fetching table names -> schema={schema}") + tables = self._reflector.get_table_names(connection, schema=schema, **kw) + logger.debug(f"Tables fetched -> count={len(tables)}") + return tables + @log_entry_exit def get_view_names(self, connection, schema=None, **kw): - return self._reflector.get_view_names(connection, schema=schema, **kw) + logger.debug(f"Fetching view names -> schema={schema}") + views = self._reflector.get_view_names(connection, schema=schema, **kw) + logger.debug(f"Views fetched -> count={len(views)}") + return views + @log_entry_exit def get_sequence_names(self, connection, schema=None, **kw): - return self._reflector.get_sequence_names(connection, schema=schema, **kw) + logger.debug(f"Fetching sequence names -> schema={schema}") + sequences = self._reflector.get_sequence_names(connection, schema=schema, **kw) + logger.debug(f"Sequences fetched -> count={len(sequences)}") + return sequences + @log_entry_exit def get_view_definition(self, connection, view_name, schema=None, **kw): - return self._reflector.get_view_definition(connection, view_name, schema=schema, **kw) + logger.debug(f"Fetching view definition -> view={view_name}, schema={schema}") + definition = self._reflector.get_view_definition(connection, view_name, schema=schema, **kw) + if definition: + logger.debug(f"View definition length -> {len(definition)} characters") + else: + logger.warning(f"View definition not found -> view={view_name}") + return definition + @log_entry_exit def get_indexes(self, connection, table_name, schema=None, **kw): - return self._reflector.get_indexes(connection, table_name, schema=schema, **kw) + logger.debug(f"Fetching indexes -> table={table_name}, schema={schema}") + indexes = self._reflector.get_indexes(connection, table_name, schema=schema, **kw) + logger.debug(f"Indexes fetched -> count={len(indexes)}") + return indexes + @log_entry_exit def get_unique_constraints(self, connection, table_name, schema=None, **kw): - return self._reflector.get_unique_constraints(connection, table_name, schema=schema, **kw) + logger.debug(f"Fetching unique constraints -> table={table_name}, schema={schema}") + constraints = self._reflector.get_unique_constraints(connection, table_name, schema=schema, **kw) + logger.debug(f"Unique constraints fetched -> count={len(constraints)}") + return constraints + @log_entry_exit def get_table_comment(self, connection, table_name, schema=None, **kw): - return self._reflector.get_table_comment(connection, table_name, schema=schema, **kw) + logger.debug(f"Fetching table comment -> table={table_name}, schema={schema}") + comment = self._reflector.get_table_comment(connection, table_name, schema=schema, **kw) + if comment: + logger.debug("Table comment present") + else: + logger.debug("No table comment found") + return comment + @log_entry_exit def normalize_name(self, name): - return self._reflector.normalize_name(name) + normalized = self._reflector.normalize_name(name) + logger.debug(f"Normalize -> original={name}, normalized={normalized}") + return normalized + @log_entry_exit def denormalize_name(self, name): - return self._reflector.denormalize_name(name) + denormalized = self._reflector.denormalize_name(name) + logger.debug(f"Denormalize -> original={name}, denormalized={denormalized}") + return denormalized + @log_entry_exit def has_table(self, connection, table_name, schema=None, **kw): - return self._reflector.has_table(connection, table_name, schema=schema, **kw) + exists = self._reflector.has_table(connection, table_name, schema=schema, **kw) + logger.debug(f"Table exists -> {exists}") + return exists + @log_entry_exit def has_sequence(self, connection, sequence_name, schema=None, **kw): - return self._reflector.has_sequence(connection, sequence_name, schema=schema, **kw) + exists = self._reflector.has_sequence(connection, sequence_name, schema=schema, **kw) + logger.debug(f"Sequence exists -> {exists}") + return exists + @log_entry_exit def get_schema_names(self, connection, **kw): - return self._reflector.get_schema_names(connection, **kw) + schemas = self._reflector.get_schema_names(connection, **kw) + logger.debug(f"Schemas fetched -> count={len(schemas)}") + return schemas + @log_entry_exit def get_primary_keys(self, connection, table_name, schema=None, **kw): - return self._reflector.get_primary_keys( - connection, table_name, schema=schema, **kw) + keys = self._reflector.get_primary_keys(connection, table_name, schema=schema, **kw) + logger.debug(f"Primary keys fetched -> count={len(keys)}") + return keys + @log_entry_exit def get_incoming_foreign_keys(self, connection, table_name, schema=None, **kw): - return self._reflector.get_incoming_foreign_keys( - connection, table_name, schema=schema, **kw) + fks = self._reflector.get_incoming_foreign_keys(connection, table_name, schema=schema, **kw) + logger.debug(f"Incoming foreign keys fetched -> count={len(fks)}") + return fks # legacy naming diff --git a/ibm_db_sa/ibm_db.py b/ibm_db_sa/ibm_db.py index 37e5c6f..9607acd 100644 --- a/ibm_db_sa/ibm_db.py +++ b/ibm_db_sa/ibm_db.py @@ -19,6 +19,7 @@ import re from sqlalchemy import __version__ as SA_VERSION_STR +from .logger import init_ibmdbsa_logging, logger, log_entry_exit m = re.match(r"^\s*(\d+)\.(\d+)", SA_VERSION_STR) SA_VERSION_MM = (int(m.group(1)), int(m.group(2))) if m else (0, 0) @@ -45,50 +46,63 @@ class _IBM_Numeric_ibm_db(sa_types.Numeric): - def result_processor(self, dialect, coltype): - def to_float(value): - if value is None: - return None - else: - return float(value) - if self.asdecimal: - return None - else: - return to_float + @log_entry_exit + def result_processor(self, dialect, coltype): + logger.debug("Creating result processor for _IBM_Numeric_ibm_db") + def to_float(value): + logger.debug("Processing numeric result value: %s", value) + if value is None: + return None + else: + return float(value) + if self.asdecimal: + logger.debug("Returning None processor since asdecimal=True") + return None + else: + logger.debug("Returning float conversion processor") + return to_float class DB2ExecutionContext_ibm_db(DB2ExecutionContext): _callproc_result = None _out_parameters = None + @log_entry_exit def get_lastrowid(self): + logger.debug("Fetching last inserted row id") return self.cursor.last_identity_val + @log_entry_exit def pre_exec(self): # check for the compiled_parameters attribute in self - if (hasattr(self, "compiled_parameters")): + logger.debug("Executing pre_exec checks") + if hasattr(self, "compiled_parameters"): # if a single execute, check for outparams + logger.debug("Compiled parameters detected") if len(self.compiled_parameters) == 1: for bindparam in self.compiled.binds.values(): if bindparam.isoutparam: + logger.debug("OUT parameter detected") self._out_parameters = True break else: - pass + logger.debug("No compiled_parameters attribute found") + @log_entry_exit def get_result_proxy(self): + logger.debug("Creating result proxy") if self._callproc_result and self._out_parameters: if SA_VERSION_MM < (0, 8): result = base.ResultProxy(self) else: result = _result.ResultProxy(self) result.out_parameters = {} - for bindparam in self.compiled.binds.values(): if bindparam.isoutparam: name = self.compiled.bind_names[bindparam] - result.out_parameters[name] = self._callproc_result[self.compiled.positiontup.index(name)] - + logger.debug("Processing OUT parameter: %s", name) + result.out_parameters[name] = \ + self._callproc_result[self.compiled.positiontup.index(name)] return result else: if SA_VERSION_MM < (0, 8): @@ -119,26 +133,36 @@ class DB2Dialect_ibm_db(DB2Dialect): if SA_VERSION_MM < (2, 0): @classmethod + @log_entry_exit def dbapi(cls): """ Returns: the underlying DBAPI driver module """ + logger.debug("Importing ibm_db_dbi DBAPI module") import ibm_db_dbi as module return module else: @classmethod + @log_entry_exit def import_dbapi(cls): """ Returns: the underlying DBAPI driver module """ + logger.debug("Importing ibm_db_dbi DBAPI module") import ibm_db_dbi as module return module + @log_entry_exit def do_execute(self, cursor, statement, parameters, context=None): + logger.debug("Executing SQL statement") + logger.debug("Statement: %s", statement) + logger.debug("Parameters: %s", parameters) if context and context._out_parameters: + logger.debug("Detected stored procedure execution") statement = statement.split('(', 1)[0].split()[1] context._callproc_result = cursor.callproc(statement, parameters) else: check_server = getattr(DB2Dialect, 'serverType') if ("round(" in statement.casefold()) and check_server == "DB2": + logger.debug("Applying round() workaround for DB2") value_index = 0 while '?' in statement and value_index < len(parameters): statement = statement.replace('?', str(parameters[value_index]), 1) @@ -147,8 +171,12 @@ def do_execute(self, cursor, statement, parameters, context=None): else: cursor.execute(statement, parameters) + @log_entry_exit def _get_server_version_info(self, connection): - return connection.connection.server_info() + logger.debug("Fetching DB2 server version") + version = connection.connection.server_info() + logger.debug("Server version info: %s", version) + return version _isolation_lookup = set(['READ STABILITY', 'RS', 'UNCOMMITTED READ', 'UR', 'CURSOR STABILITY', 'CS', 'REPEATABLE READ', 'RR']) @@ -160,37 +188,54 @@ def _get_server_version_info(self, connection): _isolation_levels_returned = {value: key for key, value in _isolation_levels_cli.items()} + @log_entry_exit def _get_cli_isolation_levels(self, level): - return self._isolation_levels_cli[level] + logger.debug("Fetching CLI isolation level mapping for: %s", level) + value = self._isolation_levels_cli[level] + logger.debug("CLI isolation level value: %s", value) + return value + @log_entry_exit def set_isolation_level(self, connection, level): + logger.debug("Requested isolation level: %s", level) if level is None: + logger.debug("Isolation level is None, defaulting to CS") level = 'CS' else: if len(level.strip()) < 1: + logger.debug("Isolation level empty after strip, defaulting to CS") level = 'CS' level = level.upper().replace("-", " ").replace("_", " ") + logger.debug("Normalized isolation level: %s", level) if level not in self._isolation_lookup: + logger.error("Invalid isolation level requested: %s", level) raise ArgumentError( "Invalid value '%s' for isolation_level. " "Valid isolation levels for %s are %s" % (level, self.name, ", ".join(self._isolation_lookup)) ) attrib = {SQL_ATTR_TXN_ISOLATION: self._get_cli_isolation_levels(level)} + logger.debug("Setting isolation level with attributes: %s", attrib) res = connection.set_option(attrib) + logger.debug("Isolation level set result: %s", res) + @log_entry_exit def get_isolation_level(self, connection): - + logger.debug("Retrieving current isolation level") attrib = SQL_ATTR_TXN_ISOLATION res = connection.get_option(attrib) - + logger.debug("Raw isolation level value from connection: %s", res) val = self._isolation_levels_returned[res] + logger.debug("Mapped isolation level: %s", val) return val + @log_entry_exit def reset_isolation_level(self, connection): + logger.debug("Resetting isolation level to default (CS)") self.set_isolation_level(connection, 'CS') def create_connect_args(self, url): + url, ibmdbsa_log_value = init_ibmdbsa_logging(url) # DSN support through CLI configuration (../cfg/db2cli.ini), # while 2 connection attributes are mandatory: database alias # and UID (in support to current schema), all the other @@ -198,57 +243,88 @@ def create_connect_args(self, url): # provided through db2cli.ini database catalog entry. Example # 1: ibm_db_sa:///?UID=db2inst1 or Example 2: # ibm_db_sa:///?DSN=;UID=db2inst1 + logger.info("entry create_connect_args()") if not url.host: + logger.debug("Using DSN based connection") dsn = url.database uid = url.username pwd = url.password + logger.debug("DSN connection parameters -> database=%s user=%s", dsn, uid) + logger.info("exit create_connect_args()") return (dsn, uid, pwd, '', ''), {} else: # Full URL string support for connection to remote data servers + logger.debug("Using full connection URL for remote DB2 server") dsn_param = ['DATABASE=%s' % url.database, 'HOSTNAME=%s' % url.host, 'PROTOCOL=TCPIP'] + logger.debug("Host: %s", url.host) + logger.debug("Database: %s", url.database) if url.port: + logger.debug("Port: %s", url.port) dsn_param.append('PORT=%s' % url.port) if url.username: + logger.debug("User: %s", url.username) dsn_param.append('UID=%s' % url.username) if url.password: + # if password contains ';' truncate at first ';' (existing logic) if ';' in url.password: - url.password = (url.password).partition(";")[0] + logger.debug("Password contains ';', truncating") + url = url._replace(password=(url.password).partition(";")[0]) dsn_param.append('PWD=%s' % url.password) - # check for connection arguments connection_keys = ['Security', 'SSLClientKeystoredb', 'SSLClientKeystash', 'SSLServerCertificate', 'CurrentSchema'] - query_keys = url.query.keys() + # rebuild query_keys in case url changed + query_keys = list(url.query.keys()) if url.query else [] for key in connection_keys: for query_key in query_keys: if query_key.lower() == key.lower(): + logger.debug("Applying connection option: %s=%s", key, url.query[query_key]) dsn_param.append( '%(connection_key)s=%(value)s' % {'connection_key': key, 'value': url.query[query_key]}) url = url.difference_update_query([query_key]) break - dsn = ';'.join(dsn_param) dsn += ';' + safe_dsn = dsn + if 'PWD=' in safe_dsn: + safe_dsn = re.sub(r'PWD=[^;]*', 'PWD=****', safe_dsn) + logger.debug("Constructed DB2 DSN: %s", safe_dsn) + logger.info("exit create_connect_args()") return (dsn, url.username, '', '', ''), {} # Retrieves current schema for the specified connection object + @log_entry_exit def _get_default_schema_name(self, connection): - return self.normalize_name(connection.connection.get_current_schema()) + logger.debug("Fetching current schema from DB2") + schema = connection.connection.get_current_schema() + logger.debug("Current schema returned: %s", schema) + normalized_schema_name = self.normalize_name(schema) + logger.debug("Normalized schema: %s", normalized_schema_name) + return normalized_schema_name # Checks if the DB_API driver error indicates an invalid connection + @log_entry_exit def is_disconnect(self, ex, connection, cursor): + logger.debug("Checking if exception indicates disconnect") + logger.debug("Exception received: %s", ex) if isinstance(ex, (self.dbapi.ProgrammingError, self.dbapi.OperationalError)): - connection_errors = ('Connection is not active', 'connection is no longer active', - 'Connection Resource cannot be found', 'SQL30081N', - 'CLI0108E', 'CLI0106E', 'SQL1224N') + connection_errors = ('Connection is not active', + 'connection is no longer active', + 'Connection Resource cannot be found', + 'SQL30081N', + 'CLI0108E', + 'CLI0106E', + 'SQL1224N') for err_msg in connection_errors: if err_msg in str(ex): + logger.debug("Disconnect detected due to error: %s", err_msg) return True else: - return False + logger.debug("Exception type does not indicate disconnect") + return False dialect = DB2Dialect_ibm_db diff --git a/ibm_db_sa/logger.py b/ibm_db_sa/logger.py new file mode 100644 index 0000000..00458ef --- /dev/null +++ b/ibm_db_sa/logger.py @@ -0,0 +1,126 @@ +import logging as ibmdbsa_logging +import functools +import inspect +logger = ibmdbsa_logging.getLogger("ibm_db_sa") +logger.setLevel(ibmdbsa_logging.DEBUG) +logger.propagate = False # prevent propagation to root logger + +def configure_ibmdbsa_logging(target=False): + """ + Configure ibm_db_sa logging. + target = True -> console logging + target = "file" -> file logging (overwrite file) + target = False -> disable logging + """ + # Prevent reconfiguration if already configured with same target + current_target = getattr(logger, "_ibmdbsa_target", None) + if current_target == target: + return + # Remove existing handlers + for handler in list(logger.handlers): + logger.removeHandler(handler) + try: + handler.close() + except Exception: + pass + if not target: + logger.disabled = True + logger._ibmdbsa_target = target + return + # Console logging + if target is True: + handler = ibmdbsa_logging.StreamHandler() + # File logging + elif isinstance(target, str): + handler = ibmdbsa_logging.FileHandler(target, mode="w", encoding='utf-8') + else: + logger.disabled = True + logger._ibmdbsa_target = target + return + formatter = ibmdbsa_logging.Formatter( + "%(asctime)s - [ibm_db_sa] - %(levelname)s - %(message)s", + "%Y-%m-%d %H:%M:%S" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.disabled = False + logger._ibmdbsa_target = target + logger.debug(f"IBM_DB_SA logging initialized -> {target}") + +def init_ibmdbsa_logging(url): + """ + Extract 'ibmdbsa_log' from SQLAlchemy URL query parameters, + configure logging if present, and remove the parameter + so it is not passed to the DBAPI layer. + """ + ibmdbsa_log_value = None + try: + query_keys = list(url.query.keys()) if url.query else [] + except Exception: + query_keys = [] + for qk in query_keys: + if qk.lower() == "ibmdbsa_log": + raw_val = url.query[qk] + if isinstance(raw_val, str): + val = raw_val.lower() + if val in ("true", "1", "yes", "y"): + ibmdbsa_log_value = True + elif val in ("false", "0", "no", "n", ""): + ibmdbsa_log_value = False + else: + ibmdbsa_log_value = raw_val + else: + ibmdbsa_log_value = raw_val + # remove parameter so DBAPI never receives it + url = url.difference_update_query([qk]) + break + if ibmdbsa_log_value is not None: + configure_ibmdbsa_logging(ibmdbsa_log_value) + logger.debug( + f"ibm_db_sa logging enabled via URL parameter -> {ibmdbsa_log_value}" + ) + return url, ibmdbsa_log_value + +def log_entry_exit(func): + """Logs entry, exit, execution time, and exceptions.""" + import time + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + start = time.time() + try: + logger.info(f"Entry: {func.__name__}") + result = await func(*args, **kwargs) + duration = round((time.time() - start) * 1000, 2) + logger.info(f"Exit: {func.__name__} (took {duration} ms)") + return result + except Exception as e: + logger.exception(f"Exception in {func.__name__}: {e}") + raise + @functools.wraps(func) + def sync_wrapper(*args, **kwargs): + start = time.time() + try: + logger.info(f"Entry: {func.__name__}") + result = func(*args, **kwargs) + duration = round((time.time() - start) * 1000, 2) + logger.info(f"Exit: {func.__name__} (took {duration} ms)") + return result + except Exception as e: + logger.exception(f"Exception in {func.__name__}: {e}") + raise + return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper + +def _format_args(args, kwargs): + parts = [] + if args: + parts.append(", ".join(map(str, args))) + if kwargs: + parts.append(", ".join(f"{k}={v}" for k, v in kwargs.items())) + return ", ".join(parts) + +__all__ = [ + "logger", + "configure_ibmdbsa_logging", + "init_ibmdbsa_logging", + "log_entry_exit" +] \ No newline at end of file diff --git a/ibm_db_sa/pyodbc.py b/ibm_db_sa/pyodbc.py index 4988dfd..31d6131 100644 --- a/ibm_db_sa/pyodbc.py +++ b/ibm_db_sa/pyodbc.py @@ -16,7 +16,9 @@ # | Authors: Jaimy Azle, Rahul Priyadarshi | # | Contributors: Mike Bayer | # +--------------------------------------------------------------------------+ +import re from sqlalchemy import util +from .logger import init_ibmdbsa_logging, logger, log_entry_exit import urllib from sqlalchemy.connectors.pyodbc import PyODBCConnector from .base import _SelectLastRowIDMixin, DB2ExecutionContext, DB2Dialect @@ -37,54 +39,61 @@ class DB2Dialect_pyodbc(PyODBCConnector, DB2Dialect): pyodbc_driver_name = "IBM DB2 ODBC DRIVER" def create_connect_args(self, url): + url, ibmdbsa_log_value = init_ibmdbsa_logging(url) + logger.info("entry create_connect_args()") + logger.debug("Starting create_connect_args for DB2Dialect_pyodbc") opts = url.translate_connect_args(username='user') opts.update(url.query) - keys = opts query = url.query - connect_args = {} for param in ('ansi', 'unicode_results', 'autocommit'): if param in keys: + logger.debug("Setting connect_arg %s=%s", param, keys[param]) connect_args[param] = util.asbool(keys.pop(param)) - if 'odbc_connect' in keys: + logger.debug("Using provided ODBC connection string") connectors = [urllib.parse.unquote_plus(keys.pop('odbc_connect'))] else: - dsn_connection = 'dsn' in keys or \ - ('host' in keys and 'database' not in keys) + dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys) if dsn_connection: - connectors = ['dsn=%s' % (keys.pop('host', '') or \ - keys.pop('dsn', ''))] + logger.debug("Using DSN based connection") + connectors = ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))] else: port = '' if 'port' in keys and not 'port' in query: port = '%d' % int(keys.pop('port')) - + logger.debug("Detected port: %s", port) database = keys.pop('database', '') - - connectors = ["DRIVER={%s}" % - keys.pop('driver', self.pyodbc_driver_name), - 'hostname=%s;port=%s' % (keys.pop('host', ''), port), - 'database=%s' % database] - + logger.debug("Host: %s", keys.get('host')) + logger.debug("Database: %s", database) + connectors = [ + "DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name), + 'hostname=%s;port=%s' % (keys.pop('host', ''), port), + 'database=%s' % database + ] user = keys.pop("user", None) if user: + logger.debug("User provided for connection") connectors.append("uid=%s" % user) connectors.append("pwd=%s" % keys.pop('password', '')) else: + logger.debug("Using trusted connection") connectors.append("trusted_connection=yes") - # if set to 'yes', the odbc layer will try to automagically # convert textual data from your database encoding to your # client encoding. this should obviously be set to 'no' if # you query a cp1253 encoded database from a latin1 client... if 'odbc_autotranslate' in keys: - connectors.append("autotranslate=%s" % - keys.pop("odbc_autotranslate")) - - connectors.extend(['%s=%s' % (k, v) - for k, v in keys.items()]) + logger.debug("Setting odbc_autotranslate option") + connectors.append( + "autotranslate=%s" % keys.pop("odbc_autotranslate") + ) + connectors.extend(['%s=%s' % (k, v) for k, v in keys.items()]) + conn_str = ";".join(connectors) + safe_conn_str = re.sub(r'pwd=[^;]*', 'pwd=****', conn_str, flags=re.IGNORECASE) + logger.debug("Constructed ODBC connection string: %s", safe_conn_str) + logger.info("exit create_connect_args()") return [[";".join(connectors)], connect_args] class AS400Dialect_pyodbc(PyODBCConnector, DB2Dialect): @@ -101,49 +110,60 @@ class AS400Dialect_pyodbc(PyODBCConnector, DB2Dialect): _reflector_cls = ibm_reflection.AS400Reflector def create_connect_args(self, url): + url, ibmdbsa_log_value = init_ibmdbsa_logging(url) + logger.info("entry create_connect_args()") + logger.debug("Starting create_connect_args for AS400Dialect_pyodbc") opts = url.translate_connect_args(username='user') opts.update(url.query) - keys = opts query = url.query - connect_args = {} for param in ('ansi', 'unicode_results', 'autocommit'): - if param in keys: - connect_args[param] = util.asbool(keys.pop(param)) - + if param in keys: + logger.debug("Setting connect_arg %s=%s", param, keys[param]) + connect_args[param] = util.asbool(keys.pop(param)) if 'odbc_connect' in keys: - connectors = [urllib.parse.unquote_plus(keys.pop('odbc_connect'))] + logger.debug("Using provided ODBC connection string") + connectors = [urllib.parse.unquote_plus(keys.pop('odbc_connect'))] else: - dsn_connection = 'dsn' in keys or \ - ('host' in keys and 'database' not in keys) - if dsn_connection: - connectors = ['dsn=%s' % (keys.pop('host', '') or \ - keys.pop('dsn', ''))] - else: - connectors = ["DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name), - 'System=%s' % keys.pop('host', ''), - 'DBQ=QGPL'] - connectors.append("PKG=QGPL/DEFAULT(IBM),2,0,1,0,512") - db_name = keys.pop('database', '') - if db_name: - connectors.append("DATABASE=%s" % db_name) - - user = keys.pop("user", None) - if user: - connectors.append("UID=%s" % user) - connectors.append("PWD=%s" % keys.pop('password', '')) - else: - connectors.append("trusted_connection=yes") - - # if set to 'Yes', the ODBC layer will try to automagically convert - # textual data from your database encoding to your client encoding - # This should obviously be set to 'No' if you query a cp1253 encoded - # database from a latin1 client... - if 'odbc_autotranslate' in keys: - connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate")) - - connectors.extend(['%s=%s' % (k,v) for k,v in keys.items()]) - return [[";".join (connectors)], connect_args] + dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys) + if dsn_connection: + logger.debug("Using DSN based connection") + connectors = ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))] + else: + logger.debug("Using IBM i Access ODBC driver") + connectors = [ + "DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name), + 'System=%s' % keys.pop('host', ''), + 'DBQ=QGPL' + ] + connectors.append("PKG=QGPL/DEFAULT(IBM),2,0,1,0,512") + db_name = keys.pop('database', '') + if db_name: + logger.debug("Database: %s", db_name) + connectors.append("DATABASE=%s" % db_name) + user = keys.pop("user", None) + if user: + logger.debug("User provided for connection") + connectors.append("UID=%s" % user) + connectors.append("PWD=%s" % keys.pop('password', '')) + else: + logger.debug("Using trusted connection") + connectors.append("trusted_connection=yes") + # if set to 'Yes', the ODBC layer will try to automagically convert + # textual data from your database encoding to your client encoding + # This should obviously be set to 'No' if you query a cp1253 encoded + # database from a latin1 client... + if 'odbc_autotranslate' in keys: + logger.debug("Setting AutoTranslate option") + connectors.append( + "AutoTranslate=%s" % keys.pop("odbc_autotranslate") + ) + connectors.extend(['%s=%s' % (k, v) for k, v in keys.items()]) + conn_str = ";".join(connectors) + safe_conn_str = re.sub(r'PWD=[^;]*', 'PWD=****', conn_str, flags=re.IGNORECASE) + logger.debug("Constructed AS400 ODBC connection string: %s", safe_conn_str) + logger.info("exit create_connect_args()") + return [[";".join(connectors)], connect_args] diff --git a/ibm_db_sa/reflection.py b/ibm_db_sa/reflection.py index 396ab56..a318557 100644 --- a/ibm_db_sa/reflection.py +++ b/ibm_db_sa/reflection.py @@ -22,6 +22,7 @@ from sqlalchemy import Table, MetaData, Column from sqlalchemy.engine import reflection from sqlalchemy import * +from .logger import logger, log_entry_exit import re import codecs from sys import version_info @@ -38,55 +39,120 @@ def process_bind_param(self, value, dialect): class BaseReflector(object): + @log_entry_exit def __init__(self, dialect): self.dialect = dialect self.ischema_names = dialect.ischema_names self.identifier_preparer = dialect.identifier_preparer + logger.debug( + f"BaseReflector initialized -> " + f"dialect={dialect}, " + ) + @log_entry_exit def normalize_name(self, name): - if isinstance(name, str): - name = name - if name is not None: - return name.lower() if name.upper() == name and \ - not self.identifier_preparer._requires_quotes(name.lower()) \ - else name - return name - - def denormalize_name(self, name): - if name is None: - return None - elif name.lower() == name and \ - not self.identifier_preparer._requires_quotes(name.lower()): - name = name.upper() - if not self.dialect.supports_unicode_binds: + try: + original_name = name if isinstance(name, str): name = name + if name is not None: + requires_quotes = self.identifier_preparer._requires_quotes( + name.lower() + ) + result = ( + name.lower() + if name.upper() == name and not requires_quotes + else name + ) + logger.debug( + f"normalize_name -> original={original_name}, " + f"requires_quotes={requires_quotes}, " + f"result={result}" + ) + return result + logger.debug("normalize_name -> input is None") + return name + except Exception as e: + logger.error(f"Error in normalize_name: {e}") + logger.exception("Stack trace in normalize_name") + raise + + @log_entry_exit + def denormalize_name(self, name): + try: + original_name = name + if name is None: + logger.debug("denormalize_name -> input is None") + return None + lower_name = name.lower() + requires_quotes = self.identifier_preparer._requires_quotes( + lower_name + ) + if lower_name == name and not requires_quotes: + name = name.upper() + supports_unicode = self.dialect.supports_unicode_binds + if not supports_unicode: + if isinstance(name, str): + name = name + else: + name = codecs.decode(name) else: - name = codecs.decode(name) - else: - if version_info[0] < 3: - name = unicode(name) - else: - name = str(name) - return name + if version_info[0] < 3: + name = unicode(name) + else: + name = str(name) + logger.debug( + f"denormalize_name -> original={original_name}, " + f"requires_quotes={requires_quotes}, " + f"supports_unicode_binds={supports_unicode}, " + f"result={name}" + ) + return name + except Exception as e: + logger.error(f"Error in denormalize_name: {e}") + logger.exception("Stack trace in denormalize_name") + raise + @log_entry_exit def _get_default_schema_name(self, connection): """Return: current setting of the schema attribute""" - default_schema_name = connection.execute( - u'SELECT CURRENT_SCHEMA FROM SYSIBM.SYSDUMMY1').scalar() - if isinstance(default_schema_name, str): - default_schema_name = default_schema_name.strip() - elif version_info[0] < 3: - if isinstance(default_schema_name, unicode): - default_schema_name = default_schema_name.strip().__str__() - else: - if isinstance(default_schema_name, str): - default_schema_name = default_schema_name.strip().__str__() - return self.normalize_name(default_schema_name) + try: + logger.debug("Fetching default schema name from database.") + default_schema_name = connection.execute( + u"SELECT CURRENT_SCHEMA FROM SYSIBM.SYSDUMMY1" + ).scalar() + logger.debug( + f"Raw default schema fetched -> {default_schema_name}" + ) + if isinstance(default_schema_name, str): + default_schema_name = default_schema_name.strip() + elif version_info[0] < 3: + if isinstance(default_schema_name, unicode): + default_schema_name = ( + default_schema_name.strip().__str__() + ) + else: + if isinstance(default_schema_name, str): + default_schema_name = ( + default_schema_name.strip().__str__() + ) + normalized = self.normalize_name(default_schema_name) + logger.debug( + f"Normalized default schema -> {normalized}" + ) + return normalized + except Exception as e: + logger.error(f"Error fetching default schema name: {e}") + logger.exception("Stack trace in _get_default_schema_name") + raise @property def default_schema_name(self): - return self.dialect.default_schema_name + schema_name = self.dialect.default_schema_name + logger.debug( + f"Accessing default_schema_name property -> {schema_name}" + ) + return schema_name class DB2Reflector(BaseReflector): @@ -174,330 +240,576 @@ class DB2Reflector(BaseReflector): Column("SEQNAME", CoerceUnicode, key="seqname"), schema="SYSCAT") + @log_entry_exit def has_table(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name( - schema or self.default_schema_name) - if table_name.startswith("'") and table_name.endswith("'"): - table_name = table_name.replace("'", "") - table_name = self.normalize_name(table_name) - else: - table_name = self.denormalize_name(table_name) - if current_schema: - whereclause = sql.and_(self.sys_tables.c.tabschema == current_schema, - self.sys_tables.c.tabname == table_name) - else: - whereclause = self.sys_tables.c.tabname == table_name - s = sql.select(self.sys_tables.c.tabname).where(whereclause) - c = connection.execute(s) - return c.first() is not None - + try: + logger.debug(f"Checking table existence -> schema={schema}, table={table_name}") + current_schema = self.denormalize_name(schema or self.default_schema_name) + original_table_name = table_name + if table_name.startswith("'") and table_name.endswith("'"): + table_name = table_name.replace("'", "") + table_name = self.normalize_name(table_name) + else: + table_name = self.denormalize_name(table_name) + logger.debug( + f"Resolved identifiers -> " + f"schema={current_schema}, " + f"table={table_name}" + ) + if current_schema: + whereclause = sql.and_( + self.sys_tables.c.tabschema == current_schema, + self.sys_tables.c.tabname == table_name + ) + else: + whereclause = self.sys_tables.c.tabname == table_name + s = sql.select(self.sys_tables.c.tabname).where(whereclause) + logger.debug(f"Generated has_table SQL -> {s}") + result = connection.execute(s).first() is not None + logger.debug(f"has_table result -> table={original_table_name}, exists={result}") + return result + except Exception as e: + logger.error(f"Error checking table existence: {e}") + logger.exception("Stack trace in has_table") + raise + + @log_entry_exit def has_sequence(self, connection, sequence_name, schema=None): - current_schema = self.denormalize_name(schema or self.default_schema_name) - sequence_name = self.denormalize_name(sequence_name) - if current_schema: - whereclause = sql.and_(self.sys_sequences.c.seqschema == current_schema, - self.sys_sequences.c.seqname == sequence_name) - else: - whereclause = self.sys_sequences.c.seqname == sequence_name - s = sql.select(self.sys_sequences.c.seqname).where(whereclause) - c = connection.execute(s) - return c.first() is not None + try: + logger.debug(f"Checking sequence existence -> schema={schema}, sequence={sequence_name}") + current_schema = self.denormalize_name(schema or self.default_schema_name) + sequence_name = self.denormalize_name(sequence_name) + logger.debug( + f"Resolved identifiers -> " + f"schema={current_schema}, " + f"sequence={sequence_name}" + ) + if current_schema: + whereclause = sql.and_( + self.sys_sequences.c.seqschema == current_schema, + self.sys_sequences.c.seqname == sequence_name + ) + else: + whereclause = self.sys_sequences.c.seqname == sequence_name + s = sql.select(self.sys_sequences.c.seqname).where(whereclause) + logger.debug(f"Generated has_sequence SQL -> {s}") + result = connection.execute(s).first() is not None + logger.debug(f"has_sequence result -> sequence={sequence_name}, exists={result}") + return result + except Exception as e: + logger.error(f"Error checking sequence existence: {e}") + logger.exception("Stack trace in has_sequence") + raise @reflection.cache + @log_entry_exit def get_sequence_names(self, connection, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - sys_sequence = self.sys_sequences - query = sql.select(sys_sequence.c.seqname).\ - where(sys_sequence.c.seqschema == current_schema).\ - order_by(sys_sequence.c.seqschema, sys_sequence.c.seqname) - return [self.normalize_name(r[0]) for r in connection.execute(query)] + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + logger.debug(f"Fetching sequence names -> schema={current_schema}") + sys_sequence = self.sys_sequences + query = ( + sql.select(sys_sequence.c.seqname) + .where(sys_sequence.c.seqschema == current_schema) + .order_by( + sys_sequence.c.seqschema, + sys_sequence.c.seqname + ) + ) + logger.debug(f"Generated get_sequence_names SQL -> {query}") + result = [self.normalize_name(r[0]) for r in connection.execute(query)] + logger.debug(f"Reflected sequences -> count={len(result)}, sequences={result}") + return result + except Exception as e: + logger.error(f"Error fetching sequence names: {e}") + logger.exception("Stack trace in get_sequence_names") + raise @reflection.cache + @log_entry_exit def get_schema_names(self, connection, **kw): - sysschema = self.sys_schemas - query = sql.select(sysschema.c.schemaname).\ - where(not_(sysschema.c.schemaname.like('SYS%'))).\ - order_by(sysschema.c.schemaname) - return [self.normalize_name(r[0].rstrip()) for r in connection.execute(query)] + try: + logger.debug("Fetching schema names.") + sysschema = self.sys_schemas + query = ( + sql.select(sysschema.c.schemaname) + .where(not_(sysschema.c.schemaname.like('SYS%'))) + .order_by(sysschema.c.schemaname) + ) + logger.debug(f"Generated get_schema_names SQL -> {query}") + result = [self.normalize_name(r[0].rstrip()) for r in connection.execute(query)] + logger.debug(f"Reflected schemas -> count={len(result)}, schemas={result}") + return result + except Exception as e: + logger.error(f"Error fetching schema names: {e}") + logger.exception("Stack trace in get_schema_names") + raise @reflection.cache + @log_entry_exit def get_table_names(self, connection, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - systbl = self.sys_tables - query = sql.select(systbl.c.tabname).\ - where(systbl.c.type == 'T').\ - where(systbl.c.tabschema == current_schema).\ - order_by(systbl.c.tabname) - return [self.normalize_name(r[0]) for r in connection.execute(query)] + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + logger.debug(f"Fetching table names -> schema={current_schema}") + systbl = self.sys_tables + query = ( + sql.select(systbl.c.tabname) + .where(systbl.c.type == 'T') + .where(systbl.c.tabschema == current_schema) + .order_by(systbl.c.tabname) + ) + logger.debug(f"Generated get_table_names SQL -> {query}") + result = [self.normalize_name(r[0]) for r in connection.execute(query)] + logger.debug(f"Reflected tables -> count={len(result)}, tables={result}") + return result + except Exception as e: + logger.error(f"Error fetching table names: {e}") + logger.exception("Stack trace in get_table_names") + raise @reflection.cache + @log_entry_exit def get_table_comment(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - systbl = self.sys_tables - query = sql.select(systbl.c.remarks).\ - where(systbl.c.tabschema == current_schema).\ - where(systbl.c.tabname == table_name) - return {'text': connection.execute(query).scalar()} + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug(f"Fetching table comment -> schema={current_schema}, table={table_name}") + systbl = self.sys_tables + query = ( + sql.select(systbl.c.remarks) + .where(systbl.c.tabschema == current_schema) + .where(systbl.c.tabname == table_name) + ) + logger.debug(f"Generated get_table_comment SQL -> {query}") + comment = connection.execute(query).scalar() + logger.debug(f"Table comment result -> {comment}") + return {'text': comment} + except Exception as e: + logger.error(f"Error fetching table comment: {e}") + logger.exception("Stack trace in get_table_comment") + raise @reflection.cache + @log_entry_exit def get_view_names(self, connection, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - - query = sql.select(self.sys_views.c.viewname).\ - where(self.sys_views.c.viewschema == current_schema).\ - order_by(self.sys_views.c.viewname) - - return [self.normalize_name(r[0]) for r in connection.execute(query)] + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + logger.debug(f"Fetching view names -> schema={current_schema}") + query = ( + sql.select(self.sys_views.c.viewname) + .where(self.sys_views.c.viewschema == current_schema) + .order_by(self.sys_views.c.viewname) + ) + logger.debug(f"Generated get_view_names SQL -> {query}") + result = [self.normalize_name(r[0]) for r in connection.execute(query)] + logger.debug(f"Reflected views -> count={len(result)}, views={result}") + return result + except Exception as e: + logger.error(f"Error fetching view names: {e}") + logger.exception("Stack trace in get_view_names") + raise @reflection.cache + @log_entry_exit def get_view_definition(self, connection, viewname, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - viewname = self.denormalize_name(viewname) - - query = sql.select(self.sys_views.c.text).\ - where(self.sys_views.c.viewschema == current_schema).\ - where(self.sys_views.c.viewname == viewname) - - return connection.execute(query).scalar() + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + viewname = self.denormalize_name(viewname) + logger.debug(f"Fetching view definition -> schema={current_schema}, view={viewname}") + query = ( + sql.select(self.sys_views.c.text) + .where(self.sys_views.c.viewschema == current_schema) + .where(self.sys_views.c.viewname == viewname) + ) + logger.debug(f"Generated get_view_definition SQL -> {query}") + definition = connection.execute(query).scalar() + logger.debug(f"View definition length -> {len(definition) if definition else 0}") + return definition + except Exception as e: + logger.error(f"Error fetching view definition: {e}") + logger.exception("Stack trace in get_view_definition") + raise @reflection.cache + @log_entry_exit def get_columns(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - syscols = self.sys_columns - - query = sql.select(syscols.c.colname, syscols.c.typename, - syscols.c.defaultval, syscols.c.nullable, - syscols.c.length, syscols.c.scale, - syscols.c.identity, syscols.c.generated, - syscols.c.remarks).\ - where(and_( - syscols.c.tabschema == current_schema, - syscols.c.tabname == table_name)).\ - order_by(syscols.c.colno) - sa_columns = [] - for r in connection.execute(query): - coltype = r[1].upper() - if coltype in ['DECIMAL', 'NUMERIC']: - coltype = self.ischema_names.get(coltype)(int(r[4]), int(r[5])) - elif coltype in ['CHARACTER', 'CHAR', 'VARCHAR', - 'GRAPHIC', 'VARGRAPHIC']: - coltype = self.ischema_names.get(coltype)(int(r[4])) - else: - try: - coltype = self.ischema_names[coltype] - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (coltype, r[0])) - coltype = coltype = sa_types.NULLTYPE - - sa_columns.append({ + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug(f"Fetching columns -> schema={current_schema}, table={table_name}") + syscols = self.sys_columns + query = ( + sql.select( + syscols.c.colname, syscols.c.typename, + syscols.c.defaultval, syscols.c.nullable, + syscols.c.length, syscols.c.scale, + syscols.c.identity, syscols.c.generated, + syscols.c.remarks + ) + .where(and_( + syscols.c.tabschema == current_schema, + syscols.c.tabname == table_name + )) + .order_by(syscols.c.colno) + ) + logger.debug(f"Generated get_columns SQL -> {query}") + sa_columns = [] + for r in connection.execute(query): + raw_type = r[1].upper() + logger.debug( + f"Processing column -> " + f"name={r[0]}, type={raw_type}, " + f"length={r[4]}, scale={r[5]}" + ) + if raw_type in ['DECIMAL', 'NUMERIC']: + coltype = self.ischema_names.get(raw_type)(int(r[4]), int(r[5])) + elif raw_type in ['CHARACTER', 'CHAR', 'VARCHAR', + 'GRAPHIC', 'VARGRAPHIC']: + coltype = self.ischema_names.get(raw_type)(int(r[4])) + else: + try: + coltype = self.ischema_names[raw_type] + except KeyError: + logger.warning( + f"Unrecognized column type '{raw_type}' " + f"for column '{r[0]}'" + ) + coltype = sa_types.NULLTYPE + column_info = { 'name': self.normalize_name(r[0]), 'type': coltype, 'nullable': r[3] == 'Y', 'default': r[2] or None, 'autoincrement': (r[6] == 'Y') and (r[7] != ' '), 'comment': r[8] or None, - }) - return sa_columns + } + logger.debug(f"Column reflected -> {column_info}") + sa_columns.append(column_info) + logger.debug(f"Total columns reflected -> count={len(sa_columns)}") + return sa_columns + except Exception as e: + logger.error(f"Error reflecting columns: {e}") + logger.exception("Stack trace in get_columns") + raise @reflection.cache + @log_entry_exit def get_pk_constraint(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - sysindexes = self.sys_indexes - col_finder = re.compile(r"(\w+)") - query = sql.select(sysindexes.c.colnames, sysindexes.c.indname).\ - where(and_(sysindexes.c.tabschema == current_schema, - sysindexes.c.tabname == table_name, - sysindexes.c.uniquerule == 'P')).\ - order_by(sysindexes.c.tabschema, sysindexes.c.tabname) - pk_columns = [] - pk_name = None - for r in connection.execute(query): - cols = col_finder.findall(r[0]) - pk_columns.extend(cols) - if not pk_name: - pk_name = self.normalize_name(r[1]) - - return {"constrained_columns": [self.normalize_name(col) for col in pk_columns], - "name": pk_name} + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug(f"Fetching primary key -> schema={current_schema}, table={table_name}") + sysindexes = self.sys_indexes + col_finder = re.compile(r"(\w+)") + query = ( + sql.select(sysindexes.c.colnames, sysindexes.c.indname) + .where(and_( + sysindexes.c.tabschema == current_schema, + sysindexes.c.tabname == table_name, + sysindexes.c.uniquerule == 'P' + )) + .order_by( + sysindexes.c.tabschema, + sysindexes.c.tabname + )) + logger.debug(f"Generated get_pk_constraint SQL -> {query}") + pk_columns = [] + pk_name = None + for r in connection.execute(query): + cols = col_finder.findall(r[0]) + pk_columns.extend(cols) + if not pk_name: + pk_name = self.normalize_name(r[1]) + normalized_columns = [self.normalize_name(col) for col in pk_columns] + logger.debug( + f"Primary key reflected -> " + f"name={pk_name}, columns={normalized_columns}" + ) + return { + "constrained_columns": normalized_columns, + "name": pk_name + } + except Exception as e: + logger.error(f"Error reflecting primary key: {e}") + logger.exception("Stack trace in get_pk_constraint") + raise @reflection.cache + @log_entry_exit def get_primary_keys(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - syscols = self.sys_columns - col_finder = re.compile(r"(\w+)") - query = sql.select(syscols.c.colname).\ - where(and_( - syscols.c.tabschema == current_schema, - syscols.c.tabname == table_name, - syscols.c.keyseq > 0 - )).\ - order_by(syscols.c.tabschema, syscols.c.tabname) - pk_columns = [] - for r in connection.execute(query): - cols = col_finder.findall(r[0]) - pk_columns.extend(cols) - return [self.normalize_name(col) for col in pk_columns] + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug(f"Fetching primary keys -> schema={current_schema}, table={table_name}") + syscols = self.sys_columns + col_finder = re.compile(r"(\w+)") + query = ( + sql.select(syscols.c.colname) + .where(and_( + syscols.c.tabschema == current_schema, + syscols.c.tabname == table_name, + syscols.c.keyseq > 0 + )) + .order_by(syscols.c.tabschema, syscols.c.tabname) + ) + logger.debug(f"Generated get_primary_keys SQL -> {query}") + pk_columns = [] + for r in connection.execute(query): + cols = col_finder.findall(r[0]) + pk_columns.extend(cols) + normalized_columns = [self.normalize_name(col) for col in pk_columns] + logger.debug(f"Primary keys reflected -> columns={normalized_columns}") + return normalized_columns + except Exception as e: + logger.error(f"Error reflecting primary keys: {e}") + logger.exception("Stack trace in get_primary_keys") + raise @reflection.cache + @log_entry_exit def get_foreign_keys(self, connection, table_name, schema=None, **kw): - default_schema = self.default_schema_name - current_schema = self.denormalize_name(schema or default_schema) - default_schema = self.normalize_name(default_schema) - table_name = self.denormalize_name(table_name) - sysfkeys = self.sys_foreignkeys - systbl = self.sys_tables - query = sql.select(sysfkeys.c.fkname, sysfkeys.c.fktabschema, - sysfkeys.c.fktabname, sysfkeys.c.fkcolname, - sysfkeys.c.pkname, sysfkeys.c.pktabschema, - sysfkeys.c.pktabname, sysfkeys.c.pkcolname).\ - select_from( - join(systbl, - sysfkeys, - sql.and_( - systbl.c.tabname == sysfkeys.c.pktabname, - systbl.c.tabschema == sysfkeys.c.pktabschema - ) - ) - ).where(systbl.c.type == 'T').\ - where(systbl.c.tabschema == current_schema).\ - where(sysfkeys.c.fktabname == table_name).\ - order_by(systbl.c.tabname) - - fschema = {} - for r in connection.execute(query): - if not (r[0]) in fschema: - referred_schema = self.normalize_name(r[5]) - - # if no schema specified and referred schema here is the - # default, then set to None - if schema is None and \ - referred_schema == default_schema: - referred_schema = None - - fschema[r[0]] = { - 'name': self.normalize_name(r[0]), - 'constrained_columns': [self.normalize_name(r[3])], - 'referred_schema': referred_schema, - 'referred_table': self.normalize_name(r[6]), - 'referred_columns': [self.normalize_name(r[7])]} - else: - fschema[r[0]]['constrained_columns'].append(self.normalize_name(r[3])) - fschema[r[0]]['referred_columns'].append(self.normalize_name(r[7])) - return [value for key, value in fschema.items()] + try: + default_schema = self.default_schema_name + current_schema = self.denormalize_name(schema or default_schema) + normalized_default_schema = self.normalize_name(default_schema) + table_name = self.denormalize_name(table_name) + logger.debug( + f"Fetching foreign keys -> " + f"schema={current_schema}, table={table_name}" + ) + sysfkeys = self.sys_foreignkeys + systbl = self.sys_tables + query = ( + sql.select( + sysfkeys.c.fkname, sysfkeys.c.fktabschema, + sysfkeys.c.fktabname, sysfkeys.c.fkcolname, + sysfkeys.c.pkname, sysfkeys.c.pktabschema, + sysfkeys.c.pktabname, sysfkeys.c.pkcolname + ) + .select_from( + join( + systbl, + sysfkeys, + sql.and_( + systbl.c.tabname == sysfkeys.c.pktabname, + systbl.c.tabschema == sysfkeys.c.pktabschema + ) + ) + ) + .where(systbl.c.type == 'T') + .where(systbl.c.tabschema == current_schema) + .where(sysfkeys.c.fktabname == table_name) + .order_by(systbl.c.tabname) + ) + logger.debug(f"Generated get_foreign_keys SQL -> {query}") + fschema = {} + for r in connection.execute(query): + fk_name = r[0] + if fk_name not in fschema: + referred_schema = self.normalize_name(r[5]) + # if no schema specified and referred schema here is the + # default, then set to None + if schema is None and \ + referred_schema == normalized_default_schema: + referred_schema = None + fschema[fk_name] = { + 'name': self.normalize_name(fk_name), + 'constrained_columns': [self.normalize_name(r[3])], + 'referred_schema': referred_schema, + 'referred_table': self.normalize_name(r[6]), + 'referred_columns': [self.normalize_name(r[7])] + } + logger.debug(f"Foreign key discovered -> {fschema[fk_name]}") + else: + fschema[fk_name]['constrained_columns'].append(self.normalize_name(r[3])) + fschema[fk_name]['referred_columns'].append(self.normalize_name(r[7])) + result = [value for value in fschema.values()] + logger.debug(f"Total foreign keys reflected -> count={len(result)}") + return result + except Exception as e: + logger.error(f"Error reflecting foreign keys: {e}") + logger.exception("Stack trace in get_foreign_keys") + raise @reflection.cache + @log_entry_exit def get_incoming_foreign_keys(self, connection, table_name, schema=None, **kw): - default_schema = self.default_schema_name - current_schema = self.denormalize_name(schema or default_schema) - default_schema = self.normalize_name(default_schema) - table_name = self.denormalize_name(table_name) - sysfkeys = self.sys_foreignkeys - query = sql.select(sysfkeys.c.fkname, sysfkeys.c.fktabschema, - sysfkeys.c.fktabname, sysfkeys.c.fkcolname, - sysfkeys.c.pkname, sysfkeys.c.pktabschema, - sysfkeys.c.pktabname, sysfkeys.c.pkcolname).\ - where(and_( - sysfkeys.c.pktabschema == current_schema, - sysfkeys.c.pktabname == table_name - )).\ - order_by(sysfkeys.c.colno) - - fschema = {} - for r in connection.execute(query): - if not fschema.has_key(r[0]): - constrained_schema = self.normalize_name(r[1]) - - # if no schema specified and referred schema here is the - # default, then set to None - if schema is None and \ - constrained_schema == default_schema: - constrained_schema = None - - fschema[r[0]] = { - 'name': self.normalize_name(r[0]), - 'constrained_schema': constrained_schema, - 'constrained_table': self.normalize_name(r[2]), - 'constrained_columns': [self.normalize_name(r[3])], - 'referred_schema': schema, - 'referred_table': self.normalize_name(r[6]), - 'referred_columns': [self.normalize_name(r[7])]} - else: - fschema[r[0]]['constrained_columns'].append(self.normalize_name(r[3])) - fschema[r[0]]['referred_columns'].append(self.normalize_name(r[7])) - return [value for key, value in fschema.items()] + try: + default_schema = self.default_schema_name + current_schema = self.denormalize_name(schema or default_schema) + normalized_default_schema = self.normalize_name(default_schema) + table_name = self.denormalize_name(table_name) + logger.debug( + f"Fetching incoming foreign keys -> " + f"schema={current_schema}, table={table_name}" + ) + sysfkeys = self.sys_foreignkeys + query = ( + sql.select( + sysfkeys.c.fkname, sysfkeys.c.fktabschema, + sysfkeys.c.fktabname, sysfkeys.c.fkcolname, + sysfkeys.c.pkname, sysfkeys.c.pktabschema, + sysfkeys.c.pktabname, sysfkeys.c.pkcolname + ) + .where(and_( + sysfkeys.c.pktabschema == current_schema, + sysfkeys.c.pktabname == table_name + )) + .order_by(sysfkeys.c.colno) + ) + logger.debug(f"Generated get_incoming_foreign_keys SQL -> {query}") + fschema = {} + for r in connection.execute(query): + fk_name = r[0] + if fk_name not in fschema: + constrained_schema = self.normalize_name(r[1]) + # if no schema specified and referred schema here is the + # default, then set to None + if schema is None and \ + constrained_schema == normalized_default_schema: + constrained_schema = None + fschema[fk_name] = { + 'name': self.normalize_name(fk_name), + 'constrained_schema': constrained_schema, + 'constrained_table': self.normalize_name(r[2]), + 'constrained_columns': [self.normalize_name(r[3])], + 'referred_schema': schema, + 'referred_table': self.normalize_name(r[6]), + 'referred_columns': [self.normalize_name(r[7])] + } + logger.debug(f"Incoming foreign key discovered -> {fschema[fk_name]}") + else: + fschema[fk_name]['constrained_columns'].append(self.normalize_name(r[3])) + fschema[fk_name]['referred_columns'].append(self.normalize_name(r[7])) + result = [value for value in fschema.values()] + logger.debug(f"Total incoming foreign keys reflected -> count={len(result)}") + return result + except Exception as e: + logger.error(f"Error reflecting incoming foreign keys: {e}") + logger.exception("Stack trace in get_incoming_foreign_keys") + raise @reflection.cache + @log_entry_exit def get_indexes(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - sysidx = self.sys_indexes - query = sql.select(sysidx.c.indname, sysidx.c.colnames, - sysidx.c.uniquerule, sysidx.c.system_required).\ - where(and_(sysidx.c.tabschema == current_schema,sysidx.c.tabname == table_name)).\ - order_by(sysidx.c.tabname) - indexes = [] - col_finder = re.compile(r"(\w+)") - for r in connection.execute(query): - if r[2] != 'P': - if r[2] == 'U' and r[3] != 0: + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug(f"Fetching indexes -> schema={current_schema}, table={table_name}") + sysidx = self.sys_indexes + query = ( + sql.select(sysidx.c.indname, sysidx.c.colnames, + sysidx.c.uniquerule, sysidx.c.system_required + ) + .where(and_( + sysidx.c.tabschema == current_schema, + sysidx.c.tabname == table_name + )) + .order_by(sysidx.c.tabname) + ) + logger.debug(f"Generated get_indexes SQL -> {query}") + indexes = [] + col_finder = re.compile(r"(\w+)") + for r in connection.execute(query): + index_name = r[0] + column_text = r[1] + unique_rule = r[2] + system_required = r[3] + logger.debug( + f"Processing index row -> " + f"name={index_name}, unique_rule={unique_rule}, " + f"system_required={system_required}" + ) + if unique_rule == 'P': + logger.debug(f"Skipping primary key index -> {index_name}") continue - if 'sqlnotapplicable' in r[1].lower(): + if unique_rule == 'U' and system_required != 0: + logger.debug(f"Skipping system-required unique index -> {index_name}") continue - indexes.append({ - 'name': self.normalize_name(r[0]), - 'column_names': [self.normalize_name(col) - for col in col_finder.findall(r[1])], - 'unique': r[2] == 'U' - }) - return indexes + if 'sqlnotapplicable' in column_text.lower(): + logger.debug(f"Skipping internal index -> {index_name}") + continue + normalized_columns = [self.normalize_name(col) for col in col_finder.findall(column_text)] + index_info = { + 'name': self.normalize_name(index_name), + 'column_names': normalized_columns, + 'unique': unique_rule == 'U' + } + logger.debug(f"Index reflected -> {index_info}") + indexes.append(index_info) + logger.debug(f"Total indexes reflected -> count={len(indexes)}") + return indexes + except Exception as e: + logger.error(f"Error reflecting indexes: {e}") + logger.exception("Stack trace in get_indexes") + raise @reflection.cache + @log_entry_exit def get_unique_constraints(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - syskeycol = self.sys_keycoluse - sysconst = self.sys_tabconst - query = ( - sql.select(syskeycol.c.constname, syskeycol.c.colname) - .select_from( - join( - syskeycol, - sysconst, - and_( - syskeycol.c.constname == sysconst.c.constname, - syskeycol.c.tabschema == sysconst.c.tabschema, - syskeycol.c.tabname == sysconst.c.tabname, - ), - ) + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug( + f"Fetching unique constraints -> " + f"schema={current_schema}, table={table_name}" ) - .where( - and_( - sysconst.c.tabname == table_name, - sysconst.c.tabschema == current_schema, - sysconst.c.type == "U", + syskeycol = self.sys_keycoluse + sysconst = self.sys_tabconst + query = ( + sql.select( + syskeycol.c.constname, + syskeycol.c.colname ) + .select_from( + join( + syskeycol, + sysconst, + and_( + syskeycol.c.constname == sysconst.c.constname, + syskeycol.c.tabschema == sysconst.c.tabschema, + syskeycol.c.tabname == sysconst.c.tabname, + ), + ) + ) + .where( + and_( + sysconst.c.tabname == table_name, + sysconst.c.tabschema == current_schema, + sysconst.c.type == "U", + ) + ) + .order_by(syskeycol.c.constname) ) - .order_by(syskeycol.c.constname) - ) - uniqueConsts = [] - currConst = None - for r in connection.execute(query): - if currConst == r[0]: - uniqueConsts[-1]["column_names"].append(self.normalize_name(r[1])) - else: - currConst = r[0] - uniqueConsts.append( - { + logger.debug(f"Generated get_unique_constraints SQL -> {query}") + uniqueConsts = [] + currConst = None + for r in connection.execute(query): + constraint_name = r[0] + column_name = self.normalize_name(r[1]) + if currConst == constraint_name: + uniqueConsts[-1]["column_names"].append(column_name) + logger.debug( + f"Appending column to constraint -> " + f"name={constraint_name}, column={column_name}" + ) + else: + currConst = constraint_name + constraint_info = { "name": self.normalize_name(currConst), - "column_names": [self.normalize_name(r[1])], + "column_names": [column_name], } - ) - return uniqueConsts + logger.debug(f"New unique constraint discovered -> {constraint_info}") + uniqueConsts.append(constraint_info) + logger.debug( + f"Total unique constraints reflected -> " + f"count={len(uniqueConsts)}" + ) + return uniqueConsts + except Exception as e: + logger.error(f"Error reflecting unique constraints: {e}") + logger.exception("Stack trace in get_unique_constraints") + raise class AS400Reflector(BaseReflector): @@ -588,288 +900,518 @@ class AS400Reflector(BaseReflector): Column("SEQUENCE_NAME", CoerceUnicode, key="seqname"), schema="QSYS2") + @log_entry_exit def has_table(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name( - schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - if current_schema: + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug( + f"[AS400] Checking table existence -> " + f"schema={current_schema}, table={table_name}" + ) + if current_schema: whereclause = sql.and_( - self.sys_tables.c.tabschema == current_schema, - self.sys_tables.c.tabname == table_name) - else: + self.sys_tables.c.tabschema == current_schema, + self.sys_tables.c.tabname == table_name + ) + else: whereclause = self.sys_tables.c.tabname == table_name - s = sql.select(self.sys_tables).where(whereclause) - c = connection.execute(s) - return c.first() is not None - + s = sql.select(self.sys_tables).where(whereclause) + logger.debug(f"[AS400] Generated has_table SQL -> {s}") + result = connection.execute(s).first() is not None + logger.debug(f"[AS400] has_table result -> exists={result}") + return result + except Exception as e: + logger.error(f"[AS400] Error in has_table: {e}") + logger.exception("Stack trace in AS400 has_table") + raise + + @log_entry_exit def has_sequence(self, connection, sequence_name, schema=None): - current_schema = self.denormalize_name( - schema or self.default_schema_name) - sequence_name = self.denormalize_name(sequence_name) - if current_schema: + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + sequence_name = self.denormalize_name(sequence_name) + logger.debug( + f"[AS400] Checking sequence existence -> " + f"schema={current_schema}, sequence={sequence_name}" + ) + if current_schema: whereclause = sql.and_( - self.sys_sequences.c.seqschema == current_schema, - self.sys_sequences.c.seqname == sequence_name) - else: + self.sys_sequences.c.seqschema == current_schema, + self.sys_sequences.c.seqname == sequence_name + ) + else: whereclause = self.sys_sequences.c.seqname == sequence_name - s = sql.select(self.sys_sequences.c.seqname).where(whereclause) - c = connection.execute(s) - return c.first() is not None - + s = sql.select(self.sys_sequences.c.seqname).where(whereclause) + logger.debug(f"[AS400] Generated has_sequence SQL -> {s}") + result = connection.execute(s).first() is not None + logger.debug(f"[AS400] has_sequence result -> exists={result}") + return result + except Exception as e: + logger.error(f"[AS400] Error in has_sequence: {e}") + logger.exception("Stack trace in AS400 has_sequence") + raise + + @log_entry_exit def get_table_comment(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - systbl = self.sys_tables - query = sql.select(systbl.c.remarks).\ - where(systbl.c.tabschema == current_schema).\ - where(systbl.c.tabname == table_name) - return {'text': connection.execute(query).scalar()} + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug( + f"[AS400] Fetching table comment -> " + f"schema={current_schema}, table={table_name}" + ) + systbl = self.sys_tables + query = ( + sql.select(systbl.c.remarks) + .where(systbl.c.tabschema == current_schema) + .where(systbl.c.tabname == table_name) + ) + logger.debug(f"[AS400] Generated get_table_comment SQL -> {query}") + comment = connection.execute(query).scalar() + logger.debug(f"[AS400] Table comment result -> {comment}") + return {'text': comment} + except Exception as e: + logger.error(f"[AS400] Error in get_table_comment: {e}") + logger.exception("Stack trace in AS400 get_table_comment") + raise @reflection.cache + @log_entry_exit def get_sequence_names(self, connection, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - sys_sequence = self.sys_sequences - query = sql.select(sys_sequence.c.seqname).\ - where(sys_sequence.c.seqschema == current_schema).\ - order_by(sys_sequence.c.seqschema, sys_sequence.c.seqname) - return [self.normalize_name(r[0]) for r in connection.execute(query)] + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + logger.debug( + f"[AS400] Fetching sequence names -> " + f"schema={current_schema}" + ) + sys_sequence = self.sys_sequences + query = ( + sql.select(sys_sequence.c.seqname) + .where(sys_sequence.c.seqschema == current_schema) + .order_by(sys_sequence.c.seqschema, sys_sequence.c.seqname) + ) + logger.debug(f"[AS400] Generated get_sequence_names SQL -> {query}") + result = [self.normalize_name(r[0]) for r in connection.execute(query)] + logger.debug( + f"[AS400] Reflected sequences -> count={len(result)}, " + f"sequences={result}" + ) + return result + except Exception as e: + logger.error(f"[AS400] Error in get_sequence_names: {e}") + logger.exception("Stack trace in AS400 get_sequence_names") + raise @reflection.cache + @log_entry_exit def get_schema_names(self, connection, **kw): - sysschema = self.sys_schemas - if version_info[0] < 3: - query = sql.select(sysschema.c.schemaname). \ - where(~sysschema.c.schemaname.like(unicode('Q%'))). \ - where(~sysschema.c.schemaname.like(unicode('SYS%'))). \ - order_by(sysschema.c.schemaname) - else: - query = sql.select(sysschema.c.schemaname). \ - where(~sysschema.c.schemaname.like(str('Q%'))). \ - where(~sysschema.c.schemaname.like(str('SYS%'))). \ - order_by(sysschema.c.schemaname) - return [self.normalize_name(r[0].rstrip()) for r in connection.execute(query)] + try: + logger.debug("[AS400] Fetching schema names") + sysschema = self.sys_schemas + if version_info[0] < 3: + logger.debug("[AS400] Using unicode branch for schema filtering") + query = ( + sql.select(sysschema.c.schemaname) + .where(~sysschema.c.schemaname.like(unicode('Q%'))) + .where(~sysschema.c.schemaname.like(unicode('SYS%'))) + .order_by(sysschema.c.schemaname) + ) + else: + logger.debug("[AS400] Using str branch for schema filtering") + query = ( + sql.select(sysschema.c.schemaname) + .where(~sysschema.c.schemaname.like(str('Q%'))) + .where(~sysschema.c.schemaname.like(str('SYS%'))) + .order_by(sysschema.c.schemaname) + ) + logger.debug(f"[AS400] Generated get_schema_names SQL -> {query}") + result = [ + self.normalize_name(r[0].rstrip()) + for r in connection.execute(query) + ] + logger.debug(f"[AS400] Reflected schemas -> count={len(result)}, schemas={result}") + return result + except Exception as e: + logger.error(f"[AS400] Error in get_schema_names: {e}") + logger.exception("Stack trace in AS400 get_schema_names") + raise # Retrieves a list of table names for a given schema @reflection.cache + @log_entry_exit def get_table_names(self, connection, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - systbl = self.sys_tables - if version_info[0] < 3: - query = sql.select(systbl.c.tabname). \ - where(systbl.c.tabtype == unicode('T')). \ - where(systbl.c.tabschema == current_schema). \ - order_by(systbl.c.tabname) - else: - query = sql.select(systbl.c.tabname). \ - where(systbl.c.tabtype == str('T')). \ - where(systbl.c.tabschema == current_schema). \ - order_by(systbl.c.tabname) - return [self.normalize_name(r[0]) for r in connection.execute(query)] + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + logger.debug(f"[AS400] Fetching table names -> schema={current_schema}") + systbl = self.sys_tables + if version_info[0] < 3: + logger.debug("[AS400] Using unicode branch for table type filter") + query = ( + sql.select(systbl.c.tabname) + .where(systbl.c.tabtype == unicode('T')) + .where(systbl.c.tabschema == current_schema) + .order_by(systbl.c.tabname) + ) + else: + logger.debug("[AS400] Using str branch for table type filter") + query = ( + sql.select(systbl.c.tabname) + .where(systbl.c.tabtype == str('T')) + .where(systbl.c.tabschema == current_schema) + .order_by(systbl.c.tabname) + ) + logger.debug(f"[AS400] Generated get_table_names SQL -> {query}") + result = [ + self.normalize_name(r[0]) + for r in connection.execute(query) + ] + logger.debug(f"[AS400] Reflected tables -> count={len(result)}, tables={result}") + return result + except Exception as e: + logger.error(f"[AS400] Error in get_table_names: {e}") + logger.exception("Stack trace in AS400 get_table_names") + raise @reflection.cache + @log_entry_exit def get_view_names(self, connection, schema=None, **kw): - current_schema = self.denormalize_name( - schema or self.default_schema_name) - - query = sql.select(self.sys_views.c.viewname).\ - where(self.sys_views.c.viewschema == current_schema).\ - order_by(self.sys_views.c.viewname) - return [self.normalize_name(r[0]) for r in connection.execute(query)] + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + logger.debug(f"[AS400] Fetching view names -> schema={current_schema}") + query = ( + sql.select(self.sys_views.c.viewname) + .where(self.sys_views.c.viewschema == current_schema) + .order_by(self.sys_views.c.viewname) + ) + logger.debug(f"[AS400] Generated get_view_names SQL -> {query}") + result = [ + self.normalize_name(r[0]) + for r in connection.execute(query) + ] + logger.debug(f"[AS400] Reflected views -> count={len(result)}, views={result}") + return result + except Exception as e: + logger.error(f"[AS400] Error in get_view_names: {e}") + logger.exception("Stack trace in AS400 get_view_names") + raise @reflection.cache + @log_entry_exit def get_view_definition(self, connection, viewname, schema=None, **kw): - current_schema = self.denormalize_name( - schema or self.default_schema_name) - viewname = self.denormalize_name(viewname) - - query = sql.select(self.sys_views.c.text).\ - where(self.sys_views.c.viewschema == current_schema).\ - where(self.sys_views.c.viewname == viewname) - return connection.execute(query).scalar() + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + viewname = self.denormalize_name(viewname) + logger.debug( + f"[AS400] Fetching view definition -> " + f"schema={current_schema}, view={viewname}" + ) + query = ( + sql.select(self.sys_views.c.text) + .where(self.sys_views.c.viewschema == current_schema) + .where(self.sys_views.c.viewname == viewname) + ) + logger.debug(f"[AS400] Generated get_view_definition SQL -> {query}") + definition = connection.execute(query).scalar() + logger.debug( + f"[AS400] View definition length -> " + f"{len(definition) if definition else 0}" + ) + return definition + except Exception as e: + logger.error(f"[AS400] Error in get_view_definition: {e}") + logger.exception("Stack trace in AS400 get_view_definition") + raise @reflection.cache + @log_entry_exit def get_columns(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - syscols = self.sys_columns - - query = sql.select(syscols.c.colname,syscols.c.typename, - syscols.c.defaultval, syscols.c.nullable, - syscols.c.length, syscols.c.scale, - syscols.c.isid, syscols.c.idgenerate, - syscols.c.remark).\ - where(and_( - syscols.c.tabschema == current_schema, - syscols.c.tabname == table_name)).\ - order_by(syscols.c.colno) - sa_columns = [] - for r in connection.execute(query): - coltype = r[1].upper() - if coltype in ['DECIMAL', 'NUMERIC']: - coltype = self.ischema_names.get(coltype)(int(r[4]), int(r[5])) - elif coltype in ['CHARACTER', 'CHAR', 'VARCHAR', - 'GRAPHIC', 'VARGRAPHIC']: - coltype = self.ischema_names.get(coltype)(int(r[4])) - else: - try: - coltype = self.ischema_names[coltype] - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (coltype, r[0])) - coltype = coltype = sa_types.NULLTYPE - - if version_info[0] < 3: - sa_columns.append({ - 'name': self.normalize_name(r[0]), - 'type': coltype, - 'nullable': r[3] == unicode('Y'), - 'default': r[2], - 'autoincrement': (r[6] == unicode('YES')) and (r[7] != None), - 'comment': r[8] or None, - }) - else: - sa_columns.append({ + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug( + f"[AS400] Fetching columns -> " + f"schema={current_schema}, table={table_name}" + ) + syscols = self.sys_columns + query = ( + sql.select( + syscols.c.colname, syscols.c.typename, + syscols.c.defaultval, syscols.c.nullable, + syscols.c.length, syscols.c.scale, + syscols.c.isid, syscols.c.idgenerate, + syscols.c.remark + ) + .where(and_( + syscols.c.tabschema == current_schema, + syscols.c.tabname == table_name + )) + .order_by(syscols.c.colno) + ) + logger.debug(f"[AS400] Generated get_columns SQL -> {query}") + sa_columns = [] + for r in connection.execute(query): + raw_type = r[1].upper() + logger.debug( + f"[AS400] Processing column -> " + f"name={r[0]}, type={raw_type}, " + f"length={r[4]}, scale={r[5]}" + ) + if raw_type in ['DECIMAL', 'NUMERIC']: + coltype = self.ischema_names.get(raw_type)(int(r[4]), int(r[5])) + elif raw_type in ['CHARACTER', 'CHAR', 'VARCHAR', + 'GRAPHIC', 'VARGRAPHIC']: + coltype = self.ischema_names.get(raw_type)(int(r[4])) + else: + try: + coltype = self.ischema_names[raw_type] + except KeyError: + logger.warning( + f"[AS400] Unrecognized type '{raw_type}' " + f"for column '{r[0]}'" + ) + coltype = sa_types.NULLTYPE + if version_info[0] < 3: + nullable_flag = r[3] == unicode('Y') + autoinc_flag = (r[6] == unicode('YES')) and (r[7] is not None) + else: + nullable_flag = r[3] == str('Y') + autoinc_flag = (r[6] == str('YES')) and (r[7] is not None) + column_info = { 'name': self.normalize_name(r[0]), 'type': coltype, - 'nullable': r[3] == str('Y'), + 'nullable': nullable_flag, 'default': r[2], - 'autoincrement': (r[6] == str('YES')) and (r[7] != None), + 'autoincrement': autoinc_flag, 'comment': r[8] or None, - }) - return sa_columns + } + logger.debug(f"[AS400] Column reflected -> {column_info}") + sa_columns.append(column_info) + logger.debug(f"[AS400] Total columns reflected -> count={len(sa_columns)}") + return sa_columns + except Exception as e: + logger.error(f"[AS400] Error reflecting columns: {e}") + logger.exception("Stack trace in AS400 get_columns") + raise @reflection.cache + @log_entry_exit def get_pk_constraint(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name( - schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - sysconst = self.sys_table_constraints - syskeyconst = self.sys_key_constraints - - query = sql.select(syskeyconst.c.colname, sysconst.c.tabname, sysconst.c.conname).\ - where(and_( - syskeyconst.c.conschema == sysconst.c.conschema, - syskeyconst.c.conname == sysconst.c.conname, - sysconst.c.tabschema == current_schema, - sysconst.c.tabname == table_name, - sysconst.c.contype == 'PRIMARY KEY')).\ - order_by(syskeyconst.c.colno) - - pk_columns = [] - pk_name = None - for key in connection.execute(query): - pk_columns.append(self.normalize_name(key[0])) - if not pk_name: - pk_name = self.normalize_name(key[2]) - return {"constrained_columns": pk_columns, "name": pk_name} + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug( + f"[AS400] Fetching PK constraint -> " + f"schema={current_schema}, table={table_name}" + ) + sysconst = self.sys_table_constraints + syskeyconst = self.sys_key_constraints + query = ( + sql.select(syskeyconst.c.colname, sysconst.c.tabname, sysconst.c.conname) + .where(and_( + syskeyconst.c.conschema == sysconst.c.conschema, + syskeyconst.c.conname == sysconst.c.conname, + sysconst.c.tabschema == current_schema, + sysconst.c.tabname == table_name, + sysconst.c.contype == 'PRIMARY KEY' + )).order_by(syskeyconst.c.colno) + ) + logger.debug(f"[AS400] Generated get_pk_constraint SQL -> {query}") + pk_columns = [] + pk_name = None + for key in connection.execute(query): + pk_columns.append(self.normalize_name(key[0])) + if not pk_name: + pk_name = self.normalize_name(key[2]) + logger.debug(f"[AS400] PK reflected -> name={pk_name}, columns={pk_columns}") + return {"constrained_columns": pk_columns, "name": pk_name} + except Exception as e: + logger.error(f"[AS400] Error reflecting PK constraint: {e}") + logger.exception("Stack trace in AS400 get_pk_constraint") + raise @reflection.cache + @log_entry_exit def get_primary_keys(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name( - schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - sysconst = self.sys_table_constraints - syskeyconst = self.sys_key_constraints - - if version_info[0] < 3: - query = sql.select(syskeyconst.c.colname, sysconst.c.tabname). \ - where(and_( - syskeyconst.c.conschema == sysconst.c.conschema, - syskeyconst.c.conname == sysconst.c.conname, - sysconst.c.tabschema == current_schema, - sysconst.c.tabname == table_name, - sysconst.c.contype == unicode('PRIMARY KEY'))). \ - order_by(syskeyconst.c.colno) - else: - query = sql.select(syskeyconst.c.colname, sysconst.c.tabname). \ - where(and_( - syskeyconst.c.conschema == sysconst.c.conschema, - syskeyconst.c.conname == sysconst.c.conname, - sysconst.c.tabschema == current_schema, - sysconst.c.tabname == table_name, - sysconst.c.contype == str('PRIMARY KEY'))). \ - order_by(syskeyconst.c.colno) - - return [self.normalize_name(key[0]) - for key in connection.execute(query)] + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug( + f"[AS400] Fetching primary keys -> " + f"schema={current_schema}, table={table_name}" + ) + sysconst = self.sys_table_constraints + syskeyconst = self.sys_key_constraints + if version_info[0] < 3: + logger.debug("[AS400] Using unicode branch for PK lookup") + query = ( + sql.select(syskeyconst.c.colname, sysconst.c.tabname) + .where(and_( + syskeyconst.c.conschema == sysconst.c.conschema, + syskeyconst.c.conname == sysconst.c.conname, + sysconst.c.tabschema == current_schema, + sysconst.c.tabname == table_name, + sysconst.c.contype == unicode('PRIMARY KEY') + )) + .order_by(syskeyconst.c.colno) + ) + else: + logger.debug("[AS400] Using str branch for PK lookup") + query = ( + sql.select(syskeyconst.c.colname, sysconst.c.tabname) + .where(and_( + syskeyconst.c.conschema == sysconst.c.conschema, + syskeyconst.c.conname == sysconst.c.conname, + sysconst.c.tabschema == current_schema, + sysconst.c.tabname == table_name, + sysconst.c.contype == str('PRIMARY KEY') + )) + .order_by(syskeyconst.c.colno) + ) + logger.debug(f"[AS400] Generated get_primary_keys SQL -> {query}") + result = [ + self.normalize_name(key[0]) + for key in connection.execute(query) + ] + logger.debug(f"[AS400] Primary keys reflected -> {result}") + return result + except Exception as e: + logger.error(f"[AS400] Error reflecting primary keys: {e}") + logger.exception("Stack trace in AS400 get_primary_keys") + raise @reflection.cache + @log_entry_exit def get_foreign_keys(self, connection, table_name, schema=None, **kw): - default_schema = self.default_schema_name - current_schema = self.denormalize_name(schema or default_schema) - default_schema = self.normalize_name(default_schema) - table_name = self.denormalize_name(table_name) - sysfkeys = self.sys_foreignkeys - query = sql.select(sysfkeys.c.fkname, sysfkeys.c.fktabschema, - sysfkeys.c.fktabname, sysfkeys.c.fkcolname, - sysfkeys.c.pkname, sysfkeys.c.pktabschema, - sysfkeys.c.pktabname, sysfkeys.c.pkcolname).\ - where(and_( - sysfkeys.c.fktabschema == current_schema, - sysfkeys.c.fktabname == table_name)).\ - order_by(sysfkeys.c.colno) - fschema = {} - for r in connection.execute(query): - if r[0] not in fschema: - referred_schema = self.normalize_name(r[5]) - - # if no schema specified and referred schema here is the - # default, then set to None - if schema is None and \ - referred_schema == default_schema: - referred_schema = None - - fschema[r[0]] = {'name': self.normalize_name(r[0]), - 'constrained_columns': [self.normalize_name(r[3])], - 'referred_schema': referred_schema, - 'referred_table': self.normalize_name(r[6]), - 'referred_columns': [self.normalize_name(r[7])]} - else: - fschema[r[0]]['constrained_columns'].append(self.normalize_name(r[3])) - fschema[r[0]]['referred_columns'].append(self.normalize_name(r[7])) - return [value for key, value in fschema.items()] + try: + default_schema = self.default_schema_name + current_schema = self.denormalize_name(schema or default_schema) + normalized_default_schema = self.normalize_name(default_schema) + table_name = self.denormalize_name(table_name) + logger.debug( + f"[AS400] Fetching foreign keys -> " + f"schema={current_schema}, table={table_name}" + ) + sysfkeys = self.sys_foreignkeys + query = ( + sql.select( + sysfkeys.c.fkname, sysfkeys.c.fktabschema, + sysfkeys.c.fktabname, sysfkeys.c.fkcolname, + sysfkeys.c.pkname, sysfkeys.c.pktabschema, + sysfkeys.c.pktabname, sysfkeys.c.pkcolname + ) + .where(and_( + sysfkeys.c.fktabschema == current_schema, + sysfkeys.c.fktabname == table_name + )) + .order_by(sysfkeys.c.colno) + ) + logger.debug(f"[AS400] Generated get_foreign_keys SQL -> {query}") + fschema = {} + for r in connection.execute(query): + fk_name = r[0] + if fk_name not in fschema: + referred_schema = self.normalize_name(r[5]) + # if no schema specified and referred schema here is the + # default, then set to None + if schema is None and \ + referred_schema == normalized_default_schema: + referred_schema = None + fschema[fk_name] = { + 'name': self.normalize_name(fk_name), + 'constrained_columns': [self.normalize_name(r[3])], + 'referred_schema': referred_schema, + 'referred_table': self.normalize_name(r[6]), + 'referred_columns': [self.normalize_name(r[7])] + } + logger.debug(f"[AS400] Foreign key discovered -> {fschema[fk_name]}") + else: + fschema[fk_name]['constrained_columns'].append(self.normalize_name(r[3])) + fschema[fk_name]['referred_columns'].append(self.normalize_name(r[7])) + result = list(fschema.values()) + logger.debug(f"[AS400] Total foreign keys reflected -> count={len(result)}") + return result + except Exception as e: + logger.error(f"[AS400] Error reflecting foreign keys: {e}") + logger.exception("Stack trace in AS400 get_foreign_keys") + raise # Retrieves a list of index names for a given schema @reflection.cache + @log_entry_exit def get_indexes(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name( - schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - - sysidx = self.sys_indexes - syskey = self.sys_keys - - query = sql.select(sysidx.c.indname,sysidx.c.uniquerule, - syskey.c.colname).\ - where(and_( - syskey.c.indschema == sysidx.c.indschema, - syskey.c.indname == sysidx.c.indname, - sysidx.c.tabschema == current_schema, - sysidx.c.tabname == table_name)).\ - order_by(syskey.c.indname, syskey.c.colno) - indexes = {} - for r in connection.execute(query): - key = r[0].upper() - if key in indexes: - indexes[key]['column_names'].append(self.normalize_name(r[2])) - else: - if version_info[0] < 3: - indexes[key] = { - 'name': self.normalize_name(r[0]), - 'column_names': [self.normalize_name(r[2])], - 'unique': r[1] == unicode('Y') - } + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug( + f"[AS400] Fetching indexes -> " + f"schema={current_schema}, table={table_name}" + ) + sysidx = self.sys_indexes + syskey = self.sys_keys + query = ( + sql.select( + sysidx.c.indname, + sysidx.c.uniquerule, + syskey.c.colname + ) + .where(and_( + syskey.c.indschema == sysidx.c.indschema, + syskey.c.indname == sysidx.c.indname, + sysidx.c.tabschema == current_schema, + sysidx.c.tabname == table_name + )) + .order_by(syskey.c.indname, syskey.c.colno) + ) + logger.debug(f"[AS400] Generated get_indexes SQL -> {query}") + indexes = {} + for r in connection.execute(query): + index_name_raw = r[0] + unique_flag_raw = r[1] + column_raw = r[2] + key = index_name_raw.upper() + logger.debug( + f"[AS400] Processing index row -> " + f"name={index_name_raw}, " + f"unique_flag={unique_flag_raw}, " + f"column={column_raw}" + ) + if key in indexes: + indexes[key]['column_names'].append(self.normalize_name(column_raw)) else: + if version_info[0] < 3: + is_unique = unique_flag_raw == unicode('Y') + else: + is_unique = unique_flag_raw == str('Y') indexes[key] = { - 'name': self.normalize_name(r[0]), - 'column_names': [self.normalize_name(r[2])], - 'unique': r[1] == str('Y') + 'name': self.normalize_name(index_name_raw), + 'column_names': [self.normalize_name(column_raw)], + 'unique': is_unique } - return [value for key, value in indexes.items()] + logger.debug( + f"[AS400] New index discovered -> " + f"{indexes[key]}" + ) + result = list(indexes.values()) + logger.debug(f"[AS400] Total indexes reflected -> count={len(result)}") + return result + except Exception as e: + logger.error(f"[AS400] Error reflecting indexes: {e}") + logger.exception("Stack trace in AS400 get_indexes") + raise @reflection.cache + @log_entry_exit def get_unique_constraints(self, connection, table_name, schema=None, **kw): + logger.debug( + f"[AS400] get_unique_constraints invoked -> " + f"schema={schema}, table={table_name}" + ) uniqueConsts = [] + logger.debug( + "[AS400] Unique constraints not implemented for AS400 " + "(returning empty list)" + ) return uniqueConsts @@ -958,321 +1500,469 @@ class OS390Reflector(BaseReflector): Column("NAME", CoerceUnicode, key="seqname"), schema="SYSIBM") + @log_entry_exit def has_table(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name( - schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - if current_schema: - whereclause = sql.and_(self.sys_tables.c.tabschema == current_schema, - self.sys_tables.c.tabname == table_name) - else: - whereclause = self.sys_tables.c.tabname == table_name - s = sql.select(self.sys_tables.c.tabname).where(whereclause) - c = connection.execute(s) - return c.first() is not None - + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug( + f"Checking table existence (OS390) -> " + f"schema={current_schema}, table={table_name}" + ) + if current_schema: + whereclause = sql.and_( + self.sys_tables.c.tabschema == current_schema, + self.sys_tables.c.tabname == table_name + ) + else: + whereclause = self.sys_tables.c.tabname == table_name + s = sql.select(self.sys_tables.c.tabname).where(whereclause) + logger.debug(f"has_table SQL -> {s}") + result = connection.execute(s).first() is not None + logger.debug(f"has_table result -> {result}") + return result + except Exception: + logger.exception("Error in has_table (OS390)") + raise + + @log_entry_exit def has_sequence(self, connection, sequence_name, schema=None): - current_schema = self.denormalize_name(schema or self.default_schema_name) - sequence_name = self.denormalize_name(sequence_name) - if current_schema: - whereclause = sql.and_(self.sys_sequences.c.seqschema == current_schema, - self.sys_sequences.c.seqname == sequence_name) - else: - whereclause = self.sys_sequences.c.seqname == sequence_name - s = sql.select(self.sys_sequences.c.seqname).where(whereclause) - c = connection.execute(s) - return c.first() is not None + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + sequence_name = self.denormalize_name(sequence_name) + logger.debug( + f"Checking sequence existence (OS390) -> " + f"schema={current_schema}, sequence={sequence_name}" + ) + if current_schema: + whereclause = sql.and_( + self.sys_sequences.c.seqschema == current_schema, + self.sys_sequences.c.seqname == sequence_name + ) + else: + whereclause = self.sys_sequences.c.seqname == sequence_name + s = sql.select(self.sys_sequences.c.seqname).where(whereclause) + logger.debug(f"has_sequence SQL -> {s}") + result = connection.execute(s).first() is not None + logger.debug(f"has_sequence result -> {result}") + return result + except Exception: + logger.exception("Error in has_sequence (OS390)") + raise @reflection.cache + @log_entry_exit def get_sequence_names(self, connection, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - sys_sequence = self.sys_sequences - query = sql.select(sys_sequence.c.seqname).\ - where(sys_sequence.c.seqschema == current_schema).\ - order_by(sys_sequence.c.seqschema, sys_sequence.c.seqname) - return [self.normalize_name(r[0]) for r in connection.execute(query)] + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + logger.debug(f"Fetching sequence names (OS390) -> schema={current_schema}") + sys_sequence = self.sys_sequences + query = ( + sql.select(sys_sequence.c.seqname) + .where(sys_sequence.c.seqschema == current_schema) + .order_by(sys_sequence.c.seqschema, sys_sequence.c.seqname) + ) + logger.debug(f"get_sequence_names SQL -> {query}") + result = [self.normalize_name(r[0]) for r in connection.execute(query)] + logger.debug(f"Sequences found -> count={len(result)}") + return result + except Exception: + logger.exception("Error in get_sequence_names (OS390)") + raise @reflection.cache + @log_entry_exit def get_schema_names(self, connection, **kw): - sysschema = self.sys_tables - query = sql.select(sysschema.c.tabschema).\ - where(not_(sysschema.c.tabschema.like('SYS%'))).\ - distinct(sysschema.c.tabschema) - return [self.normalize_name(r[0].rstrip()) for r in connection.execute(query)] - + try: + logger.debug("[OS390] get_schema_names invoked") + sysschema = self.sys_tables + query = sql.select(sysschema.c.tabschema). \ + where(not_(sysschema.c.tabschema.like('SYS%'))). \ + distinct(sysschema.c.tabschema) + logger.debug(f"[OS390] get_schema_names SQL -> {query}") + result = [ + self.normalize_name(r[0].rstrip()) + for r in connection.execute(query) + ] + logger.debug(f"[OS390] schemas found -> count={len(result)}") + return result + except Exception: + logger.exception("[OS390] Error in get_schema_names") + raise + + @log_entry_exit def get_table_comment(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - systbl = self.sys_tables - query = sql.select(systbl.c.remarks).\ - where(systbl.c.tabschema == current_schema).\ - where(systbl.c.tabname == table_name) - return {'text': connection.execute(query).scalar()} + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug(f"[OS390] get_table_comment -> schema={current_schema}, table={table_name}") + systbl = self.sys_tables + query = sql.select(systbl.c.remarks). \ + where(systbl.c.tabschema == current_schema). \ + where(systbl.c.tabname == table_name) + logger.debug(f"[OS390] get_table_comment SQL -> {query}") + comment = connection.execute(query).scalar() + logger.debug(f"[OS390] table comment -> {comment}") + return {'text': comment} + except Exception: + logger.exception("[OS390] Error in get_table_comment") + raise @reflection.cache + @log_entry_exit def get_table_names(self, connection, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - systbl = self.sys_tables - query = sql.select(systbl.c.tabname).\ - where(systbl.c.type == 'T').\ - where(systbl.c.tabschema == current_schema).\ - order_by(systbl.c.tabname) - return [self.normalize_name(r[0]) for r in connection.execute(query)] + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + logger.debug(f"[OS390] get_table_names -> schema={current_schema}") + systbl = self.sys_tables + query = sql.select(systbl.c.tabname). \ + where(systbl.c.type == 'T'). \ + where(systbl.c.tabschema == current_schema). \ + order_by(systbl.c.tabname) + logger.debug(f"[OS390] get_table_names SQL -> {query}") + result = [self.normalize_name(r[0]) for r in connection.execute(query)] + logger.debug(f"[OS390] tables found -> count={len(result)}") + return result + except Exception: + logger.exception("[OS390] Error in get_table_names") + raise @reflection.cache + @log_entry_exit def get_view_names(self, connection, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - - query = sql.select(self.sys_views.c.viewname).\ - where(self.sys_views.c.viewschema == current_schema).\ - order_by(self.sys_views.c.viewname) - - return [self.normalize_name(r[0]) for r in connection.execute(query)] + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + logger.debug(f"[OS390] get_view_names -> schema={current_schema}") + query = sql.select(self.sys_views.c.viewname). \ + where(self.sys_views.c.viewschema == current_schema). \ + order_by(self.sys_views.c.viewname) + logger.debug(f"[OS390] get_view_names SQL -> {query}") + result = [self.normalize_name(r[0]) for r in connection.execute(query)] + logger.debug(f"[OS390] views found -> count={len(result)}") + return result + except Exception: + logger.exception("[OS390] Error in get_view_names") + raise @reflection.cache + @log_entry_exit def get_view_definition(self, connection, viewname, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - viewname = self.denormalize_name(viewname) - - query = sql.select(self.sys_views.c.text).\ - where(self.sys_views.c.viewschema == current_schema).\ - where(self.sys_views.c.viewname == viewname) - - return connection.execute(query).scalar() + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + viewname = self.denormalize_name(viewname) + logger.debug( + f"[OS390] get_view_definition -> " + f"schema={current_schema}, view={viewname}" + ) + query = sql.select(self.sys_views.c.text). \ + where(self.sys_views.c.viewschema == current_schema). \ + where(self.sys_views.c.viewname == viewname) + logger.debug(f"[OS390] get_view_definition SQL -> {query}") + result = connection.execute(query).scalar() + logger.debug( + f"[OS390] view definition length -> " + f"{len(result) if result else 0}" + ) + return result + except Exception: + logger.exception("[OS390] Error in get_view_definition") + raise @reflection.cache + @log_entry_exit def get_columns(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - syscols = self.sys_columns - - query = sql.select(syscols.c.colname, syscols.c.typename, - syscols.c.defaultval, syscols.c.nullable, - syscols.c.length, syscols.c.scale, - syscols.c.generated, syscols.c.remark).\ - where(and_( + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug(f"[OS390] get_columns -> schema={current_schema}, table={table_name}") + syscols = self.sys_columns + query = sql.select(syscols.c.colname, syscols.c.typename, + syscols.c.defaultval, syscols.c.nullable, + syscols.c.length, syscols.c.scale, + syscols.c.generated, syscols.c.remark). \ + where(and_( syscols.c.tabschema == current_schema, - syscols.c.tabname == table_name)).\ - order_by(syscols.c.colno) - sa_columns = [] - for r in connection.execute(query): - coltype = r[1].upper() - if coltype in ['DECIMAL', 'NUMERIC']: - coltype = self.ischema_names.get(coltype)(int(r[4]), int(r[5])) - elif coltype in ['CHARACTER', 'CHAR', 'VARCHAR', - 'GRAPHIC', 'VARGRAPHIC']: - coltype = self.ischema_names.get(coltype)(int(r[4])) - else: - try: - coltype = self.ischema_names[coltype] - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (coltype, r[0])) - coltype = coltype = sa_types.NULLTYPE - - sa_columns.append({ + syscols.c.tabname == table_name)). \ + order_by(syscols.c.colno) + logger.debug(f"[OS390] get_columns SQL -> {query}") + sa_columns = [] + for r in connection.execute(query): + rowtype = r[1].upper() + logger.debug(f"[OS390] Processing column -> name={r[0]}, raw_type={rowtype}") + if rowtype in ['DECIMAL', 'NUMERIC']: + coltype = self.ischema_names.get(rowtype)(int(r[4]), int(r[5])) + elif rowtype in ['CHARACTER', 'CHAR', 'VARCHAR', + 'GRAPHIC', 'VARGRAPHIC']: + coltype = self.ischema_names.get(rowtype)(int(r[4])) + else: + try: + coltype = self.ischema_names[rowtype] + except KeyError: + logger.warning(f"[OS390] Unknown type '{rowtype}' for column '{r[0]}'") + util.warn( + "Did not recognize type '%s' of column '%s'" % + (rowtype, r[0]) + ) + coltype = sa_types.NULLTYPE + sa_columns.append({ 'name': self.normalize_name(r[0]), 'type': coltype, 'nullable': r[3] == 'Y', 'default': r[2] or None, - 'autoincrement': (r[2] == 'J') and (r[2] != ' ') , + 'autoincrement': (r[2] == 'J') and (r[2] != ' '), 'comment': r[7] or None, }) - return sa_columns + logger.debug(f"[OS390] get_columns completed -> count={len(sa_columns)}") + return sa_columns + except Exception: + logger.exception("[OS390] Error in get_columns") + raise @reflection.cache + @log_entry_exit def get_pk_constraint(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - sysindexes = self.sys_columns - col_finder = re.compile(r"(\w+)") - query = sql.select(sysindexes.c.colname).\ - where(and_( + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug(f"[OS390] get_pk_constraint -> schema={current_schema}, table={table_name}") + sysindexes = self.sys_columns + col_finder = re.compile(r"(\w+)") + query = sql.select(sysindexes.c.colname). \ + where(and_( sysindexes.c.tabschema == current_schema, sysindexes.c.tabname == table_name, - sysindexes.c.keyseq > 0)).\ - order_by(sysindexes.c.tabschema, sysindexes.c.tabname) - pk_columns = [] - for r in connection.execute(query): - cols = col_finder.findall(r[0]) - pk_columns.extend(cols) - return {"constrained_columns": [self.normalize_name(col) for col in pk_columns], "name": None} + sysindexes.c.keyseq > 0)). \ + order_by(sysindexes.c.tabschema, sysindexes.c.tabname) + logger.debug(f"[OS390] get_pk_constraint SQL -> {query}") + pk_columns = [] + for r in connection.execute(query): + cols = col_finder.findall(r[0]) + pk_columns.extend(cols) + result = { + "constrained_columns": [self.normalize_name(col) for col in pk_columns], + "name": None + } + logger.debug(f"[OS390] get_pk_constraint result -> {result}") + return result + except Exception: + logger.exception("[OS390] Error in get_pk_constraint") + raise @reflection.cache + @log_entry_exit def get_primary_keys(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - sysindexes = self.sys_columns - col_finder = re.compile(r"(\w+)") - query = sql.select(sysindexes.c.colname).\ - where(and_( + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug(f"[OS390] get_primary_keys -> schema={current_schema}, table={table_name}") + sysindexes = self.sys_columns + col_finder = re.compile(r"(\w+)") + query = sql.select(sysindexes.c.colname). \ + where(and_( sysindexes.c.tabschema == current_schema, sysindexes.c.tabname == table_name, - sysindexes.c.keyseq > 0)).\ - order_by(sysindexes.c.tabschema, sysindexes.c.tabname) - pk_columns = [] - for r in connection.execute(query): - cols = col_finder.findall(r[0]) - pk_columns.extend(cols) - return [self.normalize_name(col) for col in pk_columns] + sysindexes.c.keyseq > 0)). \ + order_by(sysindexes.c.tabschema, sysindexes.c.tabname) + logger.debug(f"[OS390] get_primary_keys SQL -> {query}") + pk_columns = [] + for r in connection.execute(query): + cols = col_finder.findall(r[0]) + pk_columns.extend(cols) + result = [self.normalize_name(col) for col in pk_columns] + logger.debug(f"[OS390] get_primary_keys result -> {result}") + return result + except Exception: + logger.exception("[OS390] Error in get_primary_keys") + raise @reflection.cache + @log_entry_exit def get_foreign_keys(self, connection, table_name, schema=None, **kw): - default_schema = self.default_schema_name - current_schema = self.denormalize_name(schema or default_schema) - default_schema = self.normalize_name(default_schema) - table_name = self.denormalize_name(table_name) - sysfkeys = self.sys_foreignkeys - sysrels = self.sys_rels - syscolspk = self.sys_columns - sysindex = self.sys_indexes - query = sql.select(sysrels.c.fkname, sysrels.c.fktabschema, - sysrels.c.fktabname, sysfkeys.c.fkcolname, - sysindex.c.indname, sysrels.c.pktabschema, - sysrels.c.pktabname, syscolspk.c.colname).\ - where(and_( + try: + default_schema = self.default_schema_name + current_schema = self.denormalize_name(schema or default_schema) + default_schema = self.normalize_name(default_schema) + table_name = self.denormalize_name(table_name) + logger.debug(f"[OS390] get_foreign_keys -> schema={current_schema}, table={table_name}") + sysfkeys = self.sys_foreignkeys + sysrels = self.sys_rels + syscolspk = self.sys_columns + sysindex = self.sys_indexes + query = sql.select( + sysrels.c.fkname, sysrels.c.fktabschema, + sysrels.c.fktabname, sysfkeys.c.fkcolname, + sysindex.c.indname, sysrels.c.pktabschema, + sysrels.c.pktabname, syscolspk.c.colname). \ + where(and_( sysrels.c.fktabschema == current_schema, sysrels.c.fktabname == table_name, sysrels.c.fktabname == sysfkeys.c.fktabname, sysrels.c.pktabname == syscolspk.c.tabname, - syscolspk.c.tabname == sysindex.c.tabname,syscolspk.c.keyseq > 0)).\ - order_by(sysfkeys.c.colno) - - fschema = {} - for r in connection.execute(query): - if not (r[0]) in fschema: - referred_schema = self.normalize_name(r[5]) - - # if no schema specified and referred schema here is the - # default, then set to None - if schema is None and \ - referred_schema == default_schema: - referred_schema = None - - fschema[r[0]] = { - 'name': self.normalize_name(r[0]), - 'constrained_columns': [self.normalize_name(r[3])], - 'referred_schema': referred_schema, - 'referred_table': self.normalize_name(r[6]), - 'referred_columns': [self.normalize_name(r[7])]} - else: - fschema[r[0]]['constrained_columns'].append(self.normalize_name(r[3])) - fschema[r[0]]['referred_columns'].append(self.normalize_name(r[7])) - return [value for key, value in fschema.items()] + syscolspk.c.tabname == sysindex.c.tabname, + syscolspk.c.keyseq > 0)). \ + order_by(sysfkeys.c.colno) + logger.debug(f"[OS390] get_foreign_keys SQL -> {query}") + fschema = {} + for r in connection.execute(query): + if r[0] not in fschema: + referred_schema = self.normalize_name(r[5]) + # if no schema specified and referred schema here is the + # default, then set to None + if schema is None and referred_schema == default_schema: + referred_schema = None + fschema[r[0]] = { + 'name': self.normalize_name(r[0]), + 'constrained_columns': [self.normalize_name(r[3])], + 'referred_schema': referred_schema, + 'referred_table': self.normalize_name(r[6]), + 'referred_columns': [self.normalize_name(r[7])] + } + else: + fschema[r[0]]['constrained_columns'].append(self.normalize_name(r[3])) + fschema[r[0]]['referred_columns'].append(self.normalize_name(r[7])) + result = [value for key, value in fschema.items()] + logger.debug(f"[OS390] get_foreign_keys result count -> {len(result)}") + return result + except Exception: + logger.exception("[OS390] Error in get_foreign_keys") + raise @reflection.cache + @log_entry_exit def get_incoming_foreign_keys(self, connection, table_name, schema=None, **kw): - default_schema = self.default_schema_name - current_schema = self.denormalize_name(schema or default_schema) - default_schema = self.normalize_name(default_schema) - table_name = self.denormalize_name(table_name) - sysfkeys = self.sys_foreignkeys - sysrels = self.sys_rels - syscolspk = self.sys_columns - sysindex = self.sys_indexes - query = sql.select(sysrels.c.fkname, sysrels.c.fktabschema, - sysrels.c.fktabname, sysfkeys.c.fkcolname, - sysindex.c.indname, sysrels.c.pktabschema, - sysrels.c.pktabname, syscolspk.c.colname).\ - where(and_( + try: + default_schema = self.default_schema_name + current_schema = self.denormalize_name(schema or default_schema) + default_schema = self.normalize_name(default_schema) + table_name = self.denormalize_name(table_name) + logger.debug(f"[OS390] get_incoming_foreign_keys -> schema={current_schema}, table={table_name}") + sysfkeys = self.sys_foreignkeys + sysrels = self.sys_rels + syscolspk = self.sys_columns + sysindex = self.sys_indexes + query = sql.select( + sysrels.c.fkname, sysrels.c.fktabschema, + sysrels.c.fktabname, sysfkeys.c.fkcolname, + sysindex.c.indname, sysrels.c.pktabschema, + sysrels.c.pktabname, syscolspk.c.colname). \ + where(and_( syscolspk.c.tabschema == current_schema, syscolspk.c.tabname == table_name, sysrels.c.fktabname == sysfkeys.c.fktabname, sysrels.c.pktabname == syscolspk.c.tabname, syscolspk.c.tabname == sysindex.c.tabname, - syscolspk.c.keyseq > 0)).\ - order_by(sysfkeys.c.colno) - - fschema = {} - for r in connection.execute(query): - if not fschema.has_key(r[0]): - constrained_schema = self.normalize_name(r[1]) - - # if no schema specified and referred schema here is the - # default, then set to None - if schema is None and \ - constrained_schema == default_schema: - constrained_schema = None - - fschema[r[0]] = { - 'name': self.normalize_name(r[0]), - 'constrained_schema': constrained_schema, - 'constrained_table': self.normalize_name(r[2]), - 'constrained_columns': [self.normalize_name(r[3])], - 'referred_schema': schema, - 'referred_table': self.normalize_name(r[6]), - 'referred_columns': [self.normalize_name(r[7])]} - else: - fschema[r[0]]['constrained_columns'].append(self.normalize_name(r[3])) - fschema[r[0]]['referred_columns'].append(self.normalize_name(r[7])) - return [value for key, value in fschema.items()] + syscolspk.c.keyseq > 0)). \ + order_by(sysfkeys.c.colno) + logger.debug(f"[OS390] get_incoming_foreign_keys SQL -> {query}") + fschema = {} + for r in connection.execute(query): + if r[0] not in fschema: + constrained_schema = self.normalize_name(r[1]) + # if no schema specified and referred schema here is the + # default, then set to None + if schema is None and constrained_schema == default_schema: + constrained_schema = None + fschema[r[0]] = { + 'name': self.normalize_name(r[0]), + 'constrained_schema': constrained_schema, + 'constrained_table': self.normalize_name(r[2]), + 'constrained_columns': [self.normalize_name(r[3])], + 'referred_schema': schema, + 'referred_table': self.normalize_name(r[6]), + 'referred_columns': [self.normalize_name(r[7])] + } + else: + fschema[r[0]]['constrained_columns'].append(self.normalize_name(r[3])) + fschema[r[0]]['referred_columns'].append(self.normalize_name(r[7])) + result = [value for key, value in fschema.items()] + logger.debug(f"[OS390] get_incoming_foreign_keys result count -> {len(result)}") + return result + except Exception: + logger.exception("[OS390] Error in get_incoming_foreign_keys") + raise @reflection.cache + @log_entry_exit def get_indexes(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - sysidx = self.sys_indexes - syscolpk = self.sys_columns - query = sql.select(sysidx.c.indname, syscolpk.c.colname, sysidx.c.uniquerule, sysidx.c.system_required).\ - where(and_( + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug(f"[OS390] get_indexes -> schema={current_schema}, table={table_name}") + sysidx = self.sys_indexes + syscolpk = self.sys_columns + query = sql.select( + sysidx.c.indname, syscolpk.c.colname, + sysidx.c.uniquerule, sysidx.c.system_required). \ + where(and_( sysidx.c.tabschema == current_schema, sysidx.c.tabname == table_name, syscolpk.c.colname == sysidx.c.tabname, - syscolpk.c.keyseq > 0)).\ - order_by(sysidx.c.tabname) - indexes = [] - col_finder = re.compile(r"(\w+)") - for r in connection.execute(query): - if r[2] != 'P': - if r[2] == 'U' and r[3] != 0: - continue - indexes.append({ + syscolpk.c.keyseq > 0)). \ + order_by(sysidx.c.tabname) + logger.debug(f"[OS390] get_indexes SQL -> {query}") + indexes = [] + col_finder = re.compile(r"(\w+)") + for r in connection.execute(query): + if r[2] != 'P': + if r[2] == 'U' and r[3] != 0: + continue + indexes.append({ 'name': self.normalize_name(r[0]), 'column_names': [self.normalize_name(col) - for col in col_finder.findall(r[1])], + for col in col_finder.findall(r[1])], 'unique': r[2] == 'U' }) - return indexes + logger.debug(f"[OS390] get_indexes result count -> {len(indexes)}") + return indexes + except Exception: + logger.exception("[OS390] Error in get_indexes") + raise @reflection.cache + @log_entry_exit def get_unique_constraints(self, connection, table_name, schema=None, **kw): - current_schema = self.denormalize_name(schema or self.default_schema_name) - table_name = self.denormalize_name(table_name) - syskeycol = self.sys_keycoluse - sysconst = self.sys_tabconst - query = ( - sql.select(syskeycol.c.constname, syskeycol.c.colname) - .select_from( - join( - syskeycol, - sysconst, - and_( - syskeycol.c.constname == sysconst.c.constname, - syskeycol.c.tabschema == sysconst.c.tabschema, - syskeycol.c.tabname == sysconst.c.tabname, - ), + try: + current_schema = self.denormalize_name(schema or self.default_schema_name) + table_name = self.denormalize_name(table_name) + logger.debug(f"[OS390] get_unique_constraints -> schema={current_schema}, table={table_name}") + syskeycol = self.sys_keycoluse + sysconst = self.sys_tabconst + query = ( + sql.select(syskeycol.c.constname, syskeycol.c.colname) + .select_from( + join( + syskeycol, + sysconst, + and_( + syskeycol.c.constname == sysconst.c.constname, + syskeycol.c.tabschema == sysconst.c.tabschema, + syskeycol.c.tabname == sysconst.c.tabname, + ), + ) ) - ) - .where( - and_( - sysconst.c.tabname == table_name, - sysconst.c.tabschema == current_schema, - sysconst.c.type == "U", + .where( + and_( + sysconst.c.tabname == table_name, + sysconst.c.tabschema == current_schema, + sysconst.c.type == "U", + ) ) + .order_by(syskeycol.c.constname) ) - .order_by(syskeycol.c.constname) - ) - uniqueConsts = [] - currConst = None - for r in connection.execute(query): - if currConst == r[0]: - uniqueConsts[-1]["column_names"].append(self.normalize_name(r[1])) - else: - currConst = r[0] - uniqueConsts.append( - { + logger.debug(f"[OS390] get_unique_constraints SQL -> {query}") + uniqueConsts = [] + currConst = None + for r in connection.execute(query): + if currConst == r[0]: + uniqueConsts[-1]["column_names"].append(self.normalize_name(r[1])) + else: + currConst = r[0] + uniqueConsts.append({ "name": self.normalize_name(currConst), "column_names": [self.normalize_name(r[1])], - } - ) - return uniqueConsts + }) + logger.debug(f"[OS390] get_unique_constraints result count -> {len(uniqueConsts)}") + return uniqueConsts + except Exception: + logger.exception("[OS390] Error in get_unique_constraints") + raise