summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpaulhsia <paulhsia@chromium.org>2019-09-08 15:30:06 +0800
committerCommit Bot <commit-bot@chromium.org>2019-09-18 18:13:32 +0000
commit9c2b8cdb555c2ef6d7c1937a7886e4c6e83f0e0e (patch)
tree39f25b07f4b04a60a33ff55c54dba0e552512667
parentf725c161a7bd515f049c32d4cdfa507be6029841 (diff)
downloadadhd-9c2b8cdb555c2ef6d7c1937a7886e4c6e83f0e0e.tar.gz
CRAS: rclient: Add rclient_validate_stream_connect_message
- Unify cras_connect_message checks to rclient_validate_stream_connect_message which supports - client id check - stream direction check - Add supported_directions to cras_rclient. BUG=chromium:937765 TEST=Build and run unit tests Change-Id: I0c17c1d89ba1ad2653f4bca7feac93abb46c3e30 Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/third_party/adhd/+/1790225 Reviewed-by: Yu-Hsuan Hsu <yuhsuan@chromium.org> Commit-Queue: Chih-Yang Hsia <paulhsia@chromium.org> Tested-by: Chih-Yang Hsia <paulhsia@chromium.org>
-rw-r--r--cras/src/common/cras_types.h10
-rw-r--r--cras/src/server/cras_control_rclient.c8
-rw-r--r--cras/src/server/cras_playback_rclient.c17
-rw-r--r--cras/src/server/cras_rclient.h2
-rw-r--r--cras/src/server/cras_rclient_util.c24
-rw-r--r--cras/src/server/cras_rclient_util.h14
-rw-r--r--cras/src/tests/control_rclient_unittest.cc74
-rw-r--r--cras/src/tests/playback_rclient_unittest.cc40
8 files changed, 158 insertions, 31 deletions
diff --git a/cras/src/common/cras_types.h b/cras/src/common/cras_types.h
index ebe76d89..6ff7e1ae 100644
--- a/cras/src/common/cras_types.h
+++ b/cras/src/common/cras_types.h
@@ -59,6 +59,16 @@ enum CRAS_STREAM_DIRECTION {
CRAS_NUM_DIRECTIONS
};
+/* Bitmask for supporting all CRAS_STREAM_DIRECTION. */
+#define CRAS_STREAM_ALL_DIRECTION ((1 << CRAS_NUM_DIRECTIONS) - 1)
+
+/* Converts CRAS_STREAM_DIRECTION to bitmask. */
+static inline int
+cras_stream_direction_mask(const enum CRAS_STREAM_DIRECTION dir)
+{
+ return (1 << dir);
+}
+
/*
* Flags for stream types.
* BULK_AUDIO_OK - This stream is OK with receiving up to a full shm of samples
diff --git a/cras/src/server/cras_control_rclient.c b/cras/src/server/cras_control_rclient.c
index 811e03ee..f51bd65a 100644
--- a/cras/src/server/cras_control_rclient.c
+++ b/cras/src/server/cras_control_rclient.c
@@ -40,6 +40,10 @@ static int handle_client_stream_connect(struct cras_rclient *client,
int rc, header_fd, samples_fd;
int stream_fds[2];
+ rc = rclient_validate_stream_connect_message(client, msg);
+ if (rc)
+ goto close_shm_fd;
+
unpack_cras_audio_format(&remote_fmt, &msg->format);
rc = rclient_validate_stream_connect_fds(aud_fd, client_shm_fd,
@@ -653,6 +657,10 @@ struct cras_rclient *cras_control_rclient_create(int fd, size_t id)
client->fd = fd;
client->id = id;
client->ops = &cras_control_rclient_ops;
+ client->supported_directions = CRAS_STREAM_ALL_DIRECTION;
+ /* Filters CRAS_STREAM_UNDEFINED stream out. */
+ client->supported_directions ^=
+ cras_stream_direction_mask(CRAS_STREAM_UNDEFINED);
cras_fill_client_connected(&msg, client->id);
state_fd = cras_sys_state_shm_fd();
diff --git a/cras/src/server/cras_playback_rclient.c b/cras/src/server/cras_playback_rclient.c
index 9fe4471e..248ebdab 100644
--- a/cras/src/server/cras_playback_rclient.c
+++ b/cras/src/server/cras_playback_rclient.c
@@ -30,20 +30,9 @@ static int handle_client_stream_connect(struct cras_rclient *client,
int rc, header_fd, samples_fd;
int stream_fds[2];
- if (!cras_valid_stream_id(msg->stream_id, client->id)) {
- syslog(LOG_ERR,
- "stream_connect: invalid stream_id: %x for "
- "client: %zx.\n",
- msg->stream_id, client->id);
- rc = -EINVAL;
- goto reply_err;
- }
-
- if (msg->direction != CRAS_STREAM_OUTPUT) {
- syslog(LOG_ERR, "Invalid stream direction.\n");
- rc = -EINVAL;
+ rc = rclient_validate_stream_connect_message(client, msg);
+ if (rc)
goto reply_err;
- }
unpack_cras_audio_format(&remote_fmt, &msg->format);
@@ -220,6 +209,8 @@ struct cras_rclient *cras_playback_rclient_create(int fd, size_t id)
client->id = id;
client->ops = &cras_playback_rclient_ops;
+ client->supported_directions =
+ cras_stream_direction_mask(CRAS_STREAM_OUTPUT);
cras_fill_client_connected(&msg, client->id);
state_fd = cras_sys_state_shm_fd();
diff --git a/cras/src/server/cras_rclient.h b/cras/src/server/cras_rclient.h
index d63407d2..c1545a87 100644
--- a/cras/src/server/cras_rclient.h
+++ b/cras/src/server/cras_rclient.h
@@ -17,12 +17,14 @@ struct cras_server_message;
* id - The id of the client.
* fd - Connection for client communication.
* ops - cras_rclient_ops for the cras_rclient.
+ * supported_directions - Bit mask for supported stream directions.
*/
struct cras_rclient {
struct cras_observer_client *observer;
size_t id;
int fd;
const struct cras_rclient_ops *ops;
+ int supported_directions;
};
/* Operations for cras_rclient.
diff --git a/cras/src/server/cras_rclient_util.c b/cras/src/server/cras_rclient_util.c
index 47daa2ae..2f1d88f9 100644
--- a/cras/src/server/cras_rclient_util.c
+++ b/cras/src/server/cras_rclient_util.c
@@ -12,6 +12,7 @@
#include "cras_rclient_util.h"
#include "cras_rstream.h"
#include "cras_tm.h"
+#include "cras_types.h"
#include "cras_util.h"
#include "stream_list.h"
@@ -53,6 +54,29 @@ void rclient_fill_cras_rstream_config(struct cras_rclient *client,
stream_config->client = client;
}
+int rclient_validate_stream_connect_message(
+ const struct cras_rclient *client,
+ const struct cras_connect_message *msg)
+{
+ if (!cras_valid_stream_id(msg->stream_id, client->id)) {
+ syslog(LOG_ERR,
+ "stream_connect: invalid stream_id: %x for "
+ "client: %zx.\n",
+ msg->stream_id, client->id);
+ return -EINVAL;
+ }
+
+ int direction = cras_stream_direction_mask(msg->direction);
+ if (!(client->supported_directions & direction)) {
+ syslog(LOG_ERR,
+ "stream_connect: invalid stream direction: %x for "
+ "client: %zx.\n",
+ msg->direction, client->id);
+ return -EINVAL;
+ }
+ return 0;
+}
+
int rclient_validate_stream_connect_fds(int audio_fd, int client_shm_fd,
size_t client_shm_size)
{
diff --git a/cras/src/server/cras_rclient_util.h b/cras/src/server/cras_rclient_util.h
index e9c0079e..8292a113 100644
--- a/cras/src/server/cras_rclient_util.h
+++ b/cras/src/server/cras_rclient_util.h
@@ -32,6 +32,20 @@ void rclient_fill_cras_rstream_config(
const struct cras_audio_format *remote_format,
struct cras_rstream_config *stream_config);
+/* Checks if the incoming stream connect message contains
+ * - stream_id matches client->id.
+ * - direction supported by the client.
+ *
+ * Args:
+ * client - The cras_rclient which gets the message.
+ *
+ * Returns:
+ * 0 on success, negative error on failure.
+ */
+int rclient_validate_stream_connect_message(
+ const struct cras_rclient *client,
+ const struct cras_connect_message *msg);
+
/*
* Converts an old version of connect message to the correct
* cras_connect_message. Returns zero on success, negative on failure.
diff --git a/cras/src/tests/control_rclient_unittest.cc b/cras/src/tests/control_rclient_unittest.cc
index a5d17629..75573faf 100644
--- a/cras/src/tests/control_rclient_unittest.cc
+++ b/cras/src/tests/control_rclient_unittest.cc
@@ -136,7 +136,7 @@ class RClientMessagesSuite : public testing::Test {
rc = pipe(pipe_fds_);
if (rc < 0)
return;
- rclient_ = cras_control_rclient_create(pipe_fds_[1], 800);
+ rclient_ = cras_control_rclient_create(pipe_fds_[1], 1);
rc = read(pipe_fds_[0], &msg, sizeof(msg));
if (rc < 0)
return;
@@ -246,6 +246,78 @@ TEST_F(RClientMessagesSuite, ConnectMsgFromOldClient) {
EXPECT_EQ(1, cras_server_metrics_stream_config_called);
}
+TEST_F(RClientMessagesSuite, StreamConnectMessageValidDirection) {
+ struct cras_client_stream_connected out_msg;
+ int rc;
+ int called = 0;
+
+ for (int i = 0; i < CRAS_NUM_DIRECTIONS; i++) {
+ connect_msg_.direction = static_cast<CRAS_STREAM_DIRECTION>(i);
+ if (connect_msg_.direction == CRAS_STREAM_UNDEFINED)
+ continue;
+ called++;
+ cras_rstream_create_stream_out = rstream_;
+ cras_iodev_attach_stream_retval = 0;
+
+ fd_ = 100;
+ rc = rclient_->ops->handle_message_from_client(
+ rclient_, &connect_msg_.header, &fd_, 1);
+ EXPECT_EQ(0, rc);
+ EXPECT_EQ(called, cras_make_fd_nonblocking_called);
+
+ rc = read(pipe_fds_[0], &out_msg, sizeof(out_msg));
+ EXPECT_EQ(sizeof(out_msg), rc);
+ EXPECT_EQ(stream_id_, out_msg.stream_id);
+ EXPECT_EQ(0, out_msg.err);
+ EXPECT_EQ(called, stream_list_add_stream_called);
+ EXPECT_EQ(0, stream_list_disconnect_stream_called);
+ EXPECT_EQ(called, cras_server_metrics_stream_config_called);
+ }
+}
+
+TEST_F(RClientMessagesSuite, StreamConnectMessageInvalidDirection) {
+ struct cras_client_stream_connected out_msg;
+ int rc;
+
+ connect_msg_.direction = CRAS_STREAM_UNDEFINED;
+ cras_rstream_create_stream_out = rstream_;
+ cras_iodev_attach_stream_retval = 0;
+
+ fd_ = 100;
+ rc = rclient_->ops->handle_message_from_client(rclient_, &connect_msg_.header,
+ &fd_, 1);
+ EXPECT_EQ(-EINVAL, rc);
+ EXPECT_EQ(0, cras_make_fd_nonblocking_called);
+
+ rc = read(pipe_fds_[0], &out_msg, sizeof(out_msg));
+ EXPECT_EQ(sizeof(out_msg), rc);
+ EXPECT_EQ(stream_id_, out_msg.stream_id);
+ EXPECT_EQ(-EINVAL, out_msg.err);
+ EXPECT_EQ(0, stream_list_add_stream_called);
+ EXPECT_EQ(0, stream_list_disconnect_stream_called);
+ EXPECT_EQ(0, cras_server_metrics_stream_config_called);
+}
+
+TEST_F(RClientMessagesSuite, StreamConnectMessageInvalidClientId) {
+ struct cras_client_stream_connected out_msg;
+ int rc;
+
+ connect_msg_.stream_id = 0x20002; // stream_id with invalid client_id
+
+ fd_ = 100;
+ rc = rclient_->ops->handle_message_from_client(rclient_, &connect_msg_.header,
+ &fd_, 1);
+ EXPECT_EQ(-EINVAL, rc);
+ EXPECT_EQ(0, cras_make_fd_nonblocking_called);
+ EXPECT_EQ(0, stream_list_add_stream_called);
+ EXPECT_EQ(0, stream_list_disconnect_stream_called);
+
+ rc = read(pipe_fds_[0], &out_msg, sizeof(out_msg));
+ EXPECT_EQ(sizeof(out_msg), rc);
+ EXPECT_EQ(-EINVAL, out_msg.err);
+ EXPECT_EQ(connect_msg_.stream_id, out_msg.stream_id);
+}
+
TEST_F(RClientMessagesSuite, SuccessReply) {
struct cras_client_stream_connected out_msg;
int rc;
diff --git a/cras/src/tests/playback_rclient_unittest.cc b/cras/src/tests/playback_rclient_unittest.cc
index f3aa2b73..01d8370b 100644
--- a/cras/src/tests/playback_rclient_unittest.cc
+++ b/cras/src/tests/playback_rclient_unittest.cc
@@ -128,24 +128,30 @@ TEST_F(CPRMessageSuite, StreamConnectMessageInvalidDirection) {
struct cras_connect_message msg;
cras_stream_id_t stream_id = 0x10002;
- cras_fill_connect_message(&msg, CRAS_STREAM_INPUT, stream_id,
- CRAS_STREAM_TYPE_DEFAULT, CRAS_CLIENT_TYPE_UNKNOWN,
- 480, 240, /*flags=*/0, /*effects=*/0, fmt,
- NO_DEVICE, /*client_shm_size=*/0);
- ASSERT_EQ(stream_id, msg.stream_id);
- fd_ = 100;
- rc =
- rclient_->ops->handle_message_from_client(rclient_, &msg.header, &fd_, 1);
- EXPECT_EQ(-EINVAL, rc);
- EXPECT_EQ(0, cras_make_fd_nonblocking_called);
- EXPECT_EQ(0, stream_list_add_called);
- EXPECT_EQ(0, stream_list_rm_called);
-
- rc = read(pipe_fds_[0], &out_msg, sizeof(out_msg));
- EXPECT_EQ(sizeof(out_msg), rc);
- EXPECT_EQ(-EINVAL, out_msg.err);
- EXPECT_EQ(stream_id, out_msg.stream_id);
+ for (int i = 0; i < CRAS_NUM_DIRECTIONS; i++) {
+ const auto dir = static_cast<CRAS_STREAM_DIRECTION>(i);
+ if (dir == CRAS_STREAM_OUTPUT)
+ continue;
+ cras_fill_connect_message(&msg, dir, stream_id, CRAS_STREAM_TYPE_DEFAULT,
+ CRAS_CLIENT_TYPE_UNKNOWN, 480, 240, /*flags=*/0,
+ /*effects=*/0, fmt, NO_DEVICE,
+ /*client_shm_size=*/0);
+ ASSERT_EQ(stream_id, msg.stream_id);
+
+ fd_ = 100;
+ rc = rclient_->ops->handle_message_from_client(rclient_, &msg.header, &fd_,
+ 1);
+ EXPECT_EQ(-EINVAL, rc);
+ EXPECT_EQ(0, cras_make_fd_nonblocking_called);
+ EXPECT_EQ(0, stream_list_add_called);
+ EXPECT_EQ(0, stream_list_rm_called);
+
+ rc = read(pipe_fds_[0], &out_msg, sizeof(out_msg));
+ EXPECT_EQ(sizeof(out_msg), rc);
+ EXPECT_EQ(-EINVAL, out_msg.err);
+ EXPECT_EQ(stream_id, out_msg.stream_id);
+ }
}
TEST_F(CPRMessageSuite, StreamConnectMessageInvalidClientId) {