1use crate::core_arch::{simd::*, x86::*};
2
3#[cfg(test)]
4use stdarch_test::assert_instr;
5
6#[inline]
14#[target_feature(enable = "amx-tile")]
15#[cfg_attr(test, assert_instr(ldtilecfg))]
16#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
17pub unsafe fn _tile_loadconfig(mem_addr: *const u8) {
18 ldtilecfg(mem_addr);
19}
20
21#[inline]
27#[target_feature(enable = "amx-tile")]
28#[cfg_attr(test, assert_instr(sttilecfg))]
29#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
30pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) {
31 sttilecfg(mem_addr);
32}
33
34#[inline]
38#[rustc_legacy_const_generics(0)]
39#[target_feature(enable = "amx-tile")]
40#[cfg_attr(test, assert_instr(tileloadd, DST = 0))]
41#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
42pub unsafe fn _tile_loadd<const DST: i32>(base: *const u8, stride: usize) {
43 static_assert_uimm_bits!(DST, 3);
44 tileloadd64(DST as i8, base, stride);
45}
46
47#[inline]
51#[target_feature(enable = "amx-tile")]
52#[cfg_attr(test, assert_instr(tilerelease))]
53#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
54pub unsafe fn _tile_release() {
55 tilerelease();
56}
57
58#[inline]
62#[rustc_legacy_const_generics(0)]
63#[target_feature(enable = "amx-tile")]
64#[cfg_attr(test, assert_instr(tilestored, DST = 0))]
65#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
66pub unsafe fn _tile_stored<const DST: i32>(base: *mut u8, stride: usize) {
67 static_assert_uimm_bits!(DST, 3);
68 tilestored64(DST as i8, base, stride);
69}
70
71#[inline]
77#[rustc_legacy_const_generics(0)]
78#[target_feature(enable = "amx-tile")]
79#[cfg_attr(test, assert_instr(tileloaddt1, DST = 0))]
80#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
81pub unsafe fn _tile_stream_loadd<const DST: i32>(base: *const u8, stride: usize) {
82 static_assert_uimm_bits!(DST, 3);
83 tileloaddt164(DST as i8, base, stride);
84}
85
86#[inline]
90#[rustc_legacy_const_generics(0)]
91#[target_feature(enable = "amx-tile")]
92#[cfg_attr(test, assert_instr(tilezero, DST = 0))]
93#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
94pub unsafe fn _tile_zero<const DST: i32>() {
95 static_assert_uimm_bits!(DST, 3);
96 tilezero(DST as i8);
97}
98
99#[inline]
105#[rustc_legacy_const_generics(0, 1, 2)]
106#[target_feature(enable = "amx-bf16")]
107#[cfg_attr(test, assert_instr(tdpbf16ps, DST = 0, A = 1, B = 2))]
108#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
109pub unsafe fn _tile_dpbf16ps<const DST: i32, const A: i32, const B: i32>() {
110 static_assert_uimm_bits!(DST, 3);
111 static_assert_uimm_bits!(A, 3);
112 static_assert_uimm_bits!(B, 3);
113 tdpbf16ps(DST as i8, A as i8, B as i8);
114}
115
116#[inline]
123#[rustc_legacy_const_generics(0, 1, 2)]
124#[target_feature(enable = "amx-int8")]
125#[cfg_attr(test, assert_instr(tdpbssd, DST = 0, A = 1, B = 2))]
126#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
127pub unsafe fn _tile_dpbssd<const DST: i32, const A: i32, const B: i32>() {
128 static_assert_uimm_bits!(DST, 3);
129 static_assert_uimm_bits!(A, 3);
130 static_assert_uimm_bits!(B, 3);
131 tdpbssd(DST as i8, A as i8, B as i8);
132}
133
134#[inline]
141#[rustc_legacy_const_generics(0, 1, 2)]
142#[target_feature(enable = "amx-int8")]
143#[cfg_attr(test, assert_instr(tdpbsud, DST = 0, A = 1, B = 2))]
144#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
145pub unsafe fn _tile_dpbsud<const DST: i32, const A: i32, const B: i32>() {
146 static_assert_uimm_bits!(DST, 3);
147 static_assert_uimm_bits!(A, 3);
148 static_assert_uimm_bits!(B, 3);
149 tdpbsud(DST as i8, A as i8, B as i8);
150}
151
152#[inline]
159#[rustc_legacy_const_generics(0, 1, 2)]
160#[target_feature(enable = "amx-int8")]
161#[cfg_attr(test, assert_instr(tdpbusd, DST = 0, A = 1, B = 2))]
162#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
163pub unsafe fn _tile_dpbusd<const DST: i32, const A: i32, const B: i32>() {
164 static_assert_uimm_bits!(DST, 3);
165 static_assert_uimm_bits!(A, 3);
166 static_assert_uimm_bits!(B, 3);
167 tdpbusd(DST as i8, A as i8, B as i8);
168}
169
170#[inline]
177#[rustc_legacy_const_generics(0, 1, 2)]
178#[target_feature(enable = "amx-int8")]
179#[cfg_attr(test, assert_instr(tdpbuud, DST = 0, A = 1, B = 2))]
180#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
181pub unsafe fn _tile_dpbuud<const DST: i32, const A: i32, const B: i32>() {
182 static_assert_uimm_bits!(DST, 3);
183 static_assert_uimm_bits!(A, 3);
184 static_assert_uimm_bits!(B, 3);
185 tdpbuud(DST as i8, A as i8, B as i8);
186}
187
188#[inline]
194#[rustc_legacy_const_generics(0, 1, 2)]
195#[target_feature(enable = "amx-fp16")]
196#[cfg_attr(test, assert_instr(tdpfp16ps, DST = 0, A = 1, B = 2))]
197#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
198pub unsafe fn _tile_dpfp16ps<const DST: i32, const A: i32, const B: i32>() {
199 static_assert_uimm_bits!(DST, 3);
200 static_assert_uimm_bits!(A, 3);
201 static_assert_uimm_bits!(B, 3);
202 tdpfp16ps(DST as i8, A as i8, B as i8);
203}
204
205#[inline]
215#[rustc_legacy_const_generics(0, 1, 2)]
216#[target_feature(enable = "amx-complex")]
217#[cfg_attr(test, assert_instr(tcmmimfp16ps, DST = 0, A = 1, B = 2))]
218#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
219pub unsafe fn _tile_cmmimfp16ps<const DST: i32, const A: i32, const B: i32>() {
220 static_assert_uimm_bits!(DST, 3);
221 static_assert_uimm_bits!(A, 3);
222 static_assert_uimm_bits!(B, 3);
223 tcmmimfp16ps(DST as i8, A as i8, B as i8);
224}
225
226#[inline]
236#[rustc_legacy_const_generics(0, 1, 2)]
237#[target_feature(enable = "amx-complex")]
238#[cfg_attr(test, assert_instr(tcmmrlfp16ps, DST = 0, A = 1, B = 2))]
239#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
240pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() {
241 static_assert_uimm_bits!(DST, 3);
242 static_assert_uimm_bits!(A, 3);
243 static_assert_uimm_bits!(B, 3);
244 tcmmrlfp16ps(DST as i8, A as i8, B as i8);
245}
246
247#[inline]
252#[rustc_legacy_const_generics(0, 1, 2)]
253#[target_feature(enable = "amx-fp8")]
254#[cfg_attr(
255 all(test, any(target_os = "linux", target_env = "msvc")),
256 assert_instr(tdpbf8ps, DST = 0, A = 1, B = 2)
257)]
258#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
259pub unsafe fn _tile_dpbf8ps<const DST: i32, const A: i32, const B: i32>() {
260 static_assert_uimm_bits!(DST, 3);
261 static_assert_uimm_bits!(A, 3);
262 static_assert_uimm_bits!(B, 3);
263 tdpbf8ps(DST as i8, A as i8, B as i8);
264}
265
266#[inline]
271#[rustc_legacy_const_generics(0, 1, 2)]
272#[target_feature(enable = "amx-fp8")]
273#[cfg_attr(
274 all(test, any(target_os = "linux", target_env = "msvc")),
275 assert_instr(tdpbhf8ps, DST = 0, A = 1, B = 2)
276)]
277#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
278pub unsafe fn _tile_dpbhf8ps<const DST: i32, const A: i32, const B: i32>() {
279 static_assert_uimm_bits!(DST, 3);
280 static_assert_uimm_bits!(A, 3);
281 static_assert_uimm_bits!(B, 3);
282 tdpbhf8ps(DST as i8, A as i8, B as i8);
283}
284
285#[inline]
290#[rustc_legacy_const_generics(0, 1, 2)]
291#[target_feature(enable = "amx-fp8")]
292#[cfg_attr(
293 all(test, any(target_os = "linux", target_env = "msvc")),
294 assert_instr(tdphbf8ps, DST = 0, A = 1, B = 2)
295)]
296#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
297pub unsafe fn _tile_dphbf8ps<const DST: i32, const A: i32, const B: i32>() {
298 static_assert_uimm_bits!(DST, 3);
299 static_assert_uimm_bits!(A, 3);
300 static_assert_uimm_bits!(B, 3);
301 tdphbf8ps(DST as i8, A as i8, B as i8);
302}
303
304#[inline]
309#[rustc_legacy_const_generics(0, 1, 2)]
310#[target_feature(enable = "amx-fp8")]
311#[cfg_attr(
312 all(test, any(target_os = "linux", target_env = "msvc")),
313 assert_instr(tdphf8ps, DST = 0, A = 1, B = 2)
314)]
315#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
316pub unsafe fn _tile_dphf8ps<const DST: i32, const A: i32, const B: i32>() {
317 static_assert_uimm_bits!(DST, 3);
318 static_assert_uimm_bits!(A, 3);
319 static_assert_uimm_bits!(B, 3);
320 tdphf8ps(DST as i8, A as i8, B as i8);
321}
322
323#[inline]
329#[rustc_legacy_const_generics(0)]
330#[target_feature(enable = "amx-movrs")]
331#[cfg_attr(
332 all(test, any(target_os = "linux", target_env = "msvc")),
333 assert_instr(tileloaddrs, DST = 0)
334)]
335#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
336pub unsafe fn _tile_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
337 static_assert_uimm_bits!(DST, 3);
338 tileloaddrs64(DST as i8, base, stride);
339}
340
341#[inline]
349#[rustc_legacy_const_generics(0)]
350#[target_feature(enable = "amx-movrs")]
351#[cfg_attr(
352 all(test, any(target_os = "linux", target_env = "msvc")),
353 assert_instr(tileloaddrst1, DST = 0)
354)]
355#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
356pub unsafe fn _tile_stream_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
357 static_assert_uimm_bits!(DST, 3);
358 tileloaddrst164(DST as i8, base, stride);
359}
360
361#[inline]
372#[rustc_legacy_const_generics(0, 1, 2)]
373#[target_feature(enable = "amx-tf32")]
374#[cfg_attr(
375 all(test, any(target_os = "linux", target_env = "msvc")),
376 assert_instr(tmmultf32ps, DST = 0, A = 1, B = 2)
377)]
378#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
379pub unsafe fn _tile_mmultf32ps<const DST: i32, const A: i32, const B: i32>() {
380 static_assert_uimm_bits!(DST, 3);
381 static_assert_uimm_bits!(A, 3);
382 static_assert_uimm_bits!(B, 3);
383 tmmultf32ps(DST as i8, A as i8, B as i8);
384}
385
386#[inline]
389#[rustc_legacy_const_generics(0)]
390#[target_feature(enable = "amx-avx512,avx10.2")]
391#[cfg_attr(
392 all(test, any(target_os = "linux", target_env = "msvc")),
393 assert_instr(tcvtrowd2ps, TILE = 0)
394)]
395#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
396pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 {
397 static_assert_uimm_bits!(TILE, 3);
398 tcvtrowd2ps(TILE as i8, row).as_m512()
399}
400
401#[inline]
405#[rustc_legacy_const_generics(0)]
406#[target_feature(enable = "amx-avx512,avx10.2")]
407#[cfg_attr(
408 all(test, any(target_os = "linux", target_env = "msvc")),
409 assert_instr(tcvtrowps2phh, TILE = 0)
410)]
411#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
412pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h {
413 static_assert_uimm_bits!(TILE, 3);
414 tcvtrowps2phh(TILE as i8, row).as_m512h()
415}
416
417#[inline]
421#[rustc_legacy_const_generics(0)]
422#[target_feature(enable = "amx-avx512,avx10.2")]
423#[cfg_attr(
424 all(test, any(target_os = "linux", target_env = "msvc")),
425 assert_instr(tcvtrowps2phl, TILE = 0)
426)]
427#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
428pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h {
429 static_assert_uimm_bits!(TILE, 3);
430 tcvtrowps2phl(TILE as i8, row).as_m512h()
431}
432
433#[inline]
435#[rustc_legacy_const_generics(0)]
436#[target_feature(enable = "amx-avx512,avx10.2")]
437#[cfg_attr(
438 all(test, any(target_os = "linux", target_env = "msvc")),
439 assert_instr(tilemovrow, TILE = 0)
440)]
441#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
442pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
443 static_assert_uimm_bits!(TILE, 3);
444 tilemovrow(TILE as i8, row).as_m512i()
445}
446
447#[allow(improper_ctypes)]
448unsafe extern "C" {
449 #[link_name = "llvm.x86.ldtilecfg"]
450 fn ldtilecfg(mem_addr: *const u8);
451 #[link_name = "llvm.x86.sttilecfg"]
452 fn sttilecfg(mem_addr: *mut u8);
453 #[link_name = "llvm.x86.tileloadd64"]
454 fn tileloadd64(dst: i8, base: *const u8, stride: usize);
455 #[link_name = "llvm.x86.tileloaddt164"]
456 fn tileloaddt164(dst: i8, base: *const u8, stride: usize);
457 #[link_name = "llvm.x86.tilerelease"]
458 fn tilerelease();
459 #[link_name = "llvm.x86.tilestored64"]
460 fn tilestored64(dst: i8, base: *mut u8, stride: usize);
461 #[link_name = "llvm.x86.tilezero"]
462 fn tilezero(dst: i8);
463 #[link_name = "llvm.x86.tdpbf16ps"]
464 fn tdpbf16ps(dst: i8, a: i8, b: i8);
465 #[link_name = "llvm.x86.tdpbuud"]
466 fn tdpbuud(dst: i8, a: i8, b: i8);
467 #[link_name = "llvm.x86.tdpbusd"]
468 fn tdpbusd(dst: i8, a: i8, b: i8);
469 #[link_name = "llvm.x86.tdpbsud"]
470 fn tdpbsud(dst: i8, a: i8, b: i8);
471 #[link_name = "llvm.x86.tdpbssd"]
472 fn tdpbssd(dst: i8, a: i8, b: i8);
473 #[link_name = "llvm.x86.tdpfp16ps"]
474 fn tdpfp16ps(dst: i8, a: i8, b: i8);
475 #[link_name = "llvm.x86.tcmmimfp16ps"]
476 fn tcmmimfp16ps(dst: i8, a: i8, b: i8);
477 #[link_name = "llvm.x86.tcmmrlfp16ps"]
478 fn tcmmrlfp16ps(dst: i8, a: i8, b: i8);
479 #[link_name = "llvm.x86.tdpbf8ps"]
480 fn tdpbf8ps(dst: i8, a: i8, b: i8);
481 #[link_name = "llvm.x86.tdpbhf8ps"]
482 fn tdpbhf8ps(dst: i8, a: i8, b: i8);
483 #[link_name = "llvm.x86.tdphbf8ps"]
484 fn tdphbf8ps(dst: i8, a: i8, b: i8);
485 #[link_name = "llvm.x86.tdphf8ps"]
486 fn tdphf8ps(dst: i8, a: i8, b: i8);
487 #[link_name = "llvm.x86.tileloaddrs64"]
488 fn tileloaddrs64(dst: i8, base: *const u8, stride: usize);
489 #[link_name = "llvm.x86.tileloaddrst164"]
490 fn tileloaddrst164(dst: i8, base: *const u8, stride: usize);
491 #[link_name = "llvm.x86.tmmultf32ps"]
492 fn tmmultf32ps(dst: i8, a: i8, b: i8);
493 #[link_name = "llvm.x86.tcvtrowd2ps"]
494 fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16;
495 #[link_name = "llvm.x86.tcvtrowps2phh"]
496 fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32;
497 #[link_name = "llvm.x86.tcvtrowps2phl"]
498 fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
499 #[link_name = "llvm.x86.tilemovrow"]
500 fn tilemovrow(tile: i8, row: u32) -> i32x16;
501}
502
503#[cfg(test)]
504mod tests {
505 use crate::core_arch::x86::_mm_cvtness_sbh;
506 use crate::core_arch::x86_64::*;
507 use core::{array, mem::transmute};
508 use stdarch_test::simd_test;
509 #[cfg(target_os = "linux")]
510 use syscalls::{Sysno, syscall};
511
512 #[allow(non_camel_case_types)]
513 #[repr(C, packed)]
514 #[derive(Copy, Clone, Default, Debug, PartialEq)]
515 struct __tilecfg {
516 palette: u8,
518 start_row: u8,
519 reserved_a0: [u8; 14],
521 colsb: [u16; 8],
523 reserved_b0: [u16; 8],
525 rows: [u8; 8],
527 reserved_c0: [u8; 8],
529 }
530
531 impl __tilecfg {
532 fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self {
533 Self {
534 palette,
535 start_row,
536 reserved_a0: [0u8; 14],
537 colsb,
538 reserved_b0: [0u16; 8],
539 rows,
540 reserved_c0: [0u8; 8],
541 }
542 }
543
544 const fn as_ptr(&self) -> *const u8 {
545 self as *const Self as *const u8
546 }
547
548 fn as_mut_ptr(&mut self) -> *mut u8 {
549 self as *mut Self as *mut u8
550 }
551 }
552
553 #[cfg(not(target_os = "linux"))]
554 #[target_feature(enable = "amx-tile")]
555 fn _init_amx() {}
556
557 #[cfg(target_os = "linux")]
558 #[target_feature(enable = "amx-tile")]
559 #[inline]
560 unsafe fn _init_amx() {
561 let mut ret: usize;
562 let mut xfeatures: usize = 0;
563 ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize)
564 .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed");
565 if ret != 0 {
566 panic!("Failed to get XFEATURES");
567 } else {
568 match 0b11 & (xfeatures >> 17) {
569 0 => panic!("AMX is not available"),
570 1 => {
571 ret = syscall!(Sysno::arch_prctl, 0x1023, 18)
572 .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed");
573 if ret != 0 {
574 panic!("Failed to enable AMX");
575 }
576 }
577 3 => {}
578 _ => unreachable!(),
579 }
580 }
581 }
582
583 #[simd_test(enable = "amx-tile")]
584 fn test_tile_loadconfig() {
585 unsafe {
586 let config = __tilecfg::default();
587 _tile_loadconfig(config.as_ptr());
588 _tile_release();
589 }
590 }
591
592 #[simd_test(enable = "amx-tile")]
593 fn test_tile_storeconfig() {
594 unsafe {
595 let config = __tilecfg::new(1, 0, [32; 8], [8; 8]);
596 _tile_loadconfig(config.as_ptr());
597 let mut _config = __tilecfg::default();
598 _tile_storeconfig(_config.as_mut_ptr());
599 _tile_release();
600 assert_eq!(config, _config);
601 }
602 }
603
604 #[simd_test(enable = "amx-tile")]
605 fn test_tile_zero() {
606 unsafe {
607 _init_amx();
608 let mut config = __tilecfg::default();
609 config.palette = 1;
610 config.colsb[0] = 64;
611 config.rows[0] = 16;
612 _tile_loadconfig(config.as_ptr());
613 _tile_zero::<0>();
614 let mut out = [[1_i8; 64]; 16];
615 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
616 _tile_release();
617 assert_eq!(out, [[0; 64]; 16]);
618 }
619 }
620
621 #[simd_test(enable = "amx-tile")]
622 fn test_tile_stored() {
623 unsafe {
624 _init_amx();
625 let mut config = __tilecfg::default();
626 config.palette = 1;
627 config.colsb[0] = 64;
628 config.rows[0] = 16;
629 _tile_loadconfig(config.as_ptr());
630 _tile_zero::<0>();
631 let mut out = [[1_i8; 64]; 16];
632 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
633 _tile_release();
634 assert_eq!(out, [[0; 64]; 16]);
635 }
636 }
637
638 #[simd_test(enable = "amx-tile")]
639 fn test_tile_loadd() {
640 unsafe {
641 _init_amx();
642 let mut config = __tilecfg::default();
643 config.palette = 1;
644 config.colsb[0] = 64;
645 config.rows[0] = 16;
646 _tile_loadconfig(config.as_ptr());
647 _tile_zero::<0>();
648 let mat = [1_i8; 1024];
649 _tile_loadd::<0>(&mat as *const i8 as *const u8, 64);
650 let mut out = [[0_i8; 64]; 16];
651 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
652 _tile_release();
653 assert_eq!(out, [[1; 64]; 16]);
654 }
655 }
656
657 #[simd_test(enable = "amx-tile")]
658 fn test_tile_stream_loadd() {
659 unsafe {
660 _init_amx();
661 let mut config = __tilecfg::default();
662 config.palette = 1;
663 config.colsb[0] = 64;
664 config.rows[0] = 16;
665 _tile_loadconfig(config.as_ptr());
666 _tile_zero::<0>();
667 let mat = [1_i8; 1024];
668 _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64);
669 let mut out = [[0_i8; 64]; 16];
670 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
671 _tile_release();
672 assert_eq!(out, [[1; 64]; 16]);
673 }
674 }
675
676 #[simd_test(enable = "amx-tile")]
677 fn test_tile_release() {
678 unsafe {
679 _tile_release();
680 }
681 }
682
683 #[simd_test(enable = "amx-bf16,avx512f")]
684 fn test_tile_dpbf16ps() {
685 unsafe {
686 _init_amx();
687 let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits();
688 let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits();
689 let ones: [u8; 1024] = transmute([bf16_1; 512]);
690 let twos: [u8; 1024] = transmute([bf16_2; 512]);
691 let mut res = [[0f32; 16]; 16];
692 let mut config = __tilecfg::default();
693 config.palette = 1;
694 (0..=2).for_each(|i| {
695 config.colsb[i] = 64;
696 config.rows[i] = 16;
697 });
698 _tile_loadconfig(config.as_ptr());
699 _tile_zero::<0>();
700 _tile_loadd::<1>(&ones as *const u8, 64);
701 _tile_loadd::<2>(&twos as *const u8, 64);
702 _tile_dpbf16ps::<0, 1, 2>();
703 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
704 _tile_release();
705 assert_eq!(res, [[64f32; 16]; 16]);
706 }
707 }
708
709 #[simd_test(enable = "amx-int8")]
710 fn test_tile_dpbssd() {
711 unsafe {
712 _init_amx();
713 let ones = [-1_i8; 1024];
714 let twos = [-2_i8; 1024];
715 let mut res = [[0_i32; 16]; 16];
716 let mut config = __tilecfg::default();
717 config.palette = 1;
718 (0..=2).for_each(|i| {
719 config.colsb[i] = 64;
720 config.rows[i] = 16;
721 });
722 _tile_loadconfig(config.as_ptr());
723 _tile_zero::<0>();
724 _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
725 _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
726 _tile_dpbssd::<0, 1, 2>();
727 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
728 _tile_release();
729 assert_eq!(res, [[128_i32; 16]; 16]);
730 }
731 }
732
733 #[simd_test(enable = "amx-int8")]
734 fn test_tile_dpbsud() {
735 unsafe {
736 _init_amx();
737 let ones = [-1_i8; 1024];
738 let twos = [2_u8; 1024];
739 let mut res = [[0_i32; 16]; 16];
740 let mut config = __tilecfg::default();
741 config.palette = 1;
742 (0..=2).for_each(|i| {
743 config.colsb[i] = 64;
744 config.rows[i] = 16;
745 });
746 _tile_loadconfig(config.as_ptr());
747 _tile_zero::<0>();
748 _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
749 _tile_loadd::<2>(&twos as *const u8, 64);
750 _tile_dpbsud::<0, 1, 2>();
751 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
752 _tile_release();
753 assert_eq!(res, [[-128_i32; 16]; 16]);
754 }
755 }
756
757 #[simd_test(enable = "amx-int8")]
758 fn test_tile_dpbusd() {
759 unsafe {
760 _init_amx();
761 let ones = [1_u8; 1024];
762 let twos = [-2_i8; 1024];
763 let mut res = [[0_i32; 16]; 16];
764 let mut config = __tilecfg::default();
765 config.palette = 1;
766 (0..=2).for_each(|i| {
767 config.colsb[i] = 64;
768 config.rows[i] = 16;
769 });
770 _tile_loadconfig(config.as_ptr());
771 _tile_zero::<0>();
772 _tile_loadd::<1>(&ones as *const u8, 64);
773 _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
774 _tile_dpbusd::<0, 1, 2>();
775 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
776 _tile_release();
777 assert_eq!(res, [[-128_i32; 16]; 16]);
778 }
779 }
780
781 #[simd_test(enable = "amx-int8")]
782 fn test_tile_dpbuud() {
783 unsafe {
784 _init_amx();
785 let ones = [1_u8; 1024];
786 let twos = [2_u8; 1024];
787 let mut res = [[0_i32; 16]; 16];
788 let mut config = __tilecfg::default();
789 config.palette = 1;
790 (0..=2).for_each(|i| {
791 config.colsb[i] = 64;
792 config.rows[i] = 16;
793 });
794 _tile_loadconfig(config.as_ptr());
795 _tile_zero::<0>();
796 _tile_loadd::<1>(&ones as *const u8, 64);
797 _tile_loadd::<2>(&twos as *const u8, 64);
798 _tile_dpbuud::<0, 1, 2>();
799 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
800 _tile_release();
801 assert_eq!(res, [[128_i32; 16]; 16]);
802 }
803 }
804
805 #[simd_test(enable = "amx-fp16")]
806 fn test_tile_dpfp16ps() {
807 unsafe {
808 _init_amx();
809 let ones = [1f16; 512];
810 let twos = [2f16; 512];
811 let mut res = [[0f32; 16]; 16];
812 let mut config = __tilecfg::default();
813 config.palette = 1;
814 (0..=2).for_each(|i| {
815 config.colsb[i] = 64;
816 config.rows[i] = 16;
817 });
818 _tile_loadconfig(config.as_ptr());
819 _tile_zero::<0>();
820 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
821 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
822 _tile_dpfp16ps::<0, 1, 2>();
823 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
824 _tile_release();
825 assert_eq!(res, [[64f32; 16]; 16]);
826 }
827 }
828
829 #[simd_test(enable = "amx-complex")]
830 fn test_tile_cmmimfp16ps() {
831 unsafe {
832 _init_amx();
833 let ones = [1f16; 512];
834 let twos = [2f16; 512];
835 let mut res = [[0f32; 16]; 16];
836 let mut config = __tilecfg::default();
837 config.palette = 1;
838 (0..=2).for_each(|i| {
839 config.colsb[i] = 64;
840 config.rows[i] = 16;
841 });
842 _tile_loadconfig(config.as_ptr());
843 _tile_zero::<0>();
844 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
845 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
846 _tile_cmmimfp16ps::<0, 1, 2>();
847 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
848 _tile_release();
849 assert_eq!(res, [[64f32; 16]; 16]);
850 }
851 }
852
853 #[simd_test(enable = "amx-complex")]
854 fn test_tile_cmmrlfp16ps() {
855 unsafe {
856 _init_amx();
857 let ones = [1f16; 512];
858 let twos = [2f16; 512];
859 let mut res = [[0f32; 16]; 16];
860 let mut config = __tilecfg::default();
861 config.palette = 1;
862 (0..=2).for_each(|i| {
863 config.colsb[i] = 64;
864 config.rows[i] = 16;
865 });
866 _tile_loadconfig(config.as_ptr());
867 _tile_zero::<0>();
868 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
869 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
870 _tile_cmmrlfp16ps::<0, 1, 2>();
871 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
872 _tile_release();
873 assert_eq!(res, [[0f32; 16]; 16]);
874 }
875 }
876
877 const BF8_ONE: u8 = 0x3c;
878 const BF8_TWO: u8 = 0x40;
879 const HF8_ONE: u8 = 0x38;
880 const HF8_TWO: u8 = 0x40;
881
882 #[simd_test(enable = "amx-fp8")]
883 fn test_tile_dpbf8ps() {
884 unsafe {
885 _init_amx();
886 let ones = [BF8_ONE; 1024];
887 let twos = [BF8_TWO; 1024];
888 let mut res = [[0.0_f32; 16]; 16];
889 let mut config = __tilecfg::default();
890 config.palette = 1;
891 (0..=2).for_each(|i| {
892 config.colsb[i] = 64;
893 config.rows[i] = 16;
894 });
895 _tile_loadconfig(config.as_ptr());
896 _tile_zero::<0>();
897 _tile_loadd::<1>(&ones as *const u8, 64);
898 _tile_loadd::<2>(&twos as *const u8, 64);
899 _tile_dpbf8ps::<0, 1, 2>();
900 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
901 _tile_release();
902 assert_eq!(res, [[128.0_f32; 16]; 16]);
903 }
904 }
905
906 #[simd_test(enable = "amx-fp8")]
907 fn test_tile_dpbhf8ps() {
908 unsafe {
909 _init_amx();
910 let ones = [BF8_ONE; 1024];
911 let twos = [HF8_TWO; 1024];
912 let mut res = [[0.0_f32; 16]; 16];
913 let mut config = __tilecfg::default();
914 config.palette = 1;
915 (0..=2).for_each(|i| {
916 config.colsb[i] = 64;
917 config.rows[i] = 16;
918 });
919 _tile_loadconfig(config.as_ptr());
920 _tile_zero::<0>();
921 _tile_loadd::<1>(&ones as *const u8, 64);
922 _tile_loadd::<2>(&twos as *const u8, 64);
923 _tile_dpbhf8ps::<0, 1, 2>();
924 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
925 _tile_release();
926 assert_eq!(res, [[128.0_f32; 16]; 16]);
927 }
928 }
929
930 #[simd_test(enable = "amx-fp8")]
931 fn test_tile_dphbf8ps() {
932 unsafe {
933 _init_amx();
934 let ones = [HF8_ONE; 1024];
935 let twos = [BF8_TWO; 1024];
936 let mut res = [[0.0_f32; 16]; 16];
937 let mut config = __tilecfg::default();
938 config.palette = 1;
939 (0..=2).for_each(|i| {
940 config.colsb[i] = 64;
941 config.rows[i] = 16;
942 });
943 _tile_loadconfig(config.as_ptr());
944 _tile_zero::<0>();
945 _tile_loadd::<1>(&ones as *const u8, 64);
946 _tile_loadd::<2>(&twos as *const u8, 64);
947 _tile_dphbf8ps::<0, 1, 2>();
948 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
949 _tile_release();
950 assert_eq!(res, [[128.0_f32; 16]; 16]);
951 }
952 }
953
954 #[simd_test(enable = "amx-fp8")]
955 fn test_tile_dphf8ps() {
956 unsafe {
957 _init_amx();
958 let ones = [HF8_ONE; 1024];
959 let twos = [HF8_TWO; 1024];
960 let mut res = [[0.0_f32; 16]; 16];
961 let mut config = __tilecfg::default();
962 config.palette = 1;
963 (0..=2).for_each(|i| {
964 config.colsb[i] = 64;
965 config.rows[i] = 16;
966 });
967 _tile_loadconfig(config.as_ptr());
968 _tile_zero::<0>();
969 _tile_loadd::<1>(&ones as *const u8, 64);
970 _tile_loadd::<2>(&twos as *const u8, 64);
971 _tile_dphf8ps::<0, 1, 2>();
972 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
973 _tile_release();
974 assert_eq!(res, [[128.0_f32; 16]; 16]);
975 }
976 }
977
978 #[simd_test(enable = "amx-movrs")]
979 fn test_tile_loaddrs() {
980 unsafe {
981 _init_amx();
982 let mut config = __tilecfg::default();
983 config.palette = 1;
984 config.colsb[0] = 64;
985 config.rows[0] = 16;
986 _tile_loadconfig(config.as_ptr());
987 _tile_zero::<0>();
988 let mat = [1_i8; 1024];
989 _tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
990 let mut out = [[0_i8; 64]; 16];
991 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
992 _tile_release();
993 assert_eq!(out, [[1; 64]; 16]);
994 }
995 }
996
997 #[simd_test(enable = "amx-movrs")]
998 fn test_tile_stream_loaddrs() {
999 unsafe {
1000 _init_amx();
1001 let mut config = __tilecfg::default();
1002 config.palette = 1;
1003 config.colsb[0] = 64;
1004 config.rows[0] = 16;
1005 _tile_loadconfig(config.as_ptr());
1006 _tile_zero::<0>();
1007 let mat = [1_i8; 1024];
1008 _tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
1009 let mut out = [[0_i8; 64]; 16];
1010 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
1011 _tile_release();
1012 assert_eq!(out, [[1; 64]; 16]);
1013 }
1014 }
1015
1016 #[simd_test(enable = "amx-avx512,avx10.2")]
1017 fn test_tile_movrow() {
1018 unsafe {
1019 _init_amx();
1020 let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);
1021
1022 let mut config = __tilecfg::default();
1023 config.palette = 1;
1024 config.colsb[0] = 64;
1025 config.rows[0] = 16;
1026 _tile_loadconfig(config.as_ptr());
1027 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1028 for i in 0..16 {
1029 let row = _tile_movrow::<0>(i);
1030 assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
1031 }
1032 }
1033 }
1034
1035 #[simd_test(enable = "amx-avx512,avx10.2")]
1036 fn test_tile_cvtrowd2ps() {
1037 unsafe {
1038 _init_amx();
1039 let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1040
1041 let mut config = __tilecfg::default();
1042 config.palette = 1;
1043 config.colsb[0] = 64;
1044 config.rows[0] = 16;
1045 _tile_loadconfig(config.as_ptr());
1046 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1047 for i in 0..16 {
1048 let row = _tile_cvtrowd2ps::<0>(i);
1049 assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
1050 }
1051 }
1052 }
1053
1054 #[simd_test(enable = "amx-avx512,avx10.2")]
1055 fn test_tile_cvtrowps2phh() {
1056 unsafe {
1057 _init_amx();
1058 let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1059
1060 let mut config = __tilecfg::default();
1061 config.palette = 1;
1062 config.colsb[0] = 64;
1063 config.rows[0] = 16;
1064 _tile_loadconfig(config.as_ptr());
1065 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1066 for i in 0..16 {
1067 let row = _tile_cvtrowps2phh::<0>(i);
1068 assert_eq!(
1069 *row.as_f16x32().as_array(),
1070 array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
1071 );
1072 }
1073 }
1074 }
1075
1076 #[simd_test(enable = "amx-avx512,avx10.2")]
1077 fn test_tile_cvtrowps2phl() {
1078 unsafe {
1079 _init_amx();
1080 let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1081
1082 let mut config = __tilecfg::default();
1083 config.palette = 1;
1084 config.colsb[0] = 64;
1085 config.rows[0] = 16;
1086 _tile_loadconfig(config.as_ptr());
1087 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1088 for i in 0..16 {
1089 let row = _tile_cvtrowps2phl::<0>(i);
1090 assert_eq!(
1091 *row.as_f16x32().as_array(),
1092 array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
1093 );
1094 }
1095 }
1096 }
1097
1098 #[simd_test(enable = "amx-tf32")]
1099 fn test_tile_mmultf32ps() {
1100 unsafe {
1101 _init_amx();
1102 let a: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1103 let b: [[f32; 16]; 16] = [array::from_fn(|j| j as _); _];
1104 let mut res = [[0.0; 16]; 16];
1105
1106 let mut config = __tilecfg::default();
1107 config.palette = 1;
1108 (0..=2).for_each(|i| {
1109 config.colsb[i] = 64;
1110 config.rows[i] = 16;
1111 });
1112 _tile_loadconfig(config.as_ptr());
1113 _tile_zero::<0>();
1114 _tile_loadd::<1>(a.as_ptr().cast(), 64);
1115 _tile_loadd::<2>(b.as_ptr().cast(), 64);
1116 _tile_mmultf32ps::<0, 1, 2>();
1117 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1118 _tile_release();
1119
1120 let expected = array::from_fn(|i| array::from_fn(|j| 16.0 * i as f32 * j as f32));
1121 assert_eq!(res, expected);
1122 }
1123 }
1124}