diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 59c65dc47..febb49b03 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,5 +24,26 @@ jobs: pip install --upgrade pip pip install --upgrade setuptools make install-python - - name: Test - run: make test + - name: Start services + run: | + make start-containers + make docker-age + + - name: Test (Neo4j Backend) + env: + GRAPH_DB_TYPE: neo4j + run: | + make test + + - name: Test (Apache AGE Backend) + env: + GRAPH_DB_TYPE: age + run: | + make test + + - name: Stop services + if: always() + run: | + make docker-neo4j-rm + make docker-redis-rm + make docker-age-rm diff --git a/Makefile b/Makefile index fa11162e7..bd3505757 100644 --- a/Makefile +++ b/Makefile @@ -15,6 +15,13 @@ docker-neo4j-rm: docker-neo4j: docker start cre-neo4j 2>/dev/null || docker run -d --name cre-neo4j --env NEO4J_PLUGINS='["apoc"]' --env NEO4J_AUTH=neo4j/password --volume=`pwd`/.neo4j/data:/data --volume=`pwd`/.neo4j/logs:/logs --workdir=/var/lib/neo4j -p 7474:7474 -p 7687:7687 neo4j +docker-age-rm: + docker stop cre-age + docker rm -f cre-age + +docker-age: + docker start cre-age 2>/dev/null || docker run -d --name cre-age -p 5433:5432 -e POSTGRES_PASSWORD=password apache/age:latest + docker-redis-rm: docker stop cre-redis-stack docker rm -f cre-redis-stack @@ -123,9 +130,11 @@ import-projects: import-all: $(shell bash ./scripts/import-all.sh) -import-neo4j: +import-graph: [ -d "./venv" ] && . ./venv/bin/activate &&\ - export FLASK_APP="$(CURDIR)/cre.py" && python cre.py --populate_neo4j_db + export FLASK_APP="$(CURDIR)/cre.py" && python cre.py --populate_graph_db + +import-neo4j: import-graph preload-map-analysis: $(shell RUN_COUNT=5 bash ./scripts/preload_gap_analysis.sh) diff --git a/application/cmd/cre_main.py b/application/cmd/cre_main.py index ead5a4281..cfc1e7556 100644 --- a/application/cmd/cre_main.py +++ b/application/cmd/cre_main.py @@ -1,722 +1,727 @@ -import time -import argparse -import json -import logging -import os -import shutil -import yaml -import tempfile -import requests - -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple -from rq import Queue, job, exceptions -from dacite import from_dict -from dacite.config import Config - -from application.utils.external_project_parsers.base_parser import BaseParser -from application import create_app # type: ignore -from application.config import CMDConfig -from application.database import db -from application.defs import cre_defs as defs -from application.defs import osib_defs as odefs -from application.utils import spreadsheet as sheet_utils -from application.utils import redis -from application.utils import spreadsheet_parsers -from alive_progress import alive_bar -from application.prompt_client import prompt_client as prompt_client -from application.utils import gap_analysis - -logging.basicConfig() -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - -app = None - - -def register_node(node: defs.Node, collection: db.Node_collection) -> db.Node: - """ - for each link find if either the root node or the link have a CRE, - then map the one who doesn't to the CRE - if both don't map to anything, just add them in the db as unlinked nodes - """ - if not node or not issubclass(node.__class__, defs.Node): - raise ValueError(f"node is None or not of type Node, node: {node}") - - linked_node = collection.add_node(node) - if node.embeddings: - collection.add_embedding( - linked_node, - doctype=node.doctype, - embeddings=node.embeddings, - embedding_text=node.embeddings_text, - ) - cre_less_nodes: List[defs.Node] = [] - - # we need to know the cres added in case we encounter a higher level CRE, - # in which case we get the higher level CRE to link to these cres - cres_added = [] - - for link in node.links: - if type(link.document).__name__ in [ - defs.Standard.__name__, - defs.Code.__name__, - defs.Tool.__name__, - ]: - # if a node links another node it is likely that a writer wants to reference something - # in that case, find which of the two nodes has at least one CRE attached to it and link both to the parent CRE - cres = collection.find_cres_of_node(link.document) - db_link = collection.add_node(link.document) - if cres: - for cre in cres: - collection.add_link(cre=cre, node=linked_node, ltype=link.ltype) - for unlinked_standard in cre_less_nodes: # if anything in this - collection.add_link( - cre=cre, - node=db.dbNodeFromNode(unlinked_standard), - ltype=link.ltype, - ) - else: - cres = collection.find_cres_of_node(linked_node) - if cres: - for cre in cres: - collection.add_link(cre=cre, node=db_link, ltype=link.ltype) - for unlinked_node in cre_less_nodes: - collection.add_link( - cre=cre, - node=db.dbNodeFromNode(unlinked_node), - ltype=link.ltype, - ) - else: # if neither the root nor a linked node has a CRE, add both as unlinked nodes - cre_less_nodes.append(link.document) - - if link.document.links and len(link.document.links) > 0: - register_node(node=link.document, collection=collection) - - elif type(link.document).__name__ == defs.CRE.__name__: - # dbcre,_ = register_cre(link.document, collection) # CREs are idempotent - c = collection.get_CREs(name=link.document.name)[0] - dbcre = db.dbCREfromCRE(c) - collection.add_link(dbcre, linked_node, ltype=link.ltype) - cres_added.append(dbcre) - for unlinked_standard in cre_less_nodes: # if anything in this - collection.add_link( - cre=dbcre, - node=db.dbNodeFromNode(unlinked_standard), - ltype=link.ltype, - ) - cre_less_nodes = [] - - return linked_node - - -def register_cre(cre: defs.CRE, collection: db.Node_collection) -> Tuple[db.CRE, bool]: - collection = collection.with_graph() - existing = False - if collection.get_CREs(name=cre.id): - existing = True - - dbcre: db.CRE = collection.add_cre(cre) - for link in cre.links: - if type(link.document) == defs.CRE: - other_cre, _ = register_cre(link.document, collection) - - # the following flips the PartOf relationship so that we only have contains relationship in the database - if link.ltype == defs.LinkTypes.Contains: - collection.add_internal_link( - higher=dbcre, - lower=other_cre, - ltype=defs.LinkTypes.Contains, - ) - elif link.ltype == defs.LinkTypes.PartOf: - collection.add_internal_link( - higher=other_cre, - lower=dbcre, - ltype=defs.LinkTypes.Contains, - ) - elif link.ltype == defs.LinkTypes.Related: - collection.add_internal_link( - higher=other_cre, - lower=dbcre, - ltype=defs.LinkTypes.Related, - ) - else: - raise ValueError(f"Unknown link type {link.ltype}") - else: - collection.add_link( - cre=dbcre, - node=register_node(node=link.document, collection=collection), - ltype=link.ltype, - ) - return dbcre, existing - - -def parse_file( - filename: str, yamldocs: List[Dict[str, Any]], scollection: db.Node_collection -) -> Optional[List[defs.Document]]: - """given yaml from export format deserialise to internal standards format and add standards to db""" - - resulting_objects = [] - for contents in yamldocs: - links = [] - - document: Optional[defs.Document] = None - register_callback: Optional[Callable[[Any, Any], Any]] = None - - if not isinstance( - contents, dict - ): # basic object matching, make sure we at least have an object, golang has this build in :( - logger.fatal("Malformed file %s, skipping" % filename) - return None - - if contents.get("links"): - links = contents.pop("links") - - if contents.get("doctype") == defs.Credoctypes.CRE.value: - document = from_dict( - data_class=defs.CRE, - data=contents, - config=Config(cast=[defs.Credoctypes]), - ) - # document = defs.CRE(**contents) - register_callback = register_cre - elif contents.get("doctype") in ( - defs.Credoctypes.Standard.value, - defs.Credoctypes.Code.value, - defs.Credoctypes.Tool.value, - ): - # document = defs.Standard(**contents) - doctype = contents.get("doctype") - data_class = ( - defs.Standard - if doctype == defs.Credoctypes.Standard.value - else ( - defs.Code - if doctype == defs.Credoctypes.Code.value - else defs.Tool if doctype == defs.Credoctypes.Tool.value else None - ) - ) - document = from_dict( - data_class=data_class, - data=contents, - config=Config(cast=[defs.Credoctypes]), - ) - register_callback = register_node - - for link in links: - doclink = parse_file( - filename=filename, - yamldocs=[link.get("document")], - scollection=scollection, - ) - - if doclink: - if len(doclink) > 1: - logger.fatal( - "Parsing single document returned 2 results this is a bug" - ) - document.add_link( - defs.Link( - document=doclink[0], - ltype=link.get("type"), - tags=link.get("tags"), - ) - ) - if register_callback: - register_callback(document, collection=scollection) # type: ignore - else: - logger.warning("Callback to register Document is None, likely missing data") - resulting_objects.append(document) - return resulting_objects - - -def register_standard( - standard_entries: List[defs.Standard], - collection: db.Node_collection = None, - generate_embeddings=True, - calculate_gap_analysis=True, - db_connection_str: str = "", -): - if os.environ.get("CRE_NO_GEN_EMBEDDINGS"): - generate_embeddings = False - - if not standard_entries: - logger.warning("register_standard() called with no standard_entries") - return - - if collection is None: - collection = db_connect(path=db_connection_str) - - conn = redis.connect() - ph = prompt_client.PromptHandler(database=collection) - importing_name = standard_entries[0].name - standard_hash = gap_analysis.make_resources_key([importing_name]) - if calculate_gap_analysis and conn.get(standard_hash): - logger.info( - f"Standard importing job with info-hash {standard_hash} has already returned, skipping" - ) - return - logger.info( - f"Registering resource {importing_name} of length {len(standard_entries)}" - ) - for node in standard_entries: - if not node: - logger.info( - f"encountered empty node while importing {standard_entries[0].name}" - ) - continue - register_node(node, collection) - if node.embeddings: - logger.debug( - f"node has embeddings populated, skipping generation for resource {importing_name}" - ) - generate_embeddings = False - if generate_embeddings and importing_name: - ph.generate_embeddings_for(importing_name) - - if calculate_gap_analysis and not os.environ.get("CRE_NO_CALCULATE_GAP_ANALYSIS"): - # calculate gap analysis - populate_neo4j_db(db_connection_str) - jobs = [] - pending_stadards = collection.standards() - for standard_name in pending_stadards: - if standard_name == importing_name: - continue - - fw_key = gap_analysis.make_resources_key([importing_name, standard_name]) - if not collection.gap_analysis_exists(fw_key): - fw_job = gap_analysis.schedule( - standards=[importing_name, standard_name], database=collection - ) - forward_job_id = fw_job.get("job_id") - try: - forward_job = job.Job.fetch(id=forward_job_id, connection=conn) - jobs.append(forward_job) - except exceptions.NoSuchJobError as nje: - logger.error( - f"Could not find gap analysis job for for {importing_name} and {standard_name} putting {standard_name} back in the queue" - ) - pending_stadards.append(standard_name) - - bw_key = gap_analysis.make_resources_key([standard_name, importing_name]) - if not collection.gap_analysis_exists(bw_key): - bw_job = gap_analysis.schedule( - standards=[standard_name, importing_name], database=collection - ) - backward_job_id = bw_job.get("job_id") - try: - backward_job = job.Job.fetch(id=backward_job_id, connection=conn) - jobs.append(backward_job) - except exceptions.NoSuchJobError as nje: - logger.error( - f"Could not find gap analysis job for for {importing_name} and {standard_name} putting {standard_name} back in the queue" - ) - pending_stadards.append(standard_name) - redis.wait_for_jobs(jobs) - conn.set(standard_hash, value="") - - -def parse_standards_from_spreadsheeet( - cre_file: List[Dict[str, Any]], - cache_location: str, - prompt_handler: prompt_client.PromptHandler, -) -> None: - """given a yaml with standards, build a list of standards in the db""" - collection = db_connect(cache_location) - if any(key.startswith("CRE hierarchy") for key in cre_file[0].keys()): - conn = redis.connect() - collection = collection.with_graph() - redis.empty_queues(conn) - q = Queue(connection=conn) - docs = spreadsheet_parsers.parse_hierarchical_export_format(cre_file) - total_resources = docs.keys() - jobs = [] - logger.info(f"Importing {len(docs.get(defs.Credoctypes.CRE.value))} CREs") - - with alive_bar(len(docs.get(defs.Credoctypes.CRE.value))) as bar: - for cre in docs.pop(defs.Credoctypes.CRE.value): - register_cre(cre, collection) - bar() - - if not os.environ.get("CRE_NO_NEO4J"): - populate_neo4j_db(cache_location) - if not os.environ.get("CRE_NO_GEN_EMBEDDINGS"): - prompt_handler.generate_embeddings_for(defs.Credoctypes.CRE.value) - - import_only = [] - if os.environ.get("CRE_ROOT_CSV_IMPORT_ONLY", None): - import_list = os.environ.get("CRE_ROOT_CSV_IMPORT_ONLY") - try: - import_list_json = json.loads(import_list) - except json.JSONDecodeError as jde: - env_value = os.environ.get("CRE_ROOT_CSV_IMPORT_ONLY") - logger.error(f"value '{env_value}' is not valid json") - raise jde - if type(import_list_json) == list: - import_only.extend(import_list_json) - else: - logger.warning( - f"CRE_ROOT_CSV_IMPORT_ONLY should be a list of standards to import, received {type(import_list_json)} {import_list}" - ) - database = db_connect(cache_location) - for standard_name, standard_entries in docs.items(): - if os.environ.get("CRE_NO_REIMPORT_IF_EXISTS") and database.get_nodes( - name=standard_name - ): - logger.info( - f"Already know of {standard_name} and CRE_NO_REIMPORT_IF_EXISTS is set, skipping" - ) - continue - if import_only and standard_name not in import_only: - logger.info( - f"skipping standard {standard_name} as it's not in the list of {import_only}" - ) - continue - jobs.append( - q.enqueue_call( - description=standard_name, - func=register_standard, - kwargs={ - "standard_entries": standard_entries, - "collection": None, - "db_connection_str": cache_location, - }, - timeout=gap_analysis.GAP_ANALYSIS_TIMEOUT, - ) - ) - t0 = time.perf_counter() - total_standards = len(jobs) - logger.info(f"Importing {total_standards} Standards") - with alive_bar(theme="classic", total=total_standards) as bar: - redis.wait_for_jobs(jobs, bar) - logger.info( - f"imported {total_standards} standards in {time.perf_counter()-t0} seconds" - ) - return total_resources - else: - logger.fatal(f"could not find any useful keys { cre_file[0].keys()}") - - -def get_cre_files_from_disk(cre_loc: str) -> Generator[str, None, None]: - for root, _, cre_docs in os.walk(cre_loc): - for name in cre_docs: - if name.endswith(".yaml") or name.endswith(".yml"): - yield os.path.join(root, name) - - -def add_from_spreadsheet(spreadsheet_url: str, cache_loc: str, cre_loc: str) -> None: - """--add --from_spreadsheet - use the cre db in this repo - import new mappings from - export db to ../../cres/ - """ - database = db_connect(path=cache_loc) - prompt_handler = ai_client_init(database=database) - spreadsheet = sheet_utils.read_spreadsheet( - url=spreadsheet_url, alias="new spreadsheet", validate=False - ) - for _, contents in spreadsheet.items(): - parse_standards_from_spreadsheeet(contents, cache_loc, prompt_handler) - - logger.info( - "Db located at %s got updated, files extracted at %s" % (cache_loc, cre_loc) - ) - - -def add_from_disk(cache_loc: str, cre_loc: str) -> None: - """--add --cre_loc - use the cre db in this repo - import new mappings from - export db to ../../cres/ - """ - database = db_connect(path=cache_loc) - for file in get_cre_files_from_disk(cre_loc): - with open(file, "rb") as standard: - parse_file( - filename=file, - yamldocs=list(yaml.safe_load_all(standard)), - scollection=database, - ) - - -def review_from_spreadsheet(cache: str, spreadsheet_url: str, share_with: str) -> None: - """--review --from_spreadsheet - copy db to new temp dir, - import new mappings from spreadsheet - export db to tmp dir - create new spreadsheet of the new CRE landscape for review - """ - loc, cache = prepare_for_review(cache) - database = db_connect(path=cache) - prompt_handler = ai_client_init(database=database) - spreadsheet = sheet_utils.read_spreadsheet( - url=spreadsheet_url, alias="new spreadsheet", validate=False - ) - for _, contents in spreadsheet.items(): - parse_standards_from_spreadsheeet(contents, database, prompt_handler) - - logger.info( - "Stored temporary files and database in %s if you want to use them next time, set cache to the location of the database in that dir" - % loc - ) - # logger.info("A spreadsheet view is at %s" % sheet_url) - - -def download_graph_from_upstream(cache: str) -> None: - imported_cres = {} - collection = db_connect(path=cache).with_graph() - - def download_cre_from_upstream(creid: str): - cre_response = requests.get( - os.environ.get("CRE_UPSTREAM_API_URL", "https://opencre.org/rest/v1") - + f"/id/{creid}" - ) - if cre_response.status_code != 200: - raise RuntimeError( - f"cannot connect to upstream status code {cre_response.status_code}" - ) - data = cre_response.json() - credict = data["data"] - cre = defs.Document.from_dict(credict) - if cre.id in imported_cres: - return - - register_cre(cre, collection) - imported_cres[cre.id] = "" - for link in cre.links: - if link.document.doctype == defs.Credoctypes.CRE: - download_cre_from_upstream(link.document.id) - - root_cres_response = requests.get( - os.environ.get("CRE_UPSTREAM_API_URL", "https://opencre.org/rest/v1") - + "/root_cres" - ) - if root_cres_response.status_code != 200: - raise RuntimeError( - f"cannot connect to upstream status code {root_cres_response.status_code}" - ) - data = root_cres_response.json() - for root_cre in data["data"]: - cre = defs.Document.from_dict(root_cre) - register_cre(cre, collection) - imported_cres[cre.id] = "" - for link in cre.links: - if link.document.doctype == defs.Credoctypes.CRE: - download_cre_from_upstream(link.document.id) - - -# def review_from_disk(cache: str, cre_file_loc: str, share_with: str) -> None: -# """--review --cre_loc -# copy db to new temp dir, -# import new mappings from yaml files defined in -# export db to tmp dir -# create new spreadsheet of the new CRE landscape for review -# """ -# loc, cache = prepare_for_review(cache) -# database = db_connect(path=cache) -# for file in get_cre_files_from_disk(cre_file_loc): -# with open(file, "rb") as standard: -# parse_file( -# filename=file, -# yamldocs=list(yaml.safe_load_all(standard)), -# scollection=database, -# ) - -# sheet_url = create_spreadsheet( -# collection=database, -# exported_documents=docs, -# title="cre_review", -# share_with=[share_with], -# ) -# logger.info( -# "Stored temporary files and database in %s if you want to use them next time, set cache to the location of the database in that dir" -# % loc -# ) -# logger.info("A spreadsheet view is at %s" % sheet_url) - - -def run(args: argparse.Namespace) -> None: # pragma: no cover - script_path = os.path.dirname(os.path.realpath(__file__)) - os.path.join(script_path, "../cres") - - # if args.review and args.from_spreadsheet: - # review_from_spreadsheet( - # cache=args.cache_file, - # spreadsheet_url=args.from_spreadsheet, - # share_with=args.email, - # ) - # elif args.review and args.cre_loc: - # review_from_disk( - # cache=args.cache_file, cre_file_loc=args.cre_loc, share_with=args.email - # ) - if args.add and args.from_spreadsheet: - add_from_spreadsheet( - spreadsheet_url=args.from_spreadsheet, - cache_loc=args.cache_file, - cre_loc=args.cre_loc, - ) - elif args.add and args.cre_loc and not args.from_spreadsheet: - add_from_disk(cache_loc=args.cache_file, cre_loc=args.cre_loc) - # elif args.review and args.osib_in: - # review_osib_from_file( - # file_loc=args.osib_in, cache=args.cache_file, cre_loc=args.cre_loc - # ) - - # elif args.add and args.osib_in: - # add_osib_from_file( - # file_loc=args.osib_in, cache=args.cache_file, cre_loc=args.cre_loc - # ) - - # elif args.osib_out: - # export_to_osib(file_loc=args.osib_out, cache=args.cache_file) - - if args.delete_map_analysis_for: - cache = db_connect(args.cache_file) - cache.delete_gapanalysis_results_for(args.delete_map_analysis_for) - if args.delete_resource: - cache = db_connect(args.cache_file) - cache.delete_nodes(args.delete_resource) - - # individual resource importing - if args.zap_in: - from application.utils.external_project_parsers.parsers import zap_alerts_parser - - BaseParser().register_resource( - zap_alerts_parser.ZAP, db_connection_str=args.cache_file - ) - if args.cheatsheets_in: - from application.utils.external_project_parsers.parsers import ( - cheatsheets_parser, - ) - - BaseParser().register_resource( - cheatsheets_parser.Cheatsheets, db_connection_str=args.cache_file - ) - if args.github_tools_in: - from application.utils.external_project_parsers.parsers import misc_tools_parser - - BaseParser().register_resource( - misc_tools_parser.MiscTools, db_connection_str=args.cache_file - ) - if args.capec_in: - from application.utils.external_project_parsers.parsers import capec_parser - - BaseParser().register_resource( - capec_parser.Capec, db_connection_str=args.cache_file - ) - if args.cwe_in: - from application.utils.external_project_parsers.parsers import cwe - - BaseParser().register_resource(cwe.CWE, db_connection_str=args.cache_file) - if args.csa_ccm_v4_in: - from application.utils.external_project_parsers.parsers import ccmv4 - - BaseParser().register_resource( - ccmv4.CloudControlsMatrix, db_connection_str=args.cache_file - ) - if args.iso_27001_in: - from application.utils.external_project_parsers.parsers import iso27001 - - BaseParser().register_resource( - iso27001.ISO27001, db_connection_str=args.cache_file - ) - if args.owasp_secure_headers_in: - from application.utils.external_project_parsers.parsers import secure_headers - - BaseParser().register_resource( - secure_headers.SecureHeaders, db_connection_str=args.cache_file - ) - if args.pci_dss_4_in: - from application.utils.external_project_parsers.parsers import pci_dss - - BaseParser().register_resource( - pci_dss.PciDss, db_connection_str=args.cache_file - ) - if args.juiceshop_in: - from application.utils.external_project_parsers.parsers import juiceshop - - BaseParser().register_resource( - juiceshop.JuiceShop, db_connection_str=args.cache_file - ) - if args.dsomm_in: - from application.utils.external_project_parsers.parsers import dsomm - - BaseParser().register_resource(dsomm.DSOMM, db_connection_str=args.cache_file) - if args.cloud_native_security_controls_in: - from application.utils.external_project_parsers.parsers import ( - cloud_native_security_controls, - ) - - BaseParser().register_resource( - cloud_native_security_controls.CloudNativeSecurityControls, - db_connection_str=args.cache_file, - ) - # /end individual resource importing - - if args.import_external_projects: - BaseParser().call_importers(db_connection_str=args.cache_file) - - if args.generate_embeddings: - generate_embeddings(args.cache_file) - if args.populate_neo4j_db: - populate_neo4j_db(args.cache_file) - if args.start_worker: - from application.worker import start_worker - - start_worker() - - if args.preload_map_analysis_target_url: - gap_analysis.preload(target_url=args.preload_map_analysis_target_url) - if args.upstream_sync: - download_graph_from_upstream(args.cache_file) - - -def ai_client_init(database: db.Node_collection): - return prompt_client.PromptHandler(database=database) - - -def db_connect(path: str): - global app - conf = CMDConfig(db_uri=path) - app = create_app(conf=conf) - collection = db.Node_collection() - app_context = app.app_context() - app_context.push() - logger.info(f"successfully connected to the database at {path}") - return collection - - -def create_spreadsheet( - collection: db.Node_collection, - exported_documents: List[Any], - title: str, - share_with: List[str], -) -> Any: - """Reads cre docs exported from a standards_collection.export() - dumps each doc into a workbook""" - flat_dicts = sheet_utils.prepare_spreadsheet(docs=exported_documents) - return sheet_utils.write_spreadsheet( - title=title, docs=flat_dicts, emails=share_with - ) - - -def prepare_for_review(cache: str) -> Tuple[str, str]: - loc = tempfile.mkdtemp() - cache_filename = os.path.basename(cache) - if os.path.isfile(cache): - shutil.copy(cache, loc) - else: - logger.fatal("Could not copy database %s this seems like a bug" % cache) - return loc, os.path.join(loc, cache_filename) - - -def generate_embeddings(db_url: str) -> None: - database = db_connect(path=db_url) - prompt_client.PromptHandler(database, load_all_embeddings=True) - - -def populate_neo4j_db(cache: str): - logger.info(f"Populating neo4j DB: Connecting to SQL DB") - database = db_connect(path=cache) - logger.info(f"Populating neo4j DB: Populating") - database.neo_db.populate_DB(database.session) - logger.info(f"Populating neo4j DB: Complete") +import time +import argparse +import json +import logging +import os +import shutil +import yaml +import tempfile +import requests + +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple + +try: + from rq import Queue, job, exceptions +except (ValueError, ImportError): + Queue, job, exceptions = None, None, None + +from dacite import from_dict +from dacite.config import Config + +from application.utils.external_project_parsers.base_parser import BaseParser +from application import create_app # type: ignore +from application.config import CMDConfig +from application.database import db +from application.defs import cre_defs as defs +from application.defs import osib_defs as odefs +from application.utils import spreadsheet as sheet_utils +from application.utils import redis +from application.utils import spreadsheet_parsers +from alive_progress import alive_bar +from application.prompt_client import prompt_client as prompt_client +from application.utils import gap_analysis + +logging.basicConfig() +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +app = None + + +def register_node(node: defs.Node, collection: db.Node_collection) -> db.Node: + """ + for each link find if either the root node or the link have a CRE, + then map the one who doesn't to the CRE + if both don't map to anything, just add them in the db as unlinked nodes + """ + if not node or not issubclass(node.__class__, defs.Node): + raise ValueError(f"node is None or not of type Node, node: {node}") + + linked_node = collection.add_node(node) + if node.embeddings: + collection.add_embedding( + linked_node, + doctype=node.doctype, + embeddings=node.embeddings, + embedding_text=node.embeddings_text, + ) + cre_less_nodes: List[defs.Node] = [] + + # we need to know the cres added in case we encounter a higher level CRE, + # in which case we get the higher level CRE to link to these cres + cres_added = [] + + for link in node.links: + if type(link.document).__name__ in [ + defs.Standard.__name__, + defs.Code.__name__, + defs.Tool.__name__, + ]: + # if a node links another node it is likely that a writer wants to reference something + # in that case, find which of the two nodes has at least one CRE attached to it and link both to the parent CRE + cres = collection.find_cres_of_node(link.document) + db_link = collection.add_node(link.document) + if cres: + for cre in cres: + collection.add_link(cre=cre, node=linked_node, ltype=link.ltype) + for unlinked_standard in cre_less_nodes: # if anything in this + collection.add_link( + cre=cre, + node=db.dbNodeFromNode(unlinked_standard), + ltype=link.ltype, + ) + else: + cres = collection.find_cres_of_node(linked_node) + if cres: + for cre in cres: + collection.add_link(cre=cre, node=db_link, ltype=link.ltype) + for unlinked_node in cre_less_nodes: + collection.add_link( + cre=cre, + node=db.dbNodeFromNode(unlinked_node), + ltype=link.ltype, + ) + else: # if neither the root nor a linked node has a CRE, add both as unlinked nodes + cre_less_nodes.append(link.document) + + if link.document.links and len(link.document.links) > 0: + register_node(node=link.document, collection=collection) + + elif type(link.document).__name__ == defs.CRE.__name__: + # dbcre,_ = register_cre(link.document, collection) # CREs are idempotent + c = collection.get_CREs(name=link.document.name)[0] + dbcre = db.dbCREfromCRE(c) + collection.add_link(dbcre, linked_node, ltype=link.ltype) + cres_added.append(dbcre) + for unlinked_standard in cre_less_nodes: # if anything in this + collection.add_link( + cre=dbcre, + node=db.dbNodeFromNode(unlinked_standard), + ltype=link.ltype, + ) + cre_less_nodes = [] + + return linked_node + + +def register_cre(cre: defs.CRE, collection: db.Node_collection) -> Tuple[db.CRE, bool]: + collection = collection.with_graph() + existing = False + if collection.get_CREs(name=cre.id): + existing = True + + dbcre: db.CRE = collection.add_cre(cre) + for link in cre.links: + if type(link.document) == defs.CRE: + other_cre, _ = register_cre(link.document, collection) + + # the following flips the PartOf relationship so that we only have contains relationship in the database + if link.ltype == defs.LinkTypes.Contains: + collection.add_internal_link( + higher=dbcre, + lower=other_cre, + ltype=defs.LinkTypes.Contains, + ) + elif link.ltype == defs.LinkTypes.PartOf: + collection.add_internal_link( + higher=other_cre, + lower=dbcre, + ltype=defs.LinkTypes.Contains, + ) + elif link.ltype == defs.LinkTypes.Related: + collection.add_internal_link( + higher=other_cre, + lower=dbcre, + ltype=defs.LinkTypes.Related, + ) + else: + raise ValueError(f"Unknown link type {link.ltype}") + else: + collection.add_link( + cre=dbcre, + node=register_node(node=link.document, collection=collection), + ltype=link.ltype, + ) + return dbcre, existing + + +def parse_file( + filename: str, yamldocs: List[Dict[str, Any]], scollection: db.Node_collection +) -> Optional[List[defs.Document]]: + """given yaml from export format deserialise to internal standards format and add standards to db""" + + resulting_objects = [] + for contents in yamldocs: + links = [] + + document: Optional[defs.Document] = None + register_callback: Optional[Callable[[Any, Any], Any]] = None + + if not isinstance( + contents, dict + ): # basic object matching, make sure we at least have an object, golang has this build in :( + logger.fatal("Malformed file %s, skipping" % filename) + return None + + if contents.get("links"): + links = contents.pop("links") + + if contents.get("doctype") == defs.Credoctypes.CRE.value: + document = from_dict( + data_class=defs.CRE, + data=contents, + config=Config(cast=[defs.Credoctypes]), + ) + # document = defs.CRE(**contents) + register_callback = register_cre + elif contents.get("doctype") in ( + defs.Credoctypes.Standard.value, + defs.Credoctypes.Code.value, + defs.Credoctypes.Tool.value, + ): + # document = defs.Standard(**contents) + doctype = contents.get("doctype") + data_class = ( + defs.Standard + if doctype == defs.Credoctypes.Standard.value + else ( + defs.Code + if doctype == defs.Credoctypes.Code.value + else defs.Tool if doctype == defs.Credoctypes.Tool.value else None + ) + ) + document = from_dict( + data_class=data_class, + data=contents, + config=Config(cast=[defs.Credoctypes]), + ) + register_callback = register_node + + for link in links: + doclink = parse_file( + filename=filename, + yamldocs=[link.get("document")], + scollection=scollection, + ) + + if doclink: + if len(doclink) > 1: + logger.fatal( + "Parsing single document returned 2 results this is a bug" + ) + document.add_link( + defs.Link( + document=doclink[0], + ltype=link.get("type"), + tags=link.get("tags"), + ) + ) + if register_callback: + register_callback(document, collection=scollection) # type: ignore + else: + logger.warning("Callback to register Document is None, likely missing data") + resulting_objects.append(document) + return resulting_objects + + +def register_standard( + standard_entries: List[defs.Standard], + collection: db.Node_collection = None, + generate_embeddings=True, + calculate_gap_analysis=True, + db_connection_str: str = "", +): + if os.environ.get("CRE_NO_GEN_EMBEDDINGS"): + generate_embeddings = False + + if not standard_entries: + logger.warning("register_standard() called with no standard_entries") + return + + if collection is None: + collection = db_connect(path=db_connection_str) + + conn = redis.connect() + ph = prompt_client.PromptHandler(database=collection) + importing_name = standard_entries[0].name + standard_hash = gap_analysis.make_resources_key([importing_name]) + if calculate_gap_analysis and conn.get(standard_hash): + logger.info( + f"Standard importing job with info-hash {standard_hash} has already returned, skipping" + ) + return + logger.info( + f"Registering resource {importing_name} of length {len(standard_entries)}" + ) + for node in standard_entries: + if not node: + logger.info( + f"encountered empty node while importing {standard_entries[0].name}" + ) + continue + register_node(node, collection) + if node.embeddings: + logger.debug( + f"node has embeddings populated, skipping generation for resource {importing_name}" + ) + generate_embeddings = False + if generate_embeddings and importing_name: + ph.generate_embeddings_for(importing_name) + + if calculate_gap_analysis and not os.environ.get("CRE_NO_CALCULATE_GAP_ANALYSIS"): + # calculate gap analysis + populate_graph_db(db_connection_str) + jobs = [] + pending_stadards = collection.standards() + for standard_name in pending_stadards: + if standard_name == importing_name: + continue + + fw_key = gap_analysis.make_resources_key([importing_name, standard_name]) + if not collection.gap_analysis_exists(fw_key): + fw_job = gap_analysis.schedule( + standards=[importing_name, standard_name], database=collection + ) + forward_job_id = fw_job.get("job_id") + try: + forward_job = job.Job.fetch(id=forward_job_id, connection=conn) + jobs.append(forward_job) + except exceptions.NoSuchJobError as nje: + logger.error( + f"Could not find gap analysis job for for {importing_name} and {standard_name} putting {standard_name} back in the queue" + ) + pending_stadards.append(standard_name) + + bw_key = gap_analysis.make_resources_key([standard_name, importing_name]) + if not collection.gap_analysis_exists(bw_key): + bw_job = gap_analysis.schedule( + standards=[standard_name, importing_name], database=collection + ) + backward_job_id = bw_job.get("job_id") + try: + backward_job = job.Job.fetch(id=backward_job_id, connection=conn) + jobs.append(backward_job) + except exceptions.NoSuchJobError as nje: + logger.error( + f"Could not find gap analysis job for for {importing_name} and {standard_name} putting {standard_name} back in the queue" + ) + pending_stadards.append(standard_name) + redis.wait_for_jobs(jobs) + conn.set(standard_hash, value="") + + +def parse_standards_from_spreadsheeet( + cre_file: List[Dict[str, Any]], + cache_location: str, + prompt_handler: prompt_client.PromptHandler, +) -> None: + """given a yaml with standards, build a list of standards in the db""" + collection = db_connect(cache_location) + if any(key.startswith("CRE hierarchy") for key in cre_file[0].keys()): + conn = redis.connect() + collection = collection.with_graph() + redis.empty_queues(conn) + q = Queue(connection=conn) + docs = spreadsheet_parsers.parse_hierarchical_export_format(cre_file) + total_resources = docs.keys() + jobs = [] + logger.info(f"Importing {len(docs.get(defs.Credoctypes.CRE.value))} CREs") + + with alive_bar(len(docs.get(defs.Credoctypes.CRE.value))) as bar: + for cre in docs.pop(defs.Credoctypes.CRE.value): + register_cre(cre, collection) + bar() + + if not os.environ.get("CRE_NO_NEO4J"): + populate_graph_db(cache_location) + if not os.environ.get("CRE_NO_GEN_EMBEDDINGS"): + prompt_handler.generate_embeddings_for(defs.Credoctypes.CRE.value) + + import_only = [] + if os.environ.get("CRE_ROOT_CSV_IMPORT_ONLY", None): + import_list = os.environ.get("CRE_ROOT_CSV_IMPORT_ONLY") + try: + import_list_json = json.loads(import_list) + except json.JSONDecodeError as jde: + env_value = os.environ.get("CRE_ROOT_CSV_IMPORT_ONLY") + logger.error(f"value '{env_value}' is not valid json") + raise jde + if type(import_list_json) == list: + import_only.extend(import_list_json) + else: + logger.warning( + f"CRE_ROOT_CSV_IMPORT_ONLY should be a list of standards to import, received {type(import_list_json)} {import_list}" + ) + database = db_connect(cache_location) + for standard_name, standard_entries in docs.items(): + if os.environ.get("CRE_NO_REIMPORT_IF_EXISTS") and database.get_nodes( + name=standard_name + ): + logger.info( + f"Already know of {standard_name} and CRE_NO_REIMPORT_IF_EXISTS is set, skipping" + ) + continue + if import_only and standard_name not in import_only: + logger.info( + f"skipping standard {standard_name} as it's not in the list of {import_only}" + ) + continue + jobs.append( + q.enqueue_call( + description=standard_name, + func=register_standard, + kwargs={ + "standard_entries": standard_entries, + "collection": None, + "db_connection_str": cache_location, + }, + timeout=gap_analysis.GAP_ANALYSIS_TIMEOUT, + ) + ) + t0 = time.perf_counter() + total_standards = len(jobs) + logger.info(f"Importing {total_standards} Standards") + with alive_bar(theme="classic", total=total_standards) as bar: + redis.wait_for_jobs(jobs, bar) + logger.info( + f"imported {total_standards} standards in {time.perf_counter()-t0} seconds" + ) + return total_resources + else: + logger.fatal(f"could not find any useful keys { cre_file[0].keys()}") + + +def get_cre_files_from_disk(cre_loc: str) -> Generator[str, None, None]: + for root, _, cre_docs in os.walk(cre_loc): + for name in cre_docs: + if name.endswith(".yaml") or name.endswith(".yml"): + yield os.path.join(root, name) + + +def add_from_spreadsheet(spreadsheet_url: str, cache_loc: str, cre_loc: str) -> None: + """--add --from_spreadsheet + use the cre db in this repo + import new mappings from + export db to ../../cres/ + """ + database = db_connect(path=cache_loc) + prompt_handler = ai_client_init(database=database) + spreadsheet = sheet_utils.read_spreadsheet( + url=spreadsheet_url, alias="new spreadsheet", validate=False + ) + for _, contents in spreadsheet.items(): + parse_standards_from_spreadsheeet(contents, cache_loc, prompt_handler) + + logger.info( + "Db located at %s got updated, files extracted at %s" % (cache_loc, cre_loc) + ) + + +def add_from_disk(cache_loc: str, cre_loc: str) -> None: + """--add --cre_loc + use the cre db in this repo + import new mappings from + export db to ../../cres/ + """ + database = db_connect(path=cache_loc) + for file in get_cre_files_from_disk(cre_loc): + with open(file, "rb") as standard: + parse_file( + filename=file, + yamldocs=list(yaml.safe_load_all(standard)), + scollection=database, + ) + + +def review_from_spreadsheet(cache: str, spreadsheet_url: str, share_with: str) -> None: + """--review --from_spreadsheet + copy db to new temp dir, + import new mappings from spreadsheet + export db to tmp dir + create new spreadsheet of the new CRE landscape for review + """ + loc, cache = prepare_for_review(cache) + database = db_connect(path=cache) + prompt_handler = ai_client_init(database=database) + spreadsheet = sheet_utils.read_spreadsheet( + url=spreadsheet_url, alias="new spreadsheet", validate=False + ) + for _, contents in spreadsheet.items(): + parse_standards_from_spreadsheeet(contents, database, prompt_handler) + + logger.info( + "Stored temporary files and database in %s if you want to use them next time, set cache to the location of the database in that dir" + % loc + ) + # logger.info("A spreadsheet view is at %s" % sheet_url) + + +def download_graph_from_upstream(cache: str) -> None: + imported_cres = {} + collection = db_connect(path=cache).with_graph() + + def download_cre_from_upstream(creid: str): + cre_response = requests.get( + os.environ.get("CRE_UPSTREAM_API_URL", "https://opencre.org/rest/v1") + + f"/id/{creid}" + ) + if cre_response.status_code != 200: + raise RuntimeError( + f"cannot connect to upstream status code {cre_response.status_code}" + ) + data = cre_response.json() + credict = data["data"] + cre = defs.Document.from_dict(credict) + if cre.id in imported_cres: + return + + register_cre(cre, collection) + imported_cres[cre.id] = "" + for link in cre.links: + if link.document.doctype == defs.Credoctypes.CRE: + download_cre_from_upstream(link.document.id) + + root_cres_response = requests.get( + os.environ.get("CRE_UPSTREAM_API_URL", "https://opencre.org/rest/v1") + + "/root_cres" + ) + if root_cres_response.status_code != 200: + raise RuntimeError( + f"cannot connect to upstream status code {root_cres_response.status_code}" + ) + data = root_cres_response.json() + for root_cre in data["data"]: + cre = defs.Document.from_dict(root_cre) + register_cre(cre, collection) + imported_cres[cre.id] = "" + for link in cre.links: + if link.document.doctype == defs.Credoctypes.CRE: + download_cre_from_upstream(link.document.id) + + +# def review_from_disk(cache: str, cre_file_loc: str, share_with: str) -> None: +# """--review --cre_loc +# copy db to new temp dir, +# import new mappings from yaml files defined in +# export db to tmp dir +# create new spreadsheet of the new CRE landscape for review +# """ +# loc, cache = prepare_for_review(cache) +# database = db_connect(path=cache) +# for file in get_cre_files_from_disk(cre_file_loc): +# with open(file, "rb") as standard: +# parse_file( +# filename=file, +# yamldocs=list(yaml.safe_load_all(standard)), +# scollection=database, +# ) + +# sheet_url = create_spreadsheet( +# collection=database, +# exported_documents=docs, +# title="cre_review", +# share_with=[share_with], +# ) +# logger.info( +# "Stored temporary files and database in %s if you want to use them next time, set cache to the location of the database in that dir" +# % loc +# ) +# logger.info("A spreadsheet view is at %s" % sheet_url) + + +def run(args: argparse.Namespace) -> None: # pragma: no cover + script_path = os.path.dirname(os.path.realpath(__file__)) + os.path.join(script_path, "../cres") + + # if args.review and args.from_spreadsheet: + # review_from_spreadsheet( + # cache=args.cache_file, + # spreadsheet_url=args.from_spreadsheet, + # share_with=args.email, + # ) + # elif args.review and args.cre_loc: + # review_from_disk( + # cache=args.cache_file, cre_file_loc=args.cre_loc, share_with=args.email + # ) + if args.add and args.from_spreadsheet: + add_from_spreadsheet( + spreadsheet_url=args.from_spreadsheet, + cache_loc=args.cache_file, + cre_loc=args.cre_loc, + ) + elif args.add and args.cre_loc and not args.from_spreadsheet: + add_from_disk(cache_loc=args.cache_file, cre_loc=args.cre_loc) + # elif args.review and args.osib_in: + # review_osib_from_file( + # file_loc=args.osib_in, cache=args.cache_file, cre_loc=args.cre_loc + # ) + + # elif args.add and args.osib_in: + # add_osib_from_file( + # file_loc=args.osib_in, cache=args.cache_file, cre_loc=args.cre_loc + # ) + + # elif args.osib_out: + # export_to_osib(file_loc=args.osib_out, cache=args.cache_file) + + if args.delete_map_analysis_for: + cache = db_connect(args.cache_file) + cache.delete_gapanalysis_results_for(args.delete_map_analysis_for) + if args.delete_resource: + cache = db_connect(args.cache_file) + cache.delete_nodes(args.delete_resource) + + # individual resource importing + if args.zap_in: + from application.utils.external_project_parsers.parsers import zap_alerts_parser + + BaseParser().register_resource( + zap_alerts_parser.ZAP, db_connection_str=args.cache_file + ) + if args.cheatsheets_in: + from application.utils.external_project_parsers.parsers import ( + cheatsheets_parser, + ) + + BaseParser().register_resource( + cheatsheets_parser.Cheatsheets, db_connection_str=args.cache_file + ) + if args.github_tools_in: + from application.utils.external_project_parsers.parsers import misc_tools_parser + + BaseParser().register_resource( + misc_tools_parser.MiscTools, db_connection_str=args.cache_file + ) + if args.capec_in: + from application.utils.external_project_parsers.parsers import capec_parser + + BaseParser().register_resource( + capec_parser.Capec, db_connection_str=args.cache_file + ) + if args.cwe_in: + from application.utils.external_project_parsers.parsers import cwe + + BaseParser().register_resource(cwe.CWE, db_connection_str=args.cache_file) + if args.csa_ccm_v4_in: + from application.utils.external_project_parsers.parsers import ccmv4 + + BaseParser().register_resource( + ccmv4.CloudControlsMatrix, db_connection_str=args.cache_file + ) + if args.iso_27001_in: + from application.utils.external_project_parsers.parsers import iso27001 + + BaseParser().register_resource( + iso27001.ISO27001, db_connection_str=args.cache_file + ) + if args.owasp_secure_headers_in: + from application.utils.external_project_parsers.parsers import secure_headers + + BaseParser().register_resource( + secure_headers.SecureHeaders, db_connection_str=args.cache_file + ) + if args.pci_dss_4_in: + from application.utils.external_project_parsers.parsers import pci_dss + + BaseParser().register_resource( + pci_dss.PciDss, db_connection_str=args.cache_file + ) + if args.juiceshop_in: + from application.utils.external_project_parsers.parsers import juiceshop + + BaseParser().register_resource( + juiceshop.JuiceShop, db_connection_str=args.cache_file + ) + if args.dsomm_in: + from application.utils.external_project_parsers.parsers import dsomm + + BaseParser().register_resource(dsomm.DSOMM, db_connection_str=args.cache_file) + if args.cloud_native_security_controls_in: + from application.utils.external_project_parsers.parsers import ( + cloud_native_security_controls, + ) + + BaseParser().register_resource( + cloud_native_security_controls.CloudNativeSecurityControls, + db_connection_str=args.cache_file, + ) + # /end individual resource importing + + if args.import_external_projects: + BaseParser().call_importers(db_connection_str=args.cache_file) + + if args.generate_embeddings: + generate_embeddings(args.cache_file) + if args.populate_graph_db: + populate_graph_db(args.cache_file) + if args.start_worker: + from application.worker import start_worker + + start_worker() + + if args.preload_map_analysis_target_url: + gap_analysis.preload(target_url=args.preload_map_analysis_target_url) + if args.upstream_sync: + download_graph_from_upstream(args.cache_file) + + +def ai_client_init(database: db.Node_collection): + return prompt_client.PromptHandler(database=database) + + +def db_connect(path: str): + global app + conf = CMDConfig(db_uri=path) + app = create_app(conf=conf) + collection = db.Node_collection() + app_context = app.app_context() + app_context.push() + logger.info(f"successfully connected to the database at {path}") + return collection + + +def create_spreadsheet( + collection: db.Node_collection, + exported_documents: List[Any], + title: str, + share_with: List[str], +) -> Any: + """Reads cre docs exported from a standards_collection.export() + dumps each doc into a workbook""" + flat_dicts = sheet_utils.prepare_spreadsheet(docs=exported_documents) + return sheet_utils.write_spreadsheet( + title=title, docs=flat_dicts, emails=share_with + ) + + +def prepare_for_review(cache: str) -> Tuple[str, str]: + loc = tempfile.mkdtemp() + cache_filename = os.path.basename(cache) + if os.path.isfile(cache): + shutil.copy(cache, loc) + else: + logger.fatal("Could not copy database %s this seems like a bug" % cache) + return loc, os.path.join(loc, cache_filename) + + +def generate_embeddings(db_url: str) -> None: + database = db_connect(path=db_url) + prompt_client.PromptHandler(database, load_all_embeddings=True) + + +def populate_graph_db(cache: str): + logger.info(f"Populating graph DB: Connecting to SQL DB") + database = db_connect(path=cache) + logger.info(f"Populating graph DB: Populating") + database.graph_db.populate_DB(database.session) + logger.info(f"Populating graph DB: Complete") diff --git a/application/config.py b/application/config.py index 8459f8208..693dd9b8b 100644 --- a/application/config.py +++ b/application/config.py @@ -13,6 +13,12 @@ class Config: GAP_ANALYSIS_OPTIMIZED = ( os.environ.get("GAP_ANALYSIS_OPTIMIZED", "False").lower() == "true" ) + GRAPH_DB_TYPE = os.environ.get("GRAPH_DB_TYPE", "neo4j").lower() + AGE_URL = ( + os.environ.get("AGE_URL") + or "postgresql://postgres:password@localhost:5433/postgres" + ) + AGE_GRAPH = os.environ.get("AGE_GRAPH") or "opencre" class DevelopmentConfig(Config): diff --git a/application/database/db.py b/application/database/db.py index 6c1613277..37f742a55 100644 --- a/application/database/db.py +++ b/application/database/db.py @@ -2,10 +2,13 @@ import networkx as nx import uuid import neo4j +import psycopg2 +from psycopg2.extras import RealDictCursor import os import logging import re import yaml +import threading from pprint import pprint @@ -601,7 +604,7 @@ def _gap_analysis_optimized(self, name_1, name_2): """ MATCH (BaseStandard:NeoStandard {name: $name1}) MATCH (CompareStandard:NeoStandard {name: $name2}) - MATCH p = allShortestPaths((BaseStandard)-[:(LINKED_TO|AUTOMATICALLY_LINKED_TO|SAME)*..20]-(CompareStandard)) + MATCH p = (BaseStandard)-[:(LINKED_TO|AUTOMATICALLY_LINKED_TO|SAME)*..20]-(CompareStandard) WITH p WHERE length(p) > 1 AND ALL(n in NODES(p) WHERE (n:NeoCRE or n = BaseStandard or n = CompareStandard) AND NOT n.name in $denylist) RETURN p @@ -622,7 +625,7 @@ def _gap_analysis_optimized(self, name_1, name_2): """ MATCH (BaseStandard:NeoStandard {name: $name1}) MATCH (CompareStandard:NeoStandard {name: $name2}) - MATCH p = allShortestPaths((BaseStandard)-[:(LINKED_TO|AUTOMATICALLY_LINKED_TO|SAME|CONTAINS)*..20]-(CompareStandard)) + MATCH p = (BaseStandard)-[:(LINKED_TO|AUTOMATICALLY_LINKED_TO|SAME|CONTAINS)*..20]-(CompareStandard) WITH p WHERE length(p) > 1 AND ALL(n in NODES(p) WHERE (n:NeoCRE or n = BaseStandard or n = CompareStandard) AND NOT n.name in $denylist) RETURN p @@ -645,7 +648,7 @@ def _gap_analysis_optimized(self, name_1, name_2): """ MATCH (BaseStandard:NeoStandard {name: $name1}) MATCH (CompareStandard:NeoStandard {name: $name2}) - MATCH p = allShortestPaths((BaseStandard)-[*..20]-(CompareStandard)) + MATCH p = (BaseStandard)-[*..20]-(CompareStandard) WITH p WHERE length(p) > 1 AND ALL (n in NODES(p) where (n:NeoCRE or n = BaseStandard or n = CompareStandard) AND NOT n.name in $denylist) RETURN p @@ -675,7 +678,7 @@ def _gap_analysis_original(self, name_1, name_2): """ MATCH (BaseStandard:NeoStandard {name: $name1}) MATCH (CompareStandard:NeoStandard {name: $name2}) - MATCH p = allShortestPaths((BaseStandard)-[*..20]-(CompareStandard)) + MATCH p = (BaseStandard)-[*..20]-(CompareStandard) WITH p WHERE length(p) > 1 AND ALL (n in NODES(p) where (n:NeoCRE or n = BaseStandard or n = CompareStandard) AND NOT n.name in $denylist) RETURN p @@ -689,7 +692,7 @@ def _gap_analysis_original(self, name_1, name_2): """ MATCH (BaseStandard:NeoStandard {name: $name1}) MATCH (CompareStandard:NeoStandard {name: $name2}) - MATCH p = allShortestPaths((BaseStandard)-[:(LINKED_TO|AUTOMATICALLY_LINKED_TO|CONTAINS)*..20]-(CompareStandard)) + MATCH p = (BaseStandard)-[:(LINKED_TO|AUTOMATICALLY_LINKED_TO|CONTAINS)*..20]-(CompareStandard) WITH p WHERE length(p) > 1 AND ALL(n in NODES(p) WHERE (n:NeoCRE or n = BaseStandard or n = CompareStandard) AND NOT n.name in $denylist) RETURN p @@ -793,14 +796,377 @@ def parse_node_no_links(node: NeoDocument) -> cre_defs.Document: return node.to_cre_def(node, parse_links=False) +class AGEDB: + __instance = None + conn: Optional[psycopg2.extensions.connection] = None + graph_name: str = "opencre" + connected: bool = False + _connection_attempted: bool = False + _disabled_permanently: bool = False + _is_connecting: bool = False + _lock = threading.Lock() + + @classmethod + def instance(cls): + """Singleton instance for Apache AGE database connection.""" + if cls.__instance is None: + with cls._lock: + if cls.__instance is None: + cls.__instance = cls.__new__(cls) + from application.config import Config + + cls.graph_name = Config.AGE_GRAPH or "opencre" + cls.conn = None + cls.connected = False + cls._connection_attempted = False + cls._disabled_permanently = False + cls._is_connecting = False + + # If already connected, return + if cls.connected: + return cls.__instance + + # CIRCUIT BREAKER: If disabled or already tried and failed, don't try again + if cls._disabled_permanently: + return cls.__instance + + # Start background connection if not already in progress + if not cls._is_connecting: + with cls._lock: + if not cls._is_connecting and not cls.connected: + cls._is_connecting = True + thread = threading.Thread(target=cls._connect_background) + thread.daemon = True + thread.start() + + return cls.__instance + + @classmethod + def _connect_background(cls): + """Background thread to handle the potentially hanging connection.""" + from application.config import Config + + try: + logger.info( + f"Background: Attempting to connect to Apache AGE at {Config.AGE_URL}..." + ) + cls.conn = psycopg2.connect(Config.AGE_URL, connect_timeout=10) + cls.conn.autocommit = True + + with cls.conn.cursor() as cursor: + cursor.execute("SET statement_timeout = 10000;") + cursor.execute("CREATE EXTENSION IF NOT EXISTS age;") + cursor.execute("LOAD 'age';") + cursor.execute('SET search_path = ag_catalog, "$user", public;') + # Create graph only if it doesn't exist — ignore 'already exists' error + try: + cursor.execute(f"SELECT create_graph('{cls.graph_name}');") + except Exception: + # Graph already exists, that's fine — reset connection state + cls.conn.rollback() + # Create indexes (ignore errors if already exist) + for idx_query in [ + f"SELECT * FROM cypher('{cls.graph_name}', $$ CREATE INDEX ON :NeoStandard(name) $$) as (a agtype);", + f"SELECT * FROM cypher('{cls.graph_name}', $$ CREATE INDEX ON :NeoCRE(name) $$) as (a agtype);", + ]: + try: + cursor.execute(idx_query) + except Exception: + cls.conn.rollback() + cursor.execute("SET statement_timeout = 30000;") + + cls.connected = True + logger.info("Background: Successfully connected to Apache AGE.") + except Exception as e: + logger.error( + f"Background: Apache AGE connection failed: {e}. disabling AGE for this session." + ) + cls._disabled_permanently = True + cls.connected = False + if cls.conn: + try: + cls.conn.close() + except: + pass + cls.conn = None + finally: + cls._is_connecting = False + cls._connection_attempted = True + + @classmethod + def instance_blocking(cls, timeout: int = 30): + """Synchronous version of instance() - blocks until connection is ready or timeout.""" + import time + + cls.instance() # Start background thread + waited = 0 + while cls._is_connecting and waited < timeout: + time.sleep(0.5) + waited += 0.5 + if not cls.connected: + logger.warning(f"AGE connection not available after {timeout}s.") + return cls.__instance + + @classmethod + def gap_analysis(cls, name_1, name_2): + from application.config import Config + + logger.info(f"AGE Gap Analysis for {name_1} >> {name_2}") + if not cls.conn: + cls.instance() + if not cls.conn: + return [], [] + + # Find base standard nodes + n1 = name_1.replace('"', '\\"') + base_query = f"SELECT * FROM cypher('{cls.graph_name}', $AGE$ MATCH (n:NeoStandard {{name: \"{n1}\"}}) RETURN n $AGE$) as (n agtype);" + with cls.conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute(base_query) + base_recs = cursor.fetchall() + base_standards = [cls._parse_agtype(r["n"]) for r in base_recs] + + # Find paths + # Note: Apache AGE does not have allShortestPaths, so we use a limited depth search. + # Depth 6 is usually enough for most OpenCRE relationships. + n2 = name_2.replace('"', '\\"') + path_query = f""" + SELECT * FROM cypher('{cls.graph_name}', $AGE$ + MATCH (BaseStandard:NeoStandard {{name: \"{n1}\"}}) + MATCH (CompareStandard:NeoStandard {{name: \"{n2}\"}}) + MATCH p = (BaseStandard)-[*..6]-(CompareStandard) + RETURN p + LIMIT 50 + $AGE$) as (p agtype); + """ + with cls.conn.cursor(cursor_factory=RealDictCursor) as cursor: + try: + cursor.execute(path_query) + path_recs = cursor.fetchall() + paths = [cls._parse_agpath(r["p"]) for r in path_recs] + except Exception as e: + logger.error(f"AGE Path Query failed: {e}") + paths = [] + + return base_standards, paths + + @classmethod + def _parse_agtype(cls, agtype_val): + """Convert AGE's agtype (usually JSON string or dict) to a format compatible with the app.""" + if isinstance(agtype_val, str): + import json + import re + + # AGE appends ::vertex, ::edge, ::path at the end of the string. + # We strip them carefully avoiding :: inside values. + json_str = re.sub(r"::(vertex|edge|path)$", "", agtype_val) + try: + data = json.loads(json_str) + except json.JSONDecodeError: + # Fallback: if it's still failing, it might not be JSON at all + logger.error(f"Failed to parse agtype string: {agtype_val}") + data = {} + else: + data = agtype_val + + # Return a mock object that behaves like the NeoDocument models + class MockNode(dict): + def __init__(self, d): + self.id = str(d.get("id")) + self.label = d.get("label", "") + props = d.get("properties", {}) + self.name = props.get("name") + self.section = props.get("section") + self.subsection = props.get("subsection") + self.external_id = props.get("external_id") + self.section_id = props.get("section_id") + self.version = props.get("version") + self.description = props.get("description") + self.sql_id = str(props.get("sql_id", "")) + + # Map label to doctype for frontend compatibility + self.doctype = "Standard" + if self.label == "NeoCRE": + self.doctype = "CRE" + elif self.label == "NeoTool": + self.doctype = "Tool" + + # Store in dict for JSON serialization to frontend + super().__init__( + { + "id": self.id, + "name": self.name, + "section": self.section, + "sectionID": self.section_id, + "subsection": self.subsection, + "version": self.version, + "external_id": self.external_id, + "description": self.description, + "doctype": self.doctype, + "sql_id": self.sql_id, + } + ) + + def to_cre_def(self, *args, **kwargs): + from application.defs import cre_defs + + if self.doctype == "CRE": + return cre_defs.CRE( + name=self.name, + id=self.external_id, + description=self.description, + ) + return cre_defs.Standard( + name=self.name, + section=self.section, + subsection=self.subsection, + version=self.version, + sectionID=self.section_id, + ) + + return MockNode(data) + + @classmethod + def _parse_agpath(cls, agpath_val): + """Translate AGE path to Neo4j-like path record.""" + if isinstance(agpath_val, str): + import json + + # AGE sometimes appends ::path to the JSON string + json_str = agpath_val.split("::")[0] + try: + data = json.loads(json_str) + except json.JSONDecodeError: + logger.error(f"Failed to parse agpath string: {agpath_val}") + return {"start": {"id": "stub"}, "end": {"id": "stub"}, "path": []} + else: + data = agpath_val + + # AGE path is a list: [v1, e1, v2, e2, v3, ...] + # Note: Depending on AGE version and driver, structure might vary. + # Usually it's a dict with 'vertices' and 'edges' or a flat list. + # This implementation assumes flat list. + if not isinstance(data, list): + return {"start": {"id": "stub"}, "end": {"id": "stub"}, "path": []} + + vertices = [v for i, v in enumerate(data) if i % 2 == 0] + edges = [e for i, e in enumerate(data) if i % 2 != 0] + + path_steps = [] + for i in range(len(edges)): + step = { + "start": cls._parse_agtype(vertices[i]), + "end": cls._parse_agtype(vertices[i + 1]), + "relationship": edges[i].get("label"), + } + path_steps.append(step) + + return { + "start": cls._parse_agtype(vertices[0]), + "end": cls._parse_agtype(vertices[-1]), + "path": path_steps, + } + + @classmethod + def populate_DB(cls, session): + logger.info("Populating Apache AGE DB from Postgres") + if not cls.conn: + cls.instance_blocking(timeout=30) + if not cls.conn: + logger.error("No AGE connection, skipping population") + return + + with cls.conn.cursor() as cursor: + # Clear existing data + cursor.execute( + f"SELECT * FROM cypher('{cls.graph_name}', $AGE$ MATCH (n) DETACH DELETE n $AGE$) as (a agtype);" + ) + + # Migration of Standards + nodes = session.query(Node).all() + logger.info(f"Migrating {len(nodes)} Standards to AGE") + for node in nodes: + try: + name = (node.name or "").replace('"', '\\"') + section = (node.section or "").replace('"', '\\"') + section_id = (node.section_id or "").replace('"', '\\"') + subsection = (node.subsection or "").replace('"', '\\"') + version = (node.version or "").replace('"', '\\"') + link = (node.link or "").replace('"', '\\"') + sid = str(node.id) + cursor.execute( + f'SELECT * FROM cypher(\'{cls.graph_name}\', $AGE$ CREATE (n:NeoStandard {{name: "{name}", section: "{section}", section_id: "{section_id}", subsection: "{subsection}", version: "{version}", link: "{link}", sql_id: "{sid}"}}) $AGE$) as (a agtype);' + ) + except Exception as e: + logger.error(f"Failed to migrate Standard {node.id}: {e}") + + # Migration of CREs + cres = session.query(CRE).all() + logger.info(f"Migrating {len(cres)} CREs to AGE") + for cre in cres: + try: + name = (cre.name or "").replace('"', '\\"') + desc = (cre.description or "").replace('"', '\\"') + eid = (cre.external_id or "").replace('"', '\\"') + sid = str(cre.id) + cursor.execute( + f'SELECT * FROM cypher(\'{cls.graph_name}\', $AGE$ CREATE (n:NeoCRE {{name: "{name}", description: "{desc}", external_id: "{eid}", sql_id: "{sid}"}}) $AGE$) as (a agtype);' + ) + except Exception as e: + logger.error(f"Failed to migrate CRE {cre.id}: {e}") + + # Migration of CRE-CRE Links (InternalLinks) + internal_links = session.query(InternalLinks).all() + logger.info(f"Migrating {len(internal_links)} CRE-CRE links to AGE") + for link in internal_links: + try: + rel_type = re.sub(r"[^a-zA-Z0-9_]", "", link.type.upper()) + gid = str(link.group) + cid = str(link.cre) + cursor.execute( + f'SELECT * FROM cypher(\'{cls.graph_name}\', $AGE$ MATCH (a:NeoCRE {{sql_id: "{gid}"}}), (b:NeoCRE {{sql_id: "{cid}"}}) CREATE (a)-[:{rel_type}]->(b) $AGE$) as (a agtype);' + ) + except Exception as e: + # logger.debug(f"Failed to migrate CRE-CRE link {link.id}: {e}") + pass + + # Migration of CRE-Standard Links (Links) + links = session.query(Links).all() + logger.info(f"Migrating {len(links)} CRE-Standard links to AGE") + for link in links: + try: + rel_type = re.sub(r"[^a-zA-Z0-9_]", "", link.type.upper()) + cid = str(link.cre) + nid = str(link.node) + cursor.execute( + f'SELECT * FROM cypher(\'{cls.graph_name}\', $AGE$ MATCH (a:NeoCRE {{sql_id: "{cid}"}}), (b:NeoStandard {{sql_id: "{nid}"}}) CREATE (a)-[:{rel_type}]->(b) $AGE$) as (a agtype);' + ) + except Exception as e: + # logger.debug(f"Failed to migrate CRE-Standard link {link.id}: {e}") + pass + + logger.info( + f"Populated {len(nodes)} standards, {len(cres)} CREs, and synchronized all links into AGE graph '{cls.graph_name}'" + ) + + +class GraphDB: + @staticmethod + def instance(): + from application.config import Config + + if Config.GRAPH_DB_TYPE == "age": + return AGEDB.instance() + return NEO_DB.instance() + + class Node_collection: graph: inmemory_graph.CRE_Graph = None - neo_db: NEO_DB = None + graph_db: Any = None session = sqla.session def __init__(self) -> None: if not os.environ.get("NO_LOAD_GRAPH_DB"): - self.neo_db = NEO_DB.instance() + self.graph_db = GraphDB.instance() self.session = sqla.session def with_graph(self) -> "Node_collection": @@ -1799,7 +2165,7 @@ def standards(self) -> List[str]: return list(set([s[0] for s in standards])) def text_search(self, text: str) -> List[Optional[cre_defs.Document]]: - """Given a piece of text, tries to find the best match + r"""Given a piece of text, tries to find the best match for the text in the database. Shortcuts: 'CRE:' will search for the in cre external ids @@ -1818,7 +2184,7 @@ def text_search(self, text: str) -> List[Optional[cre_defs.Document]]: node_search = ( r"(Node|(?P" + types - + "))?((:| )?(?Phttps?://\S+))?((:| )(?P.+$))?" + + r"))?((:| )?(?Phttps?://\S+))?((:| )(?P.+$))?" ) match = re.search(cre_id_search, text, re.IGNORECASE) if match: @@ -2171,12 +2537,12 @@ def dbCREfromCRE(cre: cre_defs.CRE) -> CRE: def gap_analysis( - neo_db: NEO_DB, + graph_db: Any, node_names: List[str], cache_key: str = "", ): cre_db = Node_collection() - base_standard, paths = neo_db.gap_analysis(node_names[0], node_names[1]) + base_standard, paths = graph_db.gap_analysis(node_names[0], node_names[1]) logger.info(f"got db gap analysis for {'>>>'.join(node_names)}, calculating paths") if base_standard is None: return None @@ -2196,8 +2562,8 @@ def gap_analysis( extra_paths_dict[key] = {"paths": {}} for path in paths: - key = path["start"].id - end_key = path["end"].id + key = getattr(path["start"], "id", None) or path["start"].get("id") + end_key = getattr(path["end"], "id", None) or path["end"].get("id") if not end_key: logger.error( f"end_key is empty, this is a bug and this gap analysis will not progress" @@ -2237,8 +2603,8 @@ def gap_analysis( for key in extra_paths_dict: cre_db.add_gap_analysis_result( - cache_key=make_subresources_key(node_names, key), + cache_key=make_subresources_key(node_names, str(key)), ga_object=flask_json.dumps({"result": extra_paths_dict[key]}), ) - logger.info(f"stored gapa analysis for {'>>>'.join(node_names)}, successfully") + logger.info(f"stored gap analysis for {'>>>'.join(node_names)}, successfully") return (node_names, grouped_paths, extra_paths_dict) diff --git a/application/frontend/www/bundle.js.LICENSE.txt b/application/frontend/www/bundle.js.LICENSE.txt deleted file mode 100644 index f37e1e30c..000000000 --- a/application/frontend/www/bundle.js.LICENSE.txt +++ /dev/null @@ -1,66 +0,0 @@ -/* -object-assign -(c) Sindre Sorhus -@license MIT -*/ - -/*! - Copyright (c) 2015 Jed Watson. - Based on code that is Copyright 2013-2015, Facebook, Inc. - All rights reserved. -*/ - -/*! @license DOMPurify 3.0.5 | (c) Cure53 and other contributors | Released under the Apache license 2.0 and Mozilla Public License 2.0 | github.com/cure53/DOMPurify/blob/3.0.5/LICENSE */ - -/*! fromentries. MIT License. Feross Aboukhadijeh */ - -/** - * @license - * Copyright 2010-2023 Three.js Authors - * SPDX-License-Identifier: MIT - */ - -/** - * Prism: Lightweight, robust, elegant syntax highlighting - * - * @license MIT - * @author Lea Verou - * @namespace - * @public - */ - -/** @license React v0.20.2 - * scheduler.production.min.js - * - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -/** @license React v17.0.2 - * react-dom.production.min.js - * - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -/** @license React v17.0.2 - * react-is.production.min.js - * - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -/** @license React v17.0.2 - * react.production.min.js - * - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ diff --git a/application/tests/cre_main_test.py b/application/tests/cre_main_test.py index af313c8a6..78713c88e 100644 --- a/application/tests/cre_main_test.py +++ b/application/tests/cre_main_test.py @@ -6,7 +6,12 @@ from typing import Any, Dict, List from unittest import mock from unittest.mock import Mock, patch -from rq import Queue + +try: + from rq import Queue +except (ValueError, ImportError): + Queue = None + from application.utils import redis from application.prompt_client import prompt_client as prompt_client from application.tests.utils import data_gen diff --git a/application/tests/db_test.py b/application/tests/db_test.py index 1d13bd0be..6241bdbd3 100644 --- a/application/tests/db_test.py +++ b/application/tests/db_test.py @@ -1,2289 +1,2299 @@ -import networkx as nx -from application.utils.gap_analysis import make_resources_key, make_subresources_key -import string -import random -import os -import tempfile -import unittest -from unittest import mock -from unittest.mock import patch -import uuid -from copy import copy, deepcopy -from pprint import pprint -from typing import Any, Dict, List, Union -from flask import json as flask_json - -import yaml -from application.tests.utils.data_gen import export_format_data -from application import create_app, sqla # type: ignore -from application.database import db -from application.defs import cre_defs as defs - - -class TestDB(unittest.TestCase): - def tearDown(self) -> None: - sqla.session.remove() - sqla.drop_all() - self.app_context.pop() - - def setUp(self) -> None: - self.app = create_app(mode="test") - self.app_context = self.app.app_context() - self.app_context.push() - sqla.create_all() - - self.collection = db.Node_collection().with_graph() - self.collection.graph.with_graph( - graph=nx.DiGraph(), graph_data=[] - ) # initialize the graph singleton for the tests to be unique - - collection = self.collection - - dbcre = collection.add_cre( - defs.CRE(id="111-000", description="CREdesc", name="CREname") - ) - self.dbcre = dbcre - dbgroup = collection.add_cre( - defs.CRE(id="111-001", description="Groupdesc", name="GroupName") - ) - dbstandard = collection.add_node( - defs.Standard( - subsection="4.5.6", - section="FooStand", - name="BarStand", - hyperlink="https://example.com", - tags=["788-788", "b", "c"], - ) - ) - - collection.add_node( - defs.Standard( - subsection="4.5.6", - section="Unlinked", - name="Unlinked", - hyperlink="https://example.com", - ) - ) - - collection.session.add(dbcre) - collection.add_link(cre=dbcre, node=dbstandard, ltype=defs.LinkTypes.LinkedTo) - collection.add_internal_link( - lower=dbcre, higher=dbgroup, ltype=defs.LinkTypes.Contains - ) - - self.collection = collection - - def test_get_by_tags(self) -> None: - """ - Given: A CRE with no links and a combination of possible tags: - "tag1,dash-2,underscore_3,space 4,co_mb-ination%5" - A Standard with no links and a combination of possible tags - "tag1, dots.5.5, space 6 , several spaces and newline 7 \n" - some limited overlap between the tag-sets - Expect: - The CRE to be returned when searching for "tag-2" and for ["tag1","underscore_3"] - The Standard to be returned when searching for "space 6" and ["dots.5.5", "space 6"] - Both to be returned when searching for "space" and "tag1" - """ - - dbcre = db.CRE( - description="tagCREdesc1", - name="tagCREname1", - tags="tag1,dash-2,underscore_3,space 4,co_mb-ination%5", - external_id="111-111", - ) - cre = db.CREfromDB(dbcre) - cre.id = "111-111" - dbstandard = db.Node( - subsection="4.5.6.7", - section="tagsstand", - name="tagsstand", - link="https://example.com", - version="", - tags="tag1, dots.5.5, space 6 , several spaces and newline 7 \n", - ntype=defs.Standard.__name__, - ) - standard = db.nodeFromDB(dbstandard) - self.collection.session.add(dbcre) - self.collection.session.add(dbstandard) - self.collection.session.commit() - - self.maxDiff = None - self.assertEqual(self.collection.get_by_tags(["dash-2"]), [cre]) - self.assertEqual(self.collection.get_by_tags(["tag1", "underscore_3"]), [cre]) - self.assertEqual(self.collection.get_by_tags(["space 6"]), [standard]) - self.assertEqual( - self.collection.get_by_tags(["dots.5.5", "space 6"]), [standard] - ) - - self.assertCountEqual([cre, standard], self.collection.get_by_tags(["space"])) - self.assertCountEqual( - [cre, standard], self.collection.get_by_tags(["space", "tag1"]) - ) - self.assertCountEqual(self.collection.get_by_tags(["tag1"]), [cre, standard]) - - self.assertEqual(self.collection.get_by_tags([]), []) - self.assertEqual(self.collection.get_by_tags(["this should not be a tag"]), []) - - def test_get_standards_names(self) -> None: - result = self.collection.get_node_names() - expected = [("Standard", "BarStand"), ("Standard", "Unlinked")] - self.assertEqual(expected, result) - - def test_get_max_internal_connections(self) -> None: - self.assertEqual(self.collection.get_max_internal_connections(), 1) - - dbcrelo = db.CRE(name="internal connections test lo", description="ictlo") - dbcrehi = db.CRE(name="internal connections test hi", description="icthi") - self.collection.session.add(dbcrelo) - self.collection.session.add(dbcrehi) - self.collection.session.commit() - for i in range(0, 100): - dbcre = db.CRE(name=str(i) + " name", description=str(i) + " desc") - self.collection.session.add(dbcre) - self.collection.session.commit() - - # 1 low level cre to multiple groups - self.collection.session.add( - db.InternalLinks(group=dbcre.id, cre=dbcrelo.id) - ) - - # 1 hi level cre to multiple low level - self.collection.session.add( - db.InternalLinks(group=dbcrehi.id, cre=dbcre.id) - ) - - self.collection.session.commit() - - result = self.collection.get_max_internal_connections() - self.assertEqual(result, 100) - - def test_export(self) -> None: - """ - Given: - A CRE "CREname" that links to a CRE "GroupName" and a Standard "BarStand" - Expect: - 2 documents on disk, one for "CREname" - with a link to "BarStand" and "GroupName" and one for "GroupName" with a link to "CREName" - """ - loc = tempfile.mkdtemp() - self.collection = db.Node_collection().with_graph() - collection = self.collection - code0 = defs.Code(name="co0") - code1 = defs.Code(name="co1") - tool0 = defs.Tool(name="t0", tooltype=defs.ToolTypes.Unknown) - dbstandard = collection.add_node( - defs.Standard( - subsection="4.5.6", - section="FooStand", - sectionID="123-123", - name="BarStand", - hyperlink="https://example.com", - tags=["788-788", "b", "c"], - ) - ) - - collection.add_node( - defs.Standard( - subsection="4.5.6", - section="Unlinked", - sectionID="Unlinked", - name="Unlinked", - hyperlink="https://example.com", - ) - ) - self.collection.add_link( - self.dbcre, self.collection.add_node(code0), ltype=defs.LinkTypes.LinkedTo - ) - self.collection.add_node(code1) - self.collection.add_node(tool0) - - expected = [ - defs.CRE( - id="111-001", - description="Groupdesc", - name="GroupName", - links=[ - defs.Link( - document=defs.CRE( - id="111-000", description="CREdesc", name="CREname" - ), - ltype=defs.LinkTypes.Contains, - ) - ], - ), - defs.CRE( - id="111-000", - description="CREdesc", - name="CREname", - links=[ - defs.Link( - document=defs.CRE( - id="112-001", description="Groupdesc", name="GroupName" - ), - ltype=defs.LinkTypes.Contains, - ), - defs.Link( - document=defs.Standard( - name="BarStand", - section="FooStand", - sectionID="456", - subsection="4.5.6", - hyperlink="https://example.com", - tags=["788-788", "b", "c"], - ), - ltype=defs.LinkTypes.LinkedTo, - ), - defs.Link( - document=defs.Code(name="co0"), ltype=defs.LinkTypes.LinkedTo - ), - ], - ), - defs.Standard( - subsection="4.5.6", - section="Unlinked", - name="Unlinked", - sectionID="Unlinked", - hyperlink="https://example.com", - ), - defs.Tool(name="t0", tooltype=defs.ToolTypes.Unknown), - defs.Code(name="co1"), - ] - self.collection.export(loc) - - # load yamls from loc, parse, - # ensure yaml1 is result[0].todict and - # yaml2 is expected[1].todict - group = expected[0].todict() - cre = expected[1].todict() - groupname = ( - expected[0] - .id.replace("/", "-") - .replace(" ", "_") - .replace('"', "") - .replace("'", "") - + ".yaml" - ) - with open(os.path.join(loc, groupname), "r") as f: - doc = yaml.safe_load(f) - self.assertDictEqual(group, doc) - - crename = ( - expected[1] - .id.replace("/", "-") - .replace(" ", "_") - .replace('"', "") - .replace("'", "") - + ".yaml" - ) - self.maxDiff = None - with open(os.path.join(loc, crename), "r") as f: - doc = yaml.safe_load(f) - self.assertCountEqual(cre, doc) - - def test_StandardFromDB(self) -> None: - expected = defs.Standard( - name="foo", - section="bar", - sectionID="213", - subsection="foobar", - hyperlink="https://example.com/foo/bar", - version="1.1.1", - ) - self.assertEqual( - expected, - db.nodeFromDB( - db.Node( - name="foo", - section="bar", - subsection="foobar", - link="https://example.com/foo/bar", - version="1.1.1", - section_id="213", - ntype=defs.Standard.__name__, - ) - ), - ) - - def test_CREfromDB(self) -> None: - c = defs.CRE( - id="243-243", - doctype=defs.Credoctypes.CRE, - description="CREdesc", - name="CREname", - ) - self.assertEqual( - c, - db.CREfromDB( - db.CRE(external_id="243-243", description="CREdesc", name="CREname") - ), - ) - - def test_add_cre(self) -> None: - original_desc = str(uuid.uuid4()) - name = str(uuid.uuid4()) - - c = defs.CRE( - id="243-243", - doctype=defs.Credoctypes.CRE, - description=original_desc, - name=name, - ) - self.assertIsNone( - self.collection.session.query(db.CRE).filter(db.CRE.name == c.name).first() - ) - - # happy path, add new cre - newCRE = self.collection.add_cre(c) - dbcre = ( - self.collection.session.query(db.CRE).filter(db.CRE.name == c.name).first() - ) # ensure transaction happened (commit() called) - self.assertIsNotNone(dbcre.id) - self.assertEqual(dbcre.name, c.name) - self.assertEqual(dbcre.description, c.description) - self.assertEqual(dbcre.external_id, c.id) - - # ensure the right thing got returned - self.assertEqual(newCRE.name, c.name) - - # ensure no accidental update (add only adds) - c.description = "description2" - newCRE = self.collection.add_cre(c) - dbcre = ( - self.collection.session.query(db.CRE).filter(db.CRE.name == c.name).first() - ) - # ensure original description - self.assertEqual(dbcre.description, original_desc) - # ensure original description - self.assertEqual(newCRE.description, original_desc) - - def test_add_node(self) -> None: - original_section = str(uuid.uuid4()) - name = str(uuid.uuid4()) - - s = defs.Standard( - doctype=defs.Credoctypes.Standard, - section=original_section, - subsection=original_section, - name=name, - tags=["788-788", "b", "c"], - ) - - self.assertIsNone( - self.collection.session.query(db.Node) - .filter(db.Node.name == s.name) - .first() - ) - - # happy path, add new standard - newStandard = self.collection.add_node(s) - self.assertIsNotNone(newStandard) - - dbstandard = ( - self.collection.session.query(db.Node) - .filter(db.Node.name == s.name) - .first() - ) # ensure transaction happened (commit() called) - self.assertIsNotNone(dbstandard.id) - self.assertEqual(dbstandard.name, s.name) - self.assertEqual(dbstandard.section, s.section) - self.assertEqual(dbstandard.subsection, s.subsection) - self.assertEqual( - newStandard.name, s.name - ) # ensure the right thing got returned - self.assertEqual(dbstandard.ntype, s.doctype.value) - self.assertEqual(dbstandard.tags, ",".join(s.tags)) - # standards match on all of name,section, subsection <-- if you change even one of them it's a new entry - - def find_cres_of_cre(self) -> None: - dbcre = db.CRE(description="CREdesc1", name="CREname1") - groupless_cre = db.CRE(description="CREdesc2", name="CREname2") - dbgroup = db.CRE(description="Groupdesc1", name="GroupName1") - dbgroup2 = db.CRE(description="Groupdesc2", name="GroupName2") - - only_one_group = db.CRE(description="CREdesc3", name="CREname3") - - self.collection.session.add(dbcre) - self.collection.session.add(groupless_cre) - self.collection.session.add(dbgroup) - self.collection.session.add(dbgroup2) - self.collection.session.add(only_one_group) - self.collection.session.commit() - - internalLink = db.InternalLinks(cre=dbcre.id, group=dbgroup.id, type="Contains") - internalLink2 = db.InternalLinks( - cre=dbcre.id, group=dbgroup2.id, type="Contains" - ) - internalLink3 = db.InternalLinks( - cre=only_one_group.id, group=dbgroup.id, type="Contains" - ) - self.collection.session.add(internalLink) - self.collection.session.add(internalLink2) - self.collection.session.add(internalLink3) - self.collection.session.commit() - - # happy path, find cre with 2 groups - - groups = self.collection.find_cres_of_cre(dbcre) - if not groups: - self.fail("Expected exactly 2 cres") - self.assertEqual(len(groups), 2) - self.assertEqual(groups, [dbgroup, dbgroup2]) - - # find cre with 1 group - group = self.collection.find_cres_of_cre(only_one_group) - - if not group: - self.fail("Expected exactly 1 cre") - self.assertEqual(len(group), 1) - self.assertEqual(group, [dbgroup]) - - # ensure that None is return if there are no groups - groups = self.collection.find_cres_of_cre(groupless_cre) - self.assertIsNone(groups) - - def test_find_cres_of_standard(self) -> None: - dbcre = db.CRE(description="CREdesc1", name="CREname1") - dbgroup = db.CRE(description="CREdesc2", name="CREname2") - dbstandard1 = db.Node( - section="section1", - name="standard1", - ntype=defs.Standard.__name__, - ) - group_standard = db.Node( - section="section2", - name="standard2", - ntype=defs.Standard.__name__, - ) - lone_standard = db.Node( - section="section3", - name="standard3", - ntype=defs.Standard.__name__, - ) - - self.collection.session.add(dbcre) - self.collection.session.add(dbgroup) - self.collection.session.add(dbstandard1) - self.collection.session.add(group_standard) - self.collection.session.add(lone_standard) - self.collection.session.commit() - - self.collection.session.add(db.Links(cre=dbcre.id, node=dbstandard1.id)) - self.collection.session.add(db.Links(cre=dbgroup.id, node=dbstandard1.id)) - self.collection.session.add(db.Links(cre=dbgroup.id, node=group_standard.id)) - self.collection.session.commit() - - # happy path, 1 group and 1 cre link to 1 standard - cres = self.collection.find_cres_of_node(dbstandard1) - - if not cres: - self.fail("Expected 2 cres") - self.assertEqual(len(cres), 2) - self.assertEqual(cres, [dbcre, dbgroup]) - - # group links to standard - cres = self.collection.find_cres_of_node(group_standard) - - if not cres: - self.fail("Expected 1 cre") - self.assertEqual(len(cres), 1) - self.assertEqual(cres, [dbgroup]) - - # no links = None - cres = self.collection.find_cres_of_node(lone_standard) - self.assertIsNone(cres) - - def test_get_CREs(self) -> None: - """Given: a cre 'C1' that links to cres both as a group and a cre and other standards - return the CRE in Document format""" - collection = db.Node_collection() - dbc1 = db.CRE(external_id="123-123", description="gcCD1", name="gcC1") - dbc2 = db.CRE(description="gcCD2", name="gcC2", external_id="444-444") - dbc3 = db.CRE(description="gcCD3", name="gcC3", external_id="555-555") - db_id_only = db.CRE( - description="c_get_by_internal_id_only", - name="cgbiio", - external_id="666-666", - ) - dbs1 = db.Node( - ntype=defs.Standard.__name__, - name="gcS2", - section="gc1", - subsection="gc2", - link="gc3", - version="gc1.1.1", - ) - - dbs2 = db.Node( - ntype=defs.Standard.__name__, - name="gcS3", - section="gc1", - subsection="gc2", - link="gc3", - version="gc3.1.2", - ) - - parent_cre = db.CRE( - external_id="999-999", description="parent cre", name="pcre" - ) - parent_cre2 = db.CRE( - external_id="888-888", description="parent cre2", name="pcre2" - ) - partOf_cre = db.CRE( - external_id="777-777", description="part of cre", name="poc" - ) - - collection.session.add(dbc1) - collection.session.add(dbc2) - collection.session.add(dbc3) - collection.session.add(dbs1) - collection.session.add(dbs2) - collection.session.add(db_id_only) - - collection.session.add(parent_cre) - collection.session.add(parent_cre2) - collection.session.add(partOf_cre) - collection.session.commit() - - collection.session.add( - db.InternalLinks(type="Contains", group=dbc1.id, cre=dbc2.id) - ) - collection.session.add( - db.InternalLinks(type="Contains", group=dbc1.id, cre=dbc3.id) - ) - collection.session.add(db.Links(type="Linked To", cre=dbc1.id, node=dbs1.id)) - - collection.session.add( - db.InternalLinks( - type=defs.LinkTypes.Contains.value, - group=parent_cre.id, - cre=partOf_cre.id, - ) - ) - collection.session.add( - db.InternalLinks( - type=defs.LinkTypes.Contains.value, - group=parent_cre2.id, - cre=partOf_cre.id, - ) - ) - collection.session.commit() - self.maxDiff = None - - # we can retrieve children cres - self.assertEqual( - [ - db.CREfromDB(parent_cre).add_link( - defs.Link( - document=db.CREfromDB(partOf_cre), ltype=defs.LinkTypes.Contains - ) - ) - ], - collection.get_CREs(external_id=parent_cre.external_id), - ) - self.assertEqual( - [ - db.CREfromDB(parent_cre2).add_link( - defs.Link( - document=db.CREfromDB(partOf_cre), ltype=defs.LinkTypes.Contains - ) - ) - ], - collection.get_CREs(external_id=parent_cre2.external_id), - ) - - # we can retrieve children cres with inverted multiple (PartOf) links to their parents - self.assertEqual( - [ - db.CREfromDB(partOf_cre) - .add_link( - defs.Link( - document=db.CREfromDB(parent_cre), ltype=defs.LinkTypes.PartOf - ) - ) - .add_link( - defs.Link( - document=db.CREfromDB(parent_cre2), ltype=defs.LinkTypes.PartOf - ) - ) - ], - collection.get_CREs(external_id=partOf_cre.external_id), - ) - - cd1 = defs.CRE(id="123-123", description="gcCD1", name="gcC1") - cd2 = defs.CRE(id="444-444", description="gcCD2", name="gcC2") - cd3 = defs.CRE(id="555-555", description="gcCD3", name="gcC3") - c_id_only = defs.CRE( - id="666-666", description="c_get_by_internal_id_only", name="cgbiio" - ) - - expected = [ - copy(cd1) - .add_link( - defs.Link( - ltype=defs.LinkTypes.LinkedTo, - document=defs.Standard( - name="gcS2", - section="gc1", - subsection="gc2", - hyperlink="gc3", - version="gc1.1.1", - ), - ) - ) - .add_link( - defs.Link( - ltype=defs.LinkTypes.Contains, - document=copy(cd2), - ) - ) - .add_link(defs.Link(ltype=defs.LinkTypes.Contains, document=copy(cd3))) - ] - self.maxDiff = None - shallow_cd1 = copy(cd1) - shallow_cd1.links = [] - cd2.add_link(defs.Link(ltype=defs.LinkTypes.PartOf, document=shallow_cd1)) - cd3.add_link(defs.Link(ltype=defs.LinkTypes.PartOf, document=shallow_cd1)) - - # empty returns empty - self.assertEqual([], collection.get_CREs()) - - # getting "group cre 1" by name returns gcC1 - res = collection.get_CREs(name="gcC1") - self.assertEqual(len(expected), len(res)) - self.assertCountEqual(expected[0].todict(), res[0].todict()) - - # getting "group cre 1" by id returns gcC1 - res = collection.get_CREs(external_id="123-123") - self.assertEqual(len(expected), len(res)) - self.assertCountEqual(expected[0].todict(), res[0].todict()) - - # getting "group cre 1" by partial id returns gcC1 - res = collection.get_CREs(external_id="12%", partial=True) - self.assertEqual(len(expected), len(res)) - self.assertCountEqual(expected[0].todict(), res[0].todict()) - - # getting "group cre 1" by partial name returns gcC1, gcC2 and gcC3 - res = collection.get_CREs(name="gcC%", partial=True) - self.assertEqual(3, len(res)) - self.assertCountEqual( - [expected[0].todict(), cd2.todict(), cd3.todict()], - [r.todict() for r in res], - ) - - # getting "group cre 1" by partial name and partial id returns gcC1 - res = collection.get_CREs(external_id="1%", name="gcC%", partial=True) - self.assertEqual(len(expected), len(res)) - self.assertCountEqual(expected[0].todict(), res[0].todict()) - - # getting "group cre 1" by description returns gcC1 - res = collection.get_CREs(description="gcCD1") - self.assertEqual(len(expected), len(res)) - self.assertCountEqual(expected[0].todict(), res[0].todict()) - - # getting "group cre 1" by partial id and partial description returns gcC1 - res = collection.get_CREs(external_id="1%", description="gcC%", partial=True) - self.assertEqual(len(expected), len(res)) - self.assertCountEqual(expected[0].todict(), res[0].todict()) - - # getting all the gcC* cres by partial name and partial description returns gcC1, gcC2, gcC3 - res = collection.get_CREs(description="gcC%", name="gcC%", partial=True) - want = [expected[0], cd2, cd3] - for el in res: - found = False - for wel in want: - if el.todict() == wel.todict(): - found = True - self.assertTrue(found) - - self.assertEqual([], collection.get_CREs(external_id="123-123", name="gcC5")) - self.assertEqual([], collection.get_CREs(external_id="1234")) - self.assertEqual([], collection.get_CREs(name="gcC5")) - - # add a standard to gcC1 - collection.session.add(db.Links(type="Linked To", cre=dbc1.id, node=dbs2.id)) - - only_gcS2 = deepcopy(expected) # save a copy of the current expected - expected[0].add_link( - defs.Link( - ltype=defs.LinkTypes.LinkedTo, - document=defs.Standard( - name="gcS3", - section="gc1", - subsection="gc2", - hyperlink="gc3", - version="gc3.1.2", - ), - ) - ) - # we can retrieve the cre with the standard - res = collection.get_CREs(name="gcC1") - self.assertCountEqual(expected[0].todict(), res[0].todict()) - - # we can retrieve ONLY the standard - res = collection.get_CREs(name="gcC1", include_only=["gcS2"]) - self.assertDictEqual(only_gcS2[0].todict(), res[0].todict()) - - ccd2 = copy(cd2) - ccd2.links = [] - ccd3 = copy(cd3) - ccd3.links = [] - no_standards = [ - copy(cd1) - .add_link( - defs.Link( - ltype=defs.LinkTypes.Contains, - document=ccd2, - ) - ) - .add_link(defs.Link(ltype=defs.LinkTypes.Contains, document=ccd3)) - ] - - # if the standard is not linked, we retrieve as normal - res = collection.get_CREs(name="gcC1", include_only=["gcS0"]) - self.assertEqual(no_standards, res) - - self.assertEqual([c_id_only], collection.get_CREs(internal_id=db_id_only.id)) - - def test_get_standards(self) -> None: - """Given: a Standard 'S1' that links to cres - return the Standard in Document format""" - collection = db.Node_collection() - docs: Dict[str, Union[db.CRE, db.Node]] = { - "dbc1": db.CRE(external_id="123-123", description="CD1", name="C1"), - "dbc2": db.CRE(external_id="222-222", description="CD2", name="C2"), - "dbc3": db.CRE(external_id="333-333", description="CD3", name="C3"), - "dbs1": db.Node( - ntype=defs.Standard.__name__, - name="S1", - section="111-111", - section_id="123-123", - subsection="222-222", - link="333-333", - version="4", - ), - } - links = [("dbc1", "dbs1"), ("dbc2", "dbs1"), ("dbc3", "dbs1")] - for k, v in docs.items(): - collection.session.add(v) - collection.session.commit() - - for cre, standard in links: - collection.session.add( - db.Links(type="Linked To", cre=docs[cre].id, node=docs[standard].id) - ) - collection.session.commit() - - expected = [ - defs.Standard( - name="S1", - section="111-111", - sectionID="123-123", - subsection="222-222", - hyperlink="333-333", - version="4", - links=[ - defs.Link( - ltype=defs.LinkTypes.LinkedTo, - document=defs.CRE(id="123-123", name="C1", description="CD1"), - ), - defs.Link( - ltype=defs.LinkTypes.LinkedTo, - document=defs.CRE(id="222-222", name="C2", description="CD2"), - ), - defs.Link( - ltype=defs.LinkTypes.LinkedTo, - document=defs.CRE(id="333-333", name="C3", description="CD3"), - ), - ], - ) - ] - - res = collection.get_nodes(name="S1") - self.assertEqual(expected, res) - - def test_get_nodes_with_pagination(self) -> None: - """Given: a Standard 'S1' that links to cres - return the Standard in Document format and the total pages and the page we are in - """ - collection = db.Node_collection() - docs: Dict[str, Union[db.Node, db.CRE]] = { - "dbc1": db.CRE(external_id="123-123", description="CD1", name="C1"), - "dbc2": db.CRE(external_id="222-222", description="CD2", name="C2"), - "dbc3": db.CRE(external_id="333-333", description="CD3", name="C3"), - "dbs1": db.Node( - name="S1", - section="111-111", - section_id="123-123", - subsection="222-222", - link="333-333", - version="4", - ntype=defs.Standard.__name__, - ), - } - links = [("dbc1", "dbs1"), ("dbc2", "dbs1"), ("dbc3", "dbs1")] - for k, v in docs.items(): - collection.session.add(v) - collection.session.commit() - - for cre, standard in links: - collection.session.add( - db.Links( - cre=docs[cre].id, - node=docs[standard].id, - type=defs.LinkTypes.LinkedTo, - ) - ) - collection.session.commit() - - expected = [ - defs.Standard( - name="S1", - section="111-111", - sectionID="123-123", - subsection="222-222", - hyperlink="333-333", - version="4", - links=[ - defs.Link( - document=defs.CRE(name="C1", description="CD1", id="123-123"), - ltype=defs.LinkTypes.LinkedTo, - ), - defs.Link( - document=defs.CRE(id="222-222", name="C2", description="CD2"), - ltype=defs.LinkTypes.LinkedTo, - ), - defs.Link( - document=defs.CRE(id="333-333", name="C3", description="CD3"), - ltype=defs.LinkTypes.LinkedTo, - ), - ], - ) - ] - total_pages, res, _ = collection.get_nodes_with_pagination(name="S1") - self.assertEqual(total_pages, 1) - self.assertEqual(expected, res) - - only_c1 = [ - defs.Standard( - name="S1", - section="111-111", - sectionID="123-123", - subsection="222-222", - hyperlink="333-333", - version="4", - links=[ - defs.Link( - document=defs.CRE(name="C1", description="CD1", id="123-123"), - ltype=defs.LinkTypes.LinkedTo, - ) - ], - ) - ] - _, res, _ = collection.get_nodes_with_pagination(name="S1", include_only=["C1"]) - self.assertEqual(only_c1, res) - _, res, _ = collection.get_nodes_with_pagination( - name="S1", include_only=["123-123"] - ) - self.assertEqual(only_c1, res) - - self.assertEqual( - collection.get_nodes_with_pagination(name="this should not exit"), - (None, None, None), - ) - - def test_add_internal_link(self) -> None: - """test that internal links are added successfully, - edge cases: - cre or group don't exist - called on a cycle scenario""" - - cres = { - "dbca": self.collection.add_cre( - defs.CRE(id="111-111", description="CA", name="CA") - ), - "dbcb": self.collection.add_cre( - defs.CRE(id="222-222", description="CB", name="CB") - ), - "dbcc": self.collection.add_cre( - defs.CRE(id="333-333", description="CC", name="CC") - ), - } - - # happy path - self.collection.add_internal_link( - higher=cres["dbca"], lower=cres["dbcb"], ltype=defs.LinkTypes.Related - ) - - # "happy path, internal link exists" - res = ( - self.collection.session.query(db.InternalLinks) - .filter( - db.InternalLinks.group == cres["dbca"].id, - db.InternalLinks.cre == cres["dbcb"].id, - ) - .first() - ) - self.assertEqual((res.group, res.cre), (cres["dbca"].id, cres["dbcb"].id)) - - # no cycle, free to insert - self.collection.add_internal_link( - higher=cres["dbcb"], lower=cres["dbcc"], ltype=defs.LinkTypes.Related - ) - res = ( - self.collection.session.query(db.InternalLinks) - .filter( - db.InternalLinks.group == cres["dbcb"].id, - db.InternalLinks.cre == cres["dbcc"].id, - ) - .first() - ) - self.assertEqual((res.group, res.cre), (cres["dbcb"].id, cres["dbcc"].id)) - - # introdcues a cycle, should not be inserted - self.collection.add_internal_link( - higher=cres["dbcc"], lower=cres["dbca"], ltype=defs.LinkTypes.Related - ) - - # cycles are not inserted branch - none_res = ( - self.collection.session.query(db.InternalLinks) - .filter( - db.InternalLinks.group == cres["dbcc"].id, - db.InternalLinks.cre == cres["dbca"].id, - ) - .one_or_none() - ) - self.assertIsNone(none_res) - - def test_text_search(self) -> None: - """Given: - a cre(id="111-111"23-456,name=foo,description='lorem ipsum foo+bar') - a standard(name=Bar,section=blah,subsection=foo, hyperlink='https://example.com/blah/foo') - a standard(name=Bar,section=blah,subsection=foo1, hyperlink='https://example.com/blah/foo1') - a standard(name=Bar,section=blah1,subsection=foo, hyperlink='https://example.com/blah1/foo') - - full_text_search('123-456') returns cre:foo - full_text_search('CRE:foo') and full_text_search('CRE foo') returns cre:foo - full_text_search('CRE:123-456') and full_text_search('CRE 123-456') returns cre:foo - - full_text_search('Standard:Bar') and full_text_search('Standard Bar') returns: [standard:Bar:blah:foo, - standard:Bar:blah:foo1, - standard:Bar:blah1:foo] - - full_text_search('Standard:blah') and full_text_search('Standard blah') returns [standard:Bar::blah:foo, - standard:Bar:blah:foo1] - full_text_search('Standard:blah:foo') returns [standard:Bar:blah:foo] - full_text_search('Standard:foo') returns [standard:Bar:blah:foo, - standard:Bar:blah1:foo] - - - full_text_search('ipsum') returns cre:foo - full_text_search('foo') returns [cre:foo,standard:Bar:blah:foo, standard:Bar:blah:foo1,standard:Bar:blah1:foo] - """ - collection = db.Node_collection() - cre = defs.CRE( - id="123-456", name="textSearchCRE", description="lorem ipsum tsSection+tsC" - ) - collection.add_cre(cre) - - s1 = defs.Standard( - name="textSearchStandard", - section="tsSection", - subsection="tsSubSection", - hyperlink="https://example.com/tsSection/tsSubSection", - ) - collection.add_node(s1) - s2 = defs.Standard( - name="textSearchStandard", - section="tsSection", - subsection="tsSubSection1", - hyperlink="https://example.com/tsSection/tsSubSection1", - ) - collection.add_node(s2) - s3 = defs.Standard( - name="textSearchStandard", - section="tsSection1", - subsection="tsSubSection1", - hyperlink="https://example.com/tsSection1/tsSubSection1", - ) - collection.add_node(s3) - t1 = defs.Tool( - name="textSearchTool", - tooltype=defs.ToolTypes.Offensive, - hyperlink="https://example.com/textSearchTool", - description="test text search with tool", - sectionID="15", - section="rule 15", - ) - collection.add_node(t1) - collection.session.commit() - expected: Dict[str, List[Any]] = { - "123-456": [cre], - "CRE:textSearchCRE": [cre], - "CRE textSearchCRE": [cre], - "CRE:123-456": [cre], - "CRE 123-456": [cre], - "Standard:textSearchStandard": [s1, s2, s3], - "Standard textSearchStandard": [s1, s2, s3], - "Standard:tsSection": [s1, s2], - "Standard tsSection": [s1, s2], - "Standard:tsSection:tsSubSection1": [s2], - "Standard tsSection tsSubSection1": [s2], - "Standard:tsSubSection1": [s2, s3], - "Standard tsSubSection1": [s2, s3], - "Standard:https://example.com/tsSection/tsSubSection1": [s2], - "Standard https://example.com/tsSection1/tsSubSection1": [s3], - "https://example.com/tsSection": [s1, s2, s3], - "ipsum": [cre], - "tsSection": [cre, s1, s2, s3], - "https://example.com/textSearchTool": [t1], - "text search": [t1], - } - self.maxDiff = None - for k, val in expected.items(): - res = self.collection.text_search(k) - self.assertCountEqual(res, val) - - def test_dbNodeFromNode(self) -> None: - data = { - "tool": defs.Tool( - name="fooTool", - description="lorem ipsum tsSection+tsC", - tooltype=defs.ToolTypes.Defensive, - tags=["111-111", "222-222", "333-333"], - ), - "standard": defs.Standard( - name="stand", section="s1", subsection="s2", version="s3" - ), - "code": defs.Code( - name="ccc", - description="c2", - hyperlink="https://example.com/code/hyperlink", - tags=["111-111", "222-222"], - ), - } - expected = { - "tool": db.Node( - name="fooTool", - description="lorem ipsum tsSection+tsC", - tags=",".join( - [defs.ToolTypes.Defensive.value, "111-111", "222-222", "333-333"] - ), - ntype=defs.Credoctypes.Tool.value, - ), - "standard": db.Node( - name="stand", - section="s1", - subsection="s2", - version="s3", - ntype=defs.Credoctypes.Standard.value, - ), - "code": db.Node( - name="ccc", - description="c2", - link="https://example.com/code/hyperlink", - tags="1,2", - ntype=defs.Credoctypes.Code.value, - ), - } - for k, v in data.items(): - nd = db.dbNodeFromNode(v) - for vname, var in vars(nd).items(): - if var and not vname.startswith("_"): - self.assertEqual(var, vars(expected[k]).get(vname)) - - def test_nodeFromDB(self) -> None: - expected = { - "tool": defs.Tool( - name="fooTool", - description="lorem ipsum tsSection+tsC", - tooltype=defs.ToolTypes.Defensive, - tags=["111-111", "222-222", "333-333"], - ), - "standard": defs.Standard( - name="stand", section="s1", subsection="s2", version="s3" - ), - "code": defs.Code( - name="ccc", - description="c2", - hyperlink="https://example.com/code/hyperlink", - tags=["111-111", "222-222"], - ), - } - data = { - "tool": db.Node( - name="fooTool", - description="lorem ipsum tsSection+tsC", - tags=",".join( - [defs.ToolTypes.Defensive.value, "111-111", "222-222", "333-333"] - ), - ntype=defs.Credoctypes.Tool.value, - ), - "standard": db.Node( - name="stand", - section="s1", - subsection="s2", - version="s3", - ntype=defs.Credoctypes.Standard.value, - ), - "code": db.Node( - name="ccc", - description="c2", - link="https://example.com/code/hyperlink", - tags="111-111,222-222", - ntype=defs.Credoctypes.Code.value, - ), - } - for k, v in data.items(): - nd = db.nodeFromDB(v) - for vname, var in vars(nd).items(): - if var and not vname.startswith("_"): - self.assertCountEqual(var, vars(expected[k]).get(vname)) - - def test_object_select(self) -> None: - dbnode1 = db.Node( - name="fooTool", - description="lorem ipsum tsSection+tsC", - tags=f"{defs.ToolTypes.Defensive.value},1", - ) - dbnode2 = db.Node( - name="fooTool", - description="lorem2", - link="https://example.com/foo/bar", - tags=f"{defs.ToolTypes.Defensive.value},1", - ) - - self.collection = db.Node_collection() - collection = db.Node_collection() - collection.session.add(dbnode1) - collection.session.add(dbnode2) - self.assertEqual(collection.object_select(dbnode1), [dbnode1]) - self.assertEqual(collection.object_select(dbnode2), [dbnode2]) - self.assertCountEqual( - collection.object_select(db.Node(name="fooTool")), [dbnode1, dbnode2] - ) - - self.assertEqual(collection.object_select(None), []) - - def test_get_root_cres(self): - """Given: - 6 CRES: - * C0 <-- Root - * C1 <-- Root - * C2 Part Of C0 - * C3 Part Of C1 - * C4 Part Of C2 - * C5 Related to C0 - * C6 Part Of C1 - * C7 Contains C6 <-- Root - 3 Nodes: - * N0 Unlinked - * N1 Linked To C1 - * N2 Linked to C2 - * N3 Linked to C3 - * N4 Linked to C4 - Get_root_cres should return C0, C1 - """ - cres = [] - nodes = [] - dbcres = [] - dbnodes = [] - - # clean the db from setup - sqla.session.remove() - sqla.drop_all() - sqla.create_all() - - collection = db.Node_collection().with_graph() - - for i in range(0, 8): - if i == 0 or i == 1: - cres.append(defs.CRE(name=f">> C{i}", id=f"{i}{i}{i}-{i}{i}{i}")) - else: - cres.append(defs.CRE(name=f"C{i}", id=f"{i}{i}{i}-{i}{i}{i}")) - - dbcres.append(collection.add_cre(cres[i])) - nodes.append(defs.Standard(section=f"S{i}", name=f"N{i}")) - dbnodes.append(collection.add_node(nodes[i])) - cres[i].add_link( - defs.Link(document=copy(nodes[i]), ltype=defs.LinkTypes.LinkedTo) - ) - collection.add_link( - cre=dbcres[i], node=dbnodes[i], ltype=defs.LinkTypes.LinkedTo - ) - - cres[0].add_link( - defs.Link(document=cres[2].shallow_copy(), ltype=defs.LinkTypes.Contains) - ) - cres[1].add_link( - defs.Link(document=cres[3].shallow_copy(), ltype=defs.LinkTypes.Contains) - ) - cres[2].add_link( - defs.Link(document=cres[4].shallow_copy(), ltype=defs.LinkTypes.Contains) - ) - - cres[3].add_link( - defs.Link(document=cres[5].shallow_copy(), ltype=defs.LinkTypes.Contains) - ) - cres[6].add_link( - defs.Link(document=cres[7].shallow_copy(), ltype=defs.LinkTypes.PartOf) - ) - collection.add_internal_link( - higher=dbcres[0], lower=dbcres[2], ltype=defs.LinkTypes.Contains - ) - collection.add_internal_link( - higher=dbcres[1], lower=dbcres[3], ltype=defs.LinkTypes.Contains - ) - collection.add_internal_link( - higher=dbcres[2], lower=dbcres[4], ltype=defs.LinkTypes.Contains - ) - collection.add_internal_link( - higher=dbcres[3], lower=dbcres[5], ltype=defs.LinkTypes.Contains - ) - collection.add_internal_link( - higher=dbcres[7], lower=dbcres[6], ltype=defs.LinkTypes.Contains - ) - cres[7].add_link( - defs.Link(document=cres[6].shallow_copy(), ltype=defs.LinkTypes.Contains) - ) - - root_cres = collection.get_root_cres() - self.maxDiff = None - self.assertCountEqual(root_cres, [cres[0], cres[1], cres[7]]) - - @patch.object(db.NEO_DB, "gap_analysis") - def test_gap_analysis_disconnected(self, gap_mock): - collection = db.Node_collection() - collection.neo_db.connected = False - gap_mock.return_value = (None, None) - - self.assertEqual(db.gap_analysis(collection.neo_db, ["788-788", "b"]), None) - - @patch.object(db.NEO_DB, "gap_analysis") - def test_gap_analysis_no_nodes(self, gap_mock): - collection = db.Node_collection() - collection.neo_db.connected = True - - gap_mock.return_value = ([], []) - self.assertEqual( - db.gap_analysis(collection.neo_db, ["788-788", "b"]), - (["788-788", "b"], {}, {}), - ) - - @patch.object(db.NEO_DB, "gap_analysis") - def test_gap_analysis_no_links(self, gap_mock): - collection = db.Node_collection() - collection.neo_db.connected = True - - gap_mock.return_value = ([defs.CRE(name="bob", id="111-111")], []) - self.maxDiff = None - self.assertEqual( - db.gap_analysis(collection.neo_db, ["788-788", "b"]), - ( - ["788-788", "b"], - { - "111-111": { - "start": defs.CRE(name="bob", id="111-111"), - "paths": {}, - "extra": 0, - } - }, - {"111-111": {"paths": {}}}, - ), - ) - - @patch.object(db.NEO_DB, "gap_analysis") - def test_gap_analysis_one_link(self, gap_mock): - collection = db.Node_collection() - collection.neo_db.connected = True - path = [ - { - "end": defs.CRE(name="bob", id="111-111"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - { - "end": defs.CRE(name="bob", id="222-222"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - ] - gap_mock.return_value = ( - [defs.CRE(name="bob", id="788-788")], - [ - { - "start": defs.CRE(name="bob", id="788-788"), - "end": defs.CRE(name="bob", id="788-789"), - "path": path, - } - ], - ) - expected = ( - ["788-788", "788-789"], - { - "788-788": { - "start": defs.CRE(name="bob", id="788-788"), - "paths": { - "788-789": { - "end": defs.CRE(name="bob", id="788-789"), - "path": path, - "score": 0, - } - }, - "extra": 0, - } - }, - {"788-788": {"paths": {}}}, - ) - self.maxDiff = None - self.assertEqual( - db.gap_analysis(collection.neo_db, ["788-788", "788-789"]), expected - ) - - @patch.object(db.NEO_DB, "gap_analysis") - def test_gap_analysis_one_weak_link(self, gap_mock): - collection = db.Node_collection() - collection.neo_db.connected = True - path = [ - { - "end": defs.CRE(name="bob", id="111-111"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - { - "end": defs.CRE(name="bob", id="222-222"), - "relationship": "RELATED", - "start": defs.CRE(name="bob", id="111-111"), - }, - { - "end": defs.CRE(name="bob", id="111-111"), - "relationship": "RELATED", - "start": defs.CRE(name="bob", id="222-222"), - }, - { - "end": defs.CRE(name="bob", id="333-333"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="222-222"), - }, - ] - gap_mock.return_value = ( - [defs.CRE(name="bob", id="111-111")], - [ - { - "start": defs.CRE(name="bob", id="111-111"), - "end": defs.CRE(name="bob", id="222-222"), - "path": path, - } - ], - ) - expected = ( - ["788-788", "b"], - { - "111-111": { - "start": defs.CRE(name="bob", id="111-111"), - "paths": {}, - "extra": 1, - } - }, - { - "111-111": { - "paths": { - "222-222": { - "end": defs.CRE(name="bob", id="222-222"), - "path": path, - "score": 4, - } - } - } - }, - ) - self.maxDiff = None - self.assertEqual(db.gap_analysis(collection.neo_db, ["788-788", "b"]), expected) - - @patch.object(db.NEO_DB, "gap_analysis") - def test_gap_analysis_duplicate_link_path_existing_lower(self, gap_mock): - collection = db.Node_collection() - collection.neo_db.connected = True - path = [ - { - "end": defs.CRE(name="bob", id="111-111"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - { - "end": defs.CRE(name="bob", id="222-222"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - ] - path2 = [ - { - "end": defs.CRE(name="bob", id="111-111"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - { - "end": defs.CRE(name="bob", id="222-222"), - "relationship": "RELATED", - "start": defs.CRE(name="bob", id="788-788"), - }, - ] - gap_mock.return_value = ( - [defs.CRE(name="bob", id="111-111")], - [ - { - "start": defs.CRE(name="bob", id="111-111"), - "end": defs.CRE(name="bob", id="222-222"), - "path": path, - }, - { - "start": defs.CRE(name="bob", id="111-111"), - "end": defs.CRE(name="bob", id="222-222"), - "path": path2, - }, - ], - ) - expected = ( - ["788-788", "b"], - { - "111-111": { - "start": defs.CRE(name="bob", id="111-111"), - "paths": { - "222-222": { - "end": defs.CRE(name="bob", id="222-222"), - "path": path, - "score": 0, - } - }, - "extra": 0, - }, - }, - {"111-111": {"paths": {}}}, - ) - self.assertEqual(db.gap_analysis(collection.neo_db, ["788-788", "b"]), expected) - - @patch.object(db.NEO_DB, "gap_analysis") - def test_gap_analysis_duplicate_link_path_existing_lower_new_in_extras( - self, gap_mock - ): - collection = db.Node_collection() - collection.neo_db.connected = True - path = [ - { - "end": defs.CRE(name="bob", id="111-111"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - { - "end": defs.CRE(name="bob", id="222-222"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - ] - path2 = [ - { - "end": defs.CRE(name="bob", id="111-111"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - { - "end": defs.CRE(name="bob", id="222-222"), - "relationship": "RELATED", - "start": defs.CRE(name="bob", id="788-788"), - }, - { - "end": defs.CRE(name="bob", id="222-222"), - "relationship": "RELATED", - "start": defs.CRE(name="bob", id="788-788"), - }, - ] - gap_mock.return_value = ( - [defs.CRE(name="bob", id="111-111")], - [ - { - "start": defs.CRE(name="bob", id="111-111"), - "end": defs.CRE(name="bob", id="222-222"), - "path": path, - }, - { - "start": defs.CRE(name="bob", id="111-111"), - "end": defs.CRE(name="bob", id="222-222"), - "path": path2, - }, - ], - ) - expected = ( - ["788-788", "b"], - { - "111-111": { - "start": defs.CRE(name="bob", id="111-111"), - "paths": { - "222-222": { - "end": defs.CRE(name="bob", id="222-222"), - "path": path, - "score": 0, - } - }, - "extra": 0, - }, - }, - {"111-111": {"paths": {}}}, - ) - self.assertEqual(db.gap_analysis(collection.neo_db, ["788-788", "b"]), expected) - - @patch.object(db.NEO_DB, "gap_analysis") - def test_gap_analysis_duplicate_link_path_existing_higher(self, gap_mock): - collection = db.Node_collection() - collection.neo_db.connected = True - path = [ - { - "end": defs.CRE(name="bob", id="111-111"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - { - "end": defs.CRE(name="bob", id="222-222"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - ] - path2 = [ - { - "end": defs.CRE(name="bob", id="111-111"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - { - "end": defs.CRE(name="bob", id="222-222"), - "relationship": "RELATED", - "start": defs.CRE(name="bob", id="788-788"), - }, - ] - gap_mock.return_value = ( - [defs.CRE(name="bob", id="111-111")], - [ - { - "start": defs.CRE(name="bob", id="111-111"), - "end": defs.CRE(name="bob", id="222-222"), - "path": path2, - }, - { - "start": defs.CRE(name="bob", id="111-111"), - "end": defs.CRE(name="bob", id="222-222"), - "path": path, - }, - ], - ) - expected = ( - ["788-788", "b"], - { - "111-111": { - "start": defs.CRE(name="bob", id="111-111"), - "paths": { - "222-222": { - "end": defs.CRE(name="bob", id="222-222"), - "path": path, - "score": 0, - } - }, - "extra": 0, - } - }, - {"111-111": {"paths": {}}}, - ) - self.assertEqual(db.gap_analysis(collection.neo_db, ["788-788", "b"]), expected) - - @patch.object(db.NEO_DB, "gap_analysis") - def test_gap_analysis_duplicate_link_path_existing_higher_and_in_extras( - self, gap_mock - ): - collection = db.Node_collection() - collection.neo_db.connected = True - path = [ - { - "end": defs.CRE(name="bob", id="111-111"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - { - "end": defs.CRE(name="bob", id="222-222"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - ] - path2 = [ - { - "end": defs.CRE(name="bob", id="111-111"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob", id="788-788"), - }, - { - "end": defs.CRE(name="bob", id="222-222"), - "relationship": "RELATED", - "start": defs.CRE(name="bob", id="788-788"), - }, - { - "end": defs.CRE(name="bob", id="222-222"), - "relationship": "RELATED", - "start": defs.CRE(name="bob", id="788-788"), - }, - ] - gap_mock.return_value = ( - [defs.CRE(name="bob", id="111-111")], - [ - { - "start": defs.CRE(name="bob", id="111-111"), - "end": defs.CRE(name="bob", id="222-222"), - "path": path2, - }, - { - "start": defs.CRE(name="bob", id="111-111"), - "end": defs.CRE(name="bob", id="222-222"), - "path": path, - }, - ], - ) - expected = ( - ["788-788", "b"], - { - "111-111": { - "start": defs.CRE(name="bob", id="111-111"), - "paths": { - "222-222": { - "end": defs.CRE(name="bob", id="222-222"), - "path": path, - "score": 0, - } - }, - "extra": 0, - } - }, - {"111-111": {"paths": {}}}, - ) - self.assertEqual(db.gap_analysis(collection.neo_db, ["788-788", "b"]), expected) - - @patch.object(db.NEO_DB, "gap_analysis") - def test_gap_analysis_dump_to_cache(self, gap_mock): - collection = db.Node_collection() - collection.neo_db.connected = True - path = [ - { - "end": defs.CRE(name="bob1", id="111-111"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob7", id="788-788"), - "score": 0, - }, - { - "end": defs.CRE(name="bob2", id="222-222"), - "relationship": "RELATED", - "start": defs.CRE(name="bob1", id="111-111"), - "score": 2, - }, - { - "end": defs.CRE(name="bob1", id="111-111"), - "relationship": "RELATED", - "start": defs.CRE(name="bob2", id="222-222"), - "score": 2, - }, - { - "end": defs.CRE(name="bob3", id="333-333"), - "relationship": "LINKED_TO", - "start": defs.CRE(name="bob2", id="222-222"), - "score": 4, - }, - ] - gap_mock.return_value = ( - [defs.CRE(name="bob7", id="788-788")], - [ - { - "start": defs.CRE(name="bob7", id="788-788"), - "end": defs.CRE(name="bob2", id="222-222"), - "path": path, - } - ], - ) - - expected_response = ( - ["788-788", "222-222"], - { - "788-788": { - "start": defs.CRE(name="bob7", id="788-788"), - "paths": {}, - "extra": 1, - } - }, - { - "788-788": { - "paths": { - "222-222": { - "end": defs.CRE(name="bob2", id="222-222"), - "path": path, - "score": 4, - } - } - } - }, - ) - response = db.gap_analysis(collection.neo_db, ["788-788", "222-222"]) - - self.maxDiff = None - self.assertEqual( - response, (expected_response[0], expected_response[1], expected_response[2]) - ) - self.assertEqual( - collection.gap_analysis_exists(make_resources_key(["788-788", "222-222"])), - True, - ) - self.assertEqual( - collection.get_gap_analysis_result( - make_resources_key(["788-788", "222-222"]) - ), - flask_json.dumps({"result": expected_response[1]}), - ) - self.assertEqual( - collection.get_gap_analysis_result( - make_subresources_key(["788-788", "222-222"], "788-788") - ), - flask_json.dumps({"result": expected_response[2]["788-788"]}), - ) - - def test_neo_db_parse_node_code(self): - name = "name" - description = "description" - tags = "tags" - version = "version" - hyperlink = "version" - expected = defs.Code( - name=name, - description=description, - tags=tags, - version=version, - hyperlink=hyperlink, - links=[ - defs.Link( - defs.CRE(id="123-123", description="gcCD2", name="gcC2"), "Related" - ) - ], - ) - graph_node = db.NeoCode( - name=name, - description=description, - tags=tags, - version=version, - hyperlink=hyperlink, - related=[ - db.NeoCRE(external_id="123-123", description="gcCD2", name="gcC2"), - ], - ) - - self.assertEqual(db.NEO_DB.parse_node(graph_node).todict(), expected.todict()) - - def test_neo_db_parse_node_standard(self): - name = "name" - description = "description" - tags = "tags" - version = "version" - section = "section" - sectionID = "sectionID" - subsection = "subsection" - hyperlink = "version" - expected = defs.Standard( - name=name, - description=description, - tags=tags, - version=version, - section=section, - sectionID=sectionID, - subsection=subsection, - hyperlink=hyperlink, - links=[ - defs.Link( - defs.CRE(id="123-123", description="gcCD2", name="gcC2"), "Related" - ) - ], - ) - graph_node = db.NeoStandard( - name=name, - description=description, - tags=tags, - version=version, - section=section, - section_id=sectionID, - subsection=subsection, - hyperlink=hyperlink, - related=[ - db.NeoCRE(external_id="123-123", description="gcCD2", name="gcC2"), - ], - ) - self.assertEqual(db.NEO_DB.parse_node(graph_node).todict(), expected.todict()) - - def test_neo_db_parse_node_tool(self): - name = "name" - description = "description" - tags = "tags" - version = "version" - section = "section" - sectionID = "sectionID" - subsection = "subsection" - hyperlink = "version" - tooltype = defs.ToolTypes.Defensive - expected = defs.Tool( - name=name, - tooltype=tooltype, - description=description, - tags=tags, - version=version, - section=section, - sectionID=sectionID, - subsection=subsection, - hyperlink=hyperlink, - links=[ - defs.Link( - defs.CRE(id="123-123", description="gcCD2", name="gcC2"), "Related" - ) - ], - ) - graph_node = db.NeoTool( - name=name, - description=description, - tooltype=tooltype, - tags=tags, - version=version, - section=section, - section_id=sectionID, - subsection=subsection, - hyperlink=hyperlink, - related=[ - db.NeoCRE(external_id="123-123", description="gcCD2", name="gcC2"), - ], - ) - self.assertEqual(db.NEO_DB.parse_node(graph_node).todict(), expected.todict()) - - def test_neo_db_parse_node_cre(self): - name = "name" - description = "description" - tags = "tags" - external_id = "123-123" - expected = defs.CRE( - name=name, - description=description, - id=external_id, - tags=tags, - links=[ - defs.Link( - defs.CRE(id="123-123", description="gcCD2", name="gcC2"), "Contains" - ), - defs.Link( - defs.CRE(id="123-123", description="gcCD3", name="gcC3"), "Contains" - ), - defs.Link( - defs.Standard( - hyperlink="gc3", - name="gcS2", - section="gc1", - subsection="gc2", - version="gc1.1.1", - ), - "Linked To", - ), - ], - ) - graph_node = db.NeoCRE( - name=name, - description=description, - tags=tags, - external_id=external_id, - contained_in=[], - contains=[ - db.NeoCRE(external_id="123-123", description="gcCD2", name="gcC2"), - db.NeoCRE(external_id="123-123", description="gcCD3", name="gcC3"), - ], - linked=[ - db.NeoStandard( - hyperlink="gc3", - name="gcS2", - section="gc1", - subsection="gc2", - version="gc1.1.1", - ) - ], - same_as=[], - related=[], - auto_linked_to=[], - ) - - parsed = db.NEO_DB.parse_node(graph_node) - self.maxDiff = None - self.assertEqual(parsed.todict(), expected.todict()) - - def test_neo_db_parse_node_no_links_cre(self): - name = "name" - description = "description" - tags = "tags" - external_id = "123-123" - expected = defs.CRE( - name=name, description=description, id=external_id, tags=tags, links=[] - ) - graph_node = db.NeoCRE( - name=name, - description=description, - tags=tags, - external_id=external_id, - contained_in=[], - contains=[ - db.NeoCRE(external_id="123-123", description="gcCD2", name="gcC2"), - db.NeoCRE(external_id="123-123", description="gcCD3", name="gcC3"), - ], - linked=[ - db.NeoStandard( - hyperlink="gc3", - name="gcS2", - section="gc1", - subsection="gc2", - version="gc1.1.1", - ) - ], - same_as=[], - related=[], - ) - - parsed = db.NEO_DB.parse_node_no_links(graph_node) - self.maxDiff = None - self.assertEqual(parsed.todict(), expected.todict()) - - def test_neo_db_parse_node_Document(self): - name = "name" - id = "id" - description = "description" - tags = "tags" - graph_node = db.NeoDocument( - name=name, - document_id=id, - description=description, - tags=tags, - ) - with self.assertRaises(Exception) as cm: - db.NEO_DB.parse_node(graph_node) - - self.assertEqual(str(cm.exception), "Shouldn't be parsing a NeoDocument") - - def test_neo_db_parse_node_Node(self): - name = "name" - id = "id" - description = "description" - tags = "tags" - graph_node = db.NeoNode( - name=name, - document_id=id, - description=description, - tags=tags, - ) - with self.assertRaises(Exception) as cm: - db.NEO_DB.parse_node(graph_node) - - self.assertEqual(str(cm.exception), "Shouldn't be parsing a NeoNode") - - def test_get_embeddings_by_doc_type_paginated(self): - """Given: a range of embedding for Nodes and a range of embeddings for CREs - when called with doc_type CRE return the cre embeddings - when called with doc_type Standard/Tool return the node embeddings""" - # add cre embeddings - cre_embeddings = [] - for i in range(0, 10): - dbca = db.CRE(external_id=f"{i}", description=f"C{i}", name=f"C{i}") - self.collection.session.add(dbca) - self.collection.session.commit() - - embeddings = [random.uniform(-1, 1) for e in range(0, 768)] - embeddings_text = "".join( - random.choices(string.ascii_uppercase + string.digits, k=100) - ) - cre_embeddings.append( - self.collection.add_embedding( - db_object=dbca, - doctype=defs.Credoctypes.CRE.value, - embeddings=embeddings, - embedding_text=embeddings_text, - ) - ) - - # add node embeddings - node_embeddings = [] - for i in range(0, 10): - dbsa = db.Node( - subsection=f"4.5.{i}", - section=f"FooStand-{i}", - name="BarStand", - link="https://example.com", - ntype=defs.Credoctypes.Standard.value, - ) - self.collection.session.add(dbsa) - self.collection.session.commit() - - embeddings = [random.uniform(-1, 1) for e in range(0, 768)] - embeddings_text = "".join( - random.choices(string.ascii_uppercase + string.digits, k=100) - ) - ne = self.collection.add_embedding( - db_object=dbsa, - doctype=defs.Credoctypes.Standard.value, - embeddings=embeddings, - embedding_text=embeddings_text, - ) - node_embeddings.append(ne) - - ( - cre_emb, - total_pages, - curr_page, - ) = self.collection.get_embeddings_by_doc_type_paginated( - defs.Credoctypes.CRE.value, page=1, per_page=1 - ) - self.assertNotEqual(list(cre_emb.keys())[0], "") - self.assertIn(list(cre_emb.keys())[0], list([e.cre_id for e in cre_embeddings])) - self.assertNotIn( - list(cre_emb.keys())[0], list([e.node_id for e in cre_embeddings]) - ) - self.assertEqual(total_pages, 10) - self.assertEqual(curr_page, 1) - - ( - node_emb, - total_pages, - curr_page, - ) = self.collection.get_embeddings_by_doc_type_paginated( - defs.Credoctypes.Standard.value, page=1, per_page=1 - ) - self.assertNotEqual(list(node_emb.keys())[0], "") - self.assertIn( - list(node_emb.keys())[0], list([e.node_id for e in node_embeddings]) - ) - self.assertNotIn( - list(node_emb.keys())[0], list([e.cre_id for e in cre_embeddings]) - ) - self.assertEqual(total_pages, 10) - self.assertEqual(curr_page, 1) - - ( - tool_emb, - total_pages, - curr_page, - ) = self.collection.get_embeddings_by_doc_type_paginated( - defs.Credoctypes.Tool.value, page=1, per_page=1 - ) - self.assertEqual(total_pages, 0) - self.assertEqual(tool_emb, {}) - - def test_get_embeddings_by_doc_type(self): - """Given: a range of embedding for Nodes and a range of embeddings for CREs - when called with doc_type CRE return the cre embeddings - when called with doc_type Standard/Tool return the node embeddings""" - # add cre embeddings - cre_embeddings = [] - for i in range(0, 10): - dbca = db.CRE(external_id=f"{i}", description=f"C{i}", name=f"C{i}") - self.collection.session.add(dbca) - self.collection.session.commit() - - embeddings = [random.uniform(-1, 1) for e in range(0, 768)] - embeddings_text = "".join( - random.choices(string.ascii_uppercase + string.digits, k=100) - ) - cre_embeddings.append( - self.collection.add_embedding( - db_object=dbca, - doctype=defs.Credoctypes.CRE.value, - embeddings=embeddings, - embedding_text=embeddings_text, - ) - ) - - # add node embeddings - node_embeddings = [] - for i in range(0, 10): - dbsa = db.Node( - subsection=f"4.5.{i}", - section=f"FooStand-{i}", - name="BarStand", - link="https://example.com", - ntype=defs.Credoctypes.Standard.value, - ) - self.collection.session.add(dbsa) - self.collection.session.commit() - - embeddings = [random.uniform(-1, 1) for e in range(0, 768)] - embeddings_text = "".join( - random.choices(string.ascii_uppercase + string.digits, k=100) - ) - ne = self.collection.add_embedding( - db_object=dbsa, - doctype=defs.Credoctypes.Standard.value, - embeddings=embeddings, - embedding_text=embeddings_text, - ) - node_embeddings.append(ne) - - cre_emb = self.collection.get_embeddings_by_doc_type(defs.Credoctypes.CRE.value) - self.assertNotEqual(list(cre_emb.keys())[0], "") - self.assertIn(list(cre_emb.keys())[0], list([e.cre_id for e in cre_embeddings])) - self.assertNotIn( - list(cre_emb.keys())[0], list([e.node_id for e in cre_embeddings]) - ) - - node_emb = self.collection.get_embeddings_by_doc_type( - defs.Credoctypes.Standard.value - ) - self.assertNotEqual(list(node_emb.keys())[0], "") - self.assertIn( - list(node_emb.keys())[0], list([e.node_id for e in node_embeddings]) - ) - self.assertNotIn( - list(node_emb.keys())[0], list([e.cre_id for e in cre_embeddings]) - ) - - tool_emb = self.collection.get_embeddings_by_doc_type( - defs.Credoctypes.Tool.value - ) - self.assertEqual(tool_emb, {}) - - def test_get_standard_names(self): - for s in ["sa", "sb", "sc", "sd"]: - for sub in ["suba", "subb", "subc", "subd"]: - self.collection.add_node( - defs.Standard(name=s, section=sub, subsection=sub) - ) - self.assertCountEqual( - ["BarStand", "Unlinked", "sa", "sb", "sc", "sd"], - self.collection.standards(), - ) - - def test_all_cres_with_pagination(self): - """""" - cres = [] - nodes = [] - dbcres = [] - dbnodes = [] - sqla.session.remove() - sqla.drop_all() - sqla.create_all() - collection = db.Node_collection() - for i in range(0, 8): - if i == 0 or i == 1: - cres.append(defs.CRE(name=f">> C{i}", id=f"{i}{i}{i}-{i}{i}{i}")) - else: - cres.append(defs.CRE(name=f"C{i}", id=f"{i}")) - - dbcres.append(collection.add_cre(cres[i])) - nodes.append(defs.Standard(section=f"S{i}", name=f"N{i}")) - dbnodes.append(collection.add_node(nodes[i])) - cres[i].add_link( - defs.Link(document=copy(nodes[i]), ltype=defs.LinkTypes.LinkedTo) - ) - collection.add_link( - cre=dbcres[i], node=dbnodes[i], ltype=defs.LinkTypes.LinkedTo - ) - - collection.session.commit() - - paginated_cres, page, total_pages = collection.all_cres_with_pagination( - page=1, per_page=2 - ) - self.maxDiff = None - # from pprint import pprint - # pprint(cres) - self.assertEqual(paginated_cres, [cres[0], cres[1]]) - self.assertEqual(page, 1) - self.assertEqual(total_pages, 4) - - def test_all_cres_with_pagination(self): - """""" - cres = [] - nodes = [] - dbcres = [] - dbnodes = [] - sqla.session.remove() - sqla.drop_all() - sqla.create_all() - collection = db.Node_collection() - for i in range(0, 8): - if i == 0 or i == 1: - cres.append(defs.CRE(name=f">> C{i}", id=f"{i}{i}{i}-{i}{i}{i}")) - else: - cres.append(defs.CRE(name=f"C{i}", id=f"{i}{i}{i}-{i}{i}{i}")) - - dbcres.append(collection.add_cre(cres[i])) - nodes.append(defs.Standard(section=f"S{i}", name=f"N{i}")) - dbnodes.append(collection.add_node(nodes[i])) - cres[i].add_link( - defs.Link(document=copy(nodes[i]), ltype=defs.LinkTypes.LinkedTo) - ) - collection.add_link( - cre=dbcres[i], node=dbnodes[i], ltype=defs.LinkTypes.LinkedTo - ) - - collection.session.commit() - - paginated_cres, page, total_pages = collection.all_cres_with_pagination( - page=1, per_page=2 - ) - self.maxDiff = None - self.assertEqual(paginated_cres, [cres[0], cres[1]]) - self.assertEqual(page, 1) - self.assertEqual(total_pages, 4) - - def test_get_cre_hierarchy(self) -> None: - # this needs a clean database and a clean graph so reinit everything - # sqla.session.remove() - # sqla.drop_all() - # sqla.create_all() - collection = self.collection # db.Node_collection().with_graph() - # collection.graph.with_graph(graph=nx.DiGraph(), graph_data=[]) - - _, inputDocs = export_format_data() - importItems = [] - for name, items in inputDocs.items(): - for item in items: - importItems.append(item) - if name == defs.Credoctypes.CRE: - dbitem = collection.add_cre(item) - else: - dbitem = collection.add_node(item) - for link in item.links: - if link.document.doctype == defs.Credoctypes.CRE: - linked_item = collection.add_cre(link.document) - if item.doctype == defs.Credoctypes.CRE: - collection.add_internal_link( - dbitem, linked_item, ltype=link.ltype - ) - else: - collection.add_link( - node=dbitem, cre=linked_item, ltype=link.ltype - ) - else: - linked_item = collection.add_node(link.document) - if item.doctype == defs.Credoctypes.CRE: - collection.add_link( - cre=dbitem, node=linked_item, ltype=link.ltype - ) - else: - collection.add_internal_link( - cre=linked_item, node=dbitem, ltype=link.ltype - ) - cres = inputDocs[defs.Credoctypes.CRE] - c0 = [c for c in cres if c.name == "C0"][0] - self.assertEqual(collection.get_cre_hierarchy(c0), 0) - c2 = [c for c in cres if c.name == "C2"][0] - self.assertEqual(collection.get_cre_hierarchy(c2), 1) - c3 = [c for c in cres if c.name == "C3"][0] - self.assertEqual(collection.get_cre_hierarchy(c3), 2) - c4 = [c for c in cres if c.name == "C4"][0] - self.assertEqual(collection.get_cre_hierarchy(c4), 3) - c5 = [c for c in cres if c.name == "C5"][0] - self.assertEqual(collection.get_cre_hierarchy(c5), 4) - c6 = [c for c in cres if c.name == "C6"][0] - self.assertEqual(collection.get_cre_hierarchy(c6), 0) - c7 = [c for c in cres if c.name == "C7"][0] - self.assertEqual(collection.get_cre_hierarchy(c7), 0) - c8 = [c for c in cres if c.name == "C8"][0] - self.assertEqual(collection.get_cre_hierarchy(c8), 0) +import networkx as nx +from application.utils.gap_analysis import make_resources_key, make_subresources_key +import string +import random +import os +import tempfile +import unittest +from unittest import mock +from unittest.mock import patch +import uuid +from copy import copy, deepcopy +from pprint import pprint +from typing import Any, Dict, List, Union +from flask import json as flask_json + +import yaml +from application.tests.utils.data_gen import export_format_data +from application import create_app, sqla # type: ignore +from application.database import db +from application.defs import cre_defs as defs + + +class TestDB(unittest.TestCase): + def tearDown(self) -> None: + sqla.session.remove() + sqla.drop_all() + self.app_context.pop() + + def setUp(self) -> None: + self.app = create_app(mode="test") + self.app_context = self.app.app_context() + self.app_context.push() + sqla.create_all() + + self.collection = db.Node_collection().with_graph() + self.collection.graph.with_graph( + graph=nx.DiGraph(), graph_data=[] + ) # initialize the graph singleton for the tests to be unique + + collection = self.collection + + dbcre = collection.add_cre( + defs.CRE(id="111-000", description="CREdesc", name="CREname") + ) + self.dbcre = dbcre + dbgroup = collection.add_cre( + defs.CRE(id="111-001", description="Groupdesc", name="GroupName") + ) + dbstandard = collection.add_node( + defs.Standard( + subsection="4.5.6", + section="FooStand", + name="BarStand", + hyperlink="https://example.com", + tags=["788-788", "b", "c"], + ) + ) + + collection.add_node( + defs.Standard( + subsection="4.5.6", + section="Unlinked", + name="Unlinked", + hyperlink="https://example.com", + ) + ) + + collection.session.add(dbcre) + collection.add_link(cre=dbcre, node=dbstandard, ltype=defs.LinkTypes.LinkedTo) + collection.add_internal_link( + lower=dbcre, higher=dbgroup, ltype=defs.LinkTypes.Contains + ) + + self.collection = collection + + def test_get_by_tags(self) -> None: + """ + Given: A CRE with no links and a combination of possible tags: + "tag1,dash-2,underscore_3,space 4,co_mb-ination%5" + A Standard with no links and a combination of possible tags + "tag1, dots.5.5, space 6 , several spaces and newline 7 \n" + some limited overlap between the tag-sets + Expect: + The CRE to be returned when searching for "tag-2" and for ["tag1","underscore_3"] + The Standard to be returned when searching for "space 6" and ["dots.5.5", "space 6"] + Both to be returned when searching for "space" and "tag1" + """ + + dbcre = db.CRE( + description="tagCREdesc1", + name="tagCREname1", + tags="tag1,dash-2,underscore_3,space 4,co_mb-ination%5", + external_id="111-111", + ) + cre = db.CREfromDB(dbcre) + cre.id = "111-111" + dbstandard = db.Node( + subsection="4.5.6.7", + section="tagsstand", + name="tagsstand", + link="https://example.com", + version="", + tags="tag1, dots.5.5, space 6 , several spaces and newline 7 \n", + ntype=defs.Standard.__name__, + ) + standard = db.nodeFromDB(dbstandard) + self.collection.session.add(dbcre) + self.collection.session.add(dbstandard) + self.collection.session.commit() + + self.maxDiff = None + self.assertEqual(self.collection.get_by_tags(["dash-2"]), [cre]) + self.assertEqual(self.collection.get_by_tags(["tag1", "underscore_3"]), [cre]) + self.assertEqual(self.collection.get_by_tags(["space 6"]), [standard]) + self.assertEqual( + self.collection.get_by_tags(["dots.5.5", "space 6"]), [standard] + ) + + self.assertCountEqual([cre, standard], self.collection.get_by_tags(["space"])) + self.assertCountEqual( + [cre, standard], self.collection.get_by_tags(["space", "tag1"]) + ) + self.assertCountEqual(self.collection.get_by_tags(["tag1"]), [cre, standard]) + + self.assertEqual(self.collection.get_by_tags([]), []) + self.assertEqual(self.collection.get_by_tags(["this should not be a tag"]), []) + + def test_get_standards_names(self) -> None: + result = self.collection.get_node_names() + expected = [("Standard", "BarStand"), ("Standard", "Unlinked")] + self.assertEqual(expected, result) + + def test_get_max_internal_connections(self) -> None: + self.assertEqual(self.collection.get_max_internal_connections(), 1) + + dbcrelo = db.CRE(name="internal connections test lo", description="ictlo") + dbcrehi = db.CRE(name="internal connections test hi", description="icthi") + self.collection.session.add(dbcrelo) + self.collection.session.add(dbcrehi) + self.collection.session.commit() + for i in range(0, 100): + dbcre = db.CRE(name=str(i) + " name", description=str(i) + " desc") + self.collection.session.add(dbcre) + self.collection.session.commit() + + # 1 low level cre to multiple groups + self.collection.session.add( + db.InternalLinks(group=dbcre.id, cre=dbcrelo.id) + ) + + # 1 hi level cre to multiple low level + self.collection.session.add( + db.InternalLinks(group=dbcrehi.id, cre=dbcre.id) + ) + + self.collection.session.commit() + + result = self.collection.get_max_internal_connections() + self.assertEqual(result, 100) + + def test_export(self) -> None: + """ + Given: + A CRE "CREname" that links to a CRE "GroupName" and a Standard "BarStand" + Expect: + 2 documents on disk, one for "CREname" + with a link to "BarStand" and "GroupName" and one for "GroupName" with a link to "CREName" + """ + loc = tempfile.mkdtemp() + self.collection = db.Node_collection().with_graph() + collection = self.collection + code0 = defs.Code(name="co0") + code1 = defs.Code(name="co1") + tool0 = defs.Tool(name="t0", tooltype=defs.ToolTypes.Unknown) + dbstandard = collection.add_node( + defs.Standard( + subsection="4.5.6", + section="FooStand", + sectionID="123-123", + name="BarStand", + hyperlink="https://example.com", + tags=["788-788", "b", "c"], + ) + ) + + collection.add_node( + defs.Standard( + subsection="4.5.6", + section="Unlinked", + sectionID="Unlinked", + name="Unlinked", + hyperlink="https://example.com", + ) + ) + self.collection.add_link( + self.dbcre, self.collection.add_node(code0), ltype=defs.LinkTypes.LinkedTo + ) + self.collection.add_node(code1) + self.collection.add_node(tool0) + + expected = [ + defs.CRE( + id="111-001", + description="Groupdesc", + name="GroupName", + links=[ + defs.Link( + document=defs.CRE( + id="111-000", description="CREdesc", name="CREname" + ), + ltype=defs.LinkTypes.Contains, + ) + ], + ), + defs.CRE( + id="111-000", + description="CREdesc", + name="CREname", + links=[ + defs.Link( + document=defs.CRE( + id="112-001", description="Groupdesc", name="GroupName" + ), + ltype=defs.LinkTypes.Contains, + ), + defs.Link( + document=defs.Standard( + name="BarStand", + section="FooStand", + sectionID="456", + subsection="4.5.6", + hyperlink="https://example.com", + tags=["788-788", "b", "c"], + ), + ltype=defs.LinkTypes.LinkedTo, + ), + defs.Link( + document=defs.Code(name="co0"), ltype=defs.LinkTypes.LinkedTo + ), + ], + ), + defs.Standard( + subsection="4.5.6", + section="Unlinked", + name="Unlinked", + sectionID="Unlinked", + hyperlink="https://example.com", + ), + defs.Tool(name="t0", tooltype=defs.ToolTypes.Unknown), + defs.Code(name="co1"), + ] + self.collection.export(loc) + + # load yamls from loc, parse, + # ensure yaml1 is result[0].todict and + # yaml2 is expected[1].todict + group = expected[0].todict() + cre = expected[1].todict() + groupname = ( + expected[0] + .id.replace("/", "-") + .replace(" ", "_") + .replace('"', "") + .replace("'", "") + + ".yaml" + ) + with open(os.path.join(loc, groupname), "r") as f: + doc = yaml.safe_load(f) + self.assertDictEqual(group, doc) + + crename = ( + expected[1] + .id.replace("/", "-") + .replace(" ", "_") + .replace('"', "") + .replace("'", "") + + ".yaml" + ) + self.maxDiff = None + with open(os.path.join(loc, crename), "r") as f: + doc = yaml.safe_load(f) + self.assertCountEqual(cre, doc) + + def test_StandardFromDB(self) -> None: + expected = defs.Standard( + name="foo", + section="bar", + sectionID="213", + subsection="foobar", + hyperlink="https://example.com/foo/bar", + version="1.1.1", + ) + self.assertEqual( + expected, + db.nodeFromDB( + db.Node( + name="foo", + section="bar", + subsection="foobar", + link="https://example.com/foo/bar", + version="1.1.1", + section_id="213", + ntype=defs.Standard.__name__, + ) + ), + ) + + def test_CREfromDB(self) -> None: + c = defs.CRE( + id="243-243", + doctype=defs.Credoctypes.CRE, + description="CREdesc", + name="CREname", + ) + self.assertEqual( + c, + db.CREfromDB( + db.CRE(external_id="243-243", description="CREdesc", name="CREname") + ), + ) + + def test_add_cre(self) -> None: + original_desc = str(uuid.uuid4()) + name = str(uuid.uuid4()) + + c = defs.CRE( + id="243-243", + doctype=defs.Credoctypes.CRE, + description=original_desc, + name=name, + ) + self.assertIsNone( + self.collection.session.query(db.CRE).filter(db.CRE.name == c.name).first() + ) + + # happy path, add new cre + newCRE = self.collection.add_cre(c) + dbcre = ( + self.collection.session.query(db.CRE).filter(db.CRE.name == c.name).first() + ) # ensure transaction happened (commit() called) + self.assertIsNotNone(dbcre.id) + self.assertEqual(dbcre.name, c.name) + self.assertEqual(dbcre.description, c.description) + self.assertEqual(dbcre.external_id, c.id) + + # ensure the right thing got returned + self.assertEqual(newCRE.name, c.name) + + # ensure no accidental update (add only adds) + c.description = "description2" + newCRE = self.collection.add_cre(c) + dbcre = ( + self.collection.session.query(db.CRE).filter(db.CRE.name == c.name).first() + ) + # ensure original description + self.assertEqual(dbcre.description, original_desc) + # ensure original description + self.assertEqual(newCRE.description, original_desc) + + def test_add_node(self) -> None: + original_section = str(uuid.uuid4()) + name = str(uuid.uuid4()) + + s = defs.Standard( + doctype=defs.Credoctypes.Standard, + section=original_section, + subsection=original_section, + name=name, + tags=["788-788", "b", "c"], + ) + + self.assertIsNone( + self.collection.session.query(db.Node) + .filter(db.Node.name == s.name) + .first() + ) + + # happy path, add new standard + newStandard = self.collection.add_node(s) + self.assertIsNotNone(newStandard) + + dbstandard = ( + self.collection.session.query(db.Node) + .filter(db.Node.name == s.name) + .first() + ) # ensure transaction happened (commit() called) + self.assertIsNotNone(dbstandard.id) + self.assertEqual(dbstandard.name, s.name) + self.assertEqual(dbstandard.section, s.section) + self.assertEqual(dbstandard.subsection, s.subsection) + self.assertEqual( + newStandard.name, s.name + ) # ensure the right thing got returned + self.assertEqual(dbstandard.ntype, s.doctype.value) + self.assertEqual(dbstandard.tags, ",".join(s.tags)) + # standards match on all of name,section, subsection <-- if you change even one of them it's a new entry + + def find_cres_of_cre(self) -> None: + dbcre = db.CRE(description="CREdesc1", name="CREname1") + groupless_cre = db.CRE(description="CREdesc2", name="CREname2") + dbgroup = db.CRE(description="Groupdesc1", name="GroupName1") + dbgroup2 = db.CRE(description="Groupdesc2", name="GroupName2") + + only_one_group = db.CRE(description="CREdesc3", name="CREname3") + + self.collection.session.add(dbcre) + self.collection.session.add(groupless_cre) + self.collection.session.add(dbgroup) + self.collection.session.add(dbgroup2) + self.collection.session.add(only_one_group) + self.collection.session.commit() + + internalLink = db.InternalLinks(cre=dbcre.id, group=dbgroup.id, type="Contains") + internalLink2 = db.InternalLinks( + cre=dbcre.id, group=dbgroup2.id, type="Contains" + ) + internalLink3 = db.InternalLinks( + cre=only_one_group.id, group=dbgroup.id, type="Contains" + ) + self.collection.session.add(internalLink) + self.collection.session.add(internalLink2) + self.collection.session.add(internalLink3) + self.collection.session.commit() + + # happy path, find cre with 2 groups + + groups = self.collection.find_cres_of_cre(dbcre) + if not groups: + self.fail("Expected exactly 2 cres") + self.assertEqual(len(groups), 2) + self.assertEqual(groups, [dbgroup, dbgroup2]) + + # find cre with 1 group + group = self.collection.find_cres_of_cre(only_one_group) + + if not group: + self.fail("Expected exactly 1 cre") + self.assertEqual(len(group), 1) + self.assertEqual(group, [dbgroup]) + + # ensure that None is return if there are no groups + groups = self.collection.find_cres_of_cre(groupless_cre) + self.assertIsNone(groups) + + def test_find_cres_of_standard(self) -> None: + dbcre = db.CRE(description="CREdesc1", name="CREname1") + dbgroup = db.CRE(description="CREdesc2", name="CREname2") + dbstandard1 = db.Node( + section="section1", + name="standard1", + ntype=defs.Standard.__name__, + ) + group_standard = db.Node( + section="section2", + name="standard2", + ntype=defs.Standard.__name__, + ) + lone_standard = db.Node( + section="section3", + name="standard3", + ntype=defs.Standard.__name__, + ) + + self.collection.session.add(dbcre) + self.collection.session.add(dbgroup) + self.collection.session.add(dbstandard1) + self.collection.session.add(group_standard) + self.collection.session.add(lone_standard) + self.collection.session.commit() + + self.collection.session.add(db.Links(cre=dbcre.id, node=dbstandard1.id)) + self.collection.session.add(db.Links(cre=dbgroup.id, node=dbstandard1.id)) + self.collection.session.add(db.Links(cre=dbgroup.id, node=group_standard.id)) + self.collection.session.commit() + + # happy path, 1 group and 1 cre link to 1 standard + cres = self.collection.find_cres_of_node(dbstandard1) + + if not cres: + self.fail("Expected 2 cres") + self.assertEqual(len(cres), 2) + self.assertEqual(cres, [dbcre, dbgroup]) + + # group links to standard + cres = self.collection.find_cres_of_node(group_standard) + + if not cres: + self.fail("Expected 1 cre") + self.assertEqual(len(cres), 1) + self.assertEqual(cres, [dbgroup]) + + # no links = None + cres = self.collection.find_cres_of_node(lone_standard) + self.assertIsNone(cres) + + def test_get_CREs(self) -> None: + """Given: a cre 'C1' that links to cres both as a group and a cre and other standards + return the CRE in Document format""" + collection = db.Node_collection() + dbc1 = db.CRE(external_id="123-123", description="gcCD1", name="gcC1") + dbc2 = db.CRE(description="gcCD2", name="gcC2", external_id="444-444") + dbc3 = db.CRE(description="gcCD3", name="gcC3", external_id="555-555") + db_id_only = db.CRE( + description="c_get_by_internal_id_only", + name="cgbiio", + external_id="666-666", + ) + dbs1 = db.Node( + ntype=defs.Standard.__name__, + name="gcS2", + section="gc1", + subsection="gc2", + link="gc3", + version="gc1.1.1", + ) + + dbs2 = db.Node( + ntype=defs.Standard.__name__, + name="gcS3", + section="gc1", + subsection="gc2", + link="gc3", + version="gc3.1.2", + ) + + parent_cre = db.CRE( + external_id="999-999", description="parent cre", name="pcre" + ) + parent_cre2 = db.CRE( + external_id="888-888", description="parent cre2", name="pcre2" + ) + partOf_cre = db.CRE( + external_id="777-777", description="part of cre", name="poc" + ) + + collection.session.add(dbc1) + collection.session.add(dbc2) + collection.session.add(dbc3) + collection.session.add(dbs1) + collection.session.add(dbs2) + collection.session.add(db_id_only) + + collection.session.add(parent_cre) + collection.session.add(parent_cre2) + collection.session.add(partOf_cre) + collection.session.commit() + + collection.session.add( + db.InternalLinks(type="Contains", group=dbc1.id, cre=dbc2.id) + ) + collection.session.add( + db.InternalLinks(type="Contains", group=dbc1.id, cre=dbc3.id) + ) + collection.session.add(db.Links(type="Linked To", cre=dbc1.id, node=dbs1.id)) + + collection.session.add( + db.InternalLinks( + type=defs.LinkTypes.Contains.value, + group=parent_cre.id, + cre=partOf_cre.id, + ) + ) + collection.session.add( + db.InternalLinks( + type=defs.LinkTypes.Contains.value, + group=parent_cre2.id, + cre=partOf_cre.id, + ) + ) + collection.session.commit() + self.maxDiff = None + + # we can retrieve children cres + self.assertEqual( + [ + db.CREfromDB(parent_cre).add_link( + defs.Link( + document=db.CREfromDB(partOf_cre), ltype=defs.LinkTypes.Contains + ) + ) + ], + collection.get_CREs(external_id=parent_cre.external_id), + ) + self.assertEqual( + [ + db.CREfromDB(parent_cre2).add_link( + defs.Link( + document=db.CREfromDB(partOf_cre), ltype=defs.LinkTypes.Contains + ) + ) + ], + collection.get_CREs(external_id=parent_cre2.external_id), + ) + + # we can retrieve children cres with inverted multiple (PartOf) links to their parents + self.assertEqual( + [ + db.CREfromDB(partOf_cre) + .add_link( + defs.Link( + document=db.CREfromDB(parent_cre), ltype=defs.LinkTypes.PartOf + ) + ) + .add_link( + defs.Link( + document=db.CREfromDB(parent_cre2), ltype=defs.LinkTypes.PartOf + ) + ) + ], + collection.get_CREs(external_id=partOf_cre.external_id), + ) + + cd1 = defs.CRE(id="123-123", description="gcCD1", name="gcC1") + cd2 = defs.CRE(id="444-444", description="gcCD2", name="gcC2") + cd3 = defs.CRE(id="555-555", description="gcCD3", name="gcC3") + c_id_only = defs.CRE( + id="666-666", description="c_get_by_internal_id_only", name="cgbiio" + ) + + expected = [ + copy(cd1) + .add_link( + defs.Link( + ltype=defs.LinkTypes.LinkedTo, + document=defs.Standard( + name="gcS2", + section="gc1", + subsection="gc2", + hyperlink="gc3", + version="gc1.1.1", + ), + ) + ) + .add_link( + defs.Link( + ltype=defs.LinkTypes.Contains, + document=copy(cd2), + ) + ) + .add_link(defs.Link(ltype=defs.LinkTypes.Contains, document=copy(cd3))) + ] + self.maxDiff = None + shallow_cd1 = copy(cd1) + shallow_cd1.links = [] + cd2.add_link(defs.Link(ltype=defs.LinkTypes.PartOf, document=shallow_cd1)) + cd3.add_link(defs.Link(ltype=defs.LinkTypes.PartOf, document=shallow_cd1)) + + # empty returns empty + self.assertEqual([], collection.get_CREs()) + + # getting "group cre 1" by name returns gcC1 + res = collection.get_CREs(name="gcC1") + self.assertEqual(len(expected), len(res)) + self.assertCountEqual(expected[0].todict(), res[0].todict()) + + # getting "group cre 1" by id returns gcC1 + res = collection.get_CREs(external_id="123-123") + self.assertEqual(len(expected), len(res)) + self.assertCountEqual(expected[0].todict(), res[0].todict()) + + # getting "group cre 1" by partial id returns gcC1 + res = collection.get_CREs(external_id="12%", partial=True) + self.assertEqual(len(expected), len(res)) + self.assertCountEqual(expected[0].todict(), res[0].todict()) + + # getting "group cre 1" by partial name returns gcC1, gcC2 and gcC3 + res = collection.get_CREs(name="gcC%", partial=True) + self.assertEqual(3, len(res)) + self.assertCountEqual( + [expected[0].todict(), cd2.todict(), cd3.todict()], + [r.todict() for r in res], + ) + + # getting "group cre 1" by partial name and partial id returns gcC1 + res = collection.get_CREs(external_id="1%", name="gcC%", partial=True) + self.assertEqual(len(expected), len(res)) + self.assertCountEqual(expected[0].todict(), res[0].todict()) + + # getting "group cre 1" by description returns gcC1 + res = collection.get_CREs(description="gcCD1") + self.assertEqual(len(expected), len(res)) + self.assertCountEqual(expected[0].todict(), res[0].todict()) + + # getting "group cre 1" by partial id and partial description returns gcC1 + res = collection.get_CREs(external_id="1%", description="gcC%", partial=True) + self.assertEqual(len(expected), len(res)) + self.assertCountEqual(expected[0].todict(), res[0].todict()) + + # getting all the gcC* cres by partial name and partial description returns gcC1, gcC2, gcC3 + res = collection.get_CREs(description="gcC%", name="gcC%", partial=True) + want = [expected[0], cd2, cd3] + for el in res: + found = False + for wel in want: + if el.todict() == wel.todict(): + found = True + self.assertTrue(found) + + self.assertEqual([], collection.get_CREs(external_id="123-123", name="gcC5")) + self.assertEqual([], collection.get_CREs(external_id="1234")) + self.assertEqual([], collection.get_CREs(name="gcC5")) + + # add a standard to gcC1 + collection.session.add(db.Links(type="Linked To", cre=dbc1.id, node=dbs2.id)) + + only_gcS2 = deepcopy(expected) # save a copy of the current expected + expected[0].add_link( + defs.Link( + ltype=defs.LinkTypes.LinkedTo, + document=defs.Standard( + name="gcS3", + section="gc1", + subsection="gc2", + hyperlink="gc3", + version="gc3.1.2", + ), + ) + ) + # we can retrieve the cre with the standard + res = collection.get_CREs(name="gcC1") + self.assertCountEqual(expected[0].todict(), res[0].todict()) + + # we can retrieve ONLY the standard + res = collection.get_CREs(name="gcC1", include_only=["gcS2"]) + self.assertDictEqual(only_gcS2[0].todict(), res[0].todict()) + + ccd2 = copy(cd2) + ccd2.links = [] + ccd3 = copy(cd3) + ccd3.links = [] + no_standards = [ + copy(cd1) + .add_link( + defs.Link( + ltype=defs.LinkTypes.Contains, + document=ccd2, + ) + ) + .add_link(defs.Link(ltype=defs.LinkTypes.Contains, document=ccd3)) + ] + + # if the standard is not linked, we retrieve as normal + res = collection.get_CREs(name="gcC1", include_only=["gcS0"]) + self.assertEqual(no_standards, res) + + self.assertEqual([c_id_only], collection.get_CREs(internal_id=db_id_only.id)) + + def test_get_standards(self) -> None: + """Given: a Standard 'S1' that links to cres + return the Standard in Document format""" + collection = db.Node_collection() + docs: Dict[str, Union[db.CRE, db.Node]] = { + "dbc1": db.CRE(external_id="123-123", description="CD1", name="C1"), + "dbc2": db.CRE(external_id="222-222", description="CD2", name="C2"), + "dbc3": db.CRE(external_id="333-333", description="CD3", name="C3"), + "dbs1": db.Node( + ntype=defs.Standard.__name__, + name="S1", + section="111-111", + section_id="123-123", + subsection="222-222", + link="333-333", + version="4", + ), + } + links = [("dbc1", "dbs1"), ("dbc2", "dbs1"), ("dbc3", "dbs1")] + for k, v in docs.items(): + collection.session.add(v) + collection.session.commit() + + for cre, standard in links: + collection.session.add( + db.Links(type="Linked To", cre=docs[cre].id, node=docs[standard].id) + ) + collection.session.commit() + + expected = [ + defs.Standard( + name="S1", + section="111-111", + sectionID="123-123", + subsection="222-222", + hyperlink="333-333", + version="4", + links=[ + defs.Link( + ltype=defs.LinkTypes.LinkedTo, + document=defs.CRE(id="123-123", name="C1", description="CD1"), + ), + defs.Link( + ltype=defs.LinkTypes.LinkedTo, + document=defs.CRE(id="222-222", name="C2", description="CD2"), + ), + defs.Link( + ltype=defs.LinkTypes.LinkedTo, + document=defs.CRE(id="333-333", name="C3", description="CD3"), + ), + ], + ) + ] + + res = collection.get_nodes(name="S1") + self.assertEqual(expected, res) + + def test_get_nodes_with_pagination(self) -> None: + """Given: a Standard 'S1' that links to cres + return the Standard in Document format and the total pages and the page we are in + """ + collection = db.Node_collection() + docs: Dict[str, Union[db.Node, db.CRE]] = { + "dbc1": db.CRE(external_id="123-123", description="CD1", name="C1"), + "dbc2": db.CRE(external_id="222-222", description="CD2", name="C2"), + "dbc3": db.CRE(external_id="333-333", description="CD3", name="C3"), + "dbs1": db.Node( + name="S1", + section="111-111", + section_id="123-123", + subsection="222-222", + link="333-333", + version="4", + ntype=defs.Standard.__name__, + ), + } + links = [("dbc1", "dbs1"), ("dbc2", "dbs1"), ("dbc3", "dbs1")] + for k, v in docs.items(): + collection.session.add(v) + collection.session.commit() + + for cre, standard in links: + collection.session.add( + db.Links( + cre=docs[cre].id, + node=docs[standard].id, + type=defs.LinkTypes.LinkedTo, + ) + ) + collection.session.commit() + + expected = [ + defs.Standard( + name="S1", + section="111-111", + sectionID="123-123", + subsection="222-222", + hyperlink="333-333", + version="4", + links=[ + defs.Link( + document=defs.CRE(name="C1", description="CD1", id="123-123"), + ltype=defs.LinkTypes.LinkedTo, + ), + defs.Link( + document=defs.CRE(id="222-222", name="C2", description="CD2"), + ltype=defs.LinkTypes.LinkedTo, + ), + defs.Link( + document=defs.CRE(id="333-333", name="C3", description="CD3"), + ltype=defs.LinkTypes.LinkedTo, + ), + ], + ) + ] + total_pages, res, _ = collection.get_nodes_with_pagination(name="S1") + self.assertEqual(total_pages, 1) + self.assertEqual(expected, res) + + only_c1 = [ + defs.Standard( + name="S1", + section="111-111", + sectionID="123-123", + subsection="222-222", + hyperlink="333-333", + version="4", + links=[ + defs.Link( + document=defs.CRE(name="C1", description="CD1", id="123-123"), + ltype=defs.LinkTypes.LinkedTo, + ) + ], + ) + ] + _, res, _ = collection.get_nodes_with_pagination(name="S1", include_only=["C1"]) + self.assertEqual(only_c1, res) + _, res, _ = collection.get_nodes_with_pagination( + name="S1", include_only=["123-123"] + ) + self.assertEqual(only_c1, res) + + self.assertEqual( + collection.get_nodes_with_pagination(name="this should not exit"), + (None, None, None), + ) + + def test_add_internal_link(self) -> None: + """test that internal links are added successfully, + edge cases: + cre or group don't exist + called on a cycle scenario""" + + cres = { + "dbca": self.collection.add_cre( + defs.CRE(id="111-111", description="CA", name="CA") + ), + "dbcb": self.collection.add_cre( + defs.CRE(id="222-222", description="CB", name="CB") + ), + "dbcc": self.collection.add_cre( + defs.CRE(id="333-333", description="CC", name="CC") + ), + } + + # happy path + self.collection.add_internal_link( + higher=cres["dbca"], lower=cres["dbcb"], ltype=defs.LinkTypes.Related + ) + + # "happy path, internal link exists" + res = ( + self.collection.session.query(db.InternalLinks) + .filter( + db.InternalLinks.group == cres["dbca"].id, + db.InternalLinks.cre == cres["dbcb"].id, + ) + .first() + ) + self.assertEqual((res.group, res.cre), (cres["dbca"].id, cres["dbcb"].id)) + + # no cycle, free to insert + self.collection.add_internal_link( + higher=cres["dbcb"], lower=cres["dbcc"], ltype=defs.LinkTypes.Related + ) + res = ( + self.collection.session.query(db.InternalLinks) + .filter( + db.InternalLinks.group == cres["dbcb"].id, + db.InternalLinks.cre == cres["dbcc"].id, + ) + .first() + ) + self.assertEqual((res.group, res.cre), (cres["dbcb"].id, cres["dbcc"].id)) + + # introdcues a cycle, should not be inserted + self.collection.add_internal_link( + higher=cres["dbcc"], lower=cres["dbca"], ltype=defs.LinkTypes.Related + ) + + # cycles are not inserted branch + none_res = ( + self.collection.session.query(db.InternalLinks) + .filter( + db.InternalLinks.group == cres["dbcc"].id, + db.InternalLinks.cre == cres["dbca"].id, + ) + .one_or_none() + ) + self.assertIsNone(none_res) + + def test_text_search(self) -> None: + """Given: + a cre(id="111-111"23-456,name=foo,description='lorem ipsum foo+bar') + a standard(name=Bar,section=blah,subsection=foo, hyperlink='https://example.com/blah/foo') + a standard(name=Bar,section=blah,subsection=foo1, hyperlink='https://example.com/blah/foo1') + a standard(name=Bar,section=blah1,subsection=foo, hyperlink='https://example.com/blah1/foo') + + full_text_search('123-456') returns cre:foo + full_text_search('CRE:foo') and full_text_search('CRE foo') returns cre:foo + full_text_search('CRE:123-456') and full_text_search('CRE 123-456') returns cre:foo + + full_text_search('Standard:Bar') and full_text_search('Standard Bar') returns: [standard:Bar:blah:foo, + standard:Bar:blah:foo1, + standard:Bar:blah1:foo] + + full_text_search('Standard:blah') and full_text_search('Standard blah') returns [standard:Bar::blah:foo, + standard:Bar:blah:foo1] + full_text_search('Standard:blah:foo') returns [standard:Bar:blah:foo] + full_text_search('Standard:foo') returns [standard:Bar:blah:foo, + standard:Bar:blah1:foo] + + + full_text_search('ipsum') returns cre:foo + full_text_search('foo') returns [cre:foo,standard:Bar:blah:foo, standard:Bar:blah:foo1,standard:Bar:blah1:foo] + """ + collection = db.Node_collection() + cre = defs.CRE( + id="123-456", name="textSearchCRE", description="lorem ipsum tsSection+tsC" + ) + collection.add_cre(cre) + + s1 = defs.Standard( + name="textSearchStandard", + section="tsSection", + subsection="tsSubSection", + hyperlink="https://example.com/tsSection/tsSubSection", + ) + collection.add_node(s1) + s2 = defs.Standard( + name="textSearchStandard", + section="tsSection", + subsection="tsSubSection1", + hyperlink="https://example.com/tsSection/tsSubSection1", + ) + collection.add_node(s2) + s3 = defs.Standard( + name="textSearchStandard", + section="tsSection1", + subsection="tsSubSection1", + hyperlink="https://example.com/tsSection1/tsSubSection1", + ) + collection.add_node(s3) + t1 = defs.Tool( + name="textSearchTool", + tooltype=defs.ToolTypes.Offensive, + hyperlink="https://example.com/textSearchTool", + description="test text search with tool", + sectionID="15", + section="rule 15", + ) + collection.add_node(t1) + collection.session.commit() + expected: Dict[str, List[Any]] = { + "123-456": [cre], + "CRE:textSearchCRE": [cre], + "CRE textSearchCRE": [cre], + "CRE:123-456": [cre], + "CRE 123-456": [cre], + "Standard:textSearchStandard": [s1, s2, s3], + "Standard textSearchStandard": [s1, s2, s3], + "Standard:tsSection": [s1, s2], + "Standard tsSection": [s1, s2], + "Standard:tsSection:tsSubSection1": [s2], + "Standard tsSection tsSubSection1": [s2], + "Standard:tsSubSection1": [s2, s3], + "Standard tsSubSection1": [s2, s3], + "Standard:https://example.com/tsSection/tsSubSection1": [s2], + "Standard https://example.com/tsSection1/tsSubSection1": [s3], + "https://example.com/tsSection": [s1, s2, s3], + "ipsum": [cre], + "tsSection": [cre, s1, s2, s3], + "https://example.com/textSearchTool": [t1], + "text search": [t1], + } + self.maxDiff = None + for k, val in expected.items(): + res = self.collection.text_search(k) + self.assertCountEqual(res, val) + + def test_dbNodeFromNode(self) -> None: + data = { + "tool": defs.Tool( + name="fooTool", + description="lorem ipsum tsSection+tsC", + tooltype=defs.ToolTypes.Defensive, + tags=["111-111", "222-222", "333-333"], + ), + "standard": defs.Standard( + name="stand", section="s1", subsection="s2", version="s3" + ), + "code": defs.Code( + name="ccc", + description="c2", + hyperlink="https://example.com/code/hyperlink", + tags=["111-111", "222-222"], + ), + } + expected = { + "tool": db.Node( + name="fooTool", + description="lorem ipsum tsSection+tsC", + tags=",".join( + [defs.ToolTypes.Defensive.value, "111-111", "222-222", "333-333"] + ), + ntype=defs.Credoctypes.Tool.value, + ), + "standard": db.Node( + name="stand", + section="s1", + subsection="s2", + version="s3", + ntype=defs.Credoctypes.Standard.value, + ), + "code": db.Node( + name="ccc", + description="c2", + link="https://example.com/code/hyperlink", + tags="1,2", + ntype=defs.Credoctypes.Code.value, + ), + } + for k, v in data.items(): + nd = db.dbNodeFromNode(v) + for vname, var in vars(nd).items(): + if var and not vname.startswith("_"): + self.assertEqual(var, vars(expected[k]).get(vname)) + + def test_nodeFromDB(self) -> None: + expected = { + "tool": defs.Tool( + name="fooTool", + description="lorem ipsum tsSection+tsC", + tooltype=defs.ToolTypes.Defensive, + tags=["111-111", "222-222", "333-333"], + ), + "standard": defs.Standard( + name="stand", section="s1", subsection="s2", version="s3" + ), + "code": defs.Code( + name="ccc", + description="c2", + hyperlink="https://example.com/code/hyperlink", + tags=["111-111", "222-222"], + ), + } + data = { + "tool": db.Node( + name="fooTool", + description="lorem ipsum tsSection+tsC", + tags=",".join( + [defs.ToolTypes.Defensive.value, "111-111", "222-222", "333-333"] + ), + ntype=defs.Credoctypes.Tool.value, + ), + "standard": db.Node( + name="stand", + section="s1", + subsection="s2", + version="s3", + ntype=defs.Credoctypes.Standard.value, + ), + "code": db.Node( + name="ccc", + description="c2", + link="https://example.com/code/hyperlink", + tags="111-111,222-222", + ntype=defs.Credoctypes.Code.value, + ), + } + for k, v in data.items(): + nd = db.nodeFromDB(v) + for vname, var in vars(nd).items(): + if var and not vname.startswith("_"): + self.assertCountEqual(var, vars(expected[k]).get(vname)) + + def test_object_select(self) -> None: + dbnode1 = db.Node( + name="fooTool", + description="lorem ipsum tsSection+tsC", + tags=f"{defs.ToolTypes.Defensive.value},1", + ) + dbnode2 = db.Node( + name="fooTool", + description="lorem2", + link="https://example.com/foo/bar", + tags=f"{defs.ToolTypes.Defensive.value},1", + ) + + self.collection = db.Node_collection() + collection = db.Node_collection() + collection.session.add(dbnode1) + collection.session.add(dbnode2) + self.assertEqual(collection.object_select(dbnode1), [dbnode1]) + self.assertEqual(collection.object_select(dbnode2), [dbnode2]) + self.assertCountEqual( + collection.object_select(db.Node(name="fooTool")), [dbnode1, dbnode2] + ) + + self.assertEqual(collection.object_select(None), []) + + def test_get_root_cres(self): + """Given: + 6 CRES: + * C0 <-- Root + * C1 <-- Root + * C2 Part Of C0 + * C3 Part Of C1 + * C4 Part Of C2 + * C5 Related to C0 + * C6 Part Of C1 + * C7 Contains C6 <-- Root + 3 Nodes: + * N0 Unlinked + * N1 Linked To C1 + * N2 Linked to C2 + * N3 Linked to C3 + * N4 Linked to C4 + Get_root_cres should return C0, C1 + """ + cres = [] + nodes = [] + dbcres = [] + dbnodes = [] + + # clean the db from setup + sqla.session.remove() + sqla.drop_all() + sqla.create_all() + + collection = db.Node_collection().with_graph() + + for i in range(0, 8): + if i == 0 or i == 1: + cres.append(defs.CRE(name=f">> C{i}", id=f"{i}{i}{i}-{i}{i}{i}")) + else: + cres.append(defs.CRE(name=f"C{i}", id=f"{i}{i}{i}-{i}{i}{i}")) + + dbcres.append(collection.add_cre(cres[i])) + nodes.append(defs.Standard(section=f"S{i}", name=f"N{i}")) + dbnodes.append(collection.add_node(nodes[i])) + cres[i].add_link( + defs.Link(document=copy(nodes[i]), ltype=defs.LinkTypes.LinkedTo) + ) + collection.add_link( + cre=dbcres[i], node=dbnodes[i], ltype=defs.LinkTypes.LinkedTo + ) + + cres[0].add_link( + defs.Link(document=cres[2].shallow_copy(), ltype=defs.LinkTypes.Contains) + ) + cres[1].add_link( + defs.Link(document=cres[3].shallow_copy(), ltype=defs.LinkTypes.Contains) + ) + cres[2].add_link( + defs.Link(document=cres[4].shallow_copy(), ltype=defs.LinkTypes.Contains) + ) + + cres[3].add_link( + defs.Link(document=cres[5].shallow_copy(), ltype=defs.LinkTypes.Contains) + ) + cres[6].add_link( + defs.Link(document=cres[7].shallow_copy(), ltype=defs.LinkTypes.PartOf) + ) + collection.add_internal_link( + higher=dbcres[0], lower=dbcres[2], ltype=defs.LinkTypes.Contains + ) + collection.add_internal_link( + higher=dbcres[1], lower=dbcres[3], ltype=defs.LinkTypes.Contains + ) + collection.add_internal_link( + higher=dbcres[2], lower=dbcres[4], ltype=defs.LinkTypes.Contains + ) + collection.add_internal_link( + higher=dbcres[3], lower=dbcres[5], ltype=defs.LinkTypes.Contains + ) + collection.add_internal_link( + higher=dbcres[7], lower=dbcres[6], ltype=defs.LinkTypes.Contains + ) + cres[7].add_link( + defs.Link(document=cres[6].shallow_copy(), ltype=defs.LinkTypes.Contains) + ) + + root_cres = collection.get_root_cres() + self.maxDiff = None + self.assertCountEqual(root_cres, [cres[0], cres[1], cres[7]]) + + @patch.object(db.graph_db, "gap_analysis") + def test_gap_analysis_disconnected(self, gap_mock): + collection = db.Node_collection() + collection.graph_db.connected = False + gap_mock.return_value = (None, None) + + self.assertEqual(db.gap_analysis(collection.graph_db, ["788-788", "b"]), None) + + @patch.object(db.graph_db, "gap_analysis") + def test_gap_analysis_no_nodes(self, gap_mock): + collection = db.Node_collection() + collection.graph_db.connected = True + + gap_mock.return_value = ([], []) + self.assertEqual( + db.gap_analysis(collection.graph_db, ["788-788", "b"]), + (["788-788", "b"], {}, {}), + ) + + @patch.object(db.graph_db, "gap_analysis") + def test_gap_analysis_no_links(self, gap_mock): + collection = db.Node_collection() + collection.graph_db.connected = True + + gap_mock.return_value = ([defs.CRE(name="bob", id="111-111")], []) + self.maxDiff = None + self.assertEqual( + db.gap_analysis(collection.graph_db, ["788-788", "b"]), + ( + ["788-788", "b"], + { + "111-111": { + "start": defs.CRE(name="bob", id="111-111"), + "paths": {}, + "extra": 0, + } + }, + {"111-111": {"paths": {}}}, + ), + ) + + @patch.object(db.graph_db, "gap_analysis") + def test_gap_analysis_one_link(self, gap_mock): + collection = db.Node_collection() + collection.graph_db.connected = True + path = [ + { + "end": defs.CRE(name="bob", id="111-111"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + { + "end": defs.CRE(name="bob", id="222-222"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + ] + gap_mock.return_value = ( + [defs.CRE(name="bob", id="788-788")], + [ + { + "start": defs.CRE(name="bob", id="788-788"), + "end": defs.CRE(name="bob", id="788-789"), + "path": path, + } + ], + ) + expected = ( + ["788-788", "788-789"], + { + "788-788": { + "start": defs.CRE(name="bob", id="788-788"), + "paths": { + "788-789": { + "end": defs.CRE(name="bob", id="788-789"), + "path": path, + "score": 0, + } + }, + "extra": 0, + } + }, + {"788-788": {"paths": {}}}, + ) + self.maxDiff = None + self.assertEqual( + db.gap_analysis(collection.graph_db, ["788-788", "788-789"]), expected + ) + + @patch.object(db.graph_db, "gap_analysis") + def test_gap_analysis_one_weak_link(self, gap_mock): + collection = db.Node_collection() + collection.graph_db.connected = True + path = [ + { + "end": defs.CRE(name="bob", id="111-111"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + { + "end": defs.CRE(name="bob", id="222-222"), + "relationship": "RELATED", + "start": defs.CRE(name="bob", id="111-111"), + }, + { + "end": defs.CRE(name="bob", id="111-111"), + "relationship": "RELATED", + "start": defs.CRE(name="bob", id="222-222"), + }, + { + "end": defs.CRE(name="bob", id="333-333"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="222-222"), + }, + ] + gap_mock.return_value = ( + [defs.CRE(name="bob", id="111-111")], + [ + { + "start": defs.CRE(name="bob", id="111-111"), + "end": defs.CRE(name="bob", id="222-222"), + "path": path, + } + ], + ) + expected = ( + ["788-788", "b"], + { + "111-111": { + "start": defs.CRE(name="bob", id="111-111"), + "paths": {}, + "extra": 1, + } + }, + { + "111-111": { + "paths": { + "222-222": { + "end": defs.CRE(name="bob", id="222-222"), + "path": path, + "score": 4, + } + } + } + }, + ) + self.maxDiff = None + self.assertEqual( + db.gap_analysis(collection.graph_db, ["788-788", "b"]), expected + ) + + @patch.object(db.graph_db, "gap_analysis") + def test_gap_analysis_duplicate_link_path_existing_lower(self, gap_mock): + collection = db.Node_collection() + collection.graph_db.connected = True + path = [ + { + "end": defs.CRE(name="bob", id="111-111"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + { + "end": defs.CRE(name="bob", id="222-222"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + ] + path2 = [ + { + "end": defs.CRE(name="bob", id="111-111"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + { + "end": defs.CRE(name="bob", id="222-222"), + "relationship": "RELATED", + "start": defs.CRE(name="bob", id="788-788"), + }, + ] + gap_mock.return_value = ( + [defs.CRE(name="bob", id="111-111")], + [ + { + "start": defs.CRE(name="bob", id="111-111"), + "end": defs.CRE(name="bob", id="222-222"), + "path": path, + }, + { + "start": defs.CRE(name="bob", id="111-111"), + "end": defs.CRE(name="bob", id="222-222"), + "path": path2, + }, + ], + ) + expected = ( + ["788-788", "b"], + { + "111-111": { + "start": defs.CRE(name="bob", id="111-111"), + "paths": { + "222-222": { + "end": defs.CRE(name="bob", id="222-222"), + "path": path, + "score": 0, + } + }, + "extra": 0, + }, + }, + {"111-111": {"paths": {}}}, + ) + self.assertEqual( + db.gap_analysis(collection.graph_db, ["788-788", "b"]), expected + ) + + @patch.object(db.graph_db, "gap_analysis") + def test_gap_analysis_duplicate_link_path_existing_lower_new_in_extras( + self, gap_mock + ): + collection = db.Node_collection() + collection.graph_db.connected = True + path = [ + { + "end": defs.CRE(name="bob", id="111-111"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + { + "end": defs.CRE(name="bob", id="222-222"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + ] + path2 = [ + { + "end": defs.CRE(name="bob", id="111-111"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + { + "end": defs.CRE(name="bob", id="222-222"), + "relationship": "RELATED", + "start": defs.CRE(name="bob", id="788-788"), + }, + { + "end": defs.CRE(name="bob", id="222-222"), + "relationship": "RELATED", + "start": defs.CRE(name="bob", id="788-788"), + }, + ] + gap_mock.return_value = ( + [defs.CRE(name="bob", id="111-111")], + [ + { + "start": defs.CRE(name="bob", id="111-111"), + "end": defs.CRE(name="bob", id="222-222"), + "path": path, + }, + { + "start": defs.CRE(name="bob", id="111-111"), + "end": defs.CRE(name="bob", id="222-222"), + "path": path2, + }, + ], + ) + expected = ( + ["788-788", "b"], + { + "111-111": { + "start": defs.CRE(name="bob", id="111-111"), + "paths": { + "222-222": { + "end": defs.CRE(name="bob", id="222-222"), + "path": path, + "score": 0, + } + }, + "extra": 0, + }, + }, + {"111-111": {"paths": {}}}, + ) + self.assertEqual( + db.gap_analysis(collection.graph_db, ["788-788", "b"]), expected + ) + + @patch.object(db.graph_db, "gap_analysis") + def test_gap_analysis_duplicate_link_path_existing_higher(self, gap_mock): + collection = db.Node_collection() + collection.graph_db.connected = True + path = [ + { + "end": defs.CRE(name="bob", id="111-111"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + { + "end": defs.CRE(name="bob", id="222-222"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + ] + path2 = [ + { + "end": defs.CRE(name="bob", id="111-111"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + { + "end": defs.CRE(name="bob", id="222-222"), + "relationship": "RELATED", + "start": defs.CRE(name="bob", id="788-788"), + }, + ] + gap_mock.return_value = ( + [defs.CRE(name="bob", id="111-111")], + [ + { + "start": defs.CRE(name="bob", id="111-111"), + "end": defs.CRE(name="bob", id="222-222"), + "path": path2, + }, + { + "start": defs.CRE(name="bob", id="111-111"), + "end": defs.CRE(name="bob", id="222-222"), + "path": path, + }, + ], + ) + expected = ( + ["788-788", "b"], + { + "111-111": { + "start": defs.CRE(name="bob", id="111-111"), + "paths": { + "222-222": { + "end": defs.CRE(name="bob", id="222-222"), + "path": path, + "score": 0, + } + }, + "extra": 0, + } + }, + {"111-111": {"paths": {}}}, + ) + self.assertEqual( + db.gap_analysis(collection.graph_db, ["788-788", "b"]), expected + ) + + @patch.object(db.graph_db, "gap_analysis") + def test_gap_analysis_duplicate_link_path_existing_higher_and_in_extras( + self, gap_mock + ): + collection = db.Node_collection() + collection.graph_db.connected = True + path = [ + { + "end": defs.CRE(name="bob", id="111-111"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + { + "end": defs.CRE(name="bob", id="222-222"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + ] + path2 = [ + { + "end": defs.CRE(name="bob", id="111-111"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob", id="788-788"), + }, + { + "end": defs.CRE(name="bob", id="222-222"), + "relationship": "RELATED", + "start": defs.CRE(name="bob", id="788-788"), + }, + { + "end": defs.CRE(name="bob", id="222-222"), + "relationship": "RELATED", + "start": defs.CRE(name="bob", id="788-788"), + }, + ] + gap_mock.return_value = ( + [defs.CRE(name="bob", id="111-111")], + [ + { + "start": defs.CRE(name="bob", id="111-111"), + "end": defs.CRE(name="bob", id="222-222"), + "path": path2, + }, + { + "start": defs.CRE(name="bob", id="111-111"), + "end": defs.CRE(name="bob", id="222-222"), + "path": path, + }, + ], + ) + expected = ( + ["788-788", "b"], + { + "111-111": { + "start": defs.CRE(name="bob", id="111-111"), + "paths": { + "222-222": { + "end": defs.CRE(name="bob", id="222-222"), + "path": path, + "score": 0, + } + }, + "extra": 0, + } + }, + {"111-111": {"paths": {}}}, + ) + self.assertEqual( + db.gap_analysis(collection.graph_db, ["788-788", "b"]), expected + ) + + @patch.object(db.graph_db, "gap_analysis") + def test_gap_analysis_dump_to_cache(self, gap_mock): + collection = db.Node_collection() + collection.graph_db.connected = True + path = [ + { + "end": defs.CRE(name="bob1", id="111-111"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob7", id="788-788"), + "score": 0, + }, + { + "end": defs.CRE(name="bob2", id="222-222"), + "relationship": "RELATED", + "start": defs.CRE(name="bob1", id="111-111"), + "score": 2, + }, + { + "end": defs.CRE(name="bob1", id="111-111"), + "relationship": "RELATED", + "start": defs.CRE(name="bob2", id="222-222"), + "score": 2, + }, + { + "end": defs.CRE(name="bob3", id="333-333"), + "relationship": "LINKED_TO", + "start": defs.CRE(name="bob2", id="222-222"), + "score": 4, + }, + ] + gap_mock.return_value = ( + [defs.CRE(name="bob7", id="788-788")], + [ + { + "start": defs.CRE(name="bob7", id="788-788"), + "end": defs.CRE(name="bob2", id="222-222"), + "path": path, + } + ], + ) + + expected_response = ( + ["788-788", "222-222"], + { + "788-788": { + "start": defs.CRE(name="bob7", id="788-788"), + "paths": {}, + "extra": 1, + } + }, + { + "788-788": { + "paths": { + "222-222": { + "end": defs.CRE(name="bob2", id="222-222"), + "path": path, + "score": 4, + } + } + } + }, + ) + response = db.gap_analysis(collection.graph_db, ["788-788", "222-222"]) + + self.maxDiff = None + self.assertEqual( + response, (expected_response[0], expected_response[1], expected_response[2]) + ) + self.assertEqual( + collection.gap_analysis_exists(make_resources_key(["788-788", "222-222"])), + True, + ) + self.assertEqual( + collection.get_gap_analysis_result( + make_resources_key(["788-788", "222-222"]) + ), + flask_json.dumps({"result": expected_response[1]}), + ) + self.assertEqual( + collection.get_gap_analysis_result( + make_subresources_key(["788-788", "222-222"], "788-788") + ), + flask_json.dumps({"result": expected_response[2]["788-788"]}), + ) + + def test_neo_db_parse_node_code(self): + name = "name" + description = "description" + tags = "tags" + version = "version" + hyperlink = "version" + expected = defs.Code( + name=name, + description=description, + tags=tags, + version=version, + hyperlink=hyperlink, + links=[ + defs.Link( + defs.CRE(id="123-123", description="gcCD2", name="gcC2"), "Related" + ) + ], + ) + graph_node = db.NeoCode( + name=name, + description=description, + tags=tags, + version=version, + hyperlink=hyperlink, + related=[ + db.NeoCRE(external_id="123-123", description="gcCD2", name="gcC2"), + ], + ) + + self.assertEqual(db.graph_db.parse_node(graph_node).todict(), expected.todict()) + + def test_neo_db_parse_node_standard(self): + name = "name" + description = "description" + tags = "tags" + version = "version" + section = "section" + sectionID = "sectionID" + subsection = "subsection" + hyperlink = "version" + expected = defs.Standard( + name=name, + description=description, + tags=tags, + version=version, + section=section, + sectionID=sectionID, + subsection=subsection, + hyperlink=hyperlink, + links=[ + defs.Link( + defs.CRE(id="123-123", description="gcCD2", name="gcC2"), "Related" + ) + ], + ) + graph_node = db.NeoStandard( + name=name, + description=description, + tags=tags, + version=version, + section=section, + section_id=sectionID, + subsection=subsection, + hyperlink=hyperlink, + related=[ + db.NeoCRE(external_id="123-123", description="gcCD2", name="gcC2"), + ], + ) + self.assertEqual(db.graph_db.parse_node(graph_node).todict(), expected.todict()) + + def test_neo_db_parse_node_tool(self): + name = "name" + description = "description" + tags = "tags" + version = "version" + section = "section" + sectionID = "sectionID" + subsection = "subsection" + hyperlink = "version" + tooltype = defs.ToolTypes.Defensive + expected = defs.Tool( + name=name, + tooltype=tooltype, + description=description, + tags=tags, + version=version, + section=section, + sectionID=sectionID, + subsection=subsection, + hyperlink=hyperlink, + links=[ + defs.Link( + defs.CRE(id="123-123", description="gcCD2", name="gcC2"), "Related" + ) + ], + ) + graph_node = db.NeoTool( + name=name, + description=description, + tooltype=tooltype, + tags=tags, + version=version, + section=section, + section_id=sectionID, + subsection=subsection, + hyperlink=hyperlink, + related=[ + db.NeoCRE(external_id="123-123", description="gcCD2", name="gcC2"), + ], + ) + self.assertEqual(db.graph_db.parse_node(graph_node).todict(), expected.todict()) + + def test_neo_db_parse_node_cre(self): + name = "name" + description = "description" + tags = "tags" + external_id = "123-123" + expected = defs.CRE( + name=name, + description=description, + id=external_id, + tags=tags, + links=[ + defs.Link( + defs.CRE(id="123-123", description="gcCD2", name="gcC2"), "Contains" + ), + defs.Link( + defs.CRE(id="123-123", description="gcCD3", name="gcC3"), "Contains" + ), + defs.Link( + defs.Standard( + hyperlink="gc3", + name="gcS2", + section="gc1", + subsection="gc2", + version="gc1.1.1", + ), + "Linked To", + ), + ], + ) + graph_node = db.NeoCRE( + name=name, + description=description, + tags=tags, + external_id=external_id, + contained_in=[], + contains=[ + db.NeoCRE(external_id="123-123", description="gcCD2", name="gcC2"), + db.NeoCRE(external_id="123-123", description="gcCD3", name="gcC3"), + ], + linked=[ + db.NeoStandard( + hyperlink="gc3", + name="gcS2", + section="gc1", + subsection="gc2", + version="gc1.1.1", + ) + ], + same_as=[], + related=[], + auto_linked_to=[], + ) + + parsed = db.graph_db.parse_node(graph_node) + self.maxDiff = None + self.assertEqual(parsed.todict(), expected.todict()) + + def test_neo_db_parse_node_no_links_cre(self): + name = "name" + description = "description" + tags = "tags" + external_id = "123-123" + expected = defs.CRE( + name=name, description=description, id=external_id, tags=tags, links=[] + ) + graph_node = db.NeoCRE( + name=name, + description=description, + tags=tags, + external_id=external_id, + contained_in=[], + contains=[ + db.NeoCRE(external_id="123-123", description="gcCD2", name="gcC2"), + db.NeoCRE(external_id="123-123", description="gcCD3", name="gcC3"), + ], + linked=[ + db.NeoStandard( + hyperlink="gc3", + name="gcS2", + section="gc1", + subsection="gc2", + version="gc1.1.1", + ) + ], + same_as=[], + related=[], + ) + + parsed = db.graph_db.parse_node_no_links(graph_node) + self.maxDiff = None + self.assertEqual(parsed.todict(), expected.todict()) + + def test_neo_db_parse_node_Document(self): + name = "name" + id = "id" + description = "description" + tags = "tags" + graph_node = db.NeoDocument( + name=name, + document_id=id, + description=description, + tags=tags, + ) + with self.assertRaises(Exception) as cm: + db.graph_db.parse_node(graph_node) + + self.assertEqual(str(cm.exception), "Shouldn't be parsing a NeoDocument") + + def test_neo_db_parse_node_Node(self): + name = "name" + id = "id" + description = "description" + tags = "tags" + graph_node = db.NeoNode( + name=name, + document_id=id, + description=description, + tags=tags, + ) + with self.assertRaises(Exception) as cm: + db.graph_db.parse_node(graph_node) + + self.assertEqual(str(cm.exception), "Shouldn't be parsing a NeoNode") + + def test_get_embeddings_by_doc_type_paginated(self): + """Given: a range of embedding for Nodes and a range of embeddings for CREs + when called with doc_type CRE return the cre embeddings + when called with doc_type Standard/Tool return the node embeddings""" + # add cre embeddings + cre_embeddings = [] + for i in range(0, 10): + dbca = db.CRE(external_id=f"{i}", description=f"C{i}", name=f"C{i}") + self.collection.session.add(dbca) + self.collection.session.commit() + + embeddings = [random.uniform(-1, 1) for e in range(0, 768)] + embeddings_text = "".join( + random.choices(string.ascii_uppercase + string.digits, k=100) + ) + cre_embeddings.append( + self.collection.add_embedding( + db_object=dbca, + doctype=defs.Credoctypes.CRE.value, + embeddings=embeddings, + embedding_text=embeddings_text, + ) + ) + + # add node embeddings + node_embeddings = [] + for i in range(0, 10): + dbsa = db.Node( + subsection=f"4.5.{i}", + section=f"FooStand-{i}", + name="BarStand", + link="https://example.com", + ntype=defs.Credoctypes.Standard.value, + ) + self.collection.session.add(dbsa) + self.collection.session.commit() + + embeddings = [random.uniform(-1, 1) for e in range(0, 768)] + embeddings_text = "".join( + random.choices(string.ascii_uppercase + string.digits, k=100) + ) + ne = self.collection.add_embedding( + db_object=dbsa, + doctype=defs.Credoctypes.Standard.value, + embeddings=embeddings, + embedding_text=embeddings_text, + ) + node_embeddings.append(ne) + + ( + cre_emb, + total_pages, + curr_page, + ) = self.collection.get_embeddings_by_doc_type_paginated( + defs.Credoctypes.CRE.value, page=1, per_page=1 + ) + self.assertNotEqual(list(cre_emb.keys())[0], "") + self.assertIn(list(cre_emb.keys())[0], list([e.cre_id for e in cre_embeddings])) + self.assertNotIn( + list(cre_emb.keys())[0], list([e.node_id for e in cre_embeddings]) + ) + self.assertEqual(total_pages, 10) + self.assertEqual(curr_page, 1) + + ( + node_emb, + total_pages, + curr_page, + ) = self.collection.get_embeddings_by_doc_type_paginated( + defs.Credoctypes.Standard.value, page=1, per_page=1 + ) + self.assertNotEqual(list(node_emb.keys())[0], "") + self.assertIn( + list(node_emb.keys())[0], list([e.node_id for e in node_embeddings]) + ) + self.assertNotIn( + list(node_emb.keys())[0], list([e.cre_id for e in cre_embeddings]) + ) + self.assertEqual(total_pages, 10) + self.assertEqual(curr_page, 1) + + ( + tool_emb, + total_pages, + curr_page, + ) = self.collection.get_embeddings_by_doc_type_paginated( + defs.Credoctypes.Tool.value, page=1, per_page=1 + ) + self.assertEqual(total_pages, 0) + self.assertEqual(tool_emb, {}) + + def test_get_embeddings_by_doc_type(self): + """Given: a range of embedding for Nodes and a range of embeddings for CREs + when called with doc_type CRE return the cre embeddings + when called with doc_type Standard/Tool return the node embeddings""" + # add cre embeddings + cre_embeddings = [] + for i in range(0, 10): + dbca = db.CRE(external_id=f"{i}", description=f"C{i}", name=f"C{i}") + self.collection.session.add(dbca) + self.collection.session.commit() + + embeddings = [random.uniform(-1, 1) for e in range(0, 768)] + embeddings_text = "".join( + random.choices(string.ascii_uppercase + string.digits, k=100) + ) + cre_embeddings.append( + self.collection.add_embedding( + db_object=dbca, + doctype=defs.Credoctypes.CRE.value, + embeddings=embeddings, + embedding_text=embeddings_text, + ) + ) + + # add node embeddings + node_embeddings = [] + for i in range(0, 10): + dbsa = db.Node( + subsection=f"4.5.{i}", + section=f"FooStand-{i}", + name="BarStand", + link="https://example.com", + ntype=defs.Credoctypes.Standard.value, + ) + self.collection.session.add(dbsa) + self.collection.session.commit() + + embeddings = [random.uniform(-1, 1) for e in range(0, 768)] + embeddings_text = "".join( + random.choices(string.ascii_uppercase + string.digits, k=100) + ) + ne = self.collection.add_embedding( + db_object=dbsa, + doctype=defs.Credoctypes.Standard.value, + embeddings=embeddings, + embedding_text=embeddings_text, + ) + node_embeddings.append(ne) + + cre_emb = self.collection.get_embeddings_by_doc_type(defs.Credoctypes.CRE.value) + self.assertNotEqual(list(cre_emb.keys())[0], "") + self.assertIn(list(cre_emb.keys())[0], list([e.cre_id for e in cre_embeddings])) + self.assertNotIn( + list(cre_emb.keys())[0], list([e.node_id for e in cre_embeddings]) + ) + + node_emb = self.collection.get_embeddings_by_doc_type( + defs.Credoctypes.Standard.value + ) + self.assertNotEqual(list(node_emb.keys())[0], "") + self.assertIn( + list(node_emb.keys())[0], list([e.node_id for e in node_embeddings]) + ) + self.assertNotIn( + list(node_emb.keys())[0], list([e.cre_id for e in cre_embeddings]) + ) + + tool_emb = self.collection.get_embeddings_by_doc_type( + defs.Credoctypes.Tool.value + ) + self.assertEqual(tool_emb, {}) + + def test_get_standard_names(self): + for s in ["sa", "sb", "sc", "sd"]: + for sub in ["suba", "subb", "subc", "subd"]: + self.collection.add_node( + defs.Standard(name=s, section=sub, subsection=sub) + ) + self.assertCountEqual( + ["BarStand", "Unlinked", "sa", "sb", "sc", "sd"], + self.collection.standards(), + ) + + def test_all_cres_with_pagination(self): + """""" + cres = [] + nodes = [] + dbcres = [] + dbnodes = [] + sqla.session.remove() + sqla.drop_all() + sqla.create_all() + collection = db.Node_collection() + for i in range(0, 8): + if i == 0 or i == 1: + cres.append(defs.CRE(name=f">> C{i}", id=f"{i}{i}{i}-{i}{i}{i}")) + else: + cres.append(defs.CRE(name=f"C{i}", id=f"{i}")) + + dbcres.append(collection.add_cre(cres[i])) + nodes.append(defs.Standard(section=f"S{i}", name=f"N{i}")) + dbnodes.append(collection.add_node(nodes[i])) + cres[i].add_link( + defs.Link(document=copy(nodes[i]), ltype=defs.LinkTypes.LinkedTo) + ) + collection.add_link( + cre=dbcres[i], node=dbnodes[i], ltype=defs.LinkTypes.LinkedTo + ) + + collection.session.commit() + + paginated_cres, page, total_pages = collection.all_cres_with_pagination( + page=1, per_page=2 + ) + self.maxDiff = None + # from pprint import pprint + # pprint(cres) + self.assertEqual(paginated_cres, [cres[0], cres[1]]) + self.assertEqual(page, 1) + self.assertEqual(total_pages, 4) + + def test_all_cres_with_pagination(self): + """""" + cres = [] + nodes = [] + dbcres = [] + dbnodes = [] + sqla.session.remove() + sqla.drop_all() + sqla.create_all() + collection = db.Node_collection() + for i in range(0, 8): + if i == 0 or i == 1: + cres.append(defs.CRE(name=f">> C{i}", id=f"{i}{i}{i}-{i}{i}{i}")) + else: + cres.append(defs.CRE(name=f"C{i}", id=f"{i}{i}{i}-{i}{i}{i}")) + + dbcres.append(collection.add_cre(cres[i])) + nodes.append(defs.Standard(section=f"S{i}", name=f"N{i}")) + dbnodes.append(collection.add_node(nodes[i])) + cres[i].add_link( + defs.Link(document=copy(nodes[i]), ltype=defs.LinkTypes.LinkedTo) + ) + collection.add_link( + cre=dbcres[i], node=dbnodes[i], ltype=defs.LinkTypes.LinkedTo + ) + + collection.session.commit() + + paginated_cres, page, total_pages = collection.all_cres_with_pagination( + page=1, per_page=2 + ) + self.maxDiff = None + self.assertEqual(paginated_cres, [cres[0], cres[1]]) + self.assertEqual(page, 1) + self.assertEqual(total_pages, 4) + + def test_get_cre_hierarchy(self) -> None: + # this needs a clean database and a clean graph so reinit everything + # sqla.session.remove() + # sqla.drop_all() + # sqla.create_all() + collection = self.collection # db.Node_collection().with_graph() + # collection.graph.with_graph(graph=nx.DiGraph(), graph_data=[]) + + _, inputDocs = export_format_data() + importItems = [] + for name, items in inputDocs.items(): + for item in items: + importItems.append(item) + if name == defs.Credoctypes.CRE: + dbitem = collection.add_cre(item) + else: + dbitem = collection.add_node(item) + for link in item.links: + if link.document.doctype == defs.Credoctypes.CRE: + linked_item = collection.add_cre(link.document) + if item.doctype == defs.Credoctypes.CRE: + collection.add_internal_link( + dbitem, linked_item, ltype=link.ltype + ) + else: + collection.add_link( + node=dbitem, cre=linked_item, ltype=link.ltype + ) + else: + linked_item = collection.add_node(link.document) + if item.doctype == defs.Credoctypes.CRE: + collection.add_link( + cre=dbitem, node=linked_item, ltype=link.ltype + ) + else: + collection.add_internal_link( + cre=linked_item, node=dbitem, ltype=link.ltype + ) + cres = inputDocs[defs.Credoctypes.CRE] + c0 = [c for c in cres if c.name == "C0"][0] + self.assertEqual(collection.get_cre_hierarchy(c0), 0) + c2 = [c for c in cres if c.name == "C2"][0] + self.assertEqual(collection.get_cre_hierarchy(c2), 1) + c3 = [c for c in cres if c.name == "C3"][0] + self.assertEqual(collection.get_cre_hierarchy(c3), 2) + c4 = [c for c in cres if c.name == "C4"][0] + self.assertEqual(collection.get_cre_hierarchy(c4), 3) + c5 = [c for c in cres if c.name == "C5"][0] + self.assertEqual(collection.get_cre_hierarchy(c5), 4) + c6 = [c for c in cres if c.name == "C6"][0] + self.assertEqual(collection.get_cre_hierarchy(c6), 0) + c7 = [c for c in cres if c.name == "C7"][0] + self.assertEqual(collection.get_cre_hierarchy(c7), 0) + c8 = [c for c in cres if c.name == "C8"][0] + self.assertEqual(collection.get_cre_hierarchy(c8), 0) diff --git a/application/tests/web_main_test.py b/application/tests/web_main_test.py index 9e219b4ce..5389cbac3 100644 --- a/application/tests/web_main_test.py +++ b/application/tests/web_main_test.py @@ -10,7 +10,12 @@ from unittest.mock import patch import redis -import rq + +try: + import rq +except (ValueError, ImportError): + rq = None + import os import networkx as nx @@ -29,7 +34,7 @@ def id(self): return "ABC" def get_status(self): - return rq.job.JobStatus.STARTED + return rq.job.JobStatus.STARTED if rq else "started" class TestMain(unittest.TestCase): diff --git a/application/utils/external_project_parsers/base_parser.py b/application/utils/external_project_parsers/base_parser.py index 5bff50e63..5605fbbdb 100644 --- a/application/utils/external_project_parsers/base_parser.py +++ b/application/utils/external_project_parsers/base_parser.py @@ -1,5 +1,10 @@ from application.utils.external_project_parsers import base_parser_defs -from rq import Queue + +try: + from rq import Queue +except (ValueError, ImportError): + Queue = None + from application.utils import redis from application.prompt_client import prompt_client as prompt_client import logging diff --git a/application/utils/gap_analysis.py b/application/utils/gap_analysis.py index 9e3dab04d..6bd95042e 100644 --- a/application/utils/gap_analysis.py +++ b/application/utils/gap_analysis.py @@ -2,7 +2,12 @@ import time import logging import os -from rq import Queue, job, exceptions + +try: + from rq import Queue, job, exceptions +except (ValueError, ImportError): + Queue, job, exceptions = None, None, None + from typing import List, Dict from application.utils import redis from flask import json as flask_json @@ -30,7 +35,7 @@ def make_resources_key(array: List[str]): def make_subresources_key(standards: List[str], key: str) -> str: - return str(make_resources_key(standards)) + "->" + key + return f"{str(make_resources_key(standards))}->{str(key)}" def get_path_score(path): @@ -128,48 +133,72 @@ def schedule(standards: List[str], database): logger.info(f"Gap analysis result for {standards_hash} does not exist") - conn = redis.connect() - if conn is None: - logger.error( - "Redis is not available. Please run 'make start-containers' first." - ) - return { - "error": "Redis is not available. Please run 'make start-containers' first." - } - gap_analysis_results = conn.get(standards_hash) - if ( - gap_analysis_results - ): # perhaps its calculated but not cached yet, get it from redis - gap_analysis_dict = json.loads(gap_analysis_results) - if gap_analysis_dict.get("job_id"): - try: - res = job.Job.fetch(id=gap_analysis_dict.get("job_id"), connection=conn) - except exceptions.NoSuchJobError as nje: - logger.error( - f"Could not find job id for gap analysis {standards}, this is a bug" - ) - return {"error": 404} + try: + conn = redis.connect() + if conn is None: + logger.error( + "Redis is not available. Please run 'make start-containers' first." + ) + else: + gap_analysis_results = conn.get(standards_hash) if ( - res.get_status() != job.JobStatus.FAILED - and res.get_status() != job.JobStatus.STOPPED - and res.get_status() != job.JobStatus.CANCELED - ): - logger.info( - f'gap analysis job id {gap_analysis_dict.get("job_id")}, for standards: {standards[0]}>>{standards[1]} already exists, returning early' + gap_analysis_results + ): # perhaps its calculated but not cached yet, get it from redis + gap_analysis_dict = json.loads(gap_analysis_results) + if gap_analysis_dict.get("job_id"): + try: + res = job.Job.fetch( + id=gap_analysis_dict.get("job_id"), connection=conn + ) + except exceptions.NoSuchJobError as nje: + logger.error( + f"Could not find job id for gap analysis {standards}, this is a bug" + ) + return {"error": 404} + if ( + res.get_status() != job.JobStatus.FAILED + and res.get_status() != job.JobStatus.STOPPED + and res.get_status() != job.JobStatus.CANCELED + ): + logger.info( + f'gap analysis job id {gap_analysis_dict.get("job_id")}, for standards: {standards[0]}>>{standards[1]} already exists, returning early' + ) + return {"job_id": gap_analysis_dict.get("job_id")} + if Queue: + q = Queue(connection=conn) + gap_analysis_job = q.enqueue_call( + db.gap_analysis, + kwargs={ + "graph_db": database.graph_db, + "node_names": standards, + "cache_key": standards_hash, + }, + timeout=GAP_ANALYSIS_TIMEOUT, + ) + conn.set( + standards_hash, + json.dumps({"job_id": gap_analysis_job.id, "result": ""}), ) - return {"job_id": gap_analysis_dict.get("job_id")} - q = Queue(connection=conn) - gap_analysis_job = q.enqueue_call( - db.gap_analysis, - kwargs={ - "neo_db": database.neo_db, - "node_names": standards, - "cache_key": standards_hash, - }, - timeout=GAP_ANALYSIS_TIMEOUT, + return {"job_id": gap_analysis_job.id} + except Exception as e: + logger.warning( + f"Redis operation failed or timed out: {e}. Falling back to sync." + ) + + logger.info("RQ is not available, running gap analysis synchronously") + res = db.gap_analysis( + graph_db=database.graph_db, + node_names=standards, + cache_key=standards_hash, ) - conn.set(standards_hash, json.dumps({"job_id": gap_analysis_job.id, "result": ""})) - return {"job_id": gap_analysis_job.id} + if res: + _, grouped_paths, _ = res + return {"result": grouped_paths} + else: + logger.error(f"Gap analysis failed to return results for {standards}") + return { + "error": "Gap analysis failed to return results. Please ensure standard names are correct and the graph is populated." + } def preload(target_url: str): diff --git a/application/utils/redis.py b/application/utils/redis.py index 90b2296bd..e2a1d3c94 100644 --- a/application/utils/redis.py +++ b/application/utils/redis.py @@ -3,7 +3,13 @@ from urllib.parse import urlparse import logging from typing import Callable, List -import rq + +try: + import rq +except (ValueError, ImportError): + # rq (specifically rq-scheduler) may fail on Windows due to 'fork' context requirement + rq = None + import time logging.basicConfig() @@ -31,13 +37,14 @@ def connect(): password=os.getenv("REDIS_PASSWORD", None), ssl=False if redis_no_ssl else True, ssl_cert_reqs=None, + socket_timeout=5, ) elif redis_url: logger.debug( f"Attempting to connect to Redis instance using a URL at {redis_url}" ) if redis_url == "redis://localhost:6379": - return redis.from_url(redis_url) + return redis.from_url(redis_url, socket_timeout=5) else: url = urlparse(redis_url) return redis.Redis( @@ -46,12 +53,17 @@ def connect(): password=url.password, ssl=False if redis_no_ssl else True, ssl_cert_reqs=None, + socket_timeout=5, ) else: logger.warning("Starting without Redis, functionality may be limited!") -def wait_for_jobs(jobs: List[rq.job.Job], callback: Callable = None): +def wait_for_jobs(jobs: List, callback: Callable = None): + if not rq: + logger.warning("RQ is not available on this system. Cannot wait for jobs.") + return + def do_nothing(): pass diff --git a/application/web/web_main.py b/application/web/web_main.py index 29567470a..78afb077a 100644 --- a/application/web/web_main.py +++ b/application/web/web_main.py @@ -14,7 +14,11 @@ from typing import Any from application.utils import oscal_utils, redis -from rq import job, exceptions +try: + from rq import job, exceptions +except (ValueError, ImportError): + job, exceptions = None, None + from application.utils import spreadsheet_parsers from application.utils import oscal_utils, redis @@ -738,6 +742,8 @@ def login(): session["name"] = "dev user" return redirect("/chatbot") flow_instance = CREFlow.instance() + if not flow_instance or not flow_instance.flow: + return "Login not configured. Please set up Google OAuth in .env.", 501 authorization_url, state = flow_instance.flow.authorization_url() session["state"] = state return redirect(authorization_url) diff --git a/application/worker.py b/application/worker.py index a26256622..0c6dfa32d 100644 --- a/application/worker.py +++ b/application/worker.py @@ -1,4 +1,8 @@ -from rq import Worker, Queue +try: + from rq import Worker, Queue +except (ValueError, ImportError): + Worker, Queue = None, None + import logging from application.utils import redis @@ -10,6 +14,12 @@ def start_worker(): + if not Worker: + logger.error( + "RQ Worker is not supported on Windows (requires os.fork). " + "Gap analysis will run synchronously in the web server instead." + ) + return logger.info(f"Worker Starting") worker = Worker(listen, connection=redis.connect()) worker.work() diff --git a/cre.py b/cre.py index 2e7a9d5cc..150b861a8 100644 --- a/cre.py +++ b/cre.py @@ -205,9 +205,10 @@ def main() -> None: help="for every node, download the text pointed to by the hyperlink and generate embeddings for the content of the specific node", ) parser.add_argument( + "--populate_graph_db", "--populate_neo4j_db", action="store_true", - help="populate the neo4j db", + help="populate the graph db (Neo4j or Apache AGE)", ) parser.add_argument( "--start_worker", diff --git a/scripts/benchmark_gap.py b/scripts/benchmark_gap.py index dde1e5cb3..0a27cb7c6 100644 --- a/scripts/benchmark_gap.py +++ b/scripts/benchmark_gap.py @@ -1,23 +1,3 @@ -""" -Benchmark script for Gap Analysis performance (Issue #587) -=========================================================== - -Measures wall-clock time and peak memory for: - - MODE A: Original exhaustive traversal (always runs wildcard [*..20] twice) - - MODE B: Optimized tiered pruning (early exit on strong/medium links) - -Usage: - # List available standards in Neo4j: - python scripts/benchmark_gap.py --list-standards - - # Run benchmark on two standards: - python scripts/benchmark_gap.py --standard1 "OWASP Top 10 2021" --standard2 "NIST 800-53" - -Requirements: - Neo4j must be running (use: make docker-neo4j) - NEO4J_URL env var or default: neo4j://neo4j:password@localhost:7687 -""" - import argparse import os import sys @@ -29,230 +9,105 @@ sys.path.insert(0, os.path.abspath(_project_root)) try: - from neomodel import config as neo_config, db as neomodel_db - - # Must import the project's DB models so neomodel registers NeoStandard, - # NeoCRE etc. — otherwise resolve_objects=True raises NodeClassNotDefined. - import application.database.db # noqa: F401 + from application.database import db + from application.config import Config except ImportError as exc: print(f"[ERROR] Could not import project modules: {exc}") print(" Make sure you run from the project root with venv activated.") sys.exit(1) -def connect_neo4j(): - url = os.environ.get("NEO4J_URL", "neo4j://neo4j:password@localhost:7687") - neo_config.DATABASE_URL = url - print(f" → Connected to Neo4j at: {url}\n") - - def list_available_standards(): - connect_neo4j() - results, _ = neomodel_db.cypher_query( - "MATCH (n:NeoStandard) RETURN DISTINCT n.name ORDER BY n.name" - ) - if not results: - print(" [!] No NeoStandard nodes found. Import data first:") - print(" make import-neo4j") + collection = db.Node_collection() + standards = collection.standards() + if not standards: + print(" [!] No standards found in database.") return - print(f"Found {len(results)} standards:") - for row in results: - print(f" • {row[0]}") - - -def run_original(name_1, name_2): - """Original pre-PR#716 approach: always runs BOTH queries unconditionally.""" - denylist = ["Cross-cutting concerns"] - - # Query 1 — wildcard (the expensive one) - r1, _ = neomodel_db.cypher_query( - """ - MATCH (BaseStandard:NeoStandard {name: $name1}) - MATCH (CompareStandard:NeoStandard {name: $name2}) - MATCH p = allShortestPaths((BaseStandard)-[*..20]-(CompareStandard)) - WITH p - WHERE length(p) > 1 AND ALL(n in NODES(p) WHERE - (n:NeoCRE OR n = BaseStandard OR n = CompareStandard) - AND NOT n.name IN $denylist) - RETURN p - """, - {"name1": name_1, "name2": name_2, "denylist": denylist}, - resolve_objects=True, - ) + print(f"Found {len(standards)} standards:") + for s in standards: + print(f" • {s}") - # Query 2 — filtered (also always ran) - r2, _ = neomodel_db.cypher_query( - """ - MATCH (BaseStandard:NeoStandard {name: $name1}) - MATCH (CompareStandard:NeoStandard {name: $name2}) - MATCH p = allShortestPaths((BaseStandard)-[:(LINKED_TO|AUTOMATICALLY_LINKED_TO|CONTAINS)*..20]-(CompareStandard)) - WITH p - WHERE length(p) > 1 AND ALL(n in NODES(p) WHERE - (n:NeoCRE OR n = BaseStandard OR n = CompareStandard) - AND NOT n.name IN $denylist) - RETURN p - """, - {"name1": name_1, "name2": name_2, "denylist": denylist}, - resolve_objects=True, - ) - return len(r1) + len(r2), 2 # paths, num_queries_run +def run_benchmark(name_1, name_2, db_type, runs=3): + """Run benchmark for a specific database type.""" + os.environ["GRAPH_DB_TYPE"] = db_type + # Force re-initialization of GraphDB if needed (Factory usually handles this) + collection = db.Node_collection() + graph_db = collection.graph_db + print(f"▶ Benchmarking {db_type.upper()} gap analysis: '{name_1}' ↔ '{name_2}'") -def run_optimized(name_1, name_2): - """Tiered pruning from PR #716/#717: exits early when strong/medium links found.""" - denylist = ["Cross-cutting concerns"] - - # Tier 1 — strong links only - r, _ = neomodel_db.cypher_query( - """ - MATCH (BaseStandard:NeoStandard {name: $name1}) - MATCH (CompareStandard:NeoStandard {name: $name2}) - MATCH p = allShortestPaths((BaseStandard)-[:(LINKED_TO|AUTOMATICALLY_LINKED_TO|SAME)*..20]-(CompareStandard)) - WITH p - WHERE length(p) > 1 AND ALL(n in NODES(p) WHERE - (n:NeoCRE OR n = BaseStandard OR n = CompareStandard) - AND NOT n.name IN $denylist) - RETURN p - """, - {"name1": name_1, "name2": name_2, "denylist": denylist}, - resolve_objects=True, - ) - if r: - return len(r), 1, "Tier 1 — strong links (LINKED_TO/SAME/AUTO)" - - # Tier 2 — adds CONTAINS - r, _ = neomodel_db.cypher_query( - """ - MATCH (BaseStandard:NeoStandard {name: $name1}) - MATCH (CompareStandard:NeoStandard {name: $name2}) - MATCH p = allShortestPaths((BaseStandard)-[:(LINKED_TO|AUTOMATICALLY_LINKED_TO|SAME|CONTAINS)*..20]-(CompareStandard)) - WITH p - WHERE length(p) > 1 AND ALL(n in NODES(p) WHERE - (n:NeoCRE OR n = BaseStandard OR n = CompareStandard) - AND NOT n.name IN $denylist) - RETURN p - """, - {"name1": name_1, "name2": name_2, "denylist": denylist}, - resolve_objects=True, - ) - if r: - return len(r), 2, "Tier 2 — medium links (adds CONTAINS)" - - # Tier 3 — wildcard fallback - r, _ = neomodel_db.cypher_query( - """ - MATCH (BaseStandard:NeoStandard {name: $name1}) - MATCH (CompareStandard:NeoStandard {name: $name2}) - MATCH p = allShortestPaths((BaseStandard)-[*..20]-(CompareStandard)) - WITH p - WHERE length(p) > 1 AND ALL(n in NODES(p) WHERE - (n:NeoCRE OR n = BaseStandard OR n = CompareStandard) - AND NOT n.name IN $denylist) - RETURN p - """, - {"name1": name_1, "name2": name_2, "denylist": denylist}, - resolve_objects=True, - ) - return len(r), 3, "Tier 3 — wildcard fallback (no strong/medium paths found)" - - -def benchmark(name_1, name_2, runs=3): - connect_neo4j() - print(f"Benchmarking gap analysis: '{name_1}' ↔ '{name_2}'") - print(f"Averaging over {runs} run(s) per mode\n") - print("=" * 68) - - # MODE A — Original - a_times, a_mems, a_paths, a_queries = [], [], 0, 0 - print("▶ MODE A — Original exhaustive (pre-PR #716 behaviour)...") + times, mems, paths = [], [], 0 for i in range(runs): tracemalloc.start() t0 = time.perf_counter() - a_paths, a_queries = run_original(name_1, name_2) - elapsed = time.perf_counter() - t0 - _, peak = tracemalloc.get_traced_memory() - tracemalloc.stop() - a_times.append(elapsed) - a_mems.append(peak / 1024 / 1024) - print(f" Run {i+1}: {elapsed:.3f}s | peak mem {a_mems[-1]:.2f} MB") - - avg_a_t = sum(a_times) / runs - avg_a_m = sum(a_mems) / runs - print() + # db.gap_analysis returns (node_names, grouped_paths, extra_paths_dict) + _, res_paths, _ = db.gap_analysis(graph_db, [name_1, name_2]) - # MODE B — Optimized - b_times, b_mems, b_paths, b_queries, b_tier = [], [], 0, 0, "" - print("▶ MODE B — Optimized tiered pruning (GAP_ANALYSIS_OPTIMIZED=true)...") - for i in range(runs): - tracemalloc.start() - t0 = time.perf_counter() - b_paths, b_queries, b_tier = run_optimized(name_1, name_2) elapsed = time.perf_counter() - t0 _, peak = tracemalloc.get_traced_memory() tracemalloc.stop() - b_times.append(elapsed) - b_mems.append(peak / 1024 / 1024) - print( - f" Run {i+1}: {elapsed:.3f}s | peak mem {b_mems[-1]:.2f} MB | queries run: {b_queries}" + + times.append(elapsed) + mems.append(peak / 1024 / 1024) + # res_paths is a dict of paths, we count all paths in all groups + total_p = sum( + len(group["paths"]) + group["extra"] for group in res_paths.values() ) + paths = total_p + print(f" Run {i+1}: {elapsed:.3f}s | peak mem {mems[-1]:.2f} MB") - avg_b_t = sum(b_times) / runs - avg_b_m = sum(b_mems) / runs + avg_t = sum(times) / runs + avg_m = sum(mems) / runs + return avg_t, avg_m, paths - t_pct = ((avg_a_t - avg_b_t) / avg_a_t * 100) if avg_a_t > 0 else 0 - m_pct = ((avg_a_m - avg_b_m) / avg_a_m * 100) if avg_a_m > 0 else 0 - direction = "faster" if t_pct >= 0 else "slower" +def benchmark(name_1, name_2, runs=3): + print(f"Comparative Benchmark: '{name_1}' ↔ '{name_2}'") + print(f"Averaging over {runs} run(s) per backend\n") + print("=" * 68) + + # Benchmark Neo4j + neo_t, neo_m, neo_p = run_benchmark(name_1, name_2, "neo4j", runs) + print() + + # Benchmark AGE + age_t, age_m, age_p = run_benchmark(name_1, name_2, "age", runs) print() + + t_pct = ((neo_t - age_t) / neo_t * 100) if neo_t > 0 else 0 + direction = "faster" if t_pct >= 0 else "slower" + print("=" * 68) print("RESULTS") print("=" * 68) - print(f" Pair: '{name_1}' ↔ '{name_2}' | {runs} run(s)\n") - print(f" {'Metric':<26} {'MODE A (original)':>18} {'MODE B (optimized)':>18}") - print(f" {'-'*26} {'-'*18} {'-'*18}") - print(f" {'Avg time (s)':<26} {avg_a_t:>18.3f} {avg_b_t:>18.3f}") - print(f" {'Avg peak memory (MB)':<26} {avg_a_m:>18.2f} {avg_b_m:>18.2f}") - print(f" {'Total paths returned':<26} {a_paths:>18} {b_paths:>18}") - print(f" {'DB queries executed':<26} {a_queries:>18} {b_queries:>18}") + print(f" Metric Neo4j Apache AGE") + print(f" {'-'*60}") + print(f" Avg time (s) {neo_t:>10.3f} {age_t:>10.3f}") + print(f" Avg peak memory (MB) {neo_m:>10.2f} {age_m:>10.2f}") + print(f" Total paths {neo_p:>10} {age_p:>10}") print() - print(f" ⚡ Mode B is {abs(t_pct):.1f}% {direction} than Mode A") - print( - f" 🧠 Mode B used {abs(m_pct):.1f}% {'less' if m_pct >= 0 else 'more'} peak memory" - ) - print(f" 🔍 Mode B exited at: {b_tier}") + print(f" ⚡ Apache AGE is {abs(t_pct):.1f}% {direction} than Neo4j") print("=" * 68) - # GitHub-ready table - print() - print("### GitHub-ready Benchmark Table\n") - print( - "| Metric | Original (`GAP_ANALYSIS_OPTIMIZED=false`) | Optimized (`GAP_ANALYSIS_OPTIMIZED=true`) | Δ |" - ) - print( - "|--------|------------------------------------------|------------------------------------------|---|" - ) - print( - f"| Avg query time | `{avg_a_t:.3f}s` | `{avg_b_t:.3f}s` | **{abs(t_pct):.1f}% {direction}** |" - ) - print( - f"| Peak memory | `{avg_a_m:.2f} MB` | `{avg_b_m:.2f} MB` | **{abs(m_pct):.1f}% {'less' if m_pct >= 0 else 'more'}** |" - ) - print(f"| Paths returned | `{a_paths}` | `{b_paths}` | — |") - print( - f"| DB queries run | `{a_queries}` (always both) | `{b_queries}` (early exit at {b_tier.split('—')[0].strip()}) | — |" - ) - if __name__ == "__main__": - p = argparse.ArgumentParser(description="Gap analysis benchmark — Issue #587") + p = argparse.ArgumentParser(description="Multi-backend Gap analysis benchmark") p.add_argument("--standard1", default="OWASP Top 10 2021") p.add_argument("--standard2", default="NIST 800-53") p.add_argument("--runs", type=int, default=3) p.add_argument("--list-standards", action="store_true") + p.add_argument("--db", choices=["neo4j", "age", "both"], default="both") args = p.parse_args() - if args.list_standards: - list_available_standards() - else: - benchmark(args.standard1, args.standard2, args.runs) + from application import create_app + + app = create_app(mode=os.getenv("FLASK_CONFIG", "development")) + + with app.app_context(): + if args.list_standards: + list_available_standards() + elif args.db == "both": + benchmark(args.standard1, args.standard2, args.runs) + else: + run_benchmark(args.standard1, args.standard2, args.db, args.runs)