aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/access_api_handler.cc227
-rw-r--r--src/access_api_handler.h47
-rw-r--r--src/access_api_handler_unittest.cc259
-rw-r--r--src/access_black_list_manager.h56
-rw-r--r--src/access_black_list_manager_impl.cc163
-rw-r--r--src/access_black_list_manager_impl.h58
-rw-r--r--src/access_black_list_manager_impl_unittest.cc165
-rw-r--r--src/base_api_handler.h2
-rw-r--r--src/config.cc9
-rw-r--r--src/config.h3
-rw-r--r--src/config_unittest.cc11
-rw-r--r--src/device_manager.cc6
-rw-r--r--src/device_manager.h4
-rw-r--r--src/device_registration_info.cc23
-rw-r--r--src/device_registration_info.h1
-rw-r--r--src/device_registration_info_unittest.cc2
-rw-r--r--src/notification/xmpp_channel.cc19
-rw-r--r--src/notification/xmpp_channel.h10
-rw-r--r--src/notification/xmpp_channel_unittest.cc6
-rw-r--r--src/privet/auth_manager.cc402
-rw-r--r--src/privet/auth_manager.h15
-rw-r--r--src/privet/auth_manager_unittest.cc226
-rw-r--r--src/privet/cloud_delegate.cc17
-rw-r--r--src/privet/mock_delegates.h14
-rw-r--r--src/privet/openssl_utils.cc10
-rw-r--r--src/privet/privet_handler_unittest.cc6
-rw-r--r--src/privet/privet_types.h33
-rw-r--r--src/privet/security_manager.cc9
-rw-r--r--src/privet/security_manager_unittest.cc3
29 files changed, 1599 insertions, 207 deletions
diff --git a/src/access_api_handler.cc b/src/access_api_handler.cc
new file mode 100644
index 0000000..7c39b20
--- /dev/null
+++ b/src/access_api_handler.cc
@@ -0,0 +1,227 @@
+// Copyright 2016 The Weave Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "src/access_api_handler.h"
+
+#include <base/bind.h>
+#include <weave/device.h>
+
+#include "src/access_black_list_manager.h"
+#include "src/commands/schema_constants.h"
+#include "src/data_encoding.h"
+#include "src/json_error_codes.h"
+
+namespace weave {
+
+namespace {
+
+const char kComponent[] = "accessControl";
+const char kTrait[] = "_accessControlBlackList";
+const char kStateSize[] = "_accessControlBlackList.size";
+const char kStateCapacity[] = "_accessControlBlackList.capacity";
+const char kUserId[] = "userId";
+const char kApplicationId[] = "applicationId";
+const char kExpirationTimeout[] = "expirationTimeoutSec";
+const char kBlackList[] = "blackList";
+
+bool GetIds(const base::DictionaryValue& parameters,
+ std::vector<uint8_t>* user_id_decoded,
+ std::vector<uint8_t>* app_id_decoded,
+ ErrorPtr* error) {
+ std::string user_id;
+ parameters.GetString(kUserId, &user_id);
+ if (!Base64Decode(user_id, user_id_decoded)) {
+ Error::AddToPrintf(error, FROM_HERE, errors::commands::kInvalidPropValue,
+ "Invalid user id '%s'", user_id.c_str());
+ return false;
+ }
+
+ std::string app_id;
+ parameters.GetString(kApplicationId, &app_id);
+ if (!Base64Decode(app_id, app_id_decoded)) {
+ Error::AddToPrintf(error, FROM_HERE, errors::commands::kInvalidPropValue,
+ "Invalid app id '%s'", user_id.c_str());
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+AccessApiHandler::AccessApiHandler(Device* device,
+ AccessBlackListManager* manager)
+ : device_{device}, manager_{manager} {
+ device_->AddTraitDefinitionsFromJson(R"({
+ "_accessControlBlackList": {
+ "commands": {
+ "block": {
+ "minimalRole": "owner",
+ "parameters": {
+ "userId": {
+ "type": "string"
+ },
+ "applicationId": {
+ "type": "string"
+ },
+ "expirationTimeoutSec": {
+ "type": "integer"
+ }
+ }
+ },
+ "unblock": {
+ "minimalRole": "owner",
+ "parameters": {
+ "userId": {
+ "type": "string"
+ },
+ "applicationId": {
+ "type": "string"
+ }
+ }
+ },
+ "list": {
+ "minimalRole": "owner",
+ "parameters": {},
+ "results": {
+ "blackList": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "userId": {
+ "type": "string"
+ },
+ "applicationId": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false
+ }
+ }
+ }
+ }
+ },
+ "state": {
+ "size": {
+ "type": "integer",
+ "isRequired": true
+ },
+ "capacity": {
+ "type": "integer",
+ "isRequired": true
+ }
+ }
+ }
+ })");
+ CHECK(device_->AddComponent(kComponent, {kTrait}, nullptr));
+ UpdateState();
+
+ device_->AddCommandHandler(
+ kComponent, "_accessControlBlackList.block",
+ base::Bind(&AccessApiHandler::Block, weak_ptr_factory_.GetWeakPtr()));
+ device_->AddCommandHandler(
+ kComponent, "_accessControlBlackList.unblock",
+ base::Bind(&AccessApiHandler::Unblock, weak_ptr_factory_.GetWeakPtr()));
+ device_->AddCommandHandler(
+ kComponent, "_accessControlBlackList.list",
+ base::Bind(&AccessApiHandler::List, weak_ptr_factory_.GetWeakPtr()));
+}
+
+void AccessApiHandler::Block(const std::weak_ptr<Command>& cmd) {
+ auto command = cmd.lock();
+ if (!command)
+ return;
+
+ CHECK(command->GetState() == Command::State::kQueued)
+ << EnumToString(command->GetState());
+ command->SetProgress(base::DictionaryValue{}, nullptr);
+
+ const auto& parameters = command->GetParameters();
+ std::vector<uint8_t> user_id;
+ std::vector<uint8_t> app_id;
+ ErrorPtr error;
+ if (!GetIds(parameters, &user_id, &app_id, &error)) {
+ command->Abort(error.get(), nullptr);
+ return;
+ }
+
+ int timeout_sec = 0;
+ parameters.GetInteger(kExpirationTimeout, &timeout_sec);
+
+ base::Time expiration =
+ base::Time::Now() + base::TimeDelta::FromSeconds(timeout_sec);
+
+ manager_->Block(user_id, app_id, expiration,
+ base::Bind(&AccessApiHandler::OnCommandDone,
+ weak_ptr_factory_.GetWeakPtr(), cmd));
+}
+
+void AccessApiHandler::Unblock(const std::weak_ptr<Command>& cmd) {
+ auto command = cmd.lock();
+ if (!command)
+ return;
+
+ CHECK(command->GetState() == Command::State::kQueued)
+ << EnumToString(command->GetState());
+ command->SetProgress(base::DictionaryValue{}, nullptr);
+
+ const auto& parameters = command->GetParameters();
+ std::vector<uint8_t> user_id;
+ std::vector<uint8_t> app_id;
+ ErrorPtr error;
+ if (!GetIds(parameters, &user_id, &app_id, &error)) {
+ command->Abort(error.get(), nullptr);
+ return;
+ }
+
+ manager_->Unblock(user_id, app_id,
+ base::Bind(&AccessApiHandler::OnCommandDone,
+ weak_ptr_factory_.GetWeakPtr(), cmd));
+}
+
+void AccessApiHandler::List(const std::weak_ptr<Command>& cmd) {
+ auto command = cmd.lock();
+ if (!command)
+ return;
+
+ CHECK(command->GetState() == Command::State::kQueued)
+ << EnumToString(command->GetState());
+ command->SetProgress(base::DictionaryValue{}, nullptr);
+
+ std::unique_ptr<base::ListValue> entries{new base::ListValue};
+ for (const auto& e : manager_->GetEntries()) {
+ std::unique_ptr<base::DictionaryValue> entry{new base::DictionaryValue};
+ entry->SetString(kUserId, Base64Encode(e.user_id));
+ entry->SetString(kApplicationId, Base64Encode(e.app_id));
+ entries->Append(entry.release());
+ }
+
+ base::DictionaryValue result;
+ result.Set(kBlackList, entries.release());
+
+ command->Complete(result, nullptr);
+}
+
+void AccessApiHandler::OnCommandDone(const std::weak_ptr<Command>& cmd,
+ ErrorPtr error) {
+ auto command = cmd.lock();
+ if (!command)
+ return;
+ UpdateState();
+ if (error) {
+ command->Abort(error.get(), nullptr);
+ return;
+ }
+ command->Complete({}, nullptr);
+}
+
+void AccessApiHandler::UpdateState() {
+ base::DictionaryValue state;
+ state.SetInteger(kStateSize, manager_->GetSize());
+ state.SetInteger(kStateCapacity, manager_->GetCapacity());
+ device_->SetStateProperties(kComponent, state, nullptr);
+}
+
+} // namespace weave
diff --git a/src/access_api_handler.h b/src/access_api_handler.h
new file mode 100644
index 0000000..821ce02
--- /dev/null
+++ b/src/access_api_handler.h
@@ -0,0 +1,47 @@
+// Copyright 2016 The Weave Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef LIBWEAVE_SRC_ACCESS_API_HANDLER_H_
+#define LIBWEAVE_SRC_ACCESS_API_HANDLER_H_
+
+#include <memory>
+
+#include <base/memory/weak_ptr.h>
+#include <weave/error.h>
+
+namespace weave {
+
+class AccessBlackListManager;
+class Command;
+class Device;
+
+// Handles commands for 'accessControlBlackList' trait.
+// Objects of the class subscribe for notification from CommandManager and
+// execute incoming commands.
+// Handled commands:
+// accessControlBlackList.block
+// accessControlBlackList.unblock
+// accessControlBlackList.list
+class AccessApiHandler final {
+ public:
+ AccessApiHandler(Device* device, AccessBlackListManager* manager);
+
+ private:
+ void Block(const std::weak_ptr<Command>& command);
+ void Unblock(const std::weak_ptr<Command>& command);
+ void List(const std::weak_ptr<Command>& command);
+ void UpdateState();
+
+ void OnCommandDone(const std::weak_ptr<Command>& command, ErrorPtr error);
+
+ Device* device_{nullptr};
+ AccessBlackListManager* manager_{nullptr};
+
+ base::WeakPtrFactory<AccessApiHandler> weak_ptr_factory_{this};
+ DISALLOW_COPY_AND_ASSIGN(AccessApiHandler);
+};
+
+} // namespace weave
+
+#endif // LIBWEAVE_SRC_ACCESS_API_HANDLER_H_
diff --git a/src/access_api_handler_unittest.cc b/src/access_api_handler_unittest.cc
new file mode 100644
index 0000000..3e7f5d7
--- /dev/null
+++ b/src/access_api_handler_unittest.cc
@@ -0,0 +1,259 @@
+// Copyright 2016 The Weave Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "src/access_api_handler.h"
+
+#include <gtest/gtest.h>
+#include <weave/provider/test/fake_task_runner.h>
+#include <weave/test/mock_device.h>
+#include <weave/test/unittest_utils.h>
+
+#include "src/component_manager_impl.h"
+#include "src/access_black_list_manager.h"
+#include "src/data_encoding.h"
+
+using testing::_;
+using testing::AnyOf;
+using testing::Invoke;
+using testing::Return;
+using testing::StrictMock;
+using testing::WithArgs;
+
+namespace weave {
+
+class MockAccessBlackListManager : public AccessBlackListManager {
+ public:
+ MOCK_METHOD4(Block,
+ void(const std::vector<uint8_t>&,
+ const std::vector<uint8_t>&,
+ const base::Time&,
+ const DoneCallback&));
+ MOCK_METHOD3(Unblock,
+ void(const std::vector<uint8_t>&,
+ const std::vector<uint8_t>&,
+ const DoneCallback&));
+ MOCK_CONST_METHOD2(IsBlocked,
+ bool(const std::vector<uint8_t>&,
+ const std::vector<uint8_t>&));
+ MOCK_CONST_METHOD0(GetEntries, std::vector<Entry>());
+ MOCK_CONST_METHOD0(GetSize, size_t());
+ MOCK_CONST_METHOD0(GetCapacity, size_t());
+};
+
+class AccessApiHandlerTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ EXPECT_CALL(device_, AddTraitDefinitionsFromJson(_))
+ .WillRepeatedly(Invoke([this](const std::string& json) {
+ EXPECT_TRUE(component_manager_.LoadTraits(json, nullptr));
+ }));
+ EXPECT_CALL(device_, SetStateProperties(_, _, _))
+ .WillRepeatedly(
+ Invoke(&component_manager_, &ComponentManager::SetStateProperties));
+ EXPECT_CALL(device_, SetStateProperty(_, _, _, _))
+ .WillRepeatedly(
+ Invoke(&component_manager_, &ComponentManager::SetStateProperty));
+ EXPECT_CALL(device_, AddComponent(_, _, _))
+ .WillRepeatedly(Invoke([this](const std::string& name,
+ const std::vector<std::string>& traits,
+ ErrorPtr* error) {
+ return component_manager_.AddComponent("", name, traits, error);
+ }));
+
+ EXPECT_CALL(device_,
+ AddCommandHandler(_, AnyOf("_accessControlBlackList.block",
+ "_accessControlBlackList.unblock",
+ "_accessControlBlackList.list"),
+ _))
+ .WillRepeatedly(
+ Invoke(&component_manager_, &ComponentManager::AddCommandHandler));
+
+ EXPECT_CALL(access_manager_, GetSize()).WillRepeatedly(Return(0));
+
+ EXPECT_CALL(access_manager_, GetCapacity()).WillRepeatedly(Return(10));
+
+ handler_.reset(new AccessApiHandler{&device_, &access_manager_});
+ }
+
+ const base::DictionaryValue& AddCommand(const std::string& command) {
+ std::string id;
+ auto command_instance = component_manager_.ParseCommandInstance(
+ *test::CreateDictionaryValue(command.c_str()), Command::Origin::kLocal,
+ UserRole::kOwner, &id, nullptr);
+ EXPECT_NE(nullptr, command_instance.get());
+ component_manager_.AddCommand(std::move(command_instance));
+ EXPECT_EQ(Command::State::kDone,
+ component_manager_.FindCommand(id)->GetState());
+ return component_manager_.FindCommand(id)->GetResults();
+ }
+
+ std::unique_ptr<base::DictionaryValue> GetState() {
+ std::string path =
+ component_manager_.FindComponentWithTrait("_accessControlBlackList");
+ EXPECT_FALSE(path.empty());
+ const auto* component = component_manager_.FindComponent(path, nullptr);
+ EXPECT_TRUE(component);
+ const base::DictionaryValue* state = nullptr;
+ EXPECT_TRUE(
+ component->GetDictionary("state._accessControlBlackList", &state));
+ return std::unique_ptr<base::DictionaryValue>{state->DeepCopy()};
+ }
+
+ StrictMock<provider::test::FakeTaskRunner> task_runner_;
+ ComponentManagerImpl component_manager_{&task_runner_};
+ StrictMock<test::MockDevice> device_;
+ StrictMock<MockAccessBlackListManager> access_manager_;
+ std::unique_ptr<AccessApiHandler> handler_;
+};
+
+TEST_F(AccessApiHandlerTest, Initialization) {
+ const base::DictionaryValue* trait = nullptr;
+ ASSERT_TRUE(component_manager_.GetTraits().GetDictionary(
+ "_accessControlBlackList", &trait));
+
+ auto expected = R"({
+ "commands": {
+ "block": {
+ "minimalRole": "owner",
+ "parameters": {
+ "userId": {
+ "type": "string"
+ },
+ "applicationId": {
+ "type": "string"
+ },
+ "expirationTimeoutSec": {
+ "type": "integer"
+ }
+ }
+ },
+ "unblock": {
+ "minimalRole": "owner",
+ "parameters": {
+ "userId": {
+ "type": "string"
+ },
+ "applicationId": {
+ "type": "string"
+ }
+ }
+ },
+ "list": {
+ "minimalRole": "owner",
+ "parameters": {},
+ "results": {
+ "blackList": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "userId": {
+ "type": "string"
+ },
+ "applicationId": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false
+ }
+ }
+ }
+ }
+ },
+ "state": {
+ "size": {
+ "type": "integer",
+ "isRequired": true
+ },
+ "capacity": {
+ "type": "integer",
+ "isRequired": true
+ }
+ }
+ })";
+ EXPECT_JSON_EQ(expected, *trait);
+
+ expected = R"({
+ "capacity": 10,
+ "size": 0
+ })";
+ EXPECT_JSON_EQ(expected, *GetState());
+}
+
+TEST_F(AccessApiHandlerTest, Block) {
+ EXPECT_CALL(access_manager_, Block(std::vector<uint8_t>{1, 2, 3},
+ std::vector<uint8_t>{3, 4, 5}, _, _))
+ .WillOnce(WithArgs<3>(
+ Invoke([](const DoneCallback& callback) { callback.Run(nullptr); })));
+ EXPECT_CALL(access_manager_, GetSize()).WillRepeatedly(Return(1));
+
+ AddCommand(R"({
+ 'name' : '_accessControlBlackList.block',
+ 'component': 'accessControl',
+ 'parameters': {
+ 'userId': 'AQID',
+ 'applicationId': 'AwQF',
+ 'expirationTimeoutSec': 1234
+ }
+ })");
+
+ auto expected = R"({
+ "capacity": 10,
+ "size": 1
+ })";
+ EXPECT_JSON_EQ(expected, *GetState());
+}
+
+TEST_F(AccessApiHandlerTest, Unblock) {
+ EXPECT_CALL(access_manager_, Unblock(std::vector<uint8_t>{1, 2, 3},
+ std::vector<uint8_t>{3, 4, 5}, _))
+ .WillOnce(WithArgs<2>(
+ Invoke([](const DoneCallback& callback) { callback.Run(nullptr); })));
+ EXPECT_CALL(access_manager_, GetSize()).WillRepeatedly(Return(4));
+
+ AddCommand(R"({
+ 'name' : '_accessControlBlackList.unblock',
+ 'component': 'accessControl',
+ 'parameters': {
+ 'userId': 'AQID',
+ 'applicationId': 'AwQF',
+ 'expirationTimeoutSec': 1234
+ }
+ })");
+
+ auto expected = R"({
+ "capacity": 10,
+ "size": 4
+ })";
+ EXPECT_JSON_EQ(expected, *GetState());
+}
+
+TEST_F(AccessApiHandlerTest, List) {
+ std::vector<AccessBlackListManager::Entry> entries{
+ {{11, 12, 13}, {21, 22, 23}, base::Time::FromTimeT(1410000000)},
+ {{31, 32, 33}, {41, 42, 43}, base::Time::FromTimeT(1420000000)},
+ };
+ EXPECT_CALL(access_manager_, GetEntries()).WillOnce(Return(entries));
+ EXPECT_CALL(access_manager_, GetSize()).WillRepeatedly(Return(4));
+
+ auto expected = R"({
+ "blackList": [ {
+ "applicationId": "FRYX",
+ "userId": "CwwN"
+ }, {
+ "applicationId": "KSor",
+ "userId": "HyAh"
+ } ]
+ })";
+
+ const auto& results = AddCommand(R"({
+ 'name' : '_accessControlBlackList.list',
+ 'component': 'accessControl',
+ 'parameters': {
+ }
+ })");
+
+ EXPECT_JSON_EQ(expected, results);
+}
+} // namespace weave
diff --git a/src/access_black_list_manager.h b/src/access_black_list_manager.h
new file mode 100644
index 0000000..b56226a
--- /dev/null
+++ b/src/access_black_list_manager.h
@@ -0,0 +1,56 @@
+// Copyright 2016 The Weave Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef LIBWEAVE_SRC_ACCESS_BLACK_LIST_H_
+#define LIBWEAVE_SRC_ACCESS_BLACK_LIST_H_
+
+#include <vector>
+
+#include <base/time/time.h>
+
+namespace weave {
+
+class AccessBlackListManager {
+ public:
+ struct Entry {
+ // user_id is empty, app_id is empty: block everything.
+ // user_id is not empty, app_id is empty: block if user_id matches.
+ // user_id is empty, app_id is not empty: block if app_id matches.
+ // user_id is not empty, app_id is not empty: block if both match.
+ std::vector<uint8_t> user_id;
+ std::vector<uint8_t> app_id;
+
+ // Time after which to discard the rule.
+ base::Time expiration;
+ };
+ virtual ~AccessBlackListManager() = default;
+
+ virtual void Block(const std::vector<uint8_t>& user_id,
+ const std::vector<uint8_t>& app_id,
+ const base::Time& expiration,
+ const DoneCallback& callback) = 0;
+ virtual void Unblock(const std::vector<uint8_t>& user_id,
+ const std::vector<uint8_t>& app_id,
+ const DoneCallback& callback) = 0;
+ virtual bool IsBlocked(const std::vector<uint8_t>& user_id,
+ const std::vector<uint8_t>& app_id) const = 0;
+ virtual std::vector<Entry> GetEntries() const = 0;
+ virtual size_t GetSize() const = 0;
+ virtual size_t GetCapacity() const = 0;
+};
+
+inline bool operator==(const AccessBlackListManager::Entry& l,
+ const AccessBlackListManager::Entry& r) {
+ return l.user_id == r.user_id && l.app_id == r.app_id &&
+ l.expiration == r.expiration;
+}
+
+inline bool operator!=(const AccessBlackListManager::Entry& l,
+ const AccessBlackListManager::Entry& r) {
+ return !(l == r);
+}
+
+} // namespace weave
+
+#endif // LIBWEAVE_SRC_ACCESS_BLACK_LIST_H_
diff --git a/src/access_black_list_manager_impl.cc b/src/access_black_list_manager_impl.cc
new file mode 100644
index 0000000..992a680
--- /dev/null
+++ b/src/access_black_list_manager_impl.cc
@@ -0,0 +1,163 @@
+// Copyright 2016 The Weave Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "src/access_black_list_manager_impl.h"
+
+#include <base/json/json_reader.h>
+#include <base/json/json_writer.h>
+#include <base/values.h>
+
+#include "src/commands/schema_constants.h"
+#include "src/data_encoding.h"
+
+namespace weave {
+
+namespace {
+const char kConfigFileName[] = "black_list";
+
+const char kUser[] = "user";
+const char kApp[] = "app";
+const char kExpiration[] = "expiration";
+}
+
+AccessBlackListManagerImpl::AccessBlackListManagerImpl(
+ provider::ConfigStore* store,
+ size_t capacity,
+ base::Clock* clock)
+ : capacity_{capacity}, clock_{clock}, store_{store} {
+ Load();
+}
+
+void AccessBlackListManagerImpl::Load() {
+ if (!store_)
+ return;
+ if (auto list = base::ListValue::From(
+ base::JSONReader::Read(store_->LoadSettings(kConfigFileName)))) {
+ for (const auto& e : *list) {
+ const base::DictionaryValue* entry{nullptr};
+ std::string user;
+ std::string app;
+ decltype(entries_)::key_type key;
+ int expiration;
+ if (e->GetAsDictionary(&entry) && entry->GetString(kUser, &user) &&
+ Base64Decode(user, &key.first) && entry->GetString(kApp, &app) &&
+ Base64Decode(app, &key.second) &&
+ entry->GetInteger(kExpiration, &expiration)) {
+ base::Time expiration_time = base::Time::FromTimeT(expiration);
+ if (expiration_time > clock_->Now())
+ entries_[key] = expiration_time;
+ }
+ }
+ if (entries_.size() < list->GetSize()) {
+ // Save some storage space by saving without expired entries.
+ Save({});
+ }
+ }
+}
+
+void AccessBlackListManagerImpl::Save(const DoneCallback& callback) {
+ if (!store_) {
+ if (!callback.is_null())
+ callback.Run(nullptr);
+ return;
+ }
+
+ base::ListValue list;
+ for (const auto& e : entries_) {
+ scoped_ptr<base::DictionaryValue> entry{new base::DictionaryValue};
+ entry->SetString(kUser, Base64Encode(e.first.first));
+ entry->SetString(kApp, Base64Encode(e.first.second));
+ entry->SetInteger(kExpiration, e.second.ToTimeT());
+ list.Append(std::move(entry));
+ }
+
+ std::string json;
+ base::JSONWriter::Write(list, &json);
+ store_->SaveSettings(kConfigFileName, json, callback);
+}
+
+void AccessBlackListManagerImpl::RemoveExpired() {
+ for (auto i = begin(entries_); i != end(entries_);) {
+ if (i->second <= clock_->Now())
+ i = entries_.erase(i);
+ else
+ ++i;
+ }
+}
+
+void AccessBlackListManagerImpl::Block(const std::vector<uint8_t>& user_id,
+ const std::vector<uint8_t>& app_id,
+ const base::Time& expiration,
+ const DoneCallback& callback) {
+ // Iterating is OK as Save below is more expensive.
+ RemoveExpired();
+ if (expiration <= clock_->Now()) {
+ if (!callback.is_null()) {
+ ErrorPtr error;
+ Error::AddTo(&error, FROM_HERE, "aleady_expired",
+ "Entry already expired");
+ callback.Run(std::move(error));
+ }
+ return;
+ }
+ if (entries_.size() >= capacity_) {
+ if (!callback.is_null()) {
+ ErrorPtr error;
+ Error::AddTo(&error, FROM_HERE, "blacklist_is_full",
+ "Unable to store more entries");
+ callback.Run(std::move(error));
+ }
+ return;
+ }
+ auto& value = entries_[std::make_pair(user_id, app_id)];
+ value = std::max(value, expiration);
+ Save(callback);
+}
+
+void AccessBlackListManagerImpl::Unblock(const std::vector<uint8_t>& user_id,
+ const std::vector<uint8_t>& app_id,
+ const DoneCallback& callback) {
+ if (!entries_.erase(std::make_pair(user_id, app_id))) {
+ if (!callback.is_null()) {
+ ErrorPtr error;
+ Error::AddTo(&error, FROM_HERE, "entry_not_found", "Unknown entry");
+ callback.Run(std::move(error));
+ }
+ return;
+ }
+ // Iterating is OK as Save below is more expensive.
+ RemoveExpired();
+ Save(callback);
+}
+
+bool AccessBlackListManagerImpl::IsBlocked(
+ const std::vector<uint8_t>& user_id,
+ const std::vector<uint8_t>& app_id) const {
+ for (const auto& user : {{}, user_id}) {
+ for (const auto& app : {{}, app_id}) {
+ auto both = entries_.find(std::make_pair(user, app));
+ if (both != end(entries_) && both->second > clock_->Now())
+ return true;
+ }
+ }
+ return false;
+}
+
+std::vector<AccessBlackListManager::Entry>
+AccessBlackListManagerImpl::GetEntries() const {
+ std::vector<Entry> result;
+ for (const auto& e : entries_)
+ result.push_back({e.first.first, e.first.second, e.second});
+ return result;
+}
+
+size_t AccessBlackListManagerImpl::GetSize() const {
+ return entries_.size();
+}
+
+size_t AccessBlackListManagerImpl::GetCapacity() const {
+ return capacity_;
+}
+
+} // namespace weave
diff --git a/src/access_black_list_manager_impl.h b/src/access_black_list_manager_impl.h
new file mode 100644
index 0000000..1c175db
--- /dev/null
+++ b/src/access_black_list_manager_impl.h
@@ -0,0 +1,58 @@
+// Copyright 2016 The Weave Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef LIBWEAVE_SRC_ACCESS_BLACK_LIST_IMPL_H_
+#define LIBWEAVE_SRC_ACCESS_BLACK_LIST_IMPL_H_
+
+#include <map>
+#include <utility>
+
+#include <base/time/default_clock.h>
+#include <base/time/time.h>
+#include <weave/error.h>
+#include <weave/provider/config_store.h>
+
+#include "src/access_black_list_manager.h"
+
+namespace weave {
+
+class AccessBlackListManagerImpl : public AccessBlackListManager {
+ public:
+ explicit AccessBlackListManagerImpl(provider::ConfigStore* store,
+ size_t capacity = 1024,
+ base::Clock* clock = nullptr);
+
+ // AccessBlackListManager implementation.
+ void Block(const std::vector<uint8_t>& user_id,
+ const std::vector<uint8_t>& app_id,
+ const base::Time& expiration,
+ const DoneCallback& callback) override;
+ void Unblock(const std::vector<uint8_t>& user_id,
+ const std::vector<uint8_t>& app_id,
+ const DoneCallback& callback) override;
+ bool IsBlocked(const std::vector<uint8_t>& user_id,
+ const std::vector<uint8_t>& app_id) const override;
+ std::vector<Entry> GetEntries() const override;
+ size_t GetSize() const override;
+ size_t GetCapacity() const override;
+
+ private:
+ void Load();
+ void Save(const DoneCallback& callback);
+ void RemoveExpired();
+
+ const size_t capacity_{0};
+ base::DefaultClock default_clock_;
+ base::Clock* clock_{&default_clock_};
+
+ provider::ConfigStore* store_{nullptr};
+ std::map<std::pair<std::vector<uint8_t>, std::vector<uint8_t>>, base::Time>
+ entries_;
+
+ DISALLOW_COPY_AND_ASSIGN(AccessBlackListManagerImpl);
+};
+
+} // namespace weave
+
+#endif // LIBWEAVE_SRC_ACCESS_BLACK_LIST_IMPL_H_
diff --git a/src/access_black_list_manager_impl_unittest.cc b/src/access_black_list_manager_impl_unittest.cc
new file mode 100644
index 0000000..fd9f226
--- /dev/null
+++ b/src/access_black_list_manager_impl_unittest.cc
@@ -0,0 +1,165 @@
+// Copyright 2016 The Weave Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "src/access_black_list_manager_impl.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <weave/provider/test/mock_config_store.h>
+#include <weave/test/unittest_utils.h>
+
+#include "src/test/mock_clock.h"
+#include "src/bind_lambda.h"
+
+using testing::_;
+using testing::Return;
+using testing::StrictMock;
+
+namespace weave {
+
+class AccessBlackListManagerImplTest : public testing::Test {
+ protected:
+ void SetUp() {
+ std::string to_load = R"([{
+ "user": "BQID",
+ "app": "BwQF",
+ "expiration": 1410000000
+ }, {
+ "user": "AQID",
+ "app": "AwQF",
+ "expiration": 1419999999
+ }])";
+
+ EXPECT_CALL(config_store_, LoadSettings("black_list"))
+ .WillOnce(Return(to_load));
+
+ EXPECT_CALL(config_store_, SaveSettings("black_list", _, _))
+ .WillOnce(testing::WithArgs<1, 2>(testing::Invoke(
+ [](const std::string& json, const DoneCallback& callback) {
+ std::string to_save = R"([{
+ "user": "AQID",
+ "app": "AwQF",
+ "expiration": 1419999999
+ }])";
+ EXPECT_JSON_EQ(to_save, *test::CreateValue(json));
+ if (!callback.is_null())
+ callback.Run(nullptr);
+ })));
+
+ EXPECT_CALL(clock_, Now())
+ .WillRepeatedly(Return(base::Time::FromTimeT(1412121212)));
+ manager_.reset(new AccessBlackListManagerImpl{&config_store_, 10, &clock_});
+ }
+ StrictMock<test::MockClock> clock_;
+ StrictMock<provider::test::MockConfigStore> config_store_{false};
+ std::unique_ptr<AccessBlackListManagerImpl> manager_;
+};
+
+TEST_F(AccessBlackListManagerImplTest, Init) {
+ EXPECT_EQ(1u, manager_->GetSize());
+ EXPECT_EQ(10u, manager_->GetCapacity());
+ EXPECT_EQ((std::vector<AccessBlackListManagerImpl::Entry>{{
+ {1, 2, 3}, {3, 4, 5}, base::Time::FromTimeT(1419999999),
+ }}),
+ manager_->GetEntries());
+}
+
+TEST_F(AccessBlackListManagerImplTest, Block) {
+ EXPECT_CALL(config_store_, SaveSettings("black_list", _, _))
+ .WillOnce(testing::WithArgs<1, 2>(testing::Invoke(
+ [](const std::string& json, const DoneCallback& callback) {
+ std::string to_save = R"([{
+ "user": "AQID",
+ "app": "AwQF",
+ "expiration": 1419999999
+ }, {
+ "app": "CAgI",
+ "user": "BwcH",
+ "expiration": 1419990000
+ }])";
+ EXPECT_JSON_EQ(to_save, *test::CreateValue(json));
+ if (!callback.is_null())
+ callback.Run(nullptr);
+ })));
+ manager_->Block({7, 7, 7}, {8, 8, 8}, base::Time::FromTimeT(1419990000), {});
+}
+
+TEST_F(AccessBlackListManagerImplTest, BlockExpired) {
+ manager_->Block({}, {}, base::Time::FromTimeT(1400000000),
+ base::Bind([](ErrorPtr error) {
+ EXPECT_TRUE(error->HasError("aleady_expired"));
+ }));
+}
+
+TEST_F(AccessBlackListManagerImplTest, BlockListIsFull) {
+ EXPECT_CALL(config_store_, SaveSettings("black_list", _, _))
+ .WillRepeatedly(testing::WithArgs<1, 2>(testing::Invoke(
+ [](const std::string& json, const DoneCallback& callback) {
+ if (!callback.is_null())
+ callback.Run(nullptr);
+ })));
+ for (size_t i = manager_->GetSize(); i < manager_->GetCapacity(); ++i) {
+ manager_->Block(
+ {99, static_cast<uint8_t>(i / 256), static_cast<uint8_t>(i % 256)},
+ {8, 8, 8}, base::Time::FromTimeT(1419990000), {});
+ EXPECT_EQ(i + 1, manager_->GetSize());
+ }
+ manager_->Block({99}, {8, 8, 8}, base::Time::FromTimeT(1419990000),
+ base::Bind([](ErrorPtr error) {
+ EXPECT_TRUE(error->HasError("blacklist_is_full"));
+ }));
+}
+
+TEST_F(AccessBlackListManagerImplTest, Unblock) {
+ EXPECT_CALL(config_store_, SaveSettings("black_list", _, _))
+ .WillOnce(testing::WithArgs<1, 2>(testing::Invoke(
+ [](const std::string& json, const DoneCallback& callback) {
+ EXPECT_JSON_EQ("[]", *test::CreateValue(json));
+ if (!callback.is_null())
+ callback.Run(nullptr);
+ })));
+ manager_->Unblock({1, 2, 3}, {3, 4, 5}, {});
+}
+
+TEST_F(AccessBlackListManagerImplTest, UnblockNotFound) {
+ manager_->Unblock({5, 2, 3}, {5, 4, 5}, base::Bind([](ErrorPtr error) {
+ EXPECT_TRUE(error->HasError("entry_not_found"));
+ }));
+}
+
+TEST_F(AccessBlackListManagerImplTest, IsBlockedFalse) {
+ EXPECT_FALSE(manager_->IsBlocked({7, 7, 7}, {8, 8, 8}));
+}
+
+class AccessBlackListManagerImplIsBlockedTest
+ : public AccessBlackListManagerImplTest,
+ public testing::WithParamInterface<
+ std::tuple<std::vector<uint8_t>, std::vector<uint8_t>>> {
+ public:
+ void SetUp() override {
+ AccessBlackListManagerImplTest::SetUp();
+ EXPECT_CALL(config_store_, SaveSettings("black_list", _, _))
+ .WillOnce(testing::WithArgs<2>(
+ testing::Invoke([](const DoneCallback& callback) {
+ if (!callback.is_null())
+ callback.Run(nullptr);
+ })));
+ manager_->Block(std::get<0>(GetParam()), std::get<1>(GetParam()),
+ base::Time::FromTimeT(1419990000), {});
+ }
+};
+
+TEST_P(AccessBlackListManagerImplIsBlockedTest, IsBlocked) {
+ EXPECT_TRUE(manager_->IsBlocked({7, 7, 7}, {8, 8, 8}));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ Filters,
+ AccessBlackListManagerImplIsBlockedTest,
+ testing::Combine(testing::Values(std::vector<uint8_t>{},
+ std::vector<uint8_t>{7, 7, 7}),
+ testing::Values(std::vector<uint8_t>{},
+ std::vector<uint8_t>{8, 8, 8})));
+
+} // namespace weave
diff --git a/src/base_api_handler.h b/src/base_api_handler.h
index 1dbbac8..6eebfca 100644
--- a/src/base_api_handler.h
+++ b/src/base_api_handler.h
@@ -33,7 +33,7 @@ class BaseApiHandler final {
void OnConfigChanged(const Settings& settings);
DeviceRegistrationInfo* device_info_;
- Device* device_;
+ Device* device_{nullptr};
base::WeakPtrFactory<BaseApiHandler> weak_ptr_factory_{this};
DISALLOW_COPY_AND_ASSIGN(BaseApiHandler);
diff --git a/src/config.cc b/src/config.cc
index 44d20dd..21a1c1f 100644
--- a/src/config.cc
+++ b/src/config.cc
@@ -33,6 +33,7 @@ const char kClientSecret[] = "client_secret";
const char kApiKey[] = "api_key";
const char kOAuthURL[] = "oauth_url";
const char kServiceURL[] = "service_url";
+const char kXmppEndpoint[] = "xmpp_endpoint";
const char kName[] = "name";
const char kDescription[] = "description";
const char kLocation[] = "location";
@@ -51,6 +52,7 @@ const char kRootClientTokenOwner[] = "root_client_token_owner";
const char kWeaveUrl[] = "https://www.googleapis.com/weave/v1/";
const char kDeprecatedUrl[] = "https://www.googleapis.com/clouddevices/v1/";
+const char kXmppEndpoint[] = "talk.google.com:5223";
namespace {
@@ -69,6 +71,7 @@ Config::Settings CreateDefaultSettings() {
Config::Settings result;
result.oauth_url = "https://accounts.google.com/o/oauth2/";
result.service_url = kWeaveUrl;
+ result.xmpp_endpoint = kXmppEndpoint;
result.local_anonymous_access_role = AuthScope::kViewer;
result.pairing_modes.insert(PairingType::kPinCode);
result.device_id = base::GenerateGUID();
@@ -119,6 +122,7 @@ void Config::Load() {
CHECK(!settings_.api_key.empty());
CHECK(!settings_.oauth_url.empty());
CHECK(!settings_.service_url.empty());
+ CHECK(!settings_.xmpp_endpoint.empty());
CHECK(!settings_.oem_name.empty());
CHECK(!settings_.model_name.empty());
CHECK(!settings_.model_id.empty());
@@ -190,6 +194,10 @@ void Config::Transaction::LoadState() {
set_service_url(tmp);
}
+ if (dict->GetString(config_keys::kXmppEndpoint, &tmp)) {
+ set_xmpp_endpoint(tmp);
+ }
+
if (dict->GetString(config_keys::kName, &tmp))
set_name(tmp);
@@ -249,6 +257,7 @@ void Config::Save() {
dict.SetString(config_keys::kApiKey, settings_.api_key);
dict.SetString(config_keys::kOAuthURL, settings_.oauth_url);
dict.SetString(config_keys::kServiceURL, settings_.service_url);
+ dict.SetString(config_keys::kXmppEndpoint, settings_.xmpp_endpoint);
dict.SetString(config_keys::kRefreshToken, settings_.refresh_token);
dict.SetString(config_keys::kCloudId, settings_.cloud_id);
dict.SetString(config_keys::kDeviceId, settings_.device_id);
diff --git a/src/config.h b/src/config.h
index 6dc0a07..8e0a8f3 100644
--- a/src/config.h
+++ b/src/config.h
@@ -68,6 +68,9 @@ class Config final {
void set_service_url(const std::string& url) {
settings_->service_url = url;
}
+ void set_xmpp_endpoint(const std::string& endpoint) {
+ settings_->xmpp_endpoint = endpoint;
+ }
void set_name(const std::string& name) { settings_->name = name; }
void set_description(const std::string& description) {
settings_->description = description;
diff --git a/src/config_unittest.cc b/src/config_unittest.cc
index 4b0e5b4..bb2743a 100644
--- a/src/config_unittest.cc
+++ b/src/config_unittest.cc
@@ -62,6 +62,7 @@ TEST_F(ConfigTest, Defaults) {
EXPECT_EQ("", GetSettings().api_key);
EXPECT_EQ("https://accounts.google.com/o/oauth2/", GetSettings().oauth_url);
EXPECT_EQ("https://www.googleapis.com/weave/v1/", GetSettings().service_url);
+ EXPECT_EQ("talk.google.com:5223", GetSettings().xmpp_endpoint);
EXPECT_EQ("", GetSettings().oem_name);
EXPECT_EQ("", GetSettings().model_name);
EXPECT_EQ("", GetSettings().model_id);
@@ -146,7 +147,8 @@ TEST_F(ConfigTest, LoadState) {
"refresh_token": "state_refresh_token",
"robot_account": "state_robot_account",
"secret": "c3RhdGVfc2VjcmV0",
- "service_url": "state_service_url"
+ "service_url": "state_service_url",
+ "xmpp_endpoint": "state_xmpp_endpoint"
})";
EXPECT_CALL(config_store_, LoadSettings(kConfigName)).WillOnce(Return(state));
@@ -157,6 +159,7 @@ TEST_F(ConfigTest, LoadState) {
EXPECT_EQ("state_api_key", GetSettings().api_key);
EXPECT_EQ("state_oauth_url", GetSettings().oauth_url);
EXPECT_EQ("state_service_url", GetSettings().service_url);
+ EXPECT_EQ("state_xmpp_endpoint", GetSettings().xmpp_endpoint);
EXPECT_EQ(GetDefaultSettings().oem_name, GetSettings().oem_name);
EXPECT_EQ(GetDefaultSettings().model_name, GetSettings().model_name);
EXPECT_EQ(GetDefaultSettings().model_id, GetSettings().model_id);
@@ -200,6 +203,9 @@ TEST_F(ConfigTest, Setters) {
change.set_service_url("set_service_url");
EXPECT_EQ("set_service_url", GetSettings().service_url);
+ change.set_xmpp_endpoint("set_xmpp_endpoint");
+ EXPECT_EQ("set_xmpp_endpoint", GetSettings().xmpp_endpoint);
+
change.set_name("set_name");
EXPECT_EQ("set_name", GetSettings().name);
@@ -277,7 +283,8 @@ TEST_F(ConfigTest, Setters) {
'refresh_token': 'set_token',
'robot_account': 'set_account',
'secret': 'AQIDBAU=',
- 'service_url': 'set_service_url'
+ 'service_url': 'set_service_url',
+ 'xmpp_endpoint': 'set_xmpp_endpoint'
})";
EXPECT_JSON_EQ(expected, *test::CreateValue(json));
callback.Run(nullptr);
diff --git a/src/device_manager.cc b/src/device_manager.cc
index 097f854..deb5404 100644
--- a/src/device_manager.cc
+++ b/src/device_manager.cc
@@ -8,6 +8,8 @@
#include <base/bind.h>
+#include "src/access_api_handler.h"
+#include "src/access_black_list_manager_impl.h"
#include "src/base_api_handler.h"
#include "src/commands/schema_constants.h"
#include "src/component_manager_impl.h"
@@ -40,6 +42,10 @@ DeviceManager::DeviceManager(provider::ConfigStore* config_store,
network, auth_manager_.get()));
base_api_handler_.reset(new BaseApiHandler{device_info_.get(), this});
+ black_list_manager_.reset(new AccessBlackListManagerImpl{config_store});
+ access_api_handler_.reset(
+ new AccessApiHandler{this, black_list_manager_.get()});
+
device_info_->Start();
if (http_server) {
diff --git a/src/device_manager.h b/src/device_manager.h
index d40ba8e..d77bacc 100644
--- a/src/device_manager.h
+++ b/src/device_manager.h
@@ -10,6 +10,8 @@
namespace weave {
+class AccessApiHandler;
+class AccessBlackListManager;
class BaseApiHandler;
class Config;
class ComponentManager;
@@ -107,6 +109,8 @@ class DeviceManager final : public Device {
std::unique_ptr<ComponentManager> component_manager_;
std::unique_ptr<DeviceRegistrationInfo> device_info_;
std::unique_ptr<BaseApiHandler> base_api_handler_;
+ std::unique_ptr<AccessBlackListManager> black_list_manager_;
+ std::unique_ptr<AccessApiHandler> access_api_handler_;
std::unique_ptr<privet::Manager> privet_;
base::WeakPtrFactory<DeviceManager> weak_ptr_factory_{this};
diff --git a/src/device_registration_info.cc b/src/device_registration_info.cc
index 7c20084..0dc1f54 100644
--- a/src/device_registration_info.cc
+++ b/src/device_registration_info.cc
@@ -463,8 +463,9 @@ void DeviceRegistrationInfo::StartNotificationChannel() {
current_notification_channel_ = pull_channel_.get();
notification_channel_starting_ = true;
- primary_notification_channel_.reset(new XmppChannel{
- GetSettings().robot_account, access_token_, task_runner_, network_});
+ primary_notification_channel_.reset(
+ new XmppChannel{GetSettings().robot_account, access_token_,
+ GetSettings().xmpp_endpoint, task_runner_, network_});
primary_notification_channel_->Start(this);
}
@@ -833,17 +834,25 @@ bool DeviceRegistrationInfo::UpdateServiceConfig(
const std::string& api_key,
const std::string& oauth_url,
const std::string& service_url,
+ const std::string& xmpp_endpoint,
ErrorPtr* error) {
if (HaveRegistrationCredentials()) {
return Error::AddTo(error, FROM_HERE, kErrorAlreayRegistered,
"Unable to change config for registered device");
}
Config::Transaction change{config_};
- change.set_client_id(client_id);
- change.set_client_secret(client_secret);
- change.set_api_key(api_key);
- change.set_oauth_url(oauth_url);
- change.set_service_url(service_url);
+ if (!client_id.empty())
+ change.set_client_id(client_id);
+ if (!client_secret.empty())
+ change.set_client_secret(client_secret);
+ if (!api_key.empty())
+ change.set_api_key(api_key);
+ if (!oauth_url.empty())
+ change.set_oauth_url(oauth_url);
+ if (!service_url.empty())
+ change.set_service_url(service_url);
+ if (!xmpp_endpoint.empty())
+ change.set_xmpp_endpoint(xmpp_endpoint);
return true;
}
diff --git a/src/device_registration_info.h b/src/device_registration_info.h
index f670b68..a296258 100644
--- a/src/device_registration_info.h
+++ b/src/device_registration_info.h
@@ -78,6 +78,7 @@ class DeviceRegistrationInfo : public NotificationDelegate,
const std::string& api_key,
const std::string& oauth_url,
const std::string& service_url,
+ const std::string& xmpp_endpoint,
ErrorPtr* error);
void GetDeviceInfo(const CloudRequestDoneCallback& callback);
diff --git a/src/device_registration_info_unittest.cc b/src/device_registration_info_unittest.cc
index 7908c8b..bbc167e 100644
--- a/src/device_registration_info_unittest.cc
+++ b/src/device_registration_info_unittest.cc
@@ -44,6 +44,7 @@ namespace {
namespace test_data {
+const char kXmppEndpoint[] = "xmpp.server.com:1234";
const char kServiceURL[] = "http://gcd.server.com/";
const char kOAuthURL[] = "http://oauth.server.com/";
const char kApiKey[] = "GOadRdTf9FERf0k4w6EFOof56fUJ3kFDdFL3d7f";
@@ -144,6 +145,7 @@ class DeviceRegistrationInfoTest : public ::testing::Test {
settings->model_id = "AAAAA";
settings->oauth_url = test_data::kOAuthURL;
settings->service_url = test_data::kServiceURL;
+ settings->xmpp_endpoint = test_data::kXmppEndpoint;
return true;
}));
config_.reset(new Config{&config_store_});
diff --git a/src/notification/xmpp_channel.cc b/src/notification/xmpp_channel.cc
index ceb45ed..f9d7924 100644
--- a/src/notification/xmpp_channel.cc
+++ b/src/notification/xmpp_channel.cc
@@ -7,6 +7,7 @@
#include <string>
#include <base/bind.h>
+#include <base/strings/string_number_conversions.h>
#include <weave/provider/network.h>
#include <weave/provider/task_runner.h>
@@ -16,6 +17,7 @@
#include "src/notification/notification_parser.h"
#include "src/notification/xml_node.h"
#include "src/privet/openssl_utils.h"
+#include "src/string_utils.h"
#include "src/utils.h"
namespace weave {
@@ -74,9 +76,6 @@ const BackoffEntry::Policy kDefaultBackoffPolicy = {
false,
};
-const char kDefaultXmppHost[] = "talk.google.com";
-const uint16_t kDefaultXmppPort = 5223;
-
// Used for keeping connection alive.
const int kRegularPingIntervalSeconds = 60;
const int kRegularPingTimeoutSeconds = 30;
@@ -91,10 +90,12 @@ const int kConnectingTimeoutAfterNetChangeSeconds = 30;
XmppChannel::XmppChannel(const std::string& account,
const std::string& access_token,
+ const std::string& xmpp_endpoint,
provider::TaskRunner* task_runner,
provider::Network* network)
: account_{account},
access_token_{access_token},
+ xmpp_endpoint_{xmpp_endpoint},
network_{network},
backoff_entry_{&kDefaultBackoffPolicy},
task_runner_{task_runner},
@@ -285,10 +286,16 @@ void XmppChannel::HandleMessageStanza(std::unique_ptr<XmlNode> stanza) {
void XmppChannel::CreateSslSocket() {
CHECK(!stream_);
state_ = XmppState::kConnecting;
- LOG(INFO) << "Starting XMPP connection to " << kDefaultXmppHost << ":"
- << kDefaultXmppPort;
+ LOG(INFO) << "Starting XMPP connection to: " << xmpp_endpoint_;
+
+ std::pair<std::string, std::string> host_port =
+ SplitAtFirst(xmpp_endpoint_, ":", true);
+ CHECK(!host_port.first.empty());
+ CHECK(!host_port.second.empty());
+ uint32_t port = 0;
+ CHECK(base::StringToUint(host_port.second, &port)) << xmpp_endpoint_;
- network_->OpenSslSocket(kDefaultXmppHost, kDefaultXmppPort,
+ network_->OpenSslSocket(host_port.first, port,
base::Bind(&XmppChannel::OnSslSocketReady,
task_ptr_factory_.GetWeakPtr()));
}
diff --git a/src/notification/xmpp_channel.h b/src/notification/xmpp_channel.h
index 50e84d2..b0a4468 100644
--- a/src/notification/xmpp_channel.h
+++ b/src/notification/xmpp_channel.h
@@ -45,6 +45,7 @@ class XmppChannel : public NotificationChannel,
// so you will need to reset the XmppClient every time this happens.
XmppChannel(const std::string& account,
const std::string& access_token,
+ const std::string& xmpp_endpoint,
provider::TaskRunner* task_runner,
provider::Network* network);
~XmppChannel() override = default;
@@ -124,12 +125,15 @@ class XmppChannel : public NotificationChannel,
// Robot account name for the device.
std::string account_;
- // Full JID of this device.
- std::string jid_;
-
// OAuth access token for the account. Expires fairly frequently.
std::string access_token_;
+ // Xmpp endpoint.
+ std::string xmpp_endpoint_;
+
+ // Full JID of this device.
+ std::string jid_;
+
provider::Network* network_{nullptr};
std::unique_ptr<Stream> stream_;
diff --git a/src/notification/xmpp_channel_unittest.cc b/src/notification/xmpp_channel_unittest.cc
index 674fe22..dfa2a79 100644
--- a/src/notification/xmpp_channel_unittest.cc
+++ b/src/notification/xmpp_channel_unittest.cc
@@ -26,6 +26,7 @@ namespace {
constexpr char kAccountName[] = "Account@Name";
constexpr char kAccessToken[] = "AccessToken";
+constexpr char kEndpoint[] = "endpoint:456";
constexpr char kStartStreamMessage[] =
"<stream:stream to='clouddevices.gserviceaccount.com' "
@@ -84,7 +85,8 @@ class FakeXmppChannel : public XmppChannel {
public:
explicit FakeXmppChannel(provider::TaskRunner* task_runner,
provider::Network* network)
- : XmppChannel{kAccountName, kAccessToken, task_runner, network},
+ : XmppChannel{kAccountName, kAccessToken, kEndpoint, task_runner,
+ network},
stream_{new test::FakeStream{task_runner_}},
fake_stream_{stream_.get()} {}
@@ -122,7 +124,7 @@ class MockNetwork : public provider::test::MockNetwork {
class XmppChannelTest : public ::testing::Test {
protected:
XmppChannelTest() {
- EXPECT_CALL(network_, OpenSslSocket("talk.google.com", 5223, _))
+ EXPECT_CALL(network_, OpenSslSocket("endpoint", 456, _))
.WillOnce(
WithArgs<2>(Invoke(&xmpp_client_, &FakeXmppChannel::Connect)));
}
diff --git a/src/privet/auth_manager.cc b/src/privet/auth_manager.cc
index 66d04c4..c82887e 100644
--- a/src/privet/auth_manager.cc
+++ b/src/privet/auth_manager.cc
@@ -18,6 +18,7 @@
extern "C" {
#include "third_party/libuweave/src/macaroon.h"
+#include "third_party/libuweave/src/macaroon_caveat_internal.h"
}
namespace weave {
@@ -25,9 +26,19 @@ namespace privet {
namespace {
+const time_t kJ2000ToTimeT = 946684800;
const size_t kMaxMacaroonSize = 1024;
const size_t kMaxPendingClaims = 10;
const char kInvalidTokenError[] = "invalid_token";
+const int kSessionIdTtlMinutes = 1;
+
+uint32_t ToJ2000Time(const base::Time& time) {
+ return std::max(time.ToTimeT(), kJ2000ToTimeT) - kJ2000ToTimeT;
+}
+
+base::Time FromJ2000Time(uint32_t time) {
+ return base::Time::FromTimeT(time + kJ2000ToTimeT);
+}
template <class T>
void AppendToArray(T value, std::vector<uint8_t>* array) {
@@ -37,78 +48,108 @@ void AppendToArray(T value, std::vector<uint8_t>* array) {
class Caveat {
public:
- // TODO(vitalybuka): Use _get_buffer_size_ when available.
- Caveat(UwMacaroonCaveatType type, uint32_t value) : buffer(8) {
- CHECK(uw_macaroon_caveat_create_with_uint_(type, value, buffer.data(),
- buffer.size(), &caveat));
+ Caveat(UwMacaroonCaveatType type, size_t str_len)
+ : buffer_(uw_macaroon_caveat_creation_get_buffsize_(type, str_len)) {
+ CHECK(!buffer_.empty());
}
+ const UwMacaroonCaveat& GetCaveat() const { return caveat_; }
+
+ protected:
+ UwMacaroonCaveat caveat_{};
+ std::vector<uint8_t> buffer_;
+
+ DISALLOW_COPY_AND_ASSIGN(Caveat);
+};
- // TODO(vitalybuka): Use _get_buffer_size_ when available.
- Caveat(UwMacaroonCaveatType type, const std::string& value)
- : buffer(std::max<size_t>(value.size(), 32u) * 2) {
- CHECK(uw_macaroon_caveat_create_with_str_(
- type, reinterpret_cast<const uint8_t*>(value.data()), value.size(),
- buffer.data(), buffer.size(), &caveat));
+class ScopeCaveat : public Caveat {
+ public:
+ explicit ScopeCaveat(UwMacaroonCaveatScopeType scope)
+ : Caveat(kUwMacaroonCaveatTypeScope, 0) {
+ CHECK(uw_macaroon_caveat_create_scope_(scope, buffer_.data(),
+ buffer_.size(), &caveat_));
}
- const UwMacaroonCaveat& GetCaveat() const { return caveat; }
+ DISALLOW_COPY_AND_ASSIGN(ScopeCaveat);
+};
- private:
- UwMacaroonCaveat caveat;
- std::vector<uint8_t> buffer;
+class TimestampCaveat : public Caveat {
+ public:
+ explicit TimestampCaveat(const base::Time& timestamp)
+ : Caveat(kUwMacaroonCaveatTypeDelegationTimestamp, 0) {
+ CHECK(uw_macaroon_caveat_create_delegation_timestamp_(
+ ToJ2000Time(timestamp), buffer_.data(), buffer_.size(), &caveat_));
+ }
- DISALLOW_COPY_AND_ASSIGN(Caveat);
+ DISALLOW_COPY_AND_ASSIGN(TimestampCaveat);
};
-bool CheckCaveatType(const UwMacaroonCaveat& caveat,
- UwMacaroonCaveatType type,
- ErrorPtr* error) {
- UwMacaroonCaveatType caveat_type{};
- if (!uw_macaroon_caveat_get_type_(&caveat, &caveat_type)) {
- return Error::AddTo(error, FROM_HERE, kInvalidTokenError,
- "Unable to get type");
+class ExpirationCaveat : public Caveat {
+ public:
+ explicit ExpirationCaveat(const base::Time& timestamp)
+ : Caveat(kUwMacaroonCaveatTypeExpirationAbsolute, 0) {
+ CHECK(uw_macaroon_caveat_create_expiration_absolute_(
+ ToJ2000Time(timestamp), buffer_.data(), buffer_.size(), &caveat_));
}
- if (caveat_type != type) {
- return Error::AddTo(error, FROM_HERE, kInvalidTokenError,
- "Unexpected caveat type");
+ DISALLOW_COPY_AND_ASSIGN(ExpirationCaveat);
+};
+
+class UserIdCaveat : public Caveat {
+ public:
+ explicit UserIdCaveat(const std::vector<uint8_t>& id)
+ : Caveat(kUwMacaroonCaveatTypeDelegateeUser, id.size()) {
+ CHECK(uw_macaroon_caveat_create_delegatee_user_(
+ id.data(), id.size(), buffer_.data(), buffer_.size(), &caveat_));
}
- return true;
-}
+ DISALLOW_COPY_AND_ASSIGN(UserIdCaveat);
+};
-bool ReadCaveat(const UwMacaroonCaveat& caveat,
- UwMacaroonCaveatType type,
- uint32_t* value,
- ErrorPtr* error) {
- if (!CheckCaveatType(caveat, type, error))
- return false;
+class AppIdCaveat : public Caveat {
+ public:
+ explicit AppIdCaveat(const std::vector<uint8_t>& id)
+ : Caveat(kUwMacaroonCaveatTypeDelegateeApp, id.size()) {
+ CHECK(uw_macaroon_caveat_create_delegatee_app_(
+ id.data(), id.size(), buffer_.data(), buffer_.size(), &caveat_));
+ }
- if (!uw_macaroon_caveat_get_value_uint_(&caveat, value)) {
- return Error::AddTo(error, FROM_HERE, kInvalidTokenError,
- "Unable to read caveat");
+ DISALLOW_COPY_AND_ASSIGN(AppIdCaveat);
+};
+
+class ServiceCaveat : public Caveat {
+ public:
+ explicit ServiceCaveat(const std::string& id)
+ : Caveat(kUwMacaroonCaveatTypeDelegateeService, id.size()) {
+ CHECK(uw_macaroon_caveat_create_delegatee_service_(
+ reinterpret_cast<const uint8_t*>(id.data()), id.size(), buffer_.data(),
+ buffer_.size(), &caveat_));
}
- return true;
-}
+ DISALLOW_COPY_AND_ASSIGN(ServiceCaveat);
+};
-bool ReadCaveat(const UwMacaroonCaveat& caveat,
- UwMacaroonCaveatType type,
- std::string* value,
- ErrorPtr* error) {
- if (!CheckCaveatType(caveat, type, error))
- return false;
+class SessionIdCaveat : public Caveat {
+ public:
+ explicit SessionIdCaveat(const std::string& id)
+ : Caveat(kUwMacaroonCaveatTypeLanSessionID, id.size()) {
+ CHECK(uw_macaroon_caveat_create_lan_session_id_(
+ reinterpret_cast<const uint8_t*>(id.data()), id.size(), buffer_.data(),
+ buffer_.size(), &caveat_));
+ }
- const uint8_t* start{nullptr};
- size_t size{0};
- if (!uw_macaroon_caveat_get_value_str_(&caveat, &start, &size)) {
- return Error::AddTo(error, FROM_HERE, kInvalidTokenError,
- "Unable to read caveat");
+ DISALLOW_COPY_AND_ASSIGN(SessionIdCaveat);
+};
+
+class ClientAuthTokenCaveat : public Caveat {
+ public:
+ ClientAuthTokenCaveat()
+ : Caveat(kUwMacaroonCaveatTypeClientAuthorizationTokenV1, 0) {
+ CHECK(uw_macaroon_caveat_create_client_authorization_token_(
+ nullptr, 0, buffer_.data(), buffer_.size(), &caveat_));
}
- value->assign(reinterpret_cast<const char*>(start), size);
- return true;
-}
+ DISALLOW_COPY_AND_ASSIGN(ClientAuthTokenCaveat);
+};
std::vector<uint8_t> CreateSecret() {
std::vector<uint8_t> secret(kSha256OutputSize);
@@ -122,18 +163,53 @@ bool IsClaimAllowed(RootClientTokenOwner curret, RootClientTokenOwner claimer) {
std::vector<uint8_t> CreateMacaroonToken(
const std::vector<uint8_t>& secret,
- const std::vector<UwMacaroonCaveat>& caveats) {
+ const base::Time& time,
+ const std::vector<const UwMacaroonCaveat*>& caveats) {
CHECK_EQ(kSha256OutputSize, secret.size());
+
+ UwMacaroonContext context{};
+ CHECK(uw_macaroon_context_create_(ToJ2000Time(time), nullptr, 0, &context));
+
UwMacaroon macaroon{};
- CHECK(uw_macaroon_new_from_root_key_(&macaroon, secret.data(), secret.size(),
- caveats.data(), caveats.size()));
+ CHECK(uw_macaroon_create_from_root_key_(&macaroon, secret.data(),
+ secret.size(), &context,
+ caveats.data(), caveats.size()));
- std::vector<uint8_t> token(kMaxMacaroonSize);
+ std::vector<uint8_t> serialized_token(kMaxMacaroonSize);
size_t len = 0;
- CHECK(uw_macaroon_dump_(&macaroon, token.data(), token.size(), &len));
- token.resize(len);
+ CHECK(uw_macaroon_serialize_(&macaroon, serialized_token.data(),
+ serialized_token.size(), &len));
+ serialized_token.resize(len);
- return token;
+ return serialized_token;
+}
+
+std::vector<uint8_t> ExtendMacaroonToken(
+ const UwMacaroon& macaroon,
+ const base::Time& time,
+ const std::vector<const UwMacaroonCaveat*>& caveats) {
+ UwMacaroonContext context{};
+ CHECK(uw_macaroon_context_create_(ToJ2000Time(time), nullptr, 0, &context));
+
+ UwMacaroon prev_macaroon = macaroon;
+ std::vector<uint8_t> prev_buffer(kMaxMacaroonSize);
+ std::vector<uint8_t> new_buffer(kMaxMacaroonSize);
+
+ for (auto caveat : caveats) {
+ UwMacaroon new_macaroon{};
+ CHECK(uw_macaroon_extend_(&prev_macaroon, &new_macaroon, &context, caveat,
+ new_buffer.data(), new_buffer.size()));
+ new_buffer.swap(prev_buffer);
+ prev_macaroon = new_macaroon;
+ }
+
+ std::vector<uint8_t> serialized_token(kMaxMacaroonSize);
+ size_t len = 0;
+ CHECK(uw_macaroon_serialize_(&prev_macaroon, serialized_token.data(),
+ serialized_token.size(), &len));
+ serialized_token.resize(len);
+
+ return serialized_token;
}
bool LoadMacaroon(const std::vector<uint8_t>& token,
@@ -141,8 +217,8 @@ bool LoadMacaroon(const std::vector<uint8_t>& token,
UwMacaroon* macaroon,
ErrorPtr* error) {
buffer->resize(kMaxMacaroonSize);
- if (!uw_macaroon_load_(token.data(), token.size(), buffer->data(),
- buffer->size(), macaroon)) {
+ if (!uw_macaroon_deserialize_(token.data(), token.size(), buffer->data(),
+ buffer->size(), macaroon)) {
return Error::AddTo(error, FROM_HERE, kInvalidTokenError,
"Invalid token format");
}
@@ -151,10 +227,16 @@ bool LoadMacaroon(const std::vector<uint8_t>& token,
bool VerifyMacaroon(const std::vector<uint8_t>& secret,
const UwMacaroon& macaroon,
+ const base::Time& time,
+ UwMacaroonValidationResult* result,
ErrorPtr* error) {
CHECK_EQ(kSha256OutputSize, secret.size());
- if (!uw_macaroon_verify_(&macaroon, secret.data(), secret.size())) {
- return Error::AddTo(error, FROM_HERE, "invalid_signature",
+ UwMacaroonContext context = {};
+ CHECK(uw_macaroon_context_create_(ToJ2000Time(time), nullptr, 0, &context));
+
+ if (!uw_macaroon_validate_(&macaroon, secret.data(), secret.size(), &context,
+ result)) {
+ return Error::AddTo(error, FROM_HERE, "invalid_token",
"Invalid token signature");
}
return true;
@@ -239,14 +321,22 @@ AuthManager::~AuthManager() {}
std::vector<uint8_t> AuthManager::CreateAccessToken(const UserInfo& user_info,
base::TimeDelta ttl) const {
- Caveat scope{kUwMacaroonCaveatTypeScope, ToMacaroonScope(user_info.scope())};
- Caveat user{kUwMacaroonCaveatTypeIdentifier, user_info.user_id()};
- Caveat issued{kUwMacaroonCaveatTypeExpiration,
- static_cast<uint32_t>((Now() + ttl).ToTimeT())};
+ const base::Time now = Now();
+ TimestampCaveat issued{now};
+ ScopeCaveat scope{ToMacaroonScope(user_info.scope())};
+ // Macaroons have no caveats for auth type. So we just append the type to the
+ // user ID.
+ std::vector<uint8_t> id_with_type{user_info.id().user};
+ id_with_type.push_back(static_cast<uint8_t>(user_info.id().type));
+ UserIdCaveat user{id_with_type};
+ AppIdCaveat app{user_info.id().app};
+ ExpirationCaveat expiration{now + ttl};
return CreateMacaroonToken(
- access_secret_,
+ access_secret_, now,
{
- scope.GetCaveat(), user.GetCaveat(), issued.GetCaveat(),
+
+ &issued.GetCaveat(), &scope.GetCaveat(), &user.GetCaveat(),
+ &app.GetCaveat(), &expiration.GetCaveat(),
});
}
@@ -256,37 +346,40 @@ bool AuthManager::ParseAccessToken(const std::vector<uint8_t>& token,
std::vector<uint8_t> buffer;
UwMacaroon macaroon{};
- uint32_t scope{0};
- std::string user_id;
- uint32_t expiration{0};
-
+ UwMacaroonValidationResult result{};
+ const base::Time now = Now();
if (!LoadMacaroon(token, &buffer, &macaroon, error) ||
- !VerifyMacaroon(access_secret_, macaroon, error) ||
- macaroon.num_caveats != 3 ||
- !ReadCaveat(macaroon.caveats[0], kUwMacaroonCaveatTypeScope, &scope,
- error) ||
- !ReadCaveat(macaroon.caveats[1], kUwMacaroonCaveatTypeIdentifier,
- &user_id, error) ||
- !ReadCaveat(macaroon.caveats[2], kUwMacaroonCaveatTypeExpiration,
- &expiration, error)) {
+ macaroon.num_caveats != 5 ||
+ !VerifyMacaroon(access_secret_, macaroon, now, &result, error)) {
return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthorization,
"Invalid token");
}
- AuthScope auth_scope{FromMacaroonScope(scope)};
+ AuthScope auth_scope{FromMacaroonScope(result.granted_scope)};
if (auth_scope == AuthScope::kNone) {
return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthorization,
"Invalid token data");
}
- base::Time time{base::Time::FromTimeT(expiration)};
- if (time < clock_->Now()) {
- return Error::AddTo(error, FROM_HERE, errors::kAuthorizationExpired,
- "Token is expired");
- }
-
+ // If token is valid and token was not extended, it should has precisely this
+ // values.
+ CHECK_GE(FromJ2000Time(result.expiration_time), now);
+ CHECK_EQ(2u, result.num_delegatees);
+ CHECK_EQ(kUwMacaroonDelegateeTypeUser, result.delegatees[0].type);
+ CHECK_EQ(kUwMacaroonDelegateeTypeApp, result.delegatees[1].type);
+ CHECK_GT(result.delegatees[0].id_len, 1u);
+ std::vector<uint8_t> user_id{
+ result.delegatees[0].id,
+ result.delegatees[0].id + result.delegatees[0].id_len};
+ // Last byte is used for type. See |CreateAccessToken|.
+ AuthType type = static_cast<AuthType>(user_id.back());
+ user_id.pop_back();
+
+ std::vector<uint8_t> app_id{
+ result.delegatees[1].id,
+ result.delegatees[1].id + result.delegatees[1].id_len};
if (user_info)
- *user_info = UserInfo{auth_scope, user_id};
+ *user_info = UserInfo{auth_scope, UserAppId{type, user_id, app_id}};
return true;
}
@@ -309,7 +402,7 @@ std::vector<uint8_t> AuthManager::ClaimRootClientAuthToken(
std::unique_ptr<AuthManager>{new AuthManager{nullptr, {}}}, owner));
if (pending_claims_.size() > kMaxPendingClaims)
pending_claims_.pop_front();
- return pending_claims_.back().first->GetRootClientAuthToken();
+ return pending_claims_.back().first->GetRootClientAuthToken(owner);
}
bool AuthManager::ConfirmClientAuthToken(const std::vector<uint8_t>& token,
@@ -332,14 +425,20 @@ bool AuthManager::ConfirmClientAuthToken(const std::vector<uint8_t>& token,
return true;
}
-std::vector<uint8_t> AuthManager::GetRootClientAuthToken() const {
- Caveat scope{kUwMacaroonCaveatTypeScope, kUwMacaroonCaveatScopeTypeOwner};
- Caveat issued{kUwMacaroonCaveatTypeIssued,
- static_cast<uint32_t>(Now().ToTimeT())};
- return CreateMacaroonToken(auth_secret_,
- {
- scope.GetCaveat(), issued.GetCaveat(),
- });
+std::vector<uint8_t> AuthManager::GetRootClientAuthToken(
+ RootClientTokenOwner owner) const {
+ CHECK(RootClientTokenOwner::kNone != owner);
+ ClientAuthTokenCaveat auth_token;
+ const base::Time now = Now();
+ TimestampCaveat issued{now};
+
+ ServiceCaveat client{owner == RootClientTokenOwner::kCloud ? "google.com"
+ : ""};
+ return CreateMacaroonToken(
+ auth_secret_, now,
+ {
+ &auth_token.GetCaveat(), &issued.GetCaveat(), &client.GetCaveat(),
+ });
}
base::Time AuthManager::Now() const {
@@ -350,8 +449,9 @@ bool AuthManager::IsValidAuthToken(const std::vector<uint8_t>& token,
ErrorPtr* error) const {
std::vector<uint8_t> buffer;
UwMacaroon macaroon{};
+ UwMacaroonValidationResult result{};
if (!LoadMacaroon(token, &buffer, &macaroon, error) ||
- !VerifyMacaroon(auth_secret_, macaroon, error)) {
+ !VerifyMacaroon(auth_secret_, macaroon, Now(), &result, error)) {
return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthCode,
"Invalid token");
}
@@ -365,19 +465,63 @@ bool AuthManager::CreateAccessTokenFromAuth(
AuthScope* access_token_scope,
base::TimeDelta* access_token_ttl,
ErrorPtr* error) const {
- // TODO(vitalybuka): implement token validation.
- if (!IsValidAuthToken(auth_token, error))
- return false;
+ std::vector<uint8_t> buffer;
+ UwMacaroon macaroon{};
+ UwMacaroonValidationResult result{};
+ const base::Time now = Now();
+ if (!LoadMacaroon(auth_token, &buffer, &macaroon, error) ||
+ !VerifyMacaroon(auth_secret_, macaroon, now, &result, error)) {
+ return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthCode,
+ "Invalid token");
+ }
+
+ AuthScope auth_scope{FromMacaroonScope(result.granted_scope)};
+ if (auth_scope == AuthScope::kNone) {
+ return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthCode,
+ "Invalid token data");
+ }
+
+ // TODO: Integrate black list checks.
+ auto delegates_rbegin = std::reverse_iterator<const UwMacaroonDelegateeInfo*>(
+ result.delegatees + result.num_delegatees);
+ auto delegates_rend =
+ std::reverse_iterator<const UwMacaroonDelegateeInfo*>(result.delegatees);
+ auto last_user_id =
+ std::find_if(delegates_rbegin, delegates_rend,
+ [](const UwMacaroonDelegateeInfo& delegatee) {
+ return delegatee.type == kUwMacaroonDelegateeTypeUser;
+ });
+ auto last_app_id =
+ std::find_if(delegates_rbegin, delegates_rend,
+ [](const UwMacaroonDelegateeInfo& delegatee) {
+ return delegatee.type == kUwMacaroonDelegateeTypeApp;
+ });
+
+ if (last_user_id == delegates_rend || !last_user_id->id_len) {
+ return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthCode,
+ "User ID is missing");
+ }
+
+ const char* session_id = reinterpret_cast<const char*>(result.lan_session_id);
+ if (!IsValidSessionId({session_id, session_id + result.lan_session_id_len})) {
+ return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthCode,
+ "Invalid session id");
+ }
+
+ CHECK_GE(FromJ2000Time(result.expiration_time), now);
if (!access_token)
return true;
- // TODO(vitalybuka): User and scope must be parsed from auth_token.
- UserInfo info{config_ ? config_->GetSettings().local_anonymous_access_role
- : AuthScope::kViewer,
- base::GenerateGUID()};
+ std::vector<uint8_t> user_id{last_user_id->id,
+ last_user_id->id + last_user_id->id_len};
+ std::vector<uint8_t> app_id;
+ if (last_app_id != delegates_rend)
+ app_id.assign(last_app_id->id, last_app_id->id + last_app_id->id_len);
+
+ UserInfo info{auth_scope, {AuthType::kLocal, user_id, app_id}};
- // TODO(vitalybuka): TTL also should be reduced in accordance with auth_token.
+ ttl = std::min(ttl, FromJ2000Time(result.expiration_time) - now);
*access_token = CreateAccessToken(info, ttl);
if (access_token_scope)
@@ -388,11 +532,45 @@ bool AuthManager::CreateAccessTokenFromAuth(
return true;
}
-std::vector<uint8_t> AuthManager::CreateSessionId() {
- std::vector<uint8_t> result;
- AppendToArray(Now().ToTimeT(), &result);
- AppendToArray(++session_counter_, &result);
- return result;
+std::string AuthManager::CreateSessionId() const {
+ return std::to_string(ToJ2000Time(Now())) + ":" +
+ std::to_string(++session_counter_);
+}
+
+bool AuthManager::IsValidSessionId(const std::string& session_id) const {
+ base::Time ssid_time = FromJ2000Time(std::atoi(session_id.c_str()));
+ return Now() - base::TimeDelta::FromMinutes(kSessionIdTtlMinutes) <=
+ ssid_time &&
+ ssid_time <= Now();
+}
+
+std::vector<uint8_t> AuthManager::DelegateToUser(
+ const std::vector<uint8_t>& token,
+ base::TimeDelta ttl,
+ const UserInfo& user_info) const {
+ std::vector<uint8_t> buffer;
+ UwMacaroon macaroon{};
+ CHECK(LoadMacaroon(token, &buffer, &macaroon, nullptr));
+
+ const base::Time now = Now();
+ TimestampCaveat issued{now};
+ ExpirationCaveat expiration{now + ttl};
+ ScopeCaveat scope{ToMacaroonScope(user_info.scope())};
+ UserIdCaveat user{user_info.id().user};
+ AppIdCaveat app{user_info.id().app};
+ SessionIdCaveat session{CreateSessionId()};
+
+ std::vector<const UwMacaroonCaveat*> caveats{
+ &issued.GetCaveat(), &expiration.GetCaveat(), &scope.GetCaveat(),
+ &user.GetCaveat(),
+ };
+
+ if (!user_info.id().app.empty())
+ caveats.push_back(&app.GetCaveat());
+
+ caveats.push_back(&session.GetCaveat());
+
+ return ExtendMacaroonToken(macaroon, now, caveats);
}
} // namespace privet
diff --git a/src/privet/auth_manager.h b/src/privet/auth_manager.h
index 309d80e..f0a5761 100644
--- a/src/privet/auth_manager.h
+++ b/src/privet/auth_manager.h
@@ -9,6 +9,7 @@
#include <string>
#include <vector>
+#include <base/gtest_prod_util.h>
#include <base/time/default_clock.h>
#include <base/time/time.h>
#include <weave/error.h>
@@ -54,7 +55,7 @@ class AuthManager {
bool ConfirmClientAuthToken(const std::vector<uint8_t>& token,
ErrorPtr* error);
- std::vector<uint8_t> GetRootClientAuthToken() const;
+ std::vector<uint8_t> GetRootClientAuthToken(RootClientTokenOwner owner) const;
bool IsValidAuthToken(const std::vector<uint8_t>& token,
ErrorPtr* error) const;
bool CreateAccessTokenFromAuth(const std::vector<uint8_t>& auth_token,
@@ -67,13 +68,21 @@ class AuthManager {
void SetAuthSecret(const std::vector<uint8_t>& secret,
RootClientTokenOwner owner);
- std::vector<uint8_t> CreateSessionId();
+ std::string CreateSessionId() const;
+ bool IsValidSessionId(const std::string& session_id) const;
private:
+ friend class AuthManagerTest;
+
+ // Test helpers. Device does not need to implement delegation.
+ std::vector<uint8_t> DelegateToUser(const std::vector<uint8_t>& token,
+ base::TimeDelta ttl,
+ const UserInfo& user_info) const;
+
Config* config_{nullptr}; // Can be nullptr for tests.
base::DefaultClock default_clock_;
base::Clock* clock_{&default_clock_};
- uint32_t session_counter_{0};
+ mutable uint32_t session_counter_{0};
std::vector<uint8_t> auth_secret_; // Persistent.
std::vector<uint8_t> certificate_fingerprint_;
diff --git a/src/privet/auth_manager_unittest.cc b/src/privet/auth_manager_unittest.cc
index 70750ad..294aefa 100644
--- a/src/privet/auth_manager_unittest.cc
+++ b/src/privet/auth_manager_unittest.cc
@@ -10,6 +10,7 @@
#include "src/config.h"
#include "src/data_encoding.h"
+#include "src/privet/mock_delegates.h"
#include "src/test/mock_clock.h"
using testing::Return;
@@ -29,6 +30,11 @@ class AuthManagerTest : public testing::Test {
}
protected:
+ std::vector<uint8_t> DelegateToUser(const std::vector<uint8_t>& token,
+ base::TimeDelta ttl,
+ const UserInfo& user_info) const {
+ return auth_.DelegateToUser(token, ttl, user_info);
+ }
const std::vector<uint8_t> kSecret1{
78, 40, 39, 68, 29, 19, 70, 86, 38, 61, 13, 55, 33, 32, 51, 52,
34, 43, 97, 48, 8, 56, 11, 99, 50, 59, 24, 26, 31, 71, 76, 28};
@@ -64,49 +70,90 @@ TEST_F(AuthManagerTest, Constructor) {
}
TEST_F(AuthManagerTest, CreateAccessToken) {
- EXPECT_EQ("UABRUHgcSZDry0bvIsoJv+WDQgEURQJjMjM0RgUaVArkgA==",
+ EXPECT_EQ("WC2FRggaG52hAEIBFEYJRDIzNABCCkBGBRobnaEAUFAF46oQlMmXgnLstt7wU2w=",
Base64Encode(auth_.CreateAccessToken(
- UserInfo{AuthScope::kViewer, "234"}, {})));
- EXPECT_EQ("UL7YEruLg5QQRDIp2+u1cqCDQgEIRQJjMjU3RgUaVArkgA==",
+ UserInfo{AuthScope::kViewer, TestUserId{"234"}}, {})));
+ EXPECT_EQ("WC2FRggaG52hAEIBCEYJRDI1NwBCCkBGBRobnaEAUEdWRNHcu/0mA6c3e0tgDrk=",
Base64Encode(auth_.CreateAccessToken(
- UserInfo{AuthScope::kManager, "257"}, {})));
- EXPECT_EQ("UPFGeZRanR1wLGYLP5ZDkXiDQgECRQJjNDU2RgUaVArkgA==",
+ UserInfo{AuthScope::kManager, TestUserId{"257"}}, {})));
+ EXPECT_EQ("WC2FRggaG52hAEIBAkYJRDQ1NgBCCkBGBRobnaEAUH2ZLgUPdTtjNRa+PoDkMW4=",
Base64Encode(auth_.CreateAccessToken(
- UserInfo{AuthScope::kOwner, "456"}, {})));
+ UserInfo{AuthScope::kOwner, TestUserId{"456"}}, {})));
auto new_time = clock_.Now() + base::TimeDelta::FromDays(11);
EXPECT_CALL(clock_, Now()).WillRepeatedly(Return(new_time));
- EXPECT_EQ("UMm9KlF3OEtZFBmhScJpl4uDQgEORQJjMzQ1RgUaVBllAA==",
+ EXPECT_EQ("WC2FRggaG6whgEIBDkYJRDM0NQBCCkBGBRobrCGAUDAFptj7bbYmbpaa6Wpb1Wo=",
Base64Encode(auth_.CreateAccessToken(
- UserInfo{AuthScope::kUser, "345"}, {})));
+ UserInfo{AuthScope::kUser, TestUserId{"345"}}, {})));
}
TEST_F(AuthManagerTest, CreateSameToken) {
- EXPECT_EQ(auth_.CreateAccessToken(UserInfo{AuthScope::kViewer, "555"}, {}),
- auth_.CreateAccessToken(UserInfo{AuthScope::kViewer, "555"}, {}));
+ EXPECT_EQ(auth_.CreateAccessToken(
+ UserInfo{AuthScope::kViewer, TestUserId{"555"}}, {}),
+ auth_.CreateAccessToken(
+ UserInfo{AuthScope::kViewer, TestUserId{"555"}}, {}));
+}
+
+TEST_F(AuthManagerTest, CreateSameTokenWithApp) {
+ EXPECT_EQ(auth_.CreateAccessToken(
+ UserInfo{AuthScope::kViewer,
+ {AuthType::kLocal, {1, 2, 3}, {4, 5, 6}}},
+ {}),
+ auth_.CreateAccessToken(
+ UserInfo{AuthScope::kViewer,
+ {AuthType::kLocal, {1, 2, 3}, {4, 5, 6}}},
+ {}));
+}
+
+TEST_F(AuthManagerTest, CreateSameTokenWithDifferentType) {
+ EXPECT_NE(auth_.CreateAccessToken(
+ UserInfo{AuthScope::kViewer,
+ {AuthType::kLocal, {1, 2, 3}, {4, 5, 6}}},
+ {}),
+ auth_.CreateAccessToken(
+ UserInfo{AuthScope::kViewer,
+ {AuthType::kPairing, {1, 2, 3}, {4, 5, 6}}},
+ {}));
+}
+
+TEST_F(AuthManagerTest, CreateSameTokenWithDifferentApp) {
+ EXPECT_NE(auth_.CreateAccessToken(
+ UserInfo{AuthScope::kViewer,
+ {AuthType::kLocal, {1, 2, 3}, {4, 5, 6}}},
+ {}),
+ auth_.CreateAccessToken(
+ UserInfo{AuthScope::kViewer,
+ {AuthType::kLocal, {1, 2, 3}, {4, 5, 7}}},
+ {}));
}
TEST_F(AuthManagerTest, CreateTokenDifferentScope) {
- EXPECT_NE(auth_.CreateAccessToken(UserInfo{AuthScope::kViewer, "456"}, {}),
- auth_.CreateAccessToken(UserInfo{AuthScope::kOwner, "456"}, {}));
+ EXPECT_NE(auth_.CreateAccessToken(
+ UserInfo{AuthScope::kViewer, TestUserId{"456"}}, {}),
+ auth_.CreateAccessToken(
+ UserInfo{AuthScope::kOwner, TestUserId{"456"}}, {}));
}
TEST_F(AuthManagerTest, CreateTokenDifferentUser) {
- EXPECT_NE(auth_.CreateAccessToken(UserInfo{AuthScope::kOwner, "456"}, {}),
- auth_.CreateAccessToken(UserInfo{AuthScope::kOwner, "789"}, {}));
+ EXPECT_NE(auth_.CreateAccessToken(
+ UserInfo{AuthScope::kOwner, TestUserId{"456"}}, {}),
+ auth_.CreateAccessToken(
+ UserInfo{AuthScope::kOwner, TestUserId{"789"}}, {}));
}
TEST_F(AuthManagerTest, CreateTokenDifferentTime) {
- auto token = auth_.CreateAccessToken(UserInfo{AuthScope::kOwner, "567"}, {});
+ auto token = auth_.CreateAccessToken(
+ UserInfo{AuthScope::kOwner, TestUserId{"567"}}, {});
EXPECT_CALL(clock_, Now())
.WillRepeatedly(Return(base::Time::FromTimeT(1400000000)));
- EXPECT_NE(token,
- auth_.CreateAccessToken(UserInfo{AuthScope::kOwner, "567"}, {}));
+ EXPECT_NE(token, auth_.CreateAccessToken(
+ UserInfo{AuthScope::kOwner, TestUserId{"567"}}, {}));
}
TEST_F(AuthManagerTest, CreateTokenDifferentInstance) {
- EXPECT_NE(auth_.CreateAccessToken(UserInfo{AuthScope::kUser, "123"}, {}),
+ EXPECT_NE(auth_.CreateAccessToken(
+ UserInfo{AuthScope::kUser, TestUserId{"123"}}, {}),
AuthManager({}, {}).CreateAccessToken(
- UserInfo{AuthScope::kUser, "123"}, {}));
+ UserInfo{AuthScope::kUser, TestUserId{"123"}}, {}));
}
TEST_F(AuthManagerTest, ParseAccessToken) {
@@ -117,18 +164,24 @@ TEST_F(AuthManagerTest, ParseAccessToken) {
AuthManager auth{{}, {}, {}, &clock_};
- auto token = auth.CreateAccessToken(UserInfo{AuthScope::kUser, "5"},
- base::TimeDelta::FromSeconds(i));
+ auto token =
+ auth.CreateAccessToken(UserInfo{AuthScope::kUser, TestUserId{"5"}},
+ base::TimeDelta::FromSeconds(i));
UserInfo user_info;
EXPECT_FALSE(auth_.ParseAccessToken(token, &user_info, nullptr));
EXPECT_TRUE(auth.ParseAccessToken(token, &user_info, nullptr));
EXPECT_EQ(AuthScope::kUser, user_info.scope());
- EXPECT_EQ("5", user_info.user_id());
+ EXPECT_EQ(TestUserId{"5"}, user_info.id());
EXPECT_CALL(clock_, Now())
.WillRepeatedly(Return(kStartTime + base::TimeDelta::FromSeconds(i)));
EXPECT_TRUE(auth.ParseAccessToken(token, &user_info, nullptr));
+ auto extended =
+ DelegateToUser(token, base::TimeDelta::FromSeconds(1000),
+ UserInfo{AuthScope::kUser, TestUserId{"234"}});
+ EXPECT_FALSE(auth.ParseAccessToken(extended, &user_info, nullptr));
+
EXPECT_CALL(clock_, Now())
.WillRepeatedly(
Return(kStartTime + base::TimeDelta::FromSeconds(i + 1)));
@@ -137,35 +190,135 @@ TEST_F(AuthManagerTest, ParseAccessToken) {
}
TEST_F(AuthManagerTest, GetRootClientAuthToken) {
- EXPECT_EQ("UK1ACOc3cWGjGBoTIX2bd3qCQgECRgMaVArkgA==",
- Base64Encode(auth_.GetRootClientAuthToken()));
+ EXPECT_EQ("WCCDQxkgAUYIGhudoQBCDEBQZgRhYq78I8GtFUZHNBbfGw==",
+ Base64Encode(
+ auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient)));
+}
+
+TEST_F(AuthManagerTest, GetRootClientAuthTokenDifferentOwner) {
+ EXPECT_EQ(
+ "WCqDQxkgAUYIGhudoQBMDEpnb29nbGUuY29tUOoLAxSUAZAAv54drarqhag=",
+ Base64Encode(auth_.GetRootClientAuthToken(RootClientTokenOwner::kCloud)));
}
TEST_F(AuthManagerTest, GetRootClientAuthTokenDifferentTime) {
auto new_time = clock_.Now() + base::TimeDelta::FromDays(15);
EXPECT_CALL(clock_, Now()).WillRepeatedly(Return(new_time));
- EXPECT_EQ("UBpNF8g/GbNUmAyHg1qqJr+CQgECRgMaVB6rAA==",
- Base64Encode(auth_.GetRootClientAuthToken()));
+ EXPECT_EQ("WCCDQxkgAUYIGhuxZ4BCDEBQjO+OTbjjTzZ/Dvk66nfQqg==",
+ Base64Encode(
+ auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient)));
}
TEST_F(AuthManagerTest, GetRootClientAuthTokenDifferentSecret) {
AuthManager auth{kSecret2, {}, kSecret1, &clock_};
- EXPECT_EQ("UFTBUcgd9d0HnPRnLeroN2mCQgECRgMaVArkgA==",
- Base64Encode(auth.GetRootClientAuthToken()));
+ EXPECT_EQ(
+ "WCCDQxkgAUYIGhudoQBCDEBQ2MZF8YXv5pbtmMxwz9VtLA==",
+ Base64Encode(auth.GetRootClientAuthToken(RootClientTokenOwner::kClient)));
}
TEST_F(AuthManagerTest, IsValidAuthToken) {
- EXPECT_TRUE(auth_.IsValidAuthToken(auth_.GetRootClientAuthToken(), nullptr));
+ EXPECT_TRUE(auth_.IsValidAuthToken(
+ auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient), nullptr));
// Multiple attempts with random secrets.
for (size_t i = 0; i < 1000; ++i) {
AuthManager auth{{}, {}, {}, &clock_};
- auto token = auth.GetRootClientAuthToken();
+ auto token = auth.GetRootClientAuthToken(RootClientTokenOwner::kClient);
EXPECT_FALSE(auth_.IsValidAuthToken(token, nullptr));
EXPECT_TRUE(auth.IsValidAuthToken(token, nullptr));
}
}
+TEST_F(AuthManagerTest, CreateSessionId) {
+ EXPECT_EQ("463315200:1", auth_.CreateSessionId());
+}
+
+TEST_F(AuthManagerTest, IsValidSessionId) {
+ EXPECT_TRUE(auth_.IsValidSessionId("463315200:1"));
+ EXPECT_TRUE(auth_.IsValidSessionId("463315200:2"));
+ EXPECT_TRUE(auth_.IsValidSessionId("463315150"));
+
+ // Future
+ EXPECT_FALSE(auth_.IsValidSessionId("463315230:1"));
+
+ // Expired
+ EXPECT_FALSE(auth_.IsValidSessionId("463315100:1"));
+}
+
+TEST_F(AuthManagerTest, CreateAccessTokenFromAuth) {
+ std::vector<uint8_t> access_token;
+ AuthScope scope;
+ base::TimeDelta ttl;
+ auto root = auth_.GetRootClientAuthToken(RootClientTokenOwner::kCloud);
+ auto extended = DelegateToUser(root, base::TimeDelta::FromSeconds(1000),
+ UserInfo{AuthScope::kUser, TestUserId{"234"}});
+ EXPECT_EQ(
+ "WE+IQxkgAUYIGhudoQBMDEpnb29nbGUuY29tRggaG52hAEYFGhudpOhCAQ5FCUMyMzRNEUs0"
+ "NjMzMTUyMDA6MVCRVKU+0SpOoBppnwqdKMwP",
+ Base64Encode(extended));
+ EXPECT_TRUE(
+ auth_.CreateAccessTokenFromAuth(extended, base::TimeDelta::FromDays(1),
+ &access_token, &scope, &ttl, nullptr));
+ UserInfo user_info;
+ EXPECT_TRUE(auth_.ParseAccessToken(access_token, &user_info, nullptr));
+ EXPECT_EQ(scope, user_info.scope());
+ EXPECT_EQ(AuthScope::kUser, user_info.scope());
+
+ EXPECT_EQ(TestUserId{"234"}, user_info.id());
+}
+
+TEST_F(AuthManagerTest, CreateAccessTokenFromAuthNotMinted) {
+ std::vector<uint8_t> access_token;
+ auto root = auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient);
+ ErrorPtr error;
+ EXPECT_FALSE(auth_.CreateAccessTokenFromAuth(
+ root, base::TimeDelta::FromDays(1), nullptr, nullptr, nullptr, &error));
+ EXPECT_TRUE(error->HasError("invalidAuthCode"));
+}
+
+TEST_F(AuthManagerTest, CreateAccessTokenFromAuthValidateAfterSomeTime) {
+ auto root = auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient);
+ auto extended = DelegateToUser(root, base::TimeDelta::FromSeconds(1000),
+ UserInfo{AuthScope::kUser, TestUserId{"234"}});
+
+ // new_time < session_id_expiration < token_expiration.
+ auto new_time = clock_.Now() + base::TimeDelta::FromSeconds(15);
+ EXPECT_CALL(clock_, Now()).WillRepeatedly(Return(new_time));
+ EXPECT_TRUE(
+ auth_.CreateAccessTokenFromAuth(extended, base::TimeDelta::FromDays(1),
+ nullptr, nullptr, nullptr, nullptr));
+}
+
+TEST_F(AuthManagerTest, CreateAccessTokenFromAuthExpired) {
+ auto root = auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient);
+ auto extended = DelegateToUser(root, base::TimeDelta::FromSeconds(10),
+ UserInfo{AuthScope::kUser, TestUserId{"234"}});
+ ErrorPtr error;
+
+ // token_expiration < new_time < session_id_expiration.
+ auto new_time = clock_.Now() + base::TimeDelta::FromSeconds(15);
+ EXPECT_CALL(clock_, Now()).WillRepeatedly(Return(new_time));
+ EXPECT_FALSE(
+ auth_.CreateAccessTokenFromAuth(extended, base::TimeDelta::FromDays(1),
+ nullptr, nullptr, nullptr, &error));
+ EXPECT_TRUE(error->HasError("invalidAuthCode"));
+}
+
+TEST_F(AuthManagerTest, CreateAccessTokenFromAuthExpiredSessionid) {
+ auto root = auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient);
+ auto extended = DelegateToUser(root, base::TimeDelta::FromSeconds(1000),
+ UserInfo{AuthScope::kUser, TestUserId{"234"}});
+ ErrorPtr error;
+
+ // session_id_expiration < new_time < token_expiration.
+ auto new_time = clock_.Now() + base::TimeDelta::FromSeconds(200);
+ EXPECT_CALL(clock_, Now()).WillRepeatedly(Return(new_time));
+ EXPECT_FALSE(
+ auth_.CreateAccessTokenFromAuth(extended, base::TimeDelta::FromDays(1),
+ nullptr, nullptr, nullptr, &error));
+ EXPECT_TRUE(error->HasError("invalidAuthCode"));
+}
+
class AuthManagerClaimTest : public testing::Test {
public:
void SetUp() override { EXPECT_EQ(auth_.GetAuthSecret().size(), 32u); }
@@ -241,18 +394,5 @@ TEST_F(AuthManagerClaimTest, TokenOverflow) {
EXPECT_FALSE(auth_.ConfirmClientAuthToken(token, nullptr));
}
-TEST_F(AuthManagerClaimTest, CreateAccessTokenFromAuth) {
- std::vector<uint8_t> access_token;
- AuthScope scope;
- base::TimeDelta ttl;
- EXPECT_TRUE(auth_.CreateAccessTokenFromAuth(
- auth_.GetRootClientAuthToken(), base::TimeDelta::FromDays(1),
- &access_token, &scope, &ttl, nullptr));
- UserInfo user_info;
- EXPECT_TRUE(auth_.ParseAccessToken(access_token, &user_info, nullptr));
- EXPECT_EQ(scope, user_info.scope());
- EXPECT_FALSE(user_info.user_id().empty());
-}
-
} // namespace privet
} // namespace weave
diff --git a/src/privet/cloud_delegate.cc b/src/privet/cloud_delegate.cc
index 5f31fee..49fceaa 100644
--- a/src/privet/cloud_delegate.cc
+++ b/src/privet/cloud_delegate.cc
@@ -165,7 +165,7 @@ class CloudDelegateImpl : public CloudDelegate {
const UserInfo& user_info,
const CommandDoneCallback& callback) override {
CHECK(user_info.scope() != AuthScope::kNone);
- CHECK(!user_info.user_id().empty());
+ CHECK(!user_info.id().IsEmpty());
ErrorPtr error;
UserRole role;
@@ -182,7 +182,7 @@ class CloudDelegateImpl : public CloudDelegate {
if (!command_instance)
return callback.Run({}, std::move(error));
component_manager_->AddCommand(std::move(command_instance));
- command_owners_[id] = user_info.user_id();
+ command_owners_[id] = user_info.id();
callback.Run(*component_manager_->FindCommand(id)->ToJson(), nullptr);
}
@@ -230,7 +230,7 @@ class CloudDelegateImpl : public CloudDelegate {
private:
void OnCommandAdded(Command* command) {
// Set to "" for any new unknown command.
- command_owners_.insert(std::make_pair(command->GetID(), ""));
+ command_owners_.insert(std::make_pair(command->GetID(), UserAppId{}));
}
void OnCommandRemoved(Command* command) {
@@ -309,14 +309,17 @@ class CloudDelegateImpl : public CloudDelegate {
return command;
}
- bool CanAccessCommand(const std::string& owner_id,
+ bool CanAccessCommand(const UserAppId& owner,
const UserInfo& user_info,
ErrorPtr* error) const {
CHECK(user_info.scope() != AuthScope::kNone);
- CHECK(!user_info.user_id().empty());
+ CHECK(!user_info.id().IsEmpty());
if (user_info.scope() == AuthScope::kManager ||
- owner_id == user_info.user_id()) {
+ (owner.type == user_info.id().type &&
+ owner.user == user_info.id().user &&
+ (user_info.id().app.empty() || // Token is not restricted to the app.
+ owner.app == user_info.id().app))) {
return true;
}
@@ -341,7 +344,7 @@ class CloudDelegateImpl : public CloudDelegate {
int registation_retry_count_{0};
// Map of command IDs to user IDs.
- std::map<std::string, std::string> command_owners_;
+ std::map<std::string, UserAppId> command_owners_;
// Backoff entry for retrying device registration.
BackoffEntry backoff_entry_{&register_backoff_policy};
diff --git a/src/privet/mock_delegates.h b/src/privet/mock_delegates.h
index c75d438..c2e9a89 100644
--- a/src/privet/mock_delegates.h
+++ b/src/privet/mock_delegates.h
@@ -28,6 +28,11 @@ namespace weave {
namespace privet {
+struct TestUserId : public UserAppId {
+ TestUserId(const std::string& user_id)
+ : UserAppId{AuthType::kAnonymous, {user_id.begin(), user_id.end()}, {}} {}
+};
+
ACTION_TEMPLATE(RunCallback,
HAS_1_TEMPLATE_PARAMS(int, k),
AND_0_VALUE_PARAMS()) {
@@ -103,9 +108,12 @@ class MockSecurityDelegate : public SecurityDelegate {
.WillRepeatedly(Return(true));
EXPECT_CALL(*this, ParseAccessToken(_, _, _))
- .WillRepeatedly(
- DoAll(SetArgPointee<1>(UserInfo{AuthScope::kViewer, "1234567"}),
- Return(true)));
+ .WillRepeatedly(DoAll(SetArgPointee<1>(UserInfo{
+ AuthScope::kViewer,
+ UserAppId{AuthType::kLocal,
+ {'1', '2', '3', '4', '5', '6', '7'},
+ {}}}),
+ Return(true)));
EXPECT_CALL(*this, GetPairingTypes())
.WillRepeatedly(Return(std::set<PairingType>{
diff --git a/src/privet/openssl_utils.cc b/src/privet/openssl_utils.cc
index f38fd1a..17ebf70 100644
--- a/src/privet/openssl_utils.cc
+++ b/src/privet/openssl_utils.cc
@@ -18,13 +18,9 @@ namespace privet {
std::vector<uint8_t> HmacSha256(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& data) {
std::vector<uint8_t> mac(kSha256OutputSize);
- uint8_t hmac_state[uw_crypto_hmac_required_buffer_size_()];
- CHECK(uw_crypto_hmac_init_(hmac_state, sizeof(hmac_state), key.data(),
- key.size()));
- CHECK(uw_crypto_hmac_update_(hmac_state, sizeof(hmac_state), data.data(),
- data.size()));
- CHECK(uw_crypto_hmac_final_(hmac_state, sizeof(hmac_state), mac.data(),
- mac.size()));
+ const UwCryptoHmacMsg messages[] = {{data.data(), data.size()}};
+ CHECK(uw_crypto_hmac_(key.data(), key.size(), messages, arraysize(messages),
+ mac.data(), mac.size()));
return mac;
}
diff --git a/src/privet/privet_handler_unittest.cc b/src/privet/privet_handler_unittest.cc
index fa79e77..20f5aa0 100644
--- a/src/privet/privet_handler_unittest.cc
+++ b/src/privet/privet_handler_unittest.cc
@@ -484,7 +484,8 @@ class PrivetHandlerTestWithAuth : public PrivetHandlerTest {
auth_header_ = "Privet 123";
EXPECT_CALL(security_, ParseAccessToken(_, _, _))
.WillRepeatedly(DoAll(
- SetArgPointee<1>(UserInfo{AuthScope::kOwner, "1"}), Return(true)));
+ SetArgPointee<1>(UserInfo{AuthScope::kOwner, TestUserId{"1"}}),
+ Return(true)));
}
};
@@ -658,7 +659,8 @@ TEST_F(PrivetHandlerSetupTest, GcdSetup) {
TEST_F(PrivetHandlerSetupTest, GcdSetupAsMaster) {
EXPECT_CALL(security_, ParseAccessToken(_, _, _))
.WillRepeatedly(DoAll(
- SetArgPointee<1>(UserInfo{AuthScope::kManager, "1"}), Return(true)));
+ SetArgPointee<1>(UserInfo{AuthScope::kManager, TestUserId{"1"}}),
+ Return(true)));
const char kInput[] = R"({
'gcd': {
'ticketId': 'testTicket',
diff --git a/src/privet/privet_types.h b/src/privet/privet_types.h
index 49c4522..0f51862 100644
--- a/src/privet/privet_types.h
+++ b/src/privet/privet_types.h
@@ -29,17 +29,42 @@ enum class WifiType {
kWifi50,
};
+struct UserAppId {
+ UserAppId() = default;
+
+ UserAppId(AuthType auth_type,
+ const std::vector<uint8_t>& user_id,
+ const std::vector<uint8_t>& app_id)
+ : type{auth_type},
+ user{user_id},
+ app{user_id.empty() ? user_id : app_id} {}
+
+ bool IsEmpty() const { return user.empty(); }
+
+ AuthType type{};
+ std::vector<uint8_t> user;
+ std::vector<uint8_t> app;
+};
+
+inline bool operator==(const UserAppId& l, const UserAppId& r) {
+ return l.user == r.user && l.app == r.app;
+}
+
+inline bool operator!=(const UserAppId& l, const UserAppId& r) {
+ return l.user != r.user || l.app != r.app;
+}
+
class UserInfo {
public:
explicit UserInfo(AuthScope scope = AuthScope::kNone,
- const std::string& user_id = {})
- : scope_{scope}, user_id_{scope == AuthScope::kNone ? "" : user_id} {}
+ const UserAppId& id = {})
+ : scope_{scope}, id_{scope == AuthScope::kNone ? UserAppId{} : id} {}
AuthScope scope() const { return scope_; }
- const std::string& user_id() const { return user_id_; }
+ const UserAppId& id() const { return id_; }
private:
AuthScope scope_;
- std::string user_id_;
+ UserAppId id_;
};
class ConnectionState final {
diff --git a/src/privet/security_manager.cc b/src/privet/security_manager.cc
index 0f00699..3b08613 100644
--- a/src/privet/security_manager.cc
+++ b/src/privet/security_manager.cc
@@ -91,9 +91,10 @@ bool SecurityManager::CreateAccessTokenImpl(AuthType auth_type,
std::vector<uint8_t>* access_token,
AuthScope* access_token_scope,
base::TimeDelta* access_token_ttl) {
- UserInfo user_info{desired_scope,
- std::to_string(static_cast<int>(auth_type)) + "/" +
- std::to_string(++last_user_id_)};
+ auto user_id = std::to_string(++last_user_id_);
+ UserInfo user_info{
+ desired_scope,
+ UserAppId{auth_type, {user_id.begin(), user_id.end()}, {}}};
const base::TimeDelta kTtl =
base::TimeDelta::FromSeconds(kAccessTokenExpirationSeconds);
@@ -388,7 +389,7 @@ bool SecurityManager::CancelPairing(const std::string& session_id,
}
std::string SecurityManager::CreateSessionId() {
- return Base64Encode(auth_manager_->CreateSessionId());
+ return auth_manager_->CreateSessionId();
}
void SecurityManager::RegisterPairingListeners(
diff --git a/src/privet/security_manager_unittest.cc b/src/privet/security_manager_unittest.cc
index 43b7f00..f596de9 100644
--- a/src/privet/security_manager_unittest.cc
+++ b/src/privet/security_manager_unittest.cc
@@ -25,6 +25,7 @@
#include "src/config.h"
#include "src/data_encoding.h"
#include "src/privet/auth_manager.h"
+#include "src/privet/mock_delegates.h"
#include "src/privet/openssl_utils.h"
#include "src/test/mock_clock.h"
#include "third_party/chromium/crypto/p224_spake.h"
@@ -170,7 +171,7 @@ TEST_F(SecurityManagerTest, AccessToken) {
UserInfo info;
EXPECT_TRUE(security_.ParseAccessToken(token, &info, nullptr));
EXPECT_EQ(requested_scope, info.scope());
- EXPECT_EQ("0/" + std::to_string(i), info.user_id());
+ EXPECT_EQ(TestUserId{std::to_string(i)}, info.id());
}
}