Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion bits_helpers/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def resolve_spec_data(spec, data, defaults, branch_basename="", branch_stream=""
# final: %%(%(v1)s_key)s
# "final" will have the value "bar" (first expanded to "%(foo_key)s" and
# then to value of "foo_key" i.e. "bar")
while re.search("\%\([a-zA-Z][a-zA-Z0-9_]*\)s", data):
while re.search(r"\%\([a-zA-Z][a-zA-Z0-9_]*\)s", data):
data = data % all_vars
return data

Expand Down Expand Up @@ -421,11 +421,16 @@ def represent_ordereddict(dumper, data):

def parseRecipe(reader):
assert(reader.__call__)
filename = os.path.basename(getattr(reader, "url", None) or "")[:-3] if (getattr(reader, "url", None) or "").endswith(".sh") else os.path.basename(getattr(reader, "url", None) or "")
err, spec, recipe = (None, None, None)
try:
d = reader()
header,recipe = d.split("---", 1)
spec = yamlLoad(header)
if not filename:
filename = spec["package"]
if "from" in spec:
spec = getSpecFromDir(spec, filename, os.path.join(os.environ.get("BITS_REPO_DIR", ""), spec["from"]))
validateSpec(spec)
except RuntimeError as e:
err = str(e)
Expand Down Expand Up @@ -728,6 +733,81 @@ def getGeneratedPackages(configDir):
x=sys.path.pop(0)
return pkgs

def getSpecFromDir(override_spec, pkg, configDir, visited=None):
if visited is None:
visited = set()
if len(visited) >= len(getConfigPaths(os.environ.get("BITS_REPO_DIR"))):
raise RuntimeError("Circular dependency detected")
genPackages = getGeneratedPackages(configDir)
filename, pkgdir = resolveFilename({}, pkg, configDir, genPackages)
if pkgdir in visited:
raise RuntimeError("Circular dependency detected")
visited.add(pkgdir)
reader = getRecipeReader(filename, configDir, genPackages)
d = reader()
# Handle auto-generated packages that may not have "---" separator
if "---" in d:
header, recipe = d.split("---", 1)
else:
# For auto-generated packages, treat entire content as header
header = d
recipe = ""
spec = yamlLoad(header)
if "from" in spec:
new_config_dir = os.path.join(os.path.dirname(configDir), spec["from"])
final_base = getSpecFromDir(spec, pkg, new_config_dir, visited)
return handleMergePolicy(override_spec, final_base)
return handleMergePolicy(override_spec, spec)

def handleMergePolicy(override_spec, final_base):
mergePolicy = override_spec.get("merge_policy", {})
remove_keys = mergePolicy.get("remove", [])
if isinstance(remove_keys, str):
remove_keys = remove_keys.replace(" ", "").split(",")
for k in remove_keys:
if k in final_base:
final_base.pop(k, None)
merge_keys = mergePolicy.get("merge", [])
if isinstance(merge_keys, str):
merge_keys = merge_keys.replace(" ", "").split(",")
override_spec.pop("merge_policy", None)
override_spec.pop("from", None)

for key in merge_keys:
if key not in override_spec:
continue
if key not in final_base:
final_base[key] = override_spec[key]
else:
if isinstance(final_base[key], OrderedDict) and isinstance(override_spec[key], OrderedDict):
merged = final_base[key].copy()
merged.update(override_spec[key])
final_base[key] = merged
else:
raise ValueError(f"Merge key not allowed for {key} as it's of type {type(final_base.get(key, 'unknown'))}")
for k, v in override_spec.items():
final_base[k] = override_spec[k]
return final_base

def getRecipeFromDir(pkg, configDir, visited=None):
if visited is None:
visited = set()
if len(visited) >= len(getConfigPaths(os.environ.get("BITS_REPO_DIR"))):
raise RuntimeError("Circular dependency detected")
genPackages = getGeneratedPackages(configDir)
filename, pkgdir = resolveFilename({}, pkg, configDir, genPackages)
if pkgdir in visited:
raise RuntimeError("Circular dependency detected")
visited.add(pkgdir)
reader = getRecipeReader(filename, configDir, genPackages)
d = reader()
header, recipe = d.split("---", 1)
spec = yamlLoad(header)
if "inherits_body" in spec:
new_config_dir = os.path.join(os.path.dirname(configDir), spec["inherits_body"])
return getRecipeFromDir(pkg, new_config_dir, visited)
return recipe

class Hasher:
def __init__(self) -> None:
self.h = hashlib.sha1()
Expand Down
102 changes: 102 additions & 0 deletions tests/test_getspecfromdir_autogenerated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#!/usr/bin/env python3
"""
Unit test for the getSpecFromDir fix for auto-generated packages.
"""

import unittest
import os
import sys
import tempfile
from collections import OrderedDict
from unittest.mock import patch

# Add the bits_helpers to the path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

from bits_helpers.utilities import getSpecFromDir, getGeneratedPackages

class TestGetSpecFromDirAutoGenerated(unittest.TestCase):
"""Test cases for getSpecFromDir with auto-generated packages."""

def setUp(self):
"""Set up temporary directories for testing."""
self.temp_dir = tempfile.mkdtemp()
self.config_dir = os.path.join(self.temp_dir, "config")
os.makedirs(self.config_dir)
self.pkg_dir = os.path.join(self.config_dir, "autogen")
os.makedirs(self.pkg_dir)

# Set environment variable
self.old_bits_repo_dir = os.environ.get("BITS_REPO_DIR")
os.environ["BITS_REPO_DIR"] = self.temp_dir

def tearDown(self):
"""Clean up temporary directories."""
import shutil
shutil.rmtree(self.temp_dir)

# Restore environment variable
if self.old_bits_repo_dir is not None:
os.environ["BITS_REPO_DIR"] = self.old_bits_repo_dir
else:
os.environ.pop("BITS_REPO_DIR", None)

def test_autogenerated_package_without_separator(self):
"""Test getSpecFromDir with auto-generated package that has no '---' separator."""
# Create a packages.py that generates a package with only header
packages_py = os.path.join(self.pkg_dir, "packages.py")
with open(packages_py, 'w') as f:
f.write('''
def getPackages(pkgs, pkgdir):
pkgs["testpkg"] = {
"command": "echo 'package: testpkg\\nversion: v1.0\\ntag: v1.0'",
"version": "v1.0",
"pkgdir": pkgdir
}
''')

# Create a spec that uses "from" to inherit from the auto-generated package
override_spec = OrderedDict([
("package", "testpkg"),
("from", "autogen"),
("version", "v2.0")
])

# This should work without throwing an exception
result = getSpecFromDir(override_spec, "testpkg", self.config_dir)

# Verify the result
self.assertEqual(result.get("package"), "testpkg")
self.assertEqual(result.get("version"), "v2.0") # Override should be applied
self.assertEqual(result.get("tag"), "v1.0") # Base should be inherited

def test_autogenerated_package_with_separator(self):
"""Test getSpecFromDir with auto-generated package that has '---' separator."""
# Create a packages.py that generates a package with header and body
packages_py = os.path.join(self.pkg_dir, "packages.py")
with open(packages_py, 'w') as f:
f.write('''
def getPackages(pkgs, pkgdir):
pkgs["testpkg2"] = {
"command": "echo 'package: testpkg2\\nversion: v1.0\\n---\\necho \\"Building package\\"'",
"version": "v1.0",
"pkgdir": pkgdir
}
''')

# Create a spec that uses "from" to inherit from the auto-generated package
override_spec = OrderedDict([
("package", "testpkg2"),
("from", "autogen"),
("version", "v2.0")
])

# This should work with packages that have "---" separator
result = getSpecFromDir(override_spec, "testpkg2", self.config_dir)

# Verify the result
self.assertEqual(result.get("package"), "testpkg2")
self.assertEqual(result.get("version"), "v2.0") # Override should be applied

if __name__ == '__main__':
unittest.main()