diff --git a/include/net/netlink.h b/include/net/netlink.h index 9522a0bf1f3a..f1db8e594847 100644 --- a/include/net/netlink.h +++ b/include/net/netlink.h @@ -373,6 +373,9 @@ int nla_validate(const struct nlattr *head, int len, int maxtype, int nla_parse(struct nlattr **tb, int maxtype, const struct nlattr *head, int len, const struct nla_policy *policy, struct netlink_ext_ack *extack); +int nla_parse_strict(struct nlattr **tb, int maxtype, const struct nlattr *head, + int len, const struct nla_policy *policy, + struct netlink_ext_ack *extack); int nla_policy_len(const struct nla_policy *, int); struct nlattr *nla_find(const struct nlattr *head, int len, int attrtype); size_t nla_strlcpy(char *dst, const struct nlattr *nla, size_t dstsize); @@ -525,6 +528,20 @@ static inline int nlmsg_parse(const struct nlmsghdr *nlh, int hdrlen, nlmsg_attrlen(nlh, hdrlen), policy, extack); } +static inline int nlmsg_parse_strict(const struct nlmsghdr *nlh, int hdrlen, + struct nlattr *tb[], int maxtype, + const struct nla_policy *policy, + struct netlink_ext_ack *extack) +{ + if (nlh->nlmsg_len < nlmsg_msg_size(hdrlen)) { + NL_SET_ERR_MSG(extack, "Invalid header length"); + return -EINVAL; + } + + return nla_parse_strict(tb, maxtype, nlmsg_attrdata(nlh, hdrlen), + nlmsg_attrlen(nlh, hdrlen), policy, extack); +} + /** * nlmsg_find_attr - find a specific attribute in a netlink message * @nlh: netlink message header diff --git a/lib/nlattr.c b/lib/nlattr.c index 1e900bb414ef..d26de6156b97 100644 --- a/lib/nlattr.c +++ b/lib/nlattr.c @@ -391,9 +391,10 @@ EXPORT_SYMBOL(nla_policy_len); * * Returns 0 on success or a negative error code. */ -int nla_parse(struct nlattr **tb, int maxtype, const struct nlattr *head, - int len, const struct nla_policy *policy, - struct netlink_ext_ack *extack) +static int __nla_parse(struct nlattr **tb, int maxtype, + const struct nlattr *head, int len, + bool strict, const struct nla_policy *policy, + struct netlink_ext_ack *extack) { const struct nlattr *nla; int rem; @@ -403,27 +404,50 @@ int nla_parse(struct nlattr **tb, int maxtype, const struct nlattr *head, nla_for_each_attr(nla, head, len, rem) { u16 type = nla_type(nla); - if (type > 0 && type <= maxtype) { - if (policy) { - int err = validate_nla(nla, maxtype, policy, - extack); - - if (err < 0) - return err; + if (type == 0 || type > maxtype) { + if (strict) { + NL_SET_ERR_MSG(extack, "Unknown attribute type"); + return -EINVAL; } - - tb[type] = (struct nlattr *)nla; + continue; } + if (policy) { + int err = validate_nla(nla, maxtype, policy, extack); + + if (err < 0) + return err; + } + + tb[type] = (struct nlattr *)nla; } - if (unlikely(rem > 0)) + if (unlikely(rem > 0)) { pr_warn_ratelimited("netlink: %d bytes leftover after parsing attributes in process `%s'.\n", rem, current->comm); + NL_SET_ERR_MSG(extack, "bytes leftover after parsing attributes"); + if (strict) + return -EINVAL; + } return 0; } + +int nla_parse(struct nlattr **tb, int maxtype, const struct nlattr *head, + int len, const struct nla_policy *policy, + struct netlink_ext_ack *extack) +{ + return __nla_parse(tb, maxtype, head, len, false, policy, extack); +} EXPORT_SYMBOL(nla_parse); +int nla_parse_strict(struct nlattr **tb, int maxtype, const struct nlattr *head, + int len, const struct nla_policy *policy, + struct netlink_ext_ack *extack) +{ + return __nla_parse(tb, maxtype, head, len, true, policy, extack); +} +EXPORT_SYMBOL(nla_parse_strict); + /** * nla_find - Find a specific attribute in a stream of attributes * @head: head of attribute stream