Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def pytest_configure(config):
from zou.app import app
from zou.app.utils import dbhelpers

# Register the admin blueprint so it can be tested.
from zou.app.blueprints.admin import blueprint as admin_blueprint

if "admin" not in app.blueprints:
app.register_blueprint(admin_blueprint)

with app.app_context():
from zou.app import db

Expand Down
Empty file added tests/admin/__init__.py
Empty file.
118 changes: 118 additions & 0 deletions tests/admin/test_config_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import orjson as json

from tests.base import ApiDBTestCase
from zou.app import config
from zou.app.stores import config_store

TEST_TOKEN = "test-admin-token"


class ConfigCheckTestCase(ApiDBTestCase):
def setUp(self):
super().setUp()
self._original_token = config.ADMIN_TOKEN
config.ADMIN_TOKEN = TEST_TOKEN

def tearDown(self):
config.ADMIN_TOKEN = self._original_token
if config_store.config_store is not None:
for key in [
config_store.USER_LIMIT_KEY,
config_store.DEFAULT_TIMEZONE_KEY,
config_store.DEFAULT_LOCALE_KEY,
config_store.NOMAD_HOST_KEY,
config_store.NOMAD_NORMALIZE_JOB_KEY,
config_store.NOMAD_PLAYLIST_JOB_KEY,
]:
config_store.config_store.delete(key)
super().tearDown()

def test_403_without_token(self):
response = self.app.get("admin/config/check")
self.assertEqual(response.status_code, 403)

def test_403_with_wrong_token(self):
response = self.app.get(
"admin/config/check",
headers={"Authorization": "Bearer wrong-token"},
)
self.assertEqual(response.status_code, 403)

def test_403_with_malformed_header(self):
response = self.app.get(
"admin/config/check",
headers={"Authorization": TEST_TOKEN},
)
self.assertEqual(response.status_code, 403)

def test_403_when_admin_token_not_set(self):
config.ADMIN_TOKEN = ""
response = self.app.get(
"admin/config/check",
headers={"Authorization": f"Bearer {TEST_TOKEN}"},
)
self.assertEqual(response.status_code, 403)

def test_200_with_correct_token(self):
response = self.app.get(
"admin/config/check",
headers={"Authorization": f"Bearer {TEST_TOKEN}"},
)
self.assertEqual(response.status_code, 200)
data = json.loads(response.data)

for key in [
"user_limit",
"default_timezone",
"default_locale",
"nomad_host",
"nomad_normalize_job",
"nomad_playlist_job",
]:
self.assertIn(key, data)
self.assertIn("env", data[key])
self.assertIn("redis", data[key])

self.assertIn("active_users", data)

def test_env_values_match_config(self):
response = self.app.get(
"admin/config/check",
headers={"Authorization": f"Bearer {TEST_TOKEN}"},
)
data = json.loads(response.data)

self.assertEqual(data["user_limit"]["env"], config.USER_LIMIT)
self.assertEqual(
data["default_timezone"]["env"], config.DEFAULT_TIMEZONE
)
self.assertEqual(data["default_locale"]["env"], config.DEFAULT_LOCALE)
self.assertEqual(
data["nomad_host"]["env"], config.JOB_QUEUE_NOMAD_HOST
)

def test_redis_values_after_sync(self):
config_store.sync_config()
response = self.app.get(
"admin/config/check",
headers={"Authorization": f"Bearer {TEST_TOKEN}"},
)
data = json.loads(response.data)

self.assertEqual(int(data["user_limit"]["redis"]), config.USER_LIMIT)
self.assertEqual(
data["default_timezone"]["redis"], config.DEFAULT_TIMEZONE
)
self.assertEqual(
data["default_locale"]["redis"], config.DEFAULT_LOCALE
)

def test_active_users_count(self):
response = self.app.get(
"admin/config/check",
headers={"Authorization": f"Bearer {TEST_TOKEN}"},
)
data = json.loads(response.data)

self.assertIsInstance(data["active_users"], int)
self.assertGreaterEqual(data["active_users"], 1)
143 changes: 143 additions & 0 deletions tests/admin/test_get_current_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import copy
import os
import unittest
from unittest.mock import patch, MagicMock

import requests
from click.testing import CliRunner

# Ensure ADMIN_TOKEN is set before zou.cli is imported so the command
# is registered.
os.environ.setdefault("ADMIN_TOKEN", "test-admin-token")

from zou.cli import cli # noqa: E402

TEST_TOKEN = "test-admin-token"

# Env vars as seen by the CLI process during tests.
CLI_ENV = {
"ADMIN_TOKEN": TEST_TOKEN,
"USER_LIMIT": "100",
"DEFAULT_TIMEZONE": "Europe/Paris",
"DEFAULT_LOCALE": "en_US",
"JOB_QUEUE_NOMAD_HOST": "zou-nomad-01.zou",
"JOB_QUEUE_NOMAD_NORMALIZE_JOB": "",
"JOB_QUEUE_NOMAD_PLAYLIST_JOB": "zou-playlist",
}

# API response when everything is in sync with CLI_ENV.
API_RESPONSE_SYNCED = {
"user_limit": {"env": 100, "redis": "100"},
"default_timezone": {"env": "Europe/Paris", "redis": "Europe/Paris"},
"default_locale": {"env": "en_US", "redis": "en_US"},
"nomad_host": {"env": "zou-nomad-01.zou", "redis": "zou-nomad-01.zou"},
"nomad_normalize_job": {"env": "", "redis": ""},
"nomad_playlist_job": {"env": "zou-playlist", "redis": "zou-playlist"},
"active_users": 42,
}

# API response where redis has a stale value for user_limit.
API_RESPONSE_DESYNC = {
"user_limit": {"env": 200, "redis": "200"},
"default_timezone": {"env": "Europe/Paris", "redis": "Europe/Paris"},
"default_locale": {"env": "en_US", "redis": "en_US"},
"nomad_host": {"env": "zou-nomad-01.zou", "redis": "zou-nomad-01.zou"},
"nomad_normalize_job": {"env": "", "redis": ""},
"nomad_playlist_job": {"env": "zou-playlist", "redis": "zou-playlist"},
"active_users": 10,
}


def _mock_response(json_data, status_code=200):
resp = MagicMock()
resp.status_code = status_code
resp.json.return_value = copy.deepcopy(json_data)
resp.text = str(json_data)
if status_code >= 400:
resp.raise_for_status.side_effect = requests.HTTPError(response=resp)
else:
resp.raise_for_status.return_value = None
return resp


class GetCurrentConfigCommandTestCase(unittest.TestCase):
def setUp(self):
self.runner = CliRunner()

@patch.dict(os.environ, CLI_ENV, clear=False)
@patch("requests.get")
def test_exit_0_when_synced(self, mock_get):
mock_get.return_value = _mock_response(API_RESPONSE_SYNCED)

result = self.runner.invoke(cli, ["get-current-config"], color=True)
self.assertEqual(result.exit_code, 0)
self.assertIn("Config check", result.output)
self.assertIn("user_limit", result.output)
self.assertIn("Active users: 42", result.output)
self.assertNotIn("✗", result.output)

@patch.dict(os.environ, CLI_ENV, clear=False)
@patch("requests.get")
def test_exit_1_when_desync(self, mock_get):
"""env-cli has USER_LIMIT=100 but redis has 200 → mismatch."""
mock_get.return_value = _mock_response(API_RESPONSE_DESYNC)

result = self.runner.invoke(cli, ["get-current-config"], color=True)
self.assertNotEqual(result.exit_code, 0)
self.assertIn("✗", result.output)
self.assertIn("out of sync", result.output)

@patch.dict(os.environ, CLI_ENV, clear=False)
@patch("requests.get")
def test_shows_three_columns(self, mock_get):
mock_get.return_value = _mock_response(API_RESPONSE_SYNCED)

result = self.runner.invoke(cli, ["get-current-config"], color=True)
self.assertIn("Env (CLI)", result.output)
self.assertIn("Redis", result.output)
self.assertIn("Env (API)", result.output)

@patch.dict(os.environ, CLI_ENV, clear=False)
@patch("requests.get")
def test_null_redis_shows_empty_symbol(self, mock_get):
data = copy.deepcopy(API_RESPONSE_SYNCED)
data["nomad_normalize_job"]["redis"] = None
mock_get.return_value = _mock_response(data)

result = self.runner.invoke(cli, ["get-current-config"], color=True)
self.assertIn("∅", result.output)
# env-cli is "" but redis is "∅" → mismatch
self.assertNotEqual(result.exit_code, 0)

@patch.dict(os.environ, CLI_ENV, clear=False)
@patch("requests.get")
def test_connection_error(self, mock_get):
mock_get.side_effect = requests.ConnectionError()

result = self.runner.invoke(cli, ["get-current-config"])
self.assertNotEqual(result.exit_code, 0)
self.assertIn("Cannot connect", result.output)

@patch.dict(os.environ, CLI_ENV, clear=False)
@patch("requests.get")
def test_http_error(self, mock_get):
mock_get.return_value = _mock_response({}, status_code=403)

result = self.runner.invoke(cli, ["get-current-config"])
self.assertNotEqual(result.exit_code, 0)
self.assertIn("Error 403", result.output)

@patch.dict(os.environ, CLI_ENV, clear=False)
@patch("requests.get")
def test_custom_host_option(self, mock_get):
mock_get.return_value = _mock_response(API_RESPONSE_SYNCED)

result = self.runner.invoke(
cli,
["get-current-config", "--host", "http://myhost:8080"],
)
self.assertEqual(result.exit_code, 0)
mock_get.assert_called_once_with(
"http://myhost:8080/admin/config/check",
headers={"Authorization": f"Bearer {TEST_TOKEN}"},
)
7 changes: 7 additions & 0 deletions zou/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from zou.app.blueprints.edits import blueprint as edits_blueprint
from zou.app.blueprints.concepts import blueprint as concepts_blueprint

from zou.app import config
from zou.app.utils.plugins import load_plugins
from zou.app.utils import events

Expand Down Expand Up @@ -70,6 +71,12 @@ def configure_api_routes(app):
app.register_blueprint(edits_blueprint)
app.register_blueprint(search_blueprint)
app.register_blueprint(concepts_blueprint)

if config.ADMIN_TOKEN:
from zou.app.blueprints.admin import blueprint as admin_blueprint

app.register_blueprint(admin_blueprint)

return app


Expand Down
13 changes: 13 additions & 0 deletions zou/app/blueprints/admin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from flask import Blueprint
from zou.app.utils.api import configure_api_from_blueprint

from zou.app.blueprints.admin.resources import (
ConfigCheckResource,
)

routes = [
("/admin/config/check", ConfigCheckResource),
]

blueprint = Blueprint("admin", "admin")
api = configure_api_from_blueprint(blueprint, routes)
20 changes: 20 additions & 0 deletions zou/app/blueprints/admin/resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from flask import abort, request
from flask_restful import Resource

from zou.app import config
from zou.app.models.person import Person
from zou.app.stores import config_store


class ConfigCheckResource(Resource):

def get(self):
token = request.headers.get("Authorization", "")
if not token.startswith("Bearer ") or token[7:] != config.ADMIN_TOKEN:
abort(403)

comparison = config_store.get_config_comparison()
comparison["active_users"] = Person.query.filter(
Person.active, Person.is_bot.isnot(True)
).count()
return comparison
1 change: 1 addition & 0 deletions zou/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
DEFAULT_LOCALE = os.getenv("DEFAULT_LOCALE", "en_US")

USER_LIMIT = int(os.getenv("USER_LIMIT", "100"))
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", "")
MIN_PASSWORD_LENGTH = int(os.getenv("MIN_PASSWORD_LENGTH", 8))
PROTECTED_ACCOUNTS = env_with_semicolon_to_list("PROTECTED_ACCOUNTS")
ENFORCE_2FA = envtobool("ENFORCE_2FA", False)
Expand Down
38 changes: 38 additions & 0 deletions zou/app/stores/config_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@
print("Cannot access to the required Redis instance")


def _get_redis_raw(key):
if config_store is not None:
try:
return config_store.get(key)
except redis.ConnectionError:
pass
return None


def _get(key, fallback):
if config_store is not None:
try:
Expand Down Expand Up @@ -72,6 +81,35 @@ def get_nomad_playlist_job():
return _get(NOMAD_PLAYLIST_JOB_KEY, config.JOB_QUEUE_NOMAD_PLAYLIST_JOB)


def get_config_comparison():
return {
"user_limit": {
"env": config.USER_LIMIT,
"redis": _get_redis_raw(USER_LIMIT_KEY),
},
"default_timezone": {
"env": config.DEFAULT_TIMEZONE,
"redis": _get_redis_raw(DEFAULT_TIMEZONE_KEY),
},
"default_locale": {
"env": config.DEFAULT_LOCALE,
"redis": _get_redis_raw(DEFAULT_LOCALE_KEY),
},
"nomad_host": {
"env": config.JOB_QUEUE_NOMAD_HOST,
"redis": _get_redis_raw(NOMAD_HOST_KEY),
},
"nomad_normalize_job": {
"env": config.JOB_QUEUE_NOMAD_NORMALIZE_JOB,
"redis": _get_redis_raw(NOMAD_NORMALIZE_JOB_KEY),
},
"nomad_playlist_job": {
"env": config.JOB_QUEUE_NOMAD_PLAYLIST_JOB,
"redis": _get_redis_raw(NOMAD_PLAYLIST_JOB_KEY),
},
}


def sync_config():
"""
Read config values from environment variables and push them to
Expand Down
Loading
Loading