tvm
virtual_device.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef TVM_TARGET_VIRTUAL_DEVICE_H_
27 #define TVM_TARGET_VIRTUAL_DEVICE_H_
28 
29 #include <tvm/ir/transform.h>
30 #include <tvm/target/target.h>
31 
32 #include <string>
33 #include <unordered_set>
34 #include <utility>
35 
36 namespace tvm {
37 
45 using MemoryScope = String;
46 
47 // NOTE: cannot use enum as they are out of bound of the original enum
48 // and results in an undefined behavior
49 // A 'null' device type, does not correspond to any DLDeviceType enum.
50 // TODO(mbs): This is to help us as we transition away from representing the 'homogenous' case
51 // as a singleton target map indexed by the invalid DLDeviceType '0'.
52 constexpr int kNullDeviceType = 0;
53 
54 // An 'invalid' device type, does not correspond to any DLDeviceType enum.
55 constexpr int kInvalidDeviceType = -1;
56 
176 class VirtualDeviceNode : public AttrsNode<VirtualDeviceNode> {
177  private:
189  int /* actually DLDeviceType */ device_type_int;
190 
191  public:
192  DLDeviceType device_type() const { return static_cast<DLDeviceType>(device_type_int); }
193 
201 
210 
217 
222  bool IsFullyUnconstrained() const {
223  return !target.defined() && device_type() == kInvalidDeviceType && virtual_device_id == -1 &&
225  }
226 
231  bool IsFullyConstrained() const {
232  return target.defined() && virtual_device_id != -1 && !memory_scope.empty();
233  }
234 
241  Device ToDevice() const {
242  ICHECK(device_type_int != kInvalidDeviceType);
243  ICHECK(virtual_device_id != -1);
244  Device device;
245  device.device_type = device_type();
246  device.device_id = virtual_device_id;
247  return device;
248  }
249 
251  TVM_ATTR_FIELD(device_type_int)
252  .describe("The type of the virtual device.")
253  .set_default(kInvalidDeviceType);
255  .describe("The device id of the virtual device.")
256  .set_default(-1);
258  .describe("The target describing how to compile for the virtual device.")
259  .set_default(Target());
261  .describe("The area of memory w.r.t. the virtual device where data is stored.")
262  .set_default("");
263  }
264 
265  friend class VirtualDevice;
266 };
267 
271 class VirtualDevice : public ObjectRef {
272  public:
283  explicit VirtualDevice(int device_type_int = kInvalidDeviceType, int virtual_device_id = -1,
284  Target target = {}, MemoryScope memory_scope = {});
285 
288 
293  static VirtualDevice ForDeviceType(DLDeviceType device_type, int virtual_device_id = -1) {
294  ICHECK_GT(device_type, 0);
295  return VirtualDevice(device_type, virtual_device_id);
296  }
297  static VirtualDevice ForDeviceType(int device_type, int virtual_device_id = -1) {
298  return ForDeviceType(static_cast<DLDeviceType>(device_type), virtual_device_id);
299  }
300  static VirtualDevice ForDeviceType(const Integer& device_type, int virtual_device_id = -1) {
301  return ForDeviceType(static_cast<int>(device_type->value), virtual_device_id);
302  }
303 
305  static VirtualDevice ForDevice(const Device& device) {
306  return ForDeviceType(device.device_type, device.device_id);
307  }
308 
310  static VirtualDevice ForDeviceAndTarget(const Device& device, Target target) {
311  return VirtualDevice(device.device_type, device.device_id, std::move(target));
312  }
313 
315  static VirtualDevice ForTarget(Target target) {
316  DLDeviceType device_type = static_cast<DLDeviceType>(target->GetTargetDeviceType());
317  return VirtualDevice(device_type, /*virtual_device_id=*/0, std::move(target));
318  }
319 
321  static VirtualDevice ForMemoryScope(MemoryScope memory_scope) {
322  return VirtualDevice(kInvalidDeviceType, -1, {}, std::move(memory_scope));
323  }
324 
326  TVM_DLL static VirtualDevice ForDeviceTargetAndMemoryScope(const Device& device, Target target,
327  MemoryScope memory_scope) {
328  return VirtualDevice(device.device_type, device.device_id, std::move(target),
329  std::move(memory_scope));
330  }
331 
337  static Optional<VirtualDevice> Join(const VirtualDevice& lhs, const VirtualDevice& rhs);
338 
343  static VirtualDevice Default(const VirtualDevice& lhs, const VirtualDevice& rhs);
344 
346 
347  friend class VirtualDeviceCache; // Private implementation helper.
348 };
349 
359  public:
361  VirtualDevice Make(int device_type = kInvalidDeviceType, int virtual_device_id = -1,
362  Target target = {}, MemoryScope memory_scope = {});
363 
367  VirtualDevice Unique(const VirtualDevice& virtual_device);
368 
369  private:
371  std::unordered_set<VirtualDevice, StructuralHash, StructuralEqual> cache_;
372 };
373 
379 constexpr const char* kVirtualDevice = "virtual_device";
380 
381 } // namespace tvm
382 
383 #endif // TVM_TARGET_VIRTUAL_DEVICE_H_
The base class of the all the Use "curiously recurring template pattern".
Definition: attrs.h:870
Container of constant int that adds more constructors.
Definition: expr.h:632
Managed reference class to TargetNode.
Definition: target.h:200
A cache of VirtualDevices. This can be used:
Definition: virtual_device.h:358
VirtualDevice Unique(const VirtualDevice &virtual_device)
Returns the unique VirtualDevice structurally equal to the given virtual_device.
VirtualDevice Make(int device_type=kInvalidDeviceType, int virtual_device_id=-1, Target target={}, MemoryScope memory_scope={})
Returns the unique VirtualDevice representing given fields.
Describes at compile time the constraints on where data is to be stored at runtime down to the (virtu...
Definition: virtual_device.h:176
DLDeviceType device_type() const
Definition: virtual_device.h:192
Target target
The Target describing how to compile for the virtual device.
Definition: virtual_device.h:209
int virtual_device_id
The device identifier for the virtual device. This must be resolved to a physical device identifier e...
Definition: virtual_device.h:200
TVM_DECLARE_ATTRS(VirtualDeviceNode, "VirtualDevice")
Definition: virtual_device.h:250
MemoryScope memory_scope
The scope of memory w.r.t. the virtual device which holds data.
Definition: virtual_device.h:216
Device ToDevice() const
Returns the (virtual) Device implied by this VirtualDevice. Both the device_type and virtual_device_m...
Definition: virtual_device.h:241
bool IsFullyUnconstrained() const
Returns true if virtual device is 'fully unconstrained', ie no target/device type,...
Definition: virtual_device.h:222
bool IsFullyConstrained() const
Returns true if virtual device is 'fully constrained', ie target, device id and memory scope are all ...
Definition: virtual_device.h:231
Managed reference class to VirtualDeviceNode.
Definition: virtual_device.h:271
VirtualDevice(int device_type_int=kInvalidDeviceType, int virtual_device_id=-1, Target target={}, MemoryScope memory_scope={})
Construct a virtual device.
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VirtualDevice, ObjectRef, VirtualDeviceNode)
static VirtualDevice ForDeviceTargetAndMemoryScope(const Device &device, Target target, MemoryScope memory_scope)
Returns the VirtualDevice for device, target and memory_scope.
Definition: virtual_device.h:326
static VirtualDevice ForMemoryScope(MemoryScope memory_scope)
Returns the VirtualDevice for memory_scope alone.
Definition: virtual_device.h:321
static VirtualDevice Default(const VirtualDevice &lhs, const VirtualDevice &rhs)
Returns the 'default' of lhs and rhs. The result will be lhs, except any unconstrained fields in lhs ...
static VirtualDevice FullyUnconstrained()
Returns the unique fully unconstrained VirtualDevice.
static VirtualDevice ForDeviceAndTarget(const Device &device, Target target)
Returns the VirtualDevice for device and target.
Definition: virtual_device.h:310
static VirtualDevice ForDevice(const Device &device)
Returns the VirtualDevice for device.
Definition: virtual_device.h:305
static Optional< VirtualDevice > Join(const VirtualDevice &lhs, const VirtualDevice &rhs)
Returns the 'join' of lhs and rhs. The result will agree pointwise with lhs and rhs on all their cons...
static VirtualDevice ForDeviceType(const Integer &device_type, int virtual_device_id=-1)
Definition: virtual_device.h:300
static VirtualDevice ForDeviceType(DLDeviceType device_type, int virtual_device_id=-1)
Returns the VirtualDevice for device_type and (if not -1) virtual_device_id. The target and memory sc...
Definition: virtual_device.h:293
static VirtualDevice ForDeviceType(int device_type, int virtual_device_id=-1)
Definition: virtual_device.h:297
static VirtualDevice ForTarget(Target target)
Returns the VirtualDevice for target.
Definition: virtual_device.h:315
Base class of all object reference.
Definition: object.h:519
bool defined() const
Definition: object.h:552
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
bool empty() const
Retun if the string is empty.
Definition: string.h:208
#define TVM_ATTR_FIELD(FieldName)
Declare an attribute field.
Definition: attrs.h:76
constexpr const char * device_type
The device type.
Definition: stmt.h:1422
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr const char * kVirtualDevice
Definition: virtual_device.h:379
constexpr int kInvalidDeviceType
Definition: virtual_device.h:55
DLDevice Device
Definition: ndarray.h:43
String MemoryScope
Abstract label for an area of memory.
Definition: global_info.h:36
constexpr int kNullDeviceType
Definition: virtual_device.h:52
Compilation target object.