tvm_ffi/
error.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19use crate::derive::{Object, ObjectRef};
20use crate::object::{Object, ObjectArc};
21use std::ffi::c_void;
22use tvm_ffi_sys::TVMFFIBacktraceUpdateMode::kTVMFFIBacktraceUpdateModeAppend;
23use tvm_ffi_sys::{
24    TVMFFIByteArray, TVMFFIErrorCell, TVMFFIErrorCreate, TVMFFIErrorMoveFromRaised,
25    TVMFFIErrorSetRaised, TVMFFIObjectHandle, TVMFFITypeIndex,
26};
27
28/// Error kind, wraps in a struct to be explicit
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct ErrorKind<'a>(&'a str);
31
32impl<'a> ErrorKind<'a> {
33    pub fn as_str(&self) -> &str {
34        self.0
35    }
36}
37
38impl<'a> std::fmt::Display for ErrorKind<'a> {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "{}", self.0)
41    }
42}
43
44pub const VALUE_ERROR: ErrorKind = ErrorKind("ValueError");
45pub const TYPE_ERROR: ErrorKind = ErrorKind("TypeError");
46pub const RUNTIME_ERROR: ErrorKind = ErrorKind("RuntimeError");
47pub const ATTRIBUTE_ERROR: ErrorKind = ErrorKind("AttributeError");
48pub const KEY_ERROR: ErrorKind = ErrorKind("KeyError");
49pub const INDEX_ERROR: ErrorKind = ErrorKind("IndexError");
50
51/// error object
52#[repr(C)]
53#[derive(Object)]
54#[type_key = "ffi.Error"]
55#[type_index(TVMFFITypeIndex::kTVMFFIError)]
56pub struct ErrorObj {
57    object: Object,
58    cell: TVMFFIErrorCell,
59}
60
61/// Error reference class
62#[derive(Clone, ObjectRef)]
63pub struct Error {
64    data: ObjectArc<ErrorObj>,
65}
66
67/// Default result that uses Error as the error type
68pub type Result<T, E = Error> = std::result::Result<T, E>;
69
70impl Error {
71    pub fn new(kind: ErrorKind<'_>, message: &str, traceback: &str) -> Self {
72        unsafe {
73            let kind_data = TVMFFIByteArray::from_str(kind.as_str());
74            let message_data = TVMFFIByteArray::from_str(message);
75            let traceback_data = TVMFFIByteArray::from_str(traceback);
76            let mut error_handle: TVMFFIObjectHandle = std::ptr::null_mut();
77            let ret = TVMFFIErrorCreate(
78                &kind_data,
79                &message_data,
80                &traceback_data,
81                &mut error_handle,
82            );
83            assert_eq!(ret, 0, "Failed to create error object");
84            let error_obj = ObjectArc::from_raw(error_handle as *const ErrorObj);
85            Self { data: error_obj }
86        }
87    }
88
89    /// Create a new error by moving from raised error
90    ///
91    /// # Returns
92    /// The error from the raised error
93    pub fn from_raised() -> Self {
94        unsafe {
95            let mut error_handle: TVMFFIObjectHandle = std::ptr::null_mut();
96            TVMFFIErrorMoveFromRaised(&mut error_handle as *mut TVMFFIObjectHandle);
97            assert!(
98                !error_handle.is_null(),
99                "Calling Error::from_raised but no error was raised"
100            );
101            let error_obj = ObjectArc::from_raw(error_handle as *const ErrorObj);
102            Self { data: error_obj }
103        }
104    }
105
106    /// Set the error as raised
107    ///
108    /// # Arguments
109    /// * `error` - The error to set as raised
110    pub fn set_raised(error: &Self) {
111        unsafe {
112            TVMFFIErrorSetRaised(ObjectArc::as_raw(&error.data) as TVMFFIObjectHandle);
113        }
114    }
115
116    /// Get the kind of the error
117    ///
118    /// # Returns
119    /// The kind of the error
120    pub fn kind(&self) -> ErrorKind<'_> {
121        ErrorKind(&self.data.cell.kind.as_str())
122    }
123
124    /// Get the message of the error
125    ///
126    /// # Returns
127    /// The message of the error
128    pub fn message(&self) -> &str {
129        self.data.cell.message.as_str()
130    }
131
132    /// Get the backtrace of the error
133    ///
134    /// # Returns
135    /// The backtrace of the error
136    pub fn backtrace(&self) -> &str {
137        self.data.cell.backtrace.as_str()
138    }
139
140    /// Get the traceback of the error in the order of most recent call last
141    ///
142    /// # Returns
143    /// The traceback of the error
144    pub fn traceback_most_recent_call_last(&self) -> String {
145        let backtrace = self.backtrace();
146        let backtrace_lines = backtrace.split('\n');
147        let mut traceback = String::new();
148        for line in backtrace_lines.rev() {
149            traceback.push_str(line);
150            traceback.push('\n');
151        }
152        traceback
153    }
154
155    /// Append the backtrace to the error
156    ///
157    /// # Arguments
158    /// * `this` - The error to append the backtrace to
159    /// * `backtrace` - The backtrace to append
160    ///
161    /// # Returns
162    /// The error with the appended backtrace
163    pub fn with_appended_backtrace(this: Self, backtrace: &str) -> Self {
164        if ObjectArc::strong_count(&this.data) == 1 {
165            // this is the only reference to the error
166            // we can safely mutate the error
167            unsafe {
168                let backtrace_data = TVMFFIByteArray::from_str(backtrace);
169                (this.data.cell.update_backtrace)(
170                    ObjectArc::as_raw(&this.data) as *mut ErrorObj as *mut c_void,
171                    &backtrace_data,
172                    kTVMFFIBacktraceUpdateModeAppend as i32,
173                );
174                this
175            }
176        } else {
177            // we need to create a new error because there is more than one unique reference
178            // to the error
179            let mut new_backtrace = String::new();
180            new_backtrace.push_str(this.backtrace());
181            new_backtrace.push_str(backtrace);
182            return Error::new(this.kind(), this.message(), &new_backtrace);
183        }
184    }
185}
186
187impl std::fmt::Display for Error {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        write!(
190            f,
191            "Traceback (most recent call last):\n{}{}: {}",
192            self.traceback_most_recent_call_last(),
193            self.kind().as_str(),
194            self.message()
195        )
196    }
197}
198
199impl std::fmt::Debug for Error {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        std::fmt::Display::fmt(self, f)
202    }
203}
204
205impl std::error::Error for Error {}