tvm
profiling.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RUNTIME_PROFILING_H_
25 #define TVM_RUNTIME_PROFILING_H_
26 
28 #include <tvm/runtime/device_api.h>
29 #include <tvm/runtime/object.h>
31 #include <tvm/runtime/registry.h>
32 
33 #include <stack>
34 #include <string>
35 #include <unordered_map>
36 #include <utility>
37 #include <vector>
38 
39 namespace tvm {
40 
41 namespace runtime {
42 
49 class TimerNode : public Object {
50  public:
55  virtual void Start() = 0;
60  virtual void Stop() = 0;
71  virtual int64_t SyncAndGetElapsedNanos() = 0;
72 
73  virtual ~TimerNode() {}
74 
75  static constexpr const char* _type_key = "TimerNode";
77 };
78 
85 class Timer : public ObjectRef {
86  public:
139  static TVM_DLL Timer Start(Device dev);
140 
142 };
143 
152 
153 namespace profiling {
157 struct DeviceWrapperNode : public Object {
160 
162  explicit DeviceWrapperNode(Device device) : device(device) {}
163 
164  static constexpr const char* _type_key = "runtime.profiling.DeviceWrapper";
166 };
167 
169 class DeviceWrapper : public ObjectRef {
170  public:
171  explicit DeviceWrapper(Device dev) { data_ = make_object<DeviceWrapperNode>(dev); }
173 };
174 
177 class ReportNode : public Object {
178  public:
199  String AsCSV() const;
214  String AsTable(bool sort = true, bool aggregate = true, bool compute_col_sums = true) const;
247  String AsJSON() const;
248 
249  static constexpr const char* _type_key = "runtime.profiling.Report";
251 };
252 
253 class Report : public ObjectRef {
254  public:
259  explicit Report(Array<Map<String, ObjectRef>> calls,
260  Map<String, Map<String, ObjectRef>> device_metrics);
261 
266  static Report FromJSON(String json);
268 };
269 
288 class MetricCollectorNode : public Object {
289  public:
294  virtual void Init(Array<DeviceWrapper> devs) = 0;
301  virtual ObjectRef Start(Device dev) = 0;
307  virtual Map<String, ObjectRef> Stop(ObjectRef obj) = 0;
308 
309  virtual ~MetricCollectorNode() {}
310 
311  static constexpr const char* _type_key = "runtime.profiling.MetricCollector";
313 };
314 
316 class MetricCollector : public ObjectRef {
317  public:
319 };
320 
322 struct CallFrame {
330  std::unordered_map<std::string, ObjectRef> extra_metrics;
334  std::vector<std::pair<MetricCollector, ObjectRef>> extra_collectors;
335 };
336 
356 class Profiler {
357  public:
370  explicit Profiler(std::vector<Device> devs, std::vector<MetricCollector> metric_collectors);
375  void Start();
380  void Stop();
391  void StartCall(String name, Device dev,
392  std::unordered_map<std::string, ObjectRef> extra_metrics = {});
397  void StopCall(std::unordered_map<std::string, ObjectRef> extra_metrics = {});
403  profiling::Report Report(bool aggregate = true, bool sort = true);
407  bool IsRunning() const { return is_running_; }
408 
409  private:
410  std::vector<Device> devs_;
411  bool is_running_{false};
412  std::vector<CallFrame> calls_;
413  std::stack<CallFrame> in_flight_;
414  std::vector<MetricCollector> collectors_;
415 };
416 
417 /* \brief A duration in time. */
418 class DurationNode : public Object {
419  public:
420  /* The duration as a floating point number of microseconds. */
421  double microseconds;
422 
423  /* \brief Construct a new duration.
424  * \param a The duration in microseconds.
425  */
426  explicit DurationNode(double a) : microseconds(a) {}
427 
428  static constexpr const char* _type_key = "runtime.profiling.Duration";
430 };
431 
432 /* A percentage of something */
433 class PercentNode : public Object {
434  public:
435  /* The percent as a floating point value out of 100%. i.e. if `percent` is 10 then we have 10%. */
436  double percent;
437 
438  /* \brief Construct a new percentage.
439  * \param a The percentage out of 100.
440  */
441  explicit PercentNode(double a) : percent(a) {}
442 
443  static constexpr const char* _type_key = "runtime.profiling.Percent";
445 };
446 
447 /* A count of something */
448 class CountNode : public Object {
449  public:
450  /* The actual count */
451  int64_t value;
452 
453  /* \brief Construct a new count.
454  * \param a The count.
455  */
456  explicit CountNode(int64_t a) : value(a) {}
457 
458  static constexpr const char* _type_key = "runtime.profiling.Count";
460 };
461 
466 String ShapeString(const std::vector<NDArray>& shapes);
472 String ShapeString(NDArray shape, DLDataType dtype);
478 String ShapeString(const std::vector<int64_t>& shape, DLDataType dtype);
479 
480 } // namespace profiling
481 } // namespace runtime
482 } // namespace tvm
483 
484 #endif // TVM_RUNTIME_PROFILING_H_
Definition: profiling.h:448
double microseconds
Definition: profiling.h:421
Definition: profiling.h:418
static constexpr const char * _type_key
Definition: profiling.h:75
bool IsRunning() const
Check if the profiler is currently running.
Definition: profiling.h:407
virtual int64_t SyncAndGetElapsedNanos()=0
Synchronize timer state and return elapsed time between Start and Stop.
CountNode(int64_t a)
Definition: profiling.h:456
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
Device dev
Definition: profiling.h:324
DeviceWrapperNode(Device device)
Definition: profiling.h:162
std::vector< std::pair< MetricCollector, ObjectRef > > extra_collectors
Definition: profiling.h:334
double percent
Definition: profiling.h:436
virtual void Start()=0
Start the timer.
Map< String, Map< String, ObjectRef > > device_metrics
Metrics collected for the entire run of the model on a per-device basis.
Definition: profiling.h:194
base class of all object containers.
Definition: object.h:165
Timer DefaultTimer(Device dev)
Default timer if one does not exist for the device.
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:59
Interface for user defined profiling metric collection.
Definition: profiling.h:288
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:737
virtual void Stop()=0
Stop the timer.
Wrapper for MetricCollectorNode.
Definition: profiling.h:316
DurationNode(double a)
Definition: profiling.h:426
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
std::unordered_map< std::string, ObjectRef > extra_metrics
Definition: profiling.h:330
Array< Map< String, ObjectRef > > calls
A list of function calls and the metrics recorded for that call.
Definition: profiling.h:186
Timer timer
Definition: profiling.h:328
Base class for all implementations.
Definition: profiling.h:49
Reference to string objects.
Definition: string.h:129
Data collected from a profiling run. Includes per-call metrics and per-device metrics.
Definition: profiling.h:177
Abstract device memory management API.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1608
DLDevice Device
Definition: ndarray.h:43
Definition: profiling.h:253
Base class of all object reference.
Definition: object.h:504
Definition: profiling.h:322
TVM_DECLARE_BASE_OBJECT_INFO(TimerNode, Object)
Definition: profiling.h:433
A managed object in the TVM runtime.
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:664
Device device
Definition: profiling.h:159
int64_t value
Definition: profiling.h:451
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1235
Wrapper for Device because Device is not passable across the PackedFunc interface.
Definition: profiling.h:157
virtual ~TimerNode()
Definition: profiling.h:73
virtual ~MetricCollectorNode()
Definition: profiling.h:309
Wrapper for Device.
Definition: profiling.h:169
Definition: profiling.h:356
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:721
String ShapeString(const std::vector< int64_t > &shape, DLDataType dtype)
String representation of a shape encoded as a vector.
String name
Definition: profiling.h:326
Type-erased function used across TVM API.
Timer for a specific device.
Definition: profiling.h:85
This file defines the TVM global function registry.
DeviceWrapper(Device dev)
Definition: profiling.h:171
PercentNode(double a)
Definition: profiling.h:441