Fix tdpbf16ps testcase

gcc/testsuite/ChangeLog:

	* gcc.target/i386/amx-check.h (check_float_tile_register):
	New check function for float to prevent precision loss.
	* gcc.target/i386/amxbf16-dpbf16ps-2.c: Correct the type convert
	and byte offset. Use the new check function.
This commit is contained in:
Haochen Jiang 2021-12-24 13:55:06 +08:00 committed by liuhongt
parent d1e111daee
commit 472568f5d8
2 changed files with 41 additions and 12 deletions

View File

@ -139,8 +139,27 @@ int check_tile_register (__tile* ref, __tile* target)
for (i = 0; i < rows; i++)
for (j = 0; j < colsb; j++)
if (ref->buf[i * colsb + j] != target->buf[i * colsb + j])
return 0;
if (ref->buf[i * colsb + j] != target->buf[i * colsb + j])
return 0;
return 1;
}
/* Compare float tile register value with __tile variable */
int check_float_tile_register (__tile* ref, __tile* target)
{
/* Tile register should be stored from tmm to
memory and compare with emulation results. */
int rows = target->rows;
int colsb = target->colsb / 4;
int i, j;
uint32_t *ref_buf = (uint32_t *) ref->buf;
uint32_t *target_buf = (uint32_t *) target->buf;
for (i = 0; i < rows; i++)
for (j = 0; j < colsb; j++)
if (abs(ref_buf[i * colsb + j] - target_buf[i * colsb + j]) > 1)
return 0;
return 1;
}

View File

@ -12,15 +12,25 @@ void test_amx_bf16_dpbf16ps ();
/* Transformation functions between bf16/float */
static uint16_t make_bf16 (float f)
{
uint32_t u = (uint32_t)f;
u = (u >> 16) & 0xffff;
return (uint16_t)u;
union
{
float f;
uint32_t u;
} fu;
fu.f = f;
fu.u = (fu.u >> 16) & 0xffff;
return (uint16_t) fu.u;
}
static float make_f32 (uint16_t bf)
{
uint32_t u = (uint32_t)(bf << 16);
return (float)u;
union
{
float f;
uint32_t u;
} fu;
fu.u = (uint32_t) bf << 16;
return fu.f;
}
/* Init tile buffer with bf16 pairs */
@ -54,10 +64,10 @@ void calc_matrix_dpbf16ps (__tile *dst, __tile *src1, __tile *src2)
for (t = 0; t < 2; t+=2)
{
dst_buf[i * N + k] +=
(make_f32(src1_buf[i * 4 * N + 4 * j + t]) *
make_f32(src2_buf[j * 4 * K + 4 * k + t])) +
(make_f32(src1_buf[i * 4 * N + 4 * j + t + 1]) *
make_f32(src2_buf[j * 4 * K + 4 * k + t + 1]));
(make_f32(src1_buf[i * 2 * N + 2 * j + t]) *
make_f32(src2_buf[j * 2 * K + 2 * k + t])) +
(make_f32(src1_buf[i * 2 * N + 2 * j + t + 1]) *
make_f32(src2_buf[j * 2 * K + 2 * k + t + 1]));
}
}
@ -80,6 +90,6 @@ void test_amx_bf16_dpbf16ps ()
_tile_dpbf16ps (1, 2, 3);
_tile_stored (1, dst_ref.buf, _STRIDE);
if (!check_tile_register (&dst_ref, &dst))
if (!check_float_tile_register (&dst_ref, &dst))
abort();
}