diff options
Diffstat (limited to 'gae/webapp/src/endpoint/endpoint_base_test.py')
-rw-r--r-- | gae/webapp/src/endpoint/endpoint_base_test.py | 256 |
1 files changed, 256 insertions, 0 deletions
diff --git a/gae/webapp/src/endpoint/endpoint_base_test.py b/gae/webapp/src/endpoint/endpoint_base_test.py new file mode 100644 index 0000000..2eb397a --- /dev/null +++ b/gae/webapp/src/endpoint/endpoint_base_test.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python +# +# Copyright (C) 2018 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import endpoints +import json +import unittest + +try: + from unittest import mock +except ImportError: + import mock + +from webapp.src import vtslab_status as Status +from webapp.src.endpoint import endpoint_base +from webapp.src.proto import model +from webapp.src.testing import unittest_base + + +class EndpointBaseTest(unittest_base.UnitTestBase): + """A class to test endpoint_base.EndpointBase class. + + Attributes: + eb: An EndpointBase class instance. + """ + + def setUp(self): + """Initializes test""" + super(EndpointBaseTest, self).setUp() + self.eb = endpoint_base.EndpointBase() + + def testGetAssignedMessagesAttributes(self): + attrs = ["hostname", "priority", "test_branch"] + job_message = model.JobMessage() + for attr in attrs: + setattr(job_message, attr, attr) + result = self.eb.GetAttributes(job_message, assigned_only=True) + self.assertEqual(set(attrs), set(result)) + + def testGetAssignedModelAttributes(self): + attrs = ["hostname", "priority", "test_branch"] + job = model.JobModel() + for attr in attrs: + setattr(job, attr, attr) + result = self.eb.GetAttributes(job, assigned_only=True) + self.assertEqual(set(attrs), set(result)) + + def testGetAllMessagesAttributes(self): + attrs = ["hostname", "priority", "test_branch"] + full_attrs = [ + "test_type", "hostname", "priority", "test_name", + "require_signed_device_build", "has_bootloader_img", + "has_radio_img", "device", "serial", "build_storage_type", + "manifest_branch", "build_target", "build_id", "pab_account_id", + "shards", "param", "status", "period", "gsi_storage_type", + "gsi_branch", "gsi_build_target", "gsi_build_id", + "gsi_pab_account_id", "gsi_vendor_version", "test_storage_type", + "test_branch", "test_build_target", "test_build_id", + "test_pab_account_id", "retry_count", "infra_log_url", + "image_package_repo_base", "report_bucket", + "report_spreadsheet_id", "report_persistent_url", + "report_reference_url" + ] + job_message = model.JobMessage() + for attr in attrs: + setattr(job_message, attr, attr) + result = self.eb.GetAttributes(job_message, assigned_only=False) + self.assertTrue(set(full_attrs) <= set(result)) + + def testGetAllModelAttributes(self): + attrs = ["hostname", "priority", "test_branch"] + full_attrs = [ + "test_type", "hostname", "priority", "test_name", + "require_signed_device_build", "has_bootloader_img", + "has_radio_img", "device", "serial", "build_storage_type", + "manifest_branch", "build_target", "build_id", "pab_account_id", + "shards", "param", "status", "period", "gsi_storage_type", + "gsi_branch", "gsi_build_target", "gsi_build_id", + "gsi_pab_account_id", "gsi_vendor_version", "test_storage_type", + "test_branch", "test_build_target", "test_build_id", + "test_pab_account_id", "timestamp", "heartbeat_stamp", + "retry_count", "infra_log_url", "parent_schedule", + "image_package_repo_base", "report_bucket", + "report_spreadsheet_id", "report_persistent_url", + "report_reference_url" + ] + job = model.JobModel() + for attr in attrs: + setattr(job, attr, attr) + result = self.eb.GetAttributes(job, assigned_only=False) + self.assertTrue(set(full_attrs) <= set(result)) + + def testGetSingleEntity(self): + """Asserts to get a single entity.""" + device = self.GenerateDeviceModel() + device.put() + + request_body = (endpoints.ResourceContainer( + model.GetRequestMessage).combined_message_class( + size=0, + offset=0, + filter="", + sort="", + direction="", + )) + result, more = self.eb.Get( + request=request_body, + metaclass=model.DeviceModel, + message=model.DeviceInfoMessage) + self.assertEqual(len(result), 1) + self.assertFalse(more) + + def testGetHundredEntities(self): + """Asserts to get hundred entities.""" + for _ in xrange(100): + device = self.GenerateDeviceModel() + device.put() + + request_body = (endpoints.ResourceContainer( + model.GetRequestMessage).combined_message_class( + size=0, + offset=0, + filter="", + sort="", + direction="", + )) + result, more = self.eb.Get( + request=request_body, + metaclass=model.DeviceModel, + message=model.DeviceInfoMessage) + self.assertEqual(len(result), 100) + self.assertFalse(more) + + def testGetEntitiesWithPagination(self): + """Asserts to get entities with pagination.""" + for _ in xrange(100): + device = self.GenerateDeviceModel() + device.put() + + request_body = (endpoints.ResourceContainer( + model.GetRequestMessage).combined_message_class( + size=60, + offset=0, + filter="", + sort="", + direction="", + )) + result, more = self.eb.Get( + request=request_body, + metaclass=model.DeviceModel, + message=model.DeviceInfoMessage) + self.assertEqual(len(result), 60) + self.assertTrue(more) + + request_body = (endpoints.ResourceContainer( + model.GetRequestMessage).combined_message_class( + size=100, + offset=60, + filter="", + sort="", + direction="", + )) + result, more = self.eb.Get( + request=request_body, + metaclass=model.DeviceModel, + message=model.DeviceInfoMessage) + self.assertEqual(len(result), 40) + self.assertFalse(more) + + def testGetWithFilter(self): + """Asserts to get entities with filter.""" + for _ in xrange(50): + device = self.GenerateDeviceModel() + device.put() + + for _ in xrange(50): + device = self.GenerateDeviceModel(product="product") + device.put() + + filter = [{ + "key": "product", + "method": Status.FILTER_METHOD[Status.FILTER_EqualTo], + "value": "product" + }] + filter_string = json.dumps(filter) + request_body = (endpoints.ResourceContainer( + model.GetRequestMessage).combined_message_class( + size=0, + offset=0, + filter=filter_string, + sort="", + direction="", + )) + result, more = self.eb.Get( + request=request_body, + metaclass=model.DeviceModel, + message=model.DeviceInfoMessage) + self.assertEqual(len(result), 50) + self.assertFalse(more) + + def testGetWithSort(self): + """Asserts to get entities with sort.""" + for _ in xrange(100): + device = self.GenerateDeviceModel() + device.put() + + request_body = (endpoints.ResourceContainer( + model.GetRequestMessage).combined_message_class( + size=0, + offset=0, + filter="", + sort="serial", + direction="asc", + )) + + result, more = self.eb.Get( + request=request_body, + metaclass=model.DeviceModel, + message=model.DeviceInfoMessage) + self.assertEqual(len(result), 100) + for i in xrange(len(result) - 1): + self.assertTrue(result[i]["serial"] < result[i + 1]["serial"]) + + request_body = (endpoints.ResourceContainer( + model.GetRequestMessage).combined_message_class( + size=0, + offset=0, + filter="", + sort="serial", + direction="desc", + )) + + result, more = self.eb.Get( + request=request_body, + metaclass=model.DeviceModel, + message=model.DeviceInfoMessage) + self.assertEqual(len(result), 100) + for i in xrange(len(result) - 1): + self.assertTrue(result[i]["serial"] > result[i + 1]["serial"]) + + +if __name__ == "__main__": + unittest.main() |