Merge tag '6.17-rc6-ksmbd-fixes' of git://git.samba.org/ksmbd

Pull smb server fixes from Steve French:

 - Two fixes for remaining_data_length and offset checks in receive path

 - Don't go over max SGEs which caused smbdirect send to fail (and
   trigger disconnect)

* tag '6.17-rc6-ksmbd-fixes' of git://git.samba.org/ksmbd:
  ksmbd: smbdirect: verify remaining_data_length respects max_fragmented_recv_size
  ksmbd: smbdirect: validate data_offset and data_length field of smb_direct_data_transfer
  smb: server: let smb_direct_writev() respect SMB_DIRECT_MAX_SEND_SGES
This commit is contained in:
Linus Torvalds
2025-09-17 18:23:01 -07:00

View File

@@ -554,7 +554,7 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
case SMB_DIRECT_MSG_DATA_TRANSFER: {
struct smb_direct_data_transfer *data_transfer =
(struct smb_direct_data_transfer *)recvmsg->packet;
unsigned int data_length;
u32 remaining_data_length, data_offset, data_length;
int avail_recvmsg_count, receive_credits;
if (wc->byte_len <
@@ -564,15 +564,25 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
return;
}
remaining_data_length = le32_to_cpu(data_transfer->remaining_data_length);
data_length = le32_to_cpu(data_transfer->data_length);
if (data_length) {
if (wc->byte_len < sizeof(struct smb_direct_data_transfer) +
(u64)data_length) {
put_recvmsg(t, recvmsg);
smb_direct_disconnect_rdma_connection(t);
return;
}
data_offset = le32_to_cpu(data_transfer->data_offset);
if (wc->byte_len < data_offset ||
wc->byte_len < (u64)data_offset + data_length) {
put_recvmsg(t, recvmsg);
smb_direct_disconnect_rdma_connection(t);
return;
}
if (remaining_data_length > t->max_fragmented_recv_size ||
data_length > t->max_fragmented_recv_size ||
(u64)remaining_data_length + (u64)data_length >
(u64)t->max_fragmented_recv_size) {
put_recvmsg(t, recvmsg);
smb_direct_disconnect_rdma_connection(t);
return;
}
if (data_length) {
if (t->full_packet_received)
recvmsg->first_segment = true;
@@ -1209,78 +1219,130 @@ static int smb_direct_writev(struct ksmbd_transport *t,
bool need_invalidate, unsigned int remote_key)
{
struct smb_direct_transport *st = smb_trans_direct_transfort(t);
int remaining_data_length;
int start, i, j;
int max_iov_size = st->max_send_size -
size_t remaining_data_length;
size_t iov_idx;
size_t iov_ofs;
size_t max_iov_size = st->max_send_size -
sizeof(struct smb_direct_data_transfer);
int ret;
struct kvec vec;
struct smb_direct_send_ctx send_ctx;
int error = 0;
if (st->status != SMB_DIRECT_CS_CONNECTED)
return -ENOTCONN;
//FIXME: skip RFC1002 header..
if (WARN_ON_ONCE(niovs <= 1 || iov[0].iov_len != 4))
return -EINVAL;
buflen -= 4;
iov_idx = 1;
iov_ofs = 0;
remaining_data_length = buflen;
ksmbd_debug(RDMA, "Sending smb (RDMA): smb_len=%u\n", buflen);
smb_direct_send_ctx_init(st, &send_ctx, need_invalidate, remote_key);
start = i = 1;
buflen = 0;
while (true) {
buflen += iov[i].iov_len;
if (buflen > max_iov_size) {
if (i > start) {
remaining_data_length -=
(buflen - iov[i].iov_len);
ret = smb_direct_post_send_data(st, &send_ctx,
&iov[start], i - start,
remaining_data_length);
if (ret)
goto done;
} else {
/* iov[start] is too big, break it */
int nvec = (buflen + max_iov_size - 1) /
max_iov_size;
while (remaining_data_length) {
struct kvec vecs[SMB_DIRECT_MAX_SEND_SGES - 1]; /* minus smbdirect hdr */
size_t possible_bytes = max_iov_size;
size_t possible_vecs;
size_t bytes = 0;
size_t nvecs = 0;
for (j = 0; j < nvec; j++) {
vec.iov_base =
(char *)iov[start].iov_base +
j * max_iov_size;
vec.iov_len =
min_t(int, max_iov_size,
buflen - max_iov_size * j);
remaining_data_length -= vec.iov_len;
ret = smb_direct_post_send_data(st, &send_ctx, &vec, 1,
remaining_data_length);
if (ret)
goto done;
}
i++;
if (i == niovs)
break;
}
start = i;
buflen = 0;
} else {
i++;
if (i == niovs) {
/* send out all remaining vecs */
remaining_data_length -= buflen;
ret = smb_direct_post_send_data(st, &send_ctx,
&iov[start], i - start,
remaining_data_length);
if (ret)
/*
* For the last message remaining_data_length should be
* have been 0 already!
*/
if (WARN_ON_ONCE(iov_idx >= niovs)) {
error = -EINVAL;
goto done;
}
/*
* We have 2 factors which limit the arguments we pass
* to smb_direct_post_send_data():
*
* 1. The number of supported sges for the send,
* while one is reserved for the smbdirect header.
* And we currently need one SGE per page.
* 2. The number of negotiated payload bytes per send.
*/
possible_vecs = min_t(size_t, ARRAY_SIZE(vecs), niovs - iov_idx);
while (iov_idx < niovs && possible_vecs && possible_bytes) {
struct kvec *v = &vecs[nvecs];
int page_count;
v->iov_base = ((u8 *)iov[iov_idx].iov_base) + iov_ofs;
v->iov_len = min_t(size_t,
iov[iov_idx].iov_len - iov_ofs,
possible_bytes);
page_count = get_buf_page_count(v->iov_base, v->iov_len);
if (page_count > possible_vecs) {
/*
* If the number of pages in the buffer
* is to much (because we currently require
* one SGE per page), we need to limit the
* length.
*
* We know possible_vecs is at least 1,
* so we always keep the first page.
*
* We need to calculate the number extra
* pages (epages) we can also keep.
*
* We calculate the number of bytes in the
* first page (fplen), this should never be
* larger than v->iov_len because page_count is
* at least 2, but adding a limitation feels
* better.
*
* Then we calculate the number of bytes (elen)
* we can keep for the extra pages.
*/
size_t epages = possible_vecs - 1;
size_t fpofs = offset_in_page(v->iov_base);
size_t fplen = min_t(size_t, PAGE_SIZE - fpofs, v->iov_len);
size_t elen = min_t(size_t, v->iov_len - fplen, epages*PAGE_SIZE);
v->iov_len = fplen + elen;
page_count = get_buf_page_count(v->iov_base, v->iov_len);
if (WARN_ON_ONCE(page_count > possible_vecs)) {
/*
* Something went wrong in the above
* logic...
*/
error = -EINVAL;
goto done;
break;
}
}
possible_vecs -= page_count;
nvecs += 1;
possible_bytes -= v->iov_len;
bytes += v->iov_len;
iov_ofs += v->iov_len;
if (iov_ofs >= iov[iov_idx].iov_len) {
iov_idx += 1;
iov_ofs = 0;
}
}
remaining_data_length -= bytes;
ret = smb_direct_post_send_data(st, &send_ctx,
vecs, nvecs,
remaining_data_length);
if (unlikely(ret)) {
error = ret;
goto done;
}
}
done:
ret = smb_direct_flush_send_list(st, &send_ctx, true);
if (unlikely(!ret && error))
ret = error;
/*
* As an optimization, we don't wait for individual I/O to finish
@@ -1744,6 +1806,11 @@ static int smb_direct_init_params(struct smb_direct_transport *t,
return -EINVAL;
}
if (device->attrs.max_send_sge < SMB_DIRECT_MAX_SEND_SGES) {
pr_err("warning: device max_send_sge = %d too small\n",
device->attrs.max_send_sge);
return -EINVAL;
}
if (device->attrs.max_recv_sge < SMB_DIRECT_MAX_RECV_SGES) {
pr_err("warning: device max_recv_sge = %d too small\n",
device->attrs.max_recv_sge);
@@ -1767,7 +1834,7 @@ static int smb_direct_init_params(struct smb_direct_transport *t,
cap->max_send_wr = max_send_wrs;
cap->max_recv_wr = t->recv_credit_max;
cap->max_send_sge = max_sge_per_wr;
cap->max_send_sge = SMB_DIRECT_MAX_SEND_SGES;
cap->max_recv_sge = SMB_DIRECT_MAX_RECV_SGES;
cap->max_inline_data = 0;
cap->max_rdma_ctxs = t->max_rw_credits;