blob: 721d048d65d5b1e1e384079419e6dd5b7185bb2a [file] [log] [blame] [edit]
"""DataManager to share data between steps."""
from __future__ import absolute_import
import copy
import six
from helpers import repo_utils
class DataManager(object):
"""Manages passing data between build steps."""
def __init__(self, executor, properties):
self._data_store = {}
self._executor = executor
self._project_path_lookup = None
self._project_remote_lookup = None
self._project_revision_lookup = None
self._properties = properties
def add_data(self, key, data, replace=False):
"""Adds data to be retrieved later by other DataManager consumers.
Args:
key: string value to associate the stored data with. Must not already
have a value stored with that key.
data: any value, not cloned.
replace: if True, any data stored in |key| will be replaced.
if False, a KeyError is raised if |key| already has a value.
(defaults to False)
Raises:
KeyError: if the key already has an associated value and replace is False.
"""
assert isinstance(key, six.string_types)
if self._data_store.get(key) == data:
return
if key in self._data_store and not replace:
raise KeyError(key)
self._data_store[key] = data
def clear_project_info(self):
"""Clears the project info to be reset in the next run."""
self._project_path_lookup = None
self._project_remote_lookup = None
self._project_revision_lookup = None
def contains_project(self, project):
"""Boolean if the given project is included in the repo client."""
if not self._project_path_lookup:
self._project_path_lookup = repo_utils.get_project_path_lookup(
self._executor, self.manifest_url)
return project in self._project_path_lookup
def ensure_project_info(self):
"""Initializes all project lookup tables."""
if self._project_path_lookup:
return
manifest_content = repo_utils.get_manifest(self._executor)
self._project_path_lookup = repo_utils.get_project_path_lookup(
self._executor, self.manifest_url, manifest=manifest_content)
self._project_remote_lookup = repo_utils.get_project_remote_lookup(
self._executor, self.manifest_url, manifest=manifest_content)
self._project_revision_lookup = repo_utils.get_project_revision_lookup(
self._executor, self.manifest_url, self.manifest_branch,
manifest=manifest_content)
def get_data(self, key, default=None):
"""Returns previously stored data associated with a known key.
Args:
key: string key the data is associated with.
default: Default return value if the key is not found.
Returns:
The value associated with the provided key.
"""
assert isinstance(key, six.string_types)
return self._data_store.get(key, default)
def get_project_path(self, project):
"""Returns the path to the project.
Args:
project: The name of the project (usually patch_project) to lookup.
Returns:
The relative path from the base checkout to the project.
"""
self.ensure_project_info()
return self._project_path_lookup.get(project)
def get_project_path_lookup_table(self):
"""Returns the project path lookup table."""
self.ensure_project_info()
return self._project_path_lookup
def get_project_remote(self, project):
"""Returns the remote for the project.
Args:
project: The name of the project (usually patch_project) to lookup.
Returns:
The remote for the project.
"""
self.ensure_project_info()
return self._project_remote_lookup.get(project)
def get_project_revision(self, project):
"""Returns the revision for the project.
Args:
project: The name of the project (usually patch_project) to lookup.
Returns:
The revision for the project.
"""
self.ensure_project_info()
return self._project_revision_lookup.get(project)
def get_property(self, name, default=None):
"""Returns the requested property value, or default if not found."""
if not self._properties:
return default
return self._properties.get(name, default)
@property
def manifest_branch(self):
"""Returns the manifest branch to use for this step."""
return self.get_property('manifest_branch')
@property
def manifest_url(self):
"""Returns the manifest url to use for this step."""
return self.get_property('manifest_url')
class DataManagerMock(DataManager):
"""Mocks data passing between steps for tests."""
def get_all_data(self):
return copy.deepcopy(self._data_store)