tree-optimization/101267 - fix SLP vect with masked operations

This fixes the missed handling of external/constant mask SLP
operations, for the testcase in particular masked loads.  The
patch adjusts the vect_check_scalar_mask API to reflect the
required vect_is_simple_use SLP compatible API plus adjusts
for the special handling of masked loads in SLP discovery.

The issue is likely latent.

2021-06-30  Richard Biener  <rguenther@suse.de>

	PR tree-optimization/101267
	* tree-vect-stmts.c (vect_check_scalar_mask): Adjust
	API and use SLP compatible interface of vect_is_simple_use.
	Reject not vectorized SLP defs for callers that do not support
	that.
	(vect_check_store_rhs): Handle masked stores and pass down
	the appropriate operator index.
	(vectorizable_call): Adjust.
	(vectorizable_store): Likewise.
	(vectorizable_load): Likewise.  Handle SLP pecularity of
	masked loads.
	(vect_is_simple_use): Remove special-casing of masked stores.

	* gfortran.dg/pr101267.f90: New testcase.
This commit is contained in:
Richard Biener 2021-06-30 12:35:45 +02:00
parent e61ffa2014
commit a075350ee7
2 changed files with 77 additions and 38 deletions

View File

@ -0,0 +1,23 @@
! { dg-do compile }
! { dg-options "-Ofast" }
! { dg-additional-options "-march=znver2" { target x86_64-*-* i?86-*-* } }
SUBROUTINE sfddagd( regime, znt,ite ,jte )
REAL, DIMENSION( ime, IN) :: regime, znt
REAL, DIMENSION( ite, jte) :: wndcor_u
LOGICAL wrf_dm_on_monitor
IF( int4 == 1 ) THEN
DO j=jts,jtf
DO i=itsu,itf
reg = regime(i, j)
IF( reg > 10.0 ) THEN
znt0 = znt(i-1, j) + znt(i, j)
IF( znt0 <= 0.2) THEN
wndcor_u(i,j) = 0.2
ENDIF
ENDIF
ENDDO
ENDDO
IF ( wrf_dm_on_monitor()) THEN
ENDIF
ENDIF
END

View File

@ -2439,17 +2439,31 @@ get_load_store_type (vec_info *vinfo, stmt_vec_info stmt_info,
return true;
}
/* Return true if boolean argument MASK is suitable for vectorizing
conditional operation STMT_INFO. When returning true, store the type
of the definition in *MASK_DT_OUT and the type of the vectorized mask
in *MASK_VECTYPE_OUT. */
/* Return true if boolean argument at MASK_INDEX is suitable for vectorizing
conditional operation STMT_INFO. When returning true, store the mask
in *MASK, the type of its definition in *MASK_DT_OUT, the type of the
vectorized mask in *MASK_VECTYPE_OUT and the SLP node corresponding
to the mask in *MASK_NODE if MASK_NODE is not NULL. */
static bool
vect_check_scalar_mask (vec_info *vinfo, stmt_vec_info stmt_info, tree mask,
vect_def_type *mask_dt_out,
tree *mask_vectype_out)
vect_check_scalar_mask (vec_info *vinfo, stmt_vec_info stmt_info,
slp_tree slp_node, unsigned mask_index,
tree *mask, slp_tree *mask_node,
vect_def_type *mask_dt_out, tree *mask_vectype_out)
{
if (!VECT_SCALAR_BOOLEAN_TYPE_P (TREE_TYPE (mask)))
enum vect_def_type mask_dt;
tree mask_vectype;
slp_tree mask_node_1;
if (!vect_is_simple_use (vinfo, stmt_info, slp_node, mask_index,
mask, &mask_node_1, &mask_dt, &mask_vectype))
{
if (dump_enabled_p ())
dump_printf_loc (MSG_MISSED_OPTIMIZATION, vect_location,
"mask use not simple.\n");
return false;
}
if (!VECT_SCALAR_BOOLEAN_TYPE_P (TREE_TYPE (*mask)))
{
if (dump_enabled_p ())
dump_printf_loc (MSG_MISSED_OPTIMIZATION, vect_location,
@ -2457,7 +2471,7 @@ vect_check_scalar_mask (vec_info *vinfo, stmt_vec_info stmt_info, tree mask,
return false;
}
if (TREE_CODE (mask) != SSA_NAME)
if (TREE_CODE (*mask) != SSA_NAME)
{
if (dump_enabled_p ())
dump_printf_loc (MSG_MISSED_OPTIMIZATION, vect_location,
@ -2465,13 +2479,15 @@ vect_check_scalar_mask (vec_info *vinfo, stmt_vec_info stmt_info, tree mask,
return false;
}
enum vect_def_type mask_dt;
tree mask_vectype;
if (!vect_is_simple_use (mask, vinfo, &mask_dt, &mask_vectype))
/* If the caller is not prepared for adjusting an external/constant
SLP mask vector type fail. */
if (slp_node
&& !mask_node
&& SLP_TREE_DEF_TYPE (mask_node_1) != vect_internal_def)
{
if (dump_enabled_p ())
dump_printf_loc (MSG_MISSED_OPTIMIZATION, vect_location,
"mask use not simple.\n");
"SLP mask argument is not vectorized.\n");
return false;
}
@ -2501,6 +2517,8 @@ vect_check_scalar_mask (vec_info *vinfo, stmt_vec_info stmt_info, tree mask,
*mask_dt_out = mask_dt;
*mask_vectype_out = mask_vectype;
if (mask_node)
*mask_node = mask_node_1;
return true;
}
@ -2525,10 +2543,18 @@ vect_check_store_rhs (vec_info *vinfo, stmt_vec_info stmt_info,
return false;
}
unsigned op_no = 0;
if (gcall *call = dyn_cast <gcall *> (stmt_info->stmt))
{
if (gimple_call_internal_p (call)
&& internal_store_fn_p (gimple_call_internal_fn (call)))
op_no = internal_fn_stored_value_index (gimple_call_internal_fn (call));
}
enum vect_def_type rhs_dt;
tree rhs_vectype;
slp_tree slp_op;
if (!vect_is_simple_use (vinfo, stmt_info, slp_node, 0,
if (!vect_is_simple_use (vinfo, stmt_info, slp_node, op_no,
&rhs, &slp_op, &rhs_dt, &rhs_vectype))
{
if (dump_enabled_p ())
@ -3163,9 +3189,8 @@ vectorizable_call (vec_info *vinfo,
{
if ((int) i == mask_opno)
{
op = gimple_call_arg (stmt, i);
if (!vect_check_scalar_mask (vinfo,
stmt_info, op, &dt[i], &vectypes[i]))
if (!vect_check_scalar_mask (vinfo, stmt_info, slp_node, mask_opno,
&op, &slp_op[i], &dt[i], &vectypes[i]))
return false;
continue;
}
@ -7213,13 +7238,10 @@ vectorizable_store (vec_info *vinfo,
}
int mask_index = internal_fn_mask_index (ifn);
if (mask_index >= 0)
{
mask = gimple_call_arg (call, mask_index);
if (!vect_check_scalar_mask (vinfo, stmt_info, mask, &mask_dt,
&mask_vectype))
return false;
}
if (mask_index >= 0
&& !vect_check_scalar_mask (vinfo, stmt_info, slp_node, mask_index,
&mask, NULL, &mask_dt, &mask_vectype))
return false;
}
op = vect_get_store_rhs (stmt_info);
@ -8494,13 +8516,13 @@ vectorizable_load (vec_info *vinfo,
return false;
int mask_index = internal_fn_mask_index (ifn);
if (mask_index >= 0)
{
mask = gimple_call_arg (call, mask_index);
if (!vect_check_scalar_mask (vinfo, stmt_info, mask, &mask_dt,
&mask_vectype))
return false;
}
if (mask_index >= 0
&& !vect_check_scalar_mask (vinfo, stmt_info, slp_node,
/* ??? For SLP we only have operands for
the mask operand. */
slp_node ? 0 : mask_index,
&mask, NULL, &mask_dt, &mask_vectype))
return false;
}
tree vectype = STMT_VINFO_VECTYPE (stmt_info);
@ -11484,13 +11506,7 @@ vect_is_simple_use (vec_info *vinfo, stmt_vec_info stmt, slp_tree slp_node,
*op = gimple_op (ass, operand + 1);
}
else if (gcall *call = dyn_cast <gcall *> (stmt->stmt))
{
if (gimple_call_internal_p (call)
&& internal_store_fn_p (gimple_call_internal_fn (call)))
operand = internal_fn_stored_value_index (gimple_call_internal_fn
(call));
*op = gimple_call_arg (call, operand);
}
*op = gimple_call_arg (call, operand);
else
gcc_unreachable ();
return vect_is_simple_use (*op, vinfo, dt, vectype, def_stmt_info_out);