summaryrefslogtreecommitdiff
path: root/gae/webapp/src/endpoint/endpoint_base_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'gae/webapp/src/endpoint/endpoint_base_test.py')
-rw-r--r--gae/webapp/src/endpoint/endpoint_base_test.py256
1 files changed, 256 insertions, 0 deletions
diff --git a/gae/webapp/src/endpoint/endpoint_base_test.py b/gae/webapp/src/endpoint/endpoint_base_test.py
new file mode 100644
index 0000000..2eb397a
--- /dev/null
+++ b/gae/webapp/src/endpoint/endpoint_base_test.py
@@ -0,0 +1,256 @@
+#!/usr/bin/env python
+#
+# Copyright (C) 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import endpoints
+import json
+import unittest
+
+try:
+ from unittest import mock
+except ImportError:
+ import mock
+
+from webapp.src import vtslab_status as Status
+from webapp.src.endpoint import endpoint_base
+from webapp.src.proto import model
+from webapp.src.testing import unittest_base
+
+
+class EndpointBaseTest(unittest_base.UnitTestBase):
+ """A class to test endpoint_base.EndpointBase class.
+
+ Attributes:
+ eb: An EndpointBase class instance.
+ """
+
+ def setUp(self):
+ """Initializes test"""
+ super(EndpointBaseTest, self).setUp()
+ self.eb = endpoint_base.EndpointBase()
+
+ def testGetAssignedMessagesAttributes(self):
+ attrs = ["hostname", "priority", "test_branch"]
+ job_message = model.JobMessage()
+ for attr in attrs:
+ setattr(job_message, attr, attr)
+ result = self.eb.GetAttributes(job_message, assigned_only=True)
+ self.assertEqual(set(attrs), set(result))
+
+ def testGetAssignedModelAttributes(self):
+ attrs = ["hostname", "priority", "test_branch"]
+ job = model.JobModel()
+ for attr in attrs:
+ setattr(job, attr, attr)
+ result = self.eb.GetAttributes(job, assigned_only=True)
+ self.assertEqual(set(attrs), set(result))
+
+ def testGetAllMessagesAttributes(self):
+ attrs = ["hostname", "priority", "test_branch"]
+ full_attrs = [
+ "test_type", "hostname", "priority", "test_name",
+ "require_signed_device_build", "has_bootloader_img",
+ "has_radio_img", "device", "serial", "build_storage_type",
+ "manifest_branch", "build_target", "build_id", "pab_account_id",
+ "shards", "param", "status", "period", "gsi_storage_type",
+ "gsi_branch", "gsi_build_target", "gsi_build_id",
+ "gsi_pab_account_id", "gsi_vendor_version", "test_storage_type",
+ "test_branch", "test_build_target", "test_build_id",
+ "test_pab_account_id", "retry_count", "infra_log_url",
+ "image_package_repo_base", "report_bucket",
+ "report_spreadsheet_id", "report_persistent_url",
+ "report_reference_url"
+ ]
+ job_message = model.JobMessage()
+ for attr in attrs:
+ setattr(job_message, attr, attr)
+ result = self.eb.GetAttributes(job_message, assigned_only=False)
+ self.assertTrue(set(full_attrs) <= set(result))
+
+ def testGetAllModelAttributes(self):
+ attrs = ["hostname", "priority", "test_branch"]
+ full_attrs = [
+ "test_type", "hostname", "priority", "test_name",
+ "require_signed_device_build", "has_bootloader_img",
+ "has_radio_img", "device", "serial", "build_storage_type",
+ "manifest_branch", "build_target", "build_id", "pab_account_id",
+ "shards", "param", "status", "period", "gsi_storage_type",
+ "gsi_branch", "gsi_build_target", "gsi_build_id",
+ "gsi_pab_account_id", "gsi_vendor_version", "test_storage_type",
+ "test_branch", "test_build_target", "test_build_id",
+ "test_pab_account_id", "timestamp", "heartbeat_stamp",
+ "retry_count", "infra_log_url", "parent_schedule",
+ "image_package_repo_base", "report_bucket",
+ "report_spreadsheet_id", "report_persistent_url",
+ "report_reference_url"
+ ]
+ job = model.JobModel()
+ for attr in attrs:
+ setattr(job, attr, attr)
+ result = self.eb.GetAttributes(job, assigned_only=False)
+ self.assertTrue(set(full_attrs) <= set(result))
+
+ def testGetSingleEntity(self):
+ """Asserts to get a single entity."""
+ device = self.GenerateDeviceModel()
+ device.put()
+
+ request_body = (endpoints.ResourceContainer(
+ model.GetRequestMessage).combined_message_class(
+ size=0,
+ offset=0,
+ filter="",
+ sort="",
+ direction="",
+ ))
+ result, more = self.eb.Get(
+ request=request_body,
+ metaclass=model.DeviceModel,
+ message=model.DeviceInfoMessage)
+ self.assertEqual(len(result), 1)
+ self.assertFalse(more)
+
+ def testGetHundredEntities(self):
+ """Asserts to get hundred entities."""
+ for _ in xrange(100):
+ device = self.GenerateDeviceModel()
+ device.put()
+
+ request_body = (endpoints.ResourceContainer(
+ model.GetRequestMessage).combined_message_class(
+ size=0,
+ offset=0,
+ filter="",
+ sort="",
+ direction="",
+ ))
+ result, more = self.eb.Get(
+ request=request_body,
+ metaclass=model.DeviceModel,
+ message=model.DeviceInfoMessage)
+ self.assertEqual(len(result), 100)
+ self.assertFalse(more)
+
+ def testGetEntitiesWithPagination(self):
+ """Asserts to get entities with pagination."""
+ for _ in xrange(100):
+ device = self.GenerateDeviceModel()
+ device.put()
+
+ request_body = (endpoints.ResourceContainer(
+ model.GetRequestMessage).combined_message_class(
+ size=60,
+ offset=0,
+ filter="",
+ sort="",
+ direction="",
+ ))
+ result, more = self.eb.Get(
+ request=request_body,
+ metaclass=model.DeviceModel,
+ message=model.DeviceInfoMessage)
+ self.assertEqual(len(result), 60)
+ self.assertTrue(more)
+
+ request_body = (endpoints.ResourceContainer(
+ model.GetRequestMessage).combined_message_class(
+ size=100,
+ offset=60,
+ filter="",
+ sort="",
+ direction="",
+ ))
+ result, more = self.eb.Get(
+ request=request_body,
+ metaclass=model.DeviceModel,
+ message=model.DeviceInfoMessage)
+ self.assertEqual(len(result), 40)
+ self.assertFalse(more)
+
+ def testGetWithFilter(self):
+ """Asserts to get entities with filter."""
+ for _ in xrange(50):
+ device = self.GenerateDeviceModel()
+ device.put()
+
+ for _ in xrange(50):
+ device = self.GenerateDeviceModel(product="product")
+ device.put()
+
+ filter = [{
+ "key": "product",
+ "method": Status.FILTER_METHOD[Status.FILTER_EqualTo],
+ "value": "product"
+ }]
+ filter_string = json.dumps(filter)
+ request_body = (endpoints.ResourceContainer(
+ model.GetRequestMessage).combined_message_class(
+ size=0,
+ offset=0,
+ filter=filter_string,
+ sort="",
+ direction="",
+ ))
+ result, more = self.eb.Get(
+ request=request_body,
+ metaclass=model.DeviceModel,
+ message=model.DeviceInfoMessage)
+ self.assertEqual(len(result), 50)
+ self.assertFalse(more)
+
+ def testGetWithSort(self):
+ """Asserts to get entities with sort."""
+ for _ in xrange(100):
+ device = self.GenerateDeviceModel()
+ device.put()
+
+ request_body = (endpoints.ResourceContainer(
+ model.GetRequestMessage).combined_message_class(
+ size=0,
+ offset=0,
+ filter="",
+ sort="serial",
+ direction="asc",
+ ))
+
+ result, more = self.eb.Get(
+ request=request_body,
+ metaclass=model.DeviceModel,
+ message=model.DeviceInfoMessage)
+ self.assertEqual(len(result), 100)
+ for i in xrange(len(result) - 1):
+ self.assertTrue(result[i]["serial"] < result[i + 1]["serial"])
+
+ request_body = (endpoints.ResourceContainer(
+ model.GetRequestMessage).combined_message_class(
+ size=0,
+ offset=0,
+ filter="",
+ sort="serial",
+ direction="desc",
+ ))
+
+ result, more = self.eb.Get(
+ request=request_body,
+ metaclass=model.DeviceModel,
+ message=model.DeviceInfoMessage)
+ self.assertEqual(len(result), 100)
+ for i in xrange(len(result) - 1):
+ self.assertTrue(result[i]["serial"] > result[i + 1]["serial"])
+
+
+if __name__ == "__main__":
+ unittest.main()