tvm
virtual_device.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef TVM_TARGET_VIRTUAL_DEVICE_H_
27 #define TVM_TARGET_VIRTUAL_DEVICE_H_
28 
29 #include <tvm/ir/transform.h>
30 #include <tvm/target/target.h>
31 
32 #include <string>
33 #include <unordered_set>
34 #include <utility>
35 
36 namespace tvm {
37 
45 using MemoryScope = ffi::String;
46 
47 // NOTE: cannot use enum as they are out of bound of the original enum
48 // and results in an undefined behavior
49 // A 'null' device type, does not correspond to any DLDeviceType enum.
50 // TODO(mbs): This is to help us as we transition away from representing the 'homogenous' case
51 // as a singleton target map indexed by the invalid DLDeviceType '0'.
52 constexpr int kNullDeviceType = 0;
53 
54 // An 'invalid' device type, does not correspond to any DLDeviceType enum.
55 constexpr int kInvalidDeviceType = -1;
56 
172 class VirtualDeviceNode : public AttrsNodeReflAdapter<VirtualDeviceNode> {
173  private:
185  int /* actually DLDeviceType */ device_type_int;
186 
187  public:
188  DLDeviceType device_type() const { return static_cast<DLDeviceType>(device_type_int); }
189 
197 
206 
213 
218  bool IsFullyUnconstrained() const {
219  return !target.defined() && device_type() == kInvalidDeviceType && virtual_device_id == -1 &&
220  memory_scope.empty();
221  }
222 
227  bool IsFullyConstrained() const {
228  return target.defined() && virtual_device_id != -1 && !memory_scope.empty();
229  }
230 
237  Device ToDevice() const {
238  TVM_FFI_ICHECK(device_type_int != kInvalidDeviceType);
239  TVM_FFI_ICHECK(virtual_device_id != -1);
240  Device device;
241  device.device_type = device_type();
242  device.device_id = virtual_device_id;
243  return device;
244  }
245 
246  static void RegisterReflection() {
247  namespace refl = tvm::ffi::reflection;
248  refl::ObjectDef<VirtualDeviceNode>()
249  .def_ro("device_type_int", &VirtualDeviceNode::device_type_int,
250  "The type of the virtual device.", refl::DefaultValue(kInvalidDeviceType))
251  .def_ro("virtual_device_id", &VirtualDeviceNode::virtual_device_id,
252  "The device id of the virtual device.", refl::DefaultValue(-1))
253  .def_ro("target", &VirtualDeviceNode::target,
254  "The target describing how to compile for the virtual device.",
255  refl::DefaultValue(Target()))
256  .def_ro("memory_scope", &VirtualDeviceNode::memory_scope,
257  "The area of memory w.r.t. the virtual device where data is stored.",
258  refl::DefaultValue(""));
259  }
261 
262  friend class VirtualDevice;
263 };
264 
268 class VirtualDevice : public ffi::ObjectRef {
269  public:
280  TVM_DLL explicit VirtualDevice(int device_type_int = kInvalidDeviceType,
281  int virtual_device_id = -1, Target target = {},
282  MemoryScope memory_scope = {});
283 
286 
291  static VirtualDevice ForDeviceType(DLDeviceType device_type, int virtual_device_id = -1) {
292  TVM_FFI_ICHECK_GT(device_type, 0);
293  return VirtualDevice(device_type, virtual_device_id);
294  }
295  static VirtualDevice ForDeviceType(int device_type, int virtual_device_id = -1) {
296  return ForDeviceType(static_cast<DLDeviceType>(device_type), virtual_device_id);
297  }
298  static VirtualDevice ForDeviceType(const Integer& device_type, int virtual_device_id = -1) {
299  return ForDeviceType(static_cast<int>(device_type->value), virtual_device_id);
300  }
301 
303  static VirtualDevice ForDevice(const Device& device) {
304  return ForDeviceType(device.device_type, device.device_id);
305  }
306 
308  static VirtualDevice ForDeviceAndTarget(const Device& device, Target target) {
309  return VirtualDevice(device.device_type, device.device_id, std::move(target));
310  }
311 
313  static VirtualDevice ForTarget(Target target) {
314  DLDeviceType device_type = static_cast<DLDeviceType>(target->GetTargetDeviceType());
315  return VirtualDevice(device_type, /*virtual_device_id=*/0, std::move(target));
316  }
317 
319  static VirtualDevice ForMemoryScope(MemoryScope memory_scope) {
320  return VirtualDevice(kInvalidDeviceType, -1, {}, std::move(memory_scope));
321  }
322 
324  TVM_DLL static VirtualDevice ForDeviceTargetAndMemoryScope(const Device& device, Target target,
325  MemoryScope memory_scope) {
326  return VirtualDevice(device.device_type, device.device_id, std::move(target),
327  std::move(memory_scope));
328  }
329 
335  TVM_DLL static ffi::Optional<VirtualDevice> Join(const VirtualDevice& lhs,
336  const VirtualDevice& rhs);
337 
342  TVM_DLL static VirtualDevice Default(const VirtualDevice& lhs, const VirtualDevice& rhs);
343 
345 
346  friend class VirtualDeviceCache; // Private implementation helper.
347 };
348 
357 class TVM_DLL VirtualDeviceCache {
358  public:
360  VirtualDevice Make(int device_type = kInvalidDeviceType, int virtual_device_id = -1,
361  Target target = {}, MemoryScope memory_scope = {});
362 
366  VirtualDevice Unique(const VirtualDevice& virtual_device);
367 
368  private:
370  std::unordered_set<VirtualDevice, ffi::StructuralHash, ffi::StructuralEqual> cache_;
371 };
372 
378 constexpr const char* kVirtualDevice = "virtual_device";
379 
380 } // namespace tvm
381 
382 #endif // TVM_TARGET_VIRTUAL_DEVICE_H_
Adapter for AttrsNode with the new reflection API.
Definition: attrs.h:391
Base class of all attribute class.
Definition: attrs.h:102
Container of constant int that adds more constructors.
Definition: expr.h:601
Managed reference class to TargetNode.
Definition: target.h:135
A cache of VirtualDevices. This can be used:
Definition: virtual_device.h:357
VirtualDevice Unique(const VirtualDevice &virtual_device)
Returns the unique VirtualDevice structurally equal to the given virtual_device.
VirtualDevice Make(int device_type=kInvalidDeviceType, int virtual_device_id=-1, Target target={}, MemoryScope memory_scope={})
Returns the unique VirtualDevice representing given fields.
Describes at compile time the constraints on where data is to be stored at runtime down to the (virtu...
Definition: virtual_device.h:172
static void RegisterReflection()
Definition: virtual_device.h:246
DLDeviceType device_type() const
Definition: virtual_device.h:188
Target target
The Target describing how to compile for the virtual device.
Definition: virtual_device.h:205
int virtual_device_id
The device identifier for the virtual device. This must be resolved to a physical device identifier e...
Definition: virtual_device.h:196
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("target.VirtualDevice", VirtualDeviceNode, BaseAttrsNode)
MemoryScope memory_scope
The scope of memory w.r.t. the virtual device which holds data.
Definition: virtual_device.h:212
Device ToDevice() const
Returns the (virtual) Device implied by this VirtualDevice. Both the device_type and virtual_device_m...
Definition: virtual_device.h:237
bool IsFullyUnconstrained() const
Returns true if virtual device is 'fully unconstrained', ie no target/device type,...
Definition: virtual_device.h:218
bool IsFullyConstrained() const
Returns true if virtual device is 'fully constrained', ie target, device id and memory scope are all ...
Definition: virtual_device.h:227
Managed reference class to VirtualDeviceNode.
Definition: virtual_device.h:268
VirtualDevice(int device_type_int=kInvalidDeviceType, int virtual_device_id=-1, Target target={}, MemoryScope memory_scope={})
Construct a virtual device.
static VirtualDevice ForDeviceTargetAndMemoryScope(const Device &device, Target target, MemoryScope memory_scope)
Returns the VirtualDevice for device, target and memory_scope.
Definition: virtual_device.h:324
static ffi::Optional< VirtualDevice > Join(const VirtualDevice &lhs, const VirtualDevice &rhs)
Returns the 'join' of lhs and rhs. The result will agree pointwise with lhs and rhs on all their cons...
static VirtualDevice ForMemoryScope(MemoryScope memory_scope)
Returns the VirtualDevice for memory_scope alone.
Definition: virtual_device.h:319
static VirtualDevice Default(const VirtualDevice &lhs, const VirtualDevice &rhs)
Returns the 'default' of lhs and rhs. The result will be lhs, except any unconstrained fields in lhs ...
static VirtualDevice FullyUnconstrained()
Returns the unique fully unconstrained VirtualDevice.
static VirtualDevice ForDeviceAndTarget(const Device &device, Target target)
Returns the VirtualDevice for device and target.
Definition: virtual_device.h:308
static VirtualDevice ForDevice(const Device &device)
Returns the VirtualDevice for device.
Definition: virtual_device.h:303
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(VirtualDevice, ffi::ObjectRef, VirtualDeviceNode)
static VirtualDevice ForDeviceType(const Integer &device_type, int virtual_device_id=-1)
Definition: virtual_device.h:298
static VirtualDevice ForDeviceType(DLDeviceType device_type, int virtual_device_id=-1)
Returns the VirtualDevice for device_type and (if not -1) virtual_device_id. The target and memory sc...
Definition: virtual_device.h:291
static VirtualDevice ForDeviceType(int device_type, int virtual_device_id=-1)
Definition: virtual_device.h:295
static VirtualDevice ForTarget(Target target)
Returns the VirtualDevice for target.
Definition: virtual_device.h:313
constexpr const char * device_type
The device type.
Definition: stmt.h:1011
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
constexpr const char * kVirtualDevice
Definition: virtual_device.h:378
constexpr int kInvalidDeviceType
Definition: virtual_device.h:55
ffi::String MemoryScope
Abstract label for an area of memory.
Definition: global_info.h:37
DLDevice Device
Definition: device_api.h:43
constexpr int kNullDeviceType
Definition: virtual_device.h:52
Compilation target object.