# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Error handling."""
from __future__ import annotations
import ast
import re
import sys
import types
from typing import Any
from . import core
def _parse_backtrace(backtrace: str) -> list[tuple[str, int, str]]:
"""Parse the backtrace string into a list of (filename, lineno, func).
Parameters
----------
backtrace : str
The backtrace string.
Returns
-------
result : List[Tuple[str, int, str]]
The list of (filename, lineno, func)
"""
pattern = r'File "(.+?)", line (\d+), in (.+)'
result = []
for line in backtrace.split("\n"):
match = re.match(pattern, line.strip())
if match:
try:
filename = match.group(1)
lineno = int(match.group(2))
func = match.group(3)
result.append((filename, lineno, func))
except ValueError:
pass
return result
class TracebackManager:
"""Helper to manage traceback generation."""
def __init__(self) -> None:
"""Initialize the traceback manager and its cache."""
self._code_cache: dict[tuple[str, int, str], types.CodeType] = {}
def _get_cached_code_object(self, filename: str, lineno: int, func: str) -> types.CodeType:
# Hack to create a code object that points to the correct
# line number and function name
key = (filename, lineno, func)
# cache the code object to avoid re-creating it
if key in self._code_cache:
return self._code_cache[key]
# Parse to AST and zero out column info
# since column info are not accurate in original trace
tree = ast.parse("_getframe()", filename=filename, mode="eval")
for node in ast.walk(tree):
if hasattr(node, "col_offset"):
node.col_offset = 0
if hasattr(node, "end_col_offset"):
node.end_col_offset = 0
# call into get frame, bt changes the context
code_object = compile(tree, filename, "eval")
# replace the function name and line number
code_object = code_object.replace(co_name=func, co_firstlineno=lineno)
self._code_cache[key] = code_object
return code_object
def _create_frame(self, filename: str, lineno: int, func: str) -> types.FrameType:
"""Create a frame object from the filename, lineno, and func."""
code_object = self._get_cached_code_object(filename, lineno, func)
# call into get frame, but changes the context so the code
# points to the correct frame
context = {"_getframe": sys._getframe}
# pylint: disable=eval-used
return eval(code_object, context, context)
def append_traceback(
self,
tb: types.TracebackType | None,
filename: str,
lineno: int,
func: str,
) -> types.TracebackType:
"""Append a traceback to the given traceback.
Parameters
----------
tb : types.TracebackType
The traceback to append to.
filename : str
The filename of the traceback
lineno : int
The line number of the traceback
func : str
The function name of the traceback
Returns
-------
new_tb : types.TracebackType
The new traceback with the appended frame.
"""
frame = self._create_frame(filename, lineno, func)
return types.TracebackType(tb, frame, frame.f_lasti, lineno)
_TRACEBACK_MANAGER = TracebackManager()
def _with_append_backtrace(py_error: BaseException, backtrace: str) -> BaseException:
"""Append the backtrace to the py_error and return it."""
tb = py_error.__traceback__
for filename, lineno, func in _parse_backtrace(backtrace):
tb = _TRACEBACK_MANAGER.append_traceback(tb, filename, lineno, func)
return py_error.with_traceback(tb)
def _traceback_to_backtrace_str(tb: types.TracebackType | None) -> str:
"""Convert the traceback to a string."""
lines = []
while tb is not None:
frame = tb.tb_frame
lineno = tb.tb_lineno
filename = frame.f_code.co_filename
funcname = frame.f_code.co_name
lines.append(f' File "{filename}", line {lineno}, in {funcname}\n')
tb = tb.tb_next
# needs to reverse the order of the lines so backtrace stores in
# the reverse order of python traceback
return "".join(reversed(lines))
core._WITH_APPEND_BACKTRACE = _with_append_backtrace
core._TRACEBACK_TO_BACKTRACE_STR = _traceback_to_backtrace_str
[docs]
def register_error(
name_or_cls: str | type | None = None,
cls: type | None = None,
) -> Any:
"""Register an error class so it can be recognized by the ffi error handler.
Parameters
----------
name_or_cls : str or class
The name of the error class.
cls : class
The class to register.
Returns
-------
fregister : function
Register function if f is not specified.
Examples
--------
.. code-block:: python
@tvm.error.register_error
class MyError(RuntimeError):
pass
err_inst = tvm.error.create_ffi_error("MyError: xyz")
assert isinstance(err_inst, MyError)
"""
if callable(name_or_cls):
cls = name_or_cls
name_or_cls = cls.__name__
def register(mycls: type) -> type:
"""Register the error class name with the FFI core."""
err_name = name_or_cls if isinstance(name_or_cls, str) else mycls.__name__
core.ERROR_NAME_TO_TYPE[err_name] = mycls
core.ERROR_TYPE_TO_NAME[mycls] = err_name
return mycls
if cls is None:
return register
return register(cls)
register_error("RuntimeError", RuntimeError)
register_error("ValueError", ValueError)
register_error("TypeError", TypeError)
register_error("AttributeError", AttributeError)
register_error("KeyError", KeyError)
register_error("IndexError", IndexError)
register_error("AssertionError", AssertionError)