Skip to main content

Mojo struct

KVBufferImpl

struct KVBufferImpl[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool, mut: Bool, dtype: DType, layout: Layout, address_space: AddressSpace, alignment: Int, origin: Origin[mut], masked: Bool, layout_int_type: DType, linear_idx_type: DType, //, config: KVBufferConfig, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: OptionalReg[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False]

Fields

  • load_tile (LayoutTensor[dtype, Layout.row_major(((num_stages * KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_mmas) * KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_k_tiles), KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]):
  • mma_tile (LayoutTensor[dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]):
  • smem_iter (LayoutTensorIter[dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]):
  • bounds (Int):
  • load_tile_id (Int):
  • global_iterator (LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, Layout(IntTuple(1), IntTuple(1)), layout_int_type, linear_idx_type, masked, alignment, config.btile_dim0, config.btile_dim1]()[0], origin, address_space=address_space, axis=config.iterator_axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, config.btile_dim0, config.btile_dim1]()]):

Implemented traits

AnyType, KVBuffer, UnknownDestructibility

Aliases

__del__is_trivial

alias __del__is_trivial = True

base_layout

alias base_layout = Layout.row_major(config.btile_dim0, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)

GlobalTensorType

alias GlobalTensorType = LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]

GlobalTiledIteratorType

alias GlobalTiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, Layout(IntTuple(1), IntTuple(1)), layout_int_type, linear_idx_type, masked, alignment, config.btile_dim0, config.btile_dim1]()[0], origin, address_space=address_space, axis=config.iterator_axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, config.btile_dim0, config.btile_dim1]()]

LoadTileType

alias LoadTileType = LayoutTensor[dtype, Layout.row_major(((num_stages * KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_mmas) * KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_k_tiles), KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]

MMA_K

alias MMA_K = shape.__getitem__[3, DType.int64, Int](2)

MMA_N

alias MMA_N = shape.__getitem__[3, DType.int64, Int](1)

mma_tile_layout

alias mma_tile_layout = Layout.row_major(KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_mmas, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)

MMATileType

alias MMATileType = LayoutTensor[dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]

num_k_tiles

alias num_k_tiles = ceildiv(BK, (KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMA_K * group_size))

num_mmas

alias num_mmas = ceildiv(config.wsize, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMA_N)

num_repeats

alias num_repeats = (config.btile_dim1 // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)

num_warps_n

alias num_warps_n = (BN // WN)

SharedIterType

alias SharedIterType = LayoutTensorIter[dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]

SharedTileType

alias SharedTileType = LayoutTensorIter[dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True].LayoutTensorType

SharedWarpTileType

alias SharedWarpTileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_index_type(AddressSpace.SHARED), _get_index_type(AddressSpace.SHARED), False, align_of[dtype](), KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].wtile_dim0, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].wtile_dim1]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_index_type(AddressSpace.SHARED), linear_idx_type=_get_index_type(AddressSpace.SHARED), masked=_tile_is_masked[KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].wtile_dim0, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].wtile_dim1]()]

simd_width

alias simd_width = simd_width_of[dtype]()

smem_layout

alias smem_layout = blocked_product(KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].base_layout, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].tiler_layout, True) if (token_gen ^ True) else Layout.row_major(config.btile_dim0, config.btile_dim1)

thread_layout

alias thread_layout = Layout.row_major(((min(num_threads, ((config.btile_dim0 * config.btile_dim1) // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)) * KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width) // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout.stride[0].value()), (KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout.stride[0].value() // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)) if token_gen else Layout.row_major((num_threads // 4), 4)

tiler_layout

alias tiler_layout = Layout.row_major(1, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_repeats)

wtile_dim0

alias wtile_dim0 = config.wtile_dim0

wtile_dim1

alias wtile_dim1 = config.wtile_dim1

Methods

__init__

__init__(out self, global_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_b_rows: OptionalReg[Int], shared_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, mut=mut, origin=origin])

get_dtype

static get_dtype() -> DType

Returns:

DType

load_from_dram

load_from_dram(mut self)

get_mma_tile

get_mma_tile(self) -> LayoutTensor[dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]

Returns:

LayoutTensor

copy_to_shared

copy_to_shared[tile_id: Int = 0](self)

load_from_shared

load_from_shared[k_mma: Int](self)

Was this page helpful?