diff options
author | paulhsia <paulhsia@chromium.org> | 2019-09-08 15:30:06 +0800 |
---|---|---|
committer | Commit Bot <commit-bot@chromium.org> | 2019-09-18 18:13:32 +0000 |
commit | 9c2b8cdb555c2ef6d7c1937a7886e4c6e83f0e0e (patch) | |
tree | 39f25b07f4b04a60a33ff55c54dba0e552512667 | |
parent | f725c161a7bd515f049c32d4cdfa507be6029841 (diff) | |
download | adhd-9c2b8cdb555c2ef6d7c1937a7886e4c6e83f0e0e.tar.gz |
CRAS: rclient: Add rclient_validate_stream_connect_message
- Unify cras_connect_message checks to
rclient_validate_stream_connect_message which supports
- client id check
- stream direction check
- Add supported_directions to cras_rclient.
BUG=chromium:937765
TEST=Build and run unit tests
Change-Id: I0c17c1d89ba1ad2653f4bca7feac93abb46c3e30
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/third_party/adhd/+/1790225
Reviewed-by: Yu-Hsuan Hsu <yuhsuan@chromium.org>
Commit-Queue: Chih-Yang Hsia <paulhsia@chromium.org>
Tested-by: Chih-Yang Hsia <paulhsia@chromium.org>
-rw-r--r-- | cras/src/common/cras_types.h | 10 | ||||
-rw-r--r-- | cras/src/server/cras_control_rclient.c | 8 | ||||
-rw-r--r-- | cras/src/server/cras_playback_rclient.c | 17 | ||||
-rw-r--r-- | cras/src/server/cras_rclient.h | 2 | ||||
-rw-r--r-- | cras/src/server/cras_rclient_util.c | 24 | ||||
-rw-r--r-- | cras/src/server/cras_rclient_util.h | 14 | ||||
-rw-r--r-- | cras/src/tests/control_rclient_unittest.cc | 74 | ||||
-rw-r--r-- | cras/src/tests/playback_rclient_unittest.cc | 40 |
8 files changed, 158 insertions, 31 deletions
diff --git a/cras/src/common/cras_types.h b/cras/src/common/cras_types.h index ebe76d89..6ff7e1ae 100644 --- a/cras/src/common/cras_types.h +++ b/cras/src/common/cras_types.h @@ -59,6 +59,16 @@ enum CRAS_STREAM_DIRECTION { CRAS_NUM_DIRECTIONS }; +/* Bitmask for supporting all CRAS_STREAM_DIRECTION. */ +#define CRAS_STREAM_ALL_DIRECTION ((1 << CRAS_NUM_DIRECTIONS) - 1) + +/* Converts CRAS_STREAM_DIRECTION to bitmask. */ +static inline int +cras_stream_direction_mask(const enum CRAS_STREAM_DIRECTION dir) +{ + return (1 << dir); +} + /* * Flags for stream types. * BULK_AUDIO_OK - This stream is OK with receiving up to a full shm of samples diff --git a/cras/src/server/cras_control_rclient.c b/cras/src/server/cras_control_rclient.c index 811e03ee..f51bd65a 100644 --- a/cras/src/server/cras_control_rclient.c +++ b/cras/src/server/cras_control_rclient.c @@ -40,6 +40,10 @@ static int handle_client_stream_connect(struct cras_rclient *client, int rc, header_fd, samples_fd; int stream_fds[2]; + rc = rclient_validate_stream_connect_message(client, msg); + if (rc) + goto close_shm_fd; + unpack_cras_audio_format(&remote_fmt, &msg->format); rc = rclient_validate_stream_connect_fds(aud_fd, client_shm_fd, @@ -653,6 +657,10 @@ struct cras_rclient *cras_control_rclient_create(int fd, size_t id) client->fd = fd; client->id = id; client->ops = &cras_control_rclient_ops; + client->supported_directions = CRAS_STREAM_ALL_DIRECTION; + /* Filters CRAS_STREAM_UNDEFINED stream out. */ + client->supported_directions ^= + cras_stream_direction_mask(CRAS_STREAM_UNDEFINED); cras_fill_client_connected(&msg, client->id); state_fd = cras_sys_state_shm_fd(); diff --git a/cras/src/server/cras_playback_rclient.c b/cras/src/server/cras_playback_rclient.c index 9fe4471e..248ebdab 100644 --- a/cras/src/server/cras_playback_rclient.c +++ b/cras/src/server/cras_playback_rclient.c @@ -30,20 +30,9 @@ static int handle_client_stream_connect(struct cras_rclient *client, int rc, header_fd, samples_fd; int stream_fds[2]; - if (!cras_valid_stream_id(msg->stream_id, client->id)) { - syslog(LOG_ERR, - "stream_connect: invalid stream_id: %x for " - "client: %zx.\n", - msg->stream_id, client->id); - rc = -EINVAL; - goto reply_err; - } - - if (msg->direction != CRAS_STREAM_OUTPUT) { - syslog(LOG_ERR, "Invalid stream direction.\n"); - rc = -EINVAL; + rc = rclient_validate_stream_connect_message(client, msg); + if (rc) goto reply_err; - } unpack_cras_audio_format(&remote_fmt, &msg->format); @@ -220,6 +209,8 @@ struct cras_rclient *cras_playback_rclient_create(int fd, size_t id) client->id = id; client->ops = &cras_playback_rclient_ops; + client->supported_directions = + cras_stream_direction_mask(CRAS_STREAM_OUTPUT); cras_fill_client_connected(&msg, client->id); state_fd = cras_sys_state_shm_fd(); diff --git a/cras/src/server/cras_rclient.h b/cras/src/server/cras_rclient.h index d63407d2..c1545a87 100644 --- a/cras/src/server/cras_rclient.h +++ b/cras/src/server/cras_rclient.h @@ -17,12 +17,14 @@ struct cras_server_message; * id - The id of the client. * fd - Connection for client communication. * ops - cras_rclient_ops for the cras_rclient. + * supported_directions - Bit mask for supported stream directions. */ struct cras_rclient { struct cras_observer_client *observer; size_t id; int fd; const struct cras_rclient_ops *ops; + int supported_directions; }; /* Operations for cras_rclient. diff --git a/cras/src/server/cras_rclient_util.c b/cras/src/server/cras_rclient_util.c index 47daa2ae..2f1d88f9 100644 --- a/cras/src/server/cras_rclient_util.c +++ b/cras/src/server/cras_rclient_util.c @@ -12,6 +12,7 @@ #include "cras_rclient_util.h" #include "cras_rstream.h" #include "cras_tm.h" +#include "cras_types.h" #include "cras_util.h" #include "stream_list.h" @@ -53,6 +54,29 @@ void rclient_fill_cras_rstream_config(struct cras_rclient *client, stream_config->client = client; } +int rclient_validate_stream_connect_message( + const struct cras_rclient *client, + const struct cras_connect_message *msg) +{ + if (!cras_valid_stream_id(msg->stream_id, client->id)) { + syslog(LOG_ERR, + "stream_connect: invalid stream_id: %x for " + "client: %zx.\n", + msg->stream_id, client->id); + return -EINVAL; + } + + int direction = cras_stream_direction_mask(msg->direction); + if (!(client->supported_directions & direction)) { + syslog(LOG_ERR, + "stream_connect: invalid stream direction: %x for " + "client: %zx.\n", + msg->direction, client->id); + return -EINVAL; + } + return 0; +} + int rclient_validate_stream_connect_fds(int audio_fd, int client_shm_fd, size_t client_shm_size) { diff --git a/cras/src/server/cras_rclient_util.h b/cras/src/server/cras_rclient_util.h index e9c0079e..8292a113 100644 --- a/cras/src/server/cras_rclient_util.h +++ b/cras/src/server/cras_rclient_util.h @@ -32,6 +32,20 @@ void rclient_fill_cras_rstream_config( const struct cras_audio_format *remote_format, struct cras_rstream_config *stream_config); +/* Checks if the incoming stream connect message contains + * - stream_id matches client->id. + * - direction supported by the client. + * + * Args: + * client - The cras_rclient which gets the message. + * + * Returns: + * 0 on success, negative error on failure. + */ +int rclient_validate_stream_connect_message( + const struct cras_rclient *client, + const struct cras_connect_message *msg); + /* * Converts an old version of connect message to the correct * cras_connect_message. Returns zero on success, negative on failure. diff --git a/cras/src/tests/control_rclient_unittest.cc b/cras/src/tests/control_rclient_unittest.cc index a5d17629..75573faf 100644 --- a/cras/src/tests/control_rclient_unittest.cc +++ b/cras/src/tests/control_rclient_unittest.cc @@ -136,7 +136,7 @@ class RClientMessagesSuite : public testing::Test { rc = pipe(pipe_fds_); if (rc < 0) return; - rclient_ = cras_control_rclient_create(pipe_fds_[1], 800); + rclient_ = cras_control_rclient_create(pipe_fds_[1], 1); rc = read(pipe_fds_[0], &msg, sizeof(msg)); if (rc < 0) return; @@ -246,6 +246,78 @@ TEST_F(RClientMessagesSuite, ConnectMsgFromOldClient) { EXPECT_EQ(1, cras_server_metrics_stream_config_called); } +TEST_F(RClientMessagesSuite, StreamConnectMessageValidDirection) { + struct cras_client_stream_connected out_msg; + int rc; + int called = 0; + + for (int i = 0; i < CRAS_NUM_DIRECTIONS; i++) { + connect_msg_.direction = static_cast<CRAS_STREAM_DIRECTION>(i); + if (connect_msg_.direction == CRAS_STREAM_UNDEFINED) + continue; + called++; + cras_rstream_create_stream_out = rstream_; + cras_iodev_attach_stream_retval = 0; + + fd_ = 100; + rc = rclient_->ops->handle_message_from_client( + rclient_, &connect_msg_.header, &fd_, 1); + EXPECT_EQ(0, rc); + EXPECT_EQ(called, cras_make_fd_nonblocking_called); + + rc = read(pipe_fds_[0], &out_msg, sizeof(out_msg)); + EXPECT_EQ(sizeof(out_msg), rc); + EXPECT_EQ(stream_id_, out_msg.stream_id); + EXPECT_EQ(0, out_msg.err); + EXPECT_EQ(called, stream_list_add_stream_called); + EXPECT_EQ(0, stream_list_disconnect_stream_called); + EXPECT_EQ(called, cras_server_metrics_stream_config_called); + } +} + +TEST_F(RClientMessagesSuite, StreamConnectMessageInvalidDirection) { + struct cras_client_stream_connected out_msg; + int rc; + + connect_msg_.direction = CRAS_STREAM_UNDEFINED; + cras_rstream_create_stream_out = rstream_; + cras_iodev_attach_stream_retval = 0; + + fd_ = 100; + rc = rclient_->ops->handle_message_from_client(rclient_, &connect_msg_.header, + &fd_, 1); + EXPECT_EQ(-EINVAL, rc); + EXPECT_EQ(0, cras_make_fd_nonblocking_called); + + rc = read(pipe_fds_[0], &out_msg, sizeof(out_msg)); + EXPECT_EQ(sizeof(out_msg), rc); + EXPECT_EQ(stream_id_, out_msg.stream_id); + EXPECT_EQ(-EINVAL, out_msg.err); + EXPECT_EQ(0, stream_list_add_stream_called); + EXPECT_EQ(0, stream_list_disconnect_stream_called); + EXPECT_EQ(0, cras_server_metrics_stream_config_called); +} + +TEST_F(RClientMessagesSuite, StreamConnectMessageInvalidClientId) { + struct cras_client_stream_connected out_msg; + int rc; + + connect_msg_.stream_id = 0x20002; // stream_id with invalid client_id + + fd_ = 100; + rc = rclient_->ops->handle_message_from_client(rclient_, &connect_msg_.header, + &fd_, 1); + EXPECT_EQ(-EINVAL, rc); + EXPECT_EQ(0, cras_make_fd_nonblocking_called); + EXPECT_EQ(0, stream_list_add_stream_called); + EXPECT_EQ(0, stream_list_disconnect_stream_called); + + rc = read(pipe_fds_[0], &out_msg, sizeof(out_msg)); + EXPECT_EQ(sizeof(out_msg), rc); + EXPECT_EQ(-EINVAL, out_msg.err); + EXPECT_EQ(connect_msg_.stream_id, out_msg.stream_id); +} + TEST_F(RClientMessagesSuite, SuccessReply) { struct cras_client_stream_connected out_msg; int rc; diff --git a/cras/src/tests/playback_rclient_unittest.cc b/cras/src/tests/playback_rclient_unittest.cc index f3aa2b73..01d8370b 100644 --- a/cras/src/tests/playback_rclient_unittest.cc +++ b/cras/src/tests/playback_rclient_unittest.cc @@ -128,24 +128,30 @@ TEST_F(CPRMessageSuite, StreamConnectMessageInvalidDirection) { struct cras_connect_message msg; cras_stream_id_t stream_id = 0x10002; - cras_fill_connect_message(&msg, CRAS_STREAM_INPUT, stream_id, - CRAS_STREAM_TYPE_DEFAULT, CRAS_CLIENT_TYPE_UNKNOWN, - 480, 240, /*flags=*/0, /*effects=*/0, fmt, - NO_DEVICE, /*client_shm_size=*/0); - ASSERT_EQ(stream_id, msg.stream_id); - fd_ = 100; - rc = - rclient_->ops->handle_message_from_client(rclient_, &msg.header, &fd_, 1); - EXPECT_EQ(-EINVAL, rc); - EXPECT_EQ(0, cras_make_fd_nonblocking_called); - EXPECT_EQ(0, stream_list_add_called); - EXPECT_EQ(0, stream_list_rm_called); - - rc = read(pipe_fds_[0], &out_msg, sizeof(out_msg)); - EXPECT_EQ(sizeof(out_msg), rc); - EXPECT_EQ(-EINVAL, out_msg.err); - EXPECT_EQ(stream_id, out_msg.stream_id); + for (int i = 0; i < CRAS_NUM_DIRECTIONS; i++) { + const auto dir = static_cast<CRAS_STREAM_DIRECTION>(i); + if (dir == CRAS_STREAM_OUTPUT) + continue; + cras_fill_connect_message(&msg, dir, stream_id, CRAS_STREAM_TYPE_DEFAULT, + CRAS_CLIENT_TYPE_UNKNOWN, 480, 240, /*flags=*/0, + /*effects=*/0, fmt, NO_DEVICE, + /*client_shm_size=*/0); + ASSERT_EQ(stream_id, msg.stream_id); + + fd_ = 100; + rc = rclient_->ops->handle_message_from_client(rclient_, &msg.header, &fd_, + 1); + EXPECT_EQ(-EINVAL, rc); + EXPECT_EQ(0, cras_make_fd_nonblocking_called); + EXPECT_EQ(0, stream_list_add_called); + EXPECT_EQ(0, stream_list_rm_called); + + rc = read(pipe_fds_[0], &out_msg, sizeof(out_msg)); + EXPECT_EQ(sizeof(out_msg), rc); + EXPECT_EQ(-EINVAL, out_msg.err); + EXPECT_EQ(stream_id, out_msg.stream_id); + } } TEST_F(CPRMessageSuite, StreamConnectMessageInvalidClientId) { |