AT_DISPATCH to AT_DISPATCH_V2 Converter
This skill helps convert PyTorch's legacy AT_DISPATCH macros to the new AT_DISPATCH_V2 format, as defined in aten/src/ATen/Dispatch_v2.h.
When to use this skill
Use this skill when:
- Converting AT_DISPATCH_* macros to AT_DISPATCH_V2
- Porting ATen kernels to use the new dispatch API
- Working with files in
aten/src/ATen/native/that use dispatch macros - User mentions "AT_DISPATCH", "dispatch v2", "Dispatch_v2.h", or macro conversion
Quick reference
Old format:
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() { // lambda body });
New format:
AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() { // lambda body }), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool);
Key transformations
- Reorder arguments:
scalar_typeandnamecome first, then lambda, then types - Wrap the lambda: Use
AT_WRAP(lambda)to handle internal commas - Expand type groups: Use
AT_EXPAND(AT_ALL_TYPES)instead of implicit expansion - List individual types: Add extra types (kHalf, kBFloat16, etc.) after expanded groups
- Add include:
#include <ATen/Dispatch_v2.h>near other Dispatch includes
Instructions
Step 1: Add the Dispatch_v2.h include
Add the v2 header near the existing #include <ATen/Dispatch.h>:
#include <ATen/Dispatch.h> #include <ATen/Dispatch_v2.h>
Keep the old Dispatch.h include for now (other code may still need it).
Step 2: Identify the old dispatch pattern
Common patterns to convert:
AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)
Step 3: Map the old macro to type groups
Identify which type group macro corresponds to the base types:
| Old macro base | AT_DISPATCH_V2 type group |
|---|---|
ALL_TYPES | AT_EXPAND(AT_ALL_TYPES) |
FLOATING_TYPES | AT_EXPAND(AT_FLOATING_TYPES) |
INTEGRAL_TYPES | AT_EXPAND(AT_INTEGRAL_TYPES) |
COMPLEX_TYPES | AT_EXPAND(AT_COMPLEX_TYPES) |
ALL_TYPES_AND_COMPLEX | AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX) |
For combined patterns, use multiple AT_EXPAND() entries:
// Old: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...) // New: AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2
Step 4: Extract the individual types
From AT_DISPATCH_*_AND2(type1, type2, ...) or AT_DISPATCH_*_AND3(type1, type2, type3, ...), extract the individual types (type1, type2, etc.).
These become the trailing arguments after the type group:
AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool) ^^^^^^^^^^^^^^^^^^^^^^^^ Individual types from AND3
Step 5: Transform to AT_DISPATCH_V2
Apply the transformation:
Pattern:
AT_DISPATCH_V2( scalar_type, // 1st: The dtype expression "name", // 2nd: The debug string AT_WRAP(lambda), // 3rd: The lambda wrapped in AT_WRAP type_groups, // 4th+: Type groups with AT_EXPAND() individual_types // Last: Individual types )
Example transformation:
// BEFORE AT_DISPATCH_ALL_TYPES_AND3( kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() { min_values_kernel_cuda_impl<scalar_t>(iter); } ); // AFTER AT_DISPATCH_V2( iter.dtype(), "min_values_cuda", AT_WRAP([&]() { min_values_kernel_cuda_impl<scalar_t>(iter); }), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool );
Step 6: Handle multi-line lambdas
For lambdas with internal commas or complex expressions, AT_WRAP is essential:
AT_DISPATCH_V2( dtype, "complex_kernel", AT_WRAP([&]() { gpu_reduce_kernel<scalar_t, scalar_t>( iter, MinOps<scalar_t>{}, thrust::pair<scalar_t, int64_t>(upper_bound(), 0) // Commas inside! ); }), AT_EXPAND(AT_ALL_TYPES) );
Step 7: Verify the conversion
Check that:
-
AT_WRAP()wraps the entire lambda - Type groups use
AT_EXPAND() - Individual types don't have
AT_EXPAND()(justkBFloat16, notAT_EXPAND(kBFloat16)) - Argument order is: scalar_type, name, lambda, types
- Include added:
#include <ATen/Dispatch_v2.h>
Type group reference
Available type group macros (use with AT_EXPAND()):
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort AT_FLOATING_TYPES // kDouble, kFloat AT_COMPLEX_TYPES // kComplexDouble, kComplexFloat AT_QINT_TYPES // kQInt8, kQUInt8, kQInt32 AT_ALL_TYPES // INTEGRAL_TYPES + FLOATING_TYPES AT_ALL_TYPES_AND_COMPLEX // ALL_TYPES + COMPLEX_TYPES AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + unsigned types AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64 AT_FLOAT8_TYPES // Float8 variants
Common patterns
Pattern: AT_DISPATCH_ALL_TYPES_AND2
// Before AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() { kernel<scalar_t>(data); }); // After AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { kernel<scalar_t>(data); }), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
Pattern: AT_DISPATCH_FLOATING_TYPES_AND3
// Before AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn, tensor.scalar_type(), "float_op", [&] { process<scalar_t>(tensor); }); // After AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] { process<scalar_t>(tensor); }), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn);
Pattern: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2
// Before AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( kComplexHalf, kHalf, self.scalar_type(), "complex_op", [&] { result = compute<scalar_t>(self); } ); // After AT_DISPATCH_V2( self.scalar_type(), "complex_op", AT_WRAP([&] { result = compute<scalar_t>(self); }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), kComplexHalf, kHalf );
Edge cases
Case 1: No extra types (rare)
// Before AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel<scalar_t>(); }); // After AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { kernel<scalar_t>(); }), AT_EXPAND(AT_ALL_TYPES));
Case 2: Many individual types (AND4, AND5, etc.)
// Before AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2, dtype, "float8_op", [&]() { kernel<scalar_t>(); }); // After AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() { kernel<scalar_t>(); }), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2);
Case 3: Lambda with no captures
// Before AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() { static_kernel<scalar_t>(); }); // After AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() { static_kernel<scalar_t>(); }), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool);
Benefits of AT_DISPATCH_V2
- No arity in macro name: Don't need different macros for AND2, AND3, AND4
- Composable type sets: Mix and match type groups with
AT_EXPAND() - Extensible: Easy to add more types without hitting macro limits
- Clearer: Type groups are explicit, not implicit in macro name
Important notes
- Keep
#include <ATen/Dispatch.h>- other code may need it - The
AT_WRAP()is mandatory - prevents comma parsing issues in the lambda - Type groups need
AT_EXPAND(), individual types don't - The v2 API is in
aten/src/ATen/Dispatch_v2.h- refer to it for full docs - See the header file for the Python script to regenerate the macro implementation
Workflow
When asked to convert AT_DISPATCH macros:
- Read the file to identify all AT_DISPATCH uses
- Add
#include <ATen/Dispatch_v2.h>if not present - For each dispatch macro:
- Identify the pattern and extract components
- Map the base type group
- Extract individual types
- Construct the AT_DISPATCH_V2 call
- Apply with Edit tool
- Show the user the complete converted file
- Explain what was changed
Do NOT compile or test the code - focus on accurate conversion only.