diff options
Diffstat (limited to 'src')
29 files changed, 1599 insertions, 207 deletions
diff --git a/src/access_api_handler.cc b/src/access_api_handler.cc new file mode 100644 index 0000000..7c39b20 --- /dev/null +++ b/src/access_api_handler.cc @@ -0,0 +1,227 @@ +// Copyright 2016 The Weave Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "src/access_api_handler.h" + +#include <base/bind.h> +#include <weave/device.h> + +#include "src/access_black_list_manager.h" +#include "src/commands/schema_constants.h" +#include "src/data_encoding.h" +#include "src/json_error_codes.h" + +namespace weave { + +namespace { + +const char kComponent[] = "accessControl"; +const char kTrait[] = "_accessControlBlackList"; +const char kStateSize[] = "_accessControlBlackList.size"; +const char kStateCapacity[] = "_accessControlBlackList.capacity"; +const char kUserId[] = "userId"; +const char kApplicationId[] = "applicationId"; +const char kExpirationTimeout[] = "expirationTimeoutSec"; +const char kBlackList[] = "blackList"; + +bool GetIds(const base::DictionaryValue& parameters, + std::vector<uint8_t>* user_id_decoded, + std::vector<uint8_t>* app_id_decoded, + ErrorPtr* error) { + std::string user_id; + parameters.GetString(kUserId, &user_id); + if (!Base64Decode(user_id, user_id_decoded)) { + Error::AddToPrintf(error, FROM_HERE, errors::commands::kInvalidPropValue, + "Invalid user id '%s'", user_id.c_str()); + return false; + } + + std::string app_id; + parameters.GetString(kApplicationId, &app_id); + if (!Base64Decode(app_id, app_id_decoded)) { + Error::AddToPrintf(error, FROM_HERE, errors::commands::kInvalidPropValue, + "Invalid app id '%s'", user_id.c_str()); + return false; + } + + return true; +} + +} // namespace + +AccessApiHandler::AccessApiHandler(Device* device, + AccessBlackListManager* manager) + : device_{device}, manager_{manager} { + device_->AddTraitDefinitionsFromJson(R"({ + "_accessControlBlackList": { + "commands": { + "block": { + "minimalRole": "owner", + "parameters": { + "userId": { + "type": "string" + }, + "applicationId": { + "type": "string" + }, + "expirationTimeoutSec": { + "type": "integer" + } + } + }, + "unblock": { + "minimalRole": "owner", + "parameters": { + "userId": { + "type": "string" + }, + "applicationId": { + "type": "string" + } + } + }, + "list": { + "minimalRole": "owner", + "parameters": {}, + "results": { + "blackList": { + "type": "array", + "items": { + "type": "object", + "properties": { + "userId": { + "type": "string" + }, + "applicationId": { + "type": "string" + } + }, + "additionalProperties": false + } + } + } + } + }, + "state": { + "size": { + "type": "integer", + "isRequired": true + }, + "capacity": { + "type": "integer", + "isRequired": true + } + } + } + })"); + CHECK(device_->AddComponent(kComponent, {kTrait}, nullptr)); + UpdateState(); + + device_->AddCommandHandler( + kComponent, "_accessControlBlackList.block", + base::Bind(&AccessApiHandler::Block, weak_ptr_factory_.GetWeakPtr())); + device_->AddCommandHandler( + kComponent, "_accessControlBlackList.unblock", + base::Bind(&AccessApiHandler::Unblock, weak_ptr_factory_.GetWeakPtr())); + device_->AddCommandHandler( + kComponent, "_accessControlBlackList.list", + base::Bind(&AccessApiHandler::List, weak_ptr_factory_.GetWeakPtr())); +} + +void AccessApiHandler::Block(const std::weak_ptr<Command>& cmd) { + auto command = cmd.lock(); + if (!command) + return; + + CHECK(command->GetState() == Command::State::kQueued) + << EnumToString(command->GetState()); + command->SetProgress(base::DictionaryValue{}, nullptr); + + const auto& parameters = command->GetParameters(); + std::vector<uint8_t> user_id; + std::vector<uint8_t> app_id; + ErrorPtr error; + if (!GetIds(parameters, &user_id, &app_id, &error)) { + command->Abort(error.get(), nullptr); + return; + } + + int timeout_sec = 0; + parameters.GetInteger(kExpirationTimeout, &timeout_sec); + + base::Time expiration = + base::Time::Now() + base::TimeDelta::FromSeconds(timeout_sec); + + manager_->Block(user_id, app_id, expiration, + base::Bind(&AccessApiHandler::OnCommandDone, + weak_ptr_factory_.GetWeakPtr(), cmd)); +} + +void AccessApiHandler::Unblock(const std::weak_ptr<Command>& cmd) { + auto command = cmd.lock(); + if (!command) + return; + + CHECK(command->GetState() == Command::State::kQueued) + << EnumToString(command->GetState()); + command->SetProgress(base::DictionaryValue{}, nullptr); + + const auto& parameters = command->GetParameters(); + std::vector<uint8_t> user_id; + std::vector<uint8_t> app_id; + ErrorPtr error; + if (!GetIds(parameters, &user_id, &app_id, &error)) { + command->Abort(error.get(), nullptr); + return; + } + + manager_->Unblock(user_id, app_id, + base::Bind(&AccessApiHandler::OnCommandDone, + weak_ptr_factory_.GetWeakPtr(), cmd)); +} + +void AccessApiHandler::List(const std::weak_ptr<Command>& cmd) { + auto command = cmd.lock(); + if (!command) + return; + + CHECK(command->GetState() == Command::State::kQueued) + << EnumToString(command->GetState()); + command->SetProgress(base::DictionaryValue{}, nullptr); + + std::unique_ptr<base::ListValue> entries{new base::ListValue}; + for (const auto& e : manager_->GetEntries()) { + std::unique_ptr<base::DictionaryValue> entry{new base::DictionaryValue}; + entry->SetString(kUserId, Base64Encode(e.user_id)); + entry->SetString(kApplicationId, Base64Encode(e.app_id)); + entries->Append(entry.release()); + } + + base::DictionaryValue result; + result.Set(kBlackList, entries.release()); + + command->Complete(result, nullptr); +} + +void AccessApiHandler::OnCommandDone(const std::weak_ptr<Command>& cmd, + ErrorPtr error) { + auto command = cmd.lock(); + if (!command) + return; + UpdateState(); + if (error) { + command->Abort(error.get(), nullptr); + return; + } + command->Complete({}, nullptr); +} + +void AccessApiHandler::UpdateState() { + base::DictionaryValue state; + state.SetInteger(kStateSize, manager_->GetSize()); + state.SetInteger(kStateCapacity, manager_->GetCapacity()); + device_->SetStateProperties(kComponent, state, nullptr); +} + +} // namespace weave diff --git a/src/access_api_handler.h b/src/access_api_handler.h new file mode 100644 index 0000000..821ce02 --- /dev/null +++ b/src/access_api_handler.h @@ -0,0 +1,47 @@ +// Copyright 2016 The Weave Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBWEAVE_SRC_ACCESS_API_HANDLER_H_ +#define LIBWEAVE_SRC_ACCESS_API_HANDLER_H_ + +#include <memory> + +#include <base/memory/weak_ptr.h> +#include <weave/error.h> + +namespace weave { + +class AccessBlackListManager; +class Command; +class Device; + +// Handles commands for 'accessControlBlackList' trait. +// Objects of the class subscribe for notification from CommandManager and +// execute incoming commands. +// Handled commands: +// accessControlBlackList.block +// accessControlBlackList.unblock +// accessControlBlackList.list +class AccessApiHandler final { + public: + AccessApiHandler(Device* device, AccessBlackListManager* manager); + + private: + void Block(const std::weak_ptr<Command>& command); + void Unblock(const std::weak_ptr<Command>& command); + void List(const std::weak_ptr<Command>& command); + void UpdateState(); + + void OnCommandDone(const std::weak_ptr<Command>& command, ErrorPtr error); + + Device* device_{nullptr}; + AccessBlackListManager* manager_{nullptr}; + + base::WeakPtrFactory<AccessApiHandler> weak_ptr_factory_{this}; + DISALLOW_COPY_AND_ASSIGN(AccessApiHandler); +}; + +} // namespace weave + +#endif // LIBWEAVE_SRC_ACCESS_API_HANDLER_H_ diff --git a/src/access_api_handler_unittest.cc b/src/access_api_handler_unittest.cc new file mode 100644 index 0000000..3e7f5d7 --- /dev/null +++ b/src/access_api_handler_unittest.cc @@ -0,0 +1,259 @@ +// Copyright 2016 The Weave Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "src/access_api_handler.h" + +#include <gtest/gtest.h> +#include <weave/provider/test/fake_task_runner.h> +#include <weave/test/mock_device.h> +#include <weave/test/unittest_utils.h> + +#include "src/component_manager_impl.h" +#include "src/access_black_list_manager.h" +#include "src/data_encoding.h" + +using testing::_; +using testing::AnyOf; +using testing::Invoke; +using testing::Return; +using testing::StrictMock; +using testing::WithArgs; + +namespace weave { + +class MockAccessBlackListManager : public AccessBlackListManager { + public: + MOCK_METHOD4(Block, + void(const std::vector<uint8_t>&, + const std::vector<uint8_t>&, + const base::Time&, + const DoneCallback&)); + MOCK_METHOD3(Unblock, + void(const std::vector<uint8_t>&, + const std::vector<uint8_t>&, + const DoneCallback&)); + MOCK_CONST_METHOD2(IsBlocked, + bool(const std::vector<uint8_t>&, + const std::vector<uint8_t>&)); + MOCK_CONST_METHOD0(GetEntries, std::vector<Entry>()); + MOCK_CONST_METHOD0(GetSize, size_t()); + MOCK_CONST_METHOD0(GetCapacity, size_t()); +}; + +class AccessApiHandlerTest : public ::testing::Test { + protected: + void SetUp() override { + EXPECT_CALL(device_, AddTraitDefinitionsFromJson(_)) + .WillRepeatedly(Invoke([this](const std::string& json) { + EXPECT_TRUE(component_manager_.LoadTraits(json, nullptr)); + })); + EXPECT_CALL(device_, SetStateProperties(_, _, _)) + .WillRepeatedly( + Invoke(&component_manager_, &ComponentManager::SetStateProperties)); + EXPECT_CALL(device_, SetStateProperty(_, _, _, _)) + .WillRepeatedly( + Invoke(&component_manager_, &ComponentManager::SetStateProperty)); + EXPECT_CALL(device_, AddComponent(_, _, _)) + .WillRepeatedly(Invoke([this](const std::string& name, + const std::vector<std::string>& traits, + ErrorPtr* error) { + return component_manager_.AddComponent("", name, traits, error); + })); + + EXPECT_CALL(device_, + AddCommandHandler(_, AnyOf("_accessControlBlackList.block", + "_accessControlBlackList.unblock", + "_accessControlBlackList.list"), + _)) + .WillRepeatedly( + Invoke(&component_manager_, &ComponentManager::AddCommandHandler)); + + EXPECT_CALL(access_manager_, GetSize()).WillRepeatedly(Return(0)); + + EXPECT_CALL(access_manager_, GetCapacity()).WillRepeatedly(Return(10)); + + handler_.reset(new AccessApiHandler{&device_, &access_manager_}); + } + + const base::DictionaryValue& AddCommand(const std::string& command) { + std::string id; + auto command_instance = component_manager_.ParseCommandInstance( + *test::CreateDictionaryValue(command.c_str()), Command::Origin::kLocal, + UserRole::kOwner, &id, nullptr); + EXPECT_NE(nullptr, command_instance.get()); + component_manager_.AddCommand(std::move(command_instance)); + EXPECT_EQ(Command::State::kDone, + component_manager_.FindCommand(id)->GetState()); + return component_manager_.FindCommand(id)->GetResults(); + } + + std::unique_ptr<base::DictionaryValue> GetState() { + std::string path = + component_manager_.FindComponentWithTrait("_accessControlBlackList"); + EXPECT_FALSE(path.empty()); + const auto* component = component_manager_.FindComponent(path, nullptr); + EXPECT_TRUE(component); + const base::DictionaryValue* state = nullptr; + EXPECT_TRUE( + component->GetDictionary("state._accessControlBlackList", &state)); + return std::unique_ptr<base::DictionaryValue>{state->DeepCopy()}; + } + + StrictMock<provider::test::FakeTaskRunner> task_runner_; + ComponentManagerImpl component_manager_{&task_runner_}; + StrictMock<test::MockDevice> device_; + StrictMock<MockAccessBlackListManager> access_manager_; + std::unique_ptr<AccessApiHandler> handler_; +}; + +TEST_F(AccessApiHandlerTest, Initialization) { + const base::DictionaryValue* trait = nullptr; + ASSERT_TRUE(component_manager_.GetTraits().GetDictionary( + "_accessControlBlackList", &trait)); + + auto expected = R"({ + "commands": { + "block": { + "minimalRole": "owner", + "parameters": { + "userId": { + "type": "string" + }, + "applicationId": { + "type": "string" + }, + "expirationTimeoutSec": { + "type": "integer" + } + } + }, + "unblock": { + "minimalRole": "owner", + "parameters": { + "userId": { + "type": "string" + }, + "applicationId": { + "type": "string" + } + } + }, + "list": { + "minimalRole": "owner", + "parameters": {}, + "results": { + "blackList": { + "type": "array", + "items": { + "type": "object", + "properties": { + "userId": { + "type": "string" + }, + "applicationId": { + "type": "string" + } + }, + "additionalProperties": false + } + } + } + } + }, + "state": { + "size": { + "type": "integer", + "isRequired": true + }, + "capacity": { + "type": "integer", + "isRequired": true + } + } + })"; + EXPECT_JSON_EQ(expected, *trait); + + expected = R"({ + "capacity": 10, + "size": 0 + })"; + EXPECT_JSON_EQ(expected, *GetState()); +} + +TEST_F(AccessApiHandlerTest, Block) { + EXPECT_CALL(access_manager_, Block(std::vector<uint8_t>{1, 2, 3}, + std::vector<uint8_t>{3, 4, 5}, _, _)) + .WillOnce(WithArgs<3>( + Invoke([](const DoneCallback& callback) { callback.Run(nullptr); }))); + EXPECT_CALL(access_manager_, GetSize()).WillRepeatedly(Return(1)); + + AddCommand(R"({ + 'name' : '_accessControlBlackList.block', + 'component': 'accessControl', + 'parameters': { + 'userId': 'AQID', + 'applicationId': 'AwQF', + 'expirationTimeoutSec': 1234 + } + })"); + + auto expected = R"({ + "capacity": 10, + "size": 1 + })"; + EXPECT_JSON_EQ(expected, *GetState()); +} + +TEST_F(AccessApiHandlerTest, Unblock) { + EXPECT_CALL(access_manager_, Unblock(std::vector<uint8_t>{1, 2, 3}, + std::vector<uint8_t>{3, 4, 5}, _)) + .WillOnce(WithArgs<2>( + Invoke([](const DoneCallback& callback) { callback.Run(nullptr); }))); + EXPECT_CALL(access_manager_, GetSize()).WillRepeatedly(Return(4)); + + AddCommand(R"({ + 'name' : '_accessControlBlackList.unblock', + 'component': 'accessControl', + 'parameters': { + 'userId': 'AQID', + 'applicationId': 'AwQF', + 'expirationTimeoutSec': 1234 + } + })"); + + auto expected = R"({ + "capacity": 10, + "size": 4 + })"; + EXPECT_JSON_EQ(expected, *GetState()); +} + +TEST_F(AccessApiHandlerTest, List) { + std::vector<AccessBlackListManager::Entry> entries{ + {{11, 12, 13}, {21, 22, 23}, base::Time::FromTimeT(1410000000)}, + {{31, 32, 33}, {41, 42, 43}, base::Time::FromTimeT(1420000000)}, + }; + EXPECT_CALL(access_manager_, GetEntries()).WillOnce(Return(entries)); + EXPECT_CALL(access_manager_, GetSize()).WillRepeatedly(Return(4)); + + auto expected = R"({ + "blackList": [ { + "applicationId": "FRYX", + "userId": "CwwN" + }, { + "applicationId": "KSor", + "userId": "HyAh" + } ] + })"; + + const auto& results = AddCommand(R"({ + 'name' : '_accessControlBlackList.list', + 'component': 'accessControl', + 'parameters': { + } + })"); + + EXPECT_JSON_EQ(expected, results); +} +} // namespace weave diff --git a/src/access_black_list_manager.h b/src/access_black_list_manager.h new file mode 100644 index 0000000..b56226a --- /dev/null +++ b/src/access_black_list_manager.h @@ -0,0 +1,56 @@ +// Copyright 2016 The Weave Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBWEAVE_SRC_ACCESS_BLACK_LIST_H_ +#define LIBWEAVE_SRC_ACCESS_BLACK_LIST_H_ + +#include <vector> + +#include <base/time/time.h> + +namespace weave { + +class AccessBlackListManager { + public: + struct Entry { + // user_id is empty, app_id is empty: block everything. + // user_id is not empty, app_id is empty: block if user_id matches. + // user_id is empty, app_id is not empty: block if app_id matches. + // user_id is not empty, app_id is not empty: block if both match. + std::vector<uint8_t> user_id; + std::vector<uint8_t> app_id; + + // Time after which to discard the rule. + base::Time expiration; + }; + virtual ~AccessBlackListManager() = default; + + virtual void Block(const std::vector<uint8_t>& user_id, + const std::vector<uint8_t>& app_id, + const base::Time& expiration, + const DoneCallback& callback) = 0; + virtual void Unblock(const std::vector<uint8_t>& user_id, + const std::vector<uint8_t>& app_id, + const DoneCallback& callback) = 0; + virtual bool IsBlocked(const std::vector<uint8_t>& user_id, + const std::vector<uint8_t>& app_id) const = 0; + virtual std::vector<Entry> GetEntries() const = 0; + virtual size_t GetSize() const = 0; + virtual size_t GetCapacity() const = 0; +}; + +inline bool operator==(const AccessBlackListManager::Entry& l, + const AccessBlackListManager::Entry& r) { + return l.user_id == r.user_id && l.app_id == r.app_id && + l.expiration == r.expiration; +} + +inline bool operator!=(const AccessBlackListManager::Entry& l, + const AccessBlackListManager::Entry& r) { + return !(l == r); +} + +} // namespace weave + +#endif // LIBWEAVE_SRC_ACCESS_BLACK_LIST_H_ diff --git a/src/access_black_list_manager_impl.cc b/src/access_black_list_manager_impl.cc new file mode 100644 index 0000000..992a680 --- /dev/null +++ b/src/access_black_list_manager_impl.cc @@ -0,0 +1,163 @@ +// Copyright 2016 The Weave Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "src/access_black_list_manager_impl.h" + +#include <base/json/json_reader.h> +#include <base/json/json_writer.h> +#include <base/values.h> + +#include "src/commands/schema_constants.h" +#include "src/data_encoding.h" + +namespace weave { + +namespace { +const char kConfigFileName[] = "black_list"; + +const char kUser[] = "user"; +const char kApp[] = "app"; +const char kExpiration[] = "expiration"; +} + +AccessBlackListManagerImpl::AccessBlackListManagerImpl( + provider::ConfigStore* store, + size_t capacity, + base::Clock* clock) + : capacity_{capacity}, clock_{clock}, store_{store} { + Load(); +} + +void AccessBlackListManagerImpl::Load() { + if (!store_) + return; + if (auto list = base::ListValue::From( + base::JSONReader::Read(store_->LoadSettings(kConfigFileName)))) { + for (const auto& e : *list) { + const base::DictionaryValue* entry{nullptr}; + std::string user; + std::string app; + decltype(entries_)::key_type key; + int expiration; + if (e->GetAsDictionary(&entry) && entry->GetString(kUser, &user) && + Base64Decode(user, &key.first) && entry->GetString(kApp, &app) && + Base64Decode(app, &key.second) && + entry->GetInteger(kExpiration, &expiration)) { + base::Time expiration_time = base::Time::FromTimeT(expiration); + if (expiration_time > clock_->Now()) + entries_[key] = expiration_time; + } + } + if (entries_.size() < list->GetSize()) { + // Save some storage space by saving without expired entries. + Save({}); + } + } +} + +void AccessBlackListManagerImpl::Save(const DoneCallback& callback) { + if (!store_) { + if (!callback.is_null()) + callback.Run(nullptr); + return; + } + + base::ListValue list; + for (const auto& e : entries_) { + scoped_ptr<base::DictionaryValue> entry{new base::DictionaryValue}; + entry->SetString(kUser, Base64Encode(e.first.first)); + entry->SetString(kApp, Base64Encode(e.first.second)); + entry->SetInteger(kExpiration, e.second.ToTimeT()); + list.Append(std::move(entry)); + } + + std::string json; + base::JSONWriter::Write(list, &json); + store_->SaveSettings(kConfigFileName, json, callback); +} + +void AccessBlackListManagerImpl::RemoveExpired() { + for (auto i = begin(entries_); i != end(entries_);) { + if (i->second <= clock_->Now()) + i = entries_.erase(i); + else + ++i; + } +} + +void AccessBlackListManagerImpl::Block(const std::vector<uint8_t>& user_id, + const std::vector<uint8_t>& app_id, + const base::Time& expiration, + const DoneCallback& callback) { + // Iterating is OK as Save below is more expensive. + RemoveExpired(); + if (expiration <= clock_->Now()) { + if (!callback.is_null()) { + ErrorPtr error; + Error::AddTo(&error, FROM_HERE, "aleady_expired", + "Entry already expired"); + callback.Run(std::move(error)); + } + return; + } + if (entries_.size() >= capacity_) { + if (!callback.is_null()) { + ErrorPtr error; + Error::AddTo(&error, FROM_HERE, "blacklist_is_full", + "Unable to store more entries"); + callback.Run(std::move(error)); + } + return; + } + auto& value = entries_[std::make_pair(user_id, app_id)]; + value = std::max(value, expiration); + Save(callback); +} + +void AccessBlackListManagerImpl::Unblock(const std::vector<uint8_t>& user_id, + const std::vector<uint8_t>& app_id, + const DoneCallback& callback) { + if (!entries_.erase(std::make_pair(user_id, app_id))) { + if (!callback.is_null()) { + ErrorPtr error; + Error::AddTo(&error, FROM_HERE, "entry_not_found", "Unknown entry"); + callback.Run(std::move(error)); + } + return; + } + // Iterating is OK as Save below is more expensive. + RemoveExpired(); + Save(callback); +} + +bool AccessBlackListManagerImpl::IsBlocked( + const std::vector<uint8_t>& user_id, + const std::vector<uint8_t>& app_id) const { + for (const auto& user : {{}, user_id}) { + for (const auto& app : {{}, app_id}) { + auto both = entries_.find(std::make_pair(user, app)); + if (both != end(entries_) && both->second > clock_->Now()) + return true; + } + } + return false; +} + +std::vector<AccessBlackListManager::Entry> +AccessBlackListManagerImpl::GetEntries() const { + std::vector<Entry> result; + for (const auto& e : entries_) + result.push_back({e.first.first, e.first.second, e.second}); + return result; +} + +size_t AccessBlackListManagerImpl::GetSize() const { + return entries_.size(); +} + +size_t AccessBlackListManagerImpl::GetCapacity() const { + return capacity_; +} + +} // namespace weave diff --git a/src/access_black_list_manager_impl.h b/src/access_black_list_manager_impl.h new file mode 100644 index 0000000..1c175db --- /dev/null +++ b/src/access_black_list_manager_impl.h @@ -0,0 +1,58 @@ +// Copyright 2016 The Weave Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBWEAVE_SRC_ACCESS_BLACK_LIST_IMPL_H_ +#define LIBWEAVE_SRC_ACCESS_BLACK_LIST_IMPL_H_ + +#include <map> +#include <utility> + +#include <base/time/default_clock.h> +#include <base/time/time.h> +#include <weave/error.h> +#include <weave/provider/config_store.h> + +#include "src/access_black_list_manager.h" + +namespace weave { + +class AccessBlackListManagerImpl : public AccessBlackListManager { + public: + explicit AccessBlackListManagerImpl(provider::ConfigStore* store, + size_t capacity = 1024, + base::Clock* clock = nullptr); + + // AccessBlackListManager implementation. + void Block(const std::vector<uint8_t>& user_id, + const std::vector<uint8_t>& app_id, + const base::Time& expiration, + const DoneCallback& callback) override; + void Unblock(const std::vector<uint8_t>& user_id, + const std::vector<uint8_t>& app_id, + const DoneCallback& callback) override; + bool IsBlocked(const std::vector<uint8_t>& user_id, + const std::vector<uint8_t>& app_id) const override; + std::vector<Entry> GetEntries() const override; + size_t GetSize() const override; + size_t GetCapacity() const override; + + private: + void Load(); + void Save(const DoneCallback& callback); + void RemoveExpired(); + + const size_t capacity_{0}; + base::DefaultClock default_clock_; + base::Clock* clock_{&default_clock_}; + + provider::ConfigStore* store_{nullptr}; + std::map<std::pair<std::vector<uint8_t>, std::vector<uint8_t>>, base::Time> + entries_; + + DISALLOW_COPY_AND_ASSIGN(AccessBlackListManagerImpl); +}; + +} // namespace weave + +#endif // LIBWEAVE_SRC_ACCESS_BLACK_LIST_IMPL_H_ diff --git a/src/access_black_list_manager_impl_unittest.cc b/src/access_black_list_manager_impl_unittest.cc new file mode 100644 index 0000000..fd9f226 --- /dev/null +++ b/src/access_black_list_manager_impl_unittest.cc @@ -0,0 +1,165 @@ +// Copyright 2016 The Weave Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "src/access_black_list_manager_impl.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include <weave/provider/test/mock_config_store.h> +#include <weave/test/unittest_utils.h> + +#include "src/test/mock_clock.h" +#include "src/bind_lambda.h" + +using testing::_; +using testing::Return; +using testing::StrictMock; + +namespace weave { + +class AccessBlackListManagerImplTest : public testing::Test { + protected: + void SetUp() { + std::string to_load = R"([{ + "user": "BQID", + "app": "BwQF", + "expiration": 1410000000 + }, { + "user": "AQID", + "app": "AwQF", + "expiration": 1419999999 + }])"; + + EXPECT_CALL(config_store_, LoadSettings("black_list")) + .WillOnce(Return(to_load)); + + EXPECT_CALL(config_store_, SaveSettings("black_list", _, _)) + .WillOnce(testing::WithArgs<1, 2>(testing::Invoke( + [](const std::string& json, const DoneCallback& callback) { + std::string to_save = R"([{ + "user": "AQID", + "app": "AwQF", + "expiration": 1419999999 + }])"; + EXPECT_JSON_EQ(to_save, *test::CreateValue(json)); + if (!callback.is_null()) + callback.Run(nullptr); + }))); + + EXPECT_CALL(clock_, Now()) + .WillRepeatedly(Return(base::Time::FromTimeT(1412121212))); + manager_.reset(new AccessBlackListManagerImpl{&config_store_, 10, &clock_}); + } + StrictMock<test::MockClock> clock_; + StrictMock<provider::test::MockConfigStore> config_store_{false}; + std::unique_ptr<AccessBlackListManagerImpl> manager_; +}; + +TEST_F(AccessBlackListManagerImplTest, Init) { + EXPECT_EQ(1u, manager_->GetSize()); + EXPECT_EQ(10u, manager_->GetCapacity()); + EXPECT_EQ((std::vector<AccessBlackListManagerImpl::Entry>{{ + {1, 2, 3}, {3, 4, 5}, base::Time::FromTimeT(1419999999), + }}), + manager_->GetEntries()); +} + +TEST_F(AccessBlackListManagerImplTest, Block) { + EXPECT_CALL(config_store_, SaveSettings("black_list", _, _)) + .WillOnce(testing::WithArgs<1, 2>(testing::Invoke( + [](const std::string& json, const DoneCallback& callback) { + std::string to_save = R"([{ + "user": "AQID", + "app": "AwQF", + "expiration": 1419999999 + }, { + "app": "CAgI", + "user": "BwcH", + "expiration": 1419990000 + }])"; + EXPECT_JSON_EQ(to_save, *test::CreateValue(json)); + if (!callback.is_null()) + callback.Run(nullptr); + }))); + manager_->Block({7, 7, 7}, {8, 8, 8}, base::Time::FromTimeT(1419990000), {}); +} + +TEST_F(AccessBlackListManagerImplTest, BlockExpired) { + manager_->Block({}, {}, base::Time::FromTimeT(1400000000), + base::Bind([](ErrorPtr error) { + EXPECT_TRUE(error->HasError("aleady_expired")); + })); +} + +TEST_F(AccessBlackListManagerImplTest, BlockListIsFull) { + EXPECT_CALL(config_store_, SaveSettings("black_list", _, _)) + .WillRepeatedly(testing::WithArgs<1, 2>(testing::Invoke( + [](const std::string& json, const DoneCallback& callback) { + if (!callback.is_null()) + callback.Run(nullptr); + }))); + for (size_t i = manager_->GetSize(); i < manager_->GetCapacity(); ++i) { + manager_->Block( + {99, static_cast<uint8_t>(i / 256), static_cast<uint8_t>(i % 256)}, + {8, 8, 8}, base::Time::FromTimeT(1419990000), {}); + EXPECT_EQ(i + 1, manager_->GetSize()); + } + manager_->Block({99}, {8, 8, 8}, base::Time::FromTimeT(1419990000), + base::Bind([](ErrorPtr error) { + EXPECT_TRUE(error->HasError("blacklist_is_full")); + })); +} + +TEST_F(AccessBlackListManagerImplTest, Unblock) { + EXPECT_CALL(config_store_, SaveSettings("black_list", _, _)) + .WillOnce(testing::WithArgs<1, 2>(testing::Invoke( + [](const std::string& json, const DoneCallback& callback) { + EXPECT_JSON_EQ("[]", *test::CreateValue(json)); + if (!callback.is_null()) + callback.Run(nullptr); + }))); + manager_->Unblock({1, 2, 3}, {3, 4, 5}, {}); +} + +TEST_F(AccessBlackListManagerImplTest, UnblockNotFound) { + manager_->Unblock({5, 2, 3}, {5, 4, 5}, base::Bind([](ErrorPtr error) { + EXPECT_TRUE(error->HasError("entry_not_found")); + })); +} + +TEST_F(AccessBlackListManagerImplTest, IsBlockedFalse) { + EXPECT_FALSE(manager_->IsBlocked({7, 7, 7}, {8, 8, 8})); +} + +class AccessBlackListManagerImplIsBlockedTest + : public AccessBlackListManagerImplTest, + public testing::WithParamInterface< + std::tuple<std::vector<uint8_t>, std::vector<uint8_t>>> { + public: + void SetUp() override { + AccessBlackListManagerImplTest::SetUp(); + EXPECT_CALL(config_store_, SaveSettings("black_list", _, _)) + .WillOnce(testing::WithArgs<2>( + testing::Invoke([](const DoneCallback& callback) { + if (!callback.is_null()) + callback.Run(nullptr); + }))); + manager_->Block(std::get<0>(GetParam()), std::get<1>(GetParam()), + base::Time::FromTimeT(1419990000), {}); + } +}; + +TEST_P(AccessBlackListManagerImplIsBlockedTest, IsBlocked) { + EXPECT_TRUE(manager_->IsBlocked({7, 7, 7}, {8, 8, 8})); +} + +INSTANTIATE_TEST_CASE_P( + Filters, + AccessBlackListManagerImplIsBlockedTest, + testing::Combine(testing::Values(std::vector<uint8_t>{}, + std::vector<uint8_t>{7, 7, 7}), + testing::Values(std::vector<uint8_t>{}, + std::vector<uint8_t>{8, 8, 8}))); + +} // namespace weave diff --git a/src/base_api_handler.h b/src/base_api_handler.h index 1dbbac8..6eebfca 100644 --- a/src/base_api_handler.h +++ b/src/base_api_handler.h @@ -33,7 +33,7 @@ class BaseApiHandler final { void OnConfigChanged(const Settings& settings); DeviceRegistrationInfo* device_info_; - Device* device_; + Device* device_{nullptr}; base::WeakPtrFactory<BaseApiHandler> weak_ptr_factory_{this}; DISALLOW_COPY_AND_ASSIGN(BaseApiHandler); diff --git a/src/config.cc b/src/config.cc index 44d20dd..21a1c1f 100644 --- a/src/config.cc +++ b/src/config.cc @@ -33,6 +33,7 @@ const char kClientSecret[] = "client_secret"; const char kApiKey[] = "api_key"; const char kOAuthURL[] = "oauth_url"; const char kServiceURL[] = "service_url"; +const char kXmppEndpoint[] = "xmpp_endpoint"; const char kName[] = "name"; const char kDescription[] = "description"; const char kLocation[] = "location"; @@ -51,6 +52,7 @@ const char kRootClientTokenOwner[] = "root_client_token_owner"; const char kWeaveUrl[] = "https://www.googleapis.com/weave/v1/"; const char kDeprecatedUrl[] = "https://www.googleapis.com/clouddevices/v1/"; +const char kXmppEndpoint[] = "talk.google.com:5223"; namespace { @@ -69,6 +71,7 @@ Config::Settings CreateDefaultSettings() { Config::Settings result; result.oauth_url = "https://accounts.google.com/o/oauth2/"; result.service_url = kWeaveUrl; + result.xmpp_endpoint = kXmppEndpoint; result.local_anonymous_access_role = AuthScope::kViewer; result.pairing_modes.insert(PairingType::kPinCode); result.device_id = base::GenerateGUID(); @@ -119,6 +122,7 @@ void Config::Load() { CHECK(!settings_.api_key.empty()); CHECK(!settings_.oauth_url.empty()); CHECK(!settings_.service_url.empty()); + CHECK(!settings_.xmpp_endpoint.empty()); CHECK(!settings_.oem_name.empty()); CHECK(!settings_.model_name.empty()); CHECK(!settings_.model_id.empty()); @@ -190,6 +194,10 @@ void Config::Transaction::LoadState() { set_service_url(tmp); } + if (dict->GetString(config_keys::kXmppEndpoint, &tmp)) { + set_xmpp_endpoint(tmp); + } + if (dict->GetString(config_keys::kName, &tmp)) set_name(tmp); @@ -249,6 +257,7 @@ void Config::Save() { dict.SetString(config_keys::kApiKey, settings_.api_key); dict.SetString(config_keys::kOAuthURL, settings_.oauth_url); dict.SetString(config_keys::kServiceURL, settings_.service_url); + dict.SetString(config_keys::kXmppEndpoint, settings_.xmpp_endpoint); dict.SetString(config_keys::kRefreshToken, settings_.refresh_token); dict.SetString(config_keys::kCloudId, settings_.cloud_id); dict.SetString(config_keys::kDeviceId, settings_.device_id); diff --git a/src/config.h b/src/config.h index 6dc0a07..8e0a8f3 100644 --- a/src/config.h +++ b/src/config.h @@ -68,6 +68,9 @@ class Config final { void set_service_url(const std::string& url) { settings_->service_url = url; } + void set_xmpp_endpoint(const std::string& endpoint) { + settings_->xmpp_endpoint = endpoint; + } void set_name(const std::string& name) { settings_->name = name; } void set_description(const std::string& description) { settings_->description = description; diff --git a/src/config_unittest.cc b/src/config_unittest.cc index 4b0e5b4..bb2743a 100644 --- a/src/config_unittest.cc +++ b/src/config_unittest.cc @@ -62,6 +62,7 @@ TEST_F(ConfigTest, Defaults) { EXPECT_EQ("", GetSettings().api_key); EXPECT_EQ("https://accounts.google.com/o/oauth2/", GetSettings().oauth_url); EXPECT_EQ("https://www.googleapis.com/weave/v1/", GetSettings().service_url); + EXPECT_EQ("talk.google.com:5223", GetSettings().xmpp_endpoint); EXPECT_EQ("", GetSettings().oem_name); EXPECT_EQ("", GetSettings().model_name); EXPECT_EQ("", GetSettings().model_id); @@ -146,7 +147,8 @@ TEST_F(ConfigTest, LoadState) { "refresh_token": "state_refresh_token", "robot_account": "state_robot_account", "secret": "c3RhdGVfc2VjcmV0", - "service_url": "state_service_url" + "service_url": "state_service_url", + "xmpp_endpoint": "state_xmpp_endpoint" })"; EXPECT_CALL(config_store_, LoadSettings(kConfigName)).WillOnce(Return(state)); @@ -157,6 +159,7 @@ TEST_F(ConfigTest, LoadState) { EXPECT_EQ("state_api_key", GetSettings().api_key); EXPECT_EQ("state_oauth_url", GetSettings().oauth_url); EXPECT_EQ("state_service_url", GetSettings().service_url); + EXPECT_EQ("state_xmpp_endpoint", GetSettings().xmpp_endpoint); EXPECT_EQ(GetDefaultSettings().oem_name, GetSettings().oem_name); EXPECT_EQ(GetDefaultSettings().model_name, GetSettings().model_name); EXPECT_EQ(GetDefaultSettings().model_id, GetSettings().model_id); @@ -200,6 +203,9 @@ TEST_F(ConfigTest, Setters) { change.set_service_url("set_service_url"); EXPECT_EQ("set_service_url", GetSettings().service_url); + change.set_xmpp_endpoint("set_xmpp_endpoint"); + EXPECT_EQ("set_xmpp_endpoint", GetSettings().xmpp_endpoint); + change.set_name("set_name"); EXPECT_EQ("set_name", GetSettings().name); @@ -277,7 +283,8 @@ TEST_F(ConfigTest, Setters) { 'refresh_token': 'set_token', 'robot_account': 'set_account', 'secret': 'AQIDBAU=', - 'service_url': 'set_service_url' + 'service_url': 'set_service_url', + 'xmpp_endpoint': 'set_xmpp_endpoint' })"; EXPECT_JSON_EQ(expected, *test::CreateValue(json)); callback.Run(nullptr); diff --git a/src/device_manager.cc b/src/device_manager.cc index 097f854..deb5404 100644 --- a/src/device_manager.cc +++ b/src/device_manager.cc @@ -8,6 +8,8 @@ #include <base/bind.h> +#include "src/access_api_handler.h" +#include "src/access_black_list_manager_impl.h" #include "src/base_api_handler.h" #include "src/commands/schema_constants.h" #include "src/component_manager_impl.h" @@ -40,6 +42,10 @@ DeviceManager::DeviceManager(provider::ConfigStore* config_store, network, auth_manager_.get())); base_api_handler_.reset(new BaseApiHandler{device_info_.get(), this}); + black_list_manager_.reset(new AccessBlackListManagerImpl{config_store}); + access_api_handler_.reset( + new AccessApiHandler{this, black_list_manager_.get()}); + device_info_->Start(); if (http_server) { diff --git a/src/device_manager.h b/src/device_manager.h index d40ba8e..d77bacc 100644 --- a/src/device_manager.h +++ b/src/device_manager.h @@ -10,6 +10,8 @@ namespace weave { +class AccessApiHandler; +class AccessBlackListManager; class BaseApiHandler; class Config; class ComponentManager; @@ -107,6 +109,8 @@ class DeviceManager final : public Device { std::unique_ptr<ComponentManager> component_manager_; std::unique_ptr<DeviceRegistrationInfo> device_info_; std::unique_ptr<BaseApiHandler> base_api_handler_; + std::unique_ptr<AccessBlackListManager> black_list_manager_; + std::unique_ptr<AccessApiHandler> access_api_handler_; std::unique_ptr<privet::Manager> privet_; base::WeakPtrFactory<DeviceManager> weak_ptr_factory_{this}; diff --git a/src/device_registration_info.cc b/src/device_registration_info.cc index 7c20084..0dc1f54 100644 --- a/src/device_registration_info.cc +++ b/src/device_registration_info.cc @@ -463,8 +463,9 @@ void DeviceRegistrationInfo::StartNotificationChannel() { current_notification_channel_ = pull_channel_.get(); notification_channel_starting_ = true; - primary_notification_channel_.reset(new XmppChannel{ - GetSettings().robot_account, access_token_, task_runner_, network_}); + primary_notification_channel_.reset( + new XmppChannel{GetSettings().robot_account, access_token_, + GetSettings().xmpp_endpoint, task_runner_, network_}); primary_notification_channel_->Start(this); } @@ -833,17 +834,25 @@ bool DeviceRegistrationInfo::UpdateServiceConfig( const std::string& api_key, const std::string& oauth_url, const std::string& service_url, + const std::string& xmpp_endpoint, ErrorPtr* error) { if (HaveRegistrationCredentials()) { return Error::AddTo(error, FROM_HERE, kErrorAlreayRegistered, "Unable to change config for registered device"); } Config::Transaction change{config_}; - change.set_client_id(client_id); - change.set_client_secret(client_secret); - change.set_api_key(api_key); - change.set_oauth_url(oauth_url); - change.set_service_url(service_url); + if (!client_id.empty()) + change.set_client_id(client_id); + if (!client_secret.empty()) + change.set_client_secret(client_secret); + if (!api_key.empty()) + change.set_api_key(api_key); + if (!oauth_url.empty()) + change.set_oauth_url(oauth_url); + if (!service_url.empty()) + change.set_service_url(service_url); + if (!xmpp_endpoint.empty()) + change.set_xmpp_endpoint(xmpp_endpoint); return true; } diff --git a/src/device_registration_info.h b/src/device_registration_info.h index f670b68..a296258 100644 --- a/src/device_registration_info.h +++ b/src/device_registration_info.h @@ -78,6 +78,7 @@ class DeviceRegistrationInfo : public NotificationDelegate, const std::string& api_key, const std::string& oauth_url, const std::string& service_url, + const std::string& xmpp_endpoint, ErrorPtr* error); void GetDeviceInfo(const CloudRequestDoneCallback& callback); diff --git a/src/device_registration_info_unittest.cc b/src/device_registration_info_unittest.cc index 7908c8b..bbc167e 100644 --- a/src/device_registration_info_unittest.cc +++ b/src/device_registration_info_unittest.cc @@ -44,6 +44,7 @@ namespace { namespace test_data { +const char kXmppEndpoint[] = "xmpp.server.com:1234"; const char kServiceURL[] = "http://gcd.server.com/"; const char kOAuthURL[] = "http://oauth.server.com/"; const char kApiKey[] = "GOadRdTf9FERf0k4w6EFOof56fUJ3kFDdFL3d7f"; @@ -144,6 +145,7 @@ class DeviceRegistrationInfoTest : public ::testing::Test { settings->model_id = "AAAAA"; settings->oauth_url = test_data::kOAuthURL; settings->service_url = test_data::kServiceURL; + settings->xmpp_endpoint = test_data::kXmppEndpoint; return true; })); config_.reset(new Config{&config_store_}); diff --git a/src/notification/xmpp_channel.cc b/src/notification/xmpp_channel.cc index ceb45ed..f9d7924 100644 --- a/src/notification/xmpp_channel.cc +++ b/src/notification/xmpp_channel.cc @@ -7,6 +7,7 @@ #include <string> #include <base/bind.h> +#include <base/strings/string_number_conversions.h> #include <weave/provider/network.h> #include <weave/provider/task_runner.h> @@ -16,6 +17,7 @@ #include "src/notification/notification_parser.h" #include "src/notification/xml_node.h" #include "src/privet/openssl_utils.h" +#include "src/string_utils.h" #include "src/utils.h" namespace weave { @@ -74,9 +76,6 @@ const BackoffEntry::Policy kDefaultBackoffPolicy = { false, }; -const char kDefaultXmppHost[] = "talk.google.com"; -const uint16_t kDefaultXmppPort = 5223; - // Used for keeping connection alive. const int kRegularPingIntervalSeconds = 60; const int kRegularPingTimeoutSeconds = 30; @@ -91,10 +90,12 @@ const int kConnectingTimeoutAfterNetChangeSeconds = 30; XmppChannel::XmppChannel(const std::string& account, const std::string& access_token, + const std::string& xmpp_endpoint, provider::TaskRunner* task_runner, provider::Network* network) : account_{account}, access_token_{access_token}, + xmpp_endpoint_{xmpp_endpoint}, network_{network}, backoff_entry_{&kDefaultBackoffPolicy}, task_runner_{task_runner}, @@ -285,10 +286,16 @@ void XmppChannel::HandleMessageStanza(std::unique_ptr<XmlNode> stanza) { void XmppChannel::CreateSslSocket() { CHECK(!stream_); state_ = XmppState::kConnecting; - LOG(INFO) << "Starting XMPP connection to " << kDefaultXmppHost << ":" - << kDefaultXmppPort; + LOG(INFO) << "Starting XMPP connection to: " << xmpp_endpoint_; + + std::pair<std::string, std::string> host_port = + SplitAtFirst(xmpp_endpoint_, ":", true); + CHECK(!host_port.first.empty()); + CHECK(!host_port.second.empty()); + uint32_t port = 0; + CHECK(base::StringToUint(host_port.second, &port)) << xmpp_endpoint_; - network_->OpenSslSocket(kDefaultXmppHost, kDefaultXmppPort, + network_->OpenSslSocket(host_port.first, port, base::Bind(&XmppChannel::OnSslSocketReady, task_ptr_factory_.GetWeakPtr())); } diff --git a/src/notification/xmpp_channel.h b/src/notification/xmpp_channel.h index 50e84d2..b0a4468 100644 --- a/src/notification/xmpp_channel.h +++ b/src/notification/xmpp_channel.h @@ -45,6 +45,7 @@ class XmppChannel : public NotificationChannel, // so you will need to reset the XmppClient every time this happens. XmppChannel(const std::string& account, const std::string& access_token, + const std::string& xmpp_endpoint, provider::TaskRunner* task_runner, provider::Network* network); ~XmppChannel() override = default; @@ -124,12 +125,15 @@ class XmppChannel : public NotificationChannel, // Robot account name for the device. std::string account_; - // Full JID of this device. - std::string jid_; - // OAuth access token for the account. Expires fairly frequently. std::string access_token_; + // Xmpp endpoint. + std::string xmpp_endpoint_; + + // Full JID of this device. + std::string jid_; + provider::Network* network_{nullptr}; std::unique_ptr<Stream> stream_; diff --git a/src/notification/xmpp_channel_unittest.cc b/src/notification/xmpp_channel_unittest.cc index 674fe22..dfa2a79 100644 --- a/src/notification/xmpp_channel_unittest.cc +++ b/src/notification/xmpp_channel_unittest.cc @@ -26,6 +26,7 @@ namespace { constexpr char kAccountName[] = "Account@Name"; constexpr char kAccessToken[] = "AccessToken"; +constexpr char kEndpoint[] = "endpoint:456"; constexpr char kStartStreamMessage[] = "<stream:stream to='clouddevices.gserviceaccount.com' " @@ -84,7 +85,8 @@ class FakeXmppChannel : public XmppChannel { public: explicit FakeXmppChannel(provider::TaskRunner* task_runner, provider::Network* network) - : XmppChannel{kAccountName, kAccessToken, task_runner, network}, + : XmppChannel{kAccountName, kAccessToken, kEndpoint, task_runner, + network}, stream_{new test::FakeStream{task_runner_}}, fake_stream_{stream_.get()} {} @@ -122,7 +124,7 @@ class MockNetwork : public provider::test::MockNetwork { class XmppChannelTest : public ::testing::Test { protected: XmppChannelTest() { - EXPECT_CALL(network_, OpenSslSocket("talk.google.com", 5223, _)) + EXPECT_CALL(network_, OpenSslSocket("endpoint", 456, _)) .WillOnce( WithArgs<2>(Invoke(&xmpp_client_, &FakeXmppChannel::Connect))); } diff --git a/src/privet/auth_manager.cc b/src/privet/auth_manager.cc index 66d04c4..c82887e 100644 --- a/src/privet/auth_manager.cc +++ b/src/privet/auth_manager.cc @@ -18,6 +18,7 @@ extern "C" { #include "third_party/libuweave/src/macaroon.h" +#include "third_party/libuweave/src/macaroon_caveat_internal.h" } namespace weave { @@ -25,9 +26,19 @@ namespace privet { namespace { +const time_t kJ2000ToTimeT = 946684800; const size_t kMaxMacaroonSize = 1024; const size_t kMaxPendingClaims = 10; const char kInvalidTokenError[] = "invalid_token"; +const int kSessionIdTtlMinutes = 1; + +uint32_t ToJ2000Time(const base::Time& time) { + return std::max(time.ToTimeT(), kJ2000ToTimeT) - kJ2000ToTimeT; +} + +base::Time FromJ2000Time(uint32_t time) { + return base::Time::FromTimeT(time + kJ2000ToTimeT); +} template <class T> void AppendToArray(T value, std::vector<uint8_t>* array) { @@ -37,78 +48,108 @@ void AppendToArray(T value, std::vector<uint8_t>* array) { class Caveat { public: - // TODO(vitalybuka): Use _get_buffer_size_ when available. - Caveat(UwMacaroonCaveatType type, uint32_t value) : buffer(8) { - CHECK(uw_macaroon_caveat_create_with_uint_(type, value, buffer.data(), - buffer.size(), &caveat)); + Caveat(UwMacaroonCaveatType type, size_t str_len) + : buffer_(uw_macaroon_caveat_creation_get_buffsize_(type, str_len)) { + CHECK(!buffer_.empty()); } + const UwMacaroonCaveat& GetCaveat() const { return caveat_; } + + protected: + UwMacaroonCaveat caveat_{}; + std::vector<uint8_t> buffer_; + + DISALLOW_COPY_AND_ASSIGN(Caveat); +}; - // TODO(vitalybuka): Use _get_buffer_size_ when available. - Caveat(UwMacaroonCaveatType type, const std::string& value) - : buffer(std::max<size_t>(value.size(), 32u) * 2) { - CHECK(uw_macaroon_caveat_create_with_str_( - type, reinterpret_cast<const uint8_t*>(value.data()), value.size(), - buffer.data(), buffer.size(), &caveat)); +class ScopeCaveat : public Caveat { + public: + explicit ScopeCaveat(UwMacaroonCaveatScopeType scope) + : Caveat(kUwMacaroonCaveatTypeScope, 0) { + CHECK(uw_macaroon_caveat_create_scope_(scope, buffer_.data(), + buffer_.size(), &caveat_)); } - const UwMacaroonCaveat& GetCaveat() const { return caveat; } + DISALLOW_COPY_AND_ASSIGN(ScopeCaveat); +}; - private: - UwMacaroonCaveat caveat; - std::vector<uint8_t> buffer; +class TimestampCaveat : public Caveat { + public: + explicit TimestampCaveat(const base::Time& timestamp) + : Caveat(kUwMacaroonCaveatTypeDelegationTimestamp, 0) { + CHECK(uw_macaroon_caveat_create_delegation_timestamp_( + ToJ2000Time(timestamp), buffer_.data(), buffer_.size(), &caveat_)); + } - DISALLOW_COPY_AND_ASSIGN(Caveat); + DISALLOW_COPY_AND_ASSIGN(TimestampCaveat); }; -bool CheckCaveatType(const UwMacaroonCaveat& caveat, - UwMacaroonCaveatType type, - ErrorPtr* error) { - UwMacaroonCaveatType caveat_type{}; - if (!uw_macaroon_caveat_get_type_(&caveat, &caveat_type)) { - return Error::AddTo(error, FROM_HERE, kInvalidTokenError, - "Unable to get type"); +class ExpirationCaveat : public Caveat { + public: + explicit ExpirationCaveat(const base::Time& timestamp) + : Caveat(kUwMacaroonCaveatTypeExpirationAbsolute, 0) { + CHECK(uw_macaroon_caveat_create_expiration_absolute_( + ToJ2000Time(timestamp), buffer_.data(), buffer_.size(), &caveat_)); } - if (caveat_type != type) { - return Error::AddTo(error, FROM_HERE, kInvalidTokenError, - "Unexpected caveat type"); + DISALLOW_COPY_AND_ASSIGN(ExpirationCaveat); +}; + +class UserIdCaveat : public Caveat { + public: + explicit UserIdCaveat(const std::vector<uint8_t>& id) + : Caveat(kUwMacaroonCaveatTypeDelegateeUser, id.size()) { + CHECK(uw_macaroon_caveat_create_delegatee_user_( + id.data(), id.size(), buffer_.data(), buffer_.size(), &caveat_)); } - return true; -} + DISALLOW_COPY_AND_ASSIGN(UserIdCaveat); +}; -bool ReadCaveat(const UwMacaroonCaveat& caveat, - UwMacaroonCaveatType type, - uint32_t* value, - ErrorPtr* error) { - if (!CheckCaveatType(caveat, type, error)) - return false; +class AppIdCaveat : public Caveat { + public: + explicit AppIdCaveat(const std::vector<uint8_t>& id) + : Caveat(kUwMacaroonCaveatTypeDelegateeApp, id.size()) { + CHECK(uw_macaroon_caveat_create_delegatee_app_( + id.data(), id.size(), buffer_.data(), buffer_.size(), &caveat_)); + } - if (!uw_macaroon_caveat_get_value_uint_(&caveat, value)) { - return Error::AddTo(error, FROM_HERE, kInvalidTokenError, - "Unable to read caveat"); + DISALLOW_COPY_AND_ASSIGN(AppIdCaveat); +}; + +class ServiceCaveat : public Caveat { + public: + explicit ServiceCaveat(const std::string& id) + : Caveat(kUwMacaroonCaveatTypeDelegateeService, id.size()) { + CHECK(uw_macaroon_caveat_create_delegatee_service_( + reinterpret_cast<const uint8_t*>(id.data()), id.size(), buffer_.data(), + buffer_.size(), &caveat_)); } - return true; -} + DISALLOW_COPY_AND_ASSIGN(ServiceCaveat); +}; -bool ReadCaveat(const UwMacaroonCaveat& caveat, - UwMacaroonCaveatType type, - std::string* value, - ErrorPtr* error) { - if (!CheckCaveatType(caveat, type, error)) - return false; +class SessionIdCaveat : public Caveat { + public: + explicit SessionIdCaveat(const std::string& id) + : Caveat(kUwMacaroonCaveatTypeLanSessionID, id.size()) { + CHECK(uw_macaroon_caveat_create_lan_session_id_( + reinterpret_cast<const uint8_t*>(id.data()), id.size(), buffer_.data(), + buffer_.size(), &caveat_)); + } - const uint8_t* start{nullptr}; - size_t size{0}; - if (!uw_macaroon_caveat_get_value_str_(&caveat, &start, &size)) { - return Error::AddTo(error, FROM_HERE, kInvalidTokenError, - "Unable to read caveat"); + DISALLOW_COPY_AND_ASSIGN(SessionIdCaveat); +}; + +class ClientAuthTokenCaveat : public Caveat { + public: + ClientAuthTokenCaveat() + : Caveat(kUwMacaroonCaveatTypeClientAuthorizationTokenV1, 0) { + CHECK(uw_macaroon_caveat_create_client_authorization_token_( + nullptr, 0, buffer_.data(), buffer_.size(), &caveat_)); } - value->assign(reinterpret_cast<const char*>(start), size); - return true; -} + DISALLOW_COPY_AND_ASSIGN(ClientAuthTokenCaveat); +}; std::vector<uint8_t> CreateSecret() { std::vector<uint8_t> secret(kSha256OutputSize); @@ -122,18 +163,53 @@ bool IsClaimAllowed(RootClientTokenOwner curret, RootClientTokenOwner claimer) { std::vector<uint8_t> CreateMacaroonToken( const std::vector<uint8_t>& secret, - const std::vector<UwMacaroonCaveat>& caveats) { + const base::Time& time, + const std::vector<const UwMacaroonCaveat*>& caveats) { CHECK_EQ(kSha256OutputSize, secret.size()); + + UwMacaroonContext context{}; + CHECK(uw_macaroon_context_create_(ToJ2000Time(time), nullptr, 0, &context)); + UwMacaroon macaroon{}; - CHECK(uw_macaroon_new_from_root_key_(&macaroon, secret.data(), secret.size(), - caveats.data(), caveats.size())); + CHECK(uw_macaroon_create_from_root_key_(&macaroon, secret.data(), + secret.size(), &context, + caveats.data(), caveats.size())); - std::vector<uint8_t> token(kMaxMacaroonSize); + std::vector<uint8_t> serialized_token(kMaxMacaroonSize); size_t len = 0; - CHECK(uw_macaroon_dump_(&macaroon, token.data(), token.size(), &len)); - token.resize(len); + CHECK(uw_macaroon_serialize_(&macaroon, serialized_token.data(), + serialized_token.size(), &len)); + serialized_token.resize(len); - return token; + return serialized_token; +} + +std::vector<uint8_t> ExtendMacaroonToken( + const UwMacaroon& macaroon, + const base::Time& time, + const std::vector<const UwMacaroonCaveat*>& caveats) { + UwMacaroonContext context{}; + CHECK(uw_macaroon_context_create_(ToJ2000Time(time), nullptr, 0, &context)); + + UwMacaroon prev_macaroon = macaroon; + std::vector<uint8_t> prev_buffer(kMaxMacaroonSize); + std::vector<uint8_t> new_buffer(kMaxMacaroonSize); + + for (auto caveat : caveats) { + UwMacaroon new_macaroon{}; + CHECK(uw_macaroon_extend_(&prev_macaroon, &new_macaroon, &context, caveat, + new_buffer.data(), new_buffer.size())); + new_buffer.swap(prev_buffer); + prev_macaroon = new_macaroon; + } + + std::vector<uint8_t> serialized_token(kMaxMacaroonSize); + size_t len = 0; + CHECK(uw_macaroon_serialize_(&prev_macaroon, serialized_token.data(), + serialized_token.size(), &len)); + serialized_token.resize(len); + + return serialized_token; } bool LoadMacaroon(const std::vector<uint8_t>& token, @@ -141,8 +217,8 @@ bool LoadMacaroon(const std::vector<uint8_t>& token, UwMacaroon* macaroon, ErrorPtr* error) { buffer->resize(kMaxMacaroonSize); - if (!uw_macaroon_load_(token.data(), token.size(), buffer->data(), - buffer->size(), macaroon)) { + if (!uw_macaroon_deserialize_(token.data(), token.size(), buffer->data(), + buffer->size(), macaroon)) { return Error::AddTo(error, FROM_HERE, kInvalidTokenError, "Invalid token format"); } @@ -151,10 +227,16 @@ bool LoadMacaroon(const std::vector<uint8_t>& token, bool VerifyMacaroon(const std::vector<uint8_t>& secret, const UwMacaroon& macaroon, + const base::Time& time, + UwMacaroonValidationResult* result, ErrorPtr* error) { CHECK_EQ(kSha256OutputSize, secret.size()); - if (!uw_macaroon_verify_(&macaroon, secret.data(), secret.size())) { - return Error::AddTo(error, FROM_HERE, "invalid_signature", + UwMacaroonContext context = {}; + CHECK(uw_macaroon_context_create_(ToJ2000Time(time), nullptr, 0, &context)); + + if (!uw_macaroon_validate_(&macaroon, secret.data(), secret.size(), &context, + result)) { + return Error::AddTo(error, FROM_HERE, "invalid_token", "Invalid token signature"); } return true; @@ -239,14 +321,22 @@ AuthManager::~AuthManager() {} std::vector<uint8_t> AuthManager::CreateAccessToken(const UserInfo& user_info, base::TimeDelta ttl) const { - Caveat scope{kUwMacaroonCaveatTypeScope, ToMacaroonScope(user_info.scope())}; - Caveat user{kUwMacaroonCaveatTypeIdentifier, user_info.user_id()}; - Caveat issued{kUwMacaroonCaveatTypeExpiration, - static_cast<uint32_t>((Now() + ttl).ToTimeT())}; + const base::Time now = Now(); + TimestampCaveat issued{now}; + ScopeCaveat scope{ToMacaroonScope(user_info.scope())}; + // Macaroons have no caveats for auth type. So we just append the type to the + // user ID. + std::vector<uint8_t> id_with_type{user_info.id().user}; + id_with_type.push_back(static_cast<uint8_t>(user_info.id().type)); + UserIdCaveat user{id_with_type}; + AppIdCaveat app{user_info.id().app}; + ExpirationCaveat expiration{now + ttl}; return CreateMacaroonToken( - access_secret_, + access_secret_, now, { - scope.GetCaveat(), user.GetCaveat(), issued.GetCaveat(), + + &issued.GetCaveat(), &scope.GetCaveat(), &user.GetCaveat(), + &app.GetCaveat(), &expiration.GetCaveat(), }); } @@ -256,37 +346,40 @@ bool AuthManager::ParseAccessToken(const std::vector<uint8_t>& token, std::vector<uint8_t> buffer; UwMacaroon macaroon{}; - uint32_t scope{0}; - std::string user_id; - uint32_t expiration{0}; - + UwMacaroonValidationResult result{}; + const base::Time now = Now(); if (!LoadMacaroon(token, &buffer, &macaroon, error) || - !VerifyMacaroon(access_secret_, macaroon, error) || - macaroon.num_caveats != 3 || - !ReadCaveat(macaroon.caveats[0], kUwMacaroonCaveatTypeScope, &scope, - error) || - !ReadCaveat(macaroon.caveats[1], kUwMacaroonCaveatTypeIdentifier, - &user_id, error) || - !ReadCaveat(macaroon.caveats[2], kUwMacaroonCaveatTypeExpiration, - &expiration, error)) { + macaroon.num_caveats != 5 || + !VerifyMacaroon(access_secret_, macaroon, now, &result, error)) { return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthorization, "Invalid token"); } - AuthScope auth_scope{FromMacaroonScope(scope)}; + AuthScope auth_scope{FromMacaroonScope(result.granted_scope)}; if (auth_scope == AuthScope::kNone) { return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthorization, "Invalid token data"); } - base::Time time{base::Time::FromTimeT(expiration)}; - if (time < clock_->Now()) { - return Error::AddTo(error, FROM_HERE, errors::kAuthorizationExpired, - "Token is expired"); - } - + // If token is valid and token was not extended, it should has precisely this + // values. + CHECK_GE(FromJ2000Time(result.expiration_time), now); + CHECK_EQ(2u, result.num_delegatees); + CHECK_EQ(kUwMacaroonDelegateeTypeUser, result.delegatees[0].type); + CHECK_EQ(kUwMacaroonDelegateeTypeApp, result.delegatees[1].type); + CHECK_GT(result.delegatees[0].id_len, 1u); + std::vector<uint8_t> user_id{ + result.delegatees[0].id, + result.delegatees[0].id + result.delegatees[0].id_len}; + // Last byte is used for type. See |CreateAccessToken|. + AuthType type = static_cast<AuthType>(user_id.back()); + user_id.pop_back(); + + std::vector<uint8_t> app_id{ + result.delegatees[1].id, + result.delegatees[1].id + result.delegatees[1].id_len}; if (user_info) - *user_info = UserInfo{auth_scope, user_id}; + *user_info = UserInfo{auth_scope, UserAppId{type, user_id, app_id}}; return true; } @@ -309,7 +402,7 @@ std::vector<uint8_t> AuthManager::ClaimRootClientAuthToken( std::unique_ptr<AuthManager>{new AuthManager{nullptr, {}}}, owner)); if (pending_claims_.size() > kMaxPendingClaims) pending_claims_.pop_front(); - return pending_claims_.back().first->GetRootClientAuthToken(); + return pending_claims_.back().first->GetRootClientAuthToken(owner); } bool AuthManager::ConfirmClientAuthToken(const std::vector<uint8_t>& token, @@ -332,14 +425,20 @@ bool AuthManager::ConfirmClientAuthToken(const std::vector<uint8_t>& token, return true; } -std::vector<uint8_t> AuthManager::GetRootClientAuthToken() const { - Caveat scope{kUwMacaroonCaveatTypeScope, kUwMacaroonCaveatScopeTypeOwner}; - Caveat issued{kUwMacaroonCaveatTypeIssued, - static_cast<uint32_t>(Now().ToTimeT())}; - return CreateMacaroonToken(auth_secret_, - { - scope.GetCaveat(), issued.GetCaveat(), - }); +std::vector<uint8_t> AuthManager::GetRootClientAuthToken( + RootClientTokenOwner owner) const { + CHECK(RootClientTokenOwner::kNone != owner); + ClientAuthTokenCaveat auth_token; + const base::Time now = Now(); + TimestampCaveat issued{now}; + + ServiceCaveat client{owner == RootClientTokenOwner::kCloud ? "google.com" + : ""}; + return CreateMacaroonToken( + auth_secret_, now, + { + &auth_token.GetCaveat(), &issued.GetCaveat(), &client.GetCaveat(), + }); } base::Time AuthManager::Now() const { @@ -350,8 +449,9 @@ bool AuthManager::IsValidAuthToken(const std::vector<uint8_t>& token, ErrorPtr* error) const { std::vector<uint8_t> buffer; UwMacaroon macaroon{}; + UwMacaroonValidationResult result{}; if (!LoadMacaroon(token, &buffer, &macaroon, error) || - !VerifyMacaroon(auth_secret_, macaroon, error)) { + !VerifyMacaroon(auth_secret_, macaroon, Now(), &result, error)) { return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthCode, "Invalid token"); } @@ -365,19 +465,63 @@ bool AuthManager::CreateAccessTokenFromAuth( AuthScope* access_token_scope, base::TimeDelta* access_token_ttl, ErrorPtr* error) const { - // TODO(vitalybuka): implement token validation. - if (!IsValidAuthToken(auth_token, error)) - return false; + std::vector<uint8_t> buffer; + UwMacaroon macaroon{}; + UwMacaroonValidationResult result{}; + const base::Time now = Now(); + if (!LoadMacaroon(auth_token, &buffer, &macaroon, error) || + !VerifyMacaroon(auth_secret_, macaroon, now, &result, error)) { + return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthCode, + "Invalid token"); + } + + AuthScope auth_scope{FromMacaroonScope(result.granted_scope)}; + if (auth_scope == AuthScope::kNone) { + return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthCode, + "Invalid token data"); + } + + // TODO: Integrate black list checks. + auto delegates_rbegin = std::reverse_iterator<const UwMacaroonDelegateeInfo*>( + result.delegatees + result.num_delegatees); + auto delegates_rend = + std::reverse_iterator<const UwMacaroonDelegateeInfo*>(result.delegatees); + auto last_user_id = + std::find_if(delegates_rbegin, delegates_rend, + [](const UwMacaroonDelegateeInfo& delegatee) { + return delegatee.type == kUwMacaroonDelegateeTypeUser; + }); + auto last_app_id = + std::find_if(delegates_rbegin, delegates_rend, + [](const UwMacaroonDelegateeInfo& delegatee) { + return delegatee.type == kUwMacaroonDelegateeTypeApp; + }); + + if (last_user_id == delegates_rend || !last_user_id->id_len) { + return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthCode, + "User ID is missing"); + } + + const char* session_id = reinterpret_cast<const char*>(result.lan_session_id); + if (!IsValidSessionId({session_id, session_id + result.lan_session_id_len})) { + return Error::AddTo(error, FROM_HERE, errors::kInvalidAuthCode, + "Invalid session id"); + } + + CHECK_GE(FromJ2000Time(result.expiration_time), now); if (!access_token) return true; - // TODO(vitalybuka): User and scope must be parsed from auth_token. - UserInfo info{config_ ? config_->GetSettings().local_anonymous_access_role - : AuthScope::kViewer, - base::GenerateGUID()}; + std::vector<uint8_t> user_id{last_user_id->id, + last_user_id->id + last_user_id->id_len}; + std::vector<uint8_t> app_id; + if (last_app_id != delegates_rend) + app_id.assign(last_app_id->id, last_app_id->id + last_app_id->id_len); + + UserInfo info{auth_scope, {AuthType::kLocal, user_id, app_id}}; - // TODO(vitalybuka): TTL also should be reduced in accordance with auth_token. + ttl = std::min(ttl, FromJ2000Time(result.expiration_time) - now); *access_token = CreateAccessToken(info, ttl); if (access_token_scope) @@ -388,11 +532,45 @@ bool AuthManager::CreateAccessTokenFromAuth( return true; } -std::vector<uint8_t> AuthManager::CreateSessionId() { - std::vector<uint8_t> result; - AppendToArray(Now().ToTimeT(), &result); - AppendToArray(++session_counter_, &result); - return result; +std::string AuthManager::CreateSessionId() const { + return std::to_string(ToJ2000Time(Now())) + ":" + + std::to_string(++session_counter_); +} + +bool AuthManager::IsValidSessionId(const std::string& session_id) const { + base::Time ssid_time = FromJ2000Time(std::atoi(session_id.c_str())); + return Now() - base::TimeDelta::FromMinutes(kSessionIdTtlMinutes) <= + ssid_time && + ssid_time <= Now(); +} + +std::vector<uint8_t> AuthManager::DelegateToUser( + const std::vector<uint8_t>& token, + base::TimeDelta ttl, + const UserInfo& user_info) const { + std::vector<uint8_t> buffer; + UwMacaroon macaroon{}; + CHECK(LoadMacaroon(token, &buffer, &macaroon, nullptr)); + + const base::Time now = Now(); + TimestampCaveat issued{now}; + ExpirationCaveat expiration{now + ttl}; + ScopeCaveat scope{ToMacaroonScope(user_info.scope())}; + UserIdCaveat user{user_info.id().user}; + AppIdCaveat app{user_info.id().app}; + SessionIdCaveat session{CreateSessionId()}; + + std::vector<const UwMacaroonCaveat*> caveats{ + &issued.GetCaveat(), &expiration.GetCaveat(), &scope.GetCaveat(), + &user.GetCaveat(), + }; + + if (!user_info.id().app.empty()) + caveats.push_back(&app.GetCaveat()); + + caveats.push_back(&session.GetCaveat()); + + return ExtendMacaroonToken(macaroon, now, caveats); } } // namespace privet diff --git a/src/privet/auth_manager.h b/src/privet/auth_manager.h index 309d80e..f0a5761 100644 --- a/src/privet/auth_manager.h +++ b/src/privet/auth_manager.h @@ -9,6 +9,7 @@ #include <string> #include <vector> +#include <base/gtest_prod_util.h> #include <base/time/default_clock.h> #include <base/time/time.h> #include <weave/error.h> @@ -54,7 +55,7 @@ class AuthManager { bool ConfirmClientAuthToken(const std::vector<uint8_t>& token, ErrorPtr* error); - std::vector<uint8_t> GetRootClientAuthToken() const; + std::vector<uint8_t> GetRootClientAuthToken(RootClientTokenOwner owner) const; bool IsValidAuthToken(const std::vector<uint8_t>& token, ErrorPtr* error) const; bool CreateAccessTokenFromAuth(const std::vector<uint8_t>& auth_token, @@ -67,13 +68,21 @@ class AuthManager { void SetAuthSecret(const std::vector<uint8_t>& secret, RootClientTokenOwner owner); - std::vector<uint8_t> CreateSessionId(); + std::string CreateSessionId() const; + bool IsValidSessionId(const std::string& session_id) const; private: + friend class AuthManagerTest; + + // Test helpers. Device does not need to implement delegation. + std::vector<uint8_t> DelegateToUser(const std::vector<uint8_t>& token, + base::TimeDelta ttl, + const UserInfo& user_info) const; + Config* config_{nullptr}; // Can be nullptr for tests. base::DefaultClock default_clock_; base::Clock* clock_{&default_clock_}; - uint32_t session_counter_{0}; + mutable uint32_t session_counter_{0}; std::vector<uint8_t> auth_secret_; // Persistent. std::vector<uint8_t> certificate_fingerprint_; diff --git a/src/privet/auth_manager_unittest.cc b/src/privet/auth_manager_unittest.cc index 70750ad..294aefa 100644 --- a/src/privet/auth_manager_unittest.cc +++ b/src/privet/auth_manager_unittest.cc @@ -10,6 +10,7 @@ #include "src/config.h" #include "src/data_encoding.h" +#include "src/privet/mock_delegates.h" #include "src/test/mock_clock.h" using testing::Return; @@ -29,6 +30,11 @@ class AuthManagerTest : public testing::Test { } protected: + std::vector<uint8_t> DelegateToUser(const std::vector<uint8_t>& token, + base::TimeDelta ttl, + const UserInfo& user_info) const { + return auth_.DelegateToUser(token, ttl, user_info); + } const std::vector<uint8_t> kSecret1{ 78, 40, 39, 68, 29, 19, 70, 86, 38, 61, 13, 55, 33, 32, 51, 52, 34, 43, 97, 48, 8, 56, 11, 99, 50, 59, 24, 26, 31, 71, 76, 28}; @@ -64,49 +70,90 @@ TEST_F(AuthManagerTest, Constructor) { } TEST_F(AuthManagerTest, CreateAccessToken) { - EXPECT_EQ("UABRUHgcSZDry0bvIsoJv+WDQgEURQJjMjM0RgUaVArkgA==", + EXPECT_EQ("WC2FRggaG52hAEIBFEYJRDIzNABCCkBGBRobnaEAUFAF46oQlMmXgnLstt7wU2w=", Base64Encode(auth_.CreateAccessToken( - UserInfo{AuthScope::kViewer, "234"}, {}))); - EXPECT_EQ("UL7YEruLg5QQRDIp2+u1cqCDQgEIRQJjMjU3RgUaVArkgA==", + UserInfo{AuthScope::kViewer, TestUserId{"234"}}, {}))); + EXPECT_EQ("WC2FRggaG52hAEIBCEYJRDI1NwBCCkBGBRobnaEAUEdWRNHcu/0mA6c3e0tgDrk=", Base64Encode(auth_.CreateAccessToken( - UserInfo{AuthScope::kManager, "257"}, {}))); - EXPECT_EQ("UPFGeZRanR1wLGYLP5ZDkXiDQgECRQJjNDU2RgUaVArkgA==", + UserInfo{AuthScope::kManager, TestUserId{"257"}}, {}))); + EXPECT_EQ("WC2FRggaG52hAEIBAkYJRDQ1NgBCCkBGBRobnaEAUH2ZLgUPdTtjNRa+PoDkMW4=", Base64Encode(auth_.CreateAccessToken( - UserInfo{AuthScope::kOwner, "456"}, {}))); + UserInfo{AuthScope::kOwner, TestUserId{"456"}}, {}))); auto new_time = clock_.Now() + base::TimeDelta::FromDays(11); EXPECT_CALL(clock_, Now()).WillRepeatedly(Return(new_time)); - EXPECT_EQ("UMm9KlF3OEtZFBmhScJpl4uDQgEORQJjMzQ1RgUaVBllAA==", + EXPECT_EQ("WC2FRggaG6whgEIBDkYJRDM0NQBCCkBGBRobrCGAUDAFptj7bbYmbpaa6Wpb1Wo=", Base64Encode(auth_.CreateAccessToken( - UserInfo{AuthScope::kUser, "345"}, {}))); + UserInfo{AuthScope::kUser, TestUserId{"345"}}, {}))); } TEST_F(AuthManagerTest, CreateSameToken) { - EXPECT_EQ(auth_.CreateAccessToken(UserInfo{AuthScope::kViewer, "555"}, {}), - auth_.CreateAccessToken(UserInfo{AuthScope::kViewer, "555"}, {})); + EXPECT_EQ(auth_.CreateAccessToken( + UserInfo{AuthScope::kViewer, TestUserId{"555"}}, {}), + auth_.CreateAccessToken( + UserInfo{AuthScope::kViewer, TestUserId{"555"}}, {})); +} + +TEST_F(AuthManagerTest, CreateSameTokenWithApp) { + EXPECT_EQ(auth_.CreateAccessToken( + UserInfo{AuthScope::kViewer, + {AuthType::kLocal, {1, 2, 3}, {4, 5, 6}}}, + {}), + auth_.CreateAccessToken( + UserInfo{AuthScope::kViewer, + {AuthType::kLocal, {1, 2, 3}, {4, 5, 6}}}, + {})); +} + +TEST_F(AuthManagerTest, CreateSameTokenWithDifferentType) { + EXPECT_NE(auth_.CreateAccessToken( + UserInfo{AuthScope::kViewer, + {AuthType::kLocal, {1, 2, 3}, {4, 5, 6}}}, + {}), + auth_.CreateAccessToken( + UserInfo{AuthScope::kViewer, + {AuthType::kPairing, {1, 2, 3}, {4, 5, 6}}}, + {})); +} + +TEST_F(AuthManagerTest, CreateSameTokenWithDifferentApp) { + EXPECT_NE(auth_.CreateAccessToken( + UserInfo{AuthScope::kViewer, + {AuthType::kLocal, {1, 2, 3}, {4, 5, 6}}}, + {}), + auth_.CreateAccessToken( + UserInfo{AuthScope::kViewer, + {AuthType::kLocal, {1, 2, 3}, {4, 5, 7}}}, + {})); } TEST_F(AuthManagerTest, CreateTokenDifferentScope) { - EXPECT_NE(auth_.CreateAccessToken(UserInfo{AuthScope::kViewer, "456"}, {}), - auth_.CreateAccessToken(UserInfo{AuthScope::kOwner, "456"}, {})); + EXPECT_NE(auth_.CreateAccessToken( + UserInfo{AuthScope::kViewer, TestUserId{"456"}}, {}), + auth_.CreateAccessToken( + UserInfo{AuthScope::kOwner, TestUserId{"456"}}, {})); } TEST_F(AuthManagerTest, CreateTokenDifferentUser) { - EXPECT_NE(auth_.CreateAccessToken(UserInfo{AuthScope::kOwner, "456"}, {}), - auth_.CreateAccessToken(UserInfo{AuthScope::kOwner, "789"}, {})); + EXPECT_NE(auth_.CreateAccessToken( + UserInfo{AuthScope::kOwner, TestUserId{"456"}}, {}), + auth_.CreateAccessToken( + UserInfo{AuthScope::kOwner, TestUserId{"789"}}, {})); } TEST_F(AuthManagerTest, CreateTokenDifferentTime) { - auto token = auth_.CreateAccessToken(UserInfo{AuthScope::kOwner, "567"}, {}); + auto token = auth_.CreateAccessToken( + UserInfo{AuthScope::kOwner, TestUserId{"567"}}, {}); EXPECT_CALL(clock_, Now()) .WillRepeatedly(Return(base::Time::FromTimeT(1400000000))); - EXPECT_NE(token, - auth_.CreateAccessToken(UserInfo{AuthScope::kOwner, "567"}, {})); + EXPECT_NE(token, auth_.CreateAccessToken( + UserInfo{AuthScope::kOwner, TestUserId{"567"}}, {})); } TEST_F(AuthManagerTest, CreateTokenDifferentInstance) { - EXPECT_NE(auth_.CreateAccessToken(UserInfo{AuthScope::kUser, "123"}, {}), + EXPECT_NE(auth_.CreateAccessToken( + UserInfo{AuthScope::kUser, TestUserId{"123"}}, {}), AuthManager({}, {}).CreateAccessToken( - UserInfo{AuthScope::kUser, "123"}, {})); + UserInfo{AuthScope::kUser, TestUserId{"123"}}, {})); } TEST_F(AuthManagerTest, ParseAccessToken) { @@ -117,18 +164,24 @@ TEST_F(AuthManagerTest, ParseAccessToken) { AuthManager auth{{}, {}, {}, &clock_}; - auto token = auth.CreateAccessToken(UserInfo{AuthScope::kUser, "5"}, - base::TimeDelta::FromSeconds(i)); + auto token = + auth.CreateAccessToken(UserInfo{AuthScope::kUser, TestUserId{"5"}}, + base::TimeDelta::FromSeconds(i)); UserInfo user_info; EXPECT_FALSE(auth_.ParseAccessToken(token, &user_info, nullptr)); EXPECT_TRUE(auth.ParseAccessToken(token, &user_info, nullptr)); EXPECT_EQ(AuthScope::kUser, user_info.scope()); - EXPECT_EQ("5", user_info.user_id()); + EXPECT_EQ(TestUserId{"5"}, user_info.id()); EXPECT_CALL(clock_, Now()) .WillRepeatedly(Return(kStartTime + base::TimeDelta::FromSeconds(i))); EXPECT_TRUE(auth.ParseAccessToken(token, &user_info, nullptr)); + auto extended = + DelegateToUser(token, base::TimeDelta::FromSeconds(1000), + UserInfo{AuthScope::kUser, TestUserId{"234"}}); + EXPECT_FALSE(auth.ParseAccessToken(extended, &user_info, nullptr)); + EXPECT_CALL(clock_, Now()) .WillRepeatedly( Return(kStartTime + base::TimeDelta::FromSeconds(i + 1))); @@ -137,35 +190,135 @@ TEST_F(AuthManagerTest, ParseAccessToken) { } TEST_F(AuthManagerTest, GetRootClientAuthToken) { - EXPECT_EQ("UK1ACOc3cWGjGBoTIX2bd3qCQgECRgMaVArkgA==", - Base64Encode(auth_.GetRootClientAuthToken())); + EXPECT_EQ("WCCDQxkgAUYIGhudoQBCDEBQZgRhYq78I8GtFUZHNBbfGw==", + Base64Encode( + auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient))); +} + +TEST_F(AuthManagerTest, GetRootClientAuthTokenDifferentOwner) { + EXPECT_EQ( + "WCqDQxkgAUYIGhudoQBMDEpnb29nbGUuY29tUOoLAxSUAZAAv54drarqhag=", + Base64Encode(auth_.GetRootClientAuthToken(RootClientTokenOwner::kCloud))); } TEST_F(AuthManagerTest, GetRootClientAuthTokenDifferentTime) { auto new_time = clock_.Now() + base::TimeDelta::FromDays(15); EXPECT_CALL(clock_, Now()).WillRepeatedly(Return(new_time)); - EXPECT_EQ("UBpNF8g/GbNUmAyHg1qqJr+CQgECRgMaVB6rAA==", - Base64Encode(auth_.GetRootClientAuthToken())); + EXPECT_EQ("WCCDQxkgAUYIGhuxZ4BCDEBQjO+OTbjjTzZ/Dvk66nfQqg==", + Base64Encode( + auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient))); } TEST_F(AuthManagerTest, GetRootClientAuthTokenDifferentSecret) { AuthManager auth{kSecret2, {}, kSecret1, &clock_}; - EXPECT_EQ("UFTBUcgd9d0HnPRnLeroN2mCQgECRgMaVArkgA==", - Base64Encode(auth.GetRootClientAuthToken())); + EXPECT_EQ( + "WCCDQxkgAUYIGhudoQBCDEBQ2MZF8YXv5pbtmMxwz9VtLA==", + Base64Encode(auth.GetRootClientAuthToken(RootClientTokenOwner::kClient))); } TEST_F(AuthManagerTest, IsValidAuthToken) { - EXPECT_TRUE(auth_.IsValidAuthToken(auth_.GetRootClientAuthToken(), nullptr)); + EXPECT_TRUE(auth_.IsValidAuthToken( + auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient), nullptr)); // Multiple attempts with random secrets. for (size_t i = 0; i < 1000; ++i) { AuthManager auth{{}, {}, {}, &clock_}; - auto token = auth.GetRootClientAuthToken(); + auto token = auth.GetRootClientAuthToken(RootClientTokenOwner::kClient); EXPECT_FALSE(auth_.IsValidAuthToken(token, nullptr)); EXPECT_TRUE(auth.IsValidAuthToken(token, nullptr)); } } +TEST_F(AuthManagerTest, CreateSessionId) { + EXPECT_EQ("463315200:1", auth_.CreateSessionId()); +} + +TEST_F(AuthManagerTest, IsValidSessionId) { + EXPECT_TRUE(auth_.IsValidSessionId("463315200:1")); + EXPECT_TRUE(auth_.IsValidSessionId("463315200:2")); + EXPECT_TRUE(auth_.IsValidSessionId("463315150")); + + // Future + EXPECT_FALSE(auth_.IsValidSessionId("463315230:1")); + + // Expired + EXPECT_FALSE(auth_.IsValidSessionId("463315100:1")); +} + +TEST_F(AuthManagerTest, CreateAccessTokenFromAuth) { + std::vector<uint8_t> access_token; + AuthScope scope; + base::TimeDelta ttl; + auto root = auth_.GetRootClientAuthToken(RootClientTokenOwner::kCloud); + auto extended = DelegateToUser(root, base::TimeDelta::FromSeconds(1000), + UserInfo{AuthScope::kUser, TestUserId{"234"}}); + EXPECT_EQ( + "WE+IQxkgAUYIGhudoQBMDEpnb29nbGUuY29tRggaG52hAEYFGhudpOhCAQ5FCUMyMzRNEUs0" + "NjMzMTUyMDA6MVCRVKU+0SpOoBppnwqdKMwP", + Base64Encode(extended)); + EXPECT_TRUE( + auth_.CreateAccessTokenFromAuth(extended, base::TimeDelta::FromDays(1), + &access_token, &scope, &ttl, nullptr)); + UserInfo user_info; + EXPECT_TRUE(auth_.ParseAccessToken(access_token, &user_info, nullptr)); + EXPECT_EQ(scope, user_info.scope()); + EXPECT_EQ(AuthScope::kUser, user_info.scope()); + + EXPECT_EQ(TestUserId{"234"}, user_info.id()); +} + +TEST_F(AuthManagerTest, CreateAccessTokenFromAuthNotMinted) { + std::vector<uint8_t> access_token; + auto root = auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient); + ErrorPtr error; + EXPECT_FALSE(auth_.CreateAccessTokenFromAuth( + root, base::TimeDelta::FromDays(1), nullptr, nullptr, nullptr, &error)); + EXPECT_TRUE(error->HasError("invalidAuthCode")); +} + +TEST_F(AuthManagerTest, CreateAccessTokenFromAuthValidateAfterSomeTime) { + auto root = auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient); + auto extended = DelegateToUser(root, base::TimeDelta::FromSeconds(1000), + UserInfo{AuthScope::kUser, TestUserId{"234"}}); + + // new_time < session_id_expiration < token_expiration. + auto new_time = clock_.Now() + base::TimeDelta::FromSeconds(15); + EXPECT_CALL(clock_, Now()).WillRepeatedly(Return(new_time)); + EXPECT_TRUE( + auth_.CreateAccessTokenFromAuth(extended, base::TimeDelta::FromDays(1), + nullptr, nullptr, nullptr, nullptr)); +} + +TEST_F(AuthManagerTest, CreateAccessTokenFromAuthExpired) { + auto root = auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient); + auto extended = DelegateToUser(root, base::TimeDelta::FromSeconds(10), + UserInfo{AuthScope::kUser, TestUserId{"234"}}); + ErrorPtr error; + + // token_expiration < new_time < session_id_expiration. + auto new_time = clock_.Now() + base::TimeDelta::FromSeconds(15); + EXPECT_CALL(clock_, Now()).WillRepeatedly(Return(new_time)); + EXPECT_FALSE( + auth_.CreateAccessTokenFromAuth(extended, base::TimeDelta::FromDays(1), + nullptr, nullptr, nullptr, &error)); + EXPECT_TRUE(error->HasError("invalidAuthCode")); +} + +TEST_F(AuthManagerTest, CreateAccessTokenFromAuthExpiredSessionid) { + auto root = auth_.GetRootClientAuthToken(RootClientTokenOwner::kClient); + auto extended = DelegateToUser(root, base::TimeDelta::FromSeconds(1000), + UserInfo{AuthScope::kUser, TestUserId{"234"}}); + ErrorPtr error; + + // session_id_expiration < new_time < token_expiration. + auto new_time = clock_.Now() + base::TimeDelta::FromSeconds(200); + EXPECT_CALL(clock_, Now()).WillRepeatedly(Return(new_time)); + EXPECT_FALSE( + auth_.CreateAccessTokenFromAuth(extended, base::TimeDelta::FromDays(1), + nullptr, nullptr, nullptr, &error)); + EXPECT_TRUE(error->HasError("invalidAuthCode")); +} + class AuthManagerClaimTest : public testing::Test { public: void SetUp() override { EXPECT_EQ(auth_.GetAuthSecret().size(), 32u); } @@ -241,18 +394,5 @@ TEST_F(AuthManagerClaimTest, TokenOverflow) { EXPECT_FALSE(auth_.ConfirmClientAuthToken(token, nullptr)); } -TEST_F(AuthManagerClaimTest, CreateAccessTokenFromAuth) { - std::vector<uint8_t> access_token; - AuthScope scope; - base::TimeDelta ttl; - EXPECT_TRUE(auth_.CreateAccessTokenFromAuth( - auth_.GetRootClientAuthToken(), base::TimeDelta::FromDays(1), - &access_token, &scope, &ttl, nullptr)); - UserInfo user_info; - EXPECT_TRUE(auth_.ParseAccessToken(access_token, &user_info, nullptr)); - EXPECT_EQ(scope, user_info.scope()); - EXPECT_FALSE(user_info.user_id().empty()); -} - } // namespace privet } // namespace weave diff --git a/src/privet/cloud_delegate.cc b/src/privet/cloud_delegate.cc index 5f31fee..49fceaa 100644 --- a/src/privet/cloud_delegate.cc +++ b/src/privet/cloud_delegate.cc @@ -165,7 +165,7 @@ class CloudDelegateImpl : public CloudDelegate { const UserInfo& user_info, const CommandDoneCallback& callback) override { CHECK(user_info.scope() != AuthScope::kNone); - CHECK(!user_info.user_id().empty()); + CHECK(!user_info.id().IsEmpty()); ErrorPtr error; UserRole role; @@ -182,7 +182,7 @@ class CloudDelegateImpl : public CloudDelegate { if (!command_instance) return callback.Run({}, std::move(error)); component_manager_->AddCommand(std::move(command_instance)); - command_owners_[id] = user_info.user_id(); + command_owners_[id] = user_info.id(); callback.Run(*component_manager_->FindCommand(id)->ToJson(), nullptr); } @@ -230,7 +230,7 @@ class CloudDelegateImpl : public CloudDelegate { private: void OnCommandAdded(Command* command) { // Set to "" for any new unknown command. - command_owners_.insert(std::make_pair(command->GetID(), "")); + command_owners_.insert(std::make_pair(command->GetID(), UserAppId{})); } void OnCommandRemoved(Command* command) { @@ -309,14 +309,17 @@ class CloudDelegateImpl : public CloudDelegate { return command; } - bool CanAccessCommand(const std::string& owner_id, + bool CanAccessCommand(const UserAppId& owner, const UserInfo& user_info, ErrorPtr* error) const { CHECK(user_info.scope() != AuthScope::kNone); - CHECK(!user_info.user_id().empty()); + CHECK(!user_info.id().IsEmpty()); if (user_info.scope() == AuthScope::kManager || - owner_id == user_info.user_id()) { + (owner.type == user_info.id().type && + owner.user == user_info.id().user && + (user_info.id().app.empty() || // Token is not restricted to the app. + owner.app == user_info.id().app))) { return true; } @@ -341,7 +344,7 @@ class CloudDelegateImpl : public CloudDelegate { int registation_retry_count_{0}; // Map of command IDs to user IDs. - std::map<std::string, std::string> command_owners_; + std::map<std::string, UserAppId> command_owners_; // Backoff entry for retrying device registration. BackoffEntry backoff_entry_{®ister_backoff_policy}; diff --git a/src/privet/mock_delegates.h b/src/privet/mock_delegates.h index c75d438..c2e9a89 100644 --- a/src/privet/mock_delegates.h +++ b/src/privet/mock_delegates.h @@ -28,6 +28,11 @@ namespace weave { namespace privet { +struct TestUserId : public UserAppId { + TestUserId(const std::string& user_id) + : UserAppId{AuthType::kAnonymous, {user_id.begin(), user_id.end()}, {}} {} +}; + ACTION_TEMPLATE(RunCallback, HAS_1_TEMPLATE_PARAMS(int, k), AND_0_VALUE_PARAMS()) { @@ -103,9 +108,12 @@ class MockSecurityDelegate : public SecurityDelegate { .WillRepeatedly(Return(true)); EXPECT_CALL(*this, ParseAccessToken(_, _, _)) - .WillRepeatedly( - DoAll(SetArgPointee<1>(UserInfo{AuthScope::kViewer, "1234567"}), - Return(true))); + .WillRepeatedly(DoAll(SetArgPointee<1>(UserInfo{ + AuthScope::kViewer, + UserAppId{AuthType::kLocal, + {'1', '2', '3', '4', '5', '6', '7'}, + {}}}), + Return(true))); EXPECT_CALL(*this, GetPairingTypes()) .WillRepeatedly(Return(std::set<PairingType>{ diff --git a/src/privet/openssl_utils.cc b/src/privet/openssl_utils.cc index f38fd1a..17ebf70 100644 --- a/src/privet/openssl_utils.cc +++ b/src/privet/openssl_utils.cc @@ -18,13 +18,9 @@ namespace privet { std::vector<uint8_t> HmacSha256(const std::vector<uint8_t>& key, const std::vector<uint8_t>& data) { std::vector<uint8_t> mac(kSha256OutputSize); - uint8_t hmac_state[uw_crypto_hmac_required_buffer_size_()]; - CHECK(uw_crypto_hmac_init_(hmac_state, sizeof(hmac_state), key.data(), - key.size())); - CHECK(uw_crypto_hmac_update_(hmac_state, sizeof(hmac_state), data.data(), - data.size())); - CHECK(uw_crypto_hmac_final_(hmac_state, sizeof(hmac_state), mac.data(), - mac.size())); + const UwCryptoHmacMsg messages[] = {{data.data(), data.size()}}; + CHECK(uw_crypto_hmac_(key.data(), key.size(), messages, arraysize(messages), + mac.data(), mac.size())); return mac; } diff --git a/src/privet/privet_handler_unittest.cc b/src/privet/privet_handler_unittest.cc index fa79e77..20f5aa0 100644 --- a/src/privet/privet_handler_unittest.cc +++ b/src/privet/privet_handler_unittest.cc @@ -484,7 +484,8 @@ class PrivetHandlerTestWithAuth : public PrivetHandlerTest { auth_header_ = "Privet 123"; EXPECT_CALL(security_, ParseAccessToken(_, _, _)) .WillRepeatedly(DoAll( - SetArgPointee<1>(UserInfo{AuthScope::kOwner, "1"}), Return(true))); + SetArgPointee<1>(UserInfo{AuthScope::kOwner, TestUserId{"1"}}), + Return(true))); } }; @@ -658,7 +659,8 @@ TEST_F(PrivetHandlerSetupTest, GcdSetup) { TEST_F(PrivetHandlerSetupTest, GcdSetupAsMaster) { EXPECT_CALL(security_, ParseAccessToken(_, _, _)) .WillRepeatedly(DoAll( - SetArgPointee<1>(UserInfo{AuthScope::kManager, "1"}), Return(true))); + SetArgPointee<1>(UserInfo{AuthScope::kManager, TestUserId{"1"}}), + Return(true))); const char kInput[] = R"({ 'gcd': { 'ticketId': 'testTicket', diff --git a/src/privet/privet_types.h b/src/privet/privet_types.h index 49c4522..0f51862 100644 --- a/src/privet/privet_types.h +++ b/src/privet/privet_types.h @@ -29,17 +29,42 @@ enum class WifiType { kWifi50, }; +struct UserAppId { + UserAppId() = default; + + UserAppId(AuthType auth_type, + const std::vector<uint8_t>& user_id, + const std::vector<uint8_t>& app_id) + : type{auth_type}, + user{user_id}, + app{user_id.empty() ? user_id : app_id} {} + + bool IsEmpty() const { return user.empty(); } + + AuthType type{}; + std::vector<uint8_t> user; + std::vector<uint8_t> app; +}; + +inline bool operator==(const UserAppId& l, const UserAppId& r) { + return l.user == r.user && l.app == r.app; +} + +inline bool operator!=(const UserAppId& l, const UserAppId& r) { + return l.user != r.user || l.app != r.app; +} + class UserInfo { public: explicit UserInfo(AuthScope scope = AuthScope::kNone, - const std::string& user_id = {}) - : scope_{scope}, user_id_{scope == AuthScope::kNone ? "" : user_id} {} + const UserAppId& id = {}) + : scope_{scope}, id_{scope == AuthScope::kNone ? UserAppId{} : id} {} AuthScope scope() const { return scope_; } - const std::string& user_id() const { return user_id_; } + const UserAppId& id() const { return id_; } private: AuthScope scope_; - std::string user_id_; + UserAppId id_; }; class ConnectionState final { diff --git a/src/privet/security_manager.cc b/src/privet/security_manager.cc index 0f00699..3b08613 100644 --- a/src/privet/security_manager.cc +++ b/src/privet/security_manager.cc @@ -91,9 +91,10 @@ bool SecurityManager::CreateAccessTokenImpl(AuthType auth_type, std::vector<uint8_t>* access_token, AuthScope* access_token_scope, base::TimeDelta* access_token_ttl) { - UserInfo user_info{desired_scope, - std::to_string(static_cast<int>(auth_type)) + "/" + - std::to_string(++last_user_id_)}; + auto user_id = std::to_string(++last_user_id_); + UserInfo user_info{ + desired_scope, + UserAppId{auth_type, {user_id.begin(), user_id.end()}, {}}}; const base::TimeDelta kTtl = base::TimeDelta::FromSeconds(kAccessTokenExpirationSeconds); @@ -388,7 +389,7 @@ bool SecurityManager::CancelPairing(const std::string& session_id, } std::string SecurityManager::CreateSessionId() { - return Base64Encode(auth_manager_->CreateSessionId()); + return auth_manager_->CreateSessionId(); } void SecurityManager::RegisterPairingListeners( diff --git a/src/privet/security_manager_unittest.cc b/src/privet/security_manager_unittest.cc index 43b7f00..f596de9 100644 --- a/src/privet/security_manager_unittest.cc +++ b/src/privet/security_manager_unittest.cc @@ -25,6 +25,7 @@ #include "src/config.h" #include "src/data_encoding.h" #include "src/privet/auth_manager.h" +#include "src/privet/mock_delegates.h" #include "src/privet/openssl_utils.h" #include "src/test/mock_clock.h" #include "third_party/chromium/crypto/p224_spake.h" @@ -170,7 +171,7 @@ TEST_F(SecurityManagerTest, AccessToken) { UserInfo info; EXPECT_TRUE(security_.ParseAccessToken(token, &info, nullptr)); EXPECT_EQ(requested_scope, info.scope()); - EXPECT_EQ("0/" + std::to_string(i), info.user_id()); + EXPECT_EQ(TestUserId{std::to_string(i)}, info.id()); } } |